Fix DreamBooth LoRA fp16 training crash after validation#13510
Open
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Open
Fix DreamBooth LoRA fp16 training crash after validation#13510Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Ricardo-M-L wants to merge 1 commit intohuggingface:mainfrom
Conversation
When `--mixed_precision=fp16` and `--validation_prompt` are both set, training aborts on the first step after the first validation with: ValueError: Attempting to unscale FP16 gradients. Root cause: * The LoRA trainable params are upcast to fp32 once, before training, via `cast_training_params(models, dtype=torch.float32)`. * Validation builds `DiffusionPipeline.from_pretrained(unet=unwrap_model(unet), torch_dtype=weight_dtype, ...)` and hands the pipeline to `log_validation`. * `log_validation` calls `pipeline.to(accelerator.device, dtype=torch_dtype)`, which casts the *shared* `unet` module — including the LoRA adapter weights registered as trainable — back down to fp16. * The next backward then produces fp16 grads, and the grad scaler refuses to unscale them. Re-run `cast_training_params(..., dtype=torch.float32)` immediately after `log_validation` returns (only when `mixed_precision == "fp16"`), mirroring the pre-training upcast. bf16 mixed-precision is unaffected since no grad scaler is in play there. Fixes huggingface#13124
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Fixes a regression that aborts `examples/dreambooth/train_dreambooth_lora.py` on the first training step after the first validation, whenever the user combines `--mixed_precision=fp16` and `--validation_prompt`:
```
ValueError: Attempting to unscale FP16 gradients.
```
Root cause
Fix
Mirror the pre-training upcast immediately after `log_validation` returns: if `args.mixed_precision == "fp16"`, re-run `cast_training_params(..., dtype=torch.float32)` on the trainable modules. bf16 mixed-precision is unaffected because no grad scaler is involved there.
Only one file is touched; 10 lines added, 0 removed. No behavioral change when `--validation_prompt` is not set, or when mixed-precision is anything other than fp16.
Fixes #13124
Before submitting
Who can review?
cc @sayakpaul