Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ torchao.dtypes
to_affine_quantized_intx
to_affine_quantized_floatx
to_affine_quantized_intx_static
to_affine_quantized_floatx_static
AffineQuantizedTensor

..
Expand Down
2 changes: 2 additions & 0 deletions torchao/dtypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# experimental, will be merged into floatx in the future
to_affine_quantized_fpx,
to_affine_quantized_floatx,
to_affine_quantized_floatx_static,
LayoutType,
PlainLayoutType,
SemiSparseLayoutType,
Expand All @@ -25,6 +26,7 @@
"to_affine_quantized_intx_static",
"to_affine_quantized_fpx",
"to_affine_quantized_floatx",
"to_affine_quantized_floatx_static",
"LayoutType",
"PlainLayoutType",
"SemiSparseLayoutType",
Expand Down
33 changes: 31 additions & 2 deletions torchao/dtypes/affine_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,14 +269,17 @@ def from_hp_to_intx_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
zero_point: Optional[torch.Tensor],
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
layout_type: LayoutType = PlainLayoutType(),
):
if target_dtype not in FP8_TYPES:
assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types"
assert zero_point is not None, "zero_point must be specified for non-fp8 types"
original_shape = input_float.shape
input_float = layout_type.pre_process(input_float)

Expand Down Expand Up @@ -325,6 +328,31 @@ def from_hp_to_floatx(
else:
raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx")

@classmethod
def from_hp_to_floatx_static(
cls,
input_float: torch.Tensor,
scale: torch.Tensor,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
layout_type: LayoutType,
):

if target_dtype in FP8_TYPES:
return cls.from_hp_to_intx_static(
input_float=input_float,
scale=scale,
zero_point=None,
block_size=block_size,
target_dtype=target_dtype,
quant_min=math.ceil(torch.finfo(target_dtype).min),
quant_max=math.ceil(torch.finfo(target_dtype).max),
zero_point_domain=None,
layout_type=layout_type,
)
else:
raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static")

@classmethod
def from_hp_to_fpx(
cls,
Expand Down Expand Up @@ -1304,6 +1332,7 @@ def _(func, types, args, kwargs):
to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx
to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static
to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx
to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static
# experimental will be merged in to floatx
to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx

Expand Down
2 changes: 0 additions & 2 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
PlainLayoutType,
AffineQuantizedTensor,
SemiSparseLayoutType,
to_affine_quantized_floatx,
Float8AQTLayout,
Float8LayoutType
)
from torchao.utils import (
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,8 @@ def _choose_qparams_affine(
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC.name, MappingType.ASYMMETRIC.name], f"Unsupported mapping type: {mapping_type}"
if target_dtype in FP8_TYPES:
assert mapping_type == MappingType.SYMMETRIC.name, f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"

if input is not None:
if scale_dtype is None:
Expand Down
165 changes: 103 additions & 62 deletions tutorials/calibration_flow/static_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

import torch.nn.functional as F
from torch import Tensor
from torchao.dtypes import to_affine_quantized_intx_static
from torchao.dtypes import (
to_affine_quantized_intx_static,
to_affine_quantized_floatx_static,
Float8LayoutType,
)
from torchao.quantization.utils import compute_error
from torchao.quantization import quantize_
from torchao.quantization import to_linear_activation_quantized
Expand All @@ -18,6 +22,7 @@
)
from torchao.quantization.quant_primitives import (
MappingType,
FP8_TYPES,
)


Expand Down Expand Up @@ -51,53 +56,81 @@ def replacement_fn(m):

# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_static_quant(observed_linear):
target_dtype = torch.uint8

# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
def weight_quant_func(weight):
block_size = (1, weight.shape[1])
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype)
linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False)

return linear

def apply_static_quant(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_static_quant_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()
def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None))
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")
linear = torch.nn.Linear(observed_linear.in_features, observed_linear.out_features, False, device=observed_linear.weight.device, dtype=observed_linear.weight.dtype)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

linear.weight = torch.nn.Parameter(weight_quant_func(linear.weight), requires_grad=False)

# activation quantization
act_scale, act_zero_point = observed_linear.act_obs.calculate_qparams()
if target_dtype == torch.uint8:
input_quant_func = lambda x: to_affine_quantized_intx_static(x, act_scale, act_zero_point, x.shape, target_dtype)
elif target_dtype == torch.float8_e4m3fn:
input_quant_func = lambda x: to_affine_quantized_floatx_static(x, act_scale, x.shape, target_dtype, Float8LayoutType(mm_config=None))
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")
linear.weight = torch.nn.Parameter(to_linear_activation_quantized(linear.weight, input_quant_func), requires_grad=False)

return linear

return _apply_static_quant_to_linear

