Skip to content

zero3: defer param release during retain_graph backward #7352#8045

Open
nathon-lee wants to merge 21 commits into
deepspeedai:masterfrom
nathon-lee:test/zero-multi-loss-separate-backward-7352
Open

zero3: defer param release during retain_graph backward #7352#8045
nathon-lee wants to merge 21 commits into
deepspeedai:masterfrom
nathon-lee:test/zero-multi-loss-separate-backward-7352

Conversation

@nathon-lee
Copy link
Copy Markdown
Contributor

@nathon-lee nathon-lee commented Jun 3, 2026

Summary

Fix ZeRO-3 so two separate backward passes on the same forward graph work correctly when retain_graph=True is used on the first backward.

What changed

  • Propagate retain_graph through the ZeRO backward path.
  • Delay ZeRO-3 parameter release during retained backward so saved tensors remain valid for the second backward.
  • Clear the retained-backward flag safely in finally to avoid state leakage.
  • Re-enable ZeRO-3 coverage for the two-loss separate-backward regression test.

Why

This fixes the ZeRO-3 failure where the second backward on the same forward graph hit a tensor size mismatch after zero_grad(). The regression is tracked by issue #7352.

Validation

Test command

CUDA_VISIBLE_DEVICES=0,1 DS_DISABLE_REUSE_DIST_ENV=1 NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL DS_UNITTEST_TIMEOUT=120 pytest test_zero_user_backward.py -k "test_two_losses_separate_backward_gas1" -vv -s -rs

Result

3 passed
0 skipped
55 deselected

root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# CUDA_VISIBLE_DEVICES=0,1 DS_DISABLE_REUSE_DIST_ENV=1 NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL DS_UNITTEST_TIMEOUT=120 pytest unit/v1/zero/test_zero_user_backward.py -k "test_two_losses_separate_manual_backward_gas1" -vv -s -rs
================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /usr/bin/python3.12
cachedir: .pytest_cache
rootdir: /workspace/DeepSpeed_woo/tests
configfile: pytest.ini
plugins: anyio-4.12.0
collected 61 items / 58 deselected / 3 selected                                                                                                          

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[1] [Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[rank0]:[W603 07:29:17.335227904 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank [Gloo] Rank 10 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[W603 07:29:18.304862233 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
PASSED
unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[2] [Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[rank0]:[W603 07:29:30.820676734 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[W603 07:29:31.829038701 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
PASSED
unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[3] [Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[rank0]:[W603 07:29:42.066717923 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
PASSED

==================================================================== warnings summary ====================================================================
<string>:8
  <string>:8: PytestDeprecationWarning: A private pytest class or function was used.

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[1]
  /workspace/DeepSpeed_woo/tests/conftest.py:47: UserWarning: Running test without verifying torch version, please provide an expected torch version with --torch_ver
    warnings.warn(

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[1]
  /workspace/DeepSpeed_woo/tests/conftest.py:54: UserWarning: Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================================== slowest durations ====================================================================
12.53s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[2]
12.37s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[1]
12.00s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_manual_backward_gas1[3]

(6 durations < 1s hidden.)
===================================================== 3 passed, 58 deselected, 3 warnings in 46.71s ======================================================
root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# 
root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# 
root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# CUDA_VISIBLE_DEVICES=0,1 DS_DISABLE_REUSE_DIST_ENV=1 NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL DS_UNITTEST_TIMEOUT=120 pytest unit/v1/zero/test_zero_user_backward.py -k^C-vv -s -rs
root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# CUDA_VISIBLE_DEVICES=0,1 DS_DISABLE_REUSE_DIST_ENV=1 NCCL_IB_DISABLE=1 NCCL_P2P_DISABLE=1 TORCH_NCCL_ASYNC_ERROR_HANDLING=1 TORCH_DISTRIBUTED_DEBUG=DETAIL DS_UNITTEST_TIMEOUT=120 pytest unit/v1/zero/test_zero_user_backward.py -k "test_two_losses_separate_backward_gas1" -vv -s -rs
================================================================== test session starts ===================================================================
platform linux -- Python 3.12.3, pytest-9.0.3, pluggy-1.6.0 -- /usr/bin/python3.12
cachedir: .pytest_cache
rootdir: /workspace/DeepSpeed_woo/tests
configfile: pytest.ini
plugins: anyio-4.12.0
collected 61 items / 58 deselected / 3 selected                                                                                                          

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[1] [Gloo] Rank [Gloo] Rank 1 is connected to 01 is connected to  peer ranks. 1Expected number of connected peer ranks is :  peer ranks. 1Expected number of connected peer ranks is : 
1
[rank0]:[W603 07:30:49.080985243 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank [Gloo] Rank 10 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[W603 07:30:50.032584977 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
PASSED
unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[2] [Gloo] Rank [Gloo] Rank 01 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[rank0]:[W603 07:31:01.285483828 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank [Gloo] Rank 10 is connected to  is connected to 11 peer ranks.  peer ranks. Expected number of connected peer ranks is : Expected number of connected peer ranks is : 11

[W603 07:31:02.257472436 ProcessGroupNCCL.cpp:1524] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
PASSED
unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[3] [Gloo] Rank 1 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[Gloo] Rank 0 is connected to 1 peer ranks. Expected number of connected peer ranks is : 1
[rank0]:[W603 07:31:14.971748647 ProcessGroupNCCL.cpp:5072] Guessing device ID based on global rank. This can cause a hang if rank to GPU mapping is heterogeneous. You can specify device_id in init_process_group()
[Gloo] Rank 1[Gloo] Rank  is connected to 01 is connected to  peer ranks. 1Expected number of connected peer ranks is :  peer ranks. 1Expected number of connected peer ranks is : 
1
PASSED

==================================================================== warnings summary ====================================================================
<string>:8
  <string>:8: PytestDeprecationWarning: A private pytest class or function was used.

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[1]
  /workspace/DeepSpeed_woo/tests/conftest.py:47: UserWarning: Running test without verifying torch version, please provide an expected torch version with --torch_ver
    warnings.warn(

unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[1]
  /workspace/DeepSpeed_woo/tests/conftest.py:54: UserWarning: Running test without verifying cuda version, please provide an expected cuda version with --cuda_ver
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================================== slowest durations ====================================================================
12.47s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[3]
12.23s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[2]
12.09s call     unit/v1/zero/test_zero_user_backward.py::TestZeroUserBackwardSeparateLoss::test_two_losses_separate_backward_gas1[1]

(6 durations < 1s hidden.)
===================================================== 3 passed, 58 deselected, 3 warnings in 44.97s ======================================================
root@964c299dd1b5:/workspace/DeepSpeed_woo/tests# 

Notes

The test previously reproduced a ZeRO-3 RuntimeError on the second backward.
The current run confirms the stage 1/2/3 regression coverage is passing.

@nathon-lee nathon-lee marked this pull request as ready for review June 3, 2026 06:55
@nathon-lee nathon-lee changed the title tests: temporarily skip ZeRO-3 in two-loss separate-backward regression tests: adds a regression test for the behavior reported in issue #7352 Jun 3, 2026
@nathon-lee nathon-lee marked this pull request as draft June 3, 2026 06:59
@nathon-lee nathon-lee force-pushed the test/zero-multi-loss-separate-backward-7352 branch 2 times, most recently from 60938ff to 5bc41ca Compare June 3, 2026 07:13
Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: add ZeRO-3 second backward after retain_graph=True fails with tensor size mismatch

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: Stage 3 Temporarily change the exemption from xfail to skip (for this test case only)

Signed-off-by: nathon-lee <leejianwoo@gmail.com>

fix: Fix ZeRO-3 behavior for two separate backward passes on the same forward graph.

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee nathon-lee force-pushed the test/zero-multi-loss-separate-backward-7352 branch from 5bc41ca to b41bb4c Compare June 3, 2026 07:14
@nathon-lee nathon-lee changed the title tests: adds a regression test for the behavior reported in issue #7352 zero3: defer param release during retain_graph backward #7352 Jun 3, 2026
@nathon-lee nathon-lee marked this pull request as ready for review June 3, 2026 07:19
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: b41bb4cc1b

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment on lines +558 to +560
retain_graph_backward = bool(zero_optimizer is not None
and getattr(zero_optimizer, "retain_graph_on_current_backward", False))
if not retain_graph_backward:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve retained params for manual backward

When users take the supported torch-style path engine.scale(loss).backward(retain_graph=True), engine.backward() is bypassed, so retain_graph_on_current_backward is never set before these hooks run. In that scenario this condition evaluates false and ZeRO-3 still releases the gathered parameters after the first backward, leaving the retained graph with partitioned saved tensors for the next backward over the same forward.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Contributor Author

@nathon-lee nathon-lee Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks. You're right that the original fix only set retain_graph_on_current_backward in engine.backward() / ZeROOptimizer.backward(), so the torch-style manual path engine.scale(loss).backward(retain_graph=True) bypassed it and ZeRO-3 still released the gathered params after the first backward.

Since autograd hooks can't see the user's retain_graph argument on a direct .backward() call, I propagate it explicitly through engine.scale(): it now accepts a retain_graph argument that sets the flag on the manual path. The flag is reset in the shared _backward_epilogue(), which both the engine.backward() and manual paths reach after the gradient hooks run.

Added a regression test test_two_losses_separate_manual_backward_gas1 covering the manual path for ZeRO stages 1/2/3 (two separate backwards over one forward with zero_grad() in between). Both it and the existing test_two_losses_separate_backward_gas1 pass (3 passed each).

Signed-off-by: nathon-lee <leejianwoo@gmail.com>
@nathon-lee nathon-lee requested a review from hwchen2017 as a code owner June 3, 2026 08:16
@nathon-lee nathon-lee force-pushed the test/zero-multi-loss-separate-backward-7352 branch from 8ae5113 to 5c75f99 Compare June 3, 2026 08:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants