-
Notifications
You must be signed in to change notification settings - Fork 349
add axiswise granularity to Float8Tensor #919
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/919
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 1f01df9 with merge base 5dd0132 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I skimmed through it, LGTM!
I have one suggestion to have a more streamlined and extendable API, see below
scaling_granularity=ScalingGranularity.AXISWISE, | ||
axiswise_dim=axiswise_dim, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion for another API: instead of an enum + extra params on a case-by-case basis, we could reuse the same idea that @drisspg used in the _scaled_mm operator: deduce the kind of scaling based on the size/shape of the desired scale tensor!
Concretely, we could add a single scale_shape=...
parameter, which for row-wise would be [-1, 1]
, indicating that:
- all columns (second dim) should be grouped and reduced into a single scaling factor (because the second element has a value of 1)
- but that for the rows (first dim) there should be as many scaling factors as there are rows (because the first element has a value of -1, which gets replaced with the dim of the input tensor).
The scale shape is right-aligned to the shape of the tensor (thus following PyTorch's standard broadcast semantics), and then left-padded with 1
(again, standard semantics). This means that tensor-wise scaling is achieved with a scale_size=[]
.
Using this convention will later allow to express block-wise scaling (e.g., 128x128), group-wise scaling (1x128) and maybe even column-wise scaling if that ever becomes a thing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One wrinkle to work through would be that Float8Tensor
can be of any rank, but operand inputs to torch._scaled_mm
are required to be of rank 2, to match torch.mm|torch.addmm
.
I'm definitely open to making this more flexible in the future. We've been careful to keep Float8Tensor
and these utility functions out of the public API, to give us the freedom to make these kinds of changes as other scaling types become more important.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, if someone puts up a PR for ^, sgtm!
# would be initialized in every iteration. | ||
self.enable_pre_and_post_forward = self.config.enable_pre_and_post_forward | ||
|
||
# See the comments in config.py for more details of this option. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically not related to this PR, but making the test logs non-spammy for now and we can add this back in a better way later
Summary: This is a copy-paste of meta-pytorch/float8_experimental#352 which never landed. Test Plan: Reviewers: Subscribers: Tasks: Tags:
* Use ao's int4 quantizer * Point AO to commit hash of Jerry's fix * When device is cuda, only run for dtype==bfloat16 Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Typo Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: * Use tensor subclass for int4 weight only quant * Fix bug * Fix * Use both quantizer and subclass API * Bug * unwrap tensor subclass for aoti * Add import * Eval fix * Evaluate AOTI --------- Co-authored-by: Mengwei Liu <[email protected]>
Summary:
This is a copy-paste of meta-pytorch/float8_experimental#352
which never landed.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags: