Thanks to visit codestin.com
Credit goes to github.com

Skip to content

zero3: invalidate coordinator trace on hook re-registration#8043

Open
roycho96 wants to merge 3 commits into
deepspeedai:masterfrom
roycho96:fix/zero3-hook-cycle-trace-invalidate
Open

zero3: invalidate coordinator trace on hook re-registration#8043
roycho96 wants to merge 3 commits into
deepspeedai:masterfrom
roycho96:fix/zero3-hook-cycle-trace-invalidate

Conversation

@roycho96
Copy link
Copy Markdown
Contributor

@roycho96 roycho96 commented Jun 2, 2026

Summary

Re-registering ZeRO-3 module hooks after they were removed (e.g. via unwrap_model_for_generation) leaves the param coordinator's recorded trace stale. The next training forward raises IndexError: pop from an empty deque from _start_of_forward_hook -> reset_step -> record_parameters -> popleft.

Repro

DeepSpeed master, torch 2.8.0+cu128, transformers, peft. Single GPU.

import torch, deepspeed
from deepspeed.runtime.zero import unwrap_model_for_generation
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType

m = "hf-internal-testing/tiny-random-gpt2"
tok = AutoTokenizer.from_pretrained(m); tok.pad_token = tok.eos_token
model = get_peft_model(AutoModelForCausalLM.from_pretrained(m, dtype=torch.bfloat16),
                       LoraConfig(task_type=TaskType.CAUSAL_LM, r=4, target_modules=["c_attn"]))
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

cfg = {"train_micro_batch_size_per_gpu": 1, "bf16": {"enabled": True},
       "zero_optimization": {"stage": 3, "stage3_param_persistence_threshold": 0},
       "optimizer": {"type": "Adam", "params": {"lr": 1e-3}}}
engine, *_ = deepspeed.initialize(model=model, config=cfg,
                                  model_parameters=[p for p in model.parameters() if p.requires_grad])

ids = tok("hello", return_tensors="pt").input_ids.to(engine.device)
for _ in range(2):
    with unwrap_model_for_generation(engine) as unwrapped:
        with torch.no_grad():
            unwrapped.generate(ids, max_new_tokens=4, do_sample=False, pad_token_id=tok.pad_token_id)
    out = engine(input_ids=ids, labels=ids)
    engine.backward(out.loss); engine.step()

Run with torchrun --nproc-per-node=1 repro.py. Second iteration raises the IndexError.

Fix

Two small edits in deepspeed/runtime/zero/:

  • parameter_offload.py::_register_deepspeed_module: when the root module is re-registered, invalidate the coordinator trace so the next forward re-records cleanly.
  • partitioned_param_coordinator.py::_clear_trace_structures: also clear __step_id_module_fetched_for, which was being left populated and caused the empty-deque pop.

Both guards are no-ops on initial registration (trace is already INVALID) and on non-root submodule walks.

Test

tests/unit/runtime/zero/test_unwrap_model.py::TestUnwrapModelTraceInvalidate covers the path: run one training step, wrap with unwrap_model_for_generation, assert the coordinator returns to INVALID. World size 2.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant