Skip to content

[Experimental][DNM till upstream PR merges][AMD] perf: hybrid MXFP8 MoE for MiniMax M3 on MI300X#1753

Open
Oseltamivir wants to merge 18 commits into
mainfrom
feat/m3-mi300x-mxfp8
Open

[Experimental][DNM till upstream PR merges][AMD] perf: hybrid MXFP8 MoE for MiniMax M3 on MI300X#1753
Oseltamivir wants to merge 18 commits into
mainfrom
feat/m3-mi300x-mxfp8

Conversation

@Oseltamivir

@Oseltamivir Oseltamivir commented Jun 14, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • apply the gfx94x MiniMax M3 MXFP8 implementation from
    [ROCm] Add fused MXFP8 MoE for gfx94x vllm-project/vllm#45567 to the existing ROCm image
  • preserve the main TP8 and TP8+EP8 configurations and concurrency matrix
  • keep the measured hybrid dispatch between load-time BF16 experts and native
    compressed W8A8
  • pack gfx94x E8M0 weight scales as [expert, K/32, N] so the native kernel
    loads adjacent output-column scales contiguously

Why this improves the end-to-end path

The previous native kernel benchmark used random routing. Capturing and
replaying InferenceX routes exposed a production-specific memory-layout issue.

For TP8 GEMM1, K=6144, so K/32=192. The checkpoint layout
[expert, N, K/32] makes scale bytes for adjacent output columns 192 bytes
apart at a fixed K group. The new load-time layout
[expert, K/32, N] makes those bytes contiguous.

Captured TP8 route latency:

Input tokens Canonical scales (ms) Packed scales (ms) Latency reduction
16 0.177704 0.172976 2.7%
32 0.232465 0.212536 8.6%
64 0.301167 0.269059 10.7%
128 0.420344 0.329749 21.6%

The kernel gain is diluted end to end by attention, collectives, shared
experts, routing/alignment, and top-k reduction, but the serving curve still
improves materially.

Short-K GEMM2 specialization

The matched route corpus shows that the dominant TP decode GEMM2 is
N=6144, K=384 with BLOCK_M=16. A dedicated
BLOCK_N=64, BLOCK_K=64, num_warps=2 configuration replaces the generic
32/128/1 configuration only for this short-K gfx94x GEMM2 regime.

Paired same-GPU replay:

Route sample Generic (ms) Specialized (ms) Reduction
median, GPU 0 0.215182 0.210647 2.15%
median, GPU 1 0.212965 0.208603 2.09%
p90, GPU 2 0.235560 0.230146 2.35%
p90, GPU 3 0.234272 0.230142 1.79%

GEMM1, expert-parallel GEMM2, CDNA4, and larger routed batches keep their
previous configurations. Numerical error remains in the same native W8A8
envelope.

Serving results

Reference runs:

The final sweep completed with 29 successful jobs, no failures, 18 serving
result rows, and both accuracy-eval jobs passing. The TP8 and TP8+EP8
parallelism configurations are unchanged from main.

1K/1K

Parallelism Concurrency BF16 emulation Final hybrid Change
TP8 4 80.10 79.64 -0.58%
TP8 8 128.71 130.58 +1.45%
TP8 16 203.66 214.95 +5.54%
TP8 32 299.64 340.83 +13.75%
TP8 64 365.84 497.38 +35.96%
TP8 128 590.42 715.70 +21.22%
TP8+EP8 256 777.81 852.48 +9.60%

Concurrency 4 is on the BF16 fallback for 9,527 of 9,529 captured forwards,
so it does not measure the native kernel. The official point also had a
1.072 s p99 TTFT outlier versus 0.282 s in the baseline. Three exact-command
same-node repeats measured 80.01, 80.38, and 80.41 tok/s/GPU; the median is
+0.34% versus baseline.

8K/1K

Parallelism Concurrency BF16 emulation Final hybrid Change
TP8 4 318.55 323.57 +1.58%
TP8 8 467.39 474.84 +1.59%
TP8 16 672.49 697.18 +3.67%
TP8 32 879.53 938.28 +6.68%
TP8 64 976.99 1200.55 +22.88%
TP8+EP8 128 1111.17 1142.22 +2.79%
TP8+EP8 256 1190.74 1216.81 +2.19%

Why not dequantize weights inside BF16 MMA

MiniMax M3 enters the MoE layer with BF16 hidden states, so native W8A8 does
launch an activation-quantization kernel before GEMM1. A Marlin-style W8A16
prototype was tested to avoid that launch by expanding compressed weights in
registers before BF16 MMA.

Provider Captured 64-token route (ms) Relative error vs. BF16
packed native W8A8 0.267857 about 4.2%
Marlin-style W8A16 0.464929 0.452%

W8A16 avoids quantizing the M*K activation tensor but repeatedly converts
much larger N*K expert weight tiles. MI300X BF16 MFMA also consumes half as
many K elements per instruction as FP8 MFMA. It is more accurate, but 1.74x
slower on this production route.

Scale algebra

For each K=32 MXFP8 group:

sum(k in g) (q_a[k] * s_a) * (q_b[k] * s_b)
  = (s_a * s_b) * sum(k in g) q_a[k] * q_b[k]

The native kernel applies s_a * s_b after each K=32 partial dot, not after
the complete matmul. This is valid because the scales are constant inside
that group; applying one scale across multiple groups would be incorrect.

Scope

This PR changes only minimaxm3-fp8-mi300x-vllm. The separate
minimaxm3-fp8-mi300x-vllm-mtp EAGLE3 recipe and its sweep are intentionally
unchanged and were not part of this validation matrix.

The runtime patch is applied only by
benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh. MI355X does
not load it. Upstream scale packing is gated to gfx94x FNUZ, while
gfx950/MI355X uses OCP FP8.

MI355X sanity reference:
https://github.com/SemiAnalysisAI/InferenceX/actions/runs/27452497472
completed successfully with 63 result rows and no failed benchmarks.

Validation

  • patch applies cleanly to image commit
    4a560dd8db67c270f5e2afb614558271b76f2294
  • launcher now fails if patch fails or the backend marker is absent
  • image-compatible files pass python -m py_compile
  • python -m pytest utils/matrix_logic/ -q: 156 passed
  • bash -n benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh
  • git diff --check
  • image-targeted MI300X MXFP8 tests: 48 passed, 3 skipped, 30 deselected
  • exact route replay numerical error: 4.09-4.33%, cosine floor 0.999066
  • GSM8K strict exact match: 0.956785
  • GSM8K flexible exact match: 0.956027

Note

Medium Risk
Runtime patching of vLLM MoE/quantization paths affects numerical behavior and serving correctness on MI300X only, but the change is large, experimental until upstream merges, and scoped to one benchmark recipe.

Overview
Adds a runtime vLLM patch for minimaxm3-fp8-mi300x-vllm: minimaxm3_fp8_mi300x.sh applies minimaxm3_mi300x_mxfp8.patch to the installed package before vllm serve, with idempotent apply and a hard fail if the gfx94x backend marker is missing.

The patch introduces a fused CDNA3 (gfx942) MXFP8 MoE path—E4M3FNUZ weights, in-kernel E8M0 scale products, packed scales as [expert, K/32, N], and MI300X Triton tile configs—plus a hybrid dispatch: BF16 experts for TP decode (≤8 tokens) and large prefill (≥832), BF16 emulation under expert parallelism, and native compressed W8A8 between those bands; for long context, every fifth layer can store BF16-only weights while others keep dual MXFP8/BF16 buffers.

Config comments and perf-changelog.yaml are updated to describe this hybrid recipe instead of pure BF16 emulation; the TP8 / TP8+EP8 sweep matrix is unchanged.

Reviewed by Cursor Bugbot for commit d1638a0. Bugbot is set up for automated code reviews on this repo. Configure here.

@github-actions

Copy link
Copy Markdown
Contributor

Thanks for the contribution! For vLLM & SGLang, please ensure that your recipes is similar to the official vLLM recipes and/or the SGLang cookbook

If it is not, please create a PR first before we can merge your single node PR into the master branch. Let's ensure that the documentation is first class such that the entire ML community can benefit from your hard work! Thank you

PR authors are responsible for ensuring that after merging, all GitHub Action jobs fully pass. A lot of the time, failures are just flakes and simply re-running the failed jobs will fix it. If re-running failed jobs is attempted, PR authors are responsible for ensuring it passes. See GitHub's docs on re-running failed jobs: https://docs.github.com/en/actions/how-tos/manage-workflow-runs/re-run-workflows-and-jobs#re-running-failed-jobs-in-a-workflow

As a rule of thumb, generally, PR authors should request a review & get a PR approval from the respective companies' CODEOWNERS before requesting a review from core maintainers.

If additional help is needed, PR authors can reach out to core maintainers over Slack.

@Oseltamivir Oseltamivir marked this pull request as ready for review June 14, 2026 06:42
@Oseltamivir Oseltamivir requested a review from a team June 14, 2026 06:42
@claude

claude Bot commented Jun 14, 2026

Copy link
Copy Markdown
Contributor

Claude finished @Oseltamivir's task in 3m 50s —— View job


Review of PR #1753

  • Gather context and read changed files
  • Validate master config changes
  • Validate perf-changelog entries
  • Validate benchmark script
  • Review runtime patch (Triton kernels, normalization, oracle changes)
  • Post review results

LGTM — no blocking issues found.

Summary: This PR adds a fused CDNA3 MXFP8 MoE backend for MiniMax-M3 on MI300X via a runtime vLLM patch. I reviewed all four changed files and the 656-line Triton kernel patch in detail. Key findings:

  • E4M3FN → E4M3FNUZ normalization (normalize_mxfp8_e4m3fn_to_e4m3fnuz): The math is correct — FNUZ's exponent bias is 1 higher, so incrementing the E8M0 scale compensates; negative zero (0x80) is properly canonicalized. The overflow guard (scales.max() >= 254) correctly prevents incrementing into E8M0 NaN (255).
  • FNUZ kernel scale computation (tl.exp2(asc + bsc - 254.0)): Correct — two E8M0 dequants 2^(a-127) * 2^(b-127) = 2^(a+b-254).
  • Split-K: Occupancy-aware selection, FP32 accumulation buffer when split_k > 1, tl.atomic_add for reduction — all consistent.
  • Grid bound (min(sorted_token_ids.shape[0], M_routed * block_m)): Valid upper bound since active experts ≤ M_routed, and the kernel's num_post guard handles any overestimate.
  • Benchmark script: Expert parallelism correctly conditioned on EP_SIZE, vllm serve arguments properly formatted on separate lines, patch application is idempotent via the grep guard + --forward.
  • Perf-changelog: New entry correctly appended at the end of the file.
  • Master config: Only the comment was updated; no functional config or image changes.

@functionstackx functionstackx left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz create upstream PR and have it reviewed before merging this patch

@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

Comment thread benchmarks/single_node/fixed_seq_len/minimaxm3_fp8_mi300x.sh
@Oseltamivir

Copy link
Copy Markdown
Collaborator Author

Opened the requested upstream vLLM PR: vllm-project/vllm#45567. It is stacked on the active MiniMax M3 model branch/PR (#45381), includes the tested gfx94x MXFP8 kernel and benchmark, and passes all vLLM pre-commit hooks. The InferenceX patch has also been updated to the optimized tile selection and no longer uses split-K.

@github-actions

Copy link
Copy Markdown
Contributor

@Oseltamivir Oseltamivir changed the title [AMD] feat: native MXFP8 MoE for MiniMax M3 on MI300X [AMD] perf: hybrid MXFP8 MoE for MiniMax M3 on MI300X Jun 14, 2026
@github-actions

Copy link
Copy Markdown
Contributor

@cursor cursor Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

There are 2 total unresolved issues (including 1 from previous review).

Fix All in Cursor

❌ Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

Reviewed by Cursor Bugbot for commit c3cdc37. Configure here.

MXFP8_ORACLE="$VLLM_PACKAGE_ROOT/vllm/model_executor/layers/fused_moe/oracle/mxfp8.py"
if ! grep -q "Using fused CDNA3 (gfx94x)" "$MXFP8_ORACLE"; then
patch --batch --forward -d "$VLLM_PACKAGE_ROOT" -p1 < "$MXFP8_PATCH"
fi

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MTP script skips MXFP8 patch

Medium Severity

Runtime MXFP8 patching was added only to the non-MTP MI300X benchmark script. launch_mi300x-amds.sh runs minimaxm3_fp8_mi300x_mtp.sh for spec-decoding: mtp configs, so those jobs never apply minimaxm3_mi300x_mxfp8.patch despite the MTP script claiming it mirrors this recipe.

Fix in Cursor Fix in Web

Reviewed by Cursor Bugbot for commit c3cdc37. Configure here.

@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

1 similar comment
@github-actions

Copy link
Copy Markdown
Contributor

Oseltamivir and others added 2 commits June 14, 2026 09:37
Co-authored-by: OpenAI Codex <codex@openai.com>
@github-actions

Copy link
Copy Markdown
Contributor

Oseltamivir and others added 2 commits June 14, 2026 13:34
Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>

# Conflicts:
#	perf-changelog.yaml
@Oseltamivir

Copy link
Copy Markdown
Collaborator Author

Packed-scale follow-up is pushed in 7678b0bc (merge refresh 684b6a3d). Matched local results, with parallelism unchanged:

  • 1K/1K TP8: 214.239 / 329.709 / 491.056 / 707.749 tok/s/GPU at concurrency 16/32/64/128
  • 8K/1K TP8 c64: 1199.146 tok/s/GPU
  • 8K c64 is +5.60% vs the previous hybrid sweep and +22.74% vs BF16 emulation

Full validation sweep: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/27511311644

@functionstackx

Copy link
Copy Markdown
Collaborator

@Oseltamivir 's AI agent, remember to have ur search space start at conc=1 like i am fixing it rn #1760

@functionstackx functionstackx changed the title [AMD] perf: hybrid MXFP8 MoE for MiniMax M3 on MI300X [Experimental][DNM till upstream PR merges][AMD] perf: hybrid MXFP8 MoE for MiniMax M3 on MI300X Jun 14, 2026
@Oseltamivir

Copy link
Copy Markdown
Collaborator Author

Official sweep update: the pushed 8K/1K TP8 c64 job passed at 1167.724 tok/s/GPU. That is +2.83% over run 27506382432 and +19.52% over the original BF16-emulation run 27489075807. Job: https://github.com/SemiAnalysisAI/InferenceX/actions/runs/27511311644/job/81312766431

@github-actions

Copy link
Copy Markdown
Contributor

Oseltamivir and others added 2 commits June 14, 2026 17:13
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
@github-actions

Copy link
Copy Markdown
Contributor

Oseltamivir and others added 2 commits June 14, 2026 18:41
Co-authored-by: OpenAI Codex <codex@openai.com>
Signed-off-by: Oseltamivir <58582368+Oseltamivir@users.noreply.github.com>
@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

@github-actions

Copy link
Copy Markdown
Contributor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants