Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

awgu
Copy link
Collaborator

@awgu awgu commented Apr 25, 2024

Stack from ghstack (oldest at bottom):

This PR renames the FSDP class to FSDPModule. This is a BC breaking change. The rationale is that FSDPModule is more descriptive since fully_shard is a module-level API (applied to a module arg), so the FSDP class will always correspond to a module.

Also, users commonly import FullyShardedDataParallel as FSDP, so this can help avoid some name conflict in some cases.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented Apr 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/124955

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 93cf3c1 with merge base c82fcb7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category ci-td-distributed labels Apr 25, 2024
awgu pushed a commit that referenced this pull request Apr 25, 2024
ghstack-source-id: 68b0cf6
Pull Request resolved: #124955
cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Apr 26, 2024
ghstack-source-id: 388def4
Pull Request resolved: #124955
cls = module.__class__
dct = {"__deepcopy__": unimplemented_deepcopy}
new_cls = type(f"FSDP{cls.__name__}", (FSDP, cls), dct)
new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't affect the __repr__ of the wrapped class iiuc?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should not.

@awgu
Copy link
Collaborator Author

awgu commented Apr 26, 2024

cc: @wconstab If we land this change, then I think we need to change the PiPPy code to migrate FSDP -> FSDPModule.

@awgu awgu added ciflow/trunk Trigger trunk jobs on your pull request release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Apr 26, 2024
@awgu awgu marked this pull request as ready for review April 27, 2024 02:18
@awgu awgu requested a review from wconstab April 27, 2024 02:18
Copy link
Contributor

@wconstab wconstab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for the heads up RE pippy, i dont mind changing that to match.

@awgu
Copy link
Collaborator Author

awgu commented Apr 29, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorch-bot bot pushed a commit that referenced this pull request May 3, 2024
This PR renames the `FSDP` class to `FSDPModule`. This is a BC breaking change. The rationale is that `FSDPModule` is more descriptive since `fully_shard` is a module-level API (applied to a `module` arg), so the `FSDP` class will always correspond to a module.

Also, users commonly import `FullyShardedDataParallel` as `FSDP`, so this can help avoid some name conflict in some cases.

Pull Request resolved: #124955
Approved by: https://github.com/wanchaol, https://github.com/wconstab
ghstack dependencies: #124651, #124741, #124767, #124768, #124780, #124787
@weifengpy
Copy link
Contributor

oops. now I will use try-except in TorchTune to import FSDPModule

@github-actions github-actions bot deleted the gh/awgu/573/head branch June 9, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants