From 4bd1efb16cc69c67bc0b11e1dc73c1f3fb824be3 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 11:16:52 -0700 Subject: [PATCH 1/7] add qwen3 model to example Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/config.py | 50 ++++ examples/pytorch/qwen3_moe/model.py | 336 +++++++++++++++++++++++ examples/pytorch/qwen3_moe/test_vs_hf.py | 219 +++++++++++++++ 3 files changed, 605 insertions(+) create mode 100644 examples/pytorch/qwen3_moe/config.py create mode 100644 examples/pytorch/qwen3_moe/model.py create mode 100644 examples/pytorch/qwen3_moe/test_vs_hf.py diff --git a/examples/pytorch/qwen3_moe/config.py b/examples/pytorch/qwen3_moe/config.py new file mode 100644 index 0000000000..5c34ab733b --- /dev/null +++ b/examples/pytorch/qwen3_moe/config.py @@ -0,0 +1,50 @@ +"""Configuration for Qwen3 MoE model. + +Default values match the HuggingFace Transformers Qwen3MoeConfig. +""" + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class Qwen3MoeConfig: + """Configuration class for Qwen3 MoE model. + + Attributes: + vocab_size: Size of the vocabulary. + hidden_size: Dimensionality of the hidden representations. + moe_intermediate_size: Dimensionality of each MoE expert's intermediate layer. + num_hidden_layers: Number of decoder layers. + num_attention_heads: Number of query attention heads. + num_key_value_heads: Number of key/value attention heads (for GQA). + max_position_embeddings: Maximum sequence length supported by RoPE. + initializer_range: Standard deviation for weight initialization. + rms_norm_eps: Epsilon for RMSNorm layers. + rope_theta: Base frequency for RoPE. + attention_bias: Whether to use bias in attention projections. + attention_dropout: Dropout rate for attention weights. + num_experts: Total number of MoE experts. + top_k: Number of experts selected per token (top-k). + norm_topk_prob: Whether to renormalize top-k routing probabilities. + """ + + vocab_size: int = 151936 + hidden_size: int = 2048 + moe_intermediate_size: int = 768 + num_hidden_layers: int = 24 + num_attention_heads: int = 32 + num_key_value_heads: int = 4 + max_position_embeddings: int = 32768 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + attention_bias: bool = False + attention_dropout: float = 0.0 + num_experts: int = 128 + top_k: int = 8 + norm_topk_prob: bool = False + + @property + def head_dim(self) -> int: + """Dimensionality of each attention head.""" + return self.hidden_size // self.num_attention_heads diff --git a/examples/pytorch/qwen3_moe/model.py b/examples/pytorch/qwen3_moe/model.py new file mode 100644 index 0000000000..488c107a25 --- /dev/null +++ b/examples/pytorch/qwen3_moe/model.py @@ -0,0 +1,336 @@ +"""Qwen3 MoE model implementation using TransformerEngine modules. + +Same architecture as HuggingFace Transformers Qwen3MoeForCausalLM, with PyTorch modules replaced by TransformerEngine +equivalents for FP8 training and fused kernels. + +TE module mapping (HF -> TE): + self_attn (full block) -> te.MultiheadAttention (fused LN + QKV + QK-norm + RoPE + attn + O) + post_attn_layernorm (MoE) -> te.RMSNorm + expert MLP (SwiGLU) -> te_ops.Sequential(GroupedLinear, SwiGLU, GroupedLinear) + final norm -> te.RMSNorm + lm_head -> te.Linear + RoPE frequencies -> te.RotaryPositionEmbedding +""" + +from collections.abc import Callable +from typing import override + +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformer_engine.pytorch as te +from transformer_engine.pytorch import attention as te_attention +from transformer_engine.pytorch import ops as te_ops + +import config as qwen3_moe_config + + +def _make_init_fn(std: float) -> Callable[[torch.Tensor], None]: + """Create a normal-distribution weight initializer for TE modules.""" + + def _init(weight: torch.Tensor) -> None: + nn.init.normal_(weight, mean=0.0, std=std) + + return _init + + +class Qwen3MoeRouter(nn.Module): + """Top-k softmax router for MoE expert selection. + + Computes softmax over expert logits, selects the top-k experts per token, and returns outputs in the mask format + expected by ``te.moe_permute_with_probs`` / ``te.moe_unpermute``. + + Args: + hidden_size: Dimensionality of input hidden states. + num_experts: Total number of experts. + top_k: Number of experts selected per token (top-k). + norm_topk_prob: Whether to renormalize top-k probabilities to sum to 1. + initializer_range: Std for normal initialization of the routing weight. + """ + + def __init__( + self, + hidden_size: int, + num_experts: int, + top_k: int, + norm_topk_prob: bool, + initializer_range: float, + ) -> None: + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.norm_topk_prob = norm_topk_prob + self.weight = nn.Parameter(torch.empty(num_experts, hidden_size)) + nn.init.normal_(self.weight, mean=0.0, std=initializer_range) + + @override + def forward( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute expert routing for a batch of tokens. + + Args: + hidden_states: ``(num_tokens, hidden_size)``. + + Returns: + merging_probs: ``(num_tokens, num_experts)`` with top-k entries filled, rest zero. + routing_map: ``(num_tokens, num_experts)`` int32 mask. + tokens_per_expert: ``(num_experts,)`` token counts. + router_logits: ``(num_tokens, num_experts)`` pre-softmax logits. + """ + router_logits = F.linear(hidden_states, self.weight) # pylint: disable=not-callable + probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) + topk_probs, topk_indices = torch.topk(probs, self.top_k, dim=-1) + + if self.norm_topk_prob: + topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) + + routing_map = torch.zeros( + hidden_states.shape[0], self.num_experts, dtype=torch.int32, device=hidden_states.device + ) + routing_map.scatter_(1, topk_indices, 1) + + merging_probs = torch.zeros_like(probs) + merging_probs.scatter_(1, topk_indices, topk_probs) + + tokens_per_expert = routing_map.sum(dim=0) + return merging_probs, routing_map, tokens_per_expert, router_logits + + +class Qwen3MoeBlock(nn.Module): + """Mixture-of-Experts feed-forward block with SwiGLU activation. + + Routes tokens to top-k experts via ``Qwen3MoeRouter``, then applies per-expert SwiGLU using + ``te_ops.Sequential(GroupedLinear, SwiGLU, GroupedLinear)`` for fused batched GEMMs + activation. Token dispatch + and combine are handled by ``te.moe_permute_with_probs`` / ``te.moe_unpermute``. + + Args: + config: Model configuration. + """ + + def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: + super().__init__() + self.num_experts = config.num_experts + + self.router = Qwen3MoeRouter( + hidden_size=config.hidden_size, + num_experts=config.num_experts, + top_k=config.top_k, + norm_topk_prob=config.norm_topk_prob, + initializer_range=config.initializer_range, + ) + + self.expert_mlp = te_ops.Sequential( + te_ops.GroupedLinear( + config.num_experts, config.hidden_size, 2 * config.moe_intermediate_size, bias=False + ), + te_ops.SwiGLU(), + te_ops.GroupedLinear( + config.num_experts, config.moe_intermediate_size, config.hidden_size, bias=False + ), + ) + init_fn = _make_init_fn(config.initializer_range) + for param in self.expert_mlp.parameters(): + init_fn(param) + + @override + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Route tokens to experts and apply SwiGLU. + + Args: + hidden_states: ``(batch, seq_len, hidden_size)``. + + Returns: + output: ``(batch, seq_len, hidden_size)`` after expert computation. + router_logits: ``(batch * seq_len, num_experts)`` pre-softmax logits. + """ + batch_size, seq_len, hidden_dim = hidden_states.shape + hidden_flat = hidden_states.view(-1, hidden_dim) + + merging_probs, routing_map, tokens_per_expert, router_logits = self.router(hidden_flat) + + num_out_tokens = self.router.top_k * router_logits.shape[0] + permuted_input, _, row_id_map = te.moe_permute_with_probs( + hidden_flat, merging_probs, routing_map, num_out_tokens=num_out_tokens + ) + + # Expert computation (fused GroupedLinear -> SwiGLU -> GroupedLinear). + # Pass tokens_per_expert tensor directly — no .tolist() CPU sync. + expert_out = self.expert_mlp(permuted_input, tokens_per_expert, tokens_per_expert) + + # Combine: scatter back to original order with probability weighting + output = te.moe_unpermute( + expert_out, row_id_map, merging_probs=merging_probs, restore_shape=hidden_flat.shape + ) + + return output.view(batch_size, seq_len, hidden_dim), router_logits + + +class Qwen3MoeDecoderLayer(nn.Module): + """Pre-norm decoder layer: fused attention + MoE feed-forward. + + Architecture: ``input_layernorm + MultiheadAttention`` (fused inside TE) followed by + ``post_attention_layernorm + Qwen3MoeBlock``, both with residual connections. + + Args: + config: Model configuration. + layer_idx: Zero-based layer index (passed as ``layer_number=idx+1`` to ``te.MultiheadAttention`` for internal + bookkeeping). + """ + + def __init__( + self, + config: qwen3_moe_config.Qwen3MoeConfig, + layer_idx: int, + ) -> None: + super().__init__() + init_fn = _make_init_fn(config.initializer_range) + + self.self_attn = te.MultiheadAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + kv_channels=config.head_dim, + num_gqa_groups=config.num_key_value_heads, + attention_dropout=config.attention_dropout, + layernorm_epsilon=config.rms_norm_eps, + init_method=init_fn, + output_layer_init_method=init_fn, + layer_number=layer_idx + 1, + attn_mask_type="causal", + input_layernorm=True, + normalization="RMSNorm", + bias=config.attention_bias, + qkv_format="bshd", + qk_norm_type="RMSNorm", + qk_norm_eps=config.rms_norm_eps, + qk_norm_before_rope=True, + ) + + self.post_attention_layernorm = te.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.mlp = Qwen3MoeBlock(config) + + @override + def forward( # type: ignore[override] + self, + hidden_states: torch.Tensor, + freqs: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Apply attention and MoE feed-forward with residual connections. + + Args: + hidden_states: ``(batch, seq_len, hidden_size)``. + freqs: Rotary position embedding frequencies from ``te.RotaryPositionEmbedding``. + attention_mask: Optional ``(batch, seq_len)`` mask (1 = valid). + + Returns: + hidden_states: ``(batch, seq_len, hidden_size)``. + router_logits: ``(batch * seq_len, num_experts)`` from the MoE block. + """ + residual = hidden_states + hidden_states = self.self_attn( + hidden_states, attention_mask=attention_mask, rotary_pos_emb=freqs + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, router_logits = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, router_logits + + +class Qwen3MoeModel(nn.Module): + """Qwen3 MoE transformer backbone: embedding + decoder stack + final norm. + + Embeds input token IDs, applies rotary position embeddings, runs through ``num_hidden_layers`` decoder layers, and + applies a final RMSNorm. Returns hidden states and per-layer router logits (for auxiliary loss). + + Args: + config: Model configuration. + """ + + def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: + super().__init__() + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=config.initializer_range) + + self.layers = nn.ModuleList( + [Qwen3MoeDecoderLayer(config, idx) for idx in range(config.num_hidden_layers)] + ) + self.norm = te.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.rotary_emb = te_attention.RotaryPositionEmbedding( + dim=config.head_dim, rotary_base=config.rope_theta + ) + + @override + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_ids: ``(batch, seq_len)`` token IDs. + attention_mask: Optional ``(batch, seq_len)`` mask (1 = valid). + + Returns: + hidden_states: ``(batch, seq_len, hidden_size)``. + all_router_logits: List of per-layer router logit tensors. + """ + hidden_states = self.embed_tokens(input_ids) + seq_len = input_ids.shape[1] + + freqs = self.rotary_emb(max_seq_len=seq_len) + freqs = freqs.to(device=hidden_states.device, dtype=torch.float32) + + all_router_logits: list[torch.Tensor] = [] + for layer in self.layers: + hidden_states, router_logits = layer(hidden_states, freqs, attention_mask) + all_router_logits.append(router_logits) + + hidden_states = self.norm(hidden_states) + return hidden_states, all_router_logits + + +class Qwen3MoeForCausalLM(nn.Module): + """Qwen3 MoE causal language model: backbone + LM head. + + Wraps ``Qwen3MoeModel`` and adds a ``te.Linear`` LM head that projects hidden states to vocabulary logits. + + Args: + config: Model configuration. + """ + + def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: + super().__init__() + self.model = Qwen3MoeModel(config) + + init_fn = _make_init_fn(config.initializer_range) + self.lm_head = te.Linear( + config.hidden_size, + config.vocab_size, + bias=False, + init_method=init_fn, + ) + + @override + def forward( + self, + input_ids: torch.Tensor, + attention_mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Args: + input_ids: ``(batch, seq_len)`` token IDs. + attention_mask: Optional ``(batch, seq_len)`` padding mask. + + Returns: + logits: ``(batch, seq_len, vocab_size)``. + all_router_logits: Per-layer router logits from MoE layers. + """ + hidden_states, all_router_logits = self.model(input_ids, attention_mask) + logits = self.lm_head(hidden_states) + return logits, all_router_logits diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py new file mode 100644 index 0000000000..59ae4685b3 --- /dev/null +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -0,0 +1,219 @@ +"""Compare Qwen3 MoE TE implementation against HuggingFace reference. + +Runs forward and backward passes on both models with identical weights and +verifies that logits and gradients match. + +Usage: + python -m examples.pytorch.qwen3_moe.test_vs_hf [--seed 42] [--device cuda] + +Requirements: + pip install transformers +""" + +import argparse +import re + +import torch +from transformers.models.qwen3_moe import configuration_qwen3_moe, modeling_qwen3_moe + +import config as qwen3_moe_config +import model as qwen3_moe_model + +_BATCH = 1 +_SEQ_LEN = 32 + +_TEST_CONFIG = qwen3_moe_config.Qwen3MoeConfig( + vocab_size=512, + hidden_size=256, + moe_intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=2, + max_position_embeddings=128, + num_experts=8, + top_k=2, + norm_topk_prob=True, +) + + +def _to_hf_config( + config: qwen3_moe_config.Qwen3MoeConfig, +) -> configuration_qwen3_moe.Qwen3MoeConfig: + """Create a matching HuggingFace config.""" + return configuration_qwen3_moe.Qwen3MoeConfig( + vocab_size=config.vocab_size, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + moe_intermediate_size=config.moe_intermediate_size, + num_hidden_layers=config.num_hidden_layers, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + max_position_embeddings=config.max_position_embeddings, + num_experts=config.num_experts, + num_experts_per_tok=config.top_k, + norm_topk_prob=config.norm_topk_prob, + output_router_logits=True, + use_cache=False, + tie_word_embeddings=False, + ) + + +def _te_name_to_hf_param( + name: str, + hf_model: modeling_qwen3_moe.Qwen3MoeForCausalLM, + *, + return_grad: bool = False, +) -> torch.Tensor | None: + """Match a TE param name to the corresponding HF tensor (.data or .grad).""" + + def _v(param: torch.nn.Parameter) -> torch.Tensor: + if return_grad: + assert param.grad is not None, f"Expected gradient for {name}" + return param.grad + return param.data + + if name == "model.embed_tokens.weight": + return _v(hf_model.model.embed_tokens.weight) + if name.startswith("model.norm."): + return _v(hf_model.model.norm.weight) + if "lm_head" in name and "weight" in name: + return _v(hf_model.lm_head.weight) + + m = re.match(r"model\.layers\.(\d+)\.(.*)", name) + if not m: + return None + hf_layer = hf_model.model.layers[int(m.group(1))] + hf_attn = hf_layer.self_attn + suffix = m.group(2) + + if "post_attention_layernorm" in suffix: + return _v(hf_layer.post_attention_layernorm.weight) + if "layer_norm" in suffix: + return _v(hf_layer.input_layernorm.weight) + if "q_norm" in suffix: + return _v(hf_attn.q_norm.weight) + if "k_norm" in suffix: + return _v(hf_attn.k_norm.weight) + if "query" in suffix and "weight" in suffix: + return _v(hf_attn.q_proj.weight) + if "key" in suffix and "weight" in suffix: + return _v(hf_attn.k_proj.weight) + if "value" in suffix and "weight" in suffix: + return _v(hf_attn.v_proj.weight) + if "qkv" in suffix and "weight" in suffix: + return torch.cat( + [_v(hf_attn.q_proj.weight), _v(hf_attn.k_proj.weight), _v(hf_attn.v_proj.weight)], dim=0 + ) + if "self_attn" in suffix and "proj" in suffix and "weight" in suffix: + return _v(hf_attn.o_proj.weight) + if "router" in suffix and "weight" in suffix: + return _v(hf_layer.mlp.gate.weight) + + # te_ops.Sequential stores expert weights as expert_mlp.{seq_idx}.weight{expert_idx} + # seq_idx 0 = gate_up_proj (first GroupedLinear), seq_idx 2 = down_proj (second GroupedLinear) + em = re.search(r"expert_mlp\.(\d+)\.weight(\d+)", suffix) + if em: + seq_idx, expert_idx = int(em.group(1)), int(em.group(2)) + exp = hf_layer.mlp.experts + if seq_idx == 0: + # gate_up_proj: concat gate and up for this expert + if hasattr(exp, "gate_up_proj"): + return _v(exp.gate_up_proj)[expert_idx] + return torch.cat( + [_v(exp[expert_idx].gate_proj.weight), _v(exp[expert_idx].up_proj.weight)], dim=0 + ) + if seq_idx == 2: + # down_proj + if hasattr(exp, "down_proj"): + return _v(exp.down_proj)[expert_idx] + return _v(exp[expert_idx].down_proj.weight) + + return None + + +def _copy_hf_to_te( + hf_model: modeling_qwen3_moe.Qwen3MoeForCausalLM, + te_model: qwen3_moe_model.Qwen3MoeForCausalLM, +) -> None: + """Copy every parameter from the HF model into the TE model.""" + for name, param in te_model.named_parameters(): + hf_val = _te_name_to_hf_param(name, hf_model) + if hf_val is None: + raise ValueError(f"Unmapped TE parameter: {name} {tuple(param.shape)}") + param.data.copy_(hf_val) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compare TE vs HF Qwen3 MoE") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--device", type=str, default="cuda") + args = parser.parse_args() + + torch.manual_seed(args.seed) + device = args.device + + hf_config = _to_hf_config(_TEST_CONFIG) + hf_model = modeling_qwen3_moe.Qwen3MoeForCausalLM(hf_config).to( + dtype=torch.float32, device=device + ) + # Use powers-of-two weights for exact reproducibility. + for param in hf_model.parameters(): + param.data.copy_(2.0 ** torch.randint_like(param, -4, 0)) + + te_model = qwen3_moe_model.Qwen3MoeForCausalLM(_TEST_CONFIG).to( + dtype=torch.float32, device=device + ) + _copy_hf_to_te(hf_model, te_model) + + input_ids = torch.randint(0, _TEST_CONFIG.vocab_size, (_BATCH, _SEQ_LEN), device=device) + + print("Running forward pass...") + hf_model.eval() + te_model.eval() + + with torch.no_grad(): + hf_logits = hf_model(input_ids=input_ids).logits + te_logits, _ = te_model(input_ids) + + torch.testing.assert_close( + torch.softmax(te_logits, dim=-1), torch.softmax(hf_logits, dim=-1), atol=1e-5, rtol=0 + ) + print(f" Forward PASSED — logits match (atol=1e-5)") + + print("Running backward pass...") + hf_model.zero_grad(set_to_none=True) + te_model.zero_grad(set_to_none=True) + + hf_logits = hf_model(input_ids=input_ids).logits + grad_output = 2.0 ** torch.randint_like(hf_logits, -4, 0) + hf_logits.backward(grad_output) + + te_logits, _ = te_model(input_ids) + # Use identical logits so the backward graph sees the same values. + te_logits.data.copy_(hf_logits.detach()) + te_logits.backward(grad_output) + + max_grad_err = 0.0 + for name, te_param in te_model.named_parameters(): + if te_param.grad is None: + continue + hf_grad = _te_name_to_hf_param(name, hf_model, return_grad=True) + if hf_grad is None: + raise ValueError(f"Unmapped TE parameter: {name} {tuple(te_param.shape)}") + + torch.testing.assert_close( + te_param.grad, + hf_grad, + atol=1e-2, + rtol=1e-2, + msg=lambda m: f"{m}\n{name} {tuple(te_param.shape)}", # pylint: disable=cell-var-from-loop + ) + err = (te_param.grad - hf_grad).abs().max().item() + max_grad_err = max(max_grad_err, err) + + print(f" Backward PASSED — all gradients match (atol=1e-2, max_err={max_grad_err:.2e})") + print("All checks passed.") + + +if __name__ == "__main__": + main() From c56277efc89fc1afa652216e13d919ea8b009467 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 11:28:43 -0700 Subject: [PATCH 2/7] add readme Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/README.md | 43 ++++++++++++++++++++++++ examples/pytorch/qwen3_moe/test_vs_hf.py | 5 ++- 2 files changed, 45 insertions(+), 3 deletions(-) create mode 100644 examples/pytorch/qwen3_moe/README.md diff --git a/examples/pytorch/qwen3_moe/README.md b/examples/pytorch/qwen3_moe/README.md new file mode 100644 index 0000000000..eb98e09bc2 --- /dev/null +++ b/examples/pytorch/qwen3_moe/README.md @@ -0,0 +1,43 @@ +# Qwen3 MoE with TransformerEngine + +Single-GPU implementation of [Qwen3 MoE](https://huggingface.co/Qwen/Qwen3-235B-A22B) +using TransformerEngine for FP8 training and fused kernels. + +## Architecture + +Same architecture as HuggingFace `Qwen3MoeForCausalLM`, with standard PyTorch +modules replaced by TE equivalents: + +| HuggingFace | TransformerEngine | +|---|---| +| `self_attn` (Q/K/V proj + attention + O proj) | `te.MultiheadAttention` (fused LN + QKV + QK-norm + RoPE + attn + O) | +| `post_attention_layernorm` | `te.RMSNorm` | +| Expert MLP (SwiGLU) | `te_ops.Sequential(GroupedLinear, SwiGLU, GroupedLinear)` | +| `model.norm` | `te.RMSNorm` | +| `lm_head` | `te.Linear` | +| RoPE | `te.RotaryPositionEmbedding` | + +MoE token routing uses `te.moe_permute_with_probs` / `te.moe_unpermute` for +permutation and expert computation uses `te_ops.GroupedLinear` for fused batched +GEMMs. + +## Files + +| File | Description | +|---|---| +| `config.py` | `Qwen3MoeConfig` dataclass (defaults match HuggingFace) | +| `model.py` | Full model: `Qwen3MoeRouter`, `Qwen3MoeBlock`, `Qwen3MoeDecoderLayer`, `Qwen3MoeModel`, `Qwen3MoeForCausalLM` | +| `test_vs_hf.py` | Numerical comparison against HuggingFace (forward logits + backward gradients) | + +## Running the comparison test + +```bash +pip install transformers +python test_vs_hf.py [--seed 42] +``` + +This builds a small model (2 layers, 8 experts, hidden_size=256), copies weights +from HuggingFace into the TE model, and checks: + +1. **Forward**: softmax(logits) match at `atol=1e-5`. +2. **Backward**: all parameter gradients match at `atol=1e-2`. diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py index 59ae4685b3..ff5ca0eff0 100644 --- a/examples/pytorch/qwen3_moe/test_vs_hf.py +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -4,7 +4,7 @@ verifies that logits and gradients match. Usage: - python -m examples.pytorch.qwen3_moe.test_vs_hf [--seed 42] [--device cuda] + python -m examples.pytorch.qwen3_moe.test_vs_hf [--seed 42] Requirements: pip install transformers @@ -146,11 +146,10 @@ def _copy_hf_to_te( def main() -> None: parser = argparse.ArgumentParser(description="Compare TE vs HF Qwen3 MoE") parser.add_argument("--seed", type=int, default=42) - parser.add_argument("--device", type=str, default="cuda") args = parser.parse_args() torch.manual_seed(args.seed) - device = args.device + device = "cuda" hf_config = _to_hf_config(_TEST_CONFIG) hf_model = modeling_qwen3_moe.Qwen3MoeForCausalLM(hf_config).to( From 2ea1d363e591ec9a61ba636aa8772a425d539076 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 11:33:11 -0700 Subject: [PATCH 3/7] remove python 3.12 feature Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/model.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/examples/pytorch/qwen3_moe/model.py b/examples/pytorch/qwen3_moe/model.py index 488c107a25..6222a7b89d 100644 --- a/examples/pytorch/qwen3_moe/model.py +++ b/examples/pytorch/qwen3_moe/model.py @@ -13,7 +13,6 @@ """ from collections.abc import Callable -from typing import override import torch import torch.nn as nn @@ -63,7 +62,6 @@ def __init__( self.weight = nn.Parameter(torch.empty(num_experts, hidden_size)) nn.init.normal_(self.weight, mean=0.0, std=initializer_range) - @override def forward( self, hidden_states: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -133,7 +131,6 @@ def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: for param in self.expert_mlp.parameters(): init_fn(param) - @override def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Route tokens to experts and apply SwiGLU. @@ -209,8 +206,7 @@ def __init__( self.post_attention_layernorm = te.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.mlp = Qwen3MoeBlock(config) - @override - def forward( # type: ignore[override] + def forward( self, hidden_states: torch.Tensor, freqs: torch.Tensor, @@ -265,7 +261,6 @@ def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: dim=config.head_dim, rotary_base=config.rope_theta ) - @override def forward( self, input_ids: torch.Tensor, @@ -316,7 +311,6 @@ def __init__(self, config: qwen3_moe_config.Qwen3MoeConfig) -> None: init_method=init_fn, ) - @override def forward( self, input_ids: torch.Tensor, From ed3cf31a62fb54f10dd8215f3fbe1189c771ae47 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 11:34:37 -0700 Subject: [PATCH 4/7] add license Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/README.md | 4 ++++ examples/pytorch/qwen3_moe/config.py | 4 ++++ examples/pytorch/qwen3_moe/model.py | 4 ++++ examples/pytorch/qwen3_moe/test_vs_hf.py | 4 ++++ 4 files changed, 16 insertions(+) diff --git a/examples/pytorch/qwen3_moe/README.md b/examples/pytorch/qwen3_moe/README.md index eb98e09bc2..e7ee873510 100644 --- a/examples/pytorch/qwen3_moe/README.md +++ b/examples/pytorch/qwen3_moe/README.md @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + # Qwen3 MoE with TransformerEngine Single-GPU implementation of [Qwen3 MoE](https://huggingface.co/Qwen/Qwen3-235B-A22B) diff --git a/examples/pytorch/qwen3_moe/config.py b/examples/pytorch/qwen3_moe/config.py index 5c34ab733b..ba15ce7b0b 100644 --- a/examples/pytorch/qwen3_moe/config.py +++ b/examples/pytorch/qwen3_moe/config.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + """Configuration for Qwen3 MoE model. Default values match the HuggingFace Transformers Qwen3MoeConfig. diff --git a/examples/pytorch/qwen3_moe/model.py b/examples/pytorch/qwen3_moe/model.py index 6222a7b89d..f31fd7ebb7 100644 --- a/examples/pytorch/qwen3_moe/model.py +++ b/examples/pytorch/qwen3_moe/model.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + """Qwen3 MoE model implementation using TransformerEngine modules. Same architecture as HuggingFace Transformers Qwen3MoeForCausalLM, with PyTorch modules replaced by TransformerEngine diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py index ff5ca0eff0..0f5ae72019 100644 --- a/examples/pytorch/qwen3_moe/test_vs_hf.py +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -1,3 +1,7 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + """Compare Qwen3 MoE TE implementation against HuggingFace reference. Runs forward and backward passes on both models with identical weights and From 4063918fef4b17d84827016695b021b0888dd3c4 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 12:30:24 -0700 Subject: [PATCH 5/7] Update examples/pytorch/qwen3_moe/test_vs_hf.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/test_vs_hf.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py index 0f5ae72019..7855e968af 100644 --- a/examples/pytorch/qwen3_moe/test_vs_hf.py +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -5,10 +5,9 @@ """Compare Qwen3 MoE TE implementation against HuggingFace reference. Runs forward and backward passes on both models with identical weights and -verifies that logits and gradients match. - Usage: - python -m examples.pytorch.qwen3_moe.test_vs_hf [--seed 42] + cd examples/pytorch/qwen3_moe + python test_vs_hf.py [--seed 42] Requirements: pip install transformers From fc0ecf7c8f229a681161564bd8e0b8cede66da43 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 12:30:34 -0700 Subject: [PATCH 6/7] Update examples/pytorch/qwen3_moe/test_vs_hf.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/test_vs_hf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py index 7855e968af..b4090ce012 100644 --- a/examples/pytorch/qwen3_moe/test_vs_hf.py +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -192,10 +192,8 @@ def main() -> None: te_logits, _ = te_model(input_ids) # Use identical logits so the backward graph sees the same values. - te_logits.data.copy_(hf_logits.detach()) + te_logits, _ = te_model(input_ids) te_logits.backward(grad_output) - - max_grad_err = 0.0 for name, te_param in te_model.named_parameters(): if te_param.grad is None: continue From 189b18fd27a656d01b7cc902faae704340e011d9 Mon Sep 17 00:00:00 2001 From: Hao Wu Date: Wed, 15 Apr 2026 12:49:11 -0700 Subject: [PATCH 7/7] rollback wrong changes initiated by AI Signed-off-by: Hao Wu --- examples/pytorch/qwen3_moe/test_vs_hf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/pytorch/qwen3_moe/test_vs_hf.py b/examples/pytorch/qwen3_moe/test_vs_hf.py index b4090ce012..97b40b2d29 100644 --- a/examples/pytorch/qwen3_moe/test_vs_hf.py +++ b/examples/pytorch/qwen3_moe/test_vs_hf.py @@ -192,8 +192,9 @@ def main() -> None: te_logits, _ = te_model(input_ids) # Use identical logits so the backward graph sees the same values. - te_logits, _ = te_model(input_ids) + te_logits.data.copy_(hf_logits.detach()) te_logits.backward(grad_output) + max_grad_err = 0.0 for name, te_param in te_model.named_parameters(): if te_param.grad is None: continue