zero3: invalidate coordinator trace on hook re-registration#8043
Open
roycho96 wants to merge 3 commits into
Open
zero3: invalidate coordinator trace on hook re-registration#8043roycho96 wants to merge 3 commits into
roycho96 wants to merge 3 commits into
Conversation
Signed-off-by: Sung Hyun Cho <[email protected]>
Signed-off-by: Sung Hyun Cho <[email protected]>
Signed-off-by: Sung Hyun Cho <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Re-registering ZeRO-3 module hooks after they were removed (e.g. via
unwrap_model_for_generation) leaves the param coordinator's recorded trace stale. The next training forward raisesIndexError: pop from an empty dequefrom_start_of_forward_hook -> reset_step -> record_parameters -> popleft.Repro
DeepSpeed master, torch 2.8.0+cu128, transformers, peft. Single GPU.
Run with
torchrun --nproc-per-node=1 repro.py. Second iteration raises the IndexError.Fix
Two small edits in
deepspeed/runtime/zero/:parameter_offload.py::_register_deepspeed_module: when the root module is re-registered, invalidate the coordinator trace so the next forward re-records cleanly.partitioned_param_coordinator.py::_clear_trace_structures: also clear__step_id_module_fetched_for, which was being left populated and caused the empty-deque pop.Both guards are no-ops on initial registration (trace is already INVALID) and on non-root submodule walks.
Test
tests/unit/runtime/zero/test_unwrap_model.py::TestUnwrapModelTraceInvalidatecovers the path: run one training step, wrap withunwrap_model_for_generation, assert the coordinator returns to INVALID. World size 2.