-
Notifications
You must be signed in to change notification settings - Fork 24.1k
[fsdp] add an experimental allocator hook for buffers that participate in collective communication #147146
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/yifuwang/196/base
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/147146
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 16bd23e with merge base 899066e ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
@awgu can you please give suggestions on the API name and how it should be exposed? |
def _set_fsdp_comm_allocator(allocator: _Allocator): | ||
global _fsdp_comm_allocator | ||
_fsdp_comm_allocator = allocator | ||
|
||
|
||
def _get_fsdp_comm_allocator() -> _Allocator: | ||
if _fsdp_comm_allocator is not None: | ||
return _fsdp_comm_allocator | ||
else: | ||
return torch.empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
naming is hard 😅
if you make these global public, I am okay with that
that might be better than passing as args into fully_shard
since then user is expected to pass the same allocator to all calls of fully_shard
(or I think it does not make too much sense)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that might be better than passing as args into fully_shard since then user is expected to pass the same allocator to all calls of fully_shard (or I think it does not make too much sense)
Makes sense. I made it global mainly because some allocations are performed in custom ops created for tracing, and I don't want to mess them up.
Hmm I think we also need to expose the process group on which the collective will be performed 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm right now it seems like the allocator doesn't know when the memory is not needed? My understanding is tensor doesn't have callback mechanism to tell allocator so, so wonder how GC is supposed to work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yifuwang Got it. I got slightly confused by the example as the allocation is done in python but this makes sense.
…e in collective communication Summary: pytorch#147146 Test Plan: unit test Differential Revision: D69694585
…e in collective communication (pytorch#149150) Summary: pytorch#147146 Test Plan: unit test Differential Revision: D69694585
…e in collective communication (pytorch#149150) Summary: Pull Request resolved: pytorch#149150 pytorch#147146 Test Plan: unit test Differential Revision: D69694585
…e in collective communication (pytorch#149150) Summary: pytorch#147146 Test Plan: unit test Differential Revision: D69694585
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o