Skip to content

Fix DreamBooth LoRA fp16 training crash after validation#13510

Open
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Ricardo-M-L:fix/dreambooth-lora-fp16-validation-unscale
Open

Fix DreamBooth LoRA fp16 training crash after validation#13510
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Ricardo-M-L:fix/dreambooth-lora-fp16-validation-unscale

Conversation

@Ricardo-M-L
Copy link
Copy Markdown

What does this PR do?

Fixes a regression that aborts `examples/dreambooth/train_dreambooth_lora.py` on the first training step after the first validation, whenever the user combines `--mixed_precision=fp16` and `--validation_prompt`:

```
ValueError: Attempting to unscale FP16 gradients.
```

Root cause

  1. LoRA trainable params are upcast to fp32 once, before the training loop, via `cast_training_params(models, dtype=torch.float32)` (the pre-existing mitigation from [Training] fix training resuming problem when using FP16 (SDXL LoRA DreamBooth) #6514 / [training] fix training resuming problem for fp16 (SD LoRA DreamBooth) #6554).
  2. Validation builds a pipeline with `unet=unwrap_model(unet)` (i.e. the same module object) and `torch_dtype=weight_dtype`.
  3. `log_validation` then does `pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)` (line 150 of the script), which casts every component in-place — including the shared `unet` — back down to fp16. The LoRA adapter weights are nested inside that module and get cast along with the rest.
  4. On the next training step the backward pass produces fp16 grads, which `accelerate`'s `GradScaler.unscale_` refuses to process.

Fix

Mirror the pre-training upcast immediately after `log_validation` returns: if `args.mixed_precision == "fp16"`, re-run `cast_training_params(..., dtype=torch.float32)` on the trainable modules. bf16 mixed-precision is unaffected because no grad scaler is involved there.

Only one file is touched; 10 lines added, 0 removed. No behavioral change when `--validation_prompt` is not set, or when mixed-precision is anything other than fp16.

Fixes #13124

Before submitting

  • Did you read the contributor guideline?
  • This PR fixes a bug (linked above).
  • Minimal, targeted change — no refactor.

Who can review?

cc @sayakpaul

When `--mixed_precision=fp16` and `--validation_prompt` are both set,
training aborts on the first step after the first validation with:

  ValueError: Attempting to unscale FP16 gradients.

Root cause:

* The LoRA trainable params are upcast to fp32 once, before training,
  via `cast_training_params(models, dtype=torch.float32)`.
* Validation builds `DiffusionPipeline.from_pretrained(unet=unwrap_model(unet),
  torch_dtype=weight_dtype, ...)` and hands the pipeline to `log_validation`.
* `log_validation` calls `pipeline.to(accelerator.device, dtype=torch_dtype)`,
  which casts the *shared* `unet` module — including the LoRA adapter weights
  registered as trainable — back down to fp16.
* The next backward then produces fp16 grads, and the grad scaler refuses to
  unscale them.

Re-run `cast_training_params(..., dtype=torch.float32)` immediately after
`log_validation` returns (only when `mixed_precision == "fp16"`), mirroring
the pre-training upcast. bf16 mixed-precision is unaffected since no grad
scaler is in play there.

Fixes huggingface#13124
@github-actions github-actions bot added examples size/S PR with diff < 50 LOC labels Apr 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples size/S PR with diff < 50 LOC

Projects

None yet

Development

Successfully merging this pull request may close these issues.

train_dreambooth_lora.py -- ValueError: Attempting to unscale FP16 gradients caused by "--validation_prompt" param.

1 participant