From 9172c065a9704036560c4cbaf338352eda0a29cd Mon Sep 17 00:00:00 2001 From: NuojCheng Date: Mon, 15 Jun 2026 19:25:59 +0000 Subject: [PATCH] add fan out for scatter bwd --- src/maxtext/kernels/ragged/ragged_gather.py | 99 +++++++++++++++++++-- src/maxtext/kernels/ragged/ragged_sort.py | 54 ++++++----- tests/unit/moe_test.py | 2 +- 3 files changed, 124 insertions(+), 31 deletions(-) diff --git a/src/maxtext/kernels/ragged/ragged_gather.py b/src/maxtext/kernels/ragged/ragged_gather.py index 483bbd2086..5cb18dba83 100644 --- a/src/maxtext/kernels/ragged/ragged_gather.py +++ b/src/maxtext/kernels/ragged/ragged_gather.py @@ -41,6 +41,7 @@ def main_kernel( end_ref: jax.Ref, in_hbm_ref: jax.Ref, indices_hbm_ref: jax.Ref, + weights_hbm_ref: jax.Ref, # Outputs. out_hbm_ref: jax.Ref, # Scratch. @@ -48,12 +49,14 @@ def main_kernel( end_vmem_ref: jax.Ref, out_vmem_ref: jax.Ref, indices_vmem_ref: jax.Ref, + weights_vmem_ref: jax.Ref, sem_ref: jax.Ref, *, core_axis_name: str, subcore_axis_name: str, + has_weights: bool, ): - """Core ragged gather operation""" + """Core ragged gather operation with per-row weighting.""" tpu_info = pltpu.get_tpu_info() sc_info = tpu_info.sparse_core assert sc_info is not None @@ -109,9 +112,16 @@ def _(): indices_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)], indices_vmem_ref, ) + if has_weights: + pltpu.sync_copy( + weights_hbm_ref.at[pl.ds(row_tile_start, num_simd_lanes)], + weights_vmem_ref, + ) # HBM to VMEM transfer. indices = indices_vmem_ref[...] + if has_weights: + weights = weights_vmem_ref[...] dtype = out_hbm_ref.dtype dtype_bits = jax.dtypes.itemsize_bits(dtype) @@ -189,6 +199,50 @@ def dma_write_loop(col_vmem_start): row_dst = row_src // packing out_vmem_ref[row_dst, col_slice] = out + # Apply per-row weights after unpacking if needed. + # For packing == 1 (float32), we can apply directly. + # For packing > 1 (bf16), the data is already packed; we apply below. + if has_weights: + if packing == 1: + # float32 path: data is already in float32 layout, one row per sublane. + for col_compute_offset in range(0, num_lanes, num_simd_lanes): + col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes) + for row_vmem in range(num_simd_lanes): + data = out_vmem_ref[row_vmem, col_slice] + data_f32 = jax.lax.bitcast_convert_type(data, jnp.float32) + data_f32 = data_f32 * weights[row_vmem] + out_vmem_ref[row_vmem, col_slice] = jax.lax.bitcast_convert_type(data_f32, jnp.uint32) + else: + # bf16 path: data is packed, packing=2. Each packed row contains 2 + # bf16 values from consecutive source rows. We need to unpack each + # bf16 to float32, multiply by its weight, then repack. + for col_compute_offset in range(0, num_lanes, num_simd_lanes): + col_slice = pl.ds(col_vmem_start + col_compute_offset, num_simd_lanes) + for row_dst in range(num_simd_lanes // packing): + packed_data = out_vmem_ref[row_dst, col_slice] + result = jnp.zeros_like(packed_data) + for sub in range(packing): + row_src = row_dst * packing + sub + # Extract the sub-element. + shift_right = sub * dtype_bits + shift_left = sub * dtype_bits + elem = jnp.bitwise_right_shift(packed_data, shift_right) + elem = jnp.bitwise_and(elem, 2**dtype_bits - 1) + # Convert bf16 bits to float32 for weighting. + # bf16 is stored in the lower 16 bits; shift to upper 16 for + # bitcast to float32. + elem_f32 = jnp.bitwise_left_shift(elem, 16) + elem_f32 = jax.lax.bitcast_convert_type(elem_f32, jnp.float32) + elem_f32 = elem_f32 * weights[row_src] + # Convert back: bitcast float32 -> uint32, shift right 16 to + # get bf16 bits, then shift left to target position. + elem_u32 = jax.lax.bitcast_convert_type(elem_f32, jnp.uint32) + elem_bf16 = jnp.bitwise_right_shift(elem_u32, 16) + elem_bf16 = jnp.bitwise_and(elem_bf16, 2**dtype_bits - 1) + elem_bf16 = jnp.bitwise_left_shift(elem_bf16, shift_left) + result = jnp.bitwise_or(result, elem_bf16) + out_vmem_ref[row_dst, col_slice] = result + # Start dma write. for row_vmem in range(num_simd_lanes // packing): row_hbm = row_tile_start // packing + row_vmem @@ -234,9 +288,31 @@ def calculate_col_size(hidden_size: int) -> int: return pl.cdiv(hidden_size, (num_cols * num_lanes)) * num_lanes -@jax.jit -def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.Array) -> jax.Array: - """Perform gather on indices within dynamic array start and end.""" +@functools.partial(jax.jit, static_argnames=("has_weights",)) +def ragged_gather( + x: jax.Array, + indices: jax.Array, + start: jax.Array, + end: jax.Array, + weights: jax.Array | None = None, + has_weights: bool = False, +) -> jax.Array: + """Perform gather on indices within dynamic array start and end. + + Args: + x: 2D input array of shape ``(input_size, hidden_size)``. + indices: 1D array of gather indices. + start: Scalar or 1D array indicating the start of the valid range. + end: Scalar or 1D array indicating the end of the valid range. + weights: Optional 1D array of per-row weights. When provided, each + gathered row is multiplied by its corresponding weight inside the + kernel, avoiding an extra HBM read-write pass. + has_weights: Static bool flag indicating whether ``weights`` should be + applied. Must be ``True`` when ``weights`` is not ``None``. + + Returns: + Gathered output of shape ``(indices_size, hidden_size)``. + """ assert x.ndim == 2, "Ragged gather only supports 2d inputs." assert indices.ndim == 1, "Ragged gather only supports 1d indices." @@ -257,7 +333,10 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A sc_info = pltpu.get_tpu_info().sparse_core if sc_info is None: # Sparse core is not available. Fallback to regular gather. - return x[indices] + out = x[indices] + if has_weights: + out = out * weights[:, None] + return out hidden_size = x.shape[-1] out_size = indices.size @@ -271,6 +350,12 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A out_pad_size = pl.cdiv(out_size, block_size) * block_size - out_size indices = jnp.pad(indices, ((0, out_pad_size))) + if has_weights: + weights = jnp.pad(weights, ((0, out_pad_size)), constant_values=1.0) + else: + # Provide a dummy weights array; the kernel won't use it. + weights = jnp.ones((out_size + out_pad_size,), dtype=jnp.float32) + aligned_hidden_size = pl.cdiv(hidden_size, col_size) * col_size vector_mesh = plsc.VectorSubcoreMesh( @@ -284,6 +369,7 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A main_kernel, core_axis_name=vector_mesh.core_axis_name, subcore_axis_name=vector_mesh.subcore_axis_name, + has_weights=has_weights, ), compiler_params=pltpu.CompilerParams( use_tc_tiling_on_sc=True, @@ -298,7 +384,8 @@ def ragged_gather(x: jax.Array, indices: jax.Array, start: jax.Array, end: jax.A pltpu.VMEM((num_simd_lanes,), jnp.int32), pltpu.VMEM((num_simd_lanes, col_size), jnp.uint32), pltpu.VMEM((num_simd_lanes,), jnp.int32), + pltpu.VMEM((num_simd_lanes,), jnp.float32), pltpu.SemaphoreType.DMA((2,)), ], }, - )(start, end, x, indices)[:out_size, :hidden_size] + )(start, end, x, indices, weights)[:out_size, :hidden_size] diff --git a/src/maxtext/kernels/ragged/ragged_sort.py b/src/maxtext/kernels/ragged/ragged_sort.py index b88dd7f7f6..6f1ccc64f6 100644 --- a/src/maxtext/kernels/ragged/ragged_sort.py +++ b/src/maxtext/kernels/ragged/ragged_sort.py @@ -71,7 +71,7 @@ def _ring_ragged_sort(hidden_states_local, topk_indices_local): """Sort and gather activations to different EP shards.""" return _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local)[0] - @jax.named_scope("ragged-gather-fwd") + @jax.named_scope("ragged-sort-fwd") def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): """Sort and gather activations forward pass.""" @@ -136,7 +136,7 @@ def _ring_ragged_sort_fwd(hidden_states_local, topk_indices_local): return out, res - @jax.named_scope("ragged-gather-bwd") + @jax.named_scope("ragged-sort-bwd") def _ring_ragged_sort_bwd(res, g_out): """Backward pass for the gather: a Pallas SC ragged gather reduce. The forward gathers ``hidden_states_local[token_indices_sorted[i]]`` into @@ -248,7 +248,7 @@ def _ring_ragged_unsort(sorted_tokens_local, group_sizes_local, topk_argsort_rev sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat )[0] - @jax.named_scope("ragged-scatter-fwd") + @jax.named_scope("ragged-unsort-fwd") def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort_revert_indices, topk_weights_flat): """Executes unsorting sending tokens back.""" group_offsets = jnp.cumulative_sum(group_sizes_local, include_initial=True) @@ -309,7 +309,7 @@ def _ring_ragged_unsort_fwd(sorted_tokens_local, group_sizes_local, topk_argsort return out, res - @jax.named_scope("ragged-scatter-bwd") + @jax.named_scope("ragged-unsort-bwd") def _ring_ragged_unsort_bwd(res, g_out): """Backward pass for the scatter with routing weights. @@ -330,38 +330,44 @@ def _ring_ragged_unsort_bwd(res, g_out): ) = res g_hidden_states_local = g_out - # Expand g_out from [num_tokens] to [num_tokens * topk] by repeating each - # row topk times, so that g_expanded[i] = g_out[i // topk]. - g_expanded = jnp.repeat(g_hidden_states_local, topk, axis=0) - - # Apply per-slot routing weights: g_weighted[i] = w[i] * g_out[i // topk] - g_weighted = g_expanded * topk_weights_flat[:, None] - n = topk_argsort_revert_indices.shape[0] + # Build the inverse permutation idx_inv such that idx_inv[j] = i + # where revert[i] = j. + idx_inv = jnp.argsort(topk_argsort_revert_indices) # Handle the same two buffering modes for backward pass. + # We let ragged_gather do both the fan-out (by indexing into the + # un-expanded g_hidden_states_local via idx_inv // topk) and the + # per-slot weight application (via the fused weights parameter), + # avoiding an extra HBM read-write pass. if buffer_size >= n: - # We want: g_sorted_tokens[j] = g_weighted[i] where revert[i]=j. - # Build the inverse permutation idx_inv such that idx_inv[j] = i. - idx_inv = jnp.argsort(topk_argsort_revert_indices) - # Because revert is a permutation, gathering with idx_inv reorders correctly. + # ragged_gather fans out g_hidden_states_local by reading the same row + # multiple times when idx_inv // topk maps multiple positions to it. + # Per-slot routing weights are applied inside the kernel. + weight_for_sorted = topk_weights_flat[idx_inv] grad_sorted_tokens = ragged_gather( - g_weighted, - idx_inv, + g_hidden_states_local, + idx_inv // topk, shard_output_start[None], shard_output_end[None], + weights=weight_for_sorted, + has_weights=True, ) else: # Slice the inverse permutation to match the packed local buffer. - idx_inv = jnp.argsort(topk_argsort_revert_indices) padded_idx_inv = jnp.pad(idx_inv, (0, buffer_size)) sliced_idx_inv = jax.lax.dynamic_slice_in_dim(padded_idx_inv, shard_output_start, buffer_size, axis=0) gather_end = jnp.minimum(shard_output_end - shard_output_start, buffer_size) + # Slice the per-slot routing weights to match the packed local buffer. + padded_weights = jnp.pad(topk_weights_flat[idx_inv], (0, buffer_size)) + sliced_weights = jax.lax.dynamic_slice_in_dim(padded_weights, shard_output_start, buffer_size, axis=0) grad_sorted_tokens = ragged_gather( - g_weighted, - sliced_idx_inv, + g_hidden_states_local, + sliced_idx_inv // topk, jnp.int32(0)[None], gather_end[None], + weights=sliced_weights, + has_weights=True, ) return grad_sorted_tokens, None, None, None @@ -409,7 +415,7 @@ def a2a_ragged_sort(inputs, sort_indices, valid_end): def _a2a_ragged_sort(inputs, sort_indices, valid_end): return _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end)[0] - @jax.named_scope("local-ragged-gather-fwd") + @jax.named_scope("local-ragged-sort-fwd") def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end): start = jnp.int32(0) end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end) @@ -420,7 +426,7 @@ def _a2a_ragged_sort_fwd(inputs, sort_indices, valid_end): res = (sort_indices, end, inputs.shape) return out, res - @jax.named_scope("local-ragged-gather-bwd") + @jax.named_scope("local-ragged-sort-bwd") def _a2a_ragged_sort_bwd(res, g_out): sort_indices, end, _ = res n = sort_indices.shape[0] @@ -474,7 +480,7 @@ def a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end): def _a2a_ragged_unsort(sorted_tokens, revert_indices, valid_end): return _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end)[0] - @jax.named_scope("local-ragged-scatter-fwd") + @jax.named_scope("local-ragged-unsort-fwd") def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end): start = jnp.int32(0) end = valid_end.astype(jnp.int32) if hasattr(valid_end, "astype") else jnp.int32(valid_end) @@ -490,7 +496,7 @@ def _a2a_ragged_unsort_fwd(sorted_tokens, revert_indices, valid_end): res = (revert_indices, end, sorted_tokens.shape, start) return out, res - @jax.named_scope("local-ragged-scatter-bwd") + @jax.named_scope("local-ragged-unsort-bwd") def _a2a_ragged_unsort_bwd(res, g_out): revert_indices, end, sorted_tokens_shape, start = res # g_sorted_tokens[revert_indices[i]] = g_out[i] for i in [0, end). diff --git a/tests/unit/moe_test.py b/tests/unit/moe_test.py index 51b7d1ba3e..18c675948d 100644 --- a/tests/unit/moe_test.py +++ b/tests/unit/moe_test.py @@ -618,7 +618,7 @@ def _build_cfg(use_ragged_sort: bool): enable_checkpointing=False, model_name="mixtral-8x7b", override_model_config=True, - base_emb_dim=256, + base_emb_dim=2048, # we want emb dim being multiple of 1024 for fully using the kernel base_mlp_dim=256, base_moe_mlp_dim=256, dtype="bfloat16",