Fix failing SKRL multi-gpu training error#6035
Conversation
There was a problem hiding this comment.
🤖 IsaacLab Review Bot — PR #6035
Title: Fix failing SKRL multi-gpu training error
Author: @StafaH
Actual change: 1 commit (5eb949a2) — 2 files touched
Summary
This PR fixes a JAX import error in SKRL's multi-GPU/multi-host training path. On recent JAX versions, jax.experimental.multihost_utils must be explicitly imported before skrl's distributed models reference it. The fix adds a preemptive import of this submodule in the SkrlVecEnvWrapper function when the JAX ML framework is selected.
Files Changed (actual fix only)
| File | Change |
|---|---|
source/isaaclab_rl/isaaclab_rl/skrl.py |
Added import jax.experimental.multihost_utils preload |
source/isaaclab_rl/changelog.d/fix-skrl-jax-multihost-import.rst |
Changelog fragment |
Analysis
Correctness: ✅ Looks correct
- The fix adds
import jax.experimental.multihost_utils # noqa: F401before importingskrl.envs.wrappers.jax.wrap_env - This is the right placement — preloading the submodule before skrl attempts to use it ensures the module is available in
sys.modules - The
# noqa: F401comment correctly suppresses the "imported but unused" lint warning since this is a side-effect import
Observations:
-
Minimal and targeted — The fix is a single import line, narrowly scoped to the JAX code path. No risk to torch/warp paths.
-
Root cause is upstream — The actual bug is in skrl's distributed models referencing
jax.experimental.multihost_utilswithout importing it. This workaround is a reasonable IsaacLab-side fix while waiting for an upstream skrl fix. -
No version guard — The comment says "broken on recent JAX" but there's no version check. This is fine since the import is harmless on older JAX versions where the module already loads transitively.
-
Changelog fragment is well-written — Clearly describes the issue and references the relevant function.
Suggestions (non-blocking)
- Consider adding a brief comment or link to the skrl issue/version where this broke, to help future maintainers know when this workaround can be removed.
- The PR description checklist is entirely unchecked — minor housekeeping gap.
Verdict
✅ LGTM — Clean, minimal, well-placed fix for a real multi-GPU training failure. No concerns.
0 blocking findings · 1 suggestion
There was a problem hiding this comment.
🤖 IsaacLab Review Bot — PR #6035 (Updated)
Title: Fix failing SKRL multi-gpu training error
Author: @StafaH
Branch: mh/qa_fix_skrl → develop
Incremental Review Update (commit 02325a8)
New commit: 02325a8 — Merge branch 'develop' into mh/qa_fix_skrl
This is a merge commit to sync the PR branch with develop. No new feature changes or code related to the SKRL fix were introduced.
✅ Previous review findings remain unchanged — the core JAX import fix and supporting changes in commit 5eb949a are still valid.
Original Review Summary (commit 5eb949a)
This PR addresses a JAX import error in SKRL multi-GPU training:
- Core fix: Pre-imports
jax.experimental.multihost_utilsbefore skrl's JAX wrapper to preventModuleNotFoundError - New
train_multigpu.pylauncher: Unified CLI for multi-GPU training across rl_games, rsl_rl, and skrl - Warp framework support: Adds
"warp"as a validml_frameworkoption - SKRL config modernization: Renames
state_preprocessor→observation_preprocessor - Bug fix: Adds missing
raisekeyword beforeValueError - Test modernization & lazy imports
Outstanding Items from Previous Review
⚠️ Medium: Add changelog coverage for thestate_preprocessor→observation_preprocessorrename- 💡 Suggestion: Consider adding a quick existence check for
TRAIN_SCRIPTpath intrain_multigpu.py - 📝 PR hygiene: PR checklist is unchecked; ensure merge conflicts are resolved
Reviewed commit: 02325a8 (merge commit from develop)
Original review: 5eb949a
## Summary - Cherry-pick #6035 from develop into release/3.0.0-beta2 - Preload jax.experimental.multihost_utils before importing skrl's JAX wrapper - Add the isaaclab_rl changelog fragment ## Validation - git diff --check refs/remotes/upstream/release/3.0.0-beta2...HEAD - python3 -m py_compile source/isaaclab_rl/isaaclab_rl/skrl.py Co-authored-by: Mustafa H <34825877+StafaH@users.noreply.github.com>
Description
Fixes a skrl error where jax experimental is not imported and must be imported manually
Checklist
pre-commitchecks with./isaaclab.sh --formatconfig/extension.tomlfileCONTRIBUTORS.mdor my name already exists there