Skip to content

[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930

Open
jeweldave wants to merge 3 commits intoNVIDIA:mainfrom
jeweldave:feat/mxfp8-transpose-cast
Open

[Common/PyTorch] Add MXFP8 cast-and-transpose op#2930
jeweldave wants to merge 3 commits intoNVIDIA:mainfrom
jeweldave:feat/mxfp8-transpose-cast

Conversation

@jeweldave
Copy link
Copy Markdown

Summary

Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose. Surfaces in three
layers:

  • C API (additive, ABI-safe):
    • nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise,
      output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
      signature, E4M3 output, non-swizzled scales.
    • nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype,
      with_gemm_swizzled_scales, stream) — extended signature.
  • PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast
    (default kwargs map to the minimal C symbol's behavior).
  • Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor,
    columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
    returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.

No existing C symbol, Python signature, or default behavior is changed.

Why

The standard MXFP8Quantizer path can already produce row-wise and
column-wise MXFP8 from BF16/FP16/FP32 input. There is currently no public TE
path that, given X and its compact column-wise scales S_col(X), produces
the row-wise compact MXFP8 storage for the logical transpose X.T without
either:

  • emitting it from BF16 in a separate pass through the standard quantizer
    (re-reads BF16 source), or
  • copying the existing column-wise MXFP8 payload and column-wise scales into
    transposed row-wise storage (extra payload+scale byte traffic).

This op closes that gap. It is the building block needed to route MXFP8
backward through TN GEMMs on hardware where cuBLASLt does not currently
support MXFP8 backward NN/NT GEMM layouts (NVIDIA Spark / sm_12.1). On
hardware where backward MXFP8 NN/NT is supported (B200 sm_10.0, H100
sm_9.0) it is unused by default; downstream code can still call it for any
path that wants direct transposed-rowwise MXFP8 emission without a payload
copy.

Detailed motivation, measurements, and out-of-scope notes are in
docs/motivation.md of the proposal directory.

What's in the change

  • transformer_engine/common/include/transformer_engine/recipe.h — declare
    nvte_mxfp8_scaling_transpose_cast and _v2.
  • transformer_engine/common/recipe/mxfp8_scaling.cu — two new kernels
    (mxfp8_scaling_transpose_cast_kernel for the FP8 payload tile transpose,
    mxfp8_scaling_transpose_scales_kernel for compact-or-swizzled scale
    transpose), one new C++ entry point, two new C symbols.
  • transformer_engine/pytorch/csrc/extensions.h /
    extensions/fp8_partial_cast.cpp /
    extensions/pybind.cpp — new
    mxfp8_scaling_transpose_cast PyTorch binding routed through _v2.
  • transformer_engine/pytorch/tensor/mxfp8_tensor.py — new
    MXFP8Quantizer.quantize_rowwise_transpose helper.

Numerics

For an input tensor X quantized with the standard MXFP8 column-wise path,
the new op's output is bit-for-bit equal to taking the existing column-wise
MXFP8 payload + scales and transposing those bytes. Confirmed on GB10 in the
existing cppmega probe at (M, N, K) = (64, 96, 128) and (256, 4096, 4096):
max_payload_abs_byte_delta == 0, payload_equal == True,
scale_equal == True.

The included tests/test_mxfp8_scaling_transpose_cast.py exercises this
equivalence both via the raw extension call and via the
quantize_rowwise_transpose helper.

Tests

Drop-in pytest files added under tests/pytorch/ (this PR puts them in
tests/):

  • test_mxfp8_scaling_transpose_cast.py:
    • byte equivalence vs. column-wise-then-copy reference, multiple shapes,
      E4M3 and E5M2;
    • Python helper equivalence;
    • decoded-value reconstruction is within MXFP8 quantization tolerance of
      the native re-quantized transpose;
    • error path: high-precision input is required (FP8 input rejected);
    • error path: source dims must be MXFP8-block-aligned.
  • test_mxfp8_scaling_transpose_cast_swizzled.py:
    • with with_gemm_swizzled_scales=True, emitted scales match the bytes
      produced by the standard MXFP8Quantizer.quantize swizzled path on the
      actual transposed source.

All tests gate on CUDA being present and the new extension symbol being
built into the loaded transformer_engine_torch module.

Compatibility

  • C API: additive only. Original symbol name is reserved as the long-term
    stable signature; _v2 carries the extra knobs. No symbol's signature is
    changed.
  • Python: additive only. New method on MXFP8Quantizer; no existing method
    signature or default behavior is changed.
  • Build system: no new build flags or files; only changes existing files
    that already participate in the MXFP8 recipe build.

Out of scope (intentionally not in this PR)

  • Wiring TE Linear backward to use this op on GB10. That depends on
    cuBLASLt behavior we cannot upstream and on a downstream backward-rewrite
    shim.
  • Changing default behavior of quantize / quantize_rowwise.
  • Any change to cuBLASLt transposed operand consumption.
  • Dequantize support for with_gemm_swizzled_scales=True MXFP8 tensors. By
    design TE rejects this in cast/mxfp8/dequantize_mxfp8.cuh (and symmetric
    paths in cast/nvfp4/dequantize_nvfp4.cuh,
    cast/mxfp8/group_dequantize_mxfp8.cuh) with
    Input must have scales in compact format: swizzled scales are a one-way
    GEMM-operand layout, and the dequantize kernels don't carry an inverse
    unswizzle_scale_idx path. Our swizzled-scale test
    (test_mxfp8_scaling_transpose_cast_swizzled.py) therefore compares the
    emitted row-wise payload and scale bytes against the standard
    MXFP8Quantizer.quantize(...,
    with_gemm_swizzled_scales=True) output byte-for-byte instead of via
    decoded values, since both paths target the same GEMM-ready layout for
    the same logical row-wise tensor.

apstenku123 and others added 2 commits April 26, 2026 23:28
Add a fused MXFP8 cast-and-transpose op that takes a high-precision tensor
plus the source's existing compact column-wise E8M0 scales and emits row-wise
compact MXFP8 storage for the source's logical transpose.

The standard MXFP8Quantizer path can already produce row-wise and column-wise
MXFP8 from BF16/FP16/FP32 input. There is currently no public TE path that,
given X and its compact column-wise scales S_col(X), produces the row-wise
compact MXFP8 storage for the logical transpose X.T without either re-reading
the BF16 source or copying the existing column-wise MXFP8 payload and scales
into transposed row-wise storage. This op closes that gap. It is the building
block needed to route MXFP8 backward through TN GEMMs on hardware where
cuBLASLt does not currently support MXFP8 backward NN/NT layouts (NVIDIA Spark
sm_12.1). On B200 / H100 the new op is unused by default; downstream code can
still call it for any path that wants direct transposed-rowwise MXFP8 emission
without a payload copy.

Surfaces in three layers, all additive:

* C API (ABI-safe):
  - nvte_mxfp8_scaling_transpose_cast(input, scale_inv_colwise,
    output_rowwise, output_rowwise_scale_inv, rows, cols, stream) — minimal
    signature, E4M3 output, non-swizzled scales.
  - nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype,
    with_gemm_swizzled_scales, stream) — extended signature.
* PyTorch extension: transformer_engine_torch.mxfp8_scaling_transpose_cast
  (default kwargs match the minimal C symbol's behavior).
* Python: MXFP8Quantizer.quantize_rowwise_transpose(tensor,
  columnwise_scale_inv, *, fake_dtype=None, with_gemm_swizzled_scales=None)
  returns a row-wise-only MXFP8Tensor whose logical shape is tensor.T.

No existing C symbol, Python signature, or default behavior is changed.

Tests in tests/pytorch/mxfp8/:
* test_mxfp8_scaling_transpose_cast.py — byte equivalence vs. column-wise-
  then-copy reference (E4M3 + E5M2, multiple shapes), Python helper
  equivalence, decoded-value reconstruction within MXFP8 quantization
  tolerance, error paths for FP8 input and non-block-aligned dims.
* test_mxfp8_scaling_transpose_cast_swizzled.py — with
  with_gemm_swizzled_scales=True, emitted row-wise payload and scales match
  the bytes produced by the standard MXFP8Quantizer.quantize swizzled path
  on the actual transposed source. Comparison is byte-for-byte rather than
  via decoded values because TE's dequantize kernels intentionally reject
  with_gemm_swizzled_scales=True inputs (one-way GEMM-operand layout).

Tested on NVIDIA GB10 (sm_12.1) with TE rebuilt from this change: all 14
parametrized tests pass.

Signed-off-by: David Gornshtein <davidgornshtein@gmail.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 26, 2026

Greptile Summary

This PR adds a fused MXFP8 cast-and-transpose op surfaced at three layers: two new C API symbols (nvte_mxfp8_scaling_transpose_cast / _v2), a new PyTorch extension binding, and a MXFP8Quantizer.quantize_rowwise_transpose helper. The implementation is additive only — no existing symbols or defaults are changed. Kernel indexing for both the tile-transpose payload kernel and the scale-transpose kernel is correct, and the Python helper includes thorough input validation.

Confidence Score: 5/5

Safe to merge; only P2 style findings, no correctness issues.

All findings are P2 (magic literal, nested NVTX ranges). Core kernel logic, index arithmetic, buffer validation, and swizzle routing are correct. Tests cover byte-equivalence, numerical reconstruction, and error paths.

transformer_engine/common/recipe/mxfp8_scaling.cu — two minor style nits

Important Files Changed

Filename Overview
transformer_engine/common/recipe/mxfp8_scaling.cu Adds two new kernels and two new C entry points. Kernel logic is correct (tile transpose with shared-memory padding, scale transpose with proper row-major indexing). Two minor style issues: magic 32 instead of kRowsPerTile in the cast kernel, and nested NVTE_API_CALL in the v1→v2 chain.
transformer_engine/pytorch/tensor/mxfp8_tensor.py Adds quantize_rowwise_transpose helper with full validation (CUDA device, dtype, 2D, block-alignment, scale shape). Buffer allocation, shape assignment, and with_gemm_swizzled_scales propagation all look correct.
transformer_engine/pytorch/csrc/extensions/fp8_partial_cast.cpp New PyTorch binding validates contiguity for all four tensors and correctly routes through the v2 C symbol.
transformer_engine/common/include/transformer_engine/recipe.h Additive-only declarations for both new C symbols with clear Doxygen, matching the existing style.
tests/pytorch/mxfp8/test_mxfp8_scaling_transpose_cast.py Good coverage: byte-equivalence vs. copy-adapter, Python helper equivalence, numerical reconstruction tolerance, FP8-input rejection, and block-alignment enforcement.

Sequence Diagram

sequenceDiagram
    participant Py as Python (MXFP8Quantizer)
    participant Bind as PyTorch Binding (fp8_partial_cast.cpp)
    participant V1 as nvte_mxfp8_scaling_transpose_cast (C API v1)
    participant V2 as nvte_mxfp8_scaling_transpose_cast_v2 (C API v2)
    participant ScaleK as mxfp8_scaling_transpose_scales_kernel (CUDA)
    participant CastK as mxfp8_scaling_transpose_cast_kernel (CUDA)

    Py->>Bind: tex.mxfp8_scaling_transpose_cast(source_2d, colwise_scale_inv, ...)
    Bind->>V2: nvte_mxfp8_scaling_transpose_cast_v2(..., fp8_dtype, swizzled)
    V2->>ScaleK: transpose colwise→rowwise E8M0 scales (compact or swizzled)
    V2->>CastK: tile-transpose HP input, apply scale_inv, emit FP8 payload
    CastK-->>V2: rowwise_data [cols x rows]
    ScaleK-->>V2: rowwise_scale_inv [cols_pad x rows/32_pad]
    V2-->>Bind: done
    Bind-->>Py: returns
    Py->>Py: wrap as MXFP8Tensor(shape=(cols,rows))

    Note over Py,V1: v1 stable symbol wraps v2 with E4M3 + compact defaults
    Note over V1,V2: Both emit NVTE_API_CALL (nested NVTX ranges)
Loading

Reviews (2): Last reviewed commit: "Address review feedback" | Re-trigger Greptile

Comment on lines +389 to +394
*convertNVTETensorCheck(input), *convertNVTETensorCheck(scale_inv_colwise),
*convertNVTETensorCheck(output_rowwise), *convertNVTETensorCheck(output_rowwise_scale_inv),
rows, cols, static_cast<DType>(fp8_dtype), with_gemm_swizzled_scales, stream);
}

void nvte_mxfp8_scaling_transpose_cast(const NVTETensor input, const NVTETensor scale_inv_colwise,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Missing __launch_bounds__ on transpose cast kernel

All other __global__ kernels in this file (mxfp8_scaling_compute_partial_amax_kernel, mxfp8_scaling_partial_cast_kernel, mxfp8_scaling_transpose_scales_kernel) carry __launch_bounds__(kThreadsPerBlock). This kernel is launched with block = dim3(kTransposeTileDim, kTransposeTileDim) = 256 threads, so the appropriate hint would be __launch_bounds__(kTransposeTileDim * kTransposeTileDim). Without it the compiler cannot optimize register allocation for the stated block size.

* Add NVTE_API_CALL(nvte_mxfp8_scaling_transpose_cast) to the v1 entry
  point so profiling/tracing tools attribute calls to the actual symbol
  the caller used instead of v2.
* Add __launch_bounds__(kTransposeTileDim * kTransposeTileDim) on the
  transpose-cast payload kernel to match the launch shape and let the
  compiler tune register allocation, consistent with the other __global__
  kernels in this file.
* Drop unused source = _make_source(64, 128) allocation from
  test_transpose_cast_requires_block_aligned_dims; only bad_source and
  bad_scale are exercised.

Signed-off-by: David Gornshtein <davidgornshtein@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants