diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c9ea791444..6d4f3b45a9 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -372,6 +372,8 @@ def test_dpa_num_splits(dtype, model_configs, model): "fa4_base_1": ModelConfig(4, 128, 16, 64), "fa4_base_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), "fa4_base_3": ModelConfig(2, 1024, 8, 96, attn_mask_type="causal"), + # head_dim=256 (SM100 only via dedicated kernel; flash-attn-4 > 4.0.0b10) + "fa4_base_hdim256": ModelConfig(2, 1024, 8, 256, attn_mask_type="causal"), # GQA "fa4_gqa_1": ModelConfig(2, 1024, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), "fa4_gqa_2": ModelConfig(2, 1024, 16, 128, num_gqa_groups=1, attn_mask_type="causal"), diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..61e7651207 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -167,8 +167,10 @@ from flash_attn.cute.interface import ( # pylint: disable=ungrouped-imports,no-name-in-module flash_attn_func as flash_attn_func_v4, flash_attn_varlen_func as flash_attn_varlen_func_v4, + _validate_head_dims as _fa4_validate_head_dims, ) + fa_utils.v4_validate_head_dims = _fa4_validate_head_dims fa_utils.set_flash_attention_4_params() # Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..996c6fac37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -149,6 +149,9 @@ class FlashAttentionUtils: v4_installation_steps = """\ pip install flash-attn-4==4.0.0b8 nvidia-cutlass-dsl[cu13]""" v4_warning_printed = False + # Set by backends.py if FA4 is installed; calls flash_attn.cute.interface._validate_head_dims + # which raises AssertionError for unsupported (head_dim, head_dim_v) combinations. + v4_validate_head_dims = None @staticmethod def set_flash_attention_version(): @@ -792,21 +795,24 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_3 = False - if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: - # FA4 head dimension support is architecture-dependent - # (matches _validate_head_dims in flash_attn.cute.interface): - # SM90: head_dim <= 256 and head_dim_v <= 256 - # SM100/110: head_dim <= 128 and head_dim_v <= 128, - # OR DeepSeek MLA shape (head_dim=192, head_dim_v=128) - # SM80/120: constrained by shared memory (~256 max in practice) - _fa4_hdim_ok = True - if (10, 0) <= device_compute_capability < (12, 0): - _is_standard = head_dim_qk <= 128 and head_dim_v <= 128 - _is_deepseek = head_dim_qk == 192 and head_dim_v == 128 - _fa4_hdim_ok = _is_standard or _is_deepseek - else: - _fa4_hdim_ok = head_dim_qk <= 256 and head_dim_v <= 256 - if not _fa4_hdim_ok: + if ( + use_flash_attention_4 + and FlashAttentionUtils.v4_is_installed + and FlashAttentionUtils.v4_validate_head_dims is not None + ): + # Defer to FA4's own _validate_head_dims to keep TE in sync with FA4 supported shapes + # (e.g., (256, 256) on SM100, (192, 128) DeepSeek, (64, 512) MLA-absorbed). + # The function asserts on unsupported combinations; SM80/SM120 have no validation branch + # in FA4 so the call passes through silently for those archs. + _fa4_alignment = 16 // torch.empty(0, dtype=qkv_dtype).element_size() + try: + FlashAttentionUtils.v4_validate_head_dims( + head_dim_qk, + head_dim_v, + device_compute_capability[0], + _fa4_alignment, + ) + except AssertionError: logger.debug( "Disabling FlashAttention 4 due to unsupported head dimensions. " "Found: head_dim_qk = %s, head_dim_v = %s, on sm%s.", @@ -815,13 +821,14 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt device_compute_capability[0] * 10 + device_compute_capability[1], ) use_flash_attention_4 = False - # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128). - # FlashAttentionBackwardSm100 computes dK_reduce_ncol = gcd(32, tile_hdim // 2) - # based on Q/K head_dim but reuses it for dV TMEM load atoms. When - # (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are misaligned. - # See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. - elif ( - _fa4_hdim_ok + # Workaround: SM100 backward kernel bug when MLA + 2CTA (head_dim_qk >= 128) for the + # standard (non-dedicated) kernel path. FlashAttentionBackwardSm100 computes + # dK_reduce_ncol = gcd(32, tile_hdim // 2) based on Q/K head_dim but reuses it for + # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are + # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's + # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. + if ( + use_flash_attention_4 and is_training and head_dim_qk != head_dim_v and head_dim_qk >= 128