diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index aba0cde6266d..1edd666e532d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -305,6 +305,12 @@ def mark_persistent_parameters(self, param_threshold, model_threshold): return persistent_params def _register_deepspeed_module(self, module, count=[0]): + # re-registering hooks on the root module leaves the coordinator trace stale; + # invalidate so it re-records on the next forward. + if module is self.module: + coordinator = self.get_param_coordinator() + if coordinator is not None and not coordinator.is_invalid_trace(): + coordinator._invalidate_trace() my_count = count[0] module.ds_id = my_count diff --git a/deepspeed/runtime/zero/partitioned_param_coordinator.py b/deepspeed/runtime/zero/partitioned_param_coordinator.py index 4877b44c8934..7d9476a35521 100644 --- a/deepspeed/runtime/zero/partitioned_param_coordinator.py +++ b/deepspeed/runtime/zero/partitioned_param_coordinator.py @@ -160,6 +160,9 @@ def _clear_trace_structures(self) -> None: self.__submodule_order = [] self.__param_order = [] self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10)) + # clear the fetch-step deque too; a stale entry here causes record_parameters() to + # pop an empty deque (IndexError) after trace invalidation. + self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque()) self.__param_queue = None def is_complete_trace(self) -> bool: diff --git a/tests/unit/runtime/zero/test_unwrap_model.py b/tests/unit/runtime/zero/test_unwrap_model.py index 8a7bde215301..ed20661b8172 100644 --- a/tests/unit/runtime/zero/test_unwrap_model.py +++ b/tests/unit/runtime/zero/test_unwrap_model.py @@ -3,11 +3,13 @@ # DeepSpeed Team +import torch + import deepspeed from deepspeed.runtime.zero import unwrap_model_for_generation from deepspeed.accelerator import get_accelerator -from unit.common import DistributedTest +from unit.common import DistributedTest, preferred_dtype from unit.simple_model import SimpleModel config = { @@ -65,3 +67,25 @@ def hooks_exist(engine): # assert hooks assert hooks_exist(engine) + + +class TestUnwrapModelTraceInvalidate(DistributedTest): + # unwrap_model_for_generation removes and re-registers the ZeRO-3 hooks; the + # coordinator's recorded trace must be invalidated so the next forward re-records. + world_size = 2 + + def test(self): + model = SimpleModel(hidden_dim=100) + engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config) + coordinator = engine.optimizer.parameter_offload.get_param_coordinator() + + # run one step so the coordinator records a trace (RECORD -> COMPLETE) + x = torch.randn(2, 100, device=engine.device, dtype=preferred_dtype()) + y = torch.empty(2, dtype=torch.long, device=engine.device).random_(100) + engine(x, y) + assert not coordinator.is_invalid_trace() + + # the wrap cycle around an out-of-band forward must invalidate the recorded trace + with unwrap_model_for_generation(engine): + pass + assert coordinator.is_invalid_trace()