From 5205b895214cb29273c79f502e23893e0a370d46 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 21 Apr 2026 14:12:22 -0700 Subject: [PATCH 1/3] [PyTorch][CP] Double-buffer P2P KV comm buffers in forward pass MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The forward pass allocated cp_size KV communication buffers but only ever needed 2 live at any time (current compute + next recv). This mirrors the backward pass which already uses 2-entry double-buffering. Convert p2p_comm_buffers from a cp_size-length list to a 2-entry list with i%2 rotation. Saves (cp_size-3) buffer copies at peak — measured 2.6 GB at cp=8 (S=262k, B=2, H=16, D=128) with zero perf regression. Also add bariamis benchmark configs (H=16, S=4k/8k) to flash_attn test suite for correctness coverage at the exact config used in CP communication benchmarking. Signed-off-by: Sudhakar Singh --- tests/pytorch/attention/test_attention_with_cp.py | 2 ++ .../dot_product_attention/context_parallel.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..a0b80078fe 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -51,6 +51,8 @@ 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA + "bariamis_4k": ModelConfig(2, 4096, 16, 128, attn_mask_type="causal"), # bariamis default + "bariamis_8k": ModelConfig(2, 8192, 16, 128, attn_mask_type="causal"), } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..f3f6f3dada 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1557,7 +1557,7 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() - p2p_comm_buffers = [None for _ in range(cp_size)] + p2p_comm_buffers = [None, None] k_shape = k.shape k_numel = k.numel() v_shape = v.shape @@ -1576,18 +1576,20 @@ def forward( req.wait() if i < (cp_size - 1): - p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) + p2p_comm_buffers[(i + 1) % 2] = torch.empty_like( + p2p_comm_buffers[i % 2] + ) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, - p2p_comm_buffers[i], + p2p_comm_buffers[i % 2], send_dst, - p2p_comm_buffers[i + 1], + p2p_comm_buffers[(i + 1) % 2], recv_src, cp_group, batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i % 2] k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) q_part = q @@ -1952,7 +1954,7 @@ def forward( ctx.fp8 = fp8 and is_bwd_fp8 kv_fp8 = None - kv = p2p_comm_buffers[-1] + kv = p2p_comm_buffers[(cp_size - 1) % 2] if fp8: q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) From 02d129e05b18450fdd0c020a261d303163261107 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 21 Apr 2026 14:45:12 -0700 Subject: [PATCH 2/3] [PyTorch][CP] Incremental output correction in P2P forward pass Replace post-loop output correction with online softmax merge during the main loop. Each step's partial output is immediately merged into a running accumulator using exp(old_lse - new_lse) rescaling, eliminating the need to store all cp_size out_per_step and softmax_lse_per_step tensors. Double-buffers out_per_step, softmax_lse_per_step, and max_logit_per_step (2 slots each). rng_states and attn_biases remain at cp_size for backward. All three QKV formats supported: - bshd/sbhd: new @jit_fuser incremental correction helpers - THD packed LSE: Python mul_ rescale + existing thd_out_correction kernel - THD unpacked LSE (legacy): clone+zero+reconstruct fallback Signed-off-by: Sudhakar Singh --- .../dot_product_attention/context_parallel.py | 256 +++++++++++------- 1 file changed, 162 insertions(+), 94 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f3f6f3dada..f94cf3f963 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -172,6 +172,41 @@ def flash_attn_fwd_second_half_softmax_lse_correction( softmax_lse_.copy_(new_scale) +@jit_fuser +def flash_attn_fwd_incremental_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + old_softmax_lse: torch.Tensor, + new_softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Online softmax merge: rescale accumulated output and add new step's contribution.""" + scale_old = torch.exp(old_softmax_lse - new_softmax_lse).movedim(2, seq_dim).unsqueeze(-1) + scale_new = torch.exp(softmax_lse_per_step - new_softmax_lse).movedim(2, seq_dim).unsqueeze(-1) + out.mul_(scale_old) + out.addcmul_(out_per_step, scale_new) + + +@jit_fuser +def flash_attn_fwd_incremental_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + old_softmax_lse: torch.Tensor, + new_softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Online softmax merge for second-half tokens only (causal upper-triangle steps).""" + out_ = out.select(seq_dim, 1) + old_lse_ = old_softmax_lse.view(*old_softmax_lse.shape[:-1], 2, -1)[..., 1, :] + new_lse_ = new_softmax_lse.view(*new_softmax_lse.shape[:-1], 2, -1)[..., 1, :] + scale_old = torch.exp(old_lse_ - new_lse_).movedim(2, seq_dim).unsqueeze(-1) + scale_new = torch.exp(softmax_lse_per_step - new_lse_).movedim(2, seq_dim).unsqueeze(-1) + out_.mul_(scale_old) + out_.addcmul_(out_per_step, scale_new) + + @jit_fuser def get_cu_seqlens_on_cp_rank( cu_seqlens: torch.Tensor, @@ -1347,7 +1382,7 @@ def forward( amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] - max_logit_per_step = [None for _ in range(cp_size)] + max_logit_per_step = [None, None] max_logit = None assert isinstance(k, q.__class__) and isinstance( @@ -1439,7 +1474,7 @@ def forward( fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if return_max_logit: max_logit_per_step = [ - torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) + torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(2) ] # split qkv to two halves and prepare for load balancing @@ -1547,8 +1582,8 @@ def forward( # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - out_per_step = [None for _ in range(cp_size)] - softmax_lse_per_step = [None for _ in range(cp_size)] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] rng_states = [None for _ in range(cp_size)] attn_biases = [None for _ in range(cp_size)] @@ -1568,6 +1603,7 @@ def forward( # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + second_half_lse_seqlen = None for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1676,16 +1712,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1703,16 +1739,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1730,16 +1766,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1758,18 +1794,18 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - # softmax_lse correction + # Incremental softmax_lse + output correction (online softmax merge) if i > 0: # wait until fwd results correction of last step is done if i > 1: @@ -1779,54 +1815,143 @@ def forward( if use_fused_attention: # [b, h, sq, 1] -> [b, h, sq] or # [t, h, 1] -> [t, np] - softmax_lse_per_step[i - 1].squeeze_(-1) + softmax_lse_per_step[(i - 1) % 2].squeeze_(-1) if softmax_lse_in_packed_format: - softmax_lse_per_step[i - 1] = ( - softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + softmax_lse_per_step[(i - 1) % 2] = ( + softmax_lse_per_step[(i - 1) % 2].transpose(0, 1).contiguous() ) if fp8: # dequantize out_per_step to torch.float32 if fp8_recipe.delayed(): - out_per_step[i - 1] = out_per_step[i - 1].dequantize( + out_per_step[(i - 1) % 2] = out_per_step[(i - 1) % 2].dequantize( dtype=torch.float32 ) if fp8_recipe.float8_current_scaling(): - out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + out_per_step[(i - 1) % 2] = out_per_step[(i - 1) % 2].to( + dtype=torch.float32 + ) if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": + out = out_per_step[0].to(torch.float32) if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + out = out.view(v_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = out.view(q.shape) + elif qkv_format in ["bshd", "sbhd"]: + out = out_per_step[0].to(torch.float32) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) elif (i - 1) <= rank or not causal: + old_softmax_lse = softmax_lse.clone() flash_attn_fwd_softmax_lse_correction( - softmax_lse, softmax_lse_per_step[i - 1] + softmax_lse, softmax_lse_per_step[(i - 1) % 2] ) + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_incremental_out_correction( + out.view(*out_per_step[(i - 1) % 2].shape), + out_per_step[(i - 1) % 2], + old_softmax_lse, + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + seq_dim, + ) + elif qkv_format == "thd": + if softmax_lse_in_packed_format: + # LSE is [h, t], out is [t, h, d] + scale = torch.exp( + old_softmax_lse - softmax_lse + ).transpose(0, 1).unsqueeze(-1) + out.mul_(scale) + else: + # LSE is [b*h, s_padded] — zero out and reconstruct + # via two thd_out_correction calls + out_backup = out.clone() + out.zero_() + tex.thd_out_correction( + out, + out_backup, + softmax_lse, + old_softmax_lse, + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + tex.thd_out_correction( + out, + out_per_step[(i - 1) % 2], + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) else: + old_softmax_lse = softmax_lse.clone() if qkv_format == "thd": tex.thd_second_half_lse_correction( softmax_lse, - softmax_lse_per_step[i - 1], + softmax_lse_per_step[(i - 1) % 2], + cu_seqlens_q_padded, + softmax_lse_in_packed_format, + ) + # Rescale out for affected (second-half) tokens; + # first-half tokens have old==new so scale=1.0 + if softmax_lse_in_packed_format: + scale = torch.exp( + old_softmax_lse - softmax_lse + ).transpose(0, 1).unsqueeze(-1) + out.mul_(scale) + else: + # Zero out and reconstruct; use only_second_half=False + # because first-half has old==new so scale=1.0 + out_backup = out.clone() + out.zero_() + tex.thd_out_correction( + out, + out_backup, + softmax_lse, + old_softmax_lse, + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + tex.thd_out_correction( + out, + out_per_step[(i - 1) % 2], + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], cu_seqlens_q_padded, + True, softmax_lse_in_packed_format, ) else: flash_attn_fwd_second_half_softmax_lse_correction( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), - softmax_lse_per_step[i - 1], + softmax_lse_per_step[(i - 1) % 2], + ) + flash_attn_fwd_incremental_second_half_out_correction( + out, + out_per_step[(i - 1) % 2], + old_softmax_lse, + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + seq_dim, ) if return_max_logit: if i == 1: max_logit = torch.clone(max_logit_per_step[0]) else: - max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) + max_logit = torch.maximum(max_logit, max_logit_per_step[(i - 1) % 2]) + + # Capture second_half_lse_seqlen from the last step's LSE + if i == cp_size and causal and rank < (cp_size - 1): + second_half_lse_seqlen = ( + softmax_lse_per_step[(cp_size - 1) % 2].shape[-1] + ) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) @@ -1837,63 +1962,6 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - second_half_lse_seqlen = None - if causal and rank < (cp_size - 1): - second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] - - # fwd output correction: out in torch.float32 - for i in range(cp_size): - if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: - if i == 0: - out = flash_attn_fwd_out_correction_init( - out_per_step[0], - softmax_lse, - softmax_lse_per_step[0], - seq_dim, - ) - if enable_mla: - out = out.view(v_shape) - else: - out = out.view(q.shape) - else: - flash_attn_fwd_out_correction( - out.view(*out_per_step[i].shape), - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - seq_dim, - ) - elif qkv_format == "thd": - tex.thd_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q_padded, - False, - softmax_lse_in_packed_format, - ) - else: - if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_second_half_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - seq_dim, - ) - elif qkv_format == "thd": - tex.thd_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q_padded, - True, - softmax_lse_in_packed_format, - ) - if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] From 369f62dd09d0382673d28a7673521e16b22669d2 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 21 Apr 2026 15:19:46 -0700 Subject: [PATCH 3/3] [PyTorch][CP] Fix THD dtype mismatch in incremental output correction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The thd_out_correction CUDA kernel requires out and out_per_step to share the same dtype. Using .to(torch.float32) for THD init broke this contract since out_per_step stays in bf16/fp16. Use .clone() instead — the kernel handles float promotion internally per-element. Signed-off-by: Sudhakar Singh --- .../pytorch/attention/dot_product_attention/context_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f94cf3f963..8d9533f16c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1834,7 +1834,7 @@ def forward( if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": - out = out_per_step[0].to(torch.float32) + out = out_per_step[0].clone() if enable_mla: out = out.view(v_shape) else: