Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 213 files
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_headdim256_cudnn_fe.xml $TE_PATH/tests/pytorch/attention/test_headdim256_cudnn_fe.py || test_fail "test_headdim256_cudnn_fe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint
if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then
Expand Down
154 changes: 154 additions & 0 deletions tests/pytorch/attention/test_headdim256_cudnn_fe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Smoke test for head_dim=256 via the cuDNN frontend Python SDPA (CuTe DSL) on SM100+."""

from __future__ import annotations

import pytest
import torch

import transformer_engine.pytorch as te
from transformer_engine.pytorch.attention.dot_product_attention import cudnn_fe_sdpa


def _sm100_or_newer() -> bool:
if not torch.cuda.is_available():
return False
major, _ = torch.cuda.get_device_capability()
return major >= 10


pytestmark = pytest.mark.skipif(
not (_sm100_or_newer() and cudnn_fe_sdpa.is_available()),
reason="Requires SM100+ GPU and cudnn frontend Python SDPA d=256 kernels",
)


def _reference(q, k, v, mask=None, scale=None):
"""Plain-attention reference in FP32."""
d = q.shape[-1]
scale = scale if scale is not None else 1.0 / (d**0.5)
q32 = q.float()
k32 = k.float()
v32 = v.float()
# q: (B, H, S, D), k: (B, H, S, D), v: (B, H, S, D)
s = torch.einsum("bhqd,bhkd->bhqk", q32, k32) * scale
if mask is not None:
s = s.masked_fill(mask, float("-inf"))
p = torch.softmax(s, dim=-1)
out = torch.einsum("bhqk,bhkd->bhqd", p, v32)
return out.to(q.dtype)


@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("attn_mask_type", ["no_mask", "causal"])
@pytest.mark.parametrize("seqlen", [512, 2048])
def test_cudnn_fe_fwd_bwd_bshd(dtype, attn_mask_type, seqlen):
batch, heads, head_dim = 2, 4, 256
torch.manual_seed(0)
device = "cuda"

# BSHD layout (torch-contiguous)
q = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)
d_o = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device)

window_size = (-1, 0) if attn_mask_type == "causal" else (-1, -1)

# cuDNN-FE direct call
out, aux_ctx = cudnn_fe_sdpa.fused_attn_fwd(
max_seqlen_q=seqlen,
max_seqlen_kv=seqlen,
cu_seqlens_q=None,
cu_seqlens_kv=None,
q=q,
k=k,
v=v,
qkv_format="bshd",
attn_mask_type=attn_mask_type,
attn_scale=None,
window_size=window_size,
)
assert out.shape == q.shape, f"Unexpected fwd out shape {out.shape}"

dq, dk, dv = cudnn_fe_sdpa.fused_attn_bwd(
max_seqlen_q=seqlen,
max_seqlen_kv=seqlen,
cu_seqlens_q=None,
cu_seqlens_kv=None,
q=q,
k=k,
v=v,
o=out,
d_o=d_o,
aux_ctx_tensors=aux_ctx,
qkv_format="bshd",
attn_mask_type=attn_mask_type,
attn_scale=None,
window_size=window_size,
)
assert dq.shape == q.shape
assert dk.shape == k.shape
assert dv.shape == v.shape

# FP32 reference over (B, H, S, D)
q_ref = q.detach().float().transpose(1, 2).contiguous().requires_grad_(True)
k_ref = k.detach().float().transpose(1, 2).contiguous().requires_grad_(True)
v_ref = v.detach().float().transpose(1, 2).contiguous().requires_grad_(True)

mask = None
if attn_mask_type == "causal":
mask = torch.triu(torch.ones(seqlen, seqlen, dtype=torch.bool, device=device), diagonal=1)
out_ref = _reference(q_ref, k_ref, v_ref, mask=mask)
# transpose back to BSHD for comparison
out_ref_bshd = out_ref.transpose(1, 2).contiguous().to(dtype)
out_ref.backward(d_o.transpose(1, 2).contiguous().float())

tol = {"atol": 5e-2, "rtol": 5e-2}
torch.testing.assert_close(out.float(), out_ref_bshd.float(), **tol)
torch.testing.assert_close(
dq.float(), q_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol
)
torch.testing.assert_close(
dk.float(), k_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol
)
torch.testing.assert_close(
dv.float(), v_ref.grad.transpose(1, 2).contiguous().to(dtype).float(), **tol
)


def test_cudnn_fe_fused_attention_module(monkeypatch):
"""Integration test: exercise through the DotProductAttention module."""
# Scope env var mutations to this test — pytest shares a process across
# tests, so a bare ``os.environ[...] = ...`` would leak these flags into
# every later test in the session.
monkeypatch.setenv("NVTE_FUSED_ATTN", "1")
monkeypatch.setenv("NVTE_FLASH_ATTN", "0")
monkeypatch.setenv("NVTE_UNFUSED_ATTN", "0")

dtype = torch.bfloat16
batch, seqlen, heads, head_dim = 2, 1024, 4, 256
device = "cuda"

torch.manual_seed(42)
q = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)
k = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)
v = torch.randn(batch, seqlen, heads, head_dim, dtype=dtype, device=device, requires_grad=True)

dpa = te.DotProductAttention(
num_attention_heads=heads,
kv_channels=head_dim,
qkv_format="bshd",
attention_type="self",
).cuda()
out = dpa(q, k, v, attention_mask=None)
assert out.shape == (batch, seqlen, heads * head_dim)

loss = out.sum()
loss.backward()
assert q.grad is not None and q.grad.shape == q.shape
assert k.grad is not None and k.grad.shape == k.shape
assert v.grad is not None and v.grad.shape == v.shape
6 changes: 5 additions & 1 deletion transformer_engine/common/fused_attn/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) ||
// 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged
(head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 &&
cudnn_runtime_version >= 91100)) &&
cudnn_runtime_version >= 91100) ||
// cuDNN-FE Python SDPA (CuTe DSL): d=256 + Blackwell + training + non-paged
// C++ kernel does not handle this; intercepted in Python for fwd+bwd.
(head_dim_qk == 256 && head_dim_v == 256 && is_training && sm_arch_ >= 100 &&
layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD)) &&
// 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA
// Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed
(!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 &&
Expand Down
Loading
Loading