Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ def mark_persistent_parameters(self, param_threshold, model_threshold):
return persistent_params

def _register_deepspeed_module(self, module, count=[0]):
# re-registering hooks on the root module leaves the coordinator trace stale;
# invalidate so it re-records on the next forward.
if module is self.module:
coordinator = self.get_param_coordinator()
if coordinator is not None and not coordinator.is_invalid_trace():
coordinator._invalidate_trace()
my_count = count[0]
module.ds_id = my_count

Expand Down
3 changes: 3 additions & 0 deletions deepspeed/runtime/zero/partitioned_param_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def _clear_trace_structures(self) -> None:
self.__submodule_order = []
self.__param_order = []
self.__most_recent_step_id_param_fetched_for = collections.defaultdict(lambda: int(-1e10))
# clear the fetch-step deque too; a stale entry here causes record_parameters() to
# pop an empty deque (IndexError) after trace invalidation.
self.__step_id_module_fetched_for = collections.defaultdict(lambda: collections.deque())
self.__param_queue = None

def is_complete_trace(self) -> bool:
Expand Down
26 changes: 25 additions & 1 deletion tests/unit/runtime/zero/test_unwrap_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@

# DeepSpeed Team

import torch

import deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from deepspeed.accelerator import get_accelerator

from unit.common import DistributedTest
from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import SimpleModel

config = {
Expand Down Expand Up @@ -65,3 +67,25 @@ def hooks_exist(engine):

# assert hooks
assert hooks_exist(engine)


class TestUnwrapModelTraceInvalidate(DistributedTest):
# unwrap_model_for_generation removes and re-registers the ZeRO-3 hooks; the
# coordinator's recorded trace must be invalidated so the next forward re-records.
world_size = 2

def test(self):
model = SimpleModel(hidden_dim=100)
engine, _, _, _ = deepspeed.initialize(args=None, model=model, config=config)
coordinator = engine.optimizer.parameter_offload.get_param_coordinator()

# run one step so the coordinator records a trace (RECORD -> COMPLETE)
x = torch.randn(2, 100, device=engine.device, dtype=preferred_dtype())
y = torch.empty(2, dtype=torch.long, device=engine.device).random_(100)
engine(x, y)
assert not coordinator.is_invalid_trace()

# the wrap cycle around an out-of-band forward must invalidate the recorded trace
with unwrap_model_for_generation(engine):
pass
assert coordinator.is_invalid_trace()
Loading