diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 5aaf67061b..a0b80078fe 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -51,6 +51,8 @@ 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA + "bariamis_4k": ModelConfig(2, 4096, 16, 128, attn_mask_type="causal"), # bariamis default + "bariamis_8k": ModelConfig(2, 8192, 16, 128, attn_mask_type="causal"), } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 64cccaac6e..8d9533f16c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -172,6 +172,41 @@ def flash_attn_fwd_second_half_softmax_lse_correction( softmax_lse_.copy_(new_scale) +@jit_fuser +def flash_attn_fwd_incremental_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + old_softmax_lse: torch.Tensor, + new_softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Online softmax merge: rescale accumulated output and add new step's contribution.""" + scale_old = torch.exp(old_softmax_lse - new_softmax_lse).movedim(2, seq_dim).unsqueeze(-1) + scale_new = torch.exp(softmax_lse_per_step - new_softmax_lse).movedim(2, seq_dim).unsqueeze(-1) + out.mul_(scale_old) + out.addcmul_(out_per_step, scale_new) + + +@jit_fuser +def flash_attn_fwd_incremental_second_half_out_correction( + out: torch.Tensor, + out_per_step: torch.Tensor, + old_softmax_lse: torch.Tensor, + new_softmax_lse: torch.Tensor, + softmax_lse_per_step: torch.Tensor, + seq_dim: int, +): + """Online softmax merge for second-half tokens only (causal upper-triangle steps).""" + out_ = out.select(seq_dim, 1) + old_lse_ = old_softmax_lse.view(*old_softmax_lse.shape[:-1], 2, -1)[..., 1, :] + new_lse_ = new_softmax_lse.view(*new_softmax_lse.shape[:-1], 2, -1)[..., 1, :] + scale_old = torch.exp(old_lse_ - new_lse_).movedim(2, seq_dim).unsqueeze(-1) + scale_new = torch.exp(softmax_lse_per_step - new_lse_).movedim(2, seq_dim).unsqueeze(-1) + out_.mul_(scale_old) + out_.addcmul_(out_per_step, scale_new) + + @jit_fuser def get_cu_seqlens_on_cp_rank( cu_seqlens: torch.Tensor, @@ -1347,7 +1382,7 @@ def forward( amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] - max_logit_per_step = [None for _ in range(cp_size)] + max_logit_per_step = [None, None] max_logit = None assert isinstance(k, q.__class__) and isinstance( @@ -1439,7 +1474,7 @@ def forward( fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if return_max_logit: max_logit_per_step = [ - torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) + torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(2) ] # split qkv to two halves and prepare for load balancing @@ -1547,8 +1582,8 @@ def forward( # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - out_per_step = [None for _ in range(cp_size)] - softmax_lse_per_step = [None for _ in range(cp_size)] + out_per_step = [None, None] + softmax_lse_per_step = [None, None] rng_states = [None for _ in range(cp_size)] attn_biases = [None for _ in range(cp_size)] @@ -1557,7 +1592,7 @@ def forward( # synchronize fwd results correction across steps fwd_results_correction_done = torch.cuda.Event() - p2p_comm_buffers = [None for _ in range(cp_size)] + p2p_comm_buffers = [None, None] k_shape = k.shape k_numel = k.numel() v_shape = v.shape @@ -1568,6 +1603,7 @@ def forward( # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None + second_half_lse_seqlen = None for i in range(cp_size + 1): if i < cp_size: with torch.cuda.stream(flash_attn_streams[i % 2]): @@ -1576,18 +1612,20 @@ def forward( req.wait() if i < (cp_size - 1): - p2p_comm_buffers[i + 1] = torch.empty_like(p2p_comm_buffers[i]) + p2p_comm_buffers[(i + 1) % 2] = torch.empty_like( + p2p_comm_buffers[i % 2] + ) send_recv_reqs[i % 2] = flash_attn_p2p_communicate( rank, - p2p_comm_buffers[i], + p2p_comm_buffers[i % 2], send_dst, - p2p_comm_buffers[i + 1], + p2p_comm_buffers[(i + 1) % 2], recv_src, cp_group, batch_p2p_comm, ) - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i % 2] k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) q_part = q @@ -1674,16 +1712,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1701,16 +1739,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1728,16 +1766,16 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn( *flash_attn_inputs, *prepare_outputs, section ) @@ -1756,18 +1794,18 @@ def forward( q_inputs[i % 2] = q_part if use_fused_attention: ( - out_per_step[i], - softmax_lse_per_step[i], + out_per_step[i % 2], + softmax_lse_per_step[i % 2], rng_states[i], attn_biases[i], - max_logit_per_step[i], + max_logit_per_step[i % 2], ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + out_per_step[i % 2], softmax_lse_per_step[i % 2], rng_states[i] = ( cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - # softmax_lse correction + # Incremental softmax_lse + output correction (online softmax merge) if i > 0: # wait until fwd results correction of last step is done if i > 1: @@ -1777,54 +1815,143 @@ def forward( if use_fused_attention: # [b, h, sq, 1] -> [b, h, sq] or # [t, h, 1] -> [t, np] - softmax_lse_per_step[i - 1].squeeze_(-1) + softmax_lse_per_step[(i - 1) % 2].squeeze_(-1) if softmax_lse_in_packed_format: - softmax_lse_per_step[i - 1] = ( - softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() + softmax_lse_per_step[(i - 1) % 2] = ( + softmax_lse_per_step[(i - 1) % 2].transpose(0, 1).contiguous() ) if fp8: # dequantize out_per_step to torch.float32 if fp8_recipe.delayed(): - out_per_step[i - 1] = out_per_step[i - 1].dequantize( + out_per_step[(i - 1) % 2] = out_per_step[(i - 1) % 2].dequantize( dtype=torch.float32 ) if fp8_recipe.float8_current_scaling(): - out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + out_per_step[(i - 1) % 2] = out_per_step[(i - 1) % 2].to( + dtype=torch.float32 + ) if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": + out = out_per_step[0].clone() if enable_mla: - out = torch.zeros_like(v if not fp8 else out_per_step[0]).view( - v_shape - ) + out = out.view(v_shape) else: - # MHA or GQA - out = torch.zeros_like(q if not fp8 else out_per_step[0]).view( - q.shape - ) + out = out.view(q.shape) + elif qkv_format in ["bshd", "sbhd"]: + out = out_per_step[0].to(torch.float32) + if enable_mla: + out = out.view(v_shape) + else: + out = out.view(q.shape) elif (i - 1) <= rank or not causal: + old_softmax_lse = softmax_lse.clone() flash_attn_fwd_softmax_lse_correction( - softmax_lse, softmax_lse_per_step[i - 1] + softmax_lse, softmax_lse_per_step[(i - 1) % 2] ) + if qkv_format in ["bshd", "sbhd"]: + flash_attn_fwd_incremental_out_correction( + out.view(*out_per_step[(i - 1) % 2].shape), + out_per_step[(i - 1) % 2], + old_softmax_lse, + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + seq_dim, + ) + elif qkv_format == "thd": + if softmax_lse_in_packed_format: + # LSE is [h, t], out is [t, h, d] + scale = torch.exp( + old_softmax_lse - softmax_lse + ).transpose(0, 1).unsqueeze(-1) + out.mul_(scale) + else: + # LSE is [b*h, s_padded] — zero out and reconstruct + # via two thd_out_correction calls + out_backup = out.clone() + out.zero_() + tex.thd_out_correction( + out, + out_backup, + softmax_lse, + old_softmax_lse, + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + tex.thd_out_correction( + out, + out_per_step[(i - 1) % 2], + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) else: + old_softmax_lse = softmax_lse.clone() if qkv_format == "thd": tex.thd_second_half_lse_correction( softmax_lse, - softmax_lse_per_step[i - 1], + softmax_lse_per_step[(i - 1) % 2], + cu_seqlens_q_padded, + softmax_lse_in_packed_format, + ) + # Rescale out for affected (second-half) tokens; + # first-half tokens have old==new so scale=1.0 + if softmax_lse_in_packed_format: + scale = torch.exp( + old_softmax_lse - softmax_lse + ).transpose(0, 1).unsqueeze(-1) + out.mul_(scale) + else: + # Zero out and reconstruct; use only_second_half=False + # because first-half has old==new so scale=1.0 + out_backup = out.clone() + out.zero_() + tex.thd_out_correction( + out, + out_backup, + softmax_lse, + old_softmax_lse, + cu_seqlens_q_padded, + False, + softmax_lse_in_packed_format, + ) + tex.thd_out_correction( + out, + out_per_step[(i - 1) % 2], + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], cu_seqlens_q_padded, + True, softmax_lse_in_packed_format, ) else: flash_attn_fwd_second_half_softmax_lse_correction( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), - softmax_lse_per_step[i - 1], + softmax_lse_per_step[(i - 1) % 2], + ) + flash_attn_fwd_incremental_second_half_out_correction( + out, + out_per_step[(i - 1) % 2], + old_softmax_lse, + softmax_lse, + softmax_lse_per_step[(i - 1) % 2], + seq_dim, ) if return_max_logit: if i == 1: max_logit = torch.clone(max_logit_per_step[0]) else: - max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) + max_logit = torch.maximum(max_logit, max_logit_per_step[(i - 1) % 2]) + + # Capture second_half_lse_seqlen from the last step's LSE + if i == cp_size and causal and rank < (cp_size - 1): + second_half_lse_seqlen = ( + softmax_lse_per_step[(cp_size - 1) % 2].shape[-1] + ) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) @@ -1835,63 +1962,6 @@ def forward( max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group ) - second_half_lse_seqlen = None - if causal and rank < (cp_size - 1): - second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] - - # fwd output correction: out in torch.float32 - for i in range(cp_size): - if i <= rank or not causal: - if qkv_format in ["bshd", "sbhd"]: - if i == 0: - out = flash_attn_fwd_out_correction_init( - out_per_step[0], - softmax_lse, - softmax_lse_per_step[0], - seq_dim, - ) - if enable_mla: - out = out.view(v_shape) - else: - out = out.view(q.shape) - else: - flash_attn_fwd_out_correction( - out.view(*out_per_step[i].shape), - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - seq_dim, - ) - elif qkv_format == "thd": - tex.thd_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q_padded, - False, - softmax_lse_in_packed_format, - ) - else: - if qkv_format in ["bshd", "sbhd"]: - flash_attn_fwd_second_half_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - seq_dim, - ) - elif qkv_format == "thd": - tex.thd_out_correction( - out, - out_per_step[i], - softmax_lse, - softmax_lse_per_step[i], - cu_seqlens_q_padded, - True, - softmax_lse_in_packed_format, - ) - if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] @@ -1952,7 +2022,7 @@ def forward( ctx.fp8 = fp8 and is_bwd_fp8 kv_fp8 = None - kv = p2p_comm_buffers[-1] + kv = p2p_comm_buffers[(cp_size - 1) % 2] if fp8: q_fp8, kv_fp8 = [ Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype)