Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Calling get_model_state_dict/set_model_state_dict requires forward pass for _lazy_initΒ #125170

@mvpatel2000

Description

@mvpatel2000

πŸ› Describe the bug

The new distributed APIs get_model_state_dict/set_model_state_dict require running at least one forward pass in order to call _lazy_init. For example,

                from torch.distributed.fsdp._runtime_utils import _lazy_init
                for module in self.model.modules():
                    if isinstance(module, FSDP):
                        _lazy_init(module, module)
                set_model_state_dict(
                    model=self.model,
                    model_state_dict=state_dict['model'],
                    options=StateDictOptions(
                        full_state_dict=self.fsdp_state_dict_type != 'sharded',
                        strict=strict,
                        cpu_offload=True,
                    ),
                )

I believe get/set_model_state_dict (and maybe get/set_optim_state_dict) should call _lazy_init as well?

Versions

Torch 2.3

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queue

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions