Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Apr 9, 2024

for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However, DTensor.untyped_storage().data_ptr() does not work in _same_storage. Thus desugar to DTensor._local_tensor.untyped_storage().data_ptr() #123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy requested a review from a team as a code owner April 9, 2024 01:45
Copy link

pytorch-bot bot commented Apr 9, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/123617

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit e9861b0 with merge base 61be884 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Apr 9, 2024
@weifengpy weifengpy marked this pull request as draft April 9, 2024 01:45
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy marked this pull request as ready for review April 9, 2024 16:52
@weifengpy weifengpy requested a review from awgu April 9, 2024 16:52
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! We may need to move the DTensor import into _same_storage() to avoid breaking internal.

import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed._tensor import DTensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a great state to be in, but I always remember that we cannot import DTensor at the top-level of this file, or else we may break some internal torch package or torch deploy thing.

I am not too familiar with the issue though :/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the breakage caused by circular dependency, if you can recall?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I honestly cannot remember :(

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got you. will import DTensor inside the function

fsdp_world_size = self.world_size // tp_world_size
assert (
type(tp_fsdp_model) is FSDP and len(list(tp_fsdp_model.parameters())) == 1
type(tp_fsdp_model) is FSDP
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this change is to make the check stricter to more accurately reflect our assumptions?

Copy link
Contributor Author

@weifengpy weifengpy Apr 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to make it work for use_orig=True when tp_fsdp_model.parameters() > 1

torch.cat(
[
torch.flatten(param.grad)
if param.grad is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed for use_orig_params=True specifically?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for use_orig_params=True specifically

flat_param.grad[~sharded_mask] = grad[~sharded_mask]
# Average *all* gradient elements to match the FSDP only semantics
flat_param.grad /= tp_world_size
for flat_param in tp_fsdp_model.params:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is len(tp_fsdp_model.params) > 1 iff use_orig_params=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's right

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 10, 2024
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

sanketpurandare pushed a commit to sanketpurandare/pytorch that referenced this pull request Apr 22, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: pytorch#123617
Approved by: https://github.com/awgu
@mvpatel2000
Copy link
Contributor

@weifengpy do you think we can include in torch 2.3.1?
#125425

@weifengpy
Copy link
Contributor Author

for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However, DTensor.untyped_storage().data_ptr() does not work in _same_storage. Thus desugar to DTensor._local_tensor.untyped_storage().data_ptr() #123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang

@mvpatel2000 Just checked I need to cherry-pick this commit otherwiese torch 2.3.1 won't include this fix. Will file a PR to see if we can make it

petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: pytorch#123617
Approved by: https://github.com/awgu
weifengpy added a commit to weifengpy/pytorch that referenced this pull request May 16, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: pytorch#123617
Approved by: https://github.com/awgu
weifengpy added a commit to weifengpy/pytorch that referenced this pull request May 23, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: pytorch#123617
Approved by: https://github.com/awgu
weifengpy added a commit to weifengpy/pytorch that referenced this pull request May 23, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` pytorch#123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: pytorch#123617
Approved by: https://github.com/awgu
atalman pushed a commit that referenced this pull request May 23, 2024
for FSDP (SHARD_GRAD_OP + use_orig_params) + TP, params in the backward are DTensors. However,  ``DTensor.untyped_storage().data_ptr()`` does not work in ``_same_storage``. Thus desugar to ``DTensor._local_tensor.untyped_storage().data_ptr()`` #123272

credit to @bigning for the original fix. after landing, we would not need patching in mosaic composer https://github.com/mosaicml/composer/pull/3175/files

Pull Request resolved: #123617
Approved by: https://github.com/awgu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants