Skip to content

Disable the RHT fusion for non-SM100 family devices#2968

Open
ptrendx wants to merge 4 commits into
NVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion
Open

Disable the RHT fusion for non-SM100 family devices#2968
ptrendx wants to merge 4 commits into
NVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion

Conversation

@ptrendx
Copy link
Copy Markdown
Member

@ptrendx ptrendx commented May 8, 2026

Description

Disable the RHT fusion for non-sm100 class devices (the kernel uses too much shared memory to be runnable on e.g. sm120).

Fixes #2956

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add the check on the sm arch when testing for the fusion eligibility.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 May 8, 2026 00:07
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR fixes a runtime crash/failure on SM120 (GB20x) devices by restricting the RHT cast fusion kernel to SM100-family hardware (compute capability 100–110), where the required MMA shared memory fits within device limits.

  • Adds sm_arch() >= 100 && sm_arch() <= 110 to the eligible_for_rht_cast_fusion guard in NVFP4Quantizer::quantize_impl, effectively disabling the fused path on SM120 and other non-SM100 architectures.
  • Pulls in common/util/cuda_runtime.h (previously unused in this translation unit) to expose the sm_arch() utility.

Confidence Score: 5/5

Safe to merge — the change is a narrow, additive guard that only disables a fusion path on hardware where it was already broken.

The fix is a single boolean condition added to an eligibility check; it cannot regress correct behavior on SM100 devices and safely falls back to the non-fused path on everything else. The header inclusion and the sm_arch() API are both well-established in this codebase.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Adds an SM architecture range check (100–110) to eligible_for_rht_cast_fusion, restricting the RHT fusion kernel to SM100-family Blackwell devices and calling sm_arch() twice in the expression.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize_impl called] --> B{Check eligibility for RHT cast fusion}
    B --> C{dtype == BFloat16?}
    C -- No --> F[eligible = false]
    C -- Yes --> D{rows % 64 == 0 AND cols % 128 == 0?}
    D -- No --> F
    D -- Yes --> E{sm_arch in range 100..110?}
    E -- No --> F
    E -- Yes --> G[eligible = true]
    G --> H[Use fused RHT cast kernel]
    F --> I[Use non-fused path]
Loading

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 &&
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The upper bound <= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.

Suggested change
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120;

timmoon10
timmoon10 previously approved these changes May 8, 2026
@osubotin
Copy link
Copy Markdown

Empirical verification on RTX 5080 (sm_120) — confirms PR fix works

We filed the original failure as #2956 (cycle #289 stream X.5.2, May 2 2026). Just verified this PR end-to-end on RTX 5080.

Dim (M=K=N) TE 2.14.1 (unpatched) TE 2.14.1 + this PR
64×64×64 OK OK
96×96×96 OK OK
128×128×128 FAIL: RuntimeError: row_cast_col_hadamard_transform_cast_fusion.cu:1200 in function row_col_rht_gemm_ntt_w_sfc: CUDA Error: invalid argument OK
192×192×192 FAIL: AcceleratorError CUDA invalid argument OK
256×256×256 FAIL: invalid argument OK
384×384×384 FAIL: invalid argument OK
512×512×512 FAIL: invalid argument OK
1024×1024×1024 FAIL: invalid argument OK
1024×4096×1024 (production) FAIL: invalid argument OK

Numerical correctness: separate-ops fallback produces NVFP4 output within expected noise floor — max relative error vs bf16 reference is 0.124–0.163 across all shapes (NVFP4 E2M1 + per-block scale; ~12–16% rel-err vs bf16 is the expected quantization noise, not a correctness bug). Same band on the working 64×64×64 baseline (rel_err = 0.147), so the patch does not introduce additional drift.

Reproducer:

import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling

recipe = NVFP4BlockScaling()
for M, K, N in [(64, 64, 64), (128, 128, 128), (1024, 4096, 1024)]:
    x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
    layer = te.Linear(K, N, params_dtype=torch.bfloat16, device="cuda")
    with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
        y = layer(x)
    torch.cuda.synchronize()
    print(f"OK M={M} K={K} N={N}")

On 2.14.1 unpatched: dim 64 OK, dim 128 onwards FAIL with invalid argument. With this PR applied: all shapes OK.

Hardware: RTX 5080 16GB, sm_120, CUDA 13.0.2, PyTorch 2.11.0+cu130, Linux WSL2 Ubuntu 24.04. TE built editable from a fresh clone of 2.14.1+366798e with the PR diff applied via git apply. Note: TE 2.14.1's quantizer.cpp does not transitively include common/util/cuda_runtime.h (which declares transformer_engine::cuda::sm_arch()), so we had to add that include locally for the PR to compile against the 2.14.1 base. On main this is presumably already pulled in elsewhere — worth a sanity check that the PR builds clean against the actual base it'll merge into.

Performance: separate-ops fallback ms/op for the patched path is within 5% of the working baseline shapes (e.g., 0.176 ms at 128×128×128 vs 0.184 ms at 96×96×96 baseline). Did not benchmark TFLOPS or compare against fouroversix — our project (Volkov VLM training, RTX 5080) uses fouroversix for production NVFP4 on sm_120 because of its fused-kernel sm_120 GEMM (cycle #321 Action 2 + cycle #372 Action 2 three-way validation = bit-identical at max_delta=0.000 vs TE NVFP4 reference).

LGTM for SM120 functional correctness. Reviewer's earlier suggestion of < 120 upper bound is also clean — functionally identical to <= 110 since SM111/SM112 don't exist, just slightly more readable as "everything below sm_120".

Looking forward to seeing this in 2.15 — that graduates TE-NVFP4 from RED to AMBER on consumer Blackwell as a fallback option for users without fouroversix.

@ptrendx
Copy link
Copy Markdown
Member Author

ptrendx commented May 11, 2026

/te-ci pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Copy Markdown
Member Author

ptrendx commented May 11, 2026

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

3 participants