Skip to content

[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522

Open
KshitijLakhani wants to merge 6 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/mem-optimize-seqlens-offsets-thd
Open

[JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences #2522
KshitijLakhani wants to merge 6 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/mem-optimize-seqlens-offsets-thd

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 16, 2025

Description

The current mechanism in TE JAX attention for calculating the THD seqlens and offsets materializes the full mask and then uses it to calculate the seqlens and seqoffsets. However, this is O(T²) and can result in OOM failures when running larger sequences. This PR moves to a newer O(T) approach for the same thereby giving allowing the processing of larger sequences along with some perf advantages.

Benched on a standalone script , the current mask-based approach is O(T²) in both FLOPS and intermediate HBM, while the newer approach is O(T). At T=128k, the current approach needs ~1.5 MiB of scratch vs current's 16 GiB (~10,000x less) and does ~6,900x fewer FLOPS per call
NOTE: These tests were run on a stand alone implementation to give the user some estimation of the scope of the change. Please refer the API used for calculating these numbers in the Testing section

Fixes # (issue)
#2700

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Calculate the seqlens and seqoffsets using only the segment ids (and without materializing any mask) for THD layout. The new algorithm tries to rely on changes in segment ids to find breaks.
This code path affects THD fused attn as well THD CP P2P fused attn

Testing

The bench compiles each helper (current and new) as a standalone jax.jit program on randomized [B, T] segment inputs (T up to 128k, ~16 segments/row with some intra-segment padding) and reads three signals off the compiled object:

  • Memory and FLOPS/bytes come from compiled.memory_analysis() (buffer-assignment plan: arg/out/temp/total bytes) and compiled.cost_analysis() (HLO instruction counts) — both are static, post-compile reads that do not execute the kernel
  • Latency actually runs the compiled kernel (5x warmup + 20x timed, wall-clocked via time.perf_counter() with jax.block_until_ready() to sync the GPU)

MEMORY SUMMARY (XLA static plan: what is actually allocated post-fusion)


  B        T  max_seg           layout      curr temp      new temp    curr/new     curr total     new total    curr/new
  1     8192       16   mixed_with_pad      2.03 MiB    105.28 KiB     19.77x      2.06 MiB    137.57 KiB     15.36x
  1    16384       16   mixed_with_pad      8.06 MiB    208.53 KiB     39.60x      8.13 MiB    272.82 KiB     30.50x
  1    32768       16   mixed_with_pad     16.13 MiB    385.03 KiB     42.89x     16.25 MiB    513.32 KiB     32.42x
  1    65536       16   mixed_with_pad      4.06 GiB    769.03 KiB   5539.56x      4.06 GiB      1.00 MiB   4155.14x
  1   131072       16   mixed_with_pad     16.13 GiB      1.53 MiB  10762.87x     16.13 GiB      2.03 MiB   8116.52x

WORK SUMMARY (cost_analysis: per-call work; immune to XLA fusion)


  B        T  max_seg           layout     curr flops     new flops    curr/new     curr bytes     new bytes    curr/new
  1     8192       16   mixed_with_pad      537.56 M      976.65 K    550.42x       4.53 MB     509.56 KB      8.89x
  1    16384       16   mixed_with_pad        2.15 G        1.94 M   1109.27x      17.45 MB       1.03 MB     16.93x
  1    32768       16   mixed_with_pad        8.59 G        3.88 M   2214.91x      68.46 MB       2.05 MB     33.46x
  1    65536       16   mixed_with_pad       30.07 G        7.76 M   3875.74x       8.86 GB       4.06 MB   2182.17x
  1   131072       16   mixed_with_pad      120.27 G       17.34 M   6934.07x      34.90 GB      10.25 MB   3405.67x

LATENCY SUMMARY (median wall-clock)


  B        T  max_seg           layout    curr med us    new med us    curr/new 
  1     2048       16   mixed_with_pad        124.86        118.03      1.06x
  1     8192       16   mixed_with_pad        167.68        124.64      1.35x  
  1    32768       16   mixed_with_pad      26941.88        171.38    157.21x
  1    65536       16   mixed_with_pad       5132.06        226.80     22.63x

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani self-assigned this Dec 16, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 8a7da45 to 1e15b00 Compare March 18, 2026 22:31
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 1e15b00 to 7c891bd Compare April 16, 2026 20:42
KshitijLakhani and others added 3 commits April 24, 2026 21:46
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/mem-optimize-seqlens-offsets-thd branch from 7c891bd to 642a0d6 Compare April 24, 2026 21:47
KshitijLakhani and others added 3 commits April 27, 2026 23:14
Signed-off-by: Kshitij Janardan Lakhani <klakhani@nvidia.com>
…the seqoffsets calculation API

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review April 28, 2026 00:11
@KshitijLakhani KshitijLakhani changed the title [JAX] Calculate seqlens and offsets in O(N) space instead of O(N*N) space for THD sequences [JAX] Calculate seqlens and offsets in O(T) space instead of O(T*T) space for THD sequences Apr 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 28, 2026

Greptile Summary

This PR replaces the O(T²) mask-materialisation path used to compute THD seqlens and offsets for BRCM and SWA attention types with two new O(T) helpers (_get_seqlens_thd via bincount and _get_seqoffsets_thd via boundary detection), matching the approach already used by the fast causal path. The change is logically correct for valid, monotonically-ordered segment IDs with no intra-segment padding, and the benchmark numbers demonstrate significant memory and FLOP reductions. No new unit tests covering the affected BRCM/SWA THD code path were added (both test checkboxes remain unchecked in the PR).

Confidence Score: 4/5

Safe to merge for causal+THD workloads; the BRCM/SWA path is refactored correctly for valid inputs but lacks new test coverage.

All findings are P2 (dead code, redundant ops, misleading comment). The core algorithmic change is correct — _get_seqlens_thd/_get_seqoffsets_thd produce results equivalent to the old mask-derived approach for inputs satisfying the documented constraints (monotonic segments, kv_len ≥ q_len for BRCM). Score capped at 4 due to P2s and absence of new tests for the changed code path.

transformer_engine/jax/attention.py — specifically the dead _mask_to_seqlens_offset and the BRCM/SWA code path which has no new unit tests.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Replaces O(T²) mask-materialisation path for THD BRCM/SWA seqlens+offsets with two O(T) helpers (_get_seqlens_thd, _get_seqoffsets_thd); _mask_to_seqlens_offset is now dead code, and _get_seqlens_thd has a redundant compaction step

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["_segment_ids_pos_to_seqlens_offsets(segment_ids_q, segment_ids_kv, ...)"] --> B{Causal, no SWA?}
    B -- Yes --> C["_segment_ids_pos_to_seqlens_offsets_fast_causal_path()\n[existing O(T) path]"]
    C --> D["_get_seqlens_and_offsets() × 2\n(bincount + argwhere)"]
    D --> E["_fast_causal_adjust_seqlen_and_offsets()\n(trim boundary tokens for cross-attn)"]
    E --> Z["return q_seqlen, kv_seqlen, q_offset, kv_offset"]
    B -- No: BRCM / SWA --> F["NEW: O(T) direct helpers"]
    F --> G["_get_seqlens_thd(segment_ids_q)\nbincount on compacted IDs → [B, max_seg]"]
    F --> H["_get_seqlens_thd(segment_ids_kv)\nsame → [B, max_seg]"]
    F --> I["_get_seqoffsets_thd(segment_ids_q)\nboundary detection → [B, max_seg+1]"]
    F --> J["_get_seqoffsets_thd(segment_ids_kv)\nsame → [B, max_seg+1]"]
    G & H & I & J --> Z
    OLD["OLD BRCM/SWA path (removed)"] --> M["make_attention_mask() × 2 — O(T²) mask"]
    M --> N["apply causal/BRCM mask — O(T²)"]
    N --> O["_mask_to_seqlens_offset() — O(T²)"]
    O -.->|"replaced by"| F
Loading

Comments Outside Diff (1)

  1. transformer_engine/jax/attention.py, line 461-467 (link)

    P2 Dead code: _mask_to_seqlens_offset is now unreachable

    This function is no longer called from anywhere in the file after the refactor. _segment_ids_pos_to_seqlens_offsets previously called it for the BRCM/SWA path; that call is now replaced by _get_seqlens_thd / _get_seqoffsets_thd. It should be removed to avoid confusion about whether this path is still exercised.

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment on lines +551 to +571
# Gather the indices of non-padding tokens per row into a dense prefix;
# slots past the last valid token are filled with -1.
non_zero_mask = segment_ids != 0
max_size = segment_ids.shape[-1]
non_zero_indices = jax.vmap(
lambda mask_row: jnp.where(mask_row, size=max_size, fill_value=-1)[0]
)(non_zero_mask)

# Materialise a padding-free view of segment_ids by gathering at
# non_zero_indices. Slots whose index was -1 are explicitly set
# to 0 so they end up in the id=0 bucket (that we drop below).
clipped_indices = jnp.clip(non_zero_indices, 0, None)
valid_segment_ids = jnp.where(
non_zero_indices >= 0, jnp.take_along_axis(segment_ids, clipped_indices, axis=-1), 0
)
# Per-row bincount of ids -> segment length, discarding the
# id=0 bucket (padding) and capping at max_segments_per_seq.
seqlens_all = jax.vmap(
lambda sp_row: jnp.bincount(sp_row, length=max_segments_per_seq + 1)[1:]
)(valid_segment_ids)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Unnecessary compaction step — bincount handles zeros natively

The non_zero_indices gather-and-clip block compacts padding positions to the back before calling bincount, but bincount already places padding tokens (id=0) into bucket 0, which is discarded via [1:]. The compaction produces the same result as calling bincount directly on the original segment_ids (same logic used by _get_seqlens_and_offsets). The extra vmap(jnp.where) + take_along_axis adds XLA ops that aren't needed:

# Equivalent, simpler form:
seqlens_all = jax.vmap(
    lambda row: jnp.bincount(row.astype(jnp.int32), length=max_segments_per_seq + 1)[1:]
)(segment_ids)
seqlens_all_pad_neg = jnp.where(seqlens_all == 0, -1, seqlens_all)

Comment on lines +592 to +593
# Remove any padded region segment changes (this also handles intra-segment padding correctly)
segment_changes_masked = jnp.where(segment_ids != 0, segment_changes, False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Misleading comment contradicts the documented assumption

The comment "this also handles intra-segment padding correctly" conflicts with assumption 2 stated at line 615: "No intra-segment padding, only inter-segment padding allowed." In fact, if intra-segment padding occurs (e.g., [1, 1, 0, 1, 2, 2]), segment_changes_masked will fire on the re-entry position (index 3 in the example), registering a spurious third "segment start" when only two segments exist. The comment should be removed or updated to say the masking step discards padding-region transitions, not that it handles intra-segment padding.

@KshitijLakhani KshitijLakhani added 2.15.0 performance Performance issues labels Apr 28, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant