Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924
Correctly pad scaling factor inverses to satisfy cuteDSL requirements#2924ksivaman wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Greptile SummaryThis PR fixes grouped MXFP8 swizzle when per-expert rows are not a multiple of 128. The core change introduces a "compact" vs "per-tensor-padded" layout distinction: the quantize kernel writes a compact buffer (no padding between experts), while the swizzle output must be padded to Confidence Score: 5/5Safe to merge; no P0/P1 issues found; logic is correct and well-tested across edge cases. All findings are P2 or below. The compact-layout detection, OOB-load prevention, and output buffer allocation are logically correct and consistent between swizzle.cu and swizzle.cpp. The test suite covers aligned, unaligned, and mixed shapes including the originally-failing workload shape. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["maybe_swizzle_grouped_tensor (swizzle.cpp)"]
A -->|allocate output| B["compute_padded_grouped_scale_shape\nnum_tensors × roundup(M,128) × roundup(⌈K/32⌉,4)"]
A --> C["nvte_swizzle_grouped_scaling_factors (swizzle.cu)"]
C --> D{Detect input layout}
D -->|numel == num_tensors × padded_scale_elems| E["input_is_compact = false\ninput_stride = padded_m × padded_k"]
D -->|numel == compact_total_scale_elems| F["input_is_compact = true\ninput_stride = m × padded_k (rowwise)\nor ⌈M/32⌉ × padded_m (colwise)"]
D -->|mismatch| G[NVTE_ERROR]
E --> H[dispatch_swizzle_*_kernel_impl]
F --> H
H -->|IS_PADDED_M=true, row ≥ original_M| I[Zero register, skip __ldg]
H -->|IS_PADDED_K=true, k_coord ≥ original_K| J[Zero register, skip __ldg]
H -->|in-bounds| K["__ldg + per-byte boundary zeroing"]
I --> L[Output: per-tensor padded layout\noutput_stride = padded_m × padded_k]
J --> L
K --> L
Reviews (3): Last reviewed commit: "Add test for swizzle + padding fusion" | Re-trigger Greptile |
| const auto logical_shape_nvte = input.logical_shape(); | ||
| NVTE_CHECK(logical_shape_nvte.ndim >= 2, | ||
| "Grouped GEMM swizzle expects logical_shape with ndim >= 2."); | ||
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; |
There was a problem hiding this comment.
Silent truncation when
logical_shape_nvte.data[0] is not divisible by num_tensors
per_tensor_first_dim is computed with plain integer division. If logical_shape_nvte.data[0] is not an exact multiple of num_tensors (e.g. due to a caller bug or unexpected grouped layout), the result is silently truncated, causing padded_m to be underestimated and the output buffer to be too small. A divisibility assertion would catch this much earlier with a clear error message.
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| const size_t per_tensor_first_dim = logical_shape_nvte.data[0] / num_tensors; | |
| NVTE_CHECK(logical_shape_nvte.data[0] % num_tensors == 0, | |
| "Grouped GEMM swizzle expects logical_shape first dim to be divisible by num_tensors."); |
| bool input_is_compact; | ||
| if (input_scale_numel == input->num_tensors * padded_scale_elems) { | ||
| input_is_compact = false; | ||
| } else if (input_scale_numel == compact_total_scale_elems) { | ||
| input_is_compact = true; | ||
| } else { | ||
| NVTE_CHECK(input->columnwise_scale_inv.numel() == input->num_tensors * scale_elems, | ||
| "Grouped input columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_CHECK(output->columnwise_scale_inv.numel() == output->num_tensors * scale_elems, | ||
| "Grouped output columnwise_scale_inv size does not match expected packed size."); | ||
| NVTE_ERROR("Grouped input ", (rowwise ? "scale_inv" : "columnwise_scale_inv"), | ||
| " size does not match expected packed size (got ", input_scale_numel, | ||
| ", expected either ", input->num_tensors * padded_scale_elems, | ||
| " (per-tensor padded) or ", compact_total_scale_elems, " (compact))."); | ||
| } |
There was a problem hiding this comment.
Implicit contract on compact-buffer alignment is not validated
The compact_total_scale_elems formula assumes the upstream quantize kernel allocates the compact scale buffer with its total first dim rounded up to 128 (rowwise) or 4 (colwise). If a caller passes a "plain compact" buffer of size exactly num_tensors * m * padded_k (without trailing alignment slack), neither branch matches and NVTE_ERROR fires with a size-mismatch message that may be hard to diagnose.
Consider also accepting num_tensors * compact_scale_elems as a valid compact size, or documenting this alignment requirement in the error message.
|
@ksivaman Could you add a test exercising the change? |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci |
Description
Fix grouped MXFP8 swizzle when per-expert rows aren't a multiple of 128 and pad each expert's scales to (128, 4).
Type of change
Changes
Checklist: