Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions examples/pytorch/qwen3_moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

# Qwen3 MoE with TransformerEngine

Single-GPU implementation of [Qwen3 MoE](https://huggingface.co/Qwen/Qwen3-235B-A22B)
using TransformerEngine for FP8 training and fused kernels.

## Architecture

Same architecture as HuggingFace `Qwen3MoeForCausalLM`, with standard PyTorch
modules replaced by TE equivalents:

| HuggingFace | TransformerEngine |
|---|---|
| `self_attn` (Q/K/V proj + attention + O proj) | `te.MultiheadAttention` (fused LN + QKV + QK-norm + RoPE + attn + O) |
| `post_attention_layernorm` | `te.RMSNorm` |
| Expert MLP (SwiGLU) | `te_ops.Sequential(GroupedLinear, SwiGLU, GroupedLinear)` |
| `model.norm` | `te.RMSNorm` |
| `lm_head` | `te.Linear` |
| RoPE | `te.RotaryPositionEmbedding` |

MoE token routing uses `te.moe_permute_with_probs` / `te.moe_unpermute` for
permutation and expert computation uses `te_ops.GroupedLinear` for fused batched
GEMMs.

## Files

| File | Description |
|---|---|
| `config.py` | `Qwen3MoeConfig` dataclass (defaults match HuggingFace) |
| `model.py` | Full model: `Qwen3MoeRouter`, `Qwen3MoeBlock`, `Qwen3MoeDecoderLayer`, `Qwen3MoeModel`, `Qwen3MoeForCausalLM` |
| `test_vs_hf.py` | Numerical comparison against HuggingFace (forward logits + backward gradients) |

## Running the comparison test

```bash
pip install transformers
python test_vs_hf.py [--seed 42]
```

This builds a small model (2 layers, 8 experts, hidden_size=256), copies weights
from HuggingFace into the TE model, and checks:

1. **Forward**: softmax(logits) match at `atol=1e-5`.
2. **Backward**: all parameter gradients match at `atol=1e-2`.
54 changes: 54 additions & 0 deletions examples/pytorch/qwen3_moe/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Configuration for Qwen3 MoE model.

Default values match the HuggingFace Transformers Qwen3MoeConfig.
"""

from dataclasses import dataclass


@dataclass(frozen=True)
class Qwen3MoeConfig:
"""Configuration class for Qwen3 MoE model.

Attributes:
vocab_size: Size of the vocabulary.
hidden_size: Dimensionality of the hidden representations.
moe_intermediate_size: Dimensionality of each MoE expert's intermediate layer.
num_hidden_layers: Number of decoder layers.
num_attention_heads: Number of query attention heads.
num_key_value_heads: Number of key/value attention heads (for GQA).
max_position_embeddings: Maximum sequence length supported by RoPE.
initializer_range: Standard deviation for weight initialization.
rms_norm_eps: Epsilon for RMSNorm layers.
rope_theta: Base frequency for RoPE.
attention_bias: Whether to use bias in attention projections.
attention_dropout: Dropout rate for attention weights.
num_experts: Total number of MoE experts.
top_k: Number of experts selected per token (top-k).
norm_topk_prob: Whether to renormalize top-k routing probabilities.
"""

vocab_size: int = 151936
hidden_size: int = 2048
moe_intermediate_size: int = 768
num_hidden_layers: int = 24
num_attention_heads: int = 32
num_key_value_heads: int = 4
max_position_embeddings: int = 32768
initializer_range: float = 0.02
rms_norm_eps: float = 1e-6
rope_theta: float = 1000000.0
attention_bias: bool = False
attention_dropout: float = 0.0
num_experts: int = 128
top_k: int = 8
norm_topk_prob: bool = False

@property
def head_dim(self) -> int:
"""Dimensionality of each attention head."""
return self.hidden_size // self.num_attention_heads
Loading
Loading