Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1486,9 +1486,11 @@ def backward(ctx, d_out, *_args):
rest = [None]
if ctx.use_FAv2_bwd:
softmax_lse, rng_state = aux_ctx_tensors
dq = torch.empty_like(q)
dk = torch.empty_like(k)
dv = torch.empty_like(v)
# THD + CP fix: zeros_like ensures padded positions start from safe values,
# preventing garbage from propagating through backward gradient accumulation.
dq = torch.zeros_like(q)
dk = torch.zeros_like(k)
dv = torch.zeros_like(v)
d_out, q, k, v, out = [dpa_utils.maybe_contiguous(x) for x in (d_out, q, k, v, out)]
# from transformer_engine.pytorch.attention.dot_product_attention import flash_attn_cuda_bwd
flash_attn_cuda_bwd(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,17 @@ def forward(

nvtx_range_pop(f"{nvtx_label}")

# THD CUDA Graph: zero-fill output at padded positions after CP assembly.
# cu_seqlens_q_padded is GLOBAL; divide by cp_size to get local actual_T.
if qkv_format == "thd" and out_ret is not None and hasattr(out_ret, "shape"):
import torch as _torch
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant in-function import torch as _torch

torch is already imported at the top of this module. The local alias _torch adds no value and makes the code harder to grep. The same pattern appears in the backward block at line ~2754.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


_local_aT = cu_seqlens_q_padded[-1] // cp_size
if out_ret.shape[0] > 0:
_m = _torch.arange(out_ret.shape[0], device=out_ret.device) >= _local_aT
out_ret.data[_m] = 0
out.data[_m.view(-1, *([1] * (out.dim() - 1))).expand_as(out)] = 0
Comment on lines +2052 to +2056
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 None-dereference if cu_seqlens_q_padded is absent

cu_seqlens_q_padded[-1] is accessed unconditionally, but the parameter is optional and can be None (e.g., when THD is used without inter-sequence padding). A TypeError: 'NoneType' object is not subscriptable would be raised in that case. Add a None guard so the zero-fill only executes when the padded lengths are actually present.

Suggested change
_local_aT = cu_seqlens_q_padded[-1] // cp_size
if out_ret.shape[0] > 0:
_m = _torch.arange(out_ret.shape[0], device=out_ret.device) >= _local_aT
out_ret.data[_m] = 0
out.data[_m.view(-1, *([1] * (out.dim() - 1))).expand_as(out)] = 0
if qkv_format == "thd" and out_ret is not None and hasattr(out_ret, "shape") and cu_seqlens_q_padded is not None:
import torch as _torch
_local_aT = cu_seqlens_q_padded[-1] // cp_size
if out_ret.shape[0] > 0:
_m = _torch.arange(out_ret.shape[0], device=out_ret.device) >= _local_aT
out_ret.data[_m] = 0
out.data[_m.view(-1, *([1] * (out.dim() - 1))).expand_as(out)] = 0


if return_max_logit:
return out_ret, max_logit
return out_ret
Expand Down Expand Up @@ -2680,10 +2691,17 @@ def backward(ctx, dout, *_args):
dim = ctx.qkv_format.index("s")
dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]]

# THD CUDA Graph fix: reading cu_seqlens[-1] as a Python index triggers
# GPU->CPU sync during graph capture. Use .shape[0] instead when capturing.
if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
dq[cu_seqlens_q_padded[-1] :].fill_(0)
dk[cu_seqlens_kv_padded[-1] :].fill_(0)
dv[cu_seqlens_kv_padded[-1] :].fill_(0)
if torch.cuda.is_current_stream_capturing():
_q_end, _kv_end = dq.shape[0], dk.shape[0]
else:
_q_end = cu_seqlens_q_padded[-1]
_kv_end = cu_seqlens_kv_padded[-1]
dq[_q_end:].fill_(0)
dk[_kv_end:].fill_(0)
dv[_kv_end:].fill_(0)

if ctx.fp8 and ctx.is_input_fp8:
dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer)
Expand Down Expand Up @@ -2731,6 +2749,16 @@ def backward(ctx, dout, *_args):

nvtx_range_pop(f"{nvtx_label}")

# THD CUDA Graph: zero-fill dQ/dK/dV at padded positions after CP backward.
if ctx.qkv_format == "thd":
import torch as _torch

_local_aT_bwd = cu_seqlens_q_padded[-1] // get_distributed_world_size(ctx.cp_group)
for _dg in [dq, dk, dv]:
if _dg is not None and hasattr(_dg, "shape") and _dg.shape[0] > 0:
_mb = _torch.arange(_dg.shape[0], device=_dg.device) >= _local_aT_bwd
_dg[_mb] = 0
Comment on lines +2752 to +2760
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 None-dereference if cu_seqlens_q_padded is absent (backward)

Same as the forward: cu_seqlens_q_padded[-1] is accessed unconditionally inside if ctx.qkv_format == "thd":, but cu_seqlens_q_padded can be None when no inter-sequence padding was used. The backward always restores it from ctx exactly as it was saved, so a None saved in the forward becomes a None here, crashing the backward pass.

Suggested change
# THD CUDA Graph: zero-fill dQ/dK/dV at padded positions after CP backward.
if ctx.qkv_format == "thd":
import torch as _torch
_local_aT_bwd = cu_seqlens_q_padded[-1] // get_distributed_world_size(ctx.cp_group)
for _dg in [dq, dk, dv]:
if _dg is not None and hasattr(_dg, "shape") and _dg.shape[0] > 0:
_mb = _torch.arange(_dg.shape[0], device=_dg.device) >= _local_aT_bwd
_dg[_mb] = 0
if ctx.qkv_format == "thd" and cu_seqlens_q_padded is not None:
import torch as _torch
_local_aT_bwd = cu_seqlens_q_padded[-1] // get_distributed_world_size(ctx.cp_group)
for _dg in [dq, dk, dv]:
if _dg is not None and hasattr(_dg, "shape") and _dg.shape[0] > 0:
_mb = _torch.arange(_dg.shape[0], device=_dg.device) >= _local_aT_bwd
_dg[_mb] = 0


return (
None,
dq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1330,13 +1330,10 @@ def forward(
# check if there is padding between sequences when qkv_format='thd'
if pad_between_seqs is None:
if qkv_format == "thd":
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
)
# THD + CUDA Graph fix: torch.equal() triggers GPU->CPU sync,
# which is forbidden during CUDA graph capture.
# pad_between_seqs=True is always safe for THD with padded cu_seqlens.
pad_between_seqs = True
Comment on lines 1332 to +1336
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 pad_between_seqs=True even when no padded cu_seqlens exist

The original logic returned False when cu_seqlens_q_padded is None (no padding present). The new unconditional True means callers that legitimately pass cu_seqlens_q_padded=None with THD format will now be routed into the padding code path, triggering the new zero-fill in fused_attn_fwd/bwd and any downstream masked-fill logic. If cu_seqlens_q_padded is None, kernels that rely on it for masking may see incorrect results or raise exceptions. Consider scoping the override to cases where padding is actually present:

Suggested change
if qkv_format == "thd":
pad_between_seqs = (
cu_seqlens_q_padded is not None
and not torch.equal(cu_seqlens_q_padded[:-1], cu_seqlens_q[:-1])
) or (
cu_seqlens_kv_padded is not None
and not torch.equal(cu_seqlens_kv_padded[:-1], cu_seqlens_kv[:-1])
)
# THD + CUDA Graph fix: torch.equal() triggers GPU->CPU sync,
# which is forbidden during CUDA graph capture.
# pad_between_seqs=True is always safe for THD with padded cu_seqlens.
pad_between_seqs = True
pad_between_seqs = (
cu_seqlens_q_padded is not None
or cu_seqlens_kv_padded is not None
)

else:
pad_between_seqs = False

Expand Down
17 changes: 17 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,15 @@ def fused_attn_fwd(
cuda_graph,
)

# THD CUDA Graph: zero-fill output at positions beyond cu_seqlens[-1].
# Uses pure CUDA ops (no CPU sync) for CUDA graph capture compatibility.
if qkv_layout in ("t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"):
_out = output_tensors[0]
_aT_fwd = cu_seqlens_q[-1]
if _out.shape[0] > 0:
_m_fwd = torch.arange(_out.shape[0], device=_out.device) >= _aT_fwd
_out[_m_fwd] = 0

if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd (newer cuDNN runtimes, non-sm120): output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
Expand Down Expand Up @@ -607,4 +616,12 @@ def fused_attn_bwd(
cuda_graph,
)

# THD CUDA Graph: zero-fill dQ/dK/dV at positions beyond cu_seqlens[-1].
if qkv_layout in ("t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"):
_aT_bwd = cu_seqlens_q[-1]
for _dt in output_tensors[:3]:
if hasattr(_dt, "shape") and _dt.shape[0] > 0:
_m_bwd = torch.arange(_dt.shape[0], device=_dt.device) >= _aT_bwd
_dt[_m_bwd] = 0

return output_tensors