Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 797c7e2

Browse files
blzhengpytorchmergebot
authored andcommitted
[Quant][PT2E]change flatten recipe for X86InductorQuantizer (pytorch#136298)
This PR modifies the flatten recipe: if none of the users of the flatten node are quantizable ops, int8 flatten will be disabled to avoid unnecessary dtype conversions. Pull Request resolved: pytorch#136298 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
1 parent 3be1506 commit 797c7e2

File tree

3 files changed

+128
-10
lines changed

3 files changed

+128
-10
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2084,7 +2084,7 @@ def matcher_check_fn():
20842084
@skipIfNoDynamoSupport
20852085
def test_qflatten(self):
20862086
r"""
2087-
This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten pattern.
2087+
This testcase will quantize Conv2d->AdaptiveAvgPool2d->flatten->cat pattern.
20882088
"""
20892089

20902090
class M(torch.nn.Module):
@@ -2099,8 +2099,12 @@ def __init__(
20992099
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
21002100

21012101
def forward(self, x):
2102-
return torch.flatten(
2103-
self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1
2102+
return torch.cat(
2103+
[
2104+
torch.flatten(
2105+
self.adaptive_avg_pool2d(self.relu(self.conv(x))), 1
2106+
)
2107+
]
21042108
)
21052109

21062110
mod = M().eval()

test/quantization/pt2e/test_x86inductor_quantizer.py

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,29 @@ def forward(self, x):
534534
weighted = torch.matmul(attention, v)
535535
return weighted
536536

537+
class Conv2dFlattenTranspose(nn.Module):
538+
def __init__(self):
539+
super().__init__()
540+
self.projection = torch.nn.Conv2d(
541+
3, 768, kernel_size=(16, 16), stride=(16, 16)
542+
)
543+
self.cls_token = torch.rand(1, 1, 768)
544+
545+
def forward(self, pixel_values):
546+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
547+
embeddings = torch.cat((self.cls_token, embeddings), dim=1)
548+
return embeddings
549+
550+
class Conv2dFlattenCatTranspose(nn.Module):
551+
def __init__(self):
552+
super().__init__()
553+
self.conv = torch.nn.Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
554+
555+
def forward(self, x):
556+
y = self.conv(x).flatten(2)
557+
y = torch.cat([y, y], dim=-1)
558+
return y.transpose(1, 2)
559+
537560

538561
class X86InductorQuantTestCase(QuantizationTestCase):
539562
def _test_quantizer(
@@ -944,15 +967,97 @@ def test_adaptive_avg_pool2d_recipe(self):
944967
@skipIfNoX86
945968
def test_flatten_recipe(self):
946969
r"""
947-
Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow)
948-
Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow.
970+
Test pattern: conv -> flatten -> cat -> transpose
949971
"""
950-
self._single_op_share_observer_recipe_test_helper(
951-
TestHelperModules.Conv2dSingleOpPowModule(
952-
lambda x: torch.flatten(x, 1)
953-
).eval(),
954-
torch.rand(1, 2, 14, 14),
972+
m = TestHelperModules.Conv2dFlattenCatTranspose().eval()
973+
x = torch.randn(1, 3, 224, 224)
974+
quantizer = X86InductorQuantizer().set_global(
975+
xiq.get_default_x86_inductor_quantization_config()
976+
)
977+
example_inputs = (x,)
978+
node_occurrence = {
979+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 4,
980+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 4,
981+
# quantize_per_channel for weights are const propagated
982+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
983+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
984+
}
985+
node_list = [
986+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
987+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
988+
torch.ops.aten.conv2d.default,
989+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
990+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
955991
torch.ops.aten.flatten.using_ints,
992+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
993+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
994+
torch.ops.aten.cat.default,
995+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
996+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
997+
]
998+
_, prepare_model, _ = self._test_quantizer(
999+
m,
1000+
example_inputs,
1001+
quantizer,
1002+
node_occurrence,
1003+
node_list,
1004+
)
1005+
# Check Flatten has share observer at input and output
1006+
for node in prepare_model.graph.nodes:
1007+
if (
1008+
node.op == "call_function"
1009+
and node.target is torch.ops.aten.flatten.using_ints
1010+
):
1011+
single_op_node = node
1012+
input_obs_of_single_op = getattr(
1013+
prepare_model, single_op_node.args[0].target
1014+
)
1015+
output_obs_of_single_op = getattr(
1016+
prepare_model, next(iter(single_op_node.users)).target
1017+
)
1018+
elif (
1019+
node.op == "call_function"
1020+
and node.target is torch.ops.aten.conv2d.default
1021+
):
1022+
conv_node = node
1023+
input_obs_of_conv = getattr(prepare_model, conv_node.args[0].target)
1024+
self.assertTrue(isinstance(input_obs_of_single_op, ObserverBase))
1025+
self.assertTrue(isinstance(output_obs_of_single_op, ObserverBase))
1026+
self.assertTrue(isinstance(input_obs_of_conv, ObserverBase))
1027+
self.assertTrue(input_obs_of_single_op is output_obs_of_single_op)
1028+
self.assertTrue(input_obs_of_single_op is not input_obs_of_conv)
1029+
1030+
@skipIfNoX86
1031+
def test_flatten_recipe2(self):
1032+
r"""
1033+
Test pattern: conv -> flatten -> transpose
1034+
"""
1035+
m = TestHelperModules.Conv2dFlattenTranspose().eval()
1036+
x = torch.randn(1, 3, 224, 224)
1037+
quantizer = X86InductorQuantizer().set_global(
1038+
xiq.get_default_x86_inductor_quantization_config()
1039+
)
1040+
example_inputs = (x,)
1041+
node_occurrence = {
1042+
torch.ops.quantized_decomposed.quantize_per_tensor.default: 1,
1043+
torch.ops.quantized_decomposed.dequantize_per_tensor.default: 1,
1044+
# quantize_per_channel for weights are const propagated
1045+
torch.ops.quantized_decomposed.quantize_per_channel.default: 0,
1046+
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
1047+
}
1048+
node_list = [
1049+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
1050+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
1051+
torch.ops.aten.conv2d.default,
1052+
torch.ops.aten.flatten.using_ints,
1053+
torch.ops.aten.transpose.int,
1054+
]
1055+
self._test_quantizer(
1056+
m,
1057+
example_inputs,
1058+
quantizer,
1059+
node_occurrence,
1060+
node_list,
9561061
)
9571062

9581063
@skipIfNoX86

torch/ao/quantization/quantizer/x86_inductor_quantizer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,15 @@ def is_all_inputs_connected_to_quantized_op(input_nodes):
13791379
if not is_all_inputs_connected_to_quantized_op(input_nodes_to_check):
13801380
return
13811381
self._annotate_cat(node, quantization_config)
1382+
elif (
1383+
node.target is torch.ops.aten.flatten.using_ints
1384+
and len(node.users) > 0
1385+
and not any(
1386+
user.target in quantizable_ops for user in node.users.keys()
1387+
)
1388+
):
1389+
# Recipe of flatten: check if any users of flatten node are quantizable ops or not
1390+
return
13821391
else:
13831392
input_node = node.all_input_nodes[0]
13841393
if not is_all_inputs_connected_to_quantized_op(

0 commit comments

Comments
 (0)