[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522
Conversation
8a7da45 to
1e15b00
Compare
1e15b00 to
7c891bd
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
7c891bd to
642a0d6
Compare
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…the seqoffsets calculation API Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci jax L0 L1 L2 |
Greptile SummaryThis PR replaces the O(T²) mask-materialisation path used to compute THD seqlens and offsets for BRCM and SWA attention types with two new O(T) helpers ( Confidence Score: 4/5Safe to merge for causal+THD workloads; the BRCM/SWA path is refactored correctly for valid inputs but lacks new test coverage. All findings are P2 (dead code, redundant ops, misleading comment). The core algorithmic change is correct — transformer_engine/jax/attention.py — specifically the dead Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["_segment_ids_pos_to_seqlens_offsets(segment_ids_q, segment_ids_kv, ...)"] --> B{Causal, no SWA?}
B -- Yes --> C["_segment_ids_pos_to_seqlens_offsets_fast_causal_path()\n[existing O(T) path]"]
C --> D["_get_seqlens_and_offsets() × 2\n(bincount + argwhere)"]
D --> E["_fast_causal_adjust_seqlen_and_offsets()\n(trim boundary tokens for cross-attn)"]
E --> Z["return q_seqlen, kv_seqlen, q_offset, kv_offset"]
B -- No: BRCM / SWA --> F["NEW: O(T) direct helpers"]
F --> G["_get_seqlens_thd(segment_ids_q)\nbincount on compacted IDs → [B, max_seg]"]
F --> H["_get_seqlens_thd(segment_ids_kv)\nsame → [B, max_seg]"]
F --> I["_get_seqoffsets_thd(segment_ids_q)\nboundary detection → [B, max_seg+1]"]
F --> J["_get_seqoffsets_thd(segment_ids_kv)\nsame → [B, max_seg+1]"]
G & H & I & J --> Z
OLD["OLD BRCM/SWA path (removed)"] --> M["make_attention_mask() × 2 — O(T²) mask"]
M --> N["apply causal/BRCM mask — O(T²)"]
N --> O["_mask_to_seqlens_offset() — O(T²)"]
O -.->|"replaced by"| F
|
| # Gather the indices of non-padding tokens per row into a dense prefix; | ||
| # slots past the last valid token are filled with -1. | ||
| non_zero_mask = segment_ids != 0 | ||
| max_size = segment_ids.shape[-1] | ||
| non_zero_indices = jax.vmap( | ||
| lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0] | ||
| )(non_zero_mask) | ||
|
|
||
| # Materialise a padding-free view of segment_ids by gathering at | ||
| # non_zero_indices. Slots whose index was -1 are explicitly set | ||
| # to 0 so they end up in the id=0 bucket (that we drop below). | ||
| clipped_indices = jnp.clip(non_zero_indices, 0, None) | ||
| valid_segment_ids = jnp.where( | ||
| non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0 | ||
| ) | ||
| # Per-row bincount of ids -> segment length, discarding the | ||
| # id=0 bucket (padding) and capping at max_segments_per_seq. | ||
| seqlens_all = jax.vmap( | ||
| lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:] | ||
| )(valid_segment_ids) | ||
| seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all) |
There was a problem hiding this comment.
Unnecessary compaction step — bincount handles zeros natively
The non_zero_indices gather-and-clip block compacts padding positions to the back before calling bincount, but bincount already places padding tokens (id=0) into bucket 0, which is discarded via [1:]. The compaction produces the same result as calling bincount directly on the original segment_ids (same logic used by _get_seqlens_and_offsets). The extra vmap(jnp.where) + take_along_axis adds XLA ops that aren't needed:
# Equivalent, simpler form:
seqlens_all = jax.vmap(
lambda row: jnp.bincount(row.astype(jnp.int32), length=max_segments_per_seq + 1)[1:]
)(segment_ids)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)| # Remove any padded region segment changes (this also handles intra-segment padding correctly) | ||
| segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False) |
There was a problem hiding this comment.
Misleading comment contradicts the documented assumption
The comment "this also handles intra-segment padding correctly" conflicts with assumption 2 stated at line 615: "No intra-segment padding, only inter-segment padding allowed." In fact, if intra-segment padding occurs (e.g., [1, 1, 0, 1, 2, 2]), segment_changes_masked will fire on the re-entry position (index 3 in the example), registering a spurious third "segment start" when only two segments exist. The comment should be removed or updated to say the masking step discards padding-region transitions, not that it handles intra-segment padding.
Description
The current mechanism in TE JAX attention for calculating the THD seqlens and offsets materializes the full mask and then uses it to calculate the seqlens and seqoffsets. However, this is O(T²) and can result in OOM failures when running larger sequences. This PR moves to a newer O(T) approach for the same thereby giving allowing the processing of larger sequences along with some perf advantages.
Benched on a standalone script , the current mask-based approach is O(T²) in both FLOPS and intermediate HBM, while the newer approach is O(T). At T=128k, the current approach needs ~1.5 MiB of scratch vs current's 16 GiB (~10,000x less) and does ~6,900x fewer FLOPS per call
NOTE: These tests were run on a stand alone implementation to give the user some estimation of the scope of the change. Please refer the API used for calculating these numbers in the Testing section
Fixes # (issue)
#2700
Type of change
Changes
Calculate the seqlens and seqoffsets using only the segment ids (and without materializing any mask) for THD layout. The new algorithm tries to rely on changes in segment ids to find breaks.
This code path affects THD fused attn as well THD CP P2P fused attn
Testing
The bench compiles each helper (current and new) as a standalone jax.jit program on randomized [B, T] segment inputs (T up to 128k, ~16 segments/row with some intra-segment padding) and reads three signals off the compiled object:
compiled.memory_analysis()(buffer-assignment plan: arg/out/temp/total bytes) andcompiled.cost_analysis()(HLO instruction counts) — both are static, post-compile reads that do not execute the kernelMEMORY SUMMARY (XLA static plan: what is actually allocated post-fusion)
WORK SUMMARY (cost_analysis: per-call work; immune to XLA fusion)
LATENCY SUMMARY (median wall-clock)
Checklist: