Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 93 additions & 6 deletions src/maxtext/kernels/ragged/ragged_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,22 @@ def main_kernel(
end_ref: jax.Ref,
in_hbm_ref: jax.Ref,
indices_hbm_ref: jax.Ref,
weights_hbm_ref: jax.Ref,
# Outputs.
out_hbm_ref: jax.Ref,
# Scratch.
start_vmem_ref: jax.Ref,
end_vmem_ref: jax.Ref,
out_vmem_ref: jax.Ref,
indices_vmem_ref: jax.Ref,
weights_vmem_ref: jax.Ref,
sem_ref: jax.Ref,
*,
core_axis_name: str,
subcore_axis_name: str,
has_weights: bool,
):
"""Core ragged gather operation"""
"""Core ragged gather operation with per-row weighting."""
tpu_info = pltpu.get_tpu_info()
sc_info = tpu_info.sparse_core
assert sc_info is not None
Expand Down Expand Up @@ -109,9 +112,16 @@ def _():
indices_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)],
indices_vmem_ref,
)
if has_weights:
pltpu.sync_copy(
weights_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)],
weights_vmem_ref,
)

# HBM to VMEM transfer.
indices = indices_vmem_ref[...]
if has_weights:
weights = weights_vmem_ref[...]

dtype = out_hbm_ref.dtype
dtype_bits = jax.dtypes.itemsize_bits(dtype)
Expand Down Expand Up @@ -189,6 +199,50 @@ def dma_write_loop(col_vmem_start):
row_dst = row_src // packing
out_vmem_ref[row_dst, col_slice] = out

# Apply per-row weights after unpacking if needed.
# For packing == 1 (float32), we can apply directly.
# For packing > 1 (bf16), the data is already packed; we apply below.
if has_weights:
if packing == 1:
# float32 path: data is already in float32 layout, one row per sublane.
for col_compute_offset in range(0, num_lanes, num_simd_lanes):
col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes)
for row_vmem in range(num_simd_lanes):
data = out_vmem_ref[row_vmem, col_slice]
data_f32 = jax.lax.bitcast_convert_type(data, jnp.float32)
data_f32 = data_f32 * weights[row_vmem]
out_vmem_ref[row_vmem, col_slice] = jax.lax.bitcast_convert_type(data_f32, jnp.uint32)
else:
# bf16 path: data is packed, packing=2. Each packed row contains 2
# bf16 values from consecutive source rows. We need to unpack each
# bf16 to float32, multiply by its weight, then repack.
for col_compute_offset in range(0, num_lanes, num_simd_lanes):
col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes)
for row_dst in range(num_simd_lanes // packing):
packed_data = out_vmem_ref[row_dst, col_slice]
result = jnp.zeros_like(packed_data)
for sub in range(packing):
row_src = row_dst * packing + sub
# Extract the sub-element.
shift_right = sub * dtype_bits
shift_left = sub * dtype_bits
elem = jnp.bitwise_right_shift(packed_data, shift_right)
elem = jnp.bitwise_and(elem, 2**dtype_bits - 1)
# Convert bf16 bits to float32 for weighting.
# bf16 is stored in the lower 16 bits; shift to upper 16 for
# bitcast to float32.
elem_f32 = jnp.bitwise_left_shift(elem, 16)
elem_f32 = jax.lax.bitcast_convert_type(elem_f32, jnp.float32)
elem_f32 = elem_f32 * weights[row_src]
# Convert back: bitcast float32 -> uint32, shift right 16 to
# get bf16 bits, then shift left to target position.
elem_u32 = jax.lax.bitcast_convert_type(elem_f32, jnp.uint32)
elem_bf16 = jnp.bitwise_right_shift(elem_u32, 16)
elem_bf16 = jnp.bitwise_and(elem_bf16, 2**dtype_bits - 1)
elem_bf16 = jnp.bitwise_left_shift(elem_bf16, shift_left)
result = jnp.bitwise_or(result, elem_bf16)
out_vmem_ref[row_dst, col_slice] = result

# Start dma write.
for row_vmem in range(num_simd_lanes // packing):
row_hbm = row_tile_start // packing + row_vmem
Expand Down Expand Up @@ -234,9 +288,31 @@ def calculate_col_size(hidden_size: int) -> int:
return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes


@jax.jit
def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.Array) -> jax.Array:
"""Perform gather on indices within dynamic array start and end."""
@functools.partial(jax.jit, static_argnames=("has_weights",))
def ragged_gather(
x: jax.Array,
indices: jax.Array,
start: jax.Array,
end: jax.Array,
weights: jax.Array | None = None,
has_weights: bool = False,
) -> jax.Array:
"""Perform gather on indices within dynamic array start and end.

Args:
x: 2D input array of shape ``(input_size, hidden_size)``.
indices: 1D array of gather indices.
start: Scalar or 1D array indicating the start of the valid range.
end: Scalar or 1D array indicating the end of the valid range.
weights: Optional 1D array of per-row weights. When provided, each
gathered row is multiplied by its corresponding weight inside the
kernel, avoiding an extra HBM read-write pass.
has_weights: Static bool flag indicating whether ``weights`` should be
applied. Must be ``True`` when ``weights`` is not ``None``.

Returns:
Gathered output of shape ``(indices_size, hidden_size)``.
"""

assert x.ndim == 2, "Ragged gather only supports 2d inputs."
assert indices.ndim == 1, "Ragged gather only supports 1d indices."
Expand All @@ -257,7 +333,10 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
sc_info = pltpu.get_tpu_info().sparse_core
if sc_info is None:
# Sparse core is not available. Fallback to regular gather.
return x[indices]
out = x[indices]
if has_weights:
out = out * weights[:, None]
return out

hidden_size = x.shape[-1]
out_size = indices.size
Expand All @@ -271,6 +350,12 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
out_pad_size = pl.cdiv(out_size, block_size) * block_size - out_size
indices = jnp.pad(indices, ((0, out_pad_size)))

if has_weights:
weights = jnp.pad(weights, ((0, out_pad_size)), constant_values=1.0)
else:
# Provide a dummy weights array; the kernel won't use it.
weights = jnp.ones((out_size + out_pad_size,), dtype=jnp.float32)

aligned_hidden_size = pl.cdiv(hidden_size, col_size) * col_size

vector_mesh = plsc.VectorSubcoreMesh(
Expand All @@ -284,6 +369,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
main_kernel,
core_axis_name=vector_mesh.core_axis_name,
subcore_axis_name=vector_mesh.subcore_axis_name,
has_weights=has_weights,
),
compiler_params=pltpu.CompilerParams(
use_tc_tiling_on_sc=True,
Expand All @@ -298,7 +384,8 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A
pltpu.VMEM((num_simd_lanes,), jnp.int32),
pltpu.VMEM((num_simd_lanes, col_size), jnp.uint32),
pltpu.VMEM((num_simd_lanes,), jnp.int32),
pltpu.VMEM((num_simd_lanes,), jnp.float32),
pltpu.SemaphoreType.DMA((2,)),
],
},
)(start, end, x, indices)[:out_size, :hidden_size]
)(start, end, x, indices, weights)[:out_size, :hidden_size]
54 changes: 30 additions & 24 deletions src/maxtext/kernels/ragged/ragged_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _ring_ragged_sort(hidden_states_local, topk_indices_local):
"""Sort and gather activations to different EP shards."""
return _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local)[0]

@jax.named_scope("ragged-gather-fwd")
@jax.named_scope("ragged-sort-fwd")
def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):
"""Sort and gather activations forward pass."""

Expand Down Expand Up @@ -136,7 +136,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local):

return out, res

@jax.named_scope("ragged-gather-bwd")
@jax.named_scope("ragged-sort-bwd")
def _ring_ragged_sort_bwd(res, g_out):
"""Backward pass for the gather: a Pallas SC ragged gather reduce.
The forward gathers ``hidden_states_local[token_indices_sorted[i]]`` into
Expand Down Expand Up @@ -248,7 +248,7 @@ def _ring_ragged_unsort(sorted_tokens_local, group_sizes_local, topk_argsort_rev
sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat
)[0]

@jax.named_scope("ragged-scatter-fwd")
@jax.named_scope("ragged-unsort-fwd")
def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat):
"""Executes unsorting sending tokens back."""
group_offsets = jnp.cumulative_sum(group_sizes_local, include_initial=True)
Expand Down Expand Up @@ -309,7 +309,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort

return out, res

@jax.named_scope("ragged-scatter-bwd")
@jax.named_scope("ragged-unsort-bwd")
def _ring_ragged_unsort_bwd(res, g_out):
"""Backward pass for the scatter with routing weights.

Expand All @@ -330,38 +330,44 @@ def _ring_ragged_unsort_bwd(res, g_out):
) = res
g_hidden_states_local = g_out

# Expand g_out from [num_tokens] to [num_tokens * topk] by repeating each
# row topk times, so that g_expanded[i] = g_out[i // topk].
g_expanded = jnp.repeat(g_hidden_states_local, topk, axis=0)

# Apply per-slot routing weights: g_weighted[i] = w[i] * g_out[i // topk]
g_weighted = g_expanded * topk_weights_flat[:, None]

n = topk_argsort_revert_indices.shape[0]
# Build the inverse permutation idx_inv such that idx_inv[j] = i
# where revert[i] = j.
idx_inv = jnp.argsort(topk_argsort_revert_indices)

# Handle the same two buffering modes for backward pass.
# We let ragged_gather do both the fan-out (by indexing into the
# un-expanded g_hidden_states_local via idx_inv // topk) and the
# per-slot weight application (via the fused weights parameter),
# avoiding an extra HBM read-write pass.
if buffer_size >= n:
# We want: g_sorted_tokens[j] = g_weighted[i] where revert[i]=j.
# Build the inverse permutation idx_inv such that idx_inv[j] = i.
idx_inv = jnp.argsort(topk_argsort_revert_indices)
# Because revert is a permutation, gathering with idx_inv reorders correctly.
# ragged_gather fans out g_hidden_states_local by reading the same row
# multiple times when idx_inv // topk maps multiple positions to it.
# Per-slot routing weights are applied inside the kernel.
weight_for_sorted = topk_weights_flat[idx_inv]
grad_sorted_tokens = ragged_gather(
g_weighted,
idx_inv,
g_hidden_states_local,
idx_inv // topk,
shard_output_start[None],
shard_output_end[None],
weights=weight_for_sorted,
has_weights=True,
)
else:
# Slice the inverse permutation to match the packed local buffer.
idx_inv = jnp.argsort(topk_argsort_revert_indices)
padded_idx_inv = jnp.pad(idx_inv, (0, buffer_size))
sliced_idx_inv = jax.lax.dynamic_slice_in_dim(padded_idx_inv, shard_output_start, buffer_size, axis=0)
gather_end = jnp.minimum(shard_output_end - shard_output_start, buffer_size)
# Slice the per-slot routing weights to match the packed local buffer.
padded_weights = jnp.pad(topk_weights_flat[idx_inv], (0, buffer_size))
sliced_weights = jax.lax.dynamic_slice_in_dim(padded_weights, shard_output_start, buffer_size, axis=0)
grad_sorted_tokens = ragged_gather(
g_weighted,
sliced_idx_inv,
g_hidden_states_local,
sliced_idx_inv // topk,
jnp.int32(0)[None],
gather_end[None],
weights=sliced_weights,
has_weights=True,
)
return grad_sorted_tokens, None, None, None

Expand Down Expand Up @@ -409,7 +415,7 @@ def a2a_ragged_sort(inputs, sort_indices, valid_end):
def _a2a_ragged_sort(inputs, sort_indices, valid_end):
return _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end)[0]

@jax.named_scope("local-ragged-gather-fwd")
@jax.named_scope("local-ragged-sort-fwd")
def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end):
start = jnp.int32(0)
end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end)
Expand All @@ -420,7 +426,7 @@ def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end):
res = (sort_indices, end, inputs.shape)
return out, res

@jax.named_scope("local-ragged-gather-bwd")
@jax.named_scope("local-ragged-sort-bwd")
def _a2a_ragged_sort_bwd(res, g_out):
sort_indices, end, _ = res
n = sort_indices.shape[0]
Expand Down Expand Up @@ -474,7 +480,7 @@ def a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end):
def _a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end):
return _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end)[0]

@jax.named_scope("local-ragged-scatter-fwd")
@jax.named_scope("local-ragged-unsort-fwd")
def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end):
start = jnp.int32(0)
end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end)
Expand All @@ -490,7 +496,7 @@ def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end):
res = (revert_indices, end, sorted_tokens.shape, start)
return out, res

@jax.named_scope("local-ragged-scatter-bwd")
@jax.named_scope("local-ragged-unsort-bwd")
def _a2a_ragged_unsort_bwd(res, g_out):
revert_indices, end, sorted_tokens_shape, start = res
# g_sorted_tokens[revert_indices[i]] = g_out[i] for i in [0, end).
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def _build_cfg(use_ragged_sort: bool):
enable_checkpointing=False,
model_name="mixtral-8x7b",
override_model_config=True,
base_emb_dim=256,
base_emb_dim=2048, # we want emb dim being multiple of 1024 for fully using the kernel

@gobbleturk gobbleturk Jun 16, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit strange, why does emb dim matter?

base_mlp_dim=256,
base_moe_mlp_dim=256,
dtype="bfloat16",
Expand Down
Loading