Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

awgu
Copy link
Collaborator

@awgu awgu commented Apr 29, 2024

Stack from ghstack (oldest at bottom):

For microbatching use cases (e.g. PP), we may use fp32 reduce-scatter (i.e. MixedPrecisionPolicy(reduce_dtype=torch.float32)), where we want to accumulate the unsharded gradients in fp32 across microbatches until reduce-scattering in fp32 upon the last microbatch.

Note that the unsharded_param is in bf16, so we must save the fp32 accumulated gradient to an attribute different from .grad. Moreover, saving a new attribute on the torch.Tensor leads to some annoying type checking issues (where the attribute may not be defined), so this PR prefers to save the attribute on the FSDPParam class instead.

One could argue that this behavior should be configurable, but since I think for large-scale training, everyone is leaning toward fp32 accumulation across microbatches, let us avoid adding another argument for now.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @penguinwu @fegin @XilunWu @wanchaol @fduwjj @wz337 @tianyu-l @wconstab @yf225 @chauhang @d4l3k

Copy link

pytorch-bot bot commented Apr 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/125191

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 4 Unrelated Failures

As of commit 0505ac7 with merge base 935a946 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ci-td-distributed oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Apr 29, 2024
@awgu awgu added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Apr 29, 2024
@awgu awgu marked this pull request as ready for review April 29, 2024 21:06
@awgu awgu requested review from wanchaol and weifengpy April 29, 2024 21:06
Copy link
Contributor

@weifengpy weifengpy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a confusion about cpu offloading + fsdp_param.unsharded_accumulated_grad_data but it turns out does not matter. cpu offloading is only applied to shareded params

For microbatching use cases (e.g. PP), we may use fp32 reduce-scatter (i.e. `MixedPrecisionPolicy(reduce_dtype=torch.float32)`), where we want to accumulate the unsharded gradients in fp32 across microbatches until reduce-scattering in fp32 upon the last microbatch.

Note that the `unsharded_param` is in bf16, so we must save the fp32 accumulated gradient to an attribute different from `.grad`. Moreover, saving a new attribute on the `torch.Tensor` leads to some annoying type checking issues (where the attribute may not be defined), so this PR prefers to save the attribute on the `FSDPParam` class instead.

One could argue that this behavior should be configurable, but since I think for large-scale training, everyone is leaning toward fp32 accumulation across microbatches, let us avoid adding another argument for now.



cc mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse H-Huang kwen2501 penguinwu fegin XilunWu wanchaol fduwjj wz337 tianyu-l wconstab yf225 chauhang d4l3k

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Apr 29, 2024
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Apr 29, 2024
_sharded_post_forward_param_data: Optional[torch.Tensor] # 1D
_sharded_post_forward_param: Optional[nn.Parameter] # ND
_unsharded_param: nn.Parameter # ND
unsharded_accumulated_grad: Optional[torch.Tensor] # ND
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep this be a private variable to align with all other fields?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think the notion of private and public is non-obvious for the FSDPParam class. The way I have is is that if FSDPParamGroup should access an attribute on FSDPParam, then that attribute should be public. (Note that FSDPParam will never be accessed publicly by the user, as there is no way to do that today and should not be a way in the future.)

For this attribute and the current implementation, FSDPParamGroup needs to check it if it is None or not to know if there is an unsharded accumulated gradient to reduce-scatter or not, so I have it as public.

@awgu
Copy link
Collaborator Author

awgu commented Apr 30, 2024

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

@awgu
Copy link
Collaborator Author

awgu commented Apr 30, 2024

Failures were all from the open registration test.

pytorchmergebot pushed a commit that referenced this pull request May 1, 2024
…125269)

1. This PR removes the logic for saving and removing the pre-backward hook handles (which is registered via `register_multi_grad_hook(mode="any")`).
2. This PR removes the logic for _trying_ to guard against mistargeted prefetches that relies on querying if the engine will execute the module output tensors' `grad_fn`s. (See #118118 for original motivation.)

For 1, the logic was error prone since it relied on `set_is_last_backward(False)` being set correctly or else pre-backward hooks could be de-registered too early. We would prefer to match the hook lifetimes with that of the autograd graph. This solves a bug with a 1f1b interleaved schedule.

If we directly remove the manual saving/removing hook handle logic, then we have a ref cycle where the tensors' `grad_fn`s are passed to the hook function. We decide to simply remove this `grad_fn` logic since (1) it cannot perfectly prevent mistargeted prefetches and (2) it introduces undesired complexity. In the future, we may prefer a different mechanism to override the prefetching for more complex/dynamic use cases.

Pull Request resolved: #125269
Approved by: https://github.com/weifengpy
ghstack dependencies: #125190, #125191
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…25191)

For microbatching use cases (e.g. PP), we may use fp32 reduce-scatter (i.e. `MixedPrecisionPolicy(reduce_dtype=torch.float32)`), where we want to accumulate the unsharded gradients in fp32 across microbatches until reduce-scattering in fp32 upon the last microbatch.

Note that the `unsharded_param` is in bf16, so we must save the fp32 accumulated gradient to an attribute different from `.grad`. Moreover, saving a new attribute on the `torch.Tensor` leads to some annoying type checking issues (where the attribute may not be defined), so this PR prefers to save the attribute on the `FSDPParam` class instead.

One could argue that this behavior should be configurable, but since I think for large-scale training, everyone is leaning toward fp32 accumulation across microbatches, let us avoid adding another argument for now.

Pull Request resolved: pytorch#125191
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#125190
petrex pushed a commit to petrex/pytorch that referenced this pull request May 3, 2024
…ytorch#125269)

1. This PR removes the logic for saving and removing the pre-backward hook handles (which is registered via `register_multi_grad_hook(mode="any")`).
2. This PR removes the logic for _trying_ to guard against mistargeted prefetches that relies on querying if the engine will execute the module output tensors' `grad_fn`s. (See pytorch#118118 for original motivation.)

For 1, the logic was error prone since it relied on `set_is_last_backward(False)` being set correctly or else pre-backward hooks could be de-registered too early. We would prefer to match the hook lifetimes with that of the autograd graph. This solves a bug with a 1f1b interleaved schedule.

If we directly remove the manual saving/removing hook handle logic, then we have a ref cycle where the tensors' `grad_fn`s are passed to the hook function. We decide to simply remove this `grad_fn` logic since (1) it cannot perfectly prevent mistargeted prefetches and (2) it introduces undesired complexity. In the future, we may prefer a different mechanism to override the prefetching for more complex/dynamic use cases.

Pull Request resolved: pytorch#125269
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#125190, pytorch#125191
@github-actions github-actions bot deleted the gh/awgu/575/head branch June 4, 2024 01:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-td-distributed ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants