[PyTorch] Enable head dim 256 for FA4#2932
Conversation
Signed-off-by: Xin Yao <xiny@nvidia.com>
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR replaces the hardcoded, architecture-specific head-dimension checks for Flash Attention 4 in
Confidence Score: 3/5Merging will not cause a crash, but the head_dim=256 test can silently pass without exercising FA4 on non-SM100 hardware, obscuring CI coverage. A P1 finding (test silently falls back to a non-FA4 backend on non-SM100 GPUs, providing false-green CI) combined with P2 findings (MLA workaround now conditionally skipped, installation steps stale) pulls the score below the P1 ceiling of 4.
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[use_flash_attention_4 AND v4_is_installed] -->|No| Z[Skip FA4 head-dim validation]
A -->|Yes| B{v4_validate_head_dims is not None?}
B -->|No| Z
B -->|Yes| C[Compute _fa4_alignment\n= 16 // element_size of qkv_dtype]
C --> D[Call v4_validate_head_dims\nhead_dim_qk, head_dim_v,\nsmXX, alignment]
D -->|AssertionError| E[Log debug: unsupported dims\nuse_flash_attention_4 = False]
D -->|OK| F[use_flash_attention_4 stays True]
E --> Z2[Downstream: FA4 not used]
F --> G{is_training AND\nhead_dim_qk != head_dim_v\nAND head_dim_qk >= 128\nAND SM100/110?}
G -->|Yes| H{dK_reduce_ncol\nmisalignment check}
H -->|misaligned| I[Log debug: SM100 MLA bwd bug\nuse_flash_attention_4 = False]
H -->|OK| J[FA4 used ✓]
G -->|No| J
|
| # head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10) | ||
| "fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"), |
There was a problem hiding this comment.
SM100-only config silently falls back on other architectures
"fa4_base_hdim256" is added to the shared model_configs_fa4_base dict, which is parametrized into test_dpa_fa4_base without any SM100 guard. The inline comment says this config requires SM100 (dedicated kernel, FA4 > 4.0.0b10). On SM90 or other architectures, _validate_head_dims will raise AssertionError, FA4 will be disabled in get_attention_backend, and the test will silently run via a fallback backend — not exercising FA4 at all. The test neither fails nor is skipped, giving a false green signal.
Consider adding a per-test pytest.mark.skipif based on compute capability (or moving this config into a separate SM100-only dict) so the CI result is unambiguous.
Description
TODO: We still wait FA4 to release a new beta version (probably
4.0.0b11) to enable the test in CI.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: