Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,8 @@ def test_dpa_num_splits(dtype, model_configs, model):
"fa4_base_1": ModelConfig(4, 128, 16, 64),
"fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"),
"fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"),
# head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10)
"fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"),
Comment on lines +375 to +376
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 SM100-only config silently falls back on other architectures

"fa4_base_hdim256" is added to the shared model_configs_fa4_base dict, which is parametrized into test_dpa_fa4_base without any SM100 guard. The inline comment says this config requires SM100 (dedicated kernel, FA4 > 4.0.0b10). On SM90 or other architectures, _validate_head_dims will raise AssertionError, FA4 will be disabled in get_attention_backend, and the test will silently run via a fallback backend — not exercising FA4 at all. The test neither fails nor is skipped, giving a false green signal.

Consider adding a per-test pytest.mark.skipif based on compute capability (or moving this config into a separate SM100-only dict) so the CI result is unambiguous.

# GQA
"fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"),
"fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,10 @@
from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module
flash_attn_func as flash_attn_func_v4,
flash_attn_varlen_func as flash_attn_varlen_func_v4,
_validate_head_dims as _fa4_validate_head_dims,
)

fa_utils.v4_validate_head_dims = _fa4_validate_head_dims
fa_utils.set_flash_attention_4_params()

# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16
Expand Down
51 changes: 29 additions & 22 deletions transformer_engine/pytorch/attention/dot_product_attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class FlashAttentionUtils:
v4_installation_steps = """\
pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]"""
v4_warning_printed = False
# Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims
# which raises AssertionError for unsupported (head_dim, head_dim_v) combinations.
v4_validate_head_dims = None

@staticmethod
def set_flash_attention_version():
Expand Down Expand Up @@ -792,21 +795,24 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
)
use_flash_attention_3 = False

if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed:
# FA4 head dimension support is architecture-dependent
# (matches _validate_head_dims in flash_attn.cute.interface):
# SM90: head_dim <= 256 and head_dim_v <= 256
# SM100/110: head_dim <= 128 and head_dim_v <= 128,
# OR DeepSeek MLA shape (head_dim=192, head_dim_v=128)
# SM80/120: constrained by shared memory (~256 max in practice)
_fa4_hdim_ok = True
if (10, 0) <= device_compute_capability < (12, 0):
_is_standard = head_dim_qk <= 128 and head_dim_v <= 128
_is_deepseek = head_dim_qk == 192 and head_dim_v == 128
_fa4_hdim_ok = _is_standard or _is_deepseek
else:
_fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256
if not _fa4_hdim_ok:
if (
use_flash_attention_4
and FlashAttentionUtils.v4_is_installed
and FlashAttentionUtils.v4_validate_head_dims is not None
):
# Defer to FA4's own _validate_head_dims to keep TE in sync with FA4 supported shapes
# (e.g., (256, 256) on SM100, (192, 128) DeepSeek, (64, 512) MLA-absorbed).
# The function asserts on unsupported combinations; SM80/SM120 have no validation branch
# in FA4 so the call passes through silently for those archs.
_fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size()
try:
FlashAttentionUtils.v4_validate_head_dims(
head_dim_qk,
head_dim_v,
device_compute_capability[0],
_fa4_alignment,
)
except AssertionError:
logger.debug(
"Disabling FlashAttention 4 due to unsupported head dimensions. "
"Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.",
Expand All @@ -815,13 +821,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt
device_compute_capability[0] * 10 + device_compute_capability[1],
)
use_flash_attention_4 = False
# Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128).
# FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2)
# based on Q/K head_dim but reuses it for dV TMEM load atoms. When
# (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned.
# See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
elif (
_fa4_hdim_ok
# Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128) for the
# standard (non-dedicated) kernel path. FlashAttentionBackwardSm100 computes
# dK_reduce_ncol = gcd(32, tile_hdim // 2) based on Q/K head_dim but reuses it for
# dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are
# misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's
# not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890.
if (
use_flash_attention_4
and is_training
and head_dim_qk != head_dim_v
and head_dim_qk >= 128
Expand Down
Loading