Skip to content

[Enhance] Muon: unify comm interface & reduce peak memory#1881

Draft
nil0x9 wants to merge 3 commits into
InternLM:mainfrom
nil0x9:linty/opt-muon
Draft

[Enhance] Muon: unify comm interface & reduce peak memory#1881
nil0x9 wants to merge 3 commits into
InternLM:mainfrom
nil0x9:linty/opt-muon

Conversation

@nil0x9

@nil0x9 nil0x9 commented Jun 5, 2026

Copy link
Copy Markdown
Collaborator

This PR makes the following changes regarding the current Muon implementation.

  • Replace multiple mutually exclusive process group params (process_group, subgroup_process_group, world_size, device_rank, subgroup_size, subgroup_rank, use_agrs) with a single process_group +
    explicit comm_strategy: Literal["agrs", "subgroup_allgather", "all_to_all", "replicated"]
  • Reduce peak memory by extracting orthogonalization logic into sub-functions (function scope clears refcount on intermediate tensors) and eagerly deleting temporaries
  • Refactor existing tests from test_muon_compile.py into test_muon.py, add independent unit tests for all 4 communication strategies

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant