Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <pybind.h>

#include "common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "pybind.h"
#include "torch/torch.h"
Expand Down Expand Up @@ -2243,7 +2244,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou

// Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT
bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 &&
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The upper bound <= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.

Suggested change
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120;


// Stochastic rounding
// When both rowwise and columnwise quantization are used with RHT,
Expand Down
Loading