Skip to content

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974

Open
vthumbe1503 wants to merge 5 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix
Open

[Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit#2974
vthumbe1503 wants to merge 5 commits into
NVIDIA:mainfrom
vthumbe1503:fsdp2_dcp_laod_fix

Conversation

@vthumbe1503
Copy link
Copy Markdown
Collaborator

@vthumbe1503 vthumbe1503 commented May 11, 2026

Description

Fixes DCP Sync and Async checkpoint loading for MXFP8/NVFP4.
Fixes DCP Async checkpoint loading for all Quantization recipes

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • DCP Sync Checkpoint loading

    • untyped_storage is now defined for the base QuantizedTensor to return empty storage. Untyped_storage refers to the backing storage that we use to create all the internal tensors. Since we use make_wrapper_subclass to create TE QuantizedTensors, we use dont have any backing storage associated with the tensor. data_ptr on our Custom QuantizedTensor also returns 0.
    • The main issue is that FSDP2 maintains sharded param tensor for checkpointing. It does so by calling view(-1) on our Quantized sharded model parameters. We return back a dequantized 1D tensor in TE. So, the sharded tensor that FSDP2 maintains for checkpointing is BF16 and Quantized sharded param is our custom FP8 tensor. It evaluates untyped_storage(BF16 sharded tensor reloaded from disk) == untyped_storage(Quantized sharded parameter) to see if the same_tensor. With us returning empty storage now, this would never be equal to sharded tensor's untyped storage.
  • DCP Async Checkpointing

    • to_new_empty function with device="cpu" is being used in Async Checkpointing. This function returned Quantizer.make_empty without setting the device. For device = "cpu" we now dequantize. So that the Async checkpointing directly saves the bf16 data on disk and reload works fine.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 11, 2026

Greptile Summary

This PR fixes DCP (Distributed Checkpoint) sync and async checkpoint loading for FSDP2 models using MXFP8 and NVFP4 quantization recipes. It achieves this by adding a untyped_storage() override to the QuantizedTensor base class (returning a 0-byte storage) and a new aten.new_empty.default CPU-device branch that dequantizes the tensor into a plain float CPU buffer for DCP staging, with re-quantization happening on load via copy_/quantize_().

  • quantized_tensor.py: Adds untyped_storage() (0-byte) to the base class and a CPU-device path in new_empty dispatch that dequantizes to a plain tensor for DCP staging; removes the subclass-specific override from Float8BlockwiseQTensor.
  • Test xfails removed: Several previously-broken FSDP2 DCP paths for MXFP8/NVFP4/Float8BlockScaling are now unblocked, validated by removing pytest.xfail markers.
  • Tolerances relaxed: NVFP4 allgather and DCP round-trip checks use atol=0.125, rtol=0.25 to account for the lossy dequantize\u2192quantize staging round-trip inherent to 4-bit precision.

Confidence Score: 4/5

The core logic is functionally sound and xfail removals appear justified, but the new_empty CPU branch diverges from PyTorch uninitialized-buffer semantics in a way that could silently break if DCP staging order changes.

The aten.new_empty.default handler returns a pre-populated dequantized tensor rather than an uninitialized buffer. Today this works because DCP async staging overwrites the buffer via a follow-up copy_, but the contract is undocumented and a future change to DCP internals could cause staging to silently use stale data. The NVFP4 tolerance relaxations are expected given 4-bit precision limits.

transformer_engine/pytorch/quantized_tensor.py — the new_empty CPU branch and untyped_storage override warrant a second look for correctness across all DCP code paths.

Important Files Changed

Filename Overview
transformer_engine/pytorch/quantized_tensor.py Adds untyped_storage() (0-byte storage) to the base class and a CPU-device branch in new_empty dispatch that dequantizes and returns a plain tensor for DCP staging. The new_empty semantic divergence (returning filled data instead of an uninitialized buffer) is the main concern.
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Removes the Float8BlockwiseQTensor-specific untyped_storage() override, deferring to the new base-class implementation that returns a 0-byte storage. Straightforward cleanup.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py Removes several xfail markers for MXFP8/NVFP4 DCP paths that were previously broken, and adds relaxed tolerance (atol=0.25, rtol=0.125) branches for NVFP4 output comparison; contains a minor typo ("neec") in two comment lines.
tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py Removes three xfail blocks for MXFP8/Float8BlockScaling/NVFP4 FSDP2 init combinations; relaxes NVFP4 allgather tolerance from atol=5e-4 to atol=0.125. The 250× tolerance jump for allgather verification is notable but reflects 4-bit quantization precision limits.

Sequence Diagram

sequenceDiagram
    participant DCP as DCP (async save)
    participant QT as QuantizedTensor (GPU)
    participant CPU as Plain CPU Tensor
    participant Disk as Checkpoint Storage

    Note over DCP,Disk: Async Save Path
    DCP->>QT: "new_empty(size, device=cpu, pin_memory=True)"
    QT->>QT: "dequantize(dtype=target_dtype)"
    QT-->>CPU: plain fp32 CPU tensor
    DCP->>CPU: copy_ from GPU tensor
    CPU-->>Disk: serialize fp32 values

    Note over DCP,Disk: Load Path
    Disk-->>CPU: deserialize fp32 values
    DCP->>QT: "copy_(dst=QuantizedTensor, src=CPU fp32)"
    QT->>QT: dst.quantize_(src)
    Note over QT: NVFP4/MXFP8 weights restored
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into fsdp2_dcp_laod_..." | Re-trigger Greptile

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Empty storage breaks shared-storage detection in existing callers

QuantizedTensor.untyped_storage() now returns a freshly allocated zero-byte storage every call. Code in module/_common.py:128 compares tensors[0].untyped_storage().nbytes() against expected size to decide between a no-op view and an out-of-place torch.cat. With 0 bytes returned, that condition is always true, silently disabling the in-place fast path for any QuantizedTensor through ConcatMerge.forward. More critically, utils.py:403-412 in SplitAlongDim.backward uses data_ptr() for noop detection — if all zero-size CUDA allocations return data_ptr() == 0, every QuantizedTensor pair incorrectly appears co-located, setting noop_ok = True and crashing on ret.set_() against a 0-byte storage.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, while I don't think we use QuantizedTensors in the SplitAlongDim ever, the concat sounds plausible to be hit.

Comment on lines +820 to +828
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Typo: "neec" should be "need" — appears in both NVFP4 tolerance blocks.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
loaded_output,
ref_output,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Fresh model loaded from DCP checkpoint produces different output: {x}",
)

Comment on lines +867 to +875
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same typo ("neec") in the second NVFP4 tolerance block for the post-training-step check.

Suggested change
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so need to relax tolerances
torch.testing.assert_close(
out2,
out1,
rtol=0.125,
atol=0.25,
msg=lambda x: f"NVFP4BlockScaling: Training step after DCP load produces different output: {x}",
)

Comment on lines 243 to 244
# NVFP4 scale unpad/repad through FSDP2 introduces small numerical
# differences vs the manual dequantize-then-allgather path.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Tolerance relaxed 250× for NVFP4 allgather verification

The tolerance for _check_fp8_fsdp2_allgather on NVFP4Tensor jumped from atol=5e-4, rtol=5e-3 to atol=0.125, rtol=0.25. This test compares param.dequantize() against fp32_allgathered_params[name], validating round-trip numerical fidelity of the all-gather path. A 25% relative tolerance makes the check nearly a no-op for FP4 values. A comment citing the 4-bit mantissa precision ceiling would justify the new values.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@vthumbe1503 vthumbe1503 changed the title [Pytorch][Bug] DCP Checkpoint Load Fixes for FSDP2 with QuantizedModelInit [Pytorch][Bug] DCP Checkpoint Loading Fixes for FSDP2 with QuantizedModelInit May 11, 2026
@vthumbe1503
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 pytorch

msg=lambda x: f"Fresh model loaded from DCP checkpoint produces different output: {x}",
)
elif recipe_name == "NVFP4BlockScaling":
# NVFP4 DCP load goes through a dequant + quant, so neec to relax tolerances
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need dequant + quant here?

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 May 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix seems ad hoc to me. It's not obvious why qtensor.new_empty(..., device="cpu") returns a quantized tensor while qtensor.new_empty(..., device="cuda") returns a plain tensor. I wonder if it would be cleaner to just return a plain tensor in all cases. Thoughts:

  • It's uncomfortable how new_empty and empty_like would have different behavior. I suppose we could interpret empty_like as "make a tensor that matches the input" and new_empty as "call torch.empty with defaults taken from input", but that would be a private interpretation that no one else follows.
  • Would this affect FSDP or CPU offloading?
  • Given the weirdness, would it be worthwhile raising a warning if new_empty is called outside of DCP?

# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An empty size is valid and it corresponds to a tensor with 1 entry (for the same reason 2^0=1).

>>> import torch
>>> x = torch.ones(123).new_empty([])
>>> print(x.numel())
1
Suggested change
target_size = torch.Size(size) if len(size) > 0 else tensor.size()
target_size = size

Comment on lines +536 to +545
def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The correct behavior for these functions is to fall back to the slow path for QuantizedTensor s, unless it has a dedicated implementation to handle quantized data.

# differences vs the manual dequantize-then-allgather path.
if isinstance(param, NVFP4Tensor):
tols = dict(atol=5e-4, rtol=5e-3)
tols = dict(atol=0.125, rtol=0.25)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the tolerances so much bigger? Is it also due to the dequant+quant path? If so, the comment above is no longer relevant and should be replaced with a better one (but I would still like an explanation why we cannot just load the nvfp4 values from the checkpoint).

Comment on lines +613 to +616
# When a CPU copy of a quantized tensor is requested (e.g. by
# torch DCP staging via ``x.new_empty(..., device="cpu")``), we
# save the high-precision values in a plain CPU dense tensor.
# For the DCP load path, we will re-quantize the high-precision values.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I see now why you want to dequantize. I don't think this is needed though - we should be able to create the QuantlizedTensor on the CPU and save it, no? I remember that the CPU offloading of the activations faced similar problem and already had to support some CPU ops on the QuantizedTensor anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants