-
Notifications
You must be signed in to change notification settings - Fork 355
Remove zero_point_domain from quant configs #2058
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -156,48 +156,109 @@ def from_plain( | |
| zero_point: Optional[torch.Tensor], | ||
| layout: Layout, | ||
| bias: Optional[torch.Tensor] = None, | ||
| *, | ||
| validate_inputs: bool = True, | ||
| ): | ||
| assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout) | ||
| assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" | ||
| assert layout.target in [ | ||
| t for t, _ in _TARGET_AND_STR | ||
| ], f"Unexpected target: {layout.target}" | ||
| assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain" | ||
|
|
||
| if layout.target != Target.ATEN: | ||
| _check_torchao_ops_loaded() | ||
| else: | ||
| assert ( | ||
| TORCH_VERSION_AT_LEAST_2_6 | ||
| ), "aten target is requires torch version > 2.6.0" | ||
| assert ( | ||
| torch.backends.kleidiai.is_available() | ||
| ), "ATEN target requires torch.backends.kleidiai.is_available()" | ||
| layout.bit_width == 4, "ATEN target only supports torch.int4" | ||
| assert not layout.has_weight_zeros, "ATEN target does not support zeros" | ||
|
|
||
| data_dtype = getattr(torch, f"int{layout.bit_width}") | ||
| qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype] | ||
|
|
||
| int_types = [torch.int8, torch.int16, torch.int32, torch.int64] | ||
|
|
||
| # Check int_data | ||
| assert int_data.device == torch.device("cpu") | ||
| assert int_data.dtype in int_types | ||
| n, k = int_data.shape | ||
| assert int_data.dtype in int_types, f"int_data.dtype must be {int_types}" | ||
| assert k % layout.group_size == 0, "k must be divisible by group_size" | ||
| if validate_inputs: | ||
| assert int_data.min().item() >= qmin | ||
| assert int_data.max().item() <= qmax | ||
| int_data = int_data.to(torch.int8) | ||
|
|
||
| assert scale.dtype == torch.float32, "scale must be float32" | ||
| # Check scale | ||
| assert scale.device == torch.device("cpu") | ||
| if scale.dtype != torch.float32: | ||
| logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32") | ||
| scale = scale.to(torch.float32) | ||
| n_, _ = scale.shape | ||
| assert n_ == n | ||
| assert ( | ||
| scale.numel() * layout.group_size == int_data.numel() | ||
| ), "must have 1 scale per group" | ||
|
|
||
| assert (zero_point is not None) == ( | ||
| layout.has_weight_zeros | ||
| ), "zero_point being None must be consistent with layout.has_weight_zeros" | ||
| if zero_point is not None: | ||
| if validate_inputs: | ||
| assert scale.min().item() > 0 | ||
| # Some targets round scales to bfloat16, give warning if scales are at higher precision | ||
| scale_is_rounded_to_bf16 = torch.allclose( | ||
| scale, scale.to(torch.bfloat16).to(torch.float32) | ||
| ) | ||
| if not scale_is_rounded_to_bf16: | ||
| if layout.target == Target.ATEN and (layout.group_size < k): | ||
| logging.warning( | ||
| "When using Target.ATEN with group_size < k, scales will be rounded to bfloat16" | ||
| ) | ||
| if layout.target in [Target.AUTO, Target.KLEIDIAI]: | ||
| logging.warning( | ||
| "When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16" | ||
| ) | ||
|
|
||
| # Check zero_point | ||
| if zero_point is None: | ||
| assert ( | ||
| zero_point.dtype in int_types | ||
| ), f"zero_point.dtype must be {int_types}" | ||
| not layout.has_weight_zeros | ||
| ), "zero_point must be provided if has_weight_zeros=True" | ||
| else: | ||
| assert zero_point.device == torch.device("cpu") | ||
| assert zero_point.shape == scale.shape | ||
| assert zero_point.dtype in int_types | ||
| assert ( | ||
| zero_point.numel() * layout.group_size == int_data.numel() | ||
| ), "must have 1 zero_point per group" | ||
| if validate_inputs: | ||
| zero_point_min = zero_point.min().item() | ||
| zero_point_max = zero_point.max().item() | ||
| assert zero_point.min().item() >= qmin | ||
| assert zero_point.max().item() <= qmax | ||
| has_weight_zeros = True | ||
| if zero_point_min == 0 and zero_point_max == 0: | ||
| has_weight_zeros = False | ||
| assert ( | ||
| has_weight_zeros == layout.has_weight_zeros | ||
| ), "zero_point being all zeros must be consistent with layout.has_weight_zeros" | ||
| zero_point = zero_point.to(torch.int8) | ||
|
|
||
| assert (bias is not None) == ( | ||
| layout.has_bias | ||
| # Check bias | ||
| has_bias = bias is not None | ||
| assert ( | ||
| has_bias == layout.has_bias | ||
| ), "bias being None must be consistent with layout.has_bias" | ||
| if bias is not None: | ||
| assert bias.dtype == torch.float32, "bias.dtype must be float32" | ||
| assert bias.shape == (n,), "bias must have shape n" | ||
| if has_bias: | ||
| assert bias.device == torch.device("cpu") | ||
| if bias.dtype != torch.float32: | ||
| logging.info( | ||
| f"bias has dtype {bias.dtype}, converting to torch.float32" | ||
| ) | ||
| bias = bias.to(torch.float32) | ||
| assert bias.shape == (n,) | ||
|
|
||
| # Construct packed_weight | ||
| if layout.target == Target.ATEN: | ||
| assert ( | ||
| TORCH_VERSION_AT_LEAST_2_6 | ||
| ), "aten target is requires torch version > 2.6.0" | ||
| int_data = int_data.add(8) | ||
| int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) | ||
|
|
||
|
|
@@ -213,12 +274,11 @@ def from_plain( | |
| args = [ | ||
| int_data, | ||
| scale.reshape(-1), | ||
| zero_point.reshape(-1) if zero_point is not None else None, | ||
| zero_point.reshape(-1) if layout.has_weight_zeros else None, | ||
| layout.group_size, | ||
| bias, | ||
| target_to_str(layout.target) if layout.target != Target.AUTO else None, | ||
| ] | ||
|
|
||
| packed_weight = getattr( | ||
| torch.ops.torchao, | ||
| f"_pack_8bit_act_{layout.bit_width}bit_weight", | ||
|
|
@@ -358,79 +418,35 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( | |
| assert TORCH_VERSION_AT_LEAST_2_6, "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0" | ||
|
|
||
| layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target) | ||
| if layout.target != Target.ATEN: | ||
| _check_torchao_ops_loaded() | ||
| else: | ||
| assert ( | ||
| torch.backends.kleidiai.is_available() | ||
| ), "ATEN target requires torch.backends.kleidiai.is_available()" | ||
| assert data_dtype == torch.int4, "ATEN target only supports torch.int4" | ||
| assert zero_point is None, "ATEN target does not support zeros" | ||
|
|
||
| assert data_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)] | ||
| qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype] | ||
| bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype] | ||
| qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype] | ||
|
|
||
| int_types = [torch.int8, torch.int16, torch.int32, torch.int64] | ||
|
|
||
| # Check int_data | ||
| assert int_data.device == torch.device("cpu") | ||
| assert int_data.dtype in int_types | ||
| n, k = int_data.shape | ||
| if validate_inputs: | ||
| assert int_data.min().item() >= qmin | ||
| assert int_data.max().item() <= qmax | ||
|
|
||
| # Check scale | ||
| assert scale.device == torch.device("cpu") | ||
| if scale.dtype != torch.float32: | ||
| logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32") | ||
| scale = scale.to(torch.float32) | ||
| n_, groups_per_k = scale.shape | ||
| assert n_ == n | ||
| assert k % groups_per_k == 0 | ||
| group_size = k // groups_per_k | ||
| if validate_inputs: | ||
| assert scale.min().item() > 0 | ||
|
|
||
| if validate_inputs: | ||
| # Some targets round scales to bfloat16, give warning if scales are at higher precision | ||
| scale_is_rounded_to_bf16 = torch.allclose( | ||
| scale, scale.to(torch.bfloat16).to(torch.float32) | ||
| ) | ||
| if not scale_is_rounded_to_bf16: | ||
| if layout.target == Target.ATEN and (group_size < k): | ||
| logging.warning( | ||
| "When using Target.ATEN with group_size < k, scales will be rounded to bfloat16" | ||
| ) | ||
| if layout.target in [Target.AUTO, Target.KLEIDIAI]: | ||
| logging.warning( | ||
| "When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16" | ||
| ) | ||
|
|
||
| # Check zero_point | ||
| has_weight_zeros = zero_point is not None | ||
| if has_weight_zeros: | ||
| assert zero_point.device == torch.device("cpu") | ||
| assert zero_point.shape == scale.shape | ||
| assert zero_point.dtype in int_types | ||
| if validate_inputs: | ||
| assert zero_point.min().item() >= qmin | ||
| assert zero_point.max().item() <= qmax | ||
| has_weight_zeros = True | ||
| if zero_point is None: | ||
| has_weight_zeros = False | ||
| else: | ||
| zero_point_min = zero_point.min().item() | ||
| zero_point_max = zero_point.max().item() | ||
| if zero_point_min == 0 and zero_point_max == 0: | ||
| has_weight_zeros = False | ||
|
Comment on lines
+434
to
+437
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: might be better to have a util to check for zero_point being all zero and use it everywhere to keep consistent |
||
|
|
||
| # Check bias | ||
| has_bias = bias is not None | ||
| if has_bias: | ||
| assert bias.device == torch.device("cpu") | ||
| if bias.dtype != torch.float32: | ||
| logging.info(f"bias has dtype {bias.dtype}, converting to torch.float32") | ||
| bias = bias.to(torch.float32) | ||
| assert bias.shape == (n,) | ||
|
|
||
| layout.set_params(bit_width, group_size, has_weight_zeros, has_bias) | ||
| assert layout.has_params_set() | ||
| tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl.from_plain( | ||
| int_data, scale, zero_point, layout, bias | ||
| int_data, | ||
| scale, | ||
| zero_point, | ||
| layout, | ||
| bias, | ||
| validate_inputs=validate_inputs, | ||
| ) | ||
|
|
||
| return AffineQuantizedTensor( | ||
|
|
@@ -439,7 +455,5 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor( | |
| shape=int_data.shape, | ||
| quant_min=qmin, | ||
| quant_max=qmax, | ||
| zero_point_domain=ZeroPointDomain.INT | ||
| if has_weight_zeros | ||
| else ZeroPointDomain.NONE, | ||
| zero_point_domain=ZeroPointDomain.INT, | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is this needed? I think ideally should be enforced by the choose qparams by construction
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You do not need to go through the quant config to construct a packed tensor. You can construct one directly with from_plain.
When going from quant_config, it sets validate_inputs = False in from_plain because it's already enforced by construction. But if someone calls from_plain outside of this, I wanted the inputs to be validated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I see