Disable the RHT fusion for non-SM100 family devices#2968
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR fixes a runtime crash/failure on SM120 (GB20x) devices by restricting the RHT cast fusion kernel to SM100-family hardware (compute capability 100–110), where the required MMA shared memory fits within device limits.
Confidence Score: 5/5Safe to merge — the change is a narrow, additive guard that only disables a fusion path on hardware where it was already broken. The fix is a single boolean condition added to an eligibility check; it cannot regress correct behavior on SM100 devices and safely falls back to the non-fused path on everything else. The header inclusion and the sm_arch() API are both well-established in this codebase. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[quantize_impl called] --> B{Check eligibility for RHT cast fusion}
B --> C{dtype == BFloat16?}
C -- No --> F[eligible = false]
C -- Yes --> D{rows % 64 == 0 AND cols % 128 == 0?}
D -- No --> F
D -- Yes --> E{sm_arch in range 100..110?}
E -- No --> F
E -- Yes --> G[eligible = true]
G --> H[Use fused RHT cast kernel]
F --> I[Use non-fused path]
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| bool eligible_for_rht_cast_fusion = | ||
| input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; | ||
| input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && | ||
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; |
There was a problem hiding this comment.
The upper bound
<= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; | |
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120; |
|
Empirical verification on RTX 5080 (sm_120) — confirms PR fix works We filed the original failure as #2956 (cycle #289 stream X.5.2, May 2 2026). Just verified this PR end-to-end on RTX 5080.
Numerical correctness: separate-ops fallback produces NVFP4 output within expected noise floor — max relative error vs bf16 reference is 0.124–0.163 across all shapes (NVFP4 E2M1 + per-block scale; ~12–16% rel-err vs bf16 is the expected quantization noise, not a correctness bug). Same band on the working 64×64×64 baseline (rel_err = 0.147), so the patch does not introduce additional drift. Reproducer: import torch
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import NVFP4BlockScaling
recipe = NVFP4BlockScaling()
for M, K, N in [(64, 64, 64), (128, 128, 128), (1024, 4096, 1024)]:
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
layer = te.Linear(K, N, params_dtype=torch.bfloat16, device="cuda")
with te.fp8_autocast(enabled=True, fp8_recipe=recipe):
y = layer(x)
torch.cuda.synchronize()
print(f"OK M={M} K={K} N={N}")On 2.14.1 unpatched: dim 64 OK, dim 128 onwards FAIL with Hardware: RTX 5080 16GB, sm_120, CUDA 13.0.2, PyTorch 2.11.0+cu130, Linux WSL2 Ubuntu 24.04. TE built editable from a fresh clone of 2.14.1+366798e with the PR diff applied via Performance: separate-ops fallback ms/op for the patched path is within 5% of the working baseline shapes (e.g., 0.176 ms at 128×128×128 vs 0.184 ms at 96×96×96 baseline). Did not benchmark TFLOPS or compare against fouroversix — our project (Volkov VLM training, RTX 5080) uses fouroversix for production NVFP4 on sm_120 because of its fused-kernel sm_120 GEMM (cycle #321 Action 2 + cycle #372 Action 2 three-way validation = bit-identical at max_delta=0.000 vs TE NVFP4 reference). LGTM for SM120 functional correctness. Reviewer's earlier suggestion of Looking forward to seeing this in 2.15 — that graduates TE-NVFP4 from RED to AMBER on consumer Blackwell as a fallback option for users without fouroversix. |
|
/te-ci pytorch |
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
Description
Disable the RHT fusion for non-sm100 class devices (the kernel uses too much shared memory to be runnable on e.g. sm120).
Fixes #2956
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: