Skip to content

[PyTorch] Enable head dim 256 for FA4#2932

Draft
yaox12 wants to merge 1 commit intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa
Draft

[PyTorch] Enable head dim 256 for FA4#2932
yaox12 wants to merge 1 commit intoNVIDIA:mainfrom
yaox12:xiny/headdim256_fa

Conversation

@yaox12
Copy link
Copy Markdown
Member

@yaox12 yaox12 commented Apr 27, 2026

Description

TODO: We still wait FA4 to release a new beta version (probably 4.0.0b11) to enable the test in CI.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 marked this pull request as draft April 27, 2026 09:31
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/headdim256_fa branch from bdcc02e to 3b3f7d0 Compare April 27, 2026 09:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR replaces the hardcoded, architecture-specific head-dimension checks for Flash Attention 4 in get_attention_backend with a delegated call to FA4's own _validate_head_dims, and adds a test config for head_dim=256 (SM100 dedicated kernel, FA4 > 4.0.0b10).

  • The new fa4_base_hdim256 test config is inserted into the shared model_configs_fa4_base dict without an SM100 guard; on non-SM100 hardware FA4 is silently disabled by the try/except and the test runs via a fallback backend, giving a false-green CI signal instead of a skip.
  • The SM100 MLA backward-kernel workaround is now nested inside the v4_validate_head_dims is not None condition, making it contingent on that attribute being populated when it was previously unconditional on valid head dims.
  • v4_installation_steps still advertises 4.0.0b8, but head_dim=256 requires FA4 > 4.0.0b10 (noted in the PR description as a pending release).

Confidence Score: 3/5

Merging will not cause a crash, but the head_dim=256 test can silently pass without exercising FA4 on non-SM100 hardware, obscuring CI coverage.

A P1 finding (test silently falls back to a non-FA4 backend on non-SM100 GPUs, providing false-green CI) combined with P2 findings (MLA workaround now conditionally skipped, installation steps stale) pulls the score below the P1 ceiling of 4.

tests/pytorch/attention/test_attention.py — the fa4_base_hdim256 config needs an SM100 compute-capability skip or a separate test function. utils.py — MLA workaround placement and v4_installation_steps string.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Replaces hardcoded architecture-specific head-dim checks with a delegated call to FA4's own _validate_head_dims; MLA backward workaround is now nested inside the v4_validate_head_dims is not None guard, and v4_installation_steps still advertises 4.0.0b8.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Imports _validate_head_dims from FA4 and assigns it to FlashAttentionUtils.v4_validate_head_dims; straightforward and correct.
tests/pytorch/attention/test_attention.py Adds fa4_base_hdim256 config (head_dim=256, SM100-only) to the shared model_configs_fa4_base dict without an SM100 compute-capability skip, risking silent fallback to a non-FA4 backend on other GPUs.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[use_flash_attention_4 AND v4_is_installed] -->|No| Z[Skip FA4 head-dim validation]
    A -->|Yes| B{v4_validate_head_dims is not None?}
    B -->|No| Z
    B -->|Yes| C[Compute _fa4_alignment\n= 16 // element_size of qkv_dtype]
    C --> D[Call v4_validate_head_dims\nhead_dim_qk, head_dim_v,\nsmXX, alignment]
    D -->|AssertionError| E[Log debug: unsupported dims\nuse_flash_attention_4 = False]
    D -->|OK| F[use_flash_attention_4 stays True]
    E --> Z2[Downstream: FA4 not used]
    F --> G{is_training AND\nhead_dim_qk != head_dim_v\nAND head_dim_qk >= 128\nAND SM100/110?}
    G -->|Yes| H{dK_reduce_ncol\nmisalignment check}
    H -->|misaligned| I[Log debug: SM100 MLA bwd bug\nuse_flash_attention_4 = False]
    H -->|OK| J[FA4 used ✓]
    G -->|No| J
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 743-793 (link)

    P2 MLA workaround now gated behind v4_validate_head_dims is not None

    The SM100 backward-kernel MLA workaround (lines 775–793) is now nested inside the outer if that requires FlashAttentionUtils.v4_validate_head_dims is not None. If FA4 is installed but _validate_head_dims is absent (e.g., a downgraded or modified FA4 build where the import in backends.py somehow falls back gracefully), the workaround is silently skipped, risking misaligned dV reads in training.

    The practical risk is low because the import in backends.py hard-fails on an ImportError if the symbol is missing, ensuring v4_validate_head_dims is either set or FA4 is not usable at all. Still, moving the MLA workaround block outside the v4_validate_head_dims is not None guard (guarded only by use_flash_attention_4 and v4_is_installed) would make the invariant explicit and match the original intent: the workaround applied to every valid FA4 invocation, independently of which validation path checked head dims.

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 144-145 (link)

    P2 v4_installation_steps not updated for the required FA4 version

    The class attribute still advertises flash-attn-4==4.0.0b8, but the PR description and test comment indicate head_dim=256 requires FA4 > 4.0.0b10 (likely 4.0.0b11). Users who follow these installation steps will install a version too old to use the new feature, and will likely see the test fa4_base_hdim256 silently fall back without a clear error message.

Reviews (1): Last reviewed commit: "enable head dim 256 for FA4" | Re-trigger Greptile

Comment on lines +375 to +376
# head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10)
"fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 SM100-only config silently falls back on other architectures

"fa4_base_hdim256" is added to the shared model_configs_fa4_base dict, which is parametrized into test_dpa_fa4_base without any SM100 guard. The inline comment says this config requires SM100 (dedicated kernel, FA4 > 4.0.0b10). On SM90 or other architectures, _validate_head_dims will raise AssertionError, FA4 will be disabled in get_attention_backend, and the test will silently run via a fallback backend — not exercising FA4 at all. The test neither fails nor is skipped, giving a false green signal.

Consider adding a per-test pytest.mark.skipif based on compute capability (or moving this config into a separate SM100-only dict) so the CI result is unambiguous.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant