From e3c62d078a17c51b48e1679d31c98f322c7544f4 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Tue, 2 Jun 2026 22:23:21 +0900 Subject: [PATCH 1/3] zero3: invalidate coordinator trace on hook re-registration Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/parameter_offload.py | 7 +++++++ deepspeed/runtime/zero/partitioned_param_coordinator.py | 3 +++ 2 files changed, 10 insertions(+) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index aba0cde6266d..f30ade650c2e 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -305,6 +305,13 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): return persistent_params def _register_deepspeed_module(self, module, count=[0]): + # re-registering hooks on the root module (e.g. after an out-of-band generate that + # ran with hooks removed) leaves the coordinator trace stale; invalidate so it + # re-records on the next forward. + if module is self.module: + coordinator = self.get_param_coordinator() + if coordinator is not None and not coordinator.is_invalid_trace(): + coordinator._invalidate_trace() my_count = count[0] module.ds_id = my_count diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 4877b44c8934..7d9476a35521 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -160,6 +160,9 @@ def _clear_trace_structures(self) -> None: self.__submodule_order = [] self.__param_order = [] self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) + # clear the fetch-step deque too; a stale entry here causes record_parameters() to + # pop an empty deque (IndexError) after trace invalidation. + self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) self.__param_queue = None def is_complete_trace(self) -> bool: From da50b98fb548def0a2461010bf2c4dcd99587a7f Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Tue, 2 Jun 2026 22:32:07 +0900 Subject: [PATCH 2/3] cover trace invalidation on hook re-registration Signed-off-by: Sung Hyun Cho --- tests/unit/runtime/zero/test_unwrap_model.py | 27 +++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py index 8a7bde215301..528c9aa6d75b 100644 --- a/tests/unit/runtime/zero/test_unwrap_model.py +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -3,11 +3,13 @@ # DeepSpeed Team +import torch + import deepspeed from deepspeed.runtime.zero import unwrap_model_for_generation from deepspeed.accelerator import get_accelerator -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel config = { @@ -65,3 +67,26 @@ def hooks_exist(engine): # assert hooks assert hooks_exist(engine) + + +class TestUnwrapModelTraceInvalidate(DistributedTest): + # re-registering hooks on the root module (e.g. via unwrap_model_for_generation around + # an on-policy generate) must leave the coordinator's recorded trace invalidated so + # the next training forward re-records cleanly instead of popping a stale deque. + world_size = 2 + + def test(self): + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + coordinator = engine.optimizer.parameter_offload.get_param_coordinator() + + # run one step so the coordinator records a trace (RECORD -> COMPLETE) + x = torch.randn(2, 100, device=engine.device, dtype=preferred_dtype()) + y = torch.empty(2, dtype=torch.long, device=engine.device).random_(100) + engine(x, y) + assert not coordinator.is_invalid_trace() + + # the wrap cycle around an out-of-band forward must invalidate the recorded trace + with unwrap_model_for_generation(engine): + pass + assert coordinator.is_invalid_trace() From 86cf38323b7667e8d41f1cc1c4cce15a55a44436 Mon Sep 17 00:00:00 2001 From: Sung Hyun Cho Date: Tue, 2 Jun 2026 22:41:11 +0900 Subject: [PATCH 3/3] trim comments Signed-off-by: Sung Hyun Cho --- deepspeed/runtime/zero/parameter_offload.py | 5 ++--- tests/unit/runtime/zero/test_unwrap_model.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index f30ade650c2e..1edd666e532d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -305,9 +305,8 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): return persistent_params def _register_deepspeed_module(self, module, count=[0]): - # re-registering hooks on the root module (e.g. after an out-of-band generate that - # ran with hooks removed) leaves the coordinator trace stale; invalidate so it - # re-records on the next forward. + # re-registering hooks on the root module leaves the coordinator trace stale; + # invalidate so it re-records on the next forward. if module is self.module: coordinator = self.get_param_coordinator() if coordinator is not None and not coordinator.is_invalid_trace(): diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py index 528c9aa6d75b..ed20661b8172 100644 --- a/tests/unit/runtime/zero/test_unwrap_model.py +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -70,9 +70,8 @@ def hooks_exist(engine): class TestUnwrapModelTraceInvalidate(DistributedTest): - # re-registering hooks on the root module (e.g. via unwrap_model_for_generation around - # an on-policy generate) must leave the coordinator's recorded trace invalidated so - # the next training forward re-records cleanly instead of popping a stale deque. + # unwrap_model_for_generation removes and re-registers the ZeRO-3 hooks; the + # coordinator's recorded trace must be invalidated so the next forward re-records. world_size = 2 def test(self):