Skip to content

Implement per-token NVFP4 fprop recipe#2931

Open
zianglih wants to merge 15 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token
Open

Implement per-token NVFP4 fprop recipe#2931
zianglih wants to merge 15 commits intoNVIDIA:mainfrom
zianglih:fp4-per-token

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented Apr 27, 2026

Description

@HumansAnd

Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.

The following tests passed on B200:

python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
python3 -m pytest --tb=auto tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
python3 -m pytest --tb=auto tests/pytorch/test_backward_override.py
python3 -m pytest --tb=auto tests/pytorch/test_sanity.py
python3 -m pytest --tb=auto tests/pytorch/test_recipe.py
python3 -m pytest --tb=auto tests/pytorch/test_torch_compile.py
python3 -m pytest --tb=auto tests/pytorch/test_cpu_offloading.py
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto tests/pytorch/test_cuda_graphs.py
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add a per_token_activation field in nvfp4 recipe, can be turned on by NVTE_NVFP4_PER_TOKEN_ACTIVATION
  • New per-token nvfp4 quantize kernels in transformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.
  • Expand dequant kernel transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh to correctly handle this per-token nvfp4
  • In TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemm
  • Broad test coverage by expanding 7 test files
  • Modify 1d quant reference implementation in tests/cpp/operator/test_cast_nvfp4_transpose.cu to align with pytorch reference numerics

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

zianglih and others added 13 commits April 26, 2026 23:07
Signed-off-by: Ziang Li <ziangli@umich.edu>
Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih marked this pull request as draft April 27, 2026 06:24
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 27, 2026

Greptile Summary

This PR adds a row-level NVFP4 quantization recipe for the forward pass, introducing a new CUDA kernel, C++ quantizer infrastructure, and a post-GEMM FP32 scaling step in the PyTorch GEMM extension. Backward usage is explicitly blocked via assertions in both general_gemm and general_grouped_gemm; test infrastructure skips unsupported backward paths.

  • assert statements as API guards in general_grouped_gemm (P2): the same pattern flagged for general_gemm recurs in the new grouped branch — silently disabled under python -O, producing wrong numerics instead of a clear error.
  • Backward quantizers receive the row-level activation flag without a recipe-level guard (pre-existing P1): NVFP4BlockScalingRecipeState forwards the flag to both forward and backward quantizers, so a backward GEMM whose B operand is a row-scaled tensor will crash at assert not grad inside general_gemm.

Confidence Score: 4/5

Safe to merge for fprop-only workloads; backward pass with row-scaled tensors will crash, but test infrastructure already guards against that path via skip helpers.

The pre-existing P1 (backward quantizers receiving the row-level activation flag without a recipe-level guard) sets a ceiling of 4/5. No new P0s were found. The two new P2 findings (assert-as-API-guard in grouped GEMM, avoidable extra copy in grouped GEMM loop) are non-blocking. CUDA kernel logic, quantizer plumbing, and post-GEMM scaling are all sound for the declared fprop-only scope.

transformer_engine/pytorch/quantization.py and transformer_engine/common/recipe/init.py — backward quantizer receiving the row-level activation flag without a backward-override guard is the primary risk.

Important Files Changed

Filename Overview
transformer_engine/common/cast/nvfp4/quantize_per_token_nvfp4.cuh New 486-line CUDA header: one block per row for rowwise amax+quantize (CUB block-reduce), one block per column for the columnwise transpose path. Logic is sound; the float32 input path reads 8-byte float2 pairs without an explicit alignment check.
transformer_engine/pytorch/cpp_extensions/gemm.py Adds per-token GEMM detection and post-GEMM FP32 scaling for both single and grouped GEMM; bias is correctly removed before per-token scaling and re-added. The new general_grouped_gemm branch uses assert for API contract enforcement and incurs an extra alloc+copy per sub-GEMM.
transformer_engine/pytorch/quantization.py Forwards per_token_activation to both forward and backward quantizers without restriction; backward GEMMs using a row-scaled B tensor will crash at assert not grad — only avoided in tests by skip_unsupported_backward_override.
transformer_engine/pytorch/csrc/extensions/cast.cpp Extends split_quantize and bulk_allocate_nvfp4_tensors with per-token amax buffer sizing; adds standalone quantize_nvfp4_per_token entry-point. Allocation/quantization method selection correctly bypasses 128-alignment and contiguity fallbacks for the per-token path.
transformer_engine/pytorch/csrc/quantizer.cpp Propagates per_token_activation through create_tensor, convert_and_update_tensor, and quantize_impl; amax tensors correctly sized to flat_first_dim rows for per-token mode. quantize_with_amax is blocked with an explicit NVTE_CHECK.
transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh Extends dequantize kernel with amax_numel parameter; tensor_amax[y] lookup for per-token path correctly uses row index y = thread_idx / M. Change is minimal and correct.
transformer_engine/common/recipe/init.py Adds per_token_activation field with env-var default; no post_init validation requiring a non-default backward_override when per-token is enabled — pre-existing P1 from previous review.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Correctly computes total_amax_elements = sum(prod(s[:-1]) for s in shape) for per-token mode and slices the flat amax buffer via nvfp4_per_token_amax_offsets when constructing individual tensor views.

Sequence Diagram

sequenceDiagram
    participant FwdPass as Forward Pass
    participant Quantizer as NVFP4Quantizer
    participant Kernel as CUDA Kernel
    participant GEMM as general_gemm
    participant PostScale as FP32 Post-Scale

    FwdPass->>Quantizer: split_quantize(activation, rowwise=True)
    Quantizer->>Kernel: quantize rowwise nvfp4 kernel
    Note over Kernel: Row amax reduction, one block per row
    Kernel-->>Quantizer: fp4 data, fp8 scales, amax vector
    Quantizer-->>FwdPass: NVFP4TensorStorage with rowwise amax

    FwdPass->>GEMM: general_gemm(weight_A, activation_B)
    GEMM->>GEMM: detect rowwise amax tensor in B
    GEMM->>GEMM: replace B amax with ones tensor
    GEMM->>GEMM: tex.generic_gemm produces fp32 output
    GEMM->>PostScale: multiply output rows by rowwise scales
    Note over PostScale: out_2d.mul_(amax_B.view(-1, 1))
    PostScale-->>FwdPass: scaled output converted to target dtype
Loading

Reviews (2): Last reviewed commit: "Expand .cu test" | Re-trigger Greptile

Comment thread transformer_engine/pytorch/cpp_extensions/gemm.py
Signed-off-by: Ziang Li <ziangli@umich.edu>
// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.

@zianglih zianglih marked this pull request as ready for review April 27, 2026 09:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant