[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 support for FlashAttention 4 on SM100/103 (Blackwell) by delegating head-dimension validation to FA4's own
Confidence Score: 4/5Functional on FA4 4.0.0b11+, but importing The transformer_engine/pytorch/attention/dot_product_attention/backends.py — the combined import of Important Files Changed
Reviews (4): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile |
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's | ||
| # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| if ( |
There was a problem hiding this comment.
Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?
There was a problem hiding this comment.
I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552
|
LGTM |
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: