diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 7b9b711c22..a91f0e04dc 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 7b9b711c22b6823e87150213ecd8449260db8610 +Subproject commit a91f0e04dcea10515f0f776fc5a89535e316a9c8 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 377c9ddb00..8c95f70e28 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -50,6 +50,7 @@ NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_headdim256_cudnn_fe.xml $TE_PATH/tests/pytorch/attention/test_headdim256_cudnn_fe.py || test_fail "test_headdim256_cudnn_fe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then diff --git a/tests/pytorch/attention/test_headdim256_cudnn_fe.py b/tests/pytorch/attention/test_headdim256_cudnn_fe.py new file mode 100644 index 0000000000..43c8f9e4b0 --- /dev/null +++ b/tests/pytorch/attention/test_headdim256_cudnn_fe.py @@ -0,0 +1,154 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Smoke test for head_dim=256 via the cuDNN frontend Python SDPA (CuTe DSL) on SM100+.""" + +from __future__ import annotations + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.pytorch.attention.dot_product_attention import cudnn_fe_sdpa + + +def _sm100_or_newer() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +pytestmark = pytest.mark.skipif( + not (_sm100_or_newer() and cudnn_fe_sdpa.is_available()), + reason="Requires SM100+ GPU and cudnn frontend Python SDPA d=256 kernels", +) + + +def _reference(q, k, v, mask=None, scale=None): + """Plain-attention reference in FP32.""" + d = q.shape[-1] + scale = scale if scale is not None else 1.0 / (d**0.5) + q32 = q.float() + k32 = k.float() + v32 = v.float() + # q: (B, H, S, D), k: (B, H, S, D), v: (B, H, S, D) + s = torch.einsum("bhqd,bhkd->bhqk", q32, k32) * scale + if mask is not None: + s = s.masked_fill(mask, float("-inf")) + p = torch.softmax(s, dim=-1) + out = torch.einsum("bhqk,bhkd->bhqd", p, v32) + return out.to(q.dtype) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("attn_mask_type", ["no_mask", "causal"]) +@pytest.mark.parametrize("seqlen", [512, 2048]) +def test_cudnn_fe_fwd_bwd_bshd(dtype, attn_mask_type, seqlen): + batch, heads, head_dim = 2, 4, 256 + torch.manual_seed(0) + device = "cuda" + + # BSHD layout (torch-contiguous) + q = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + v = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + d_o = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device) + + window_size = (-1, 0) if attn_mask_type == "causal" else (-1, -1) + + # cuDNN-FE direct call + out, aux_ctx = cudnn_fe_sdpa.fused_attn_fwd( + max_seqlen_q=seqlen, + max_seqlen_kv=seqlen, + cu_seqlens_q=None, + cu_seqlens_kv=None, + q=q, + k=k, + v=v, + qkv_format="bshd", + attn_mask_type=attn_mask_type, + attn_scale=None, + window_size=window_size, + ) + assert out.shape == q.shape, f"Unexpected fwd out shape {out.shape}" + + dq, dk, dv = cudnn_fe_sdpa.fused_attn_bwd( + max_seqlen_q=seqlen, + max_seqlen_kv=seqlen, + cu_seqlens_q=None, + cu_seqlens_kv=None, + q=q, + k=k, + v=v, + o=out, + d_o=d_o, + aux_ctx_tensors=aux_ctx, + qkv_format="bshd", + attn_mask_type=attn_mask_type, + attn_scale=None, + window_size=window_size, + ) + assert dq.shape == q.shape + assert dk.shape == k.shape + assert dv.shape == v.shape + + # FP32 reference over (B, H, S, D) + q_ref = q.detach().float().transpose(1, 2).contiguous().requires_grad_(True) + k_ref = k.detach().float().transpose(1, 2).contiguous().requires_grad_(True) + v_ref = v.detach().float().transpose(1, 2).contiguous().requires_grad_(True) + + mask = None + if attn_mask_type == "causal": + mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=device), diagonal=1) + out_ref = _reference(q_ref, k_ref, v_ref, mask=mask) + # transpose back to BSHD for comparison + out_ref_bshd = out_ref.transpose(1, 2).contiguous().to(dtype) + out_ref.backward(d_o.transpose(1, 2).contiguous().float()) + + tol = {"atol": 5e-2, "rtol": 5e-2} + torch.testing.assert_close(out.float(), out_ref_bshd.float(), **tol) + torch.testing.assert_close( + dq.float(), q_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol + ) + torch.testing.assert_close( + dk.float(), k_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol + ) + torch.testing.assert_close( + dv.float(), v_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol + ) + + +def test_cudnn_fe_fused_attention_module(monkeypatch): + """Integration test: exercise through the DotProductAttention module.""" + # Scope env var mutations to this test — pytest shares a process across + # tests, so a bare ``os.environ[...] = ...`` would leak these flags into + # every later test in the session. + monkeypatch.setenv("NVTE_FUSED_ATTN", "1") + monkeypatch.setenv("NVTE_FLASH_ATTN", "0") + monkeypatch.setenv("NVTE_UNFUSED_ATTN", "0") + + dtype = torch.bfloat16 + batch, seqlen, heads, head_dim = 2, 1024, 4, 256 + device = "cuda" + + torch.manual_seed(42) + q = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + k = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + v = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True) + + dpa = te.DotProductAttention( + num_attention_heads=heads, + kv_channels=head_dim, + qkv_format="bshd", + attention_type="self", + ).cuda() + out = dpa(q, k, v, attention_mask=None) + assert out.shape == (batch, seqlen, heads * head_dim) + + loss = out.sum() + loss.backward() + assert q.grad is not None and q.grad.shape == q.shape + assert k.grad is not None and k.grad.shape == k.shape + assert v.grad is not None and v.grad.shape == v.shape diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 3d6e3a0aac..eb1205c734 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -339,7 +339,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && - cudnn_runtime_version >= 91100)) && + cudnn_runtime_version >= 91100) || + // cuDNN-FE Python SDPA (CuTe DSL): d=256 + Blackwell + training + non-paged + // C++ kernel does not handle this; intercepted in Python for fwd+bwd. + (head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 && + layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) && // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && diff --git a/transformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.py b/transformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.py new file mode 100644 index 0000000000..bd2c55845b --- /dev/null +++ b/transformer_engine/pytorch/attention/dot_product_attention/cudnn_fe_sdpa.py @@ -0,0 +1,277 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Dispatch layer for the cuDNN frontend Python SDPA (CuTe DSL), head_dim=256. + +The kernels in ``cudnn.sdpa`` are Python-only (CuTe DSL), so invocation happens +entirely on the Python side. This module adapts TransformerEngine's +``fused_attn_fwd``/``fused_attn_bwd`` calling convention to the wrappers +``sdpa_fwd_wrapper_sm100_d256``/``sdpa_bwd_wrapper_sm100_d256`` and hides the +layout massaging required by the kernel. +""" + +from __future__ import annotations + +import functools +from typing import List, Optional, Tuple + +import torch + + +@functools.lru_cache(maxsize=None) +def _sdpa_fwd_wrapper(): + from cudnn import sdpa_fwd_wrapper_sm100_d256 # pylint: disable=no-name-in-module + + return sdpa_fwd_wrapper_sm100_d256 + + +@functools.lru_cache(maxsize=None) +def _sdpa_bwd_wrapper(): + from cudnn import sdpa_bwd_wrapper_sm100_d256 # pylint: disable=no-name-in-module + + return sdpa_bwd_wrapper_sm100_d256 + + +@functools.lru_cache(maxsize=None) +def is_available() -> bool: + """Whether the cuDNN-FE Python SDPA d=256 kernels can be imported.""" + try: + _sdpa_fwd_wrapper() + _sdpa_bwd_wrapper() + except ImportError: + return False + return True + + +_SUPPORTED_MASKS = ( + "no_mask", + "causal", + "causal_bottom_right", + "padding", + "padding_causal", + "padding_causal_bottom_right", +) + + +def is_supported( + *, + head_dim_qk: int, + head_dim_v: int, + qkv_dtype: torch.dtype, + qkv_format: str, + attn_mask_type: str, + attn_bias_type: str, + softmax_type: str, + dropout: float, + window_size: Tuple[int, int], + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + deterministic: bool, + device_compute_capability: Tuple[int, int], + return_max_logit: bool = False, +) -> bool: + """Whether the cuDNN-FE SDPA d=256 kernel can service this configuration.""" + if device_compute_capability[0] < 10: + return False + if head_dim_qk != 256 or head_dim_v != 256: + return False + if qkv_dtype not in (torch.float16, torch.bfloat16): + return False + if qkv_format not in ("bshd", "thd"): + return False + if attn_bias_type != "no_bias": + return False + if softmax_type != "vanilla": + return False + if dropout != 0.0: + return False + if return_max_logit: + return False + if attn_mask_type not in _SUPPORTED_MASKS: + return False + if qkv_format == "thd" and "padding" not in attn_mask_type: + return False + if qkv_format == "bshd" and "padding" in attn_mask_type: + return False + # The kernel's causal implementation aligns the end of Q with the end of K + # (i.e., bottom-right). TE's plain "causal" means top-left, which only + # matches when max_seqlen_q == max_seqlen_kv. For cross-attention the user + # must opt in explicitly via a "_bottom_right" mask. + if attn_mask_type in ("causal", "padding_causal") and max_seqlen_q != max_seqlen_kv: + return False + is_causal = "causal" in attn_mask_type + left, right = window_size + if not is_causal: + if (left, right) != (-1, -1): + return False + else: + if right not in (-1, 0): + return False + # Backward uses atomic adds on dQ → non-deterministic. + if is_training and deterministic: + return False + if not is_available(): + return False + return True + + +def _to_kernel_shape(x: torch.Tensor, qkv_format: str) -> torch.Tensor: + """Arrange a TE tensor in the (B, H, S, D) view expected by the kernel. + + The kernel accepts either: + * ``(B, H, S, D)`` where the underlying memory is BSHD-contiguous (i.e. + a ``.transpose(1, 2)`` view of a BSHD tensor), or + * ``(T, H, D)`` for variable-length THD. + """ + if qkv_format == "bshd": + return x.transpose(1, 2) + return x + + +def _from_kernel_shape(x: torch.Tensor, qkv_format: str) -> torch.Tensor: + """Inverse of :func:`_to_kernel_shape`.""" + if qkv_format == "bshd": + return x.transpose(1, 2) + return x + + +def _causal_and_window( + attn_mask_type: str, window_size: Tuple[int, int] +) -> Tuple[bool, Tuple[int, int]]: + is_causal = "causal" in attn_mask_type + left, _ = window_size + if is_causal: + return True, (left if left is not None else -1, 0) + return False, (-1, -1) + + +def _cum_seqlens_for_kernel( + cu_seqlens: Optional[torch.Tensor], +) -> Optional[torch.Tensor]: + if cu_seqlens is None: + return None + if cu_seqlens.dtype == torch.int32: + return cu_seqlens + return cu_seqlens.to(dtype=torch.int32) + + +def fused_attn_fwd( + *, + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_kv: Optional[torch.Tensor], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + qkv_format: str, + attn_mask_type: str, + attn_scale: Optional[float], + window_size: Tuple[int, int], +) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Run the cuDNN-FE Python SDPA forward for head_dim=256 on SM100+. + + Returns + ------- + out : torch.Tensor + Attention output, same layout as ``q``. + aux_ctx_tensors : list of torch.Tensor + ``[softmax_lse, rng_state_placeholder]`` — kept two-element for + compatibility with TE's aux-context convention. The rng placeholder is + an empty tensor because this kernel does not support dropout. + """ + sdpa_fwd = _sdpa_fwd_wrapper() + is_causal, window = _causal_and_window(attn_mask_type, window_size) + + q_i = _to_kernel_shape(q, qkv_format) + k_i = _to_kernel_shape(k, qkv_format) + v_i = _to_kernel_shape(v, qkv_format) + + cum_q = _cum_seqlens_for_kernel(cu_seqlens_q) if qkv_format == "thd" else None + cum_k = _cum_seqlens_for_kernel(cu_seqlens_kv) if qkv_format == "thd" else None + + current_stream = torch.cuda.current_stream().cuda_stream + result = sdpa_fwd( + q_tensor=q_i, + k_tensor=k_i, + v_tensor=v_i, + cum_seqlen_q_tensor=cum_q, + cum_seqlen_k_tensor=cum_k, + max_s_q=max_seqlen_q, + max_s_k=max_seqlen_kv, + is_causal=is_causal, + window_size=window, + scale_softmax=attn_scale, + current_stream=current_stream, + ) + + o_i = result["o_tensor"] + lse = result["lse_tensor"] + out = _from_kernel_shape(o_i, qkv_format) + + # Rng state placeholder (no dropout support on this path); kept so the + # aux-context shape matches other F16 fused backends' (lse, rng) pair. + rng_state = torch.empty(2, dtype=torch.int64, device=q.device) + return out, [lse, rng_state] + + +def fused_attn_bwd( + *, + max_seqlen_q: int, + max_seqlen_kv: int, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_kv: Optional[torch.Tensor], + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + o: torch.Tensor, + d_o: torch.Tensor, + aux_ctx_tensors: List[torch.Tensor], + qkv_format: str, + attn_mask_type: str, + attn_scale: Optional[float], + window_size: Tuple[int, int], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the cuDNN-FE Python SDPA backward for head_dim=256 on SM100+.""" + sdpa_bwd = _sdpa_bwd_wrapper() + is_causal, window = _causal_and_window(attn_mask_type, window_size) + + q_i = _to_kernel_shape(q, qkv_format) + k_i = _to_kernel_shape(k, qkv_format) + v_i = _to_kernel_shape(v, qkv_format) + o_i = _to_kernel_shape(o, qkv_format) + do_i = _to_kernel_shape(d_o, qkv_format) + + cum_q = _cum_seqlens_for_kernel(cu_seqlens_q) if qkv_format == "thd" else None + cum_k = _cum_seqlens_for_kernel(cu_seqlens_kv) if qkv_format == "thd" else None + + lse = aux_ctx_tensors[0] + + current_stream = torch.cuda.current_stream().cuda_stream + result = sdpa_bwd( + q_tensor=q_i, + k_tensor=k_i, + v_tensor=v_i, + o_tensor=o_i, + do_tensor=do_i, + lse_tensor=lse, + cum_seqlen_q_tensor=cum_q, + cum_seqlen_k_tensor=cum_k, + max_s_q=max_seqlen_q, + max_s_k=max_seqlen_kv, + is_causal=is_causal, + window_size=window, + scale_softmax=attn_scale, + current_stream=current_stream, + ) + + dq_i = result["dq_tensor"] + dk_i = result["dk_tensor"] + dv_i = result["dv_tensor"] + + dq = _from_kernel_shape(dq_i, qkv_format) + dk = _from_kernel_shape(dk_i, qkv_format) + dv = _from_kernel_shape(dv_i, qkv_format) + return dq, dk, dv diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 20228ddb80..4c0349acd8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1173,6 +1173,46 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt use_fused_attention = False fused_attention_backend = None + # head_dim=256 on SM100+ is serviced by the cuDNN frontend Python SDPA + # (CuTe DSL) rather than the C++ kernel. Promote the backend to the + # Python-only sentinel so ``fused_attn_fwd`` / ``fused_attn_bwd`` route + # through ``cudnn_fe_sdpa`` without re-checking. Disable FusedAttention + # if the Python kernel can't service this config. + if ( + use_fused_attention + and fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + and head_dim_qk == 256 + and head_dim_v == 256 + and device_compute_capability[0] >= 10 + ): + from .cudnn_fe_sdpa import is_supported as _cudnn_fe_supported + + if _cudnn_fe_supported( + head_dim_qk=head_dim_qk, + head_dim_v=head_dim_v, + qkv_dtype=qkv_dtype, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + attn_bias_type=fu_core_attention_bias_type, + softmax_type=softmax_type, + dropout=attention_dropout, + window_size=window_size, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + is_training=is_training, + deterministic=deterministic, + device_compute_capability=device_compute_capability, + return_max_logit=return_max_logit, + ): + fused_attention_backend = FusedAttnBackend["F16_cudnn_fe_sdpa"] + else: + logger.debug( + "Disabling FusedAttention: cuDNN frontend Python SDPA (d=256) does not" + " support this config or is not importable" + ) + use_fused_attention = False + fused_attention_backend = None + # Filter: Determinism # backend | deterministic # --------------------------------------------- diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 06bfb6ef3c..b2e2a6cb37 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -18,7 +18,6 @@ from ..quantized_tensor import Quantizer from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx - __all__ = [ "fused_attn_fwd", "fused_attn_bwd", @@ -98,12 +97,17 @@ "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, + # Python-only sentinel (no C++ counterpart). Set by ``get_attention_backend`` + # when the cuDNN frontend CuTe-DSL d=256 SDPA path is applicable; checked + # in ``fused_attn_fwd`` / ``fused_attn_bwd`` to route to the Python kernel. + "F16_cudnn_fe_sdpa": "F16_cudnn_fe_sdpa", "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 + META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1 META_O = FP8FwdTensorIdx.GEMM2_INPUT @@ -281,6 +285,31 @@ def fused_attn_fwd( f"attn_bias.dtype={attn_bias.dtype} but q.dtype={q.dtype}." ) + # Route head_dim=256 on SM100+ through the cuDNN frontend Python SDPA (CuTe DSL). + # Eligibility was already decided in ``get_attention_backend``; a matching + # backend value here means "use the Python kernel". + if fused_attention_backend == FusedAttnBackend["F16_cudnn_fe_sdpa"]: + from ..attention.dot_product_attention import cudnn_fe_sdpa + + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + out, aux_ctx_tensors = cudnn_fe_sdpa.fused_attn_fwd( + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + q=q, + k=k, + v=v, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + attn_scale=attn_scale, + window_size=window_size, + ) + # ``is_supported`` rejects ``return_max_logit=True``, return None to make the linter happy + if return_max_logit: + return out, aux_ctx_tensors, None + return out, aux_ctx_tensors + if fused_attention_backend == FusedAttnBackend["No_Backend"]: raise ValueError( "Fused attention does not support this input combination:" @@ -537,6 +566,30 @@ def fused_attn_bwd( d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) + # Route head_dim=256 on SM100+ through the cuDNN frontend Python SDPA backward. + # ``ctx.fused_attention_backend`` carries the sentinel forward-to-backward. + if fused_attention_backend == FusedAttnBackend["F16_cudnn_fe_sdpa"]: + from ..attention.dot_product_attention import cudnn_fe_sdpa + + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + dq, dk, dv = cudnn_fe_sdpa.fused_attn_bwd( + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + q=q, + k=k, + v=v, + o=o, + d_o=d_o, + aux_ctx_tensors=aux_ctx_tensors, + qkv_format=qkv_format, + attn_mask_type=attn_mask_type, + attn_scale=attn_scale, + window_size=window_size, + ) + return dq, dk, dv + if fused_attention_backend == FusedAttnBackend["No_Backend"]: raise ValueError( "Fused attention backward does not support this input combination:"