diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3a2b32aaef5..9dbe2b8f02d5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -175,16 +175,17 @@ def log_validation( # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - if args.validation_images is None: - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) - else: - for image in args.validation_images: - image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index c7e0c290fa8e..1e8309d40589 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -191,7 +191,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b6baccc4bc99..367cf909ad8a 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -155,17 +155,19 @@ def log_validation( if args.validation_images is None: images = [] - for _ in range(args.num_validation_images): - with torch.amp.autocast(accelerator.device.type): - image = pipeline(**pipeline_args, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with torch.amp.autocast(accelerator.device.type): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) else: images = [] - for image in args.validation_images: - image = Image.open(image) - with torch.amp.autocast(accelerator.device.type): - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + for image in args.validation_images: + image = Image.open(image) + with torch.amp.autocast(accelerator.device.type): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e0e7d2e40e56..c201a3aef110 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -210,12 +210,13 @@ def log_validation( pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] ) images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..c6788575b1b5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -209,13 +209,14 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index e18909e6dfd7..19c7016b7662 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -209,14 +209,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - image=pipeline_args["image"], - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..2fcb0b870f1b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -209,13 +209,14 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..ee0e5795498e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -209,15 +209,16 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - image=pipeline_args["image"], - prompt_embeds=pipeline_args["prompt_embeds"], - negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index dee65761e92b..e9a20eb7634b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -213,15 +213,16 @@ def log_validation( prompt = pipeline_args_cp.pop("prompt") prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None) images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - **pipeline_args_cp, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + **pipeline_args_cp, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index bd2fb8db2d21..eef2bafb74dd 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -212,18 +212,19 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds_t5=pipeline_args["prompt_embeds_t5"], - prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"], - negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"], - negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"], - pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"], - negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds_t5=pipeline_args["prompt_embeds_t5"], + prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"], + negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"], + negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"], + pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"], + negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index 48eba4c5041d..8aac5c742644 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -171,7 +171,7 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index a1e2fa0f6052..05f3545951fe 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -200,14 +200,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - prompt_embeds_mask=pipeline_args["prompt_embeds_mask"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + prompt_embeds_mask=pipeline_args["prompt_embeds_mask"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 3b295163b73d..4485d2f9c01b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -192,7 +192,8 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + with torch.no_grad(): + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4f49ef4bd801..29a7ed986981 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -204,7 +204,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 502ce1a3f1ec..8d173c8c3e66 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -216,7 +216,7 @@ def log_validation( else: autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 623ae4d2aca3..b02a124bc159 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -208,14 +208,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt=args.validation_prompt, - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=args.validation_prompt, + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 98e7d2d66cbc..28ce2b4d8a62 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -180,7 +180,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: