zero3: defer param release during retain_graph backward #7352#8045
zero3: defer param release during retain_graph backward #7352#8045nathon-lee wants to merge 21 commits into
Conversation
60938ff to
5bc41ca
Compare
Signed-off-by: nathon-lee <leejianwoo@gmail.com> fix: add ZeRO-3 second backward after retain_graph=True fails with tensor size mismatch Signed-off-by: nathon-lee <leejianwoo@gmail.com> fix: Stage 3 Temporarily change the exemption from xfail to skip (for this test case only) Signed-off-by: nathon-lee <leejianwoo@gmail.com> fix: Fix ZeRO-3 behavior for two separate backward passes on the same forward graph. Signed-off-by: nathon-lee <leejianwoo@gmail.com>
5bc41ca to
b41bb4c
Compare
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: b41bb4cc1b
ℹ️ About Codex in GitHub
Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".
| retain_graph_backward = bool(zero_optimizer is not None | ||
| and getattr(zero_optimizer, "retain_graph_on_current_backward", False)) | ||
| if not retain_graph_backward: |
There was a problem hiding this comment.
Preserve retained params for manual backward
When users take the supported torch-style path engine.scale(loss).backward(retain_graph=True), engine.backward() is bypassed, so retain_graph_on_current_backward is never set before these hooks run. In that scenario this condition evaluates false and ZeRO-3 still releases the gathered parameters after the first backward, leaving the retained graph with partitioned saved tensors for the next backward over the same forward.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Good catch, thanks. You're right that the original fix only set retain_graph_on_current_backward in engine.backward() / ZeROOptimizer.backward(), so the torch-style manual path engine.scale(loss).backward(retain_graph=True) bypassed it and ZeRO-3 still released the gathered params after the first backward.
Since autograd hooks can't see the user's retain_graph argument on a direct .backward() call, I propagate it explicitly through engine.scale(): it now accepts a retain_graph argument that sets the flag on the manual path. The flag is reset in the shared _backward_epilogue(), which both the engine.backward() and manual paths reach after the gradient hooks run.
Added a regression test test_two_losses_separate_manual_backward_gas1 covering the manual path for ZeRO stages 1/2/3 (two separate backwards over one forward with zero_grad() in between). Both it and the existing test_two_losses_separate_backward_gas1 pass (3 passed each).
Signed-off-by: nathon-lee <leejianwoo@gmail.com>
8ae5113 to
5c75f99
Compare
Summary
Fix ZeRO-3 so two separate backward passes on the same forward graph work correctly when
retain_graph=Trueis used on the first backward.What changed
retain_graphthrough the ZeRO backward path.finallyto avoid state leakage.Why
This fixes the ZeRO-3 failure where the second backward on the same forward graph hit a tensor size mismatch after
zero_grad(). The regression is tracked by issue #7352.Validation
Test command
CUDA_VISIBLE_DEVICES=0,1 DS_DISABLE_REUSE_DIST_ENV=1 NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL DS_UNITTEST_TIMEOUT=120 pytest test_zero_user_backward.py -k "test_two_losses_separate_backward_gas1" -vv -s -rsResult
3 passed
0 skipped
55 deselected
Notes
The test previously reproduced a ZeRO-3 RuntimeError on the second backward.
The current run confirms the stage 1/2/3 regression coverage is passing.