Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,17 @@ def log_validation(
# run inference
generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
if args.validation_images is None:
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image)
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
with torch.no_grad():
if args.validation_images is None:
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
images.append(image)
else:
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def log_validation(
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

with autocast_ctx:
with torch.no_grad(), autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
20 changes: 11 additions & 9 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,19 @@ def log_validation(

if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
with torch.no_grad():
for image in args.validation_images:
image = Image.open(image)
with torch.amp.autocast(accelerator.device.type):
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
13 changes: 7 additions & 6 deletions examples/dreambooth/train_dreambooth_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,12 +210,13 @@ def log_validation(
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
15 changes: 8 additions & 7 deletions examples/dreambooth/train_dreambooth_lora_flux2.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
17 changes: 9 additions & 8 deletions examples/dreambooth/train_dreambooth_lora_flux2_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,14 +209,15 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
image=pipeline_args["image"],
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
image=pipeline_args["image"],
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
15 changes: 8 additions & 7 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,13 +209,14 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
19 changes: 10 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,16 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
image=pipeline_args["image"],
prompt_embeds=pipeline_args["prompt_embeds"],
negative_prompt_embeds=pipeline_args["negative_prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
image=pipeline_args["image"],
prompt_embeds=pipeline_args["prompt_embeds"],
negative_prompt_embeds=pipeline_args["negative_prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
19 changes: 10 additions & 9 deletions examples/dreambooth/train_dreambooth_lora_flux_kontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,16 @@ def log_validation(
prompt = pipeline_args_cp.pop("prompt")
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
**pipeline_args_cp,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
**pipeline_args_cp,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
25 changes: 13 additions & 12 deletions examples/dreambooth/train_dreambooth_lora_hidream.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,19 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds_t5=pipeline_args["prompt_embeds_t5"],
prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"],
negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"],
negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"],
pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"],
negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds_t5=pipeline_args["prompt_embeds_t5"],
prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"],
negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"],
negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"],
pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"],
negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_lumina2.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def log_validation(
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

with autocast_ctx:
with torch.no_grad(), autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
17 changes: 9 additions & 8 deletions examples/dreambooth/train_dreambooth_lora_qwen_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,15 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
prompt_embeds_mask=pipeline_args["prompt_embeds_mask"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt_embeds=pipeline_args["prompt_embeds"],
prompt_embeds_mask=pipeline_args["prompt_embeds_mask"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
3 changes: 2 additions & 1 deletion examples/dreambooth/train_dreambooth_lora_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def log_validation(
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None

images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
with torch.no_grad():
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def log_validation(
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

with autocast_ctx:
with torch.no_grad(), autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def log_validation(
else:
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

with autocast_ctx:
with torch.no_grad(), autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down
17 changes: 9 additions & 8 deletions examples/dreambooth/train_dreambooth_lora_z_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,15 @@ def log_validation(
autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()

images = []
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt=args.validation_prompt,
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)
with torch.no_grad():
for _ in range(args.num_validation_images):
with autocast_ctx:
image = pipeline(
prompt=args.validation_prompt,
prompt_embeds=pipeline_args["prompt_embeds"],
generator=generator,
).images[0]
images.append(image)

for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
Expand Down
2 changes: 1 addition & 1 deletion examples/dreambooth/train_dreambooth_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def log_validation(
# autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
autocast_ctx = nullcontext()

with autocast_ctx:
with torch.no_grad(), autocast_ctx:
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]

for tracker in accelerator.trackers:
Expand Down