Implement per-token NVFP4 fprop recipe#2931
Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu> Co-authored-by: Yigong Qin <qqqyyy1233@outlook.com>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Greptile SummaryThis PR adds a row-level NVFP4 quantization recipe for the forward pass, introducing a new CUDA kernel, C++ quantizer infrastructure, and a post-GEMM FP32 scaling step in the PyTorch GEMM extension. Backward usage is explicitly blocked via assertions in both
Confidence Score: 4/5Safe to merge for fprop-only workloads; backward pass with row-scaled tensors will crash, but test infrastructure already guards against that path via skip helpers. The pre-existing P1 (backward quantizers receiving the row-level activation flag without a recipe-level guard) sets a ceiling of 4/5. No new P0s were found. The two new P2 findings (assert-as-API-guard in grouped GEMM, avoidable extra copy in grouped GEMM loop) are non-blocking. CUDA kernel logic, quantizer plumbing, and post-GEMM scaling are all sound for the declared fprop-only scope. transformer_engine/pytorch/quantization.py and transformer_engine/common/recipe/init.py — backward quantizer receiving the row-level activation flag without a backward-override guard is the primary risk. Important Files Changed
Sequence DiagramsequenceDiagram
participant FwdPass as Forward Pass
participant Quantizer as NVFP4Quantizer
participant Kernel as CUDA Kernel
participant GEMM as general_gemm
participant PostScale as FP32 Post-Scale
FwdPass->>Quantizer: split_quantize(activation, rowwise=True)
Quantizer->>Kernel: quantize rowwise nvfp4 kernel
Note over Kernel: Row amax reduction, one block per row
Kernel-->>Quantizer: fp4 data, fp8 scales, amax vector
Quantizer-->>FwdPass: NVFP4TensorStorage with rowwise amax
FwdPass->>GEMM: general_gemm(weight_A, activation_B)
GEMM->>GEMM: detect rowwise amax tensor in B
GEMM->>GEMM: replace B amax with ones tensor
GEMM->>GEMM: tex.generic_gemm produces fp32 output
GEMM->>PostScale: multiply output rows by rowwise scales
Note over PostScale: out_2d.mul_(amax_B.view(-1, 1))
PostScale-->>FwdPass: scaled output converted to target dtype
Reviews (2): Last reviewed commit: "Expand .cu test" | Re-trigger Greptile |
Signed-off-by: Ziang Li <ziangli@umich.edu>
| // Compute "correct" per-block encoding scaling factor | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; | ||
| const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : | ||
| fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm); |
There was a problem hiding this comment.
We have to change here to stay aligned with pytorch reference.
Description
@HumansAnd
Implement per-token NVFP4 recipe with fprop only.
Currently, the per-token scaling is handled by separate pytorch code.
Quantization kernels are bitwise exact with existing TE reference implementation.
The following tests passed on B200:
Type of change
Changes
Please list the changes introduced in this PR:
per_token_activationfield in nvfp4 recipe, can be turned on byNVTE_NVFP4_PER_TOKEN_ACTIVATIONtransformer_engine/common/cast/nvfp4/quantize_pertoken_nvfp4.cuh, bitwise exact with existing TE pytorch reference implementation and per-tesor nvfp4 emulated implmentation.transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuhto correctly handle this per-token nvfp4TransformerEngine/transformer_engine/pytorch/cpp_extensions/gemm.py, if per-token nvfp4 is detected, it conducts separate per-token scaling using pytorch code, after cublas gemmtests/cpp/operator/test_cast_nvfp4_transpose.cuto align with pytorch reference numericsChecklist: