From a1f0dfe6229e7dd8d815cd8659df0c2d492d5883 Mon Sep 17 00:00:00 2001 From: YigongQin Date: Mon, 20 Apr 2026 14:17:51 -0700 Subject: [PATCH 01/10] plumb through nvfp4 pertoken recipe Signed-off-by: YigongQin --- NVFP4_GROUPED_GEMM_CHANGES.md | 337 ++++++++++++ .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 76 +++ transformer_engine/common/recipe/__init__.py | 46 ++ transformer_engine/pytorch/ops/_common.py | 2 +- .../pytorch/ops/fused/__init__.py | 1 + .../pytorch/ops/fused/forward_grouped_mlp.py | 486 +++++++++++++++++- transformer_engine/pytorch/quantization.py | 18 + 7 files changed, 963 insertions(+), 3 deletions(-) create mode 100644 NVFP4_GROUPED_GEMM_CHANGES.md create mode 100644 transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh diff --git a/NVFP4_GROUPED_GEMM_CHANGES.md b/NVFP4_GROUPED_GEMM_CHANGES.md new file mode 100644 index 0000000000..f64ed145ff --- /dev/null +++ b/NVFP4_GROUPED_GEMM_CHANGES.md @@ -0,0 +1,337 @@ +# Per-Token NVFP4 Grouped GEMM — Change Summary + +## Overview + +Added per-token NVFP4 global scale support to the cuDNN Frontend grouped GEMM kernels, and a new `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` fused op class in TransformerEngine to use it. + +**Scope:** Forward pass only. NVFP4 (FP4 E2M1 data + FP8 E4M3 block scales + FP32 per-token global scale). Backward falls back to unfused path. + +--- + +## cuDNN Frontend Changes + +### New Parameter: `global_scale_tensor` + +A new optional `global_scale_tensor` parameter was added to both the GLU and quant grouped GEMM kernels. It carries a per-token FP32 global scale that is applied to the accumulator after the per-expert `alpha` multiply and before the activation function. + +- **Shape:** `(valid_m, S, 1)` where S=1 for per-token, S>1 for future subchannel scaling +- **Default:** `None` (no-op, zero overhead — compile-time `const_expr` guard) +- **Kernel behavior:** `acc = acc * alpha[expert] * global_scale[token] -> activation(acc)` + +### Files Modified + +| File | Change | +|------|--------| +| `python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py` | Added `enable_global_scale` to `__init__`, `global_scale` param to `__call__`/kernel. Per-token load via `get_gmem_tensor("global_scale", ...)`, FP32 multiply on accumulator after alpha. | +| `python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py` | Same kernel changes for the quant (FC2) path. | +| `python/cudnn/grouped_gemm/moe_sched_extension.py` | Registered `"global_scale"` in the M-dimension tensor category (alongside `prob`, `c`, `d`) for both contiguous and discrete extensions. | +| `python/cudnn/grouped_gemm/grouped_gemm_glu/api.py` | Added `sample_global_scale`/`global_scale_tensor` to `GroupedGemmGluSm100.__init__`, shape validation, dense+discrete compile paths, `tensor_api` closures, `execute`, `grouped_gemm_glu_wrapper_sm100`, and cache keys. | +| `python/cudnn/grouped_gemm/grouped_gemm_quant/api.py` | Same API plumbing for `GroupedGemmQuantSm100` and `grouped_gemm_quant_wrapper_sm100`. | + +### Files Created + +| File | Description | +|------|-------------| +| `test/python/fe_api/test_grouped_gemm_glu_nvfp4.py` | 3 L0 tests: backward compat (`None`), identity (`ones`), functional scaling (`2x`). All pass on B200. | + +### Test Results (cuDNN Frontend) + +| Test Suite | Result | +|------------|--------| +| `test_grouped_gemm_swiglu.py` | 58 passed, 94 skipped (no regression) | +| `test_grouped_gemm_glu.py` | 312 passed, 233 skipped (no regression) | +| `test_grouped_gemm_glu_nvfp4.py` | 3 passed (new) | + +--- + +## TransformerEngine Changes + +### New Class: `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` + +A new fused operation class for NVFP4 forward grouped MLP, modeled after `ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8`. + +**Enabled by:** `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1` environment variable. + +### Files Modified + +| File | Change | +|------|--------| +| `transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py` | Added `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` class (~300 lines), `fuse_forward_ops_nvfp4` registration function. Imported `NVFP4Quantizer` and `NVFP4_BLOCK_SCALING_SIZE`. | +| `transformer_engine/pytorch/ops/fused/__init__.py` | Added `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` to exports. | + +### MXFP8 vs NVFP4 Fused Op Comparison + +| Aspect | MXFP8 | NVFP4 | +|--------|-------|-------| +| Env var | `NVTE_CUTEDSL_FUSED_GROUPED_MLP` | `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4` | +| Data dtype | `float8_e4m3fn` | `float4_e2m1fn_x2` (via `.view()` from `uint8`) | +| Scale dtype | `float8_e8m0fnu` | `float8_e4m3fn` (via `.view()` from `uint8`) | +| Block size | 32 (`MXFP8_BLOCK_SCALING_SIZE`) | 16 (`NVFP4_BLOCK_SCALING_SIZE`) | +| `sf_vec_size` | 32 | 16 | +| FC1 `d_dtype` | `float8_e4m3fn` (re-quant for FC2) | `bfloat16` (no SFD generation) | +| `discrete_col_sfd` | `True` | `False` (FC2 input re-quantized separately) | +| `global_scale_tensor` | Not used (`None`) | Per-token FP32 from `amax / (fp4_max * fp8_max)` | +| FC2 input | Direct from FC1 SFD output (zero-copy) | Re-quantized BF16 -> NVFP4 | + +### Data Flow + +``` +MXFP8 path: + Input(BF16) -> MXFP8 quant -> FC1 GEMM+SwiGLU -> FP8 output + SFD scales + | + v (zero-copy) + FC2 GEMM+quant -> Output(BF16) + +NVFP4 path: + Input(BF16) -> NVFP4 quant -> FC1 GEMM+SwiGLU -> BF16 output + + global_scale + global_scale | + v (re-quantize to NVFP4) + FC2 GEMM+quant -> Output(BF16) + + global_scale +``` + +--- + +## Per-Token NVFP4 Recipe and Backward Override + +### New Recipe: `NVFP4PerTokenBlockScaling` + +Subclass of `NVFP4BlockScaling` that enables per-token global scaling in the forward grouped GEMM path. Backward precision is controlled by `NVTE_BACKWARD_OVERRIDE` (same as MXFP8 per PR #2644). + +**Usage:** +```python +from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling + +# Forward: NVFP4 per-token, Backward: high-precision (BF16) +recipe = NVFP4PerTokenBlockScaling(backward_override="high_precision") + +# Forward: NVFP4 per-token, Backward: dequantized +recipe = NVFP4PerTokenBlockScaling(backward_override="dequantized") + +# Or via env var: +# NVTE_BACKWARD_OVERRIDE=high_precision +recipe = NVFP4PerTokenBlockScaling() +``` + +**Env var to enable fused path:** `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1` + +### Files Modified/Created for Recipe + +| File | Change | +|------|--------| +| `transformer_engine/common/recipe/__init__.py` | Added `NVFP4PerTokenBlockScaling` recipe class (subclass of `NVFP4BlockScaling`) and `nvfp4_pertoken()` class method on `Recipe`. | +| `transformer_engine/pytorch/quantization.py` | Added `NVFP4PerTokenBlockScalingRecipeState` (inherits from `NVFP4BlockScalingRecipeState`). Registered in factory before `nvfp4()` check. | +| `transformer_engine/pytorch/ops/_common.py` | Updated `fuse_grouped_mlp_ops` recipe check: `recipe.mxfp8() or recipe.nvfp4_pertoken()`. | + +### Per-Token Quantization Kernel Placeholder + +| File | Description | +|------|-------------| +| `transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh` | **New placeholder** — CUDA kernel header for per-token NVFP4 quantization. Documents the scaling hierarchy, parameters, and TODO items for implementation. | + +### Backward Override Flow + +The backward override is inherited from `NVFP4BlockScaling` and works identically to the MXFP8 pattern (PR #2644): + +1. **Forward:** `grouped_linear.py` reads `recipe.backward_override` +2. **If `"high_precision"`:** saves original high-precision input before quantization +3. **If `"dequantized"`:** saves quantized input, dequantizes in backward +4. **If `None`:** standard NVFP4 backward (unfused, since backward kernels don't support `global_scale_tensor` yet) + +No module-level changes needed — `grouped_linear.py` automatically respects the `backward_override` field from any `Recipe` subclass. + +--- + +## Open Items + +1. **Per-token quantization kernel** — `quantize_pertoken_nvfp4.cuh` is a placeholder. Currently, the per-tensor amax is broadcast to all tokens as an approximation. The kernel needs to: (a) compute per-row amax via parallel reduction, (b) derive per-row global_scale, (c) quantize with per-row scales. Also requires changes to `NVFP4Quantizer.make_empty()` to allocate `(M,)` shaped amax and C++ bindings in `cast.cpp`. + +2. **FC2 input re-quantization overhead** — The MXFP8 path avoids re-quantization by having FC1 output SFD (scale factor D) directly in FP8 format. The NVFP4 path outputs BF16 from FC1 and re-quantizes to NVFP4 for FC2 input. This can be optimized by enabling `discrete_col_sfd=True` with NVFP4 output dtype in a future iteration. + +3. **Backward pass** — `global_scale_tensor` is forward-pass only. The backward kernels (`grouped_gemm_dglu`, `grouped_gemm_dswiglu`) do not yet support it. Backward falls back to the unfused path. + +4. **No runtime overhead for existing MXFP8 path** — `enable_global_scale` is a compile-time constant (`cutlass.const_expr`). When `False`, the compiler eliminates dead branches entirely. + +--- + +## Verification Commands + +Run these in an environment with both TE (built with C++ extensions) and cuDNN Frontend installed. + +### Prerequisites + +```bash +# Ensure cudnn-frontend source is on PYTHONPATH (for the global_scale_tensor changes) +export PYTHONPATH=/path/to/cudnn-frontend/python:$PYTHONPATH +``` + +### 1. Verify Imports and Recipe + +```bash +python -c " +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + NVFP4PerTokenBlockScaling, +) + +# Recipe class hierarchy +r = NVFP4PerTokenBlockScaling() +print('nvfp4():', r.nvfp4()) # True (subclass of NVFP4BlockScaling) +print('nvfp4_pertoken():', r.nvfp4_pertoken()) # True +print('mxfp8():', r.mxfp8()) # False + +# Backward override +r_hp = NVFP4PerTokenBlockScaling(backward_override='high_precision') +print('backward_override:', r_hp.backward_override) # high_precision + +r_dq = NVFP4PerTokenBlockScaling(backward_override='dequantized') +print('backward_override:', r_dq.backward_override) # dequantized +" +``` + +### 2. Verify Fused Op Class Loads + +```bash +NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " +from transformer_engine.pytorch.ops.fused.forward_grouped_mlp import ( + ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, +) +print('MXFP8 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported()) +print('NVFP4 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported()) +# Both should be True on Blackwell (SM100) with cuDNN frontend installed +" +``` + +### 3. Verify Recipe State Factory + +```bash +python -c " +from transformer_engine.pytorch.quantization import make_recipe_state +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + NVFP4PerTokenBlockScaling, +) + +# Standard NVFP4 -> NVFP4BlockScalingRecipeState +state1 = make_recipe_state(NVFP4BlockScaling(), mode='forward') +print('NVFP4:', type(state1).__name__) + +# Per-token NVFP4 -> NVFP4PerTokenBlockScalingRecipeState +state2 = make_recipe_state(NVFP4PerTokenBlockScaling(), mode='forward') +print('NVFP4 PerToken:', type(state2).__name__) +" +``` + +### 4. Verify Fusion Gate Accepts NVFP4 Per-Token Recipe + +```bash +python -c " +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, + NVFP4BlockScaling, + NVFP4PerTokenBlockScaling, +) + +# Simulate the check in fuse_grouped_mlp_ops +for recipe_cls in [MXFP8BlockScaling, NVFP4BlockScaling, NVFP4PerTokenBlockScaling]: + r = recipe_cls() + passes = r.mxfp8() or r.nvfp4_pertoken() + print(f'{recipe_cls.__name__:40s} fusion gate: {passes}') +# Expected: +# MXFP8BlockScaling -> True (mxfp8) +# NVFP4BlockScaling -> False (neither) +# NVFP4PerTokenBlockScaling -> True (nvfp4_pertoken) +" +``` + +### 5. Run cuDNN Frontend NVFP4 Tests (global_scale_tensor) + +```bash +cd /path/to/cudnn-frontend/test/python +conda activate cudnn-dev # or your env with cudnn-frontend built + +# New NVFP4 global_scale tests +python -m pytest fe_api/test_grouped_gemm_glu_nvfp4.py -v --tb=short + +# Regression: existing tests should still pass +python -m pytest fe_api/test_grouped_gemm_swiglu.py -v --tb=short +python -m pytest fe_api/test_grouped_gemm_glu.py -v --tb=short +``` + +### 6. Run TE Grouped Linear Tests (requires full TE build) + +```bash +cd /path/to/TransformerEngine + +# Existing MXFP8 grouped MLP tests (regression check) +NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python -m pytest test/pytorch/test_grouped_linear.py -v --tb=short -k "mxfp8" 2>&1 | tail -20 + +# NVFP4 per-token path (end-to-end, requires fused op + kernel support) +NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -m pytest test/pytorch/test_grouped_linear.py -v --tb=short -k "nvfp4" 2>&1 | tail -20 +``` + +### 7. Smoke Test: NVFP4 Per-Token Forward Pass (manual) + +```bash +NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling + +recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') + +# Simple MoE-like grouped linear test +num_groups = 4 +in_features = 256 +out_features = 512 +batch = 128 + +with te.fp8_autocast(fp8_recipe=recipe): + fc1 = te.GroupedLinear( + in_features, out_features, num_groups, + bias=False, params_dtype=torch.bfloat16, + ).cuda() + + x = torch.randn(batch, in_features, dtype=torch.bfloat16, device='cuda') + split_sizes = torch.tensor([32, 32, 32, 32], dtype=torch.int64, device='cuda') + + y = fc1(x, extra_inputs=(split_sizes,)) + print(f'Input: {x.shape}, Output: {y.shape}') + print(f'Output dtype: {y.dtype}') + print('Forward pass OK') +" +``` + +### 8. Backward Override Smoke Test + +```bash +NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling + +# Test high_precision backward override +recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') + +num_groups = 4 +in_features = 256 +out_features = 512 + +with te.fp8_autocast(fp8_recipe=recipe): + fc1 = te.GroupedLinear( + in_features, out_features, num_groups, + bias=False, params_dtype=torch.bfloat16, + ).cuda() + + x = torch.randn(32 * num_groups, in_features, dtype=torch.bfloat16, device='cuda', requires_grad=True) + split_sizes = torch.tensor([32] * num_groups, dtype=torch.int64, device='cuda') + + y = fc1(x, extra_inputs=(split_sizes,)) + loss = y.sum() + loss.backward() + print(f'Grad shape: {x.grad.shape}') + print(f'Grad dtype: {x.grad.dtype}') + print('Backward pass (high_precision override) OK') +" +``` diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh new file mode 100644 index 0000000000..5f7b558fb1 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_pertoken_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. + * + * Unlike standard NVFP4 quantization which uses a single per-tensor global scale + * (amax / (fp8_max * fp4_max)), per-token NVFP4 computes a separate global scale + * for each row (token) of the input tensor. This preserves more dynamic range + * information per token, improving accuracy for MoE grouped GEMM workloads. + * + * Scaling hierarchy: + * x_quantized = round_to_fp4(x / (global_scale[row] * block_scale[row, block])) + * x_dequantized = x_quantized * block_scale[row, block] * global_scale[row] + * + * Where: + * - global_scale[row] = row_amax / (fp8_max * fp4_max) [FP32, per-row] + * - block_scale[row, block] = block_amax / (fp4_max * global_scale[row]) [FP8 E4M3, per-16-element block] + * + * Output tensors: + * - data: uint8 packed FP4 (same as standard NVFP4) + * - block_scales: uint8 reinterpreted as FP8 E4M3 (same layout as standard NVFP4) + * - per_token_scales: float32 tensor of shape (num_rows,) containing global_scale per row + * + * TODO: Implement the CUDA kernel. The kernel should: + * 1. Compute per-row amax via parallel reduction + * 2. Derive per-row global_scale = row_amax / (fp8_max * fp4_max) + * 3. For each 16-element block: compute block_amax, derive block_scale, quantize to FP4 + * 4. Store per-row global_scale to output tensor + * + * For now, per-token scaling is approximated by using the per-tensor amax + * broadcast to all rows. The fused grouped MLP path in TransformerEngine + * handles this via the global_scale_tensor parameter in cuDNN Frontend. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ + +#include +#include +#include + +#include "../../common.h" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_pertoken_kernel { + +using namespace core; + +/* + * Per-token NVFP4 quantization kernel placeholder. + * + * Parameters: + * input - Input tensor (rows x cols), high-precision (BF16/FP32) + * output_data - Output packed FP4 data (rows x cols/2), uint8 + * output_scales - Output block scales (rows x ceil(cols/16)), FP8 E4M3 + * output_per_token_scales - Output per-row global scales (rows,), FP32 + * rows - Number of rows (tokens) + * cols - Number of columns (hidden dim), must be multiple of 16 + * + * TODO: Implement kernel body. See quantize_nvfp4.cuh for reference implementation + * of the per-tensor variant. + */ + +} // namespace quantize_pertoken_kernel +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..033e9b82ac 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -96,6 +96,11 @@ def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" return issubclass(cls, NVFP4BlockScaling) + @classmethod + def nvfp4_pertoken(cls): + """Whether the given recipe is NVFP4 per-token block scaling.""" + return issubclass(cls, NVFP4PerTokenBlockScaling) + @classmethod def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" @@ -540,6 +545,47 @@ def __repr__(self) -> str: ) +@dataclass() +class NVFP4PerTokenBlockScaling(NVFP4BlockScaling): + """ + NVFP4 with per-token (per-row) global scaling. + + Extends NVFP4BlockScaling by computing a separate FP32 global scale factor + for each token row, rather than a single per-tensor global scale. This + preserves more dynamic range information per token, improving accuracy + for MoE grouped GEMM workloads. + + The forward pass uses cuDNN Frontend's grouped GEMM kernels with the + ``global_scale_tensor`` parameter to apply per-token scales. The backward + pass is controlled by ``backward_override``: + + - ``None``: Use standard NVFP4 backward (default) + - ``'high_precision'``: Keep original high-precision operands for backward + - ``'dequantized'``: Dequantize saved operands to BF16/FP32 for backward + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + backward_override : {None, 'high_precision', 'dequantized'}, default = None + Backward precision mode. Inherited from NVFP4BlockScaling. + disable_rht : bool, default = False + If set to `True`, random Hadamard transforms are not applied. + disable_stochastic_rounding : bool, default = False + If set to `True`, stochastic rounding is disabled. + disable_2d_quantization : bool, default = False + If set to `True`, 1D block scaling with block size 16 is used for all tensors. + + Notes + ----- + The per-token quantization kernel is a placeholder. Currently, the per-tensor + amax is broadcast to all tokens as an approximation. A true per-token kernel + (``quantize_pertoken_nvfp4.cuh``) will compute row-wise amax for optimal accuracy. + """ + + pass + + @dataclass() class CustomRecipe(Recipe): """ diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index e21915a5a6..cd9b68c1e8 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -156,7 +156,7 @@ def fuse_grouped_mlp_ops( if not fused_op_cls.is_supported(): return ops - if recipe is None or not recipe.mxfp8(): + if recipe is None or not (recipe.mxfp8() or recipe.nvfp4_pertoken()): return ops fc1_bias_ok = ( diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 19a090f121..06197db66f 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -33,6 +33,7 @@ # Note: Registration logic is non-trivial, so submodule handles it internally. from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, + ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, ) from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 90c4204f06..ea18d566a2 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -19,8 +19,9 @@ from ...utils import get_cached_ones_tensor, get_device_compute_capability, mark_grouped_tensor from ...tensor.grouped_tensor import GroupedTensor from ...tensor.mxfp8_tensor import MXFP8Quantizer -from ...constants import MXFP8_BLOCK_SCALING_SIZE from ..basic import GroupedLinear, ScaledClampedQGeGLU, ScaledSwiGLU +from ...tensor.nvfp4_tensor import NVFP4Quantizer +from ...constants import MXFP8_BLOCK_SCALING_SIZE, NVFP4_BLOCK_SCALING_SIZE from ..fuser import register_forward_fusion from ..op import FusedOperation, FusibleOperation, OperationContext from .._common import ( @@ -543,6 +544,471 @@ def fuser_forward( return fc2_out, [(), (), ()] +class ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4(FusedOperation): + """Fused op for NVFP4 GroupedLinear + ScaledSwiGLU + GroupedLinear + + Uses experimental CuTe DSL kernel from cuDNN front-end with NVFP4 + (FP4 E2M1 data + FP8 E4M3 block scales + FP32 per-token global scale). + + Forward pass only. Backward falls back to the unfused path. + """ + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_glu_kernel(cls) -> Callable: + """Fused kernel for grouped GEMM, GLU activation, and post-multiplication.""" + from cudnn import grouped_gemm_glu_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_glu_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def grouped_gemm_quant_kernel(cls) -> Callable: + """Grouped GEMM quant kernel for block-scaled inputs.""" + from cudnn import grouped_gemm_quant_wrapper_sm100 # pylint: disable=no-name-in-module + + return grouped_gemm_quant_wrapper_sm100 + + @classmethod + @functools.lru_cache(maxsize=None) + def is_supported(cls) -> bool: + """Whether this fused operation is supported on the current system.""" + if int(os.environ.get("NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4", "0")) <= 0: + return False + if get_device_compute_capability()[0] != 10: + return False + try: + cls.grouped_gemm_glu_kernel() + cls.grouped_gemm_quant_kernel() + except ImportError: + return False + return True + + def __init__(self, ops: tuple[FusibleOperation, ...]) -> None: + super().__init__(ops) + fc1, swiglu, fc2 = ops + if not isinstance(fc1, GroupedLinear): + raise TypeError(f"Expected GroupedLinear for FC1, got {type(fc1).__name__}") + if not isinstance(swiglu, ScaledSwiGLU): + raise TypeError(f"Expected ScaledSwiGLU, got {type(swiglu).__name__}") + if not isinstance(fc2, GroupedLinear): + raise TypeError(f"Expected GroupedLinear for FC2, got {type(fc2).__name__}") + validate_grouped_mlp_dims(fc1, swiglu, fc2) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + # Get basic operations + fc1_op, _, fc2_op = self.basic_ops + fc1_ctx, swiglu_ctx, fc2_ctx = basic_op_ctxs + + # Tensor properties + fc1_weight_shape = (fc1_op.out_features, fc1_op.in_features) + fc2_weight_shape = (fc2_op.out_features, fc2_op.in_features) + input_ = input_.reshape(-1, fc1_weight_shape[1]) + in_shape = list(input_.size()) + + num_groups = fc1_op.num_groups + fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 + fc2_weight_param = fc2_op.weight if fc2_op.single_grouped_weight else fc2_op.weight0 + device = fc1_weight_param.device + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = fc1_weight_param.dtype + + # Check which grads are required + requires_grad = any(ctx.requires_grad for ctx in basic_op_ctxs) + input_requires_grad = requires_grad + weight_requires_grad = requires_grad and ( + fc1_weight_param.requires_grad or fc2_weight_param.requires_grad + ) + + # Quantizers + fc1_input_quantizer = fc1_op.get_quantizer("forward", 0) + fc1_weight_quantizer = fc1_op.get_quantizer("forward", 1) + fc1_grad_output_quantizer = fc1_op.get_quantizer("backward", 0) + fc2_input_quantizer = fc2_op.get_quantizer("forward", 0) + fc2_weight_quantizer = fc2_op.get_quantizer("forward", 1) + fc2_grad_output_quantizer = fc2_op.get_quantizer("backward", 0) + + # Extract split sizes from extra input + fc1_split_sizes = basic_op_extra_inputs[0][0] + fc2_split_sizes = basic_op_extra_inputs[2][0] + if ( + fc1_split_sizes.size() != fc2_split_sizes.size() + or fc1_split_sizes.data_ptr() != fc2_split_sizes.data_ptr() + ): + raise RuntimeError( + f"{self.__class__.__name__} got different split points for FC1 and FC2." + ) + split_sizes = fc1_split_sizes + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + split_sizes = split_sizes.to(dtype=torch.int64, device=device) + split_points = torch.cumsum(split_sizes, 0, dtype=torch.int) + split_points_offsets = torch.cumsum(split_sizes, 0) + base_offsets = torch.cat( + [ + torch.zeros(1, device=split_sizes.device, dtype=split_sizes.dtype), + split_points_offsets, + ] + ) + fc1_x_tensor_offsets = base_offsets * fc1_weight_shape[1] + fc2_x_tensor_offsets = base_offsets * fc2_weight_shape[1] + + # Extract post-scales from extra input + scales = basic_op_extra_inputs[1][0] + + # Prepare FC1 grouped weight tensor for fused kernels. + if fc1_op.single_grouped_weight: + if not isinstance(fc1_op.weight, GroupedTensor): + raise RuntimeError( + "FC1 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc1_op.weight.quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc1_op.weight.quantizer = fc1_weight_quantizer + grouped_fc1_weight = fc1_op.weight + else: + if fc1_op.weight.rowwise_data is None: + raise RuntimeError("FC1 grouped weight has no rowwise_data to quantize.") + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc1_weight = tex.group_quantize( + fc1_op.weight.rowwise_data.view(fc1_op.weight.logical_shape), + fc1_weight_quantizer, + num_groups, + None, + ) + else: + fc1_weights = [getattr(fc1_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc1_weights = [] + for idx, weight in enumerate(fc1_weights): + quantizer = fc1_op.get_quantizer("forward", 2 * idx + 1) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc1_weights.append(quantizer(weight)) + else: + quantized_fc1_weights.append(weight) + grouped_fc1_weight = quantized_fc1_weights + + # Prepare FC2 grouped weight tensor for fused kernels. + if fc2_op.single_grouped_weight: + if not isinstance(fc2_op.weight, GroupedTensor): + raise RuntimeError( + "FC2 expected GroupedTensor weight with single_grouped_weight=True." + ) + if fc2_op.weight.quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + fc2_op.weight.quantizer = fc2_weight_quantizer + grouped_fc2_weight = fc2_op.weight + else: + if fc2_op.weight.rowwise_data is None: + raise RuntimeError("FC2 grouped weight has no rowwise_data to quantize.") + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + grouped_fc2_weight = tex.group_quantize( + fc2_op.weight.rowwise_data.view(fc2_op.weight.logical_shape), + fc2_weight_quantizer, + num_groups, + None, + ) + else: + fc2_weights = [getattr(fc2_op, f"weight{idx}") for idx in range(num_groups)] + quantized_fc2_weights = [] + for idx, weight in enumerate(fc2_weights): + quantizer = fc2_op.get_quantizer("forward", 2 * idx + 1) + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + if not is_quantized_tensor(weight): + quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) + quantized_fc2_weights.append(quantizer(weight)) + else: + quantized_fc2_weights.append(weight) + grouped_fc2_weight = quantized_fc2_weights + + # Enforce default swizzle metadata + if getattr(grouped_fc1_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc1_weight, GroupedTensor + ): + grouped_fc1_weight._with_gemm_swizzled_scales = False + if getattr(grouped_fc2_weight, "_with_gemm_swizzled_scales", None) is None and isinstance( + grouped_fc2_weight, GroupedTensor + ): + grouped_fc2_weight._with_gemm_swizzled_scales = False + + # Group-quantize input tensor (NVFP4) + fc1_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc1_input_quantizer.optimize_for_gemm = True + if isinstance(input_, GroupedTensor) and isinstance( + getattr(input_, "quantizer", None), NVFP4Quantizer + ): + grouped_fc1_x = input_ + else: + fc1_x = maybe_dequantize(input_, dtype) + grouped_fc1_x = tex.group_quantize( + fc1_x, fc1_input_quantizer, num_groups, split_sizes + ) + + # Pack data tensors for cuDNN kernel + # NVFP4: data is uint8 (packed FP4), reinterpret as float4_e2m1fn_x2 + # Scales are uint8, reinterpret as float8_e4m3fn + # Block size is 16 (NVFP4_BLOCK_SCALING_SIZE) + fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1] // 2) + fc1_x_data = fc1_x_data.view(dtype=torch.float4_e2m1fn_x2) + fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) + + fc1_x_scales = grouped_fc1_x.scale_inv + fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e4m3fn) + fc1_x_scales = fc1_x_scales.view( + 1, + in_shape[0] // 128, + in_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) + + # Per-token global scale from NVFP4 quantizer + # amax_rowwise is per-tensor (1,); broadcast to per-token for now. + # TODO: Implement true per-token global scale in NVFP4Quantizer. + nvfp4_amax = grouped_fc1_x.amax + if nvfp4_amax is not None and nvfp4_amax.numel() == 1: + # global_scale = amax / (fp4_max * fp8_max) per NVFP4 spec + fp4_max = 6.0 + fp8_max = 448.0 + global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) + global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) + else: + global_scale_tensor = None + + alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) + norm_const_tensor = get_cached_ones_tensor(1, dtype, device) + current_stream = torch.cuda.current_stream().cuda_stream + + fc1_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc1_op) + fc2_bias_packed = _pack_grouped_linear_bias_for_cudnn(fc2_op) + + fc1_glu_kwargs = { + "a_tensor": fc1_x_data, + "sfa_tensor": fc1_x_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor, + "bias_tensor": fc1_bias_packed, + "norm_const_tensor": norm_const_tensor, + "prob_tensor": scales.detach().to(dtype=dtype).reshape(-1, 1, 1), + "global_scale_tensor": global_scale_tensor, + "acc_dtype": torch.float32, + "c_dtype": torch.bfloat16, + "d_dtype": torch.bfloat16, # NVFP4 output stays BF16 (no FP8 re-quant for FC2 input) + "cd_major": "n", + "sf_vec_size": NVFP4_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "discrete_col_sfd": False, + "act_func": "swiglu", + "use_dynamic_sched": True, + } + + if fc1_op.single_grouped_weight: + fc1_weight_for_gemm = grouped_fc1_weight.copy() + tex.grouped_swizzle_for_gemm(fc1_weight_for_gemm, rowwise=True, columnwise=False) + + fc1_w_data = fc1_weight_for_gemm.rowwise_data + fc1_w_data = fc1_w_data.view(dtype=torch.float4_e2m1fn_x2) + fc1_w_data = fc1_w_data.view( + num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2 + ) + fc1_w_data = fc1_w_data.permute(1, 2, 0) + fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) + fc1_w_scales = fc1_w_scales.view( + num_groups, + fc1_weight_shape[0] // 128, + fc1_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc1_w_scales = fc1_w_scales.permute(3, 4, 1, 5, 2, 0) + + fc1_glu_kwargs["b_tensor"] = fc1_w_data + fc1_glu_kwargs["sfb_tensor"] = fc1_w_scales + else: + fc1_b_ptrs, fc1_sfb_ptrs, _fc1_sw = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc1_weight], + [w._rowwise_scale_inv for w in grouped_fc1_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc1_weight[0]._fp8_dtype, + ) + fc1_glu_kwargs["b_ptrs"] = fc1_b_ptrs + fc1_glu_kwargs["sfb_ptrs"] = fc1_sfb_ptrs + fc1_glu_kwargs["n"] = fc1_weight_shape[0] + fc1_glu_kwargs["b_dtype"] = torch.float4_e2m1fn_x2 + fc1_glu_kwargs["b_major"] = "k" + + fc1_kernel_out = self.grouped_gemm_glu_kernel()(**fc1_glu_kwargs) + + # Unpack FC1 kernel outputs + # NVFP4 FC1 output is BF16 (no SFD generation needed for FC2) + swiglu_in = fc1_kernel_out["c_tensor"] + swiglu_in = swiglu_in.view(in_shape[0], fc1_weight_shape[0]) + + fc2_in_data = fc1_kernel_out["d_tensor"] + fc2_in_data = fc2_in_data.view(in_shape[0], fc2_weight_shape[1]) + + # FC2 GEMM: input is BF16 from FC1 output, needs re-quantization to NVFP4 + # For now, quantize the BF16 FC2 input to NVFP4 for the quant kernel + fc2_input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + fc2_input_quantizer.optimize_for_gemm = True + grouped_fc2_x = tex.group_quantize( + fc2_in_data, fc2_input_quantizer, num_groups, split_sizes + ) + + fc2_x_data = grouped_fc2_x.rowwise_data.view(in_shape[0], fc2_weight_shape[1] // 2) + fc2_x_data = fc2_x_data.view(dtype=torch.float4_e2m1fn_x2) + fc2_x_data = fc2_x_data.unsqueeze(0).permute(1, 2, 0) + + fc2_x_scales = grouped_fc2_x.scale_inv + fc2_x_scales = fc2_x_scales.view(dtype=torch.float8_e4m3fn) + fc2_x_scales = fc2_x_scales.view( + 1, + in_shape[0] // 128, + fc2_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_x_scales = fc2_x_scales.permute(3, 4, 1, 5, 2, 0) + + # FC2 per-token global scale + fc2_nvfp4_amax = grouped_fc2_x.amax + if fc2_nvfp4_amax is not None and fc2_nvfp4_amax.numel() == 1: + fp4_max = 6.0 + fp8_max = 448.0 + fc2_gs_val = fc2_nvfp4_amax.float() / (fp4_max * fp8_max) + fc2_global_scale = fc2_gs_val.expand(in_shape[0]).reshape(-1, 1, 1) + else: + fc2_global_scale = None + + fc2_out_shape = in_shape[:-1] + [fc2_weight_shape[0]] + fc2_quant_kwargs = { + "a_tensor": fc2_x_data, + "sfa_tensor": fc2_x_scales, + "padded_offsets": split_points, + "alpha_tensor": alpha_tensor.float(), + "norm_const_tensor": None, + "prob_tensor": torch.ones((in_shape[0], 1, 1), dtype=torch.float32, device=device), + "global_scale_tensor": fc2_global_scale, + "acc_dtype": torch.float32, + "c_dtype": dtype, + "d_dtype": dtype, + "cd_major": "n", + "sf_vec_size": NVFP4_BLOCK_SCALING_SIZE, + "current_stream": current_stream, + "use_dynamic_sched": True, + } + + if fc2_op.single_grouped_weight: + fc2_weight_for_gemm = grouped_fc2_weight.copy() + tex.grouped_swizzle_for_gemm(fc2_weight_for_gemm, rowwise=True, columnwise=False) + + fc2_w_data = fc2_weight_for_gemm.rowwise_data + fc2_w_data = fc2_w_data.view(dtype=torch.float4_e2m1fn_x2) + fc2_w_data = fc2_w_data.view( + num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2 + ) + fc2_w_data = fc2_w_data.permute(1, 2, 0) + + fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) + fc2_w_scales = fc2_w_scales.view( + num_groups, + fc2_weight_shape[0] // 128, + fc2_weight_shape[1] // NVFP4_BLOCK_SCALING_SIZE // 4, + NVFP4_BLOCK_SCALING_SIZE, + 4, + 4, + ) + fc2_w_scales = fc2_w_scales.permute(3, 4, 1, 5, 2, 0) + fc2_quant_kwargs["b_tensor"] = fc2_w_data + fc2_quant_kwargs["sfb_tensor"] = fc2_w_scales + else: + fc2_b_ptrs, fc2_sfb_ptrs, _ = tex.get_device_pointer_for_data_and_scales( + [w._rowwise_data for w in grouped_fc2_weight], + [w._rowwise_scale_inv for w in grouped_fc2_weight], + swizzle=True, + rowwise=True, + data_dtype=grouped_fc2_weight[0]._fp8_dtype, + ) + fc2_quant_kwargs["b_ptrs"] = fc2_b_ptrs + fc2_quant_kwargs["sfb_ptrs"] = fc2_sfb_ptrs + fc2_quant_kwargs["n"] = fc2_weight_shape[0] + fc2_quant_kwargs["b_dtype"] = torch.float4_e2m1fn_x2 + fc2_quant_kwargs["b_major"] = "k" + + fc2_kernel_out = self.grouped_gemm_quant_kernel()(**fc2_quant_kwargs) + fc2_out = fc2_kernel_out["d_tensor"].permute(2, 0, 1).view(fc2_out_shape).contiguous() + + # Save state for backward pass + if requires_grad: + mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) + fc1_input_tensors = ( + grouped_fc1_x.columnwise_data, + grouped_fc1_x.columnwise_scale_inv, + fc1_x_tensor_offsets, + ) + fc1_weight_tensors = ( + [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight + ) + fc1_ctx.save_for_backward( + split_sizes, split_points, *fc1_weight_tensors, *fc1_input_tensors + ) + fc1_ctx.with_quantized_compute = True + fc1_ctx.input_quantizer = fc1_input_quantizer + fc1_ctx.weight_quantizer = fc1_weight_quantizer + fc1_ctx.grad_output_quantizer = fc1_grad_output_quantizer + fc1_ctx.grad_input_quantizers = None + fc1_ctx.dtype = dtype + fc1_ctx.input_requires_grad = input_requires_grad + fc1_ctx.weight_requires_grad = weight_requires_grad + fc1_ctx.base_split_offsets = base_offsets + + swiglu_ctx.save_for_backward(swiglu_in, scales) + swiglu_ctx.input_requires_grad = True + swiglu_ctx.extra_input_requires_grad = True + swiglu_ctx.dtype = dtype + + if grouped_fc2_x is not None: + fc2_input_tensors = ( + grouped_fc2_x.columnwise_data, + grouped_fc2_x.columnwise_scale_inv, + fc2_x_tensor_offsets, + ) + else: + fc2_input_tensors = (None, None, None) + + if fc2_op.single_grouped_weight: + fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) + else: + fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) + + fc2_ctx.with_quantized_compute = True + fc2_ctx.input_quantizer = fc2_input_quantizer + fc2_ctx.weight_quantizer = fc2_weight_quantizer + fc2_ctx.grad_output_quantizer = fc2_grad_output_quantizer + fc2_ctx.grad_input_quantizers = None + fc2_ctx.dtype = dtype + fc2_ctx.input_requires_grad = input_requires_grad + fc2_ctx.weight_requires_grad = weight_requires_grad + + return fc2_out, [(), (), ()] + + def fuse_forward_ops( ops: list[FusibleOperation], *, @@ -572,6 +1038,22 @@ def fuse_forward_ops( ) -# Register fusion if available +def fuse_forward_ops_nvfp4( + ops: list[FusibleOperation], + *, + recipe: Optional[Recipe] = None, + **unused, # pylint: disable=unused-argument +) -> list[FusibleOperation]: + """Apply NVFP4 operation fusion for forward pass.""" + return fuse_grouped_mlp_ops( + ops, + recipe=recipe, + fused_op_cls=ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, + ) + + +# Register fusions if available if ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported(): register_forward_fusion(fuse_forward_ops, prepend=True) +if ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported(): + register_forward_fusion(fuse_forward_ops_nvfp4, prepend=True) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..aefe8af39c 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -24,6 +24,7 @@ Float8CurrentScaling, Float8BlockScaling, NVFP4BlockScaling, + NVFP4PerTokenBlockScaling, CustomRecipe, ) from .constants import dist_group_type @@ -1065,6 +1066,8 @@ def create( cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState + elif recipe.nvfp4_pertoken(): + cls = NVFP4PerTokenBlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState elif recipe.custom(): @@ -1396,6 +1399,21 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: raise RuntimeError(f"Unexpected recipe mode ({self.mode})") +class NVFP4PerTokenBlockScalingRecipeState(NVFP4BlockScalingRecipeState): + """State for NVFP4PerTokenBlockScaling recipe. + + Inherits all quantizer creation logic from NVFP4BlockScalingRecipeState. + The per-token global scale is handled at the fused op level (in the cuDNN + kernel via global_scale_tensor), not in the quantizer itself. The quantizer + still produces per-tensor amax which is broadcast to per-token in the fused op. + + Once the per-token quantization kernel (quantize_pertoken_nvfp4.cuh) is + implemented, the quantizer will produce per-row amax directly. + """ + + pass + + class CustomRecipeState(RecipeState): """State for CustomRecipe: produce quantizers per tensor.""" From 6eb70d3d5914ad4a7db4a7004c842da8a9908cca Mon Sep 17 00:00:00 2001 From: YigongQin Date: Tue, 21 Apr 2026 10:47:43 -0700 Subject: [PATCH 02/10] interface checks Signed-off-by: YigongQin --- NVFP4_GROUPED_GEMM_CHANGES.md | 46 +++++------ NVFP4_NEXT_STEPS.md | 141 ++++++++++++++++++++++++++++++++++ 2 files changed, 164 insertions(+), 23 deletions(-) create mode 100644 NVFP4_NEXT_STEPS.md diff --git a/NVFP4_GROUPED_GEMM_CHANGES.md b/NVFP4_GROUPED_GEMM_CHANGES.md index f64ed145ff..573620b2de 100644 --- a/NVFP4_GROUPED_GEMM_CHANGES.md +++ b/NVFP4_GROUPED_GEMM_CHANGES.md @@ -192,7 +192,7 @@ print('backward_override:', r_dq.backward_override) # dequantized ### 2. Verify Fused Op Class Loads ```bash -NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " +NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " from transformer_engine.pytorch.ops.fused.forward_grouped_mlp import ( ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, @@ -200,6 +200,8 @@ from transformer_engine.pytorch.ops.fused.forward_grouped_mlp import ( print('MXFP8 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported()) print('NVFP4 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported()) # Both should be True on Blackwell (SM100) with cuDNN frontend installed +# MXFP8 requires NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 +# NVFP4 requires NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 " ``` @@ -207,18 +209,18 @@ print('NVFP4 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported()) ```bash python -c " -from transformer_engine.pytorch.quantization import make_recipe_state +from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.common.recipe import ( NVFP4BlockScaling, NVFP4PerTokenBlockScaling, ) # Standard NVFP4 -> NVFP4BlockScalingRecipeState -state1 = make_recipe_state(NVFP4BlockScaling(), mode='forward') +state1 = RecipeState.create(NVFP4BlockScaling(), mode='forward') print('NVFP4:', type(state1).__name__) # Per-token NVFP4 -> NVFP4PerTokenBlockScalingRecipeState -state2 = make_recipe_state(NVFP4PerTokenBlockScaling(), mode='forward') +state2 = RecipeState.create(NVFP4PerTokenBlockScaling(), mode='forward') print('NVFP4 PerToken:', type(state2).__name__) " ``` @@ -259,20 +261,19 @@ python -m pytest fe_api/test_grouped_gemm_swiglu.py -v --tb=short python -m pytest fe_api/test_grouped_gemm_glu.py -v --tb=short ``` -### 6. Run TE Grouped Linear Tests (requires full TE build) +### 6. Run TE Backward Override Tests (requires full TE build) ```bash cd /path/to/TransformerEngine -# Existing MXFP8 grouped MLP tests (regression check) -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python -m pytest test/pytorch/test_grouped_linear.py -v --tb=short -k "mxfp8" 2>&1 | tail -20 - -# NVFP4 per-token path (end-to-end, requires fused op + kernel support) -NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -m pytest test/pytorch/test_grouped_linear.py -v --tb=short -k "nvfp4" 2>&1 | tail -20 +# Existing MXFP8 backward override tests (regression check) +NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python -m pytest tests/pytorch/test_backward_override.py -v --tb=short -k "mxfp8" 2>&1 | tail -20 ``` ### 7. Smoke Test: NVFP4 Per-Token Forward Pass (manual) +Note: tokens per expert must be 64-aligned for NVFP4's Hadamard transform. + ```bash NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " import torch @@ -281,22 +282,21 @@ from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') -# Simple MoE-like grouped linear test -num_groups = 4 +num_gemms = 4 in_features = 256 out_features = 512 -batch = 128 +tokens_per_expert = 64 # must be 64-aligned for NVFP4 RHT with te.fp8_autocast(fp8_recipe=recipe): fc1 = te.GroupedLinear( - in_features, out_features, num_groups, + num_gemms, in_features, out_features, bias=False, params_dtype=torch.bfloat16, ).cuda() - x = torch.randn(batch, in_features, dtype=torch.bfloat16, device='cuda') - split_sizes = torch.tensor([32, 32, 32, 32], dtype=torch.int64, device='cuda') + x = torch.randn(tokens_per_expert * num_gemms, in_features, dtype=torch.bfloat16, device='cuda') + m_splits = [tokens_per_expert] * num_gemms - y = fc1(x, extra_inputs=(split_sizes,)) + y = fc1(x, m_splits) print(f'Input: {x.shape}, Output: {y.shape}') print(f'Output dtype: {y.dtype}') print('Forward pass OK') @@ -311,23 +311,23 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling -# Test high_precision backward override recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') -num_groups = 4 +num_gemms = 4 in_features = 256 out_features = 512 +tokens_per_expert = 64 # must be 64-aligned for NVFP4 RHT with te.fp8_autocast(fp8_recipe=recipe): fc1 = te.GroupedLinear( - in_features, out_features, num_groups, + num_gemms, in_features, out_features, bias=False, params_dtype=torch.bfloat16, ).cuda() - x = torch.randn(32 * num_groups, in_features, dtype=torch.bfloat16, device='cuda', requires_grad=True) - split_sizes = torch.tensor([32] * num_groups, dtype=torch.int64, device='cuda') + x = torch.randn(tokens_per_expert * num_gemms, in_features, dtype=torch.bfloat16, device='cuda', requires_grad=True) + m_splits = [tokens_per_expert] * num_gemms - y = fc1(x, extra_inputs=(split_sizes,)) + y = fc1(x, m_splits) loss = y.sum() loss.backward() print(f'Grad shape: {x.grad.shape}') diff --git a/NVFP4_NEXT_STEPS.md b/NVFP4_NEXT_STEPS.md new file mode 100644 index 0000000000..f4337328d0 --- /dev/null +++ b/NVFP4_NEXT_STEPS.md @@ -0,0 +1,141 @@ +# Per-Token NVFP4 Grouped GEMM — What's Missing and Next Steps + +## Current State + +The end-to-end plumbing is complete: recipe -> quantizer -> fused op -> cuDNN kernel. Smoke tests pass for both forward and backward (with `backward_override`). However, several pieces are placeholders or approximations. + +--- + +## What's Missing + +### 1. Per-Token Quantization Kernel (CUDA) + +**Status:** Placeholder only (`quantize_pertoken_nvfp4.cuh`) + +**What exists today:** The standard NVFP4 quantizer computes a single per-tensor amax and derives one global scale `s_global = amax / (fp8_max * fp4_max)`. All tokens share this scale. + +**What's needed:** A kernel that computes per-row (per-token) amax and derives per-row global scales. + +**Files to create/modify:** +- `transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh` — implement kernel +- `transformer_engine/common/cast/nvfp4/group_quantize_pertoken_nvfp4.cuh` — grouped variant for MoE +- `transformer_engine/common/include/transformer_engine/cast.h` — add C API declarations +- `transformer_engine/common/cast.cu` — add dispatch for per-token path +- `transformer_engine/pytorch/csrc/extensions/cast.cpp` — add `group_quantize_nvfp4_pertoken_impl` and wire into `group_quantize()` + +**Kernel spec:** +``` +Input: (M, K) tensor, BF16/FP32 +Output: (M, K/2) packed FP4 data, uint8 + (round_to_128(M), ceil(K/16)/4) block scales, FP8 E4M3 + (M,) per-token global scales, FP32 +``` + +**Key difference from standard NVFP4:** Step 1 computes `amax` per row instead of per tensor. This requires a row-wise parallel reduction (one warp per row or similar). + +### 2. NVFP4Quantizer Per-Token Support + +**Status:** Not started + +**What's needed:** The `NVFP4Quantizer` and `NVFP4Tensor` need to support per-token amax/global_scale instead of per-tensor. + +**Files to modify:** +- `transformer_engine/pytorch/tensor/nvfp4_tensor.py` + - `NVFP4Quantizer.make_empty()` — allocate `amax_rowwise` as `(M,)` instead of `(1,)` when per-token mode is enabled + - `NVFP4Tensor` — property to expose `per_token_global_scale` as `(M,)` FP32 tensor + - `NVFP4Quantizer.get_scale_shape()` — may need adjustment for per-token layout +- `transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py` — storage format for per-token scales + +### 3. Fused Op: Per-Token Global Scale from Quantizer Output + +**Status:** Approximation (broadcast per-tensor amax to all tokens) + +**What's in place:** `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.fuser_forward()` extracts `grouped_fc1_x.amax` (per-tensor, shape `(1,)`) and broadcasts it to `(valid_m, 1, 1)` as `global_scale_tensor`. + +**What's needed once per-token quantizer exists:** +```python +# Replace this: +global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) +global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) + +# With this: +global_scale_tensor = grouped_fc1_x.per_token_global_scale.reshape(-1, 1, 1) +``` + +**File:** `transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py` (inside `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.fuser_forward()`) + +### 4. Fused FC1->FC2 Handoff (Optimization) + +**Status:** Not optimized — FC1 outputs BF16, then FC2 re-quantizes to NVFP4 + +**What exists in MXFP8:** FC1 kernel produces FP8 output + SFD (scale factor D) in a single kernel call (`discrete_col_sfd=True`). FC2 consumes them directly with zero re-quantization. + +**What's needed for NVFP4:** Enable `discrete_col_sfd=True` with FP4 output dtype in the cuDNN kernel, so FC1 directly produces NVFP4 output + block scales + per-token global scales. Then FC2 can consume them without re-quantizing. This requires: +- cuDNN kernel: verify FP4 output with SFD generation works (may already work, needs testing) +- TE fused op: update FC2 input path to use FC1's SFD output instead of re-quantizing + +### 5. Backward Pass Kernels + +**Status:** Not implemented. Backward falls back to unfused path via `backward_override`. + +**What's needed for fused backward:** +- Add `global_scale_tensor` to cuDNN `grouped_gemm_dglu_wrapper_sm100` (backward GLU kernel) +- Add `global_scale_tensor` to cuDNN `grouped_gemm_dswiglu_wrapper_sm100` (backward SwiGLU kernel) +- Same kernel pattern as the forward: `enable_global_scale` flag, per-token load in epilogue +- Add `BackwardGroupedMLP_CuTeGEMMDSwiGLU_NVFP4` class in TE +- Wire up backward fusion registration + +**Files (cuDNN Frontend):** +- `python/cudnn/grouped_gemm/grouped_gemm_dglu/moe_blockscaled_grouped_gemm_dglu_dbias.py` +- `python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py` +- `python/cudnn/grouped_gemm/grouped_gemm_dswiglu/grouped_gemm_dswiglu_quant.py` +- `python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py` + +**Files (TE):** +- `transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py` +- `transformer_engine/pytorch/ops/fused/__init__.py` + +### 6. Weight Gradient Kernel + +**Status:** Not in scope yet + +The weight gradient path (`grouped_gemm_wgrad_wrapper_sm100`) also needs `global_scale_tensor` support if wgrad computation uses NVFP4-quantized activations. + +**File (cuDNN Frontend):** +- `python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py` +- `python/cudnn/grouped_gemm/grouped_gemm_wgrad/moe_blockscaled_grouped_gemm_wgrad.py` + +### 7. Tests + +**Status:** Basic smoke tests pass. Comprehensive tests missing. + +**Needed:** +- cuDNN Frontend: more test configs in `test_grouped_gemm_glu_nvfp4.py` (FP8 + global_scale, varying per-token values, discrete mode, class API) +- TE: dedicated NVFP4 per-token test cases in `tests/pytorch/test_backward_override.py` for `NVFP4PerTokenBlockScaling` +- Numerical accuracy comparison: NVFP4 per-token vs per-tensor vs BF16 baseline on real MoE workloads + +--- + +## Recommended Execution Order + +### Phase 1: Make per-token approximation production-ready +1. Add comprehensive cuDNN Frontend tests for `global_scale_tensor` (more configs, edge cases) +2. Add TE test cases for `NVFP4PerTokenBlockScaling` in backward override test suite +3. Optimize FC1->FC2 handoff (test if `discrete_col_sfd=True` works with FP4 output) + +### Phase 2: Implement true per-token quantization kernel +4. Implement `quantize_pertoken_nvfp4.cuh` CUDA kernel +5. Add grouped variant `group_quantize_pertoken_nvfp4.cuh` +6. Add C API and C++ bindings +7. Update `NVFP4Quantizer` to support per-token mode +8. Update fused op to use real per-token global scales + +### Phase 3: Fused backward pass +9. Add `global_scale_tensor` to backward cuDNN kernels (dglu, dswiglu) +10. Add `BackwardGroupedMLP_CuTeGEMMDSwiGLU_NVFP4` in TE +11. Remove `backward_override` requirement for NVFP4 fused path + +### Phase 4: Benchmarking and validation +12. Benchmark per-token vs per-tensor NVFP4 on DeepSeek-V3 / Mixtral MoE workloads +13. Compare training loss curves: NVFP4 per-token vs MXFP8 vs BF16 +14. Measure throughput: fused NVFP4 per-token vs unfused NVFP4 vs MXFP8 fused From 86694ffe4ffc5860c7d7229b299b3a9da371ab48 Mon Sep 17 00:00:00 2001 From: YigongQin Date: Tue, 21 Apr 2026 11:16:40 -0700 Subject: [PATCH 03/10] more unit test Signed-off-by: YigongQin --- NVFP4_GROUPED_GEMM_CHANGES.md | 2 +- NVFP4_NEXT_STEPS.md | 46 ++++++------------------- tests/pytorch/test_backward_override.py | 15 +++++--- tests/pytorch/utils.py | 7 ++++ 4 files changed, 28 insertions(+), 42 deletions(-) diff --git a/NVFP4_GROUPED_GEMM_CHANGES.md b/NVFP4_GROUPED_GEMM_CHANGES.md index 573620b2de..b544134718 100644 --- a/NVFP4_GROUPED_GEMM_CHANGES.md +++ b/NVFP4_GROUPED_GEMM_CHANGES.md @@ -148,7 +148,7 @@ No module-level changes needed — `grouped_linear.py` automatically respects th 2. **FC2 input re-quantization overhead** — The MXFP8 path avoids re-quantization by having FC1 output SFD (scale factor D) directly in FP8 format. The NVFP4 path outputs BF16 from FC1 and re-quantizes to NVFP4 for FC2 input. This can be optimized by enabling `discrete_col_sfd=True` with NVFP4 output dtype in a future iteration. -3. **Backward pass** — `global_scale_tensor` is forward-pass only. The backward kernels (`grouped_gemm_dglu`, `grouped_gemm_dswiglu`) do not yet support it. Backward falls back to the unfused path. +3. **Backward pass** — By design, backward runs in higher precision (BF16) via `backward_override`, not with fused NVFP4 kernels. This matches the MXFP8 pattern from PR #2644. 4. **No runtime overhead for existing MXFP8 path** — `enable_global_scale` is a compile-time constant (`cutlass.const_expr`). When `False`, the compiler eliminates dead branches entirely. diff --git a/NVFP4_NEXT_STEPS.md b/NVFP4_NEXT_STEPS.md index f4337328d0..762f78054c 100644 --- a/NVFP4_NEXT_STEPS.md +++ b/NVFP4_NEXT_STEPS.md @@ -74,36 +74,15 @@ global_scale_tensor = grouped_fc1_x.per_token_global_scale.reshape(-1, 1, 1) - cuDNN kernel: verify FP4 output with SFD generation works (may already work, needs testing) - TE fused op: update FC2 input path to use FC1's SFD output instead of re-quantizing -### 5. Backward Pass Kernels +### 5. Backward Pass -**Status:** Not implemented. Backward falls back to unfused path via `backward_override`. +**Status:** By design, backward uses higher precision via `backward_override`. -**What's needed for fused backward:** -- Add `global_scale_tensor` to cuDNN `grouped_gemm_dglu_wrapper_sm100` (backward GLU kernel) -- Add `global_scale_tensor` to cuDNN `grouped_gemm_dswiglu_wrapper_sm100` (backward SwiGLU kernel) -- Same kernel pattern as the forward: `enable_global_scale` flag, per-token load in epilogue -- Add `BackwardGroupedMLP_CuTeGEMMDSwiGLU_NVFP4` class in TE -- Wire up backward fusion registration +The `NVFP4PerTokenBlockScaling` recipe intentionally runs forward in NVFP4 for throughput and backward in BF16 for training stability. This is controlled by `backward_override`: +- `"high_precision"`: saves original BF16 tensors for backward (more memory, best accuracy) +- `"dequantized"`: dequantizes saved FP4 tensors to BF16 for backward (less memory, slightly less accurate) -**Files (cuDNN Frontend):** -- `python/cudnn/grouped_gemm/grouped_gemm_dglu/moe_blockscaled_grouped_gemm_dglu_dbias.py` -- `python/cudnn/grouped_gemm/grouped_gemm_dglu/api.py` -- `python/cudnn/grouped_gemm/grouped_gemm_dswiglu/grouped_gemm_dswiglu_quant.py` -- `python/cudnn/grouped_gemm/grouped_gemm_dswiglu/api.py` - -**Files (TE):** -- `transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py` -- `transformer_engine/pytorch/ops/fused/__init__.py` - -### 6. Weight Gradient Kernel - -**Status:** Not in scope yet - -The weight gradient path (`grouped_gemm_wgrad_wrapper_sm100`) also needs `global_scale_tensor` support if wgrad computation uses NVFP4-quantized activations. - -**File (cuDNN Frontend):** -- `python/cudnn/grouped_gemm/grouped_gemm_wgrad/api.py` -- `python/cudnn/grouped_gemm/grouped_gemm_wgrad/moe_blockscaled_grouped_gemm_wgrad.py` +No fused NVFP4 backward kernels are needed. The unfused BF16 backward path handles gradient computation. This matches the pattern established by MXFP8 in PR #2644. ### 7. Tests @@ -130,12 +109,7 @@ The weight gradient path (`grouped_gemm_wgrad_wrapper_sm100`) also needs `global 7. Update `NVFP4Quantizer` to support per-token mode 8. Update fused op to use real per-token global scales -### Phase 3: Fused backward pass -9. Add `global_scale_tensor` to backward cuDNN kernels (dglu, dswiglu) -10. Add `BackwardGroupedMLP_CuTeGEMMDSwiGLU_NVFP4` in TE -11. Remove `backward_override` requirement for NVFP4 fused path - -### Phase 4: Benchmarking and validation -12. Benchmark per-token vs per-tensor NVFP4 on DeepSeek-V3 / Mixtral MoE workloads -13. Compare training loss curves: NVFP4 per-token vs MXFP8 vs BF16 -14. Measure throughput: fused NVFP4 per-token vs unfused NVFP4 vs MXFP8 fused +### Phase 3: Benchmarking and validation +9. Benchmark per-token vs per-tensor NVFP4 on DeepSeek-V3 / Mixtral MoE workloads +10. Compare training loss curves: NVFP4 per-token vs MXFP8 vs BF16 +11. Measure throughput: fused NVFP4 per-token vs unfused NVFP4 vs MXFP8 fused diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..b98f54053a 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_pertoken", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4PerTokenBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_pertoken"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -195,7 +200,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +225,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,9 +244,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index fd9a6416ec..f5077ee294 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -149,6 +149,13 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_pertoken": + return transformer_engine.common.recipe.NVFP4PerTokenBlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") From cf367d8378e481a0a8af509ce0cb4b93af1318b8 Mon Sep 17 00:00:00 2001 From: YigongQin Date: Tue, 21 Apr 2026 13:55:28 -0700 Subject: [PATCH 04/10] pertoken quantization to match flashinfer Signed-off-by: YigongQin --- NVFP4_GROUPED_GEMM_CHANGES.md | 337 ------------------ NVFP4_NEXT_STEPS.md | 115 ------ tests/pytorch/test_nvfp4_pertoken_quant.py | 256 +++++++++++++ transformer_engine/common/cast/cast.cu | 45 +++ .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 219 ++++++++++-- .../common/include/transformer_engine/cast.h | 21 ++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/cast.cpp | 37 ++ .../pytorch/csrc/extensions/pybind.cpp | 2 + .../pytorch/ops/fused/forward_grouped_mlp.py | 34 +- 10 files changed, 569 insertions(+), 499 deletions(-) delete mode 100644 NVFP4_GROUPED_GEMM_CHANGES.md delete mode 100644 NVFP4_NEXT_STEPS.md create mode 100644 tests/pytorch/test_nvfp4_pertoken_quant.py diff --git a/NVFP4_GROUPED_GEMM_CHANGES.md b/NVFP4_GROUPED_GEMM_CHANGES.md deleted file mode 100644 index b544134718..0000000000 --- a/NVFP4_GROUPED_GEMM_CHANGES.md +++ /dev/null @@ -1,337 +0,0 @@ -# Per-Token NVFP4 Grouped GEMM — Change Summary - -## Overview - -Added per-token NVFP4 global scale support to the cuDNN Frontend grouped GEMM kernels, and a new `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` fused op class in TransformerEngine to use it. - -**Scope:** Forward pass only. NVFP4 (FP4 E2M1 data + FP8 E4M3 block scales + FP32 per-token global scale). Backward falls back to unfused path. - ---- - -## cuDNN Frontend Changes - -### New Parameter: `global_scale_tensor` - -A new optional `global_scale_tensor` parameter was added to both the GLU and quant grouped GEMM kernels. It carries a per-token FP32 global scale that is applied to the accumulator after the per-expert `alpha` multiply and before the activation function. - -- **Shape:** `(valid_m, S, 1)` where S=1 for per-token, S>1 for future subchannel scaling -- **Default:** `None` (no-op, zero overhead — compile-time `const_expr` guard) -- **Kernel behavior:** `acc = acc * alpha[expert] * global_scale[token] -> activation(acc)` - -### Files Modified - -| File | Change | -|------|--------| -| `python/cudnn/grouped_gemm/grouped_gemm_glu/moe_blockscaled_grouped_gemm_glu_bias.py` | Added `enable_global_scale` to `__init__`, `global_scale` param to `__call__`/kernel. Per-token load via `get_gmem_tensor("global_scale", ...)`, FP32 multiply on accumulator after alpha. | -| `python/cudnn/grouped_gemm/grouped_gemm_quant/grouped_gemm_quant.py` | Same kernel changes for the quant (FC2) path. | -| `python/cudnn/grouped_gemm/moe_sched_extension.py` | Registered `"global_scale"` in the M-dimension tensor category (alongside `prob`, `c`, `d`) for both contiguous and discrete extensions. | -| `python/cudnn/grouped_gemm/grouped_gemm_glu/api.py` | Added `sample_global_scale`/`global_scale_tensor` to `GroupedGemmGluSm100.__init__`, shape validation, dense+discrete compile paths, `tensor_api` closures, `execute`, `grouped_gemm_glu_wrapper_sm100`, and cache keys. | -| `python/cudnn/grouped_gemm/grouped_gemm_quant/api.py` | Same API plumbing for `GroupedGemmQuantSm100` and `grouped_gemm_quant_wrapper_sm100`. | - -### Files Created - -| File | Description | -|------|-------------| -| `test/python/fe_api/test_grouped_gemm_glu_nvfp4.py` | 3 L0 tests: backward compat (`None`), identity (`ones`), functional scaling (`2x`). All pass on B200. | - -### Test Results (cuDNN Frontend) - -| Test Suite | Result | -|------------|--------| -| `test_grouped_gemm_swiglu.py` | 58 passed, 94 skipped (no regression) | -| `test_grouped_gemm_glu.py` | 312 passed, 233 skipped (no regression) | -| `test_grouped_gemm_glu_nvfp4.py` | 3 passed (new) | - ---- - -## TransformerEngine Changes - -### New Class: `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` - -A new fused operation class for NVFP4 forward grouped MLP, modeled after `ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8`. - -**Enabled by:** `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1` environment variable. - -### Files Modified - -| File | Change | -|------|--------| -| `transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py` | Added `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` class (~300 lines), `fuse_forward_ops_nvfp4` registration function. Imported `NVFP4Quantizer` and `NVFP4_BLOCK_SCALING_SIZE`. | -| `transformer_engine/pytorch/ops/fused/__init__.py` | Added `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4` to exports. | - -### MXFP8 vs NVFP4 Fused Op Comparison - -| Aspect | MXFP8 | NVFP4 | -|--------|-------|-------| -| Env var | `NVTE_CUTEDSL_FUSED_GROUPED_MLP` | `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4` | -| Data dtype | `float8_e4m3fn` | `float4_e2m1fn_x2` (via `.view()` from `uint8`) | -| Scale dtype | `float8_e8m0fnu` | `float8_e4m3fn` (via `.view()` from `uint8`) | -| Block size | 32 (`MXFP8_BLOCK_SCALING_SIZE`) | 16 (`NVFP4_BLOCK_SCALING_SIZE`) | -| `sf_vec_size` | 32 | 16 | -| FC1 `d_dtype` | `float8_e4m3fn` (re-quant for FC2) | `bfloat16` (no SFD generation) | -| `discrete_col_sfd` | `True` | `False` (FC2 input re-quantized separately) | -| `global_scale_tensor` | Not used (`None`) | Per-token FP32 from `amax / (fp4_max * fp8_max)` | -| FC2 input | Direct from FC1 SFD output (zero-copy) | Re-quantized BF16 -> NVFP4 | - -### Data Flow - -``` -MXFP8 path: - Input(BF16) -> MXFP8 quant -> FC1 GEMM+SwiGLU -> FP8 output + SFD scales - | - v (zero-copy) - FC2 GEMM+quant -> Output(BF16) - -NVFP4 path: - Input(BF16) -> NVFP4 quant -> FC1 GEMM+SwiGLU -> BF16 output - + global_scale + global_scale | - v (re-quantize to NVFP4) - FC2 GEMM+quant -> Output(BF16) - + global_scale -``` - ---- - -## Per-Token NVFP4 Recipe and Backward Override - -### New Recipe: `NVFP4PerTokenBlockScaling` - -Subclass of `NVFP4BlockScaling` that enables per-token global scaling in the forward grouped GEMM path. Backward precision is controlled by `NVTE_BACKWARD_OVERRIDE` (same as MXFP8 per PR #2644). - -**Usage:** -```python -from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling - -# Forward: NVFP4 per-token, Backward: high-precision (BF16) -recipe = NVFP4PerTokenBlockScaling(backward_override="high_precision") - -# Forward: NVFP4 per-token, Backward: dequantized -recipe = NVFP4PerTokenBlockScaling(backward_override="dequantized") - -# Or via env var: -# NVTE_BACKWARD_OVERRIDE=high_precision -recipe = NVFP4PerTokenBlockScaling() -``` - -**Env var to enable fused path:** `NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1` - -### Files Modified/Created for Recipe - -| File | Change | -|------|--------| -| `transformer_engine/common/recipe/__init__.py` | Added `NVFP4PerTokenBlockScaling` recipe class (subclass of `NVFP4BlockScaling`) and `nvfp4_pertoken()` class method on `Recipe`. | -| `transformer_engine/pytorch/quantization.py` | Added `NVFP4PerTokenBlockScalingRecipeState` (inherits from `NVFP4BlockScalingRecipeState`). Registered in factory before `nvfp4()` check. | -| `transformer_engine/pytorch/ops/_common.py` | Updated `fuse_grouped_mlp_ops` recipe check: `recipe.mxfp8() or recipe.nvfp4_pertoken()`. | - -### Per-Token Quantization Kernel Placeholder - -| File | Description | -|------|-------------| -| `transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh` | **New placeholder** — CUDA kernel header for per-token NVFP4 quantization. Documents the scaling hierarchy, parameters, and TODO items for implementation. | - -### Backward Override Flow - -The backward override is inherited from `NVFP4BlockScaling` and works identically to the MXFP8 pattern (PR #2644): - -1. **Forward:** `grouped_linear.py` reads `recipe.backward_override` -2. **If `"high_precision"`:** saves original high-precision input before quantization -3. **If `"dequantized"`:** saves quantized input, dequantizes in backward -4. **If `None`:** standard NVFP4 backward (unfused, since backward kernels don't support `global_scale_tensor` yet) - -No module-level changes needed — `grouped_linear.py` automatically respects the `backward_override` field from any `Recipe` subclass. - ---- - -## Open Items - -1. **Per-token quantization kernel** — `quantize_pertoken_nvfp4.cuh` is a placeholder. Currently, the per-tensor amax is broadcast to all tokens as an approximation. The kernel needs to: (a) compute per-row amax via parallel reduction, (b) derive per-row global_scale, (c) quantize with per-row scales. Also requires changes to `NVFP4Quantizer.make_empty()` to allocate `(M,)` shaped amax and C++ bindings in `cast.cpp`. - -2. **FC2 input re-quantization overhead** — The MXFP8 path avoids re-quantization by having FC1 output SFD (scale factor D) directly in FP8 format. The NVFP4 path outputs BF16 from FC1 and re-quantizes to NVFP4 for FC2 input. This can be optimized by enabling `discrete_col_sfd=True` with NVFP4 output dtype in a future iteration. - -3. **Backward pass** — By design, backward runs in higher precision (BF16) via `backward_override`, not with fused NVFP4 kernels. This matches the MXFP8 pattern from PR #2644. - -4. **No runtime overhead for existing MXFP8 path** — `enable_global_scale` is a compile-time constant (`cutlass.const_expr`). When `False`, the compiler eliminates dead branches entirely. - ---- - -## Verification Commands - -Run these in an environment with both TE (built with C++ extensions) and cuDNN Frontend installed. - -### Prerequisites - -```bash -# Ensure cudnn-frontend source is on PYTHONPATH (for the global_scale_tensor changes) -export PYTHONPATH=/path/to/cudnn-frontend/python:$PYTHONPATH -``` - -### 1. Verify Imports and Recipe - -```bash -python -c " -from transformer_engine.common.recipe import ( - NVFP4BlockScaling, - NVFP4PerTokenBlockScaling, -) - -# Recipe class hierarchy -r = NVFP4PerTokenBlockScaling() -print('nvfp4():', r.nvfp4()) # True (subclass of NVFP4BlockScaling) -print('nvfp4_pertoken():', r.nvfp4_pertoken()) # True -print('mxfp8():', r.mxfp8()) # False - -# Backward override -r_hp = NVFP4PerTokenBlockScaling(backward_override='high_precision') -print('backward_override:', r_hp.backward_override) # high_precision - -r_dq = NVFP4PerTokenBlockScaling(backward_override='dequantized') -print('backward_override:', r_dq.backward_override) # dequantized -" -``` - -### 2. Verify Fused Op Class Loads - -```bash -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " -from transformer_engine.pytorch.ops.fused.forward_grouped_mlp import ( - ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8, - ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4, -) -print('MXFP8 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported()) -print('NVFP4 supported:', ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.is_supported()) -# Both should be True on Blackwell (SM100) with cuDNN frontend installed -# MXFP8 requires NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 -# NVFP4 requires NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 -" -``` - -### 3. Verify Recipe State Factory - -```bash -python -c " -from transformer_engine.pytorch.quantization import RecipeState -from transformer_engine.common.recipe import ( - NVFP4BlockScaling, - NVFP4PerTokenBlockScaling, -) - -# Standard NVFP4 -> NVFP4BlockScalingRecipeState -state1 = RecipeState.create(NVFP4BlockScaling(), mode='forward') -print('NVFP4:', type(state1).__name__) - -# Per-token NVFP4 -> NVFP4PerTokenBlockScalingRecipeState -state2 = RecipeState.create(NVFP4PerTokenBlockScaling(), mode='forward') -print('NVFP4 PerToken:', type(state2).__name__) -" -``` - -### 4. Verify Fusion Gate Accepts NVFP4 Per-Token Recipe - -```bash -python -c " -from transformer_engine.common.recipe import ( - MXFP8BlockScaling, - NVFP4BlockScaling, - NVFP4PerTokenBlockScaling, -) - -# Simulate the check in fuse_grouped_mlp_ops -for recipe_cls in [MXFP8BlockScaling, NVFP4BlockScaling, NVFP4PerTokenBlockScaling]: - r = recipe_cls() - passes = r.mxfp8() or r.nvfp4_pertoken() - print(f'{recipe_cls.__name__:40s} fusion gate: {passes}') -# Expected: -# MXFP8BlockScaling -> True (mxfp8) -# NVFP4BlockScaling -> False (neither) -# NVFP4PerTokenBlockScaling -> True (nvfp4_pertoken) -" -``` - -### 5. Run cuDNN Frontend NVFP4 Tests (global_scale_tensor) - -```bash -cd /path/to/cudnn-frontend/test/python -conda activate cudnn-dev # or your env with cudnn-frontend built - -# New NVFP4 global_scale tests -python -m pytest fe_api/test_grouped_gemm_glu_nvfp4.py -v --tb=short - -# Regression: existing tests should still pass -python -m pytest fe_api/test_grouped_gemm_swiglu.py -v --tb=short -python -m pytest fe_api/test_grouped_gemm_glu.py -v --tb=short -``` - -### 6. Run TE Backward Override Tests (requires full TE build) - -```bash -cd /path/to/TransformerEngine - -# Existing MXFP8 backward override tests (regression check) -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python -m pytest tests/pytorch/test_backward_override.py -v --tb=short -k "mxfp8" 2>&1 | tail -20 -``` - -### 7. Smoke Test: NVFP4 Per-Token Forward Pass (manual) - -Note: tokens per expert must be 64-aligned for NVFP4's Hadamard transform. - -```bash -NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " -import torch -import transformer_engine.pytorch as te -from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling - -recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') - -num_gemms = 4 -in_features = 256 -out_features = 512 -tokens_per_expert = 64 # must be 64-aligned for NVFP4 RHT - -with te.fp8_autocast(fp8_recipe=recipe): - fc1 = te.GroupedLinear( - num_gemms, in_features, out_features, - bias=False, params_dtype=torch.bfloat16, - ).cuda() - - x = torch.randn(tokens_per_expert * num_gemms, in_features, dtype=torch.bfloat16, device='cuda') - m_splits = [tokens_per_expert] * num_gemms - - y = fc1(x, m_splits) - print(f'Input: {x.shape}, Output: {y.shape}') - print(f'Output dtype: {y.dtype}') - print('Forward pass OK') -" -``` - -### 8. Backward Override Smoke Test - -```bash -NVTE_CUTEDSL_FUSED_GROUPED_MLP_NVFP4=1 python -c " -import torch -import transformer_engine.pytorch as te -from transformer_engine.common.recipe import NVFP4PerTokenBlockScaling - -recipe = NVFP4PerTokenBlockScaling(backward_override='high_precision') - -num_gemms = 4 -in_features = 256 -out_features = 512 -tokens_per_expert = 64 # must be 64-aligned for NVFP4 RHT - -with te.fp8_autocast(fp8_recipe=recipe): - fc1 = te.GroupedLinear( - num_gemms, in_features, out_features, - bias=False, params_dtype=torch.bfloat16, - ).cuda() - - x = torch.randn(tokens_per_expert * num_gemms, in_features, dtype=torch.bfloat16, device='cuda', requires_grad=True) - m_splits = [tokens_per_expert] * num_gemms - - y = fc1(x, m_splits) - loss = y.sum() - loss.backward() - print(f'Grad shape: {x.grad.shape}') - print(f'Grad dtype: {x.grad.dtype}') - print('Backward pass (high_precision override) OK') -" -``` diff --git a/NVFP4_NEXT_STEPS.md b/NVFP4_NEXT_STEPS.md deleted file mode 100644 index 762f78054c..0000000000 --- a/NVFP4_NEXT_STEPS.md +++ /dev/null @@ -1,115 +0,0 @@ -# Per-Token NVFP4 Grouped GEMM — What's Missing and Next Steps - -## Current State - -The end-to-end plumbing is complete: recipe -> quantizer -> fused op -> cuDNN kernel. Smoke tests pass for both forward and backward (with `backward_override`). However, several pieces are placeholders or approximations. - ---- - -## What's Missing - -### 1. Per-Token Quantization Kernel (CUDA) - -**Status:** Placeholder only (`quantize_pertoken_nvfp4.cuh`) - -**What exists today:** The standard NVFP4 quantizer computes a single per-tensor amax and derives one global scale `s_global = amax / (fp8_max * fp4_max)`. All tokens share this scale. - -**What's needed:** A kernel that computes per-row (per-token) amax and derives per-row global scales. - -**Files to create/modify:** -- `transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh` — implement kernel -- `transformer_engine/common/cast/nvfp4/group_quantize_pertoken_nvfp4.cuh` — grouped variant for MoE -- `transformer_engine/common/include/transformer_engine/cast.h` — add C API declarations -- `transformer_engine/common/cast.cu` — add dispatch for per-token path -- `transformer_engine/pytorch/csrc/extensions/cast.cpp` — add `group_quantize_nvfp4_pertoken_impl` and wire into `group_quantize()` - -**Kernel spec:** -``` -Input: (M, K) tensor, BF16/FP32 -Output: (M, K/2) packed FP4 data, uint8 - (round_to_128(M), ceil(K/16)/4) block scales, FP8 E4M3 - (M,) per-token global scales, FP32 -``` - -**Key difference from standard NVFP4:** Step 1 computes `amax` per row instead of per tensor. This requires a row-wise parallel reduction (one warp per row or similar). - -### 2. NVFP4Quantizer Per-Token Support - -**Status:** Not started - -**What's needed:** The `NVFP4Quantizer` and `NVFP4Tensor` need to support per-token amax/global_scale instead of per-tensor. - -**Files to modify:** -- `transformer_engine/pytorch/tensor/nvfp4_tensor.py` - - `NVFP4Quantizer.make_empty()` — allocate `amax_rowwise` as `(M,)` instead of `(1,)` when per-token mode is enabled - - `NVFP4Tensor` — property to expose `per_token_global_scale` as `(M,)` FP32 tensor - - `NVFP4Quantizer.get_scale_shape()` — may need adjustment for per-token layout -- `transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py` — storage format for per-token scales - -### 3. Fused Op: Per-Token Global Scale from Quantizer Output - -**Status:** Approximation (broadcast per-tensor amax to all tokens) - -**What's in place:** `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.fuser_forward()` extracts `grouped_fc1_x.amax` (per-tensor, shape `(1,)`) and broadcasts it to `(valid_m, 1, 1)` as `global_scale_tensor`. - -**What's needed once per-token quantizer exists:** -```python -# Replace this: -global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) -global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) - -# With this: -global_scale_tensor = grouped_fc1_x.per_token_global_scale.reshape(-1, 1, 1) -``` - -**File:** `transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py` (inside `ForwardGroupedMLP_CuTeGEMMSwiGLU_NVFP4.fuser_forward()`) - -### 4. Fused FC1->FC2 Handoff (Optimization) - -**Status:** Not optimized — FC1 outputs BF16, then FC2 re-quantizes to NVFP4 - -**What exists in MXFP8:** FC1 kernel produces FP8 output + SFD (scale factor D) in a single kernel call (`discrete_col_sfd=True`). FC2 consumes them directly with zero re-quantization. - -**What's needed for NVFP4:** Enable `discrete_col_sfd=True` with FP4 output dtype in the cuDNN kernel, so FC1 directly produces NVFP4 output + block scales + per-token global scales. Then FC2 can consume them without re-quantizing. This requires: -- cuDNN kernel: verify FP4 output with SFD generation works (may already work, needs testing) -- TE fused op: update FC2 input path to use FC1's SFD output instead of re-quantizing - -### 5. Backward Pass - -**Status:** By design, backward uses higher precision via `backward_override`. - -The `NVFP4PerTokenBlockScaling` recipe intentionally runs forward in NVFP4 for throughput and backward in BF16 for training stability. This is controlled by `backward_override`: -- `"high_precision"`: saves original BF16 tensors for backward (more memory, best accuracy) -- `"dequantized"`: dequantizes saved FP4 tensors to BF16 for backward (less memory, slightly less accurate) - -No fused NVFP4 backward kernels are needed. The unfused BF16 backward path handles gradient computation. This matches the pattern established by MXFP8 in PR #2644. - -### 7. Tests - -**Status:** Basic smoke tests pass. Comprehensive tests missing. - -**Needed:** -- cuDNN Frontend: more test configs in `test_grouped_gemm_glu_nvfp4.py` (FP8 + global_scale, varying per-token values, discrete mode, class API) -- TE: dedicated NVFP4 per-token test cases in `tests/pytorch/test_backward_override.py` for `NVFP4PerTokenBlockScaling` -- Numerical accuracy comparison: NVFP4 per-token vs per-tensor vs BF16 baseline on real MoE workloads - ---- - -## Recommended Execution Order - -### Phase 1: Make per-token approximation production-ready -1. Add comprehensive cuDNN Frontend tests for `global_scale_tensor` (more configs, edge cases) -2. Add TE test cases for `NVFP4PerTokenBlockScaling` in backward override test suite -3. Optimize FC1->FC2 handoff (test if `discrete_col_sfd=True` works with FP4 output) - -### Phase 2: Implement true per-token quantization kernel -4. Implement `quantize_pertoken_nvfp4.cuh` CUDA kernel -5. Add grouped variant `group_quantize_pertoken_nvfp4.cuh` -6. Add C API and C++ bindings -7. Update `NVFP4Quantizer` to support per-token mode -8. Update fused op to use real per-token global scales - -### Phase 3: Benchmarking and validation -9. Benchmark per-token vs per-tensor NVFP4 on DeepSeek-V3 / Mixtral MoE workloads -10. Compare training loss curves: NVFP4 per-token vs MXFP8 vs BF16 -11. Measure throughput: fused NVFP4 per-token vs unfused NVFP4 vs MXFP8 fused diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py new file mode 100644 index 0000000000..db76c077c9 --- /dev/null +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -0,0 +1,256 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for per-token NVFP4 quantization kernel (tex.quantize_nvfp4_pertoken). + +These tests validate the CUDA kernel in quantize_pertoken_nvfp4.cuh, which +performs per-row amax reduction and NVFP4 quantization in a single kernel. + +Tests require SM100+ (Blackwell) for FP4 hardware support. +""" + +import math +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex + +# Check hardware support +_, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) +nvfp4_available = te.is_nvfp4_available() + +pytestmark = pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4) + +FP4_MAX = 6.0 +FP8_E4M3_MAX = 448.0 + + +def _has_pertoken_kernel(): + """Check if the per-token kernel binding is available.""" + return hasattr(tex, "quantize_nvfp4_pertoken") + + +# --------------------------------------------------------------------------- +# Reference implementation +# --------------------------------------------------------------------------- + + +def nvfp4_pertoken_quantize_ref(input_tensor: torch.Tensor): + """Pure PyTorch reference for per-token NVFP4 quantization. + + Returns: + per_token_scales: (num_rows,) FP32 tensor + global_scale[row] = row_amax / (fp8_max * fp4_max) + """ + assert input_tensor.dim() == 2 + num_rows, num_cols = input_tensor.shape + assert num_cols % 16 == 0 + + input_f32 = input_tensor.float() + + # Per-row amax + row_amax = input_f32.abs().amax(dim=1) # (num_rows,) + + # Per-token global scale = row_amax / (fp8_max * fp4_max) + per_token_scales = row_amax / (FP8_E4M3_MAX * FP4_MAX) + + # Handle zero rows + per_token_scales = torch.where( + row_amax == 0, torch.zeros_like(per_token_scales), per_token_scales + ) + + return per_token_scales + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not _has_pertoken_kernel(), reason="tex.quantize_nvfp4_pertoken not available") +class TestQuantizeNvfp4Pertoken: + """Test suite for per-token NVFP4 quantization kernel.""" + + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (1, 16), + (1, 256), + (4, 256), + (32, 256), + (64, 4096), + (128, 4096), + (256, 4096), + (512, 14336), + ], + ids=lambda x: f"{x}", + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_output_shapes(self, num_rows, num_cols, dtype): + """Verify output tensor shapes are correct.""" + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + data, scales, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + assert data.shape == (num_rows, num_cols // 2), f"data shape: {data.shape}" + assert scales.shape == (num_rows, num_cols // 16), f"scales shape: {scales.shape}" + assert per_token_scales.shape == (num_rows,), f"per_token_scales shape: {per_token_scales.shape}" + assert data.dtype == torch.uint8 + assert scales.dtype == torch.uint8 + assert per_token_scales.dtype == torch.float32 + + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (1, 256), + (32, 256), + (64, 4096), + (256, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_per_token_scales_match_reference(self, num_rows, num_cols, dtype): + """Verify per-token scales match pure PyTorch reference.""" + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + ref_scales = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + per_token_scales, + ref_scales, + atol=1e-5, + rtol=1e-3, + msg="Per-token scales should match reference", + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_zero_input(self, dtype): + """Zero input should produce zero per-token scales.""" + x = torch.zeros(16, 256, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + assert (per_token_scales == 0).all(), "Zero input should give zero per-token scales" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_uniform_rows_same_scale(self, dtype): + """Rows with the same magnitude should produce the same per-token scale.""" + num_rows = 8 + num_cols = 256 + x = torch.randn(1, num_cols, dtype=dtype, device="cuda").expand(num_rows, -1).contiguous() + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # All rows identical → all scales identical + assert torch.allclose( + per_token_scales, per_token_scales[0].expand_as(per_token_scales) + ), "Identical rows should produce identical per-token scales" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_different_rows_different_scales(self, dtype): + """Rows with different magnitudes should produce different per-token scales.""" + num_cols = 256 + # Row 0: small values, Row 1: large values + x = torch.zeros(2, num_cols, dtype=dtype, device="cuda") + x[0] = torch.randn(num_cols, dtype=dtype, device="cuda") * 0.01 + x[1] = torch.randn(num_cols, dtype=dtype, device="cuda") * 100.0 + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # Scale for large row should be much larger + assert per_token_scales[1] > per_token_scales[0] * 10, ( + f"Large row scale ({per_token_scales[1].item():.6f}) should be >> " + f"small row scale ({per_token_scales[0].item():.6f})" + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_scale_formula(self, dtype): + """Verify scale = row_amax / (fp8_max * fp4_max).""" + num_rows = 4 + num_cols = 256 + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + # Compute expected scales + row_amax = x.float().abs().amax(dim=1) + expected_scales = row_amax / (FP8_E4M3_MAX * FP4_MAX) + + torch.testing.assert_close( + per_token_scales, + expected_scales, + atol=1e-5, + rtol=1e-3, + msg="Scale should equal row_amax / (fp8_max * fp4_max)", + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_block_scales_are_valid_fp8(self, dtype): + """Block scales should be valid FP8 E4M3 values (non-NaN, non-Inf).""" + x = torch.randn(32, 4096, dtype=dtype, device="cuda") + _, scales, _ = tex.quantize_nvfp4_pertoken(x) + + # Reinterpret uint8 as FP8 E4M3 and check for validity + scales_f32 = scales.to(torch.float8_e4m3fn).float() + assert not torch.isnan(scales_f32).any(), "Block scales contain NaN" + assert not torch.isinf(scales_f32).any(), "Block scales contain Inf" + assert (scales_f32 >= 0).all(), "Block scales should be non-negative" + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_packed_fp4_data_shape(self, dtype): + """Packed FP4 output should have exactly half the columns (2 elements per byte).""" + for num_cols in [16, 32, 256, 4096]: + x = torch.randn(4, num_cols, dtype=dtype, device="cuda") + data, _, _ = tex.quantize_nvfp4_pertoken(x) + assert data.shape[1] == num_cols // 2 + + def test_input_validation_not_2d(self): + """Should reject non-2D input.""" + x = torch.randn(2, 3, 256, dtype=torch.bfloat16, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + def test_input_validation_not_multiple_of_16(self): + """Should reject num_cols not divisible by 16.""" + x = torch.randn(4, 100, dtype=torch.bfloat16, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + def test_input_validation_wrong_dtype(self): + """Should reject non-BF16/FP16 input.""" + x = torch.randn(4, 256, dtype=torch.float32, device="cuda") + with pytest.raises(RuntimeError): + tex.quantize_nvfp4_pertoken(x) + + +# --------------------------------------------------------------------------- +# Standalone test (can run without tex binding for reference validation) +# --------------------------------------------------------------------------- + + +class TestPertokenScaleReference: + """Test the pure PyTorch reference implementation (no CUDA kernel needed).""" + + def test_reference_basic(self): + """Basic reference test on CPU.""" + x = torch.tensor([[1.0, 2.0, 3.0, 4.0] * 4], dtype=torch.float32) + scales = nvfp4_pertoken_quantize_ref(x) + expected = torch.tensor([4.0 / (FP8_E4M3_MAX * FP4_MAX)]) + torch.testing.assert_close(scales, expected) + + def test_reference_multi_row(self): + """Multi-row reference test.""" + x = torch.zeros(3, 16, dtype=torch.float32) + x[0] = 1.0 + x[1] = 10.0 + x[2] = 0.1 + scales = nvfp4_pertoken_quantize_ref(x) + + assert scales[1] > scales[0] > scales[2] + torch.testing.assert_close(scales[0], torch.tensor(1.0 / (FP8_E4M3_MAX * FP4_MAX))) + torch.testing.assert_close(scales[1], torch.tensor(10.0 / (FP8_E4M3_MAX * FP4_MAX))) + + def test_reference_zero_row(self): + """Zero row should produce zero scale.""" + x = torch.zeros(2, 16, dtype=torch.float32) + x[0] = 5.0 + scales = nvfp4_pertoken_quantize_ref(x) + assert scales[1] == 0.0 diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 61cfacd334..4a30485278 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -16,6 +16,7 @@ #include "../utils.cuh" #include "dispatch/dequantize.cuh" #include "dispatch/quantize.cuh" +#include "nvfp4/quantize_pertoken_nvfp4.cuh" #include "transformer_engine/transpose.h" void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -146,3 +147,47 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out dispatch::group_quantize_fwd_host_aware_helper( input, outputs, split_sections, num_tensors, quant_config, stream); } + +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, + NVTETensor output_data, + NVTETensor output_scales, + NVTETensor output_per_token_scales, + size_t num_rows, + size_t num_cols, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_nvfp4_pertoken); + using namespace transformer_engine; + + const auto &input_tensor = *reinterpret_cast(input); + auto *data_tensor = reinterpret_cast(output_data); + auto *scales_tensor = reinterpret_cast(output_scales); + auto *pertoken_tensor = reinterpret_cast(output_per_token_scales); + + const auto itype = input_tensor.data.dtype; + + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + if (itype == DType::kBFloat16) { + dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( + num_rows, num_cols, + reinterpret_cast(input_tensor.data.dptr), + nullptr, // row_offsets + reinterpret_cast(data_tensor->data.dptr), + reinterpret_cast(scales_tensor->data.dptr), + reinterpret_cast(pertoken_tensor->data.dptr), + stream); + } else if (itype == DType::kFloat16) { + dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( + num_rows, num_cols, + reinterpret_cast(input_tensor.data.dptr), + nullptr, // row_offsets + reinterpret_cast(data_tensor->data.dptr), + reinterpret_cast(scales_tensor->data.dptr), + reinterpret_cast(pertoken_tensor->data.dptr), + stream); + } else { + NVTE_ERROR("Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16 or Float16."); + } +} diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 5f7b558fb1..12786f5280 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -7,33 +7,17 @@ /*! \file quantize_pertoken_nvfp4.cuh * \brief CUDA kernels to cast to NVFP4 with per-token (per-row) global scaling. * - * Unlike standard NVFP4 quantization which uses a single per-tensor global scale - * (amax / (fp8_max * fp4_max)), per-token NVFP4 computes a separate global scale - * for each row (token) of the input tensor. This preserves more dynamic range - * information per token, improving accuracy for MoE grouped GEMM workloads. + * Unlike standard NVFP4 quantization which uses a single per-tensor global scale, + * per-token NVFP4 computes a separate global scale for each row. This preserves + * more dynamic range per token, improving accuracy for MoE workloads. * * Scaling hierarchy: - * x_quantized = round_to_fp4(x / (global_scale[row] * block_scale[row, block])) - * x_dequantized = x_quantized * block_scale[row, block] * global_scale[row] + * global_scale[row] = row_amax / (fp8_max * fp4_max) + * block_scale[row, block] = block_amax / (fp4_max * global_scale[row]) + * x_fp4 = quantize_to_fp4(x / (global_scale[row] * block_scale[row, block])) * - * Where: - * - global_scale[row] = row_amax / (fp8_max * fp4_max) [FP32, per-row] - * - block_scale[row, block] = block_amax / (fp4_max * global_scale[row]) [FP8 E4M3, per-16-element block] - * - * Output tensors: - * - data: uint8 packed FP4 (same as standard NVFP4) - * - block_scales: uint8 reinterpreted as FP8 E4M3 (same layout as standard NVFP4) - * - per_token_scales: float32 tensor of shape (num_rows,) containing global_scale per row - * - * TODO: Implement the CUDA kernel. The kernel should: - * 1. Compute per-row amax via parallel reduction - * 2. Derive per-row global_scale = row_amax / (fp8_max * fp4_max) - * 3. For each 16-element block: compute block_amax, derive block_scale, quantize to FP4 - * 4. Store per-row global_scale to output tensor - * - * For now, per-token scaling is approximated by using the per-tensor amax - * broadcast to all rows. The fused grouped MLP path in TransformerEngine - * handles this via the global_scale_tensor parameter in cuDNN Frontend. + * Based on the approach from FlashInfer (flashinfer-ai/flashinfer#3027): + * two-pass design with one CUDA block per row. */ #ifndef TRANSFORMER_ENGINE_QUANTIZE_PERTOKEN_NVFP4_CUH_ @@ -41,11 +25,18 @@ #include #include -#include +#include #include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" #include "core_nvfp4.cuh" +#if FP4_TYPE_SUPPORTED +#include +#endif + namespace transformer_engine { namespace dispatch { namespace nvfp4 { @@ -53,20 +44,178 @@ namespace quantize_pertoken_kernel { using namespace core; +constexpr int PERTOKEN_BLOCK_SIZE = 256; +constexpr int PERTOKEN_SF_VEC_SIZE = 16; + /* - * Per-token NVFP4 quantization kernel placeholder. + * Per-token NVFP4 quantization kernel. * - * Parameters: - * input - Input tensor (rows x cols), high-precision (BF16/FP32) - * output_data - Output packed FP4 data (rows x cols/2), uint8 - * output_scales - Output block scales (rows x ceil(cols/16)), FP8 E4M3 - * output_per_token_scales - Output per-row global scales (rows,), FP32 - * rows - Number of rows (tokens) - * cols - Number of columns (hidden dim), must be multiple of 16 + * One CUDA block per row. Two passes: + * Pass 1: Vectorized load + per-row amax reduction via cub::BlockReduce + * Pass 2: Reload data, compute per-block E4M3 scale, quantize to FP4 * - * TODO: Implement kernel body. See quantize_nvfp4.cuh for reference implementation - * of the per-tensor variant. + * Template parameters: + * IType - Input type (half, __nv_bfloat16) + * BLOCK_SIZE - Threads per block + * + * Parameters: + * num_rows - Number of rows (tokens) + * num_cols - Number of columns (hidden dim), must be multiple of 16 + * input - Input tensor (num_rows, num_cols), IType + * row_offsets - Optional row index remapping (for MoE expert routing), or nullptr + * output_data - Output packed FP4 data (num_rows, num_cols/2), uint8 + * output_scales - Output block scales, fp8e4m3 + * output_per_token_scales - Output per-row global scales (num_rows,), FP32 + * scale_stride - Stride of scale factor output (number of SF vectors per row) */ +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + quantize_pertoken_nvfp4_kernel( + const int num_rows, + const int num_cols, + const IType *__restrict__ input, + const int *__restrict__ row_offsets, // optional: nullptr for identity mapping + uint8_t *__restrict__ output_data, + fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_scales, + const int scale_stride) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f + constexpr float fp4_max = TypeExtrema::max; // 6.0f + constexpr float fp4_max_inv = 1.0f / fp4_max; + + // Packed type: 4 elements per float2 pair for FP4 conversion + using IType2 = typename std::conditional::value, + half2, __nv_bfloat162>::type; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + // Optional row remapping (for MoE routing) + const int actual_row = (row_offsets != nullptr) ? row_offsets[row_idx] : row_idx; + if (actual_row < 0) return; + + const int num_vec2 = num_cols / 2; // number of IType2 elements per row + const IType2 *input_row = reinterpret_cast(input + actual_row * num_cols); + + // ========================================================================= + // Pass 1: Per-row amax reduction + // ========================================================================= + float thread_max = 0.0f; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + IType2 val = input_row[i]; + float2 fval; + if constexpr (std::is_same_v) { + fval = __half22float2(val); + } else { + fval = __bfloat1622float2(val); + } + thread_max = fmaxf(thread_max, fabsf(fval.x)); + thread_max = fmaxf(thread_max, fabsf(fval.y)); + } + + // Block-wide max reduction + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + float row_amax = BlockReduce(temp_storage).Reduce(thread_max, cub::Max()); + + // Compute and store per-token global scale + // global_scale = row_amax / (fp8_max * fp4_max) + // S_enc = fp8_max * fp4_max / row_amax (encoding scale, inverse of global_scale) + __shared__ float shared_s_enc; + if (threadIdx.x == 0) { + float s_enc = compute_global_encode_scaling_factor_FP4(row_amax); + float global_scale = (s_enc > 0.0f) ? (1.0f / s_enc) : 0.0f; + output_per_token_scales[row_idx] = global_scale; + shared_s_enc = s_enc; + } + __syncthreads(); + const float S_enc = shared_s_enc; + + // ========================================================================= + // Pass 2: Quantize to FP4 with per-token scale + // ========================================================================= + // Process in chunks of SF_VEC_SIZE (16) elements. + // Each chunk produces one FP8 E4M3 block scale factor. + const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; + + for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { + const int col_start = sf_idx * PERTOKEN_SF_VEC_SIZE; + + // Load 16 elements and find block amax + float block_max = 0.0f; + float vals[PERTOKEN_SF_VEC_SIZE]; + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j++) { + if constexpr (std::is_same_v) { + vals[j] = __half2float(input[actual_row * num_cols + col_start + j]); + } else { + vals[j] = __bfloat162float(input[actual_row * num_cols + col_start + j]); + } + block_max = fmaxf(block_max, fabsf(vals[j])); + } + + // Compute per-block E4M3 scale factor + fp8e4m3 S_dec_b = quantization_SF::compute_decoding_scaling_factor(block_max, S_enc); + float S_dec_b_f = static_cast(S_dec_b); + + // Store block scale + output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; + + // Compute inverse block scale for quantization + float block_encode_scale = (S_dec_b_f != 0.0f) + ? __fdividef(S_enc, S_dec_b_f) + : 0.0f; + + // Quantize 16 elements to FP4 and pack into 8 bytes + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { + float2 in01 = {vals[j] * block_encode_scale, vals[j + 1] * block_encode_scale}; + float2 in23 = {vals[j + 2] * block_encode_scale, vals[j + 3] * block_encode_scale}; + fp4e2m1x4 fp4_packed; + ptx::mul_cvt_4x(fp4_packed, in01, in23, 1.0f, 0); + // Pack 4 FP4 values (2 bytes) into output + reinterpret_cast(out_ptr)[j / 4] = + *reinterpret_cast(&fp4_packed); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +/* + * Host-side launcher for per-token NVFP4 quantization. + */ +template +void launch_quantize_pertoken_nvfp4( + const int num_rows, + const int num_cols, + const IType *input, + const int *row_offsets, + uint8_t *output_data, + fp8e4m3 *output_scales, + float *output_per_token_scales, + cudaStream_t stream) { + if (num_rows == 0 || num_cols == 0) return; + + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, + "num_cols must be a multiple of ", PERTOKEN_SF_VEC_SIZE, + " for per-token NVFP4 quantization, got ", num_cols); + + const int scale_stride = num_cols / PERTOKEN_SF_VEC_SIZE; + dim3 grid(num_rows); + dim3 block(PERTOKEN_BLOCK_SIZE); + + quantize_pertoken_nvfp4_kernel + <<>>( + num_rows, num_cols, input, row_offsets, + output_data, output_scales, output_per_token_scales, + scale_stride); + NVTE_CHECK_CUDA(cudaGetLastError()); +} } // namespace quantize_pertoken_kernel } // namespace nvfp4 diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 554d8c1ac9..6d3e540eb9 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -453,6 +453,27 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out const NVTEQuantizationConfig quant_config, cudaStream_t stream); +/*! \brief Per-token NVFP4 quantization. + * + * Quantizes an input tensor to NVFP4 with per-row (per-token) global scaling. + * Each row gets its own FP32 global scale derived from its row-wise amax. + * + * \param[in] input Input tensor (num_rows, num_cols). + * \param[out] output_data Packed FP4 data (num_rows, num_cols/2), uint8. + * \param[out] output_scales Block scales (num_rows, num_cols/16), FP8 E4M3. + * \param[out] output_per_token_scales Per-row global scales (num_rows,), FP32. + * \param[in] num_rows Number of rows. + * \param[in] num_cols Number of columns (must be multiple of 16). + * \param[in] stream CUDA stream. + */ +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, + NVTETensor output_data, + NVTETensor output_scales, + NVTETensor output_per_token_scales, + size_t num_rows, + size_t num_cols, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fb5783dfcb..eefbc2fdc7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -314,6 +314,8 @@ py::object group_dequantize(const py::handle &input, DType otype); py::object bgrad_group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, std::optional first_dims); +std::tuple quantize_nvfp4_pertoken(at::Tensor input); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5fb162c72d..39936787d3 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1556,5 +1556,42 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } +std::tuple quantize_nvfp4_pertoken( + at::Tensor input) { + // Input validation + NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); + NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); + NVTE_CHECK(input.scalar_type() == at::ScalarType::BFloat16 || + input.scalar_type() == at::ScalarType::Half, + "Input must be BFloat16 or Half"); + + const int num_rows = input.size(0); + const int num_cols = input.size(1); + NVTE_CHECK(num_cols % 16 == 0, + "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); + + auto options = input.options(); + + // Allocate outputs + auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); + auto output_scales = at::empty( + {num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte)); + auto output_per_token_scales = at::empty({num_rows}, options.dtype(at::kFloat)); + + // Wrap as NVTETensors + auto te_input = makeTransformerEngineTensor(input); + auto te_data = makeTransformerEngineTensor(output_data); + auto te_scales = makeTransformerEngineTensor(output_scales); + auto te_pertoken = makeTransformerEngineTensor(output_per_token_scales); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + nvte_quantize_nvfp4_pertoken( + te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(), + num_rows, num_cols, stream); + + return {output_data, output_scales, output_per_token_scales}; +} + } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 27d26d3dab..a7ca590478 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -145,6 +145,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Dequantize group tensor", py::arg("input"), py::arg("otype")); m.def("bgrad_group_quantize", transformer_engine::pytorch::bgrad_group_quantize, py::arg("tensor"), py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); + m.def("quantize_nvfp4_pertoken", transformer_engine::pytorch::quantize_nvfp4_pertoken, + "Per-token NVFP4 quantization", py::arg("input")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index ea18d566a2..0ed44e9a04 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -775,18 +775,28 @@ def fuser_forward( ) fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) - # Per-token global scale from NVFP4 quantizer - # amax_rowwise is per-tensor (1,); broadcast to per-token for now. - # TODO: Implement true per-token global scale in NVFP4Quantizer. - nvfp4_amax = grouped_fc1_x.amax - if nvfp4_amax is not None and nvfp4_amax.numel() == 1: - # global_scale = amax / (fp4_max * fp8_max) per NVFP4 spec - fp4_max = 6.0 - fp8_max = 448.0 - global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) - global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) - else: - global_scale_tensor = None + # Per-token global scale. + # The per-token NVFP4 kernel (tex.quantize_nvfp4_pertoken) produces + # data + block_scales + per_token_scales in one pass. Here we call it + # to get the per-token scales. The quantized data from group_quantize + # (above) is used for the GEMM since it handles grouped layout/swizzle. + # TODO: Unify into a single quantization call once the grouped per-token + # kernel supports the full TE scale factor layout. + global_scale_tensor = None + try: + _, _, fc1_per_token_scales = tex.quantize_nvfp4_pertoken( + fc1_x.reshape(in_shape[0], in_shape[1]) if not isinstance(input_, GroupedTensor) + else input_.dequantize(dtype=dtype).reshape(in_shape[0], in_shape[1]) + ) + global_scale_tensor = fc1_per_token_scales.reshape(-1, 1, 1) + except (AttributeError, RuntimeError): + # Fallback: per-tensor amax broadcast to all tokens + nvfp4_amax = grouped_fc1_x.amax + if nvfp4_amax is not None and nvfp4_amax.numel() == 1: + fp4_max = 6.0 + fp8_max = 448.0 + global_scale_val = nvfp4_amax.float() / (fp4_max * fp8_max) + global_scale_tensor = global_scale_val.expand(in_shape[0]).reshape(-1, 1, 1) alpha_tensor = get_cached_ones_tensor(num_groups, dtype, device) norm_const_tensor = get_cached_ones_tensor(1, dtype, device) From 5433f9931edc286150658bf5b83a10099fd89377 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 00:21:24 +0000 Subject: [PATCH 05/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: YigongQin --- tests/pytorch/test_backward_override.py | 8 +++- tests/pytorch/test_nvfp4_pertoken_quant.py | 4 +- transformer_engine/common/cast/cast.cu | 27 ++++------- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 47 +++++++------------ .../common/include/transformer_engine/cast.h | 10 ++-- .../pytorch/csrc/extensions/cast.cpp | 11 ++--- .../pytorch/ops/fused/forward_grouped_mlp.py | 15 ++---- 7 files changed, 48 insertions(+), 74 deletions(-) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index b98f54053a..8ae502a1a1 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -200,7 +200,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -225,7 +227,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_pertoken") and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_pertoken") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index db76c077c9..5aab660ee2 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -95,7 +95,9 @@ def test_output_shapes(self, num_rows, num_cols, dtype): assert data.shape == (num_rows, num_cols // 2), f"data shape: {data.shape}" assert scales.shape == (num_rows, num_cols // 16), f"scales shape: {scales.shape}" - assert per_token_scales.shape == (num_rows,), f"per_token_scales shape: {per_token_scales.shape}" + assert per_token_scales.shape == ( + num_rows, + ), f"per_token_scales shape: {per_token_scales.shape}" assert data.dtype == torch.uint8 assert scales.dtype == torch.uint8 assert per_token_scales.dtype == torch.float32 diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 4a30485278..6bcdd2cf66 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -148,13 +148,9 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out input, outputs, split_sections, num_tensors, quant_config, stream); } -void nvte_quantize_nvfp4_pertoken(const NVTETensor input, - NVTETensor output_data, - NVTETensor output_scales, - NVTETensor output_per_token_scales, - size_t num_rows, - size_t num_cols, - cudaStream_t stream) { +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data, + NVTETensor output_scales, NVTETensor output_per_token_scales, + size_t num_rows, size_t num_cols, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_nvfp4_pertoken); using namespace transformer_engine; @@ -170,24 +166,21 @@ void nvte_quantize_nvfp4_pertoken(const NVTETensor input, if (itype == DType::kBFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( - num_rows, num_cols, - reinterpret_cast(input_tensor.data.dptr), + num_rows, num_cols, reinterpret_cast(input_tensor.data.dptr), nullptr, // row_offsets reinterpret_cast(data_tensor->data.dptr), reinterpret_cast(scales_tensor->data.dptr), - reinterpret_cast(pertoken_tensor->data.dptr), - stream); + reinterpret_cast(pertoken_tensor->data.dptr), stream); } else if (itype == DType::kFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( - num_rows, num_cols, - reinterpret_cast(input_tensor.data.dptr), + num_rows, num_cols, reinterpret_cast(input_tensor.data.dptr), nullptr, // row_offsets reinterpret_cast(data_tensor->data.dptr), reinterpret_cast(scales_tensor->data.dptr), - reinterpret_cast(pertoken_tensor->data.dptr), - stream); + reinterpret_cast(pertoken_tensor->data.dptr), stream); } else { - NVTE_ERROR("Unsupported input dtype for per-token NVFP4 quantization. " - "Expected BFloat16 or Float16."); + NVTE_ERROR( + "Unsupported input dtype for per-token NVFP4 quantization. " + "Expected BFloat16 or Float16."); } } diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 12786f5280..80810aac97 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -25,6 +25,7 @@ #include #include + #include #include "../../common.h" @@ -74,24 +75,20 @@ __global__ void __launch_bounds__(BLOCK_SIZE) #endif quantize_pertoken_nvfp4_kernel( - const int num_rows, - const int num_cols, - const IType *__restrict__ input, + const int num_rows, const int num_cols, const IType *__restrict__ input, const int *__restrict__ row_offsets, // optional: nullptr for identity mapping - uint8_t *__restrict__ output_data, - fp8e4m3 *__restrict__ output_scales, - float *__restrict__ output_per_token_scales, - const int scale_stride) { + uint8_t *__restrict__ output_data, fp8e4m3 *__restrict__ output_scales, + float *__restrict__ output_per_token_scales, const int scale_stride) { #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f - constexpr float fp4_max = TypeExtrema::max; // 6.0f + constexpr float fp8_max = TypeExtrema::max; // 448.0f + constexpr float fp4_max = TypeExtrema::max; // 6.0f constexpr float fp4_max_inv = 1.0f / fp4_max; // Packed type: 4 elements per float2 pair for FP4 conversion - using IType2 = typename std::conditional::value, - half2, __nv_bfloat162>::type; + using IType2 = + typename std::conditional::value, half2, __nv_bfloat162>::type; const int row_idx = blockIdx.x; if (row_idx >= num_rows) return; @@ -167,9 +164,7 @@ __launch_bounds__(BLOCK_SIZE) output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; // Compute inverse block scale for quantization - float block_encode_scale = (S_dec_b_f != 0.0f) - ? __fdividef(S_enc, S_dec_b_f) - : 0.0f; + float block_encode_scale = (S_dec_b_f != 0.0f) ? __fdividef(S_enc, S_dec_b_f) : 0.0f; // Quantize 16 elements to FP4 and pack into 8 bytes uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; @@ -190,30 +185,22 @@ __launch_bounds__(BLOCK_SIZE) * Host-side launcher for per-token NVFP4 quantization. */ template -void launch_quantize_pertoken_nvfp4( - const int num_rows, - const int num_cols, - const IType *input, - const int *row_offsets, - uint8_t *output_data, - fp8e4m3 *output_scales, - float *output_per_token_scales, - cudaStream_t stream) { +void launch_quantize_pertoken_nvfp4(const int num_rows, const int num_cols, const IType *input, + const int *row_offsets, uint8_t *output_data, + fp8e4m3 *output_scales, float *output_per_token_scales, + cudaStream_t stream) { if (num_rows == 0 || num_cols == 0) return; - NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, - "num_cols must be a multiple of ", PERTOKEN_SF_VEC_SIZE, - " for per-token NVFP4 quantization, got ", num_cols); + NVTE_CHECK(num_cols % PERTOKEN_SF_VEC_SIZE == 0, "num_cols must be a multiple of ", + PERTOKEN_SF_VEC_SIZE, " for per-token NVFP4 quantization, got ", num_cols); const int scale_stride = num_cols / PERTOKEN_SF_VEC_SIZE; dim3 grid(num_rows); dim3 block(PERTOKEN_BLOCK_SIZE); quantize_pertoken_nvfp4_kernel - <<>>( - num_rows, num_cols, input, row_offsets, - output_data, output_scales, output_per_token_scales, - scale_stride); + <<>>(num_rows, num_cols, input, row_offsets, output_data, + output_scales, output_per_token_scales, scale_stride); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 6d3e540eb9..0661fc5454 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -466,13 +466,9 @@ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *out * \param[in] num_cols Number of columns (must be multiple of 16). * \param[in] stream CUDA stream. */ -void nvte_quantize_nvfp4_pertoken(const NVTETensor input, - NVTETensor output_data, - NVTETensor output_scales, - NVTETensor output_per_token_scales, - size_t num_rows, - size_t num_cols, - cudaStream_t stream); +void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data, + NVTETensor output_scales, NVTETensor output_per_token_scales, + size_t num_rows, size_t num_cols, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 39936787d3..f60c3d7672 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1556,8 +1556,7 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -std::tuple quantize_nvfp4_pertoken( - at::Tensor input) { +std::tuple quantize_nvfp4_pertoken(at::Tensor input) { // Input validation NVTE_CHECK(input.dim() == 2, "Input must be 2D (num_rows, num_cols)"); NVTE_CHECK(input.is_cuda(), "Input must be on CUDA device"); @@ -1574,8 +1573,7 @@ std::tuple quantize_nvfp4_pertoken( // Allocate outputs auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); - auto output_scales = at::empty( - {num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte)); auto output_per_token_scales = at::empty({num_rows}, options.dtype(at::kFloat)); // Wrap as NVTETensors @@ -1586,9 +1584,8 @@ std::tuple quantize_nvfp4_pertoken( auto stream = at::cuda::getCurrentCUDAStream().stream(); - nvte_quantize_nvfp4_pertoken( - te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(), - num_rows, num_cols, stream); + nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(), + te_pertoken.data(), num_rows, num_cols, stream); return {output_data, output_scales, output_per_token_scales}; } diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 0ed44e9a04..17d66f906b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -751,9 +751,7 @@ def fuser_forward( grouped_fc1_x = input_ else: fc1_x = maybe_dequantize(input_, dtype) - grouped_fc1_x = tex.group_quantize( - fc1_x, fc1_input_quantizer, num_groups, split_sizes - ) + grouped_fc1_x = tex.group_quantize(fc1_x, fc1_input_quantizer, num_groups, split_sizes) # Pack data tensors for cuDNN kernel # NVFP4: data is uint8 (packed FP4), reinterpret as float4_e2m1fn_x2 @@ -785,7 +783,8 @@ def fuser_forward( global_scale_tensor = None try: _, _, fc1_per_token_scales = tex.quantize_nvfp4_pertoken( - fc1_x.reshape(in_shape[0], in_shape[1]) if not isinstance(input_, GroupedTensor) + fc1_x.reshape(in_shape[0], in_shape[1]) + if not isinstance(input_, GroupedTensor) else input_.dequantize(dtype=dtype).reshape(in_shape[0], in_shape[1]) ) global_scale_tensor = fc1_per_token_scales.reshape(-1, 1, 1) @@ -831,9 +830,7 @@ def fuser_forward( fc1_w_data = fc1_weight_for_gemm.rowwise_data fc1_w_data = fc1_w_data.view(dtype=torch.float4_e2m1fn_x2) - fc1_w_data = fc1_w_data.view( - num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2 - ) + fc1_w_data = fc1_w_data.view(num_groups, fc1_weight_shape[0], fc1_weight_shape[1] // 2) fc1_w_data = fc1_w_data.permute(1, 2, 0) fc1_w_scales = fc1_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) fc1_w_scales = fc1_w_scales.view( @@ -930,9 +927,7 @@ def fuser_forward( fc2_w_data = fc2_weight_for_gemm.rowwise_data fc2_w_data = fc2_w_data.view(dtype=torch.float4_e2m1fn_x2) - fc2_w_data = fc2_w_data.view( - num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2 - ) + fc2_w_data = fc2_w_data.view(num_groups, fc2_weight_shape[0], fc2_weight_shape[1] // 2) fc2_w_data = fc2_w_data.permute(1, 2, 0) fc2_w_scales = fc2_weight_for_gemm.scale_inv.view(dtype=torch.float8_e4m3fn) From 8b1c88f598d67ec76076c60ce6c5cc20e4b33dbc Mon Sep 17 00:00:00 2001 From: YigongQin Date: Wed, 22 Apr 2026 11:58:55 -0700 Subject: [PATCH 06/10] fix building failures Signed-off-by: YigongQin --- tests/pytorch/test_nvfp4_pertoken_quant.py | 20 +++++---- transformer_engine/common/cast/cast.cu | 44 ++++++++++--------- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 42 +++++++----------- .../pytorch/csrc/extensions/cast.cpp | 25 +++++++---- 4 files changed, 69 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index 5aab660ee2..4892910e62 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -53,12 +53,12 @@ def nvfp4_pertoken_quantize_ref(input_tensor: torch.Tensor): # Per-row amax row_amax = input_f32.abs().amax(dim=1) # (num_rows,) - # Per-token global scale = row_amax / (fp8_max * fp4_max) + # S_enc = fp8_max * fp4_max / row_amax + # global_scale = 1 / S_enc = row_amax / (fp8_max * fp4_max) + # When amax=0, S_enc=1.0 (fallback), so global_scale=1.0 per_token_scales = row_amax / (FP8_E4M3_MAX * FP4_MAX) - - # Handle zero rows per_token_scales = torch.where( - row_amax == 0, torch.zeros_like(per_token_scales), per_token_scales + row_amax == 0, torch.ones_like(per_token_scales), per_token_scales ) return per_token_scales @@ -129,11 +129,15 @@ def test_per_token_scales_match_reference(self, num_rows, num_cols, dtype): @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_zero_input(self, dtype): - """Zero input should produce zero per-token scales.""" + """Zero input: S_enc = 1.0 (fallback), so global_scale = 1/1 = 1.0.""" x = torch.zeros(16, 256, dtype=dtype, device="cuda") _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) - assert (per_token_scales == 0).all(), "Zero input should give zero per-token scales" + # When amax=0, compute_global_encode_scaling_factor_FP4 returns 1.0 + # so global_scale = 1/S_enc = 1/1 = 1.0 + assert (per_token_scales == 1.0).all(), ( + f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}" + ) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_uniform_rows_same_scale(self, dtype): @@ -251,8 +255,8 @@ def test_reference_multi_row(self): torch.testing.assert_close(scales[1], torch.tensor(10.0 / (FP8_E4M3_MAX * FP4_MAX))) def test_reference_zero_row(self): - """Zero row should produce zero scale.""" + """Zero row: S_enc=1.0 fallback, so global_scale=1.0.""" x = torch.zeros(2, 16, dtype=torch.float32) x[0] = 5.0 scales = nvfp4_pertoken_quantize_ref(x) - assert scales[1] == 0.0 + assert scales[1] == 1.0 diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 6bcdd2cf66..40a881044b 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -152,32 +152,36 @@ void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data NVTETensor output_scales, NVTETensor output_per_token_scales, size_t num_rows, size_t num_cols, cudaStream_t stream) { NVTE_API_CALL(nvte_quantize_nvfp4_pertoken); - using namespace transformer_engine; - - const auto &input_tensor = *reinterpret_cast(input); - auto *data_tensor = reinterpret_cast(output_data); - auto *scales_tensor = reinterpret_cast(output_scales); - auto *pertoken_tensor = reinterpret_cast(output_per_token_scales); - - const auto itype = input_tensor.data.dtype; NVTE_CHECK(num_cols % 16 == 0, "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); - if (itype == DType::kBFloat16) { + const void *input_ptr = nvte_tensor_data(input); + void *data_ptr = nvte_tensor_data(output_data); + void *scales_ptr = nvte_tensor_data(output_scales); + void *pertoken_ptr = nvte_tensor_data(output_per_token_scales); + const NVTEDType itype = nvte_tensor_type(input); + + using namespace transformer_engine; + + if (itype == NVTEDType::kNVTEBFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( - num_rows, num_cols, reinterpret_cast(input_tensor.data.dptr), - nullptr, // row_offsets - reinterpret_cast(data_tensor->data.dptr), - reinterpret_cast(scales_tensor->data.dptr), - reinterpret_cast(pertoken_tensor->data.dptr), stream); - } else if (itype == DType::kFloat16) { + num_rows, num_cols, + reinterpret_cast(input_ptr), + nullptr, + reinterpret_cast(data_ptr), + reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), + stream); + } else if (itype == NVTEDType::kNVTEFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( - num_rows, num_cols, reinterpret_cast(input_tensor.data.dptr), - nullptr, // row_offsets - reinterpret_cast(data_tensor->data.dptr), - reinterpret_cast(scales_tensor->data.dptr), - reinterpret_cast(pertoken_tensor->data.dptr), stream); + num_rows, num_cols, + reinterpret_cast(input_ptr), + nullptr, + reinterpret_cast(data_ptr), + reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), + stream); } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 80810aac97..4ed80c39bd 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -119,7 +119,8 @@ __launch_bounds__(BLOCK_SIZE) // Block-wide max reduction using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - float row_amax = BlockReduce(temp_storage).Reduce(thread_max, cub::Max()); + float row_amax = BlockReduce(temp_storage).Reduce(thread_max, + [](float a, float b) { return fmaxf(a, b); }); // Compute and store per-token global scale // global_scale = row_amax / (fp8_max * fp4_max) @@ -135,10 +136,11 @@ __launch_bounds__(BLOCK_SIZE) const float S_enc = shared_s_enc; // ========================================================================= - // Pass 2: Quantize to FP4 with per-token scale + // Pass 2: Compute block scales and quantize to FP4 // ========================================================================= - // Process in chunks of SF_VEC_SIZE (16) elements. - // Each chunk produces one FP8 E4M3 block scale factor. + // TODO: FP4 data packing is disabled pending alignment investigation. + // For now, only per-token scales and block scales are computed. + // The FP4 data output is zeroed. const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { @@ -146,37 +148,25 @@ __launch_bounds__(BLOCK_SIZE) // Load 16 elements and find block amax float block_max = 0.0f; - float vals[PERTOKEN_SF_VEC_SIZE]; for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j++) { + float val; if constexpr (std::is_same_v) { - vals[j] = __half2float(input[actual_row * num_cols + col_start + j]); + val = __half2float(input[actual_row * num_cols + col_start + j]); } else { - vals[j] = __bfloat162float(input[actual_row * num_cols + col_start + j]); + val = __bfloat162float(input[actual_row * num_cols + col_start + j]); } - block_max = fmaxf(block_max, fabsf(vals[j])); + block_max = fmaxf(block_max, fabsf(val)); } - // Compute per-block E4M3 scale factor + // Compute and store per-block E4M3 scale factor fp8e4m3 S_dec_b = quantization_SF::compute_decoding_scaling_factor(block_max, S_enc); - float S_dec_b_f = static_cast(S_dec_b); - - // Store block scale output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; + } - // Compute inverse block scale for quantization - float block_encode_scale = (S_dec_b_f != 0.0f) ? __fdividef(S_enc, S_dec_b_f) : 0.0f; - - // Quantize 16 elements to FP4 and pack into 8 bytes - uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; - for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 4) { - float2 in01 = {vals[j] * block_encode_scale, vals[j + 1] * block_encode_scale}; - float2 in23 = {vals[j + 2] * block_encode_scale, vals[j + 3] * block_encode_scale}; - fp4e2m1x4 fp4_packed; - ptx::mul_cvt_4x(fp4_packed, in01, in23, 1.0f, 0); - // Pack 4 FP4 values (2 bytes) into output - reinterpret_cast(out_ptr)[j / 4] = - *reinterpret_cast(&fp4_packed); - } + // Zero out FP4 data output (placeholder until FP4 packing is validated) + const int data_bytes_per_row = num_cols / 2; + for (int i = threadIdx.x; i < data_bytes_per_row; i += BLOCK_SIZE) { + output_data[actual_row * data_bytes_per_row + i] = 0; } #endif // __CUDA_ARCH__ >= 1000 } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f60c3d7672..b2ac4c4897 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1569,23 +1569,32 @@ std::tuple quantize_nvfp4_pertoken(at::Tenso NVTE_CHECK(num_cols % 16 == 0, "num_cols must be a multiple of 16 for per-token NVFP4 quantization"); - auto options = input.options(); + if (num_rows == 0) { + auto options = input.options(); + return {at::empty({0, num_cols / 2}, options.dtype(at::kByte)), + at::empty({0, num_cols / 16}, options.dtype(at::kByte)), + at::empty({0}, options.dtype(at::kFloat))}; + } + + auto input_contig = input.contiguous(); + auto options = input_contig.options(); // Allocate outputs auto output_data = at::empty({num_rows, num_cols / 2}, options.dtype(at::kByte)); - auto output_scales = at::empty({num_rows, (num_cols + 15) / 16}, options.dtype(at::kByte)); + auto output_scales = at::empty({num_rows, num_cols / 16}, options.dtype(at::kByte)); auto output_per_token_scales = at::empty({num_rows}, options.dtype(at::kFloat)); - // Wrap as NVTETensors - auto te_input = makeTransformerEngineTensor(input); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + // Call C API + auto te_input = makeTransformerEngineTensor(input_contig); auto te_data = makeTransformerEngineTensor(output_data); auto te_scales = makeTransformerEngineTensor(output_scales); auto te_pertoken = makeTransformerEngineTensor(output_per_token_scales); - auto stream = at::cuda::getCurrentCUDAStream().stream(); - - nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(), - te_pertoken.data(), num_rows, num_cols, stream); + nvte_quantize_nvfp4_pertoken( + te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(), + num_rows, num_cols, stream); return {output_data, output_scales, output_per_token_scales}; } From 8a3a36d822ae3413974de47558fb5a813719d8a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 21:03:26 +0000 Subject: [PATCH 07/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: YigongQin --- tests/pytorch/test_nvfp4_pertoken_quant.py | 6 +++--- transformer_engine/common/cast/cast.cu | 20 ++++++------------- .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 4 ++-- .../pytorch/csrc/extensions/cast.cpp | 5 ++--- 4 files changed, 13 insertions(+), 22 deletions(-) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index 4892910e62..569428be84 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -135,9 +135,9 @@ def test_zero_input(self, dtype): # When amax=0, compute_global_encode_scaling_factor_FP4 returns 1.0 # so global_scale = 1/S_enc = 1/1 = 1.0 - assert (per_token_scales == 1.0).all(), ( - f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}" - ) + assert ( + per_token_scales == 1.0 + ).all(), f"Zero input should give global_scale=1.0 (S_enc fallback), got {per_token_scales}" @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_uniform_rows_same_scale(self, dtype): diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 40a881044b..ab519b98d4 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -166,22 +166,14 @@ void nvte_quantize_nvfp4_pertoken(const NVTETensor input, NVTETensor output_data if (itype == NVTEDType::kNVTEBFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4<__nv_bfloat16>( - num_rows, num_cols, - reinterpret_cast(input_ptr), - nullptr, - reinterpret_cast(data_ptr), - reinterpret_cast(scales_ptr), - reinterpret_cast(pertoken_ptr), - stream); + num_rows, num_cols, reinterpret_cast(input_ptr), nullptr, + reinterpret_cast(data_ptr), reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), stream); } else if (itype == NVTEDType::kNVTEFloat16) { dispatch::nvfp4::quantize_pertoken_kernel::launch_quantize_pertoken_nvfp4( - num_rows, num_cols, - reinterpret_cast(input_ptr), - nullptr, - reinterpret_cast(data_ptr), - reinterpret_cast(scales_ptr), - reinterpret_cast(pertoken_ptr), - stream); + num_rows, num_cols, reinterpret_cast(input_ptr), nullptr, + reinterpret_cast(data_ptr), reinterpret_cast(scales_ptr), + reinterpret_cast(pertoken_ptr), stream); } else { NVTE_ERROR( "Unsupported input dtype for per-token NVFP4 quantization. " diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 4ed80c39bd..1aa20d4636 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -119,8 +119,8 @@ __launch_bounds__(BLOCK_SIZE) // Block-wide max reduction using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; - float row_amax = BlockReduce(temp_storage).Reduce(thread_max, - [](float a, float b) { return fmaxf(a, b); }); + float row_amax = + BlockReduce(temp_storage).Reduce(thread_max, [](float a, float b) { return fmaxf(a, b); }); // Compute and store per-token global scale // global_scale = row_amax / (fp8_max * fp4_max) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b2ac4c4897..44a9eeaf3a 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1592,9 +1592,8 @@ std::tuple quantize_nvfp4_pertoken(at::Tenso auto te_scales = makeTransformerEngineTensor(output_scales); auto te_pertoken = makeTransformerEngineTensor(output_per_token_scales); - nvte_quantize_nvfp4_pertoken( - te_input.data(), te_data.data(), te_scales.data(), te_pertoken.data(), - num_rows, num_cols, stream); + nvte_quantize_nvfp4_pertoken(te_input.data(), te_data.data(), te_scales.data(), + te_pertoken.data(), num_rows, num_cols, stream); return {output_data, output_scales, output_per_token_scales}; } From 1b1b9e62e5be7d5ce242fd37ea56ba628904404c Mon Sep 17 00:00:00 2001 From: YigongQin Date: Wed, 22 Apr 2026 15:30:20 -0700 Subject: [PATCH 08/10] pertoken nvfp4 tests Signed-off-by: YigongQin --- tests/pytorch/test_nvfp4_pertoken_quant.py | 95 +++++++++++++++++++ .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 58 ++++++++--- 2 files changed, 140 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index 569428be84..20daf55733 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -26,6 +26,58 @@ FP4_MAX = 6.0 FP8_E4M3_MAX = 448.0 +# FP4 E2M1 look-up table: 4-bit index -> float value +# Lower nibble = first element, upper nibble = second element +_FP4_E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] + + +def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + """Unpack uint8 packed FP4 data to two columns per byte. + + Each byte contains 2 FP4 values: lower nibble = first, upper nibble = second. + Returns a uint8 tensor with 2x the columns. + """ + repeated = packed.repeat_interleave(2, dim=1) + repeated[:, 0::2] = repeated[:, 0::2] & 0x0F # Lower 4 bits + repeated[:, 1::2] = repeated[:, 1::2] >> 4 # Upper 4 bits + return repeated + + +def fp4_to_fp32(unpacked: torch.Tensor) -> torch.Tensor: + """Convert unpacked FP4 indices to float32 values using E2M1 LUT.""" + lut = torch.tensor(_FP4_E2M1_LUT, dtype=torch.float32, device=unpacked.device) + return lut[unpacked.long()] + + +def dequantize_pertoken_fp4(data: torch.Tensor, scales: torch.Tensor, + per_token_scales: torch.Tensor) -> torch.Tensor: + """Dequantize per-token NVFP4: result = fp4_val * block_scale * per_token_scale. + + Args: + data: (M, K/2) uint8 packed FP4 + scales: (M, K/16) uint8 block scales (FP8 E4M3) + per_token_scales: (M,) FP32 per-token global scales + + Returns: + (M, K) float32 dequantized tensor + """ + num_rows = data.shape[0] + num_cols = data.shape[1] * 2 # 2 FP4 values per byte + + # Unpack FP4 -> float32 + fp4_vals = fp4_to_fp32(unpack_fp4(data)) # (M, K) + + # Expand block scales: each scale covers 16 elements + block_scales_f32 = scales.view(torch.float8_e4m3fn).float() # (M, K/16) + block_scales_expanded = block_scales_f32.repeat_interleave(16, dim=1) # (M, K) + block_scales_expanded = block_scales_expanded[:, :num_cols] + + # Expand per-token scales: one per row + token_scales_expanded = per_token_scales.unsqueeze(1) # (M, 1) + + return fp4_vals * block_scales_expanded * token_scales_expanded + def _has_pertoken_kernel(): """Check if the per-token kernel binding is available.""" @@ -208,6 +260,49 @@ def test_packed_fp4_data_shape(self, dtype): data, _, _ = tex.quantize_nvfp4_pertoken(x) assert data.shape[1] == num_cols // 2 + @pytest.mark.parametrize( + "num_rows,num_cols", + [ + (4, 256), + (32, 256), + (64, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_dequantized_data_close_to_input(self, num_rows, num_cols, dtype): + """Dequantized FP4 data should be close to the original input. + + Quantize -> dequantize round-trip should preserve values within FP4 precision. + FP4 E2M1 has ~1 bit mantissa, so expect ~25% relative error for non-tiny values. + """ + torch.manual_seed(42) + x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") + data, scales, per_token_scales = tex.quantize_nvfp4_pertoken(x) + + dequant = dequantize_pertoken_fp4(data, scales, per_token_scales) + + # Compare against original (allow FP4 quantization error) + x_f32 = x.float() + nonzero = x_f32.abs() > 0.1 # skip very small values where relative error is meaningless + if nonzero.any(): + rel_error = ((dequant[nonzero] - x_f32[nonzero]).abs() / + x_f32[nonzero].abs()).mean() + assert rel_error < 0.5, ( + f"Mean relative error {rel_error:.3f} too high for FP4 round-trip " + f"(shape={num_rows}x{num_cols}, dtype={dtype})" + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_fp4_values_in_valid_range(self, dtype): + """Unpacked FP4 indices should be in [0, 15] (valid 4-bit range).""" + x = torch.randn(16, 256, dtype=dtype, device="cuda") + data, _, _ = tex.quantize_nvfp4_pertoken(x) + + unpacked = unpack_fp4(data) + assert (unpacked >= 0).all() and (unpacked <= 15).all(), ( + f"FP4 indices out of range: min={unpacked.min()}, max={unpacked.max()}" + ) + def test_input_validation_not_2d(self): """Should reject non-2D input.""" x = torch.randn(2, 3, 256, dtype=torch.bfloat16, device="cuda") diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 1aa20d4636..5209c1f777 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -138,9 +138,8 @@ __launch_bounds__(BLOCK_SIZE) // ========================================================================= // Pass 2: Compute block scales and quantize to FP4 // ========================================================================= - // TODO: FP4 data packing is disabled pending alignment investigation. - // For now, only per-token scales and block scales are computed. - // The FP4 data output is zeroed. + // Each thread processes one 16-element block: computes block amax, + // derives E4M3 block scale, then quantizes 16 elements to 8 packed FP4 bytes. const int num_sf_blocks = num_cols / PERTOKEN_SF_VEC_SIZE; for (int sf_idx = threadIdx.x; sf_idx < num_sf_blocks; sf_idx += BLOCK_SIZE) { @@ -148,25 +147,58 @@ __launch_bounds__(BLOCK_SIZE) // Load 16 elements and find block amax float block_max = 0.0f; + float vals[PERTOKEN_SF_VEC_SIZE]; for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j++) { - float val; if constexpr (std::is_same_v) { - val = __half2float(input[actual_row * num_cols + col_start + j]); + vals[j] = __half2float(input[actual_row * num_cols + col_start + j]); } else { - val = __bfloat162float(input[actual_row * num_cols + col_start + j]); + vals[j] = __bfloat162float(input[actual_row * num_cols + col_start + j]); } - block_max = fmaxf(block_max, fabsf(val)); + block_max = fmaxf(block_max, fabsf(vals[j])); } - // Compute and store per-block E4M3 scale factor + // Compute per-block E4M3 scale factor: S_dec_b = block_max / (fp4_max / S_enc) fp8e4m3 S_dec_b = quantization_SF::compute_decoding_scaling_factor(block_max, S_enc); + float S_dec_b_f = static_cast(S_dec_b); + + // Store block scale (LINEAR layout: row-major) output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; - } - // Zero out FP4 data output (placeholder until FP4 packing is validated) - const int data_bytes_per_row = num_cols / 2; - for (int i = threadIdx.x; i < data_bytes_per_row; i += BLOCK_SIZE) { - output_data[actual_row * data_bytes_per_row + i] = 0; + // Compute encoding scale for this block: maps input range to [-6, 6] (FP4 range) + float block_encode_scale = (S_dec_b_f != 0.0f) + ? __fdividef(S_enc, S_dec_b_f) + : 0.0f; + + // Scale values and pack to FP4 using PTX cvt.rn.satfinite.e2m1x2 + // Process 8 elements (4 pairs) at a time -> 4 bytes -> 1 uint32_t + // Matching FlashInfer's fp32_vec_to_e2m1 pattern. + uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; + for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 8) { + float s0 = vals[j] * block_encode_scale; + float s1 = vals[j + 1] * block_encode_scale; + float s2 = vals[j + 2] * block_encode_scale; + float s3 = vals[j + 3] * block_encode_scale; + float s4 = vals[j + 4] * block_encode_scale; + float s5 = vals[j + 5] * block_encode_scale; + float s6 = vals[j + 6] * block_encode_scale; + float s7 = vals[j + 7] * block_encode_scale; + uint32_t packed; + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}\n" + : "=r"(packed) + : "f"(s0), "f"(s1), "f"(s2), "f"(s3), + "f"(s4), "f"(s5), "f"(s6), "f"(s7)); + reinterpret_cast(out_ptr)[j / 8] = packed; + } + // Handle remaining 8 elements (PERTOKEN_SF_VEC_SIZE=16, so exactly 2 iterations of 8) + // The loop above covers j=0..7 and j=8..15, so all 16 elements are handled. } #endif // __CUDA_ARCH__ >= 1000 } From 5191ec0b2c69fc7ecaf48c3a31238dbb00b65c08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 22 Apr 2026 23:20:40 +0000 Subject: [PATCH 09/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_nvfp4_pertoken_quant.py | 36 +++++++++++++------ .../cast/nvfp4/quantize_pertoken_nvfp4.cuh | 9 ++--- 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index 20daf55733..f678cb0a5e 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -28,8 +28,24 @@ # FP4 E2M1 look-up table: 4-bit index -> float value # Lower nibble = first element, upper nibble = second element -_FP4_E2M1_LUT = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] +_FP4_E2M1_LUT = [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, +] def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: @@ -40,7 +56,7 @@ def unpack_fp4(packed: torch.Tensor) -> torch.Tensor: """ repeated = packed.repeat_interleave(2, dim=1) repeated[:, 0::2] = repeated[:, 0::2] & 0x0F # Lower 4 bits - repeated[:, 1::2] = repeated[:, 1::2] >> 4 # Upper 4 bits + repeated[:, 1::2] = repeated[:, 1::2] >> 4 # Upper 4 bits return repeated @@ -50,8 +66,9 @@ def fp4_to_fp32(unpacked: torch.Tensor) -> torch.Tensor: return lut[unpacked.long()] -def dequantize_pertoken_fp4(data: torch.Tensor, scales: torch.Tensor, - per_token_scales: torch.Tensor) -> torch.Tensor: +def dequantize_pertoken_fp4( + data: torch.Tensor, scales: torch.Tensor, per_token_scales: torch.Tensor +) -> torch.Tensor: """Dequantize per-token NVFP4: result = fp4_val * block_scale * per_token_scale. Args: @@ -285,8 +302,7 @@ def test_dequantized_data_close_to_input(self, num_rows, num_cols, dtype): x_f32 = x.float() nonzero = x_f32.abs() > 0.1 # skip very small values where relative error is meaningless if nonzero.any(): - rel_error = ((dequant[nonzero] - x_f32[nonzero]).abs() / - x_f32[nonzero].abs()).mean() + rel_error = ((dequant[nonzero] - x_f32[nonzero]).abs() / x_f32[nonzero].abs()).mean() assert rel_error < 0.5, ( f"Mean relative error {rel_error:.3f} too high for FP4 round-trip " f"(shape={num_rows}x{num_cols}, dtype={dtype})" @@ -299,9 +315,9 @@ def test_fp4_values_in_valid_range(self, dtype): data, _, _ = tex.quantize_nvfp4_pertoken(x) unpacked = unpack_fp4(data) - assert (unpacked >= 0).all() and (unpacked <= 15).all(), ( - f"FP4 indices out of range: min={unpacked.min()}, max={unpacked.max()}" - ) + assert (unpacked >= 0).all() and ( + unpacked <= 15 + ).all(), f"FP4 indices out of range: min={unpacked.min()}, max={unpacked.max()}" def test_input_validation_not_2d(self): """Should reject non-2D input.""" diff --git a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh index 5209c1f777..659835e51a 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh @@ -165,16 +165,14 @@ __launch_bounds__(BLOCK_SIZE) output_scales[row_idx * scale_stride + sf_idx] = S_dec_b; // Compute encoding scale for this block: maps input range to [-6, 6] (FP4 range) - float block_encode_scale = (S_dec_b_f != 0.0f) - ? __fdividef(S_enc, S_dec_b_f) - : 0.0f; + float block_encode_scale = (S_dec_b_f != 0.0f) ? __fdividef(S_enc, S_dec_b_f) : 0.0f; // Scale values and pack to FP4 using PTX cvt.rn.satfinite.e2m1x2 // Process 8 elements (4 pairs) at a time -> 4 bytes -> 1 uint32_t // Matching FlashInfer's fp32_vec_to_e2m1 pattern. uint8_t *out_ptr = output_data + actual_row * (num_cols / 2) + col_start / 2; for (int j = 0; j < PERTOKEN_SF_VEC_SIZE; j += 8) { - float s0 = vals[j] * block_encode_scale; + float s0 = vals[j] * block_encode_scale; float s1 = vals[j + 1] * block_encode_scale; float s2 = vals[j + 2] * block_encode_scale; float s3 = vals[j + 3] * block_encode_scale; @@ -193,8 +191,7 @@ __launch_bounds__(BLOCK_SIZE) "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" "}\n" : "=r"(packed) - : "f"(s0), "f"(s1), "f"(s2), "f"(s3), - "f"(s4), "f"(s5), "f"(s6), "f"(s7)); + : "f"(s0), "f"(s1), "f"(s2), "f"(s3), "f"(s4), "f"(s5), "f"(s6), "f"(s7)); reinterpret_cast(out_ptr)[j / 8] = packed; } // Handle remaining 8 elements (PERTOKEN_SF_VEC_SIZE=16, so exactly 2 iterations of 8) From cad6473639fd286c2f5bd2d4e6326af40358821b Mon Sep 17 00:00:00 2001 From: YigongQin Date: Fri, 24 Apr 2026 08:02:49 -0700 Subject: [PATCH 10/10] pytorch reference for quant --- tests/pytorch/test_nvfp4_pertoken_quant.py | 172 ++++++++++++++++++--- 1 file changed, 152 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_nvfp4_pertoken_quant.py b/tests/pytorch/test_nvfp4_pertoken_quant.py index f678cb0a5e..93a4376eb2 100644 --- a/tests/pytorch/test_nvfp4_pertoken_quant.py +++ b/tests/pytorch/test_nvfp4_pertoken_quant.py @@ -109,28 +109,71 @@ def _has_pertoken_kernel(): def nvfp4_pertoken_quantize_ref(input_tensor: torch.Tensor): """Pure PyTorch reference for per-token NVFP4 quantization. + Reproduces the exact logic of quantize_pertoken_nvfp4_kernel: + Pass 1: per-row amax → S_enc → per_token_scale + Pass 2: per-block(16) amax → S_dec_b (E4M3) → scale + quantize to FP4 + Returns: - per_token_scales: (num_rows,) FP32 tensor - global_scale[row] = row_amax / (fp8_max * fp4_max) + data: (M, K/2) uint8 packed FP4 + scales: (M, K/16) uint8 (FP8 E4M3 block scales) + per_token_scales: (M,) FP32 """ + from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import cast_to_fp4x2 + assert input_tensor.dim() == 2 num_rows, num_cols = input_tensor.shape assert num_cols % 16 == 0 - input_f32 = input_tensor.float() + x = input_tensor.float() + + # --- Pass 1: Per-row amax → S_enc → per_token_scale --- + row_amax = x.abs().amax(dim=1) # (M,) - # Per-row amax - row_amax = input_f32.abs().amax(dim=1) # (num_rows,) + # compute_global_encode_scaling_factor_FP4: S_enc = fp8_max * fp4_max / amax + S_enc = FP8_E4M3_MAX * FP4_MAX / row_amax + S_enc = torch.clamp(S_enc, max=torch.finfo(torch.float32).max) + S_enc = torch.where((row_amax == 0) | (S_enc == 0), torch.ones_like(S_enc), S_enc) - # S_enc = fp8_max * fp4_max / row_amax - # global_scale = 1 / S_enc = row_amax / (fp8_max * fp4_max) - # When amax=0, S_enc=1.0 (fallback), so global_scale=1.0 - per_token_scales = row_amax / (FP8_E4M3_MAX * FP4_MAX) + per_token_scales = 1.0 / S_enc # global_scale = 1 / S_enc per_token_scales = torch.where( row_amax == 0, torch.ones_like(per_token_scales), per_token_scales ) - return per_token_scales + # --- Pass 2: Per-block quantization --- + num_blocks = num_cols // 16 + x_blocks = x.view(num_rows, num_blocks, 16) # (M, K/16, 16) + + # Per-block amax + block_amax = x_blocks.abs().amax(dim=-1) # (M, K/16) + + # compute_decoding_scaling_factor: S_dec_b = block_amax * S_enc / fp4_max + # Then cast to FP8 E4M3 + S_enc_expanded = S_enc.unsqueeze(1) # (M, 1) + S_dec_b = block_amax * S_enc_expanded / FP4_MAX + S_dec_b = torch.clamp(S_dec_b, max=FP8_E4M3_MAX) + S_dec_b_fp8 = S_dec_b.to(torch.float8_e4m3fn) + S_dec_b_f = S_dec_b_fp8.float() + + # Block encode scale = S_enc / S_dec_b_f (inverse for quantization) + block_encode_scale = torch.where( + S_dec_b_f != 0, + S_enc_expanded / S_dec_b_f, + torch.zeros_like(S_dec_b_f), + ) # (M, K/16) + + # Scale input and clamp to FP4 range [-6, 6] + block_encode_expanded = block_encode_scale.unsqueeze(-1) # (M, K/16, 1) + scaled_x = x_blocks * block_encode_expanded # (M, K/16, 16) + scaled_x = scaled_x.reshape(num_rows, num_cols) + clamped_x = torch.clamp(scaled_x, -FP4_MAX, FP4_MAX) + + # Pack to FP4 using TE's reference cast_to_fp4x2 + data = cast_to_fp4x2(clamped_x) + + # Block scales as uint8 (FP8 E4M3 raw bytes) + scales = S_dec_b_fp8.view(torch.uint8) + + return data, scales, per_token_scales # --------------------------------------------------------------------------- @@ -186,11 +229,11 @@ def test_per_token_scales_match_reference(self, num_rows, num_cols, dtype): x = torch.randn(num_rows, num_cols, dtype=dtype, device="cuda") _, _, per_token_scales = tex.quantize_nvfp4_pertoken(x) - ref_scales = nvfp4_pertoken_quantize_ref(x) + _, _, ref_scales = nvfp4_pertoken_quantize_ref(x) torch.testing.assert_close( per_token_scales, - ref_scales, + ref_scales.to(device="cuda"), atol=1e-5, rtol=1e-3, msg="Per-token scales should match reference", @@ -337,6 +380,95 @@ def test_input_validation_wrong_dtype(self): with pytest.raises(RuntimeError): tex.quantize_nvfp4_pertoken(x) + # ----------------------------------------------------------------------- + # Exact byte-match tests (following test_nvfp4_quantize_exact.py pattern) + # ----------------------------------------------------------------------- + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_fp4_data_exact_match(self, M, N, dtype): + """FP4 packed data must exactly match Python reference (byte-for-byte).""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + data, scales, pts = tex.quantize_nvfp4_pertoken(x) + ref_data, ref_scales, ref_pts = nvfp4_pertoken_quantize_ref(x) + + # Unpack both to 4-bit indices for comparison + kernel_unpacked = unpack_fp4(data) + ref_unpacked = unpack_fp4(ref_data.to(device="cuda")) + + torch.testing.assert_close( + kernel_unpacked, + ref_unpacked, + atol=0.0, + rtol=0.0, + msg=f"FP4 data mismatch for shape ({M}, {N}), dtype={dtype}", + ) + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_block_scales_exact_match(self, M, N, dtype): + """Block scales must exactly match Python reference (byte-for-byte).""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + _, scales, _ = tex.quantize_nvfp4_pertoken(x) + _, ref_scales, _ = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + scales, + ref_scales.to(device="cuda"), + atol=0.0, + rtol=0.0, + msg=f"Block scales mismatch for shape ({M}, {N}), dtype={dtype}", + ) + + @pytest.mark.parametrize( + "M, N", + [ + (4, 256), + (16, 256), + (32, 1024), + (128, 4096), + ], + ) + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_per_token_scales_exact_match(self, M, N, dtype): + """Per-token scales must exactly match Python reference.""" + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn(M, N, dtype=dtype, device="cuda") + + _, _, pts = tex.quantize_nvfp4_pertoken(x) + _, _, ref_pts = nvfp4_pertoken_quantize_ref(x) + + torch.testing.assert_close( + pts, + ref_pts.to(device="cuda"), + atol=0.0, + rtol=0.0, + msg=f"Per-token scales mismatch for shape ({M}, {N}), dtype={dtype}", + ) + # --------------------------------------------------------------------------- # Standalone test (can run without tex binding for reference validation) @@ -349,9 +481,9 @@ class TestPertokenScaleReference: def test_reference_basic(self): """Basic reference test on CPU.""" x = torch.tensor([[1.0, 2.0, 3.0, 4.0] * 4], dtype=torch.float32) - scales = nvfp4_pertoken_quantize_ref(x) + _, _, pts = nvfp4_pertoken_quantize_ref(x) expected = torch.tensor([4.0 / (FP8_E4M3_MAX * FP4_MAX)]) - torch.testing.assert_close(scales, expected) + torch.testing.assert_close(pts, expected) def test_reference_multi_row(self): """Multi-row reference test.""" @@ -359,15 +491,15 @@ def test_reference_multi_row(self): x[0] = 1.0 x[1] = 10.0 x[2] = 0.1 - scales = nvfp4_pertoken_quantize_ref(x) + _, _, pts = nvfp4_pertoken_quantize_ref(x) - assert scales[1] > scales[0] > scales[2] - torch.testing.assert_close(scales[0], torch.tensor(1.0 / (FP8_E4M3_MAX * FP4_MAX))) - torch.testing.assert_close(scales[1], torch.tensor(10.0 / (FP8_E4M3_MAX * FP4_MAX))) + assert pts[1] > pts[0] > pts[2] + torch.testing.assert_close(pts[0], torch.tensor(1.0 / (FP8_E4M3_MAX * FP4_MAX))) + torch.testing.assert_close(pts[1], torch.tensor(10.0 / (FP8_E4M3_MAX * FP4_MAX))) def test_reference_zero_row(self): """Zero row: S_enc=1.0 fallback, so global_scale=1.0.""" x = torch.zeros(2, 16, dtype=torch.float32) x[0] = 5.0 - scales = nvfp4_pertoken_quantize_ref(x) - assert scales[1] == 1.0 + _, _, pts = nvfp4_pertoken_quantize_ref(x) + assert pts[1] == 1.0