-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Closed
Closed
Copy link
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
π Describe the bug
The new distributed APIs get_model_state_dict/set_model_state_dict
require running at least one forward pass in order to call _lazy_init
. For example,
from torch.distributed.fsdp._runtime_utils import _lazy_init
for module in self.model.modules():
if isinstance(module, FSDP):
_lazy_init(module, module)
set_model_state_dict(
model=self.model,
model_state_dict=state_dict['model'],
options=StateDictOptions(
full_state_dict=self.fsdp_state_dict_type != 'sharded',
strict=strict,
cpu_offload=True,
),
)
I believe get/set_model_state_dict (and maybe get/set_optim_state_dict) should call _lazy_init as well?
Versions
Torch 2.3
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue