Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Open
yaox12 wants to merge 5 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa
Open

[PyTorch] Enable head dim 256 for FA4#2932
yaox12 wants to merge 5 commits into
NVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 27, 2026

Description

Need FA4 version 4.0.0b11.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR enables head_dim=256 support for FlashAttention 4 on SM100/103 (Blackwell) by delegating head-dimension validation to FA4's own _validate_head_dims function (available in flash-attn-4 > 4.0.0b10), and adds a properly guarded test_dpa_fa4_hdim256 test that addresses the previous review thread. It also removes stale cuDNN ≥8.9.1 guards from FA4 tests (FA4 is a standalone CUDA kernel independent of cuDNN).

  • backends.py: Imports _validate_head_dims from flash_attn.cute.interface and stores it as FlashAttentionUtils.v4_validate_head_dims; the call is now deferred to FA4's own assertion-based validator instead of TE's inline check.
  • utils.py: Replaces the static SM architecture table with a live call to v4_validate_head_dims, adds the Callable type annotation, updates the installation hint to 4.0.0b11, and restructures the MLA backward-kernel workaround from elif to a standalone if so it runs independently of the head-dim check.
  • test_attention.py: Adds test_dpa_fa4_hdim256 with an explicit device_compute_capability not in ((10, 0), (10, 3)) skip guard; bumps FA QA test suite to 4.0.0b11 on both SM90 and SM100+ branches.

Confidence Score: 4/5

Functional on FA4 4.0.0b11+, but importing _validate_head_dims in the same block as the main FA4 symbols will crash backends.py module load for any environment with an older FA4 installed.

The _validate_head_dims symbol is imported in the same from flash_attn.cute.interface import (...) statement as flash_attn_func and flash_attn_varlen_func. The outer except PackageNotFoundError does not catch ImportError, so any user with FA4 older than 4.0.0b11 will hit an unhandled exception at module load of backends.py, breaking all attention backends — not just FA4. This issue was flagged in the previous review round and remains unaddressed.

transformer_engine/pytorch/attention/dot_product_attention/backends.py — the combined import of _validate_head_dims alongside the two main FA4 symbols needs its own try/except to preserve the graceful fallback.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Imports _validate_head_dims in the same from flash_attn.cute.interface import (...) block as the main FA4 functions; an ImportError (not caught by the outer except PackageNotFoundError) crashes backends.py loading for any user with FA4 < 4.0.0b11.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds v4_validate_head_dims class attribute and delegates head-dim validation to FA4's own _validate_head_dims; installation steps correctly updated to 4.0.0b11; MLA workaround restructured with if (was elif) to run independently after the head-dim check.
tests/pytorch/attention/test_attention.py Adds test_dpa_fa4_hdim256 with explicit SM100/103 skipif guard (addressing the previous thread); removes stale cuDNN ≥8.9.1 guards from all FA4 tests (FA4 is a standalone CUDA kernel, not cuDNN-dependent).
qa/L3_pytorch_FA_versions_test/test.sh Bumps FA4 test version from 4.0.0b8 to 4.0.0b11 for both SM90 and SM100+ branches, matching the minimum version required for _validate_head_dims and the new hdim=256 dedicated kernel.

Reviews (4): Last reviewed commit: "Merge branch 'main' into xiny/headdim256..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/test_attention.py Outdated
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from 3b3f7d0 to 9a93156 Compare May 6, 2026 02:44
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from ae74e44 to 8aa5242 Compare May 6, 2026 02:55
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

/te-ci pytorch L3

@yaox12 yaox12 marked this pull request as ready for review May 6, 2026 02:59
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 6, 2026

@vcherepanov-nv @KshitijLakhani Please review.

@KshitijLakhani KshitijLakhani requested a review from mk-61 May 8, 2026 06:34
Comment thread tests/pytorch/attention/test_attention.py Outdated
Comment thread tests/pytorch/attention/test_attention.py
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked that this is a bug of FA4. Kernels produce wrong results on these shapes but they're allowed by v4_validate_head_dims, so we have to filter them out manually.
Raise an issue to FA4. Dao-AILab/flash-attention#2552

@vcherepanov-nv
Copy link
Copy Markdown
Collaborator

LGTM

yaox12 added 2 commits May 10, 2026 22:28
@yaox12
Copy link
Copy Markdown
Member Author

yaox12 commented May 11, 2026

/te-ci pytorch L3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants