Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions benchmarks/microbenchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,24 +255,30 @@ def string_to_config(
group_size = int(_quant_args[2])
return UIntXWeightOnlyConfig(dtype, group_size, use_hqq=use_hqq)
elif "int8_dynamic_activation_intx_weight" in quantization:
from torchao.experimental.quant_api import (
Int8DynamicActivationIntxWeightConfig,
)
from torchao.quantization.granularity import PerGroup

assert (
high_precision_dtype == torch.float32
), "int8_dynamic_activation_intx_weight requires using high_precision_dtype=torch.float32"

from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
)

# Quantize model
_quant_args = quantization.split("-")
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
granularity = PerGroup(int(_quant_args[2]))
has_weight_zeros = bool(_quant_args[3])
group_size = int(_quant_args[2])
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
is_asymmetric = bool(_quant_args[3])
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
granularity=granularity,
has_weight_zeros=has_weight_zeros,
weight_granularity=granularity,
weight_mapping_type=MappingType.ASYMMETRIC
if is_asymmetric
else MappingType.SYMMETRIC,
weight_scale_dtype=torch.bfloat16,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
)
elif "float8wo" in quantization:
return Float8WeightOnlyConfig()
Expand Down
10 changes: 4 additions & 6 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,24 +568,22 @@ def ffn_or_attn_only(mod, fqn):
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
ZeroPointDomain,
)

# Quantize model
_quant_args = quantization.split("-")
weight_dtype = getattr(torch, f"int{_quant_args[1]}")
group_size = int(_quant_args[2])
granularity = PerGroup(group_size) if group_size > 0 else PerAxis(0)
has_weight_zeros = bool(_quant_args[3])
is_asymmetric = bool(_quant_args[3])
quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=weight_dtype,
weight_granularity=granularity,
weight_zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
weight_mapping_type=MappingType.ASYMMETRIC,
weight_mapping_type=MappingType.ASYMMETRIC
if is_asymmetric
else MappingType.SYMMETRIC,
weight_scale_dtype=torch.bfloat16,
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,48 +156,109 @@ def from_plain(
zero_point: Optional[torch.Tensor],
layout: Layout,
bias: Optional[torch.Tensor] = None,
*,
validate_inputs: bool = True,
Copy link
Contributor

@jerryzh168 jerryzh168 Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed? I think ideally should be enforced by the choose qparams by construction

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You do not need to go through the quant config to construct a packed tensor. You can construct one directly with from_plain.

When going from quant_config, it sets validate_inputs = False in from_plain because it's already enforced by construction. But if someone calls from_plain outside of this, I wanted the inputs to be validated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I see

):
assert isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"
assert layout.target in [
t for t, _ in _TARGET_AND_STR
], f"Unexpected target: {layout.target}"
assert layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params must be set before calling from_plain"

if layout.target != Target.ATEN:
_check_torchao_ops_loaded()
else:
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
assert (
torch.backends.kleidiai.is_available()
), "ATEN target requires torch.backends.kleidiai.is_available()"
layout.bit_width == 4, "ATEN target only supports torch.int4"
assert not layout.has_weight_zeros, "ATEN target does not support zeros"

data_dtype = getattr(torch, f"int{layout.bit_width}")
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]

int_types = [torch.int8, torch.int16, torch.int32, torch.int64]

# Check int_data
assert int_data.device == torch.device("cpu")
assert int_data.dtype in int_types
n, k = int_data.shape
assert int_data.dtype in int_types, f"int_data.dtype must be {int_types}"
assert k % layout.group_size == 0, "k must be divisible by group_size"
if validate_inputs:
assert int_data.min().item() >= qmin
assert int_data.max().item() <= qmax
int_data = int_data.to(torch.int8)

assert scale.dtype == torch.float32, "scale must be float32"
# Check scale
assert scale.device == torch.device("cpu")
if scale.dtype != torch.float32:
logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32")
scale = scale.to(torch.float32)
n_, _ = scale.shape
assert n_ == n
assert (
scale.numel() * layout.group_size == int_data.numel()
), "must have 1 scale per group"

assert (zero_point is not None) == (
layout.has_weight_zeros
), "zero_point being None must be consistent with layout.has_weight_zeros"
if zero_point is not None:
if validate_inputs:
assert scale.min().item() > 0
# Some targets round scales to bfloat16, give warning if scales are at higher precision
scale_is_rounded_to_bf16 = torch.allclose(
scale, scale.to(torch.bfloat16).to(torch.float32)
)
if not scale_is_rounded_to_bf16:
if layout.target == Target.ATEN and (layout.group_size < k):
logging.warning(
"When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
)
if layout.target in [Target.AUTO, Target.KLEIDIAI]:
logging.warning(
"When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
)

# Check zero_point
if zero_point is None:
assert (
zero_point.dtype in int_types
), f"zero_point.dtype must be {int_types}"
not layout.has_weight_zeros
), "zero_point must be provided if has_weight_zeros=True"
else:
assert zero_point.device == torch.device("cpu")
assert zero_point.shape == scale.shape
assert zero_point.dtype in int_types
assert (
zero_point.numel() * layout.group_size == int_data.numel()
), "must have 1 zero_point per group"
if validate_inputs:
zero_point_min = zero_point.min().item()
zero_point_max = zero_point.max().item()
assert zero_point.min().item() >= qmin
assert zero_point.max().item() <= qmax
has_weight_zeros = True
if zero_point_min == 0 and zero_point_max == 0:
has_weight_zeros = False
assert (
has_weight_zeros == layout.has_weight_zeros
), "zero_point being all zeros must be consistent with layout.has_weight_zeros"
zero_point = zero_point.to(torch.int8)

assert (bias is not None) == (
layout.has_bias
# Check bias
has_bias = bias is not None
assert (
has_bias == layout.has_bias
), "bias being None must be consistent with layout.has_bias"
if bias is not None:
assert bias.dtype == torch.float32, "bias.dtype must be float32"
assert bias.shape == (n,), "bias must have shape n"
if has_bias:
assert bias.device == torch.device("cpu")
if bias.dtype != torch.float32:
logging.info(
f"bias has dtype {bias.dtype}, converting to torch.float32"
)
bias = bias.to(torch.float32)
assert bias.shape == (n,)

# Construct packed_weight
if layout.target == Target.ATEN:
assert (
TORCH_VERSION_AT_LEAST_2_6
), "aten target is requires torch version > 2.6.0"
int_data = int_data.add(8)
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)

Expand All @@ -213,12 +274,11 @@ def from_plain(
args = [
int_data,
scale.reshape(-1),
zero_point.reshape(-1) if zero_point is not None else None,
zero_point.reshape(-1) if layout.has_weight_zeros else None,
layout.group_size,
bias,
target_to_str(layout.target) if layout.target != Target.AUTO else None,
]

packed_weight = getattr(
torch.ops.torchao,
f"_pack_8bit_act_{layout.bit_width}bit_weight",
Expand Down Expand Up @@ -358,79 +418,35 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
assert TORCH_VERSION_AT_LEAST_2_6, "Using PackedLinearInt8DynamicActivationIntxWeightLayout requires torch version > 2.6.0"

layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target=target)
if layout.target != Target.ATEN:
_check_torchao_ops_loaded()
else:
assert (
torch.backends.kleidiai.is_available()
), "ATEN target requires torch.backends.kleidiai.is_available()"
assert data_dtype == torch.int4, "ATEN target only supports torch.int4"
assert zero_point is None, "ATEN target does not support zeros"

assert data_dtype in [getattr(torch, f"int{x}") for x in range(1, 9)]
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]
bit_width = _DTYPE_TO_BIT_WIDTH[data_dtype]
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[data_dtype]

int_types = [torch.int8, torch.int16, torch.int32, torch.int64]

# Check int_data
assert int_data.device == torch.device("cpu")
assert int_data.dtype in int_types
n, k = int_data.shape
if validate_inputs:
assert int_data.min().item() >= qmin
assert int_data.max().item() <= qmax

# Check scale
assert scale.device == torch.device("cpu")
if scale.dtype != torch.float32:
logging.info(f"scale has dtype {scale.dtype}, converting to torch.float32")
scale = scale.to(torch.float32)
n_, groups_per_k = scale.shape
assert n_ == n
assert k % groups_per_k == 0
group_size = k // groups_per_k
if validate_inputs:
assert scale.min().item() > 0

if validate_inputs:
# Some targets round scales to bfloat16, give warning if scales are at higher precision
scale_is_rounded_to_bf16 = torch.allclose(
scale, scale.to(torch.bfloat16).to(torch.float32)
)
if not scale_is_rounded_to_bf16:
if layout.target == Target.ATEN and (group_size < k):
logging.warning(
"When using Target.ATEN with group_size < k, scales will be rounded to bfloat16"
)
if layout.target in [Target.AUTO, Target.KLEIDIAI]:
logging.warning(
"When using [Target.AUTO, Target.KLEIDIAI], scales will be rounded to bfloat16"
)

# Check zero_point
has_weight_zeros = zero_point is not None
if has_weight_zeros:
assert zero_point.device == torch.device("cpu")
assert zero_point.shape == scale.shape
assert zero_point.dtype in int_types
if validate_inputs:
assert zero_point.min().item() >= qmin
assert zero_point.max().item() <= qmax
has_weight_zeros = True
if zero_point is None:
has_weight_zeros = False
else:
zero_point_min = zero_point.min().item()
zero_point_max = zero_point.max().item()
if zero_point_min == 0 and zero_point_max == 0:
has_weight_zeros = False
Comment on lines +434 to +437
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: might be better to have a util to check for zero_point being all zero and use it everywhere to keep consistent


# Check bias
has_bias = bias is not None
if has_bias:
assert bias.device == torch.device("cpu")
if bias.dtype != torch.float32:
logging.info(f"bias has dtype {bias.dtype}, converting to torch.float32")
bias = bias.to(torch.float32)
assert bias.shape == (n,)

layout.set_params(bit_width, group_size, has_weight_zeros, has_bias)
assert layout.has_params_set()
tensor_impl = PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl.from_plain(
int_data, scale, zero_point, layout, bias
int_data,
scale,
zero_point,
layout,
bias,
validate_inputs=validate_inputs,
)

return AffineQuantizedTensor(
Expand All @@ -439,7 +455,5 @@ def make_packed_linear_int8_dynamic_activation_intx_weight_tensor(
shape=int_data.shape,
quant_min=qmin,
quant_max=qmax,
zero_point_domain=ZeroPointDomain.INT
if has_weight_zeros
else ZeroPointDomain.NONE,
zero_point_domain=ZeroPointDomain.INT,
)
Loading
Loading