Skip to content

Add Gemma 4 Megatron model support#736

Draft
FurtherAI wants to merge 498 commits into
mainfrom
austin/gemma_4_model_support
Draft

Add Gemma 4 Megatron model support#736
FurtherAI wants to merge 498 commits into
mainfrom
austin/gemma_4_model_support

Conversation

@FurtherAI

@FurtherAI FurtherAI commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Summary

Adds Gemma 4 MoE support to ART's Megatron backend.

The branch also brings in the supporting Megatron/vLLM runtime infrastructure needed for this model family: sliding-window attention masks (including in CP), vLLM token-id based RL tokenization, native vLLM LoRA support for Gemma4 MoE, train/inference mismatch validation, and the length trainability workflow to replace the yes-no trainability test.

Semantic Change Groups

Gemma 4 model handler and Megatron bridge support

Adds the Gemma 4 model-support handler and registry wiring. The handler covers Gemma4-specific provider setup, layer-family discovery, proportional RoPE handling, router replay, fused expert loading, shared expert overlap, full activation recompute, and LoRA export.

This was needed because Gemma 4 differs from existing Qwen handlers in several important ways: K equals V behavior, fused expert layout, tuple rotary outputs, SWA/global attention layer mix, and bridge/runtime config expectations.

Sliding-window attention and context parallelism

Adds ART flex-attention SWA mask support and wires it through shared-prefix state and CP mask preparation. The CP path now prepares masks up front and the forward path requires the prepared mask, avoiding host-side work and accidental runtime mask construction.

This was needed so Gemma 4's SWA layers match HF/vLLM behavior while preserving the GPU-only forward path and keeping CP planning outside the hot model forward.

RL tokenization via vLLM token ids

Cuts the RL tokenization path over to vLLM-returned token ids. ART now requests vLLM token ids and stores the native vLLM fields on Choice.model_extra:

  • prompt_token_ids
  • token_ids

Tokenization then uses those ids directly, including append-only multi-turn collapse when prompt ids prove equivalence.

This removes fragile chat-template re-rendering for RL trajectories and makes multi-turn/tool/thinking behavior follow the actual serving prompt seen by vLLM.

@Kovbo Should have changed the RL path mostly, but SFT now has its own, simple tokenize path. Can you check this out?

vLLM runtime and LoRA serving

Upgrades the ART vLLM runtime to vLLM 0.23.0 and updates runtime patches for the new API surface. Adds an isolated Gemma4 MoE LoRA patch in vllm_runtime so native LoRA serving works for Gemma4 MoE until upstream support exists.

The branch also adds compact LoRA delta publishing and merged-weight transfer improvements (send the lora to vLLM, merge and apply there), but Gemma4 is now configured to use native vLLM LoRA by default after validation.

Train/inference mismatch validation

Extends the real-path train-inf mismatch stage for Gemma4, including long prompts so SWA is exercised, routed-expert replay, CP scoring and native-LoRA rollout settings.

Trainability workflow

Replaces the awkward yes/no trainability default with the length trainability workflow. The new test trains only on generated length error, uses dedicated Megatron/PipelineTrainer mode, stops early once target error is reached, and has explicit initial/final error assertions.

This gives a cleaner trainability signal for Gemma4 and avoids relying on a prompt/task shape that Gemma4 starts unusually high on (it was getting 0.9375 step 0 reward and no signal).

Runtime config and sequence-length handling

Adds an explicit Megatron runtime config singleton and removes scattered topology/packed-length mutation paths. Model max sequence length is derived from model config, while packed sequence length is treated as a runtime packing capacity rather than a model-context constraint.

This removes the annoying and misleading need to set max sequence length to packed sequence length and the singleton design prevents subtle recompilation and throughput regressions from runtime topology or packed-length changes.

All model support workflow stages were passed. Additional full model throughput was measured, 22k+ tok/s compared to Qwen 3.6 at 27k tok/s (CP4 EP4 5k + 16x100, repeated to ~196k).

FurtherAI added 30 commits June 18, 2026 20:44
…support

# Conflicts:
#	src/art/megatron/service.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant