From f86782dac144c8d0f8c9cb2121aac7b2c0574c22 Mon Sep 17 00:00:00 2001 From: Sivan Mehta Date: Wed, 21 Sep 2022 14:34:13 -0700 Subject: [PATCH 1/4] add --batch parameter --- text2image.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/text2image.py b/text2image.py index 0e79fcc..3ebc607 100644 --- a/text2image.py +++ b/text2image.py @@ -46,6 +46,10 @@ "--steps", type=int, default=50, help="number of ddim sampling steps" ) +parser.add_argument( + "--batch", type=int, default=1, help="number of images to generate" +) + parser.add_argument( "--seed", type=int, @@ -71,8 +75,17 @@ num_steps=args.steps, unconditional_guidance_scale=args.scale, temperature=1, - batch_size=1, + batch_size=args.batch, seed=args.seed, ) -Image.fromarray(img[0]).save(args.output) -print(f"saved at {args.output}") + +if(args.batch > 1): + Image.fromarray(img[0]).save(args.output) + print(f"saved at {args.output}") +else: + for i in range(args.batch): + filename = f"{i}-args.output" + Image.fromarray(img[i]).save(args.output) + + print(f"saved {args.batch} images at {args.output}") + From 3c8c21b3f441a156a200751fb337928d0b844c61 Mon Sep 17 00:00:00 2001 From: Sivan Mehta Date: Wed, 21 Sep 2022 15:46:01 -0700 Subject: [PATCH 2/4] better logging --- text2image.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/text2image.py b/text2image.py index 3ebc607..d0c1f0e 100644 --- a/text2image.py +++ b/text2image.py @@ -83,9 +83,13 @@ Image.fromarray(img[0]).save(args.output) print(f"saved at {args.output}") else: + split_filename = args.output.split(".") + filename = split_filename[0:-1] + extension = split_filename[-1] + generate_filename = lambda x: f"{filename}-{x}.{extension}" for i in range(args.batch): - filename = f"{i}-args.output" + filename = generate_filename(i + 1) Image.fromarray(img[i]).save(args.output) - print(f"saved {args.batch} images at {args.output}") + print(f"saved {args.batch} images as {generate_filename(f"1, {args.batch + 1}")}") From cbeccb6e6963b836e0ebcb80117442172bf5d648 Mon Sep 17 00:00:00 2001 From: Sivan Mehta Date: Wed, 21 Sep 2022 16:45:50 -0700 Subject: [PATCH 3/4] flip around conditional --- text2image.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/text2image.py b/text2image.py index d0c1f0e..fa646bd 100644 --- a/text2image.py +++ b/text2image.py @@ -79,17 +79,19 @@ seed=args.seed, ) -if(args.batch > 1): - Image.fromarray(img[0]).save(args.output) - print(f"saved at {args.output}") +if(args.batch <= 1): + Image.fromarray(img[0]).save(args.output) + print(f"saved at {args.output}") else: split_filename = args.output.split(".") - filename = split_filename[0:-1] + filename = ''.join(split_filename[0:-1]) extension = split_filename[-1] - generate_filename = lambda x: f"{filename}-{x}.{extension}" + def generate_filename(suffix): + return f"{filename}-{suffix}.{extension}" + for i in range(args.batch): - filename = generate_filename(i + 1) - Image.fromarray(img[i]).save(args.output) + generated_filename = generate_filename(i + 1) + Image.fromarray(img[i]).save(generated_filename) - print(f"saved {args.batch} images as {generate_filename(f"1, {args.batch + 1}")}") + print(f"saved {args.batch} images as {generate_filename(f'[{1}, {args.batch}]')}") From 60b4e4029b251077ea0f71683874a74fb4c19ecc Mon Sep 17 00:00:00 2001 From: Sivan Mehta Date: Wed, 21 Sep 2022 16:50:33 -0700 Subject: [PATCH 4/4] restore generation --- text2image.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/text2image.py b/text2image.py index fa646bd..a2163c3 100644 --- a/text2image.py +++ b/text2image.py @@ -86,12 +86,9 @@ split_filename = args.output.split(".") filename = ''.join(split_filename[0:-1]) extension = split_filename[-1] - def generate_filename(suffix): - return f"{filename}-{suffix}.{extension}" for i in range(args.batch): - generated_filename = generate_filename(i + 1) + generated_filename = f"{filename}-{i+1}.{extension}" Image.fromarray(img[i]).save(generated_filename) - - print(f"saved {args.batch} images as {generate_filename(f'[{1}, {args.batch}]')}") + print(f"saved at {generated_filename}")