Skip to content

fix: wrap validation inference with torch.no_grad() in dreambooth examples#13273

Open
gambletan wants to merge 1 commit intohuggingface:mainfrom
gambletan:fix/dreambooth-validation-fp16-gradient-crash
Open

fix: wrap validation inference with torch.no_grad() in dreambooth examples#13273
gambletan wants to merge 1 commit intohuggingface:mainfrom
gambletan:fix/dreambooth-validation-fp16-gradient-crash

Conversation

@gambletan
Copy link

Summary

Fixes #13124

log_validation() in dreambooth training scripts runs pipeline inference using the same UNet/transformer that is being trained, without torch.no_grad(). With --mixed_precision="fp16", gradients are computed and stored during validation inference. When the training loop resumes, the gradient scaler tries to unscale these FP16 gradients, causing:

ValueError: Attempting to unscale FP16 gradients.

This PR wraps all pipeline inference calls in log_validation() with torch.no_grad() across all 17 dreambooth training scripts, not just the one reported in the issue.

Changes

  • train_dreambooth.py
  • train_dreambooth_flux.py
  • train_dreambooth_sd3.py
  • train_dreambooth_lora.py
  • train_dreambooth_lora_flux.py
  • train_dreambooth_lora_flux2.py
  • train_dreambooth_lora_flux2_img2img.py
  • train_dreambooth_lora_flux2_klein.py
  • train_dreambooth_lora_flux2_klein_img2img.py
  • train_dreambooth_lora_flux_kontext.py
  • train_dreambooth_lora_hidream.py
  • train_dreambooth_lora_lumina2.py
  • train_dreambooth_lora_qwen_image.py
  • train_dreambooth_lora_sana.py
  • train_dreambooth_lora_sd3.py
  • train_dreambooth_lora_sdxl.py
  • train_dreambooth_lora_z_image.py

Root Cause

Validation runs the pipeline which internally calls the UNet/transformer forward pass. Since torch.no_grad() is not set, autograd tracks operations and creates FP16 gradient tensors. The GradScaler in the training loop then fails because it expects FP32 gradients to unscale.

Test Plan

  • Verify train_dreambooth_lora.py with --mixed_precision="fp16" and --validation_prompt no longer crashes with ValueError: Attempting to unscale FP16 gradients
  • Verify validation images are still generated correctly (inference output is identical with or without no_grad)

…mples

When `log_validation()` runs pipeline inference during training, it uses the
same UNet/transformer that is being trained. Without `torch.no_grad()`, PyTorch
computes and stores gradients during validation. With `--mixed_precision="fp16"`,
this causes the gradient scaler to encounter FP16 gradients from the validation
pass when training resumes, resulting in:

    ValueError: Attempting to unscale FP16 gradients.

This adds `torch.no_grad()` around all pipeline inference calls in
`log_validation()` across all dreambooth training scripts to prevent gradient
computation during validation.

Fixes huggingface#13124
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

train_dreambooth_lora.py -- ValueError: Attempting to unscale FP16 gradients caused by "--validation_prompt" param.

1 participant