# alternative for converting observed linear module to quantized linear module
class QuantizedLinear(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor):
def __init__(self, in_features: int, out_features: int, act_obs: torch.nn.Module, weight_obs: torch.nn.Module, weight: torch.Tensor, bias: torch.Tensor, target_dtype: torch.dtype):
super().__init__()
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
assert weight.dim() == 2
block_size = (1, weight.shape[1])
target_dtype = torch.uint8
self.qweight = to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, target_dtype)
self.target_dtype = target_dtype
self.bias = bias
if self.target_dtype == torch.uint8:
self.qweight = to_affine_quantized_intx_static(weight, weight_scale, weight_zero_point, block_size, self.target_dtype)
elif self.target_dtype == torch.float8_e4m3fn:
self.qweight = to_affine_quantized_floatx_static(weight, weight_scale, block_size, target_dtype, Float8LayoutType(mm_config=None))
else:
raise ValueError(f"Unsupported target dtype {self.target_dtype}")

def forward(self, input: Tensor):
block_size = input.shape
target_dtype = torch.uint8
qinput = to_affine_quantized_intx_static(input, self.act_scale, self.act_zero_point, block_size, target_dtype)
if self.target_dtype == torch.uint8:
qinput = to_affine_quantized_intx_static(input, self.act_scale, self.act_zero_point, block_size, self.target_dtype)
elif self.target_dtype == torch.float8_e4m3fn:
qinput = to_affine_quantized_floatx_static(input, self.act_scale, block_size, self.target_dtype, Float8LayoutType(mm_config=None))
else:
raise ValueError(f"Unsupported target dtype {self.target_dtype}")
return F.linear(qinput, self.qweight, self.bias)

@classmethod
def from_observed(cls, observed_linear):
quantized_linear = cls(observed_linear.in_features, observed_linear.out_features, observed_linear.act_obs, observed_linear.weight_obs, observed_linear.weight, observed_linear.bias)
def from_observed(cls, observed_linear, target_dtype):
quantized_linear = cls(observed_linear.in_features,
observed_linear.out_features,
observed_linear.act_obs,
observed_linear.weight_obs,
observed_linear.weight,
observed_linear.bias,
target_dtype)
return quantized_linear

def apply_static_quant2(observed_linear):
return QuantizedLinear.from_observed(observed_linear)
def apply_static_quant2(target_dtype: torch.dtype):
def _apply_static_quant2(observed_linear):
return QuantizedLinear.from_observed(observed_linear, target_dtype)
return _apply_static_quant2

class ToyLinearModel(torch.nn.Module):
def __init__(self, m=64, n=32, k=64):
Expand All @@ -113,46 +146,54 @@ def forward(self, x):
x = self.linear2(x)
return x

torch.manual_seed(0)

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")
def test_static_quant(target_dtype: torch.dtype, mapping_type: MappingType):
print(f"Testing {target_dtype} static quantization:")
torch.manual_seed(0)

dtype = torch.bfloat16
m = ToyLinearModel().eval().to(dtype).to("cuda")

m_for_test = copy.deepcopy(m)

m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
print("example inputs shape:", example_inputs[0].shape)

m_for_test = copy.deepcopy(m)
m_bf16 = torch.compile(m_bf16, mode='max-autotune')

m_bf16 = copy.deepcopy(m)
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
print("example inputs shape:", example_inputs[0].shape)
act_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)
weight_obs = AffineQuantizedMinMaxObserver(mapping_type, target_dtype, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.float32)

m_bf16 = torch.compile(m_bf16, mode='max-autotune')
before_quant = m(*example_inputs)

act_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerTensor(), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
weight_obs = AffineQuantizedMinMaxObserver(MappingType.ASYMMETRIC, torch.uint8, granularity_type=PerAxis(axis=0), eps=torch.finfo(torch.float32).eps, scale_dtype=torch.float32, zero_point_dtype=torch.int32)
insert_observers_(m, act_obs, weight_obs)
# calibrating / training
for _ in range(10):
m(*example_inputs)

before_quant = m(*example_inputs)
after_obs = m(*example_inputs)

insert_observers_(m, act_obs, weight_obs)
# calibrating / training
for _ in range(10):
m(*example_inputs)
m2 = copy.deepcopy(m)

after_obs = m(*example_inputs)
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)

m2 = copy.deepcopy(m)
# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_static_quant(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
print("test passed")

is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
# quantized linear as a standalone module
quantize_(m2, apply_static_quant2(target_dtype), is_observed_linear)
print("quantized model (quantized module):", m2)
after_quant = m2(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
print("test passed")

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_static_quant, is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")

# quantized linear as a standalone module
quantize_(m2, apply_static_quant2, is_observed_linear)
print("quantized model (quantized module):", m2)
after_quant = m2(*example_inputs)
assert compute_error(before_quant, after_quant) > 30
print("test passed")
if __name__ == "__main__":
test_static_quant(torch.uint8, MappingType.ASYMMETRIC)
test_static_quant(torch.float8_e4m3fn, MappingType.SYMMETRIC)