Skip to content

[Common, PyTorch] Add triton mHC kernels & pytorch APIs#2790

Open
kainzhong wants to merge 23 commits intoNVIDIA:mainfrom
kainzhong:feat/mhc_kernels
Open

[Common, PyTorch] Add triton mHC kernels & pytorch APIs#2790
kainzhong wants to merge 23 commits intoNVIDIA:mainfrom
kainzhong:feat/mhc_kernels

Conversation

@kainzhong
Copy link
Copy Markdown
Collaborator

@kainzhong kainzhong commented Mar 23, 2026

Description

Implementation of DeepSeek's mHC: Manifold-Constrained Hyper-Connections paper

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added triton mHC kernels
  • Added pytorch interface

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@kainzhong kainzhong force-pushed the feat/mhc_kernels branch 3 times, most recently from 7a2cdea to 4657371 Compare April 6, 2026 02:24
@kainzhong
Copy link
Copy Markdown
Collaborator Author

kainzhong commented Apr 8, 2026

Benchmark result:

image image

Naming:
eq. 14-15: Projection
eq. 16-18: Scale (this operation's overhead is negligible -- for B=4, T=8192 bwd+fwd uses < 10us, also Megatron's implementation lets torch.compile to handle this so there's nothing to compare)
eq. 19: Sinkhorn
$F_{pre}$: Aggregate
$F_{post}$: Expand & Combine

For this benchmark n=4 is fixed since it's the most reasonable hyper connection size (n=1 is worse than baseline according to the ByteDance original hyper connection paper, n=2 requires paddings to utilize TensorCore for the projection GEMM, and for n>4 the activation usage will also be multiplied by n)

"triton" means my triton kernels, and "cutile" means the cutile implementation from this Megatron PR NVIDIA/Megatron-LM#3828

B: batch size
T: sequence length
C: hidden size
dtype is bf16
t_min is the SOL time (since all kernels are memory bandwidth bounded it's just IO / bandwidth)

My benchmark script:
mhc_bench.sh
mhc_bench.py
These results are on B200

Expand & Combine

FWD: out(M,C,4) = f(M,C,1) @ H_post(M,1,4) + x(M,C,4) @ H_res(M,4,4)
IO_fwd = M * (18C + 40) bytes — read f(2MC), H_post(8M), x(8MC), H_res(32M); write out(8MC)

BWD: grads for f, H_post, x, H_res, bias
IO_bwd = M * (28C + 80) bytes — read grad_out(8MC), f(2MC), H_post(8M), x(8MC), H_res(32M); write grad_f(2MC), grad_H_post(8M), grad_x(8MC), grad_H_res(32M)

Config M IO FWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 453.1 56.6 66.0 85.8% 90.4 62.6% 1.37x
B1_T4096_C6656 4096 490.9 61.4 72.0 85.3% 120.3 51.0% 1.67x
B1_T4096_C8192 4096 604.1 75.5 87.4 86.4% 117.2 64.4% 1.34x
B1_T8192_C4096 8192 604.3 75.5 87.5 86.3% 111.1 68.0% 1.27x
B1_T8192_C8192 8192 1208.3 151.0 173.1 87.2% 226.2 66.8% 1.31x
B1_T8192_C16384 8192 2415.8 302.0 345.5 87.4% 454.1 66.5% 1.31x
B2_T4096_C5120 8192 755.3 94.4 110.0 85.8% 138.1 68.4% 1.26x
B3_T8192_C2048 24576 907.0 113.4 130.6 86.8% 192.9 58.8% 1.48x
B4_T2048_C4096 8192 604.3 75.5 87.5 86.3% 110.6 68.3% 1.26x
B4_T4096_C4096 16384 1208.6 151.1 173.8 86.9% 216.8 69.7% 1.25x
B4_T8192_C4096 32768 2417.2 302.2 346.1 87.3% 428.0 70.6% 1.24x
B4_T8192_C7168 32768 4229.2 528.6 599.8 88.1% 762.6 69.3% 1.27x
B8_T2048_C2560 16384 755.6 94.5 109.2 86.5% 190.8 49.5% 1.75x
Config M IO BWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 704.8 88.1 135.6 65.0% 466.1 18.9% 3.44x
B1_T4096_C6656 4096 763.7 95.5 146.7 65.1% 599.4 15.9% 4.09x
B1_T4096_C8192 4096 939.9 117.5 179.0 65.6% 635.5 18.5% 3.55x
B1_T8192_C4096 8192 940.2 117.5 181.0 64.9% 679.7 17.3% 3.76x
B1_T8192_C8192 8192 1879.7 234.9 351.3 66.9% 1259.6 18.7% 3.59x
B1_T8192_C16384 8192 3758.8 469.8 705.3 66.6% 2422.6 19.4% 3.44x
B2_T4096_C5120 8192 1175.1 146.9 221.9 66.2% 826.1 17.8% 3.72x
B3_T8192_C2048 24576 1411.3 176.4 267.9 65.8% 1157.8 15.2% 4.32x
B4_T2048_C4096 8192 940.2 117.5 180.8 65.0% 679.5 17.3% 3.76x
B4_T4096_C4096 16384 1880.4 235.0 355.1 66.2% 1351.6 17.4% 3.81x
B4_T8192_C4096 32768 3760.7 470.1 704.7 66.7% 2696.4 17.4% 3.83x
B4_T8192_C7168 32768 6579.3 822.4 1221.1 67.4% 4426.5 18.6% 3.63x
B8_T2048_C2560 16384 1175.7 147.0 222.9 65.9% 1005.3 14.6% 4.51x

Aggregate

FWD: out(M,C) = x(M,C,4) @ H_pre(M,4,1)
IO_fwd = M * (10C + 8) bytes — read x(8MC), H_pre(8M); write out(2MC)

BWD: grads for x, H_pre
IO_bwd = M * (18C + 16) bytes — read grad_out(2MC), x(8MC), H_pre(8M); write grad_x(8MC), grad_H_pre(8M)

Config M IO FWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 251.8 31.5 36.1 87.2% 43.7 72.1% 1.21x
B1_T4096_C6656 4096 272.7 34.1 39.5 86.3% 210.8 16.2% 5.34x
B1_T4096_C8192 4096 335.6 42.0 51.0 82.4% 59.1 71.1% 1.16x
B1_T8192_C4096 8192 335.6 42.0 50.3 83.5% 60.4 69.5% 1.20x
B1_T8192_C8192 8192 671.2 83.9 100.2 83.7% 107.7 77.9% 1.07x
B1_T8192_C16384 8192 1342.2 167.8 195.8 85.7% 199.9 83.9% 1.02x
B2_T4096_C5120 8192 419.5 52.4 64.1 81.7% 73.6 71.2% 1.15x
B3_T8192_C2048 24576 503.5 62.9 76.5 82.2% 110.3 57.0% 1.44x
B4_T2048_C4096 8192 335.6 42.0 50.8 82.7% 59.7 70.4% 1.18x
B4_T4096_C4096 16384 671.2 83.9 100.0 83.9% 110.2 76.1% 1.10x
B4_T8192_C4096 32768 1342.4 167.8 195.9 85.7% 209.4 80.1% 1.07x
B4_T8192_C7168 32768 2349.1 293.6 334.9 87.7% 342.6 85.7% 1.02x
B8_T2048_C2560 16384 419.6 52.4 64.0 81.9% 341.6 15.3% 5.34x
Config M IO BWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 453.0 56.6 76.2 74.3% 144.8 39.1% 1.90x
B1_T4096_C6656 4096 490.8 61.3 82.3 74.5% 163.2 37.6% 1.98x
B1_T4096_C8192 4096 604.0 75.5 100.7 75.0% 177.9 42.4% 1.77x
B1_T8192_C4096 8192 604.1 75.5 100.6 75.0% 186.6 40.5% 1.85x
B1_T8192_C8192 8192 1208.1 151.0 198.6 76.0% 348.4 43.3% 1.75x
B1_T8192_C16384 8192 2416.1 302.0 390.5 77.3% 671.0 45.0% 1.72x
B2_T4096_C5120 8192 755.1 94.4 125.1 75.5% 227.2 41.5% 1.82x
B3_T8192_C2048 24576 906.4 113.3 149.1 76.0% 305.6 37.1% 2.05x
B4_T2048_C4096 8192 604.1 75.5 100.5 75.1% 187.0 40.4% 1.86x
B4_T4096_C4096 16384 1208.2 151.0 197.8 76.3% 365.8 41.3% 1.85x
B4_T8192_C4096 32768 2416.4 302.1 391.3 77.2% 724.1 41.7% 1.85x
B4_T8192_C7168 32768 4228.4 528.5 682.6 77.4% 1201.9 44.0% 1.76x
B8_T2048_C2560 16384 755.2 94.4 125.3 75.3% 271.0 34.8% 2.16x

Projection + RMSNorm

FWD: Hs(M,24) = x(M,4C) @ phi(4C,24)^T, ms(M,) = mean(x^2)
IO_fwd = M * (8C + 52) + 192C bytes — read x(8MC), phi(192C); write Hs(48M), ms(4M)

BWD: back through scale + projection; grads for x, phi
IO_bwd = M * (16C + 100) + 384C bytes — read grad_H(48M), phi(192C), x(8MC), ms(4M) + saved intermediates(~48M); write grad_x(8MC), grad_phi(192C)

Note: I cheated a bit in my backward implementation because I leave grad_phi to pytorch to compute, which is a pure GEMM and it's hard to fuse it to grad_x kernel efficiently. Since grad_phi is (4C, 24) is relatively negligible compared to x (M, 4C) if M is large, so I used the same IO for my SOL analysis since I don't think it will make too much difference.

I also run the end to end operator level comparison using triton's benchmark tool, which includes the pytorch GEMM time for grad_phi. My implementation is still more efficient (this diagram is backward only):
image

Config M IO FWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 203.8 25.5 43.3 58.9% 342.9 7.4% 7.92x
B1_T4096_C6656 4096 219.6 27.4 46.5 58.9% 192.5 14.2% 4.14x
B1_T4096_C8192 4096 270.2 33.8 60.3 56.1% 236.0 14.3% 3.91x
B1_T8192_C4096 8192 269.6 33.7 58.1 58.0% 121.8 27.7% 2.10x
B1_T8192_C8192 8192 538.9 67.4 105.8 63.7% 238.0 28.3% 2.25x
B1_T8192_C16384 8192 1077.3 134.7 196.4 68.6% 466.5 28.9% 2.38x
B2_T4096_C5120 8192 337.0 42.1 69.7 60.4% 151.4 27.8% 2.17x
B3_T8192_C2048 24576 404.3 50.5 83.5 60.5% 124.9 40.4% 1.50x
B4_T2048_C4096 8192 269.6 33.7 59.0 57.1% 121.7 27.7% 2.06x
B4_T4096_C4096 16384 538.5 67.3 106.0 63.5% 124.8 53.9% 1.18x
B4_T8192_C4096 32768 1076.2 134.5 194.5 69.2% 243.4 55.3% 1.25x
B4_T8192_C7168 32768 1882.1 235.3 331.6 71.0% 414.0 56.8% 1.25x
B8_T2048_C2560 16384 336.9 42.1 69.7 60.4% 81.6 51.6% 1.17x
Config M IO BWD (MB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 407.6 50.9 92.4 55.1% 328.8 15.5% 3.56x
B1_T4096_C6656 4096 439.2 54.9 99.0 55.5% 376.7 14.6% 3.81x
B1_T4096_C8192 4096 540.4 67.6 119.3 56.7% 384.2 17.6% 3.22x
B1_T8192_C4096 8192 539.3 67.4 120.7 55.8% 391.4 17.2% 3.24x
B1_T8192_C8192 8192 1077.7 134.7 235.4 57.2% 747.2 18.0% 3.17x
B1_T8192_C16384 8192 2154.6 269.3 470.0 57.3% 1397.1 19.3% 2.97x
B2_T4096_C5120 8192 673.9 84.2 148.7 56.6% 693.3 12.1% 4.66x
B3_T8192_C2048 24576 808.6 101.1 176.5 57.3% 1094.1 9.2% 6.20x
B4_T2048_C4096 8192 539.3 67.4 120.4 56.0% 390.9 17.2% 3.25x
B4_T4096_C4096 16384 1077.0 134.6 234.3 57.4% 766.4 17.6% 3.27x
B4_T8192_C4096 32768 2152.3 269.0 466.5 57.7% 1517.8 17.7% 3.25x
B4_T8192_C7168 32768 3764.1 470.5 815.7 57.7% 2736.0 17.2% 3.35x
B8_T2048_C2560 16384 673.7 84.2 148.3 56.8% 735.7 11.4% 4.96x

Sinkhorn

FWD: 20 Sinkhorn iterations on H_res(M,4,4) in fp32. Data is tiny — kernel is occupancy/launch-bound, not BW-bound.
IO_fwd = M * 96 bytes — read H_res(64M, fp32); write out(32M, bf16)

BWD: backprop through 20 iterations with recompute
IO_bwd = M * 128 bytes — read grad_out(32M, bf16), H_res(64M, fp32); write grad_H_res(32M, bf16)

Config M IO FWD (KB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 192 0.024 7.4 0.3% 90.8 0.03% 12.3x
B1_T4096_C6656 4096 384 0.048 7.8 0.6% 91.3 0.05% 11.7x
B1_T4096_C8192 4096 384 0.048 7.8 0.6% 100.7 0.05% 12.9x
B1_T8192_C4096 8192 768 0.096 9.9 1.0% 92.0 0.10% 9.29x
B1_T8192_C8192 8192 768 0.096 9.8 1.0% 91.7 0.10% 9.36x
B1_T8192_C16384 8192 768 0.096 9.9 1.0% 91.8 0.10% 9.27x
B2_T4096_C5120 8192 768 0.096 9.9 1.0% 92.2 0.10% 9.31x
B3_T8192_C2048 24576 2304 0.288 18.8 1.5% 98.8 0.29% 5.26x
B4_T2048_C4096 8192 768 0.096 9.9 1.0% 91.9 0.10% 9.28x
B4_T4096_C4096 16384 1536 0.192 13.5 1.4% 93.6 0.21% 6.93x
B4_T8192_C4096 32768 3072 0.384 22.5 1.7% 100.3 0.38% 4.46x
B4_T8192_C7168 32768 3072 0.384 22.4 1.7% 99.6 0.39% 4.45x
B8_T2048_C2560 16384 1536 0.192 13.6 1.4% 93.3 0.21% 6.86x
Config M IO BWD (KB) t_min (us) Triton (us) Triton SOL Cutile (us) Cutile SOL Speedup
B1_T2048_C12288 2048 256 0.032 11.0 0.3% 238.6 0.01% 21.7x
B1_T4096_C6656 4096 512 0.064 11.1 0.6% 200.9 0.03% 18.1x
B1_T4096_C8192 4096 512 0.064 11.1 0.6% 210.3 0.03% 18.9x
B1_T8192_C4096 8192 1024 0.128 13.5 0.9% 210.5 0.06% 15.6x
B1_T8192_C8192 8192 1024 0.128 13.5 0.9% 238.0 0.05% 17.6x
B1_T8192_C16384 8192 1024 0.128 13.5 0.9% 238.0 0.05% 17.6x
B2_T4096_C5120 8192 1024 0.128 13.6 0.9% 238.2 0.05% 17.5x
B3_T8192_C2048 24576 3072 0.384 28.3 1.4% 255.9 0.15% 9.04x
B4_T2048_C4096 8192 1024 0.128 13.5 0.9% 210.3 0.06% 15.6x
B4_T4096_C4096 16384 2048 0.256 20.1 1.3% 240.5 0.11% 12.0x
B4_T8192_C4096 32768 4096 0.512 29.2 1.8% 266.5 0.19% 9.13x
B4_T8192_C7168 32768 4096 0.512 29.2 1.8% 265.8 0.19% 9.10x
B8_T2048_C2560 16384 2048 0.256 20.2 1.3% 238.2 0.11% 11.8x

@kainzhong
Copy link
Copy Markdown
Collaborator Author

kainzhong commented Apr 8, 2026

Validated on a 2B nemotron model

Training loss:
image
Validation loss:
image

Both config use pure DP on 16 nodes, each with 8 B200, so I think DP size is 128, GBS=768, MBS=3, grad_acc=2. From Megatron logs (I don't know how to ask Megatron to export it to tensorboard yet...) it looks like my throughput per GPU is ~465 TFLOPS/s/GPU, elapsed time per iteration is ~1100 ms; and the current Megatron dev branch's cutile kernels have ~445 TFLOPS/s/GPU, elapsed time per iteration ~1160 ms. The baseline (no mHC) is ~535 TFLOPS/s/GPU, elapsed time per iteration ~960ms. I think the improvement might vary according to different training configs but I haven't tried other configs yet.

Note that I never managed to reach 10K step with Megatron's cutile implementation. I always crash after I go a few thousand steps and that's why the cutile curve in my tensorboard stops very early since I can't go further. I suspect it's because their sinkhorn is not calculated in log space as I do, which causes numerical issues when it's bf16 (I'm not sure though. I didn't look really deep into their implementation).

@kainzhong kainzhong marked this pull request as ready for review April 8, 2026 18:12
@kainzhong kainzhong changed the title [WIP] [Common, PyTorch] Add triton mHC kernels & pytorch operators [Common, PyTorch] Add triton mHC kernels & pytorch APIs Apr 8, 2026
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 8, 2026

Greptile Summary

This PR implements DeepSeek's mHC (Manifold-Constrained Hyper-Connections) paper via five new Triton kernels (projection, scale, sinkhorn, aggregate, expand_combine) with corresponding PyTorch autograd wrappers. The mathematical correctness of all forward and backward passes was verified — sigmoid/RMSNorm gradients in the scale kernel, Sinkhorn log-space iteration reversals, and outer-product chain-rules in expand/combine are all correct.

Confidence Score: 5/5

Safe to merge; only P2 style findings, all correctness-critical paths are sound

All five kernel backward passes are mathematically correct. The two findings are both P2: a dead memory allocation and unused kernel parameters. No logic bugs, shape mismatches, or thread-safety issues were found.

transformer_engine/pytorch/triton/mhc.py (dead grad_x allocation), transformer_engine/common/triton/mhc.py (unused stride params)

Important Files Changed

Filename Overview
transformer_engine/common/triton/mhc.py New Triton kernels for all five mHC operations; mathematically correct backward passes; minor unused stride parameters in the projection backward kernel
transformer_engine/pytorch/triton/mhc.py PyTorch autograd wrappers for all five kernels; shape handling and chain-rule derivations are correct; contains one dead grad_x allocation in mHCProjectionOp.backward
tests/pytorch/test_mhc.py Comprehensive forward and backward correctness tests for all five ops across fp32/bf16; reference implementations are clear and cover all gradient paths
docs/api/pytorch.rst Adds five autoapifunction directives for the new mHC wrapper functions, correctly using autoapifunction (not autoapiclass)
qa/L0_pytorch_unittest/test.sh Adds test_mhc.py to the CI runner with NVTE_DISABLE_TRITON_AUTOTUNING=1 and NVIDIA_TF32_OVERRIDE=0 for reproducible, faster tests

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["x: (M, nC)\nphi: (N, K)"] -->|mhc_fused_projection| B["H: (M,32)\nms: (M,)"]
    B -->|mhc_fused_scale| C["H_pre: (M,n)\nH_post: (M,n)\nH_res: (M,n²)"]
    C -->|mhc_fused_sinkhorn| D["H_res_norm: (s,b,n,n)\ndoubly-stochastic"]
    E["x_stream: (s,b,C,n)"] -->|mhc_fused_aggregate + H_pre| F["out: (s,b,C)"]
    F -->|Attention/FFN sublayer| G["f: (s,b,C)"]
    G -->|mhc_fused_expand_combine + H_post + H_res| H["out: (s,b,C,n)"]
    E --> H
    D --> H
Loading

Reviews (15): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread tests/pytorch/test_mhc.py
Comment thread transformer_engine/common/triton/mhc.py Outdated
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Comment thread tests/pytorch/test_mhc.py Outdated
Comment thread tests/pytorch/test_mhc.py Outdated
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 8, 2026

Tip:

Greploops — Automatically fix all review issues by running /greploops in Claude Code. It iterates: fix, push, re-review, repeat until 5/5 confidence.

Use the Greptile plugin for Claude Code to query reviews, search comments, and manage custom context directly from your terminal.

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Comment thread transformer_engine/pytorch/triton/mhc.py Outdated
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Comment thread tests/pytorch/test_mhc.py Outdated
mhc_fused_projection,
)

reset_rng_states()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done per test to enable running both the full test suite and the individual tests with the same inputs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made each test call this function instead now

Comment thread tests/pytorch/test_mhc.py Outdated
Comment on lines +21 to +22
# Enable TF32 for matmul to ensure consistency between the fused and reference implementations
torch.backends.cuda.matmul.allow_tf32 = False
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either the comment is wrong or the code is wrong here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I forgot to change the comment. Fixed now

Comment thread tests/pytorch/test_mhc.py
n: int = 4 # Number of Hyper Connection streams

allow_n = [
4,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end, do we only allow 4 here or do we also work for n equal to 2?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have another branch for the n=2 kernels but I haven't done the pretraining and gather metrics. I was thinking about making that a separate PR later but I can combine them into one as well. It shouldn't be too hard since n=4 is already validated so I expect n=2 to be working as well.

Comment thread tests/pytorch/test_mhc.py
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For FP32 this tolerance seems a little high. What is the test that needs that tolerance?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used this for all FP32 allclose tests. I tried stricter tolerance but I think my extensive usage of atomic add makes it hard to full match pytorch's result. One mismatch comes from the projection kernel where I use split-K with atomic add for the GEMM, and another one comes from aggregate and expand's H_pre / H_post where the gradient is computed as (1, C) @ (C, n) = (1, n) and C >> n.

Comment on lines +262 to +299
if use_tf32:
_mhc_projection_fwd_fused[grid](
x_ptr=x, # (M, K)
phi_ptr=phi, # (N, K)
h_ptr=H, # (M, 32)
ms_ptr=ms, # (M,)
M=M,
N=N,
K=K,
stride_xm=K,
stride_xk=1,
stride_phin=K,
stride_phik=1,
stride_hm=32,
stride_hn=1,
stride_ms=1,
BLOCK_SIZE_N=32,
precision="tf32",
)
else:
_mhc_projection_fwd_fused[grid](
x_ptr=x, # (M, K)
phi_ptr=phi, # (N, K)
h_ptr=H, # (M, 32)
ms_ptr=ms, # (M,)
M=M,
N=N,
K=K,
stride_xm=K,
stride_xk=1,
stride_phin=K,
stride_phik=1,
stride_hm=32,
stride_hn=1,
stride_ms=1,
BLOCK_SIZE_N=32,
precision="ieee",
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if use_tf32:
_mhc_projection_fwd_fused[grid](
x_ptr=x, # (M, K)
phi_ptr=phi, # (N, K)
h_ptr=H, # (M, 32)
ms_ptr=ms, # (M,)
M=M,
N=N,
K=K,
stride_xm=K,
stride_xk=1,
stride_phin=K,
stride_phik=1,
stride_hm=32,
stride_hn=1,
stride_ms=1,
BLOCK_SIZE_N=32,
precision="tf32",
)
else:
_mhc_projection_fwd_fused[grid](
x_ptr=x, # (M, K)
phi_ptr=phi, # (N, K)
h_ptr=H, # (M, 32)
ms_ptr=ms, # (M,)
M=M,
N=N,
K=K,
stride_xm=K,
stride_xk=1,
stride_phin=K,
stride_phik=1,
stride_hm=32,
stride_hn=1,
stride_ms=1,
BLOCK_SIZE_N=32,
precision="ieee",
)
_mhc_projection_fwd_fused[grid](
x_ptr=x, # (M, K)
phi_ptr=phi, # (N, K)
h_ptr=H, # (M, 32)
ms_ptr=ms, # (M,)
M=M,
N=N,
K=K,
stride_xm=K,
stride_xk=1,
stride_phin=K,
stride_phik=1,
stride_hm=32,
stride_hn=1,
stride_ms=1,
BLOCK_SIZE_N=32,
precision="tf32" if use_tf32 else "ieee",
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And similar for the other cases in this file.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed


def mhc_fused_projection(x: torch.Tensor, phi: torch.Tensor, use_tf32: bool = True):
"""
Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, seciton 4.3.1 of the DeepSeek mHC paper):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, seciton 4.3.1 of the DeepSeek mHC paper):
Fused projection operation to compute H matrices and mean square for RMSNorm (see eq. 14-15, section 4.3.1 of the DeepSeek mHC paper):

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

)

h_ptrs = h_ptr + offs_m[:, None] * stride_hm + offs_n_full[None, :] * stride_hn
tl.atomic_add(h_ptrs, h_acc, mask=mask_m[:, None], sem="relaxed")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the guards so that there is an error when somebody tries to run this function with NVTE_ALLOW_NONDETERMINISTIC_ALGO set to 0.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assertion for NVTE_ALLOW_NONDETERMINISTIC_ALGO in APIs where atomic_add might be called afterwards

)


def mhc_fused_sinkhorn(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the functions to the Sphinx documentation.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in docs/api/pytorch.rst

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 20, 2026

There are some other typos in the docstrings besides the one I flagged, could you check those docstrings?

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
ptrendx
ptrendx previously approved these changes Apr 23, 2026
@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 23, 2026

/te-ci pytorch

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
Comment on lines +336 to +348
allow_bf16_reduced_precision_reduction = (
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction
)
# Use FP32 accumulator in case of pytorch choosing a path with BF16 accumulator which hurts accuracy,
# which seems to happen on Ampere but not on Hopper and Blackwell
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
grad_phi = (grad_H.T @ x)[:N, :].to(
ctx.phi_dtype
) # (2n + n^2, M) @ (M, nC) = (2n + n^2, nC), note that the last dimension of grad_H is already padded to 32
# Recover the original pytorch setting
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (
allow_bf16_reduced_precision_reduction
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Global state mutation in backward is not thread-safe

mHCProjectionOp.backward temporarily sets torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False on a process-global object. If two backward passes run concurrently (e.g., in a DataParallel setup where the autograd engine spawns multiple CUDA streams, or when users manually overlap backward passes), one thread can restore the flag while another is still executing the BF16 matmul — silently producing lower-precision gradients or overwriting the caller's original setting.

A safer approach is to cast inputs to float32 explicitly before the matmul instead of relying on the global flag.

@kainzhong
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

ptrendx
ptrendx previously approved these changes Apr 25, 2026
@alint77
Copy link
Copy Markdown

alint77 commented Apr 26, 2026

@kainzhong please check out my implementation of mHC-lite in triton, I mostly focused on SM120 and SM90 and all the kernels perform at ~80-95% SOL at T=65536,H=768,N=4

I've implemented it in our protein language modeling project with Rostlab

@sbhavani
Copy link
Copy Markdown
Collaborator

@kainzhong have we compared with the original TileLang implementation: https://github.com/deepseek-ai/TileKernels/tree/main?

@kainzhong
Copy link
Copy Markdown
Collaborator Author

kainzhong commented Apr 27, 2026

@sbhavani I asked claude code to vibe code some benchmark scripts and ran on B200. Here are the results: https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/benchmark.md

(triton - this implementation, cutile - Megatron-LM's fused kernels, tilelang - DeepSeek's tilelang kernels)

The naming for different implementation is a bit chaotic so refer to https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/kernel_comparison.md#kernel-name--framework-op-direction-mapping if you get lost

In terms of the performance, I think the major difference is DeepSeek's implementation prioritizes determinism and my implementation prefers efficiency. In my implementation I use atomic add heavily whereas DeepSeek seems to avoid that. I can work on a set of deterministic kernels if it's needed. (Btw I took a brief look at @alint77 's kernels and they seem to be deterministic)
Another difference is I use (s, b, C, n) layout which I think is more coalesced than (s, b, n, C).

Another thing is DeepSeek seems to have a CUDA C++ kernel for eq. 14-15 in https://github.com/deepseek-ai/DeepGEMM/blob/891d57b4db1071624b5c8fa0d1e51cb317fa709f/deep_gemm/__init__.py#L69 which supports split-K for this GEMM whereas tilelang implementation doesn't:

# TileLang implementation doesn't support split-k, so we set n_splits to 1
# You may want to adopt the DeepGEMM implementation with split-k for better performance
n_splits = 1

I'll profile this one separately and if the current implementation turns to be slower then I can work on a CuTeDSL / cutlass version and try to match the performance.

Aside of performance, DeepSeek has some additional optimizations that are not mentioned in the original mHC paper. I asked claude code to write a summary in https://github.com/kainzhong/mhc_bench/blob/benchmark_b200/triton_vs_tilelang.md but I don't fully trust it (I think it hallucinates in some parts). I'll take a closer look next week.

image

I only skimmed through DeepSeek's code and this is my rudimentary analysis for their differences:

Eq. 14-15 (projection + RMSnorm)

  • mhc_norm_weight: they support this additional parameter. I think in this path it's applying a affine parameter so it's the same as torch.nn.RMSNorm with elementwise_affine=True, whereas my implementation is the elementwise_affine=False path. I can work on supporting this later.
  • mhc_norm_eps: they support passing a user defined eps whereas I hardcoded one. This should be a minor difference
  • fuse_grad_acc: looks like they have an addition optimization to add the pytorch grad to the grad accumulator directly to skip copying the pytorch grad to the accumulator. I can work with Megatron devs and support this later.

Eq. 16-18

  • mhc_post_mult_value: they allow user to pass a customized number to multiple H_post's sigmoid with. I hardcoded it to 2 as this is the number the mHC paper uses.
  • mhc_pre_eps: another user defined eps.

Eq. 19

  • eps: another user defined eps.
    In addition, they compute sinkhorn in linear space where I compute it in log space. Mathematically I think they should be equivalent, but I think log space is more friendly to numerical stability

There are another two kernels for F_pre = H_pre @ x and F_post_res = H_res @ x + H_post @ out but they don't differ much functionally.

In addition, they have a few other kernels.

  • mhc_head_compute_mix: I haven't fully figured it out but it has torch.sigmoid(mhc_head_layer_mix) so I feel like this is the path where you only need H_pre matrix (reducing (s, b, n, C) to (s, b, c) without expanding it back, so I guess this is used in the last layer?)
  • expand_to_mhc: expand (s, b, C) to (s, b, n, C). I wonder why they need a kernel for that? torch.view should do this efficiently enough...
  • mhc_pre_big_fuse: for inference only, which seems to fuse eq. 14-19 together
  • mhc_multilayer_recompute: I feel like it's for recomputation

I can work with the Megatron team to integrate these kernels if needed.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Apr 27, 2026

/te-ci pytorch

@sbhavani
Copy link
Copy Markdown
Collaborator

@kainzhong thanks for the benchmark! This is promising that Triton performs best in the per op training kernels.

I was also wondering if TileLang's extra fusions might benefit e2e training performance. Did you also run a full training step with Megatron?

@kainzhong
Copy link
Copy Markdown
Collaborator Author

@sbhavani Note that triton kernels gain advantage by aggressively apply split-K strategy to maximize parallelism, whereas tilelang kernels sacrifice some performance for determinism. See DeepSeek V4 tech report's section 3.3:

• Matrix Multiplication in mHC. mHC involves a matrix multiplication with an output di-
mension of only 24. For very small batch sizes, we are compelled to use the split-k (Osama
et al., 2023) algorithm, whose naive implementation will cause non-determinism. To
overcome this, we output each split part separately and perform a deterministic reduction
in a subsequent kernel, thereby preserving both performance and determinism.
If Megatron plans to support the same determinism, we might need a different set of kernels

As for e2e performance I'll run some experiments this week -- though I believe at least mhc_multilayer_recompute will benefit the recomputation. For mhc_head_compute_mix I'm not fully sure where it's applied, and for expand_to_mhc I'm not sure how much it will help.

@kainzhong
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Signed-off-by: Kaining Zhong <kainingz@nvidia.com>
@kainzhong
Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants