From 2daab54a071958c6caad7c5703dc81d13508e5c9 Mon Sep 17 00:00:00 2001 From: gambletan Date: Mon, 16 Mar 2026 22:13:09 +0800 Subject: [PATCH] fix: wrap validation inference with torch.no_grad() in dreambooth examples When `log_validation()` runs pipeline inference during training, it uses the same UNet/transformer that is being trained. Without `torch.no_grad()`, PyTorch computes and stores gradients during validation. With `--mixed_precision="fp16"`, this causes the gradient scaler to encounter FP16 gradients from the validation pass when training resumes, resulting in: ValueError: Attempting to unscale FP16 gradients. This adds `torch.no_grad()` around all pipeline inference calls in `log_validation()` across all dreambooth training scripts to prevent gradient computation during validation. Fixes huggingface/diffusers#13124 --- examples/dreambooth/train_dreambooth.py | 21 ++++++++-------- examples/dreambooth/train_dreambooth_flux.py | 2 +- examples/dreambooth/train_dreambooth_lora.py | 20 ++++++++------- .../dreambooth/train_dreambooth_lora_flux.py | 13 +++++----- .../dreambooth/train_dreambooth_lora_flux2.py | 15 +++++------ .../train_dreambooth_lora_flux2_img2img.py | 17 +++++++------ .../train_dreambooth_lora_flux2_klein.py | 15 +++++------ ...ain_dreambooth_lora_flux2_klein_img2img.py | 19 +++++++------- .../train_dreambooth_lora_flux_kontext.py | 19 +++++++------- .../train_dreambooth_lora_hidream.py | 25 ++++++++++--------- .../train_dreambooth_lora_lumina2.py | 2 +- .../train_dreambooth_lora_qwen_image.py | 17 +++++++------ .../dreambooth/train_dreambooth_lora_sana.py | 3 ++- .../dreambooth/train_dreambooth_lora_sd3.py | 2 +- .../dreambooth/train_dreambooth_lora_sdxl.py | 2 +- .../train_dreambooth_lora_z_image.py | 17 +++++++------ examples/dreambooth/train_dreambooth_sd3.py | 2 +- 17 files changed, 112 insertions(+), 99 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3a2b32aaef5..9dbe2b8f02d5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -175,16 +175,17 @@ def log_validation( # run inference generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) images = [] - if args.validation_images is None: - for _ in range(args.num_validation_images): - with torch.autocast("cuda"): - image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] - images.append(image) - else: - for image in args.validation_images: - image = Image.open(image) - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + if args.validation_images is None: + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] + images.append(image) + else: + for image in args.validation_images: + image = Image.open(image) + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: if tracker.name == "tensorboard": diff --git a/examples/dreambooth/train_dreambooth_flux.py b/examples/dreambooth/train_dreambooth_flux.py index c7e0c290fa8e..1e8309d40589 100644 --- a/examples/dreambooth/train_dreambooth_flux.py +++ b/examples/dreambooth/train_dreambooth_flux.py @@ -191,7 +191,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index b6baccc4bc99..367cf909ad8a 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -155,17 +155,19 @@ def log_validation( if args.validation_images is None: images = [] - for _ in range(args.num_validation_images): - with torch.amp.autocast(accelerator.device.type): - image = pipeline(**pipeline_args, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with torch.amp.autocast(accelerator.device.type): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) else: images = [] - for image in args.validation_images: - image = Image.open(image) - with torch.amp.autocast(accelerator.device.type): - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) + with torch.no_grad(): + for image in args.validation_images: + image = Image.open(image) + with torch.amp.autocast(accelerator.device.type): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index e0e7d2e40e56..c201a3aef110 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -210,12 +210,13 @@ def log_validation( pipeline_args["prompt"], prompt_2=pipeline_args["prompt"] ) images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2.py b/examples/dreambooth/train_dreambooth_lora_flux2.py index 24d098add017..c6788575b1b5 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2.py @@ -209,13 +209,14 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py index e18909e6dfd7..19c7016b7662 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_img2img.py @@ -209,14 +209,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - image=pipeline_args["image"], - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py index 268d0148e446..2fcb0b870f1b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein.py @@ -209,13 +209,14 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py index 0205f2e9e65f..ee0e5795498e 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py +++ b/examples/dreambooth/train_dreambooth_lora_flux2_klein_img2img.py @@ -209,15 +209,16 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - image=pipeline_args["image"], - prompt_embeds=pipeline_args["prompt_embeds"], - negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + image=pipeline_args["image"], + prompt_embeds=pipeline_args["prompt_embeds"], + negative_prompt_embeds=pipeline_args["negative_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py index dee65761e92b..e9a20eb7634b 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux_kontext.py +++ b/examples/dreambooth/train_dreambooth_lora_flux_kontext.py @@ -213,15 +213,16 @@ def log_validation( prompt = pipeline_args_cp.pop("prompt") prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None) images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - **pipeline_args_cp, - prompt_embeds=prompt_embeds, - pooled_prompt_embeds=pooled_prompt_embeds, - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + **pipeline_args_cp, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_hidream.py b/examples/dreambooth/train_dreambooth_lora_hidream.py index bd2fb8db2d21..eef2bafb74dd 100644 --- a/examples/dreambooth/train_dreambooth_lora_hidream.py +++ b/examples/dreambooth/train_dreambooth_lora_hidream.py @@ -212,18 +212,19 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds_t5=pipeline_args["prompt_embeds_t5"], - prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"], - negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"], - negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"], - pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"], - negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds_t5=pipeline_args["prompt_embeds_t5"], + prompt_embeds_llama3=pipeline_args["prompt_embeds_llama3"], + negative_prompt_embeds_t5=pipeline_args["negative_prompt_embeds_t5"], + negative_prompt_embeds_llama3=pipeline_args["negative_prompt_embeds_llama3"], + pooled_prompt_embeds=pipeline_args["pooled_prompt_embeds"], + negative_pooled_prompt_embeds=pipeline_args["negative_pooled_prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_lumina2.py b/examples/dreambooth/train_dreambooth_lora_lumina2.py index 48eba4c5041d..8aac5c742644 100644 --- a/examples/dreambooth/train_dreambooth_lora_lumina2.py +++ b/examples/dreambooth/train_dreambooth_lora_lumina2.py @@ -171,7 +171,7 @@ def log_validation( generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_qwen_image.py b/examples/dreambooth/train_dreambooth_lora_qwen_image.py index a1e2fa0f6052..05f3545951fe 100644 --- a/examples/dreambooth/train_dreambooth_lora_qwen_image.py +++ b/examples/dreambooth/train_dreambooth_lora_qwen_image.py @@ -200,14 +200,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt_embeds=pipeline_args["prompt_embeds"], - prompt_embeds_mask=pipeline_args["prompt_embeds_mask"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt_embeds=pipeline_args["prompt_embeds"], + prompt_embeds_mask=pipeline_args["prompt_embeds_mask"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_sana.py b/examples/dreambooth/train_dreambooth_lora_sana.py index 3b295163b73d..4485d2f9c01b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sana.py +++ b/examples/dreambooth/train_dreambooth_lora_sana.py @@ -192,7 +192,8 @@ def log_validation( # run inference generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None - images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] + with torch.no_grad(): + images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index 4f49ef4bd801..29a7ed986981 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -204,7 +204,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 502ce1a3f1ec..8d173c8c3e66 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -216,7 +216,7 @@ def log_validation( else: autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: diff --git a/examples/dreambooth/train_dreambooth_lora_z_image.py b/examples/dreambooth/train_dreambooth_lora_z_image.py index 623ae4d2aca3..b02a124bc159 100644 --- a/examples/dreambooth/train_dreambooth_lora_z_image.py +++ b/examples/dreambooth/train_dreambooth_lora_z_image.py @@ -208,14 +208,15 @@ def log_validation( autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() images = [] - for _ in range(args.num_validation_images): - with autocast_ctx: - image = pipeline( - prompt=args.validation_prompt, - prompt_embeds=pipeline_args["prompt_embeds"], - generator=generator, - ).images[0] - images.append(image) + with torch.no_grad(): + for _ in range(args.num_validation_images): + with autocast_ctx: + image = pipeline( + prompt=args.validation_prompt, + prompt_embeds=pipeline_args["prompt_embeds"], + generator=generator, + ).images[0] + images.append(image) for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 98e7d2d66cbc..28ce2b4d8a62 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -180,7 +180,7 @@ def log_validation( # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext() autocast_ctx = nullcontext() - with autocast_ctx: + with torch.no_grad(), autocast_ctx: images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)] for tracker in accelerator.trackers: