[Common, PyTorch] Add triton mHC kernels & pytorch APIs#2790
[Common, PyTorch] Add triton mHC kernels & pytorch APIs#2790kainzhong wants to merge 23 commits intoNVIDIA:mainfrom
Conversation
7a2cdea to
4657371
Compare
|
Benchmark result:
Naming: For this benchmark n=4 is fixed since it's the most reasonable hyper connection size (n=1 is worse than baseline according to the ByteDance original hyper connection paper, n=2 requires paddings to utilize TensorCore for the projection GEMM, and for n>4 the activation usage will also be multiplied by n) "triton" means my triton kernels, and "cutile" means the cutile implementation from this Megatron PR NVIDIA/Megatron-LM#3828 B: batch size My benchmark script: Expand & CombineFWD: out(M,C,4) = f(M,C,1) @ H_post(M,1,4) + x(M,C,4) @ H_res(M,4,4) BWD: grads for f, H_post, x, H_res, bias
AggregateFWD: out(M,C) = x(M,C,4) @ H_pre(M,4,1) BWD: grads for x, H_pre
Projection + RMSNormFWD: Hs(M,24) = x(M,4C) @ phi(4C,24)^T, ms(M,) = mean(x^2) BWD: back through scale + projection; grads for x, phi Note: I cheated a bit in my backward implementation because I leave grad_phi to pytorch to compute, which is a pure GEMM and it's hard to fuse it to grad_x kernel efficiently. Since grad_phi is (4C, 24) is relatively negligible compared to x (M, 4C) if M is large, so I used the same IO for my SOL analysis since I don't think it will make too much difference. I also run the end to end operator level comparison using triton's benchmark tool, which includes the pytorch GEMM time for grad_phi. My implementation is still more efficient (this diagram is backward only):
SinkhornFWD: 20 Sinkhorn iterations on H_res(M,4,4) in fp32. Data is tiny — kernel is occupancy/launch-bound, not BW-bound. BWD: backprop through 20 iterations with recompute
|
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
5fbf1ec to
1d30e88
Compare
Greptile SummaryThis PR implements DeepSeek's mHC (Manifold-Constrained Hyper-Connections) paper via five new Triton kernels ( Confidence Score: 5/5Safe to merge; only P2 style findings, all correctness-critical paths are sound All five kernel backward passes are mathematically correct. The two findings are both P2: a dead memory allocation and unused kernel parameters. No logic bugs, shape mismatches, or thread-safety issues were found. transformer_engine/pytorch/triton/mhc.py (dead grad_x allocation), transformer_engine/common/triton/mhc.py (unused stride params) Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["x: (M, nC)\nphi: (N, K)"] -->|mhc_fused_projection| B["H: (M,32)\nms: (M,)"]
B -->|mhc_fused_scale| C["H_pre: (M,n)\nH_post: (M,n)\nH_res: (M,n²)"]
C -->|mhc_fused_sinkhorn| D["H_res_norm: (s,b,n,n)\ndoubly-stochastic"]
E["x_stream: (s,b,C,n)"] -->|mhc_fused_aggregate + H_pre| F["out: (s,b,C)"]
F -->|Attention/FFN sublayer| G["f: (s,b,C)"]
G -->|mhc_fused_expand_combine + H_post + H_res| H["out: (s,b,C,n)"]
E --> H
D --> H
Reviews (15): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
|
Tip: Greploops — Automatically fix all review issues by running Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal. |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
| mhc_fused_projection, | ||
| ) | ||
|
|
||
| reset_rng_states() |
There was a problem hiding this comment.
This should be done per test to enable running both the full test suite and the individual tests with the same inputs.
There was a problem hiding this comment.
Made each test call this function instead now
| # Enable TF32 for matmul to ensure consistency between the fused and reference implementations | ||
| torch.backends.cuda.matmul.allow_tf32 = False |
There was a problem hiding this comment.
Either the comment is wrong or the code is wrong here.
There was a problem hiding this comment.
Sorry I forgot to change the comment. Fixed now
| n: int = 4 # Number of Hyper Connection streams | ||
|
|
||
| allow_n = [ | ||
| 4, |
There was a problem hiding this comment.
In the end, do we only allow 4 here or do we also work for n equal to 2?
There was a problem hiding this comment.
I have another branch for the n=2 kernels but I haven't done the pretraining and gather metrics. I was thinking about making that a separate PR later but I can combine them into one as well. It shouldn't be too hard since n=4 is already validated so I expect n=2 to be working as well.
| if dtype == torch.bfloat16: | ||
| tols = dict(atol=2.5e-2, rtol=2.5e-2) | ||
| else: | ||
| tols = dict(atol=5e-3, rtol=5e-3) |
There was a problem hiding this comment.
For FP32 this tolerance seems a little high. What is the test that needs that tolerance?
There was a problem hiding this comment.
I used this for all FP32 allclose tests. I tried stricter tolerance but I think my extensive usage of atomic add makes it hard to full match pytorch's result. One mismatch comes from the projection kernel where I use split-K with atomic add for the GEMM, and another one comes from aggregate and expand's H_pre / H_post where the gradient is computed as (1, C) @ (C, n) = (1, n) and C >> n.
| if use_tf32: | ||
| _mhc_projection_fwd_fused[grid]( | ||
| x_ptr=x, # (M, K) | ||
| phi_ptr=phi, # (N, K) | ||
| h_ptr=H, # (M, 32) | ||
| ms_ptr=ms, # (M,) | ||
| M=M, | ||
| N=N, | ||
| K=K, | ||
| stride_xm=K, | ||
| stride_xk=1, | ||
| stride_phin=K, | ||
| stride_phik=1, | ||
| stride_hm=32, | ||
| stride_hn=1, | ||
| stride_ms=1, | ||
| BLOCK_SIZE_N=32, | ||
| precision="tf32", | ||
| ) | ||
| else: | ||
| _mhc_projection_fwd_fused[grid]( | ||
| x_ptr=x, # (M, K) | ||
| phi_ptr=phi, # (N, K) | ||
| h_ptr=H, # (M, 32) | ||
| ms_ptr=ms, # (M,) | ||
| M=M, | ||
| N=N, | ||
| K=K, | ||
| stride_xm=K, | ||
| stride_xk=1, | ||
| stride_phin=K, | ||
| stride_phik=1, | ||
| stride_hm=32, | ||
| stride_hn=1, | ||
| stride_ms=1, | ||
| BLOCK_SIZE_N=32, | ||
| precision="ieee", | ||
| ) |
There was a problem hiding this comment.
| if use_tf32: | |
| _mhc_projection_fwd_fused[grid]( | |
| x_ptr=x, # (M, K) | |
| phi_ptr=phi, # (N, K) | |
| h_ptr=H, # (M, 32) | |
| ms_ptr=ms, # (M,) | |
| M=M, | |
| N=N, | |
| K=K, | |
| stride_xm=K, | |
| stride_xk=1, | |
| stride_phin=K, | |
| stride_phik=1, | |
| stride_hm=32, | |
| stride_hn=1, | |
| stride_ms=1, | |
| BLOCK_SIZE_N=32, | |
| precision="tf32", | |
| ) | |
| else: | |
| _mhc_projection_fwd_fused[grid]( | |
| x_ptr=x, # (M, K) | |
| phi_ptr=phi, # (N, K) | |
| h_ptr=H, # (M, 32) | |
| ms_ptr=ms, # (M,) | |
| M=M, | |
| N=N, | |
| K=K, | |
| stride_xm=K, | |
| stride_xk=1, | |
| stride_phin=K, | |
| stride_phik=1, | |
| stride_hm=32, | |
| stride_hn=1, | |
| stride_ms=1, | |
| BLOCK_SIZE_N=32, | |
| precision="ieee", | |
| ) | |
| _mhc_projection_fwd_fused[grid]( | |
| x_ptr=x, # (M, K) | |
| phi_ptr=phi, # (N, K) | |
| h_ptr=H, # (M, 32) | |
| ms_ptr=ms, # (M,) | |
| M=M, | |
| N=N, | |
| K=K, | |
| stride_xm=K, | |
| stride_xk=1, | |
| stride_phin=K, | |
| stride_phik=1, | |
| stride_hm=32, | |
| stride_hn=1, | |
| stride_ms=1, | |
| BLOCK_SIZE_N=32, | |
| precision="tf32" if use_tf32 else "ieee", | |
| ) |
There was a problem hiding this comment.
And similar for the other cases in this file.
|
|
||
| def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True): | ||
| """ | ||
| Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, seciton 4.3.1 of the DeepSeek mHC paper): |
There was a problem hiding this comment.
| Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, seciton 4.3.1 of the DeepSeek mHC paper): | |
| Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper): |
| ) | ||
|
|
||
| h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn | ||
| tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed") |
There was a problem hiding this comment.
Please add the guards so that there is an error when somebody tries to run this function with NVTE_ALLOW_NONDETERMINISTIC_ALGO set to 0.
There was a problem hiding this comment.
Add an assertion for NVTE_ALLOW_NONDETERMINISTIC_ALGO in APIs where atomic_add might be called afterwards
| ) | ||
|
|
||
|
|
||
| def mhc_fused_sinkhorn( |
There was a problem hiding this comment.
Please add the functions to the Sphinx documentation.
There was a problem hiding this comment.
Added in docs/api/pytorch.rst
|
There are some other typos in the docstrings besides the one I flagged, could you check those docstrings? |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
|
/te-ci pytorch |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
2f258c9 to
2a4225d
Compare
| allow_bf16_reduced_precision_reduction = ( | ||
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction | ||
| ) | ||
| # Use FP32 accumulator in case of pytorch choosing a path with BF16 accumulator which hurts accuracy, | ||
| # which seems to happen on Ampere but not on Hopper and Blackwell | ||
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False | ||
| grad_phi = (grad_H.T @ x)[:N, :].to( | ||
| ctx.phi_dtype | ||
| ) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), note that the last dimension of grad_H is already padded to 32 | ||
| # Recover the original pytorch setting | ||
| torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = ( | ||
| allow_bf16_reduced_precision_reduction | ||
| ) |
There was a problem hiding this comment.
Global state mutation in backward is not thread-safe
mHCProjectionOp.backward temporarily sets torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False on a process-global object. If two backward passes run concurrently (e.g., in a DataParallel setup where the autograd engine spawns multiple CUDA streams, or when users manually overlap backward passes), one thread can restore the flag while another is still executing the BF16 matmul — silently producing lower-precision gradients or overwriting the caller's original setting.
A safer approach is to cast inputs to float32 explicitly before the matmul instead of relying on the global flag.
|
/te-ci pytorch |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci
|
@kainzhong please check out my implementation of mHC-lite in triton, I mostly focused on SM120 and SM90 and all the kernels perform at ~80-95% SOL at T=65536,H=768,N=4 I've implemented it in our protein language modeling project with Rostlab |
|
@kainzhong have we compared with the original TileLang implementation: https://github.com/deepseek-ai/TileKernels/tree/main? |
|
@sbhavani I asked claude code to vibe code some benchmark scripts and ran on B200. Here are the results: https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/benchmark.md (triton - this implementation, cutile - Megatron-LM's fused kernels, tilelang - DeepSeek's tilelang kernels) The naming for different implementation is a bit chaotic so refer to https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/kernel_comparison.md#kernel-name--framework-op-direction-mapping if you get lost In terms of the performance, I think the major difference is DeepSeek's implementation prioritizes determinism and my implementation prefers efficiency. In my implementation I use atomic add heavily whereas DeepSeek seems to avoid that. I can work on a set of deterministic kernels if it's needed. (Btw I took a brief look at @alint77 's kernels and they seem to be deterministic) Another thing is DeepSeek seems to have a CUDA C++ kernel for eq. 14-15 in https://github.com/deepseek-ai/DeepGEMM/blob/891d57b4db1071624b5c8fa0d1e51cb317fa709f/deep_gemm/__init__.py#L69 which supports split-K for this GEMM whereas tilelang implementation doesn't: I'll profile this one separately and if the current implementation turns to be slower then I can work on a CuTeDSL / cutlass version and try to match the performance. Aside of performance, DeepSeek has some additional optimizations that are not mentioned in the original mHC paper. I asked claude code to write a summary in https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/triton_vs_tilelang.md but I don't fully trust it (I think it hallucinates in some parts). I'll take a closer look next week.
I only skimmed through DeepSeek's code and this is my rudimentary analysis for their differences: Eq. 14-15 (projection + RMSnorm)
Eq. 16-18
Eq. 19
There are another two kernels for F_pre = H_pre @ x and F_post_res = H_res @ x + H_post @ out but they don't differ much functionally. In addition, they have a few other kernels.
I can work with the Megatron team to integrate these kernels if needed. |
|
/te-ci pytorch |
|
@kainzhong thanks for the benchmark! This is promising that Triton performs best in the per op training kernels. I was also wondering if TileLang's extra fusions might benefit e2e training performance. Did you also run a full training step with Megatron? |
|
@sbhavani Note that triton kernels gain advantage by aggressively apply split-K strategy to maximize parallelism, whereas tilelang kernels sacrifice some performance for determinism. See DeepSeek V4 tech report's section 3.3:
As for e2e performance I'll run some experiments this week -- though I believe at least |
|
/te-ci pytorch |
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
|
/te-ci pytorch |
for more information, see https://pre-commit.ci






Description
Implementation of DeepSeek's mHC: Manifold-Constrained Hyper-Connections paper
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: