Add Gemma 4 Megatron model support#736
Draft
FurtherAI wants to merge 498 commits into
Draft
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds Gemma 4 MoE support to ART's Megatron backend.
The branch also brings in the supporting Megatron/vLLM runtime infrastructure needed for this model family: sliding-window attention masks (including in CP), vLLM token-id based RL tokenization, native vLLM LoRA support for Gemma4 MoE, train/inference mismatch validation, and the length trainability workflow to replace the yes-no trainability test.
Semantic Change Groups
Gemma 4 model handler and Megatron bridge support
Adds the Gemma 4 model-support handler and registry wiring. The handler covers Gemma4-specific provider setup, layer-family discovery, proportional RoPE handling, router replay, fused expert loading, shared expert overlap, full activation recompute, and LoRA export.
This was needed because Gemma 4 differs from existing Qwen handlers in several important ways: K equals V behavior, fused expert layout, tuple rotary outputs, SWA/global attention layer mix, and bridge/runtime config expectations.
Sliding-window attention and context parallelism
Adds ART flex-attention SWA mask support and wires it through shared-prefix state and CP mask preparation. The CP path now prepares masks up front and the forward path requires the prepared mask, avoiding host-side work and accidental runtime mask construction.
This was needed so Gemma 4's SWA layers match HF/vLLM behavior while preserving the GPU-only forward path and keeping CP planning outside the hot model forward.
RL tokenization via vLLM token ids
Cuts the RL tokenization path over to vLLM-returned token ids. ART now requests vLLM token ids and stores the native vLLM fields on
Choice.model_extra:prompt_token_idstoken_idsTokenization then uses those ids directly, including append-only multi-turn collapse when prompt ids prove equivalence.
This removes fragile chat-template re-rendering for RL trajectories and makes multi-turn/tool/thinking behavior follow the actual serving prompt seen by vLLM.
@Kovbo Should have changed the RL path mostly, but SFT now has its own, simple tokenize path. Can you check this out?
vLLM runtime and LoRA serving
Upgrades the ART vLLM runtime to vLLM 0.23.0 and updates runtime patches for the new API surface. Adds an isolated Gemma4 MoE LoRA patch in
vllm_runtimeso native LoRA serving works for Gemma4 MoE until upstream support exists.The branch also adds compact LoRA delta publishing and merged-weight transfer improvements (send the lora to vLLM, merge and apply there), but Gemma4 is now configured to use native vLLM LoRA by default after validation.
Train/inference mismatch validation
Extends the real-path train-inf mismatch stage for Gemma4, including long prompts so SWA is exercised, routed-expert replay, CP scoring and native-LoRA rollout settings.
Trainability workflow
Replaces the awkward yes/no trainability default with the length trainability workflow. The new test trains only on generated length error, uses dedicated Megatron/PipelineTrainer mode, stops early once target error is reached, and has explicit initial/final error assertions.
This gives a cleaner trainability signal for Gemma4 and avoids relying on a prompt/task shape that Gemma4 starts unusually high on (it was getting 0.9375 step 0 reward and no signal).
Runtime config and sequence-length handling
Adds an explicit Megatron runtime config singleton and removes scattered topology/packed-length mutation paths. Model max sequence length is derived from model config, while packed sequence length is treated as a runtime packing capacity rather than a model-context constraint.
This removes the annoying and misleading need to set max sequence length to packed sequence length and the singleton design prevents subtle recompilation and throughput regressions from runtime topology or packed-length changes.
All model support workflow stages were passed. Additional full model throughput was measured, 22k+ tok/s compared to Qwen 3.6 at 27k tok/s (CP4 EP4 5k + 16x100, repeated to ~196k).