Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,18 @@ def test_ao_per_module_config_embedding_linear(self):
assert isinstance(model.emb.weight._layout, QDQLayout)
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_ao_per_module_config_skip(self):
config1 = Int4WeightOnlyConfig(group_size=32)
config = AOPerModuleConfig({"_default": config1, "linear2": None})
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
quantize_(model, config)
model(*example_inputs)
assert isinstance(model.linear1.weight, AffineQuantizedTensor)
assert isinstance(model.linear1.weight._layout, TensorCoreTiledLayout)
assert not isinstance(model.linear2.weight, AffineQuantizedTensor)


class TestMultiTensorFlow(TestCase):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
PlainLayout,
TensorCoreTiledLayout,
UIntXWeightOnlyConfig,
Expand Down Expand Up @@ -139,6 +140,7 @@
"Float8StaticActivationFloat8WeightConfig",
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
"UIntXWeightOnlyConfig",
"IntxWeightOnlyConfig",
"FPXWeightOnlyConfig",
"GemliteUIntXWeightOnlyConfig",
"AOPerModuleConfig",
Expand Down
16 changes: 9 additions & 7 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ def quantize_(

"""
filter_fn = _is_linear if filter_fn is None else filter_fn

if isinstance(config, AOPerModuleConfig):
_replace_with_custom_fn_if_matches_filter_with_name(
model,
Expand Down Expand Up @@ -1975,18 +1976,19 @@ class AOPerModuleConfig(AOBaseConfig):
def _ao_per_module_config_handler(
module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig
):
c = config.module_fqn_to_config.get(module_fqn, None)
# Maybe: we can add module type specific config in the future, in needed
# fallback to use default if no module specific config is provided
default_c = config.module_fqn_to_config.get("_default", None)
if default_c is not None and c is None:
c = default_c
c = None
if module_fqn in config.module_fqn_to_config:
# Maybe: we can add module type specific config in the future, in needed
c = config.module_fqn_to_config[module_fqn]
else:
# fallback to use default if no module specific config is provided
c = config.module_fqn_to_config.get("_default", None)

if c is not None:
handler = _QUANTIZE_CONFIG_HANDLER[type(c)]
return handler(module, c)

return handler(module, c)
return module


if TORCH_VERSION_AT_LEAST_2_5:
Expand Down
Loading