[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974vthumbe1503 wants to merge 5 commits into
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR fixes DCP (Distributed Checkpoint) sync and async checkpoint loading for FSDP2 models using MXFP8 and NVFP4 quantization recipes. It achieves this by adding a
Confidence Score: 4/5The core logic is functionally sound and xfail removals appear justified, but the The transformer_engine/pytorch/quantized_tensor.py — the Important Files Changed
Sequence DiagramsequenceDiagram
participant DCP as DCP (async save)
participant QT as QuantizedTensor (GPU)
participant CPU as Plain CPU Tensor
participant Disk as Checkpoint Storage
Note over DCP,Disk: Async Save Path
DCP->>QT: "new_empty(size, device=cpu, pin_memory=True)"
QT->>QT: "dequantize(dtype=target_dtype)"
QT-->>CPU: plain fp32 CPU tensor
DCP->>CPU: copy_ from GPU tensor
CPU-->>Disk: serialize fp32 values
Note over DCP,Disk: Load Path
Disk-->>CPU: deserialize fp32 values
DCP->>QT: "copy_(dst=QuantizedTensor, src=CPU fp32)"
QT->>QT: dst.quantize_(src)
Note over QT: NVFP4/MXFP8 weights restored
Reviews (2): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
Empty storage breaks shared-storage detection in existing callers
QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
There was a problem hiding this comment.
Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | ||
| torch.testing.assert_close( | ||
| loaded_output, | ||
| ref_output, | ||
| rtol=0.125, | ||
| atol=0.25, | ||
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) |
There was a problem hiding this comment.
Typo: "neec" should be "need" — appears in both NVFP4 tolerance blocks.
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | |
| torch.testing.assert_close( | |
| loaded_output, | |
| ref_output, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | |
| ) | |
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances | |
| torch.testing.assert_close( | |
| loaded_output, | |
| ref_output, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}", | |
| ) |
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | ||
| torch.testing.assert_close( | ||
| out2, | ||
| out1, | ||
| rtol=0.125, | ||
| atol=0.25, | ||
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | ||
| ) |
There was a problem hiding this comment.
Same typo ("neec") in the second NVFP4 tolerance block for the post-training-step check.
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances | |
| torch.testing.assert_close( | |
| out2, | |
| out1, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | |
| ) | |
| elif recipe_name == "NVFP4BlockScaling": | |
| # NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances | |
| torch.testing.assert_close( | |
| out2, | |
| out1, | |
| rtol=0.125, | |
| atol=0.25, | |
| msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}", | |
| ) |
| # NVFP4 scale unpad/repad through FSDP2 introduces small numerical | ||
| # differences vs the manual dequantize-then-allgather path. |
There was a problem hiding this comment.
Tolerance relaxed 250× for NVFP4 allgather verification
The tolerance for _check_fp8_fsdp2_allgather on NVFP4Tensor jumped from atol=5e-4, rtol=5e-3 to atol=0.125, rtol=0.25. This test compares param.dequantize() against fp32_allgathered_params[name], validating round-trip numerical fidelity of the all-gather path. A 25% relative tolerance makes the check nearly a no-op for FP4 values. A comment citing the 4-bit mantissa precision ceiling would justify the new values.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci L1 pytorch |
| msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}", | ||
| ) | ||
| elif recipe_name == "NVFP4BlockScaling": | ||
| # NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances |
There was a problem hiding this comment.
Why do we need dequant + quant here?
| # When a CPU copy of a quantized tensor is requested (e.g. by | ||
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. |
There was a problem hiding this comment.
This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:
- It's uncomfortable how
new_emptyandempty_likewould have different behavior. I suppose we could interpretempty_likeas "make a tensor that matches the input" andnew_emptyas "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows. - Would this affect FSDP or CPU offloading?
- Given the weirdness, would it be worthwhile raising a warning if
new_emptyis called outside of DCP?
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. | ||
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() |
There was a problem hiding this comment.
An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).
>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
| target_size = torch.Size(size) if len(size) > 0 else tensor.size() | |
| target_size = size |
| def untyped_storage(self) -> torch.UntypedStorage: | ||
| """Return an empty UntypedStorage on the tensor's device. | ||
|
|
||
| ``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real | ||
| backing storage of its own; the actual bytes live in the inner | ||
| buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are | ||
| an implementation detail of the quantization scheme. Need to define | ||
| this method to avoid DCP staging errors with FSDP2. | ||
| """ | ||
| return torch.UntypedStorage(0, device=self.device) |
There was a problem hiding this comment.
The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.
| # differences vs the manual dequantize-then-allgather path. | ||
| if isinstance(param, NVFP4Tensor): | ||
| tols = dict(atol=5e-4, rtol=5e-3) | ||
| tols = dict(atol=0.125, rtol=0.25) |
There was a problem hiding this comment.
Why are the tolerances so much bigger? Is it also due to the dequant+quant path? If so, the comment above is no longer relevant and should be replaced with a better one (but I would still like an explanation why we cannot just load the nvfp4 values from the checkpoint).
| # When a CPU copy of a quantized tensor is requested (e.g. by | ||
| # torch DCP staging via ``x.new_empty(..., device="cpu")``), we | ||
| # save the high-precision values in a plain CPU dense tensor. | ||
| # For the DCP load path, we will re-quantize the high-precision values. |
There was a problem hiding this comment.
Ok, I see now why you want to dequantize. I don't think this is needed though - we should be able to create the QuantlizedTensor on the CPU and save it, no? I remember that the CPU offloading of the activations faced similar problem and already had to support some CPU ops on the QuantizedTensor anyway.
Description
Fixes DCP Sync and Async checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
DCP Sync Checkpoint loading
DCP Async Checkpointing
Checklist: