diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..75e28c64ac 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -7,6 +7,7 @@ #include #include "common.h" +#include "common/util/cuda_runtime.h" #include "common/util/system.h" #include "pybind.h" #include "torch/torch.h" @@ -2243,7 +2244,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = - input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && + transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; // Stochastic rounding // When both rowwise and columnwise quantization are used with RHT,