@@ -534,6 +534,29 @@ def forward(self, x):
534
534
weighted = torch .matmul (attention , v )
535
535
return weighted
536
536
537
+ class Conv2dFlattenTranspose (nn .Module ):
538
+ def __init__ (self ):
539
+ super ().__init__ ()
540
+ self .projection = torch .nn .Conv2d (
541
+ 3 , 768 , kernel_size = (16 , 16 ), stride = (16 , 16 )
542
+ )
543
+ self .cls_token = torch .rand (1 , 1 , 768 )
544
+
545
+ def forward (self , pixel_values ):
546
+ embeddings = self .projection (pixel_values ).flatten (2 ).transpose (1 , 2 )
547
+ embeddings = torch .cat ((self .cls_token , embeddings ), dim = 1 )
548
+ return embeddings
549
+
550
+ class Conv2dFlattenCatTranspose (nn .Module ):
551
+ def __init__ (self ):
552
+ super ().__init__ ()
553
+ self .conv = torch .nn .Conv2d (3 , 768 , kernel_size = (16 , 16 ), stride = (16 , 16 ))
554
+
555
+ def forward (self , x ):
556
+ y = self .conv (x ).flatten (2 )
557
+ y = torch .cat ([y , y ], dim = - 1 )
558
+ return y .transpose (1 , 2 )
559
+
537
560
538
561
class X86InductorQuantTestCase (QuantizationTestCase ):
539
562
def _test_quantizer (
@@ -944,15 +967,97 @@ def test_adaptive_avg_pool2d_recipe(self):
944
967
@skipIfNoX86
945
968
def test_flatten_recipe (self ):
946
969
r"""
947
- Test pattern: int8_in_int8_out_ops(flatten) - non_quantizable op(pow)
948
- Since flatten is a int8_in_int8_out_op, there is obs between flatten and pow.
970
+ Test pattern: conv -> flatten -> cat -> transpose
949
971
"""
950
- self ._single_op_share_observer_recipe_test_helper (
951
- TestHelperModules .Conv2dSingleOpPowModule (
952
- lambda x : torch .flatten (x , 1 )
953
- ).eval (),
954
- torch .rand (1 , 2 , 14 , 14 ),
972
+ m = TestHelperModules .Conv2dFlattenCatTranspose ().eval ()
973
+ x = torch .randn (1 , 3 , 224 , 224 )
974
+ quantizer = X86InductorQuantizer ().set_global (
975
+ xiq .get_default_x86_inductor_quantization_config ()
976
+ )
977
+ example_inputs = (x ,)
978
+ node_occurrence = {
979
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 4 ,
980
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 4 ,
981
+ # quantize_per_channel for weights are const propagated
982
+ torch .ops .quantized_decomposed .quantize_per_channel .default : 0 ,
983
+ torch .ops .quantized_decomposed .dequantize_per_channel .default : 1 ,
984
+ }
985
+ node_list = [
986
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
987
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
988
+ torch .ops .aten .conv2d .default ,
989
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
990
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
955
991
torch .ops .aten .flatten .using_ints ,
992
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
993
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
994
+ torch .ops .aten .cat .default ,
995
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
996
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
997
+ ]
998
+ _ , prepare_model , _ = self ._test_quantizer (
999
+ m ,
1000
+ example_inputs ,
1001
+ quantizer ,
1002
+ node_occurrence ,
1003
+ node_list ,
1004
+ )
1005
+ # Check Flatten has share observer at input and output
1006
+ for node in prepare_model .graph .nodes :
1007
+ if (
1008
+ node .op == "call_function"
1009
+ and node .target is torch .ops .aten .flatten .using_ints
1010
+ ):
1011
+ single_op_node = node
1012
+ input_obs_of_single_op = getattr (
1013
+ prepare_model , single_op_node .args [0 ].target
1014
+ )
1015
+ output_obs_of_single_op = getattr (
1016
+ prepare_model , next (iter (single_op_node .users )).target
1017
+ )
1018
+ elif (
1019
+ node .op == "call_function"
1020
+ and node .target is torch .ops .aten .conv2d .default
1021
+ ):
1022
+ conv_node = node
1023
+ input_obs_of_conv = getattr (prepare_model , conv_node .args [0 ].target )
1024
+ self .assertTrue (isinstance (input_obs_of_single_op , ObserverBase ))
1025
+ self .assertTrue (isinstance (output_obs_of_single_op , ObserverBase ))
1026
+ self .assertTrue (isinstance (input_obs_of_conv , ObserverBase ))
1027
+ self .assertTrue (input_obs_of_single_op is output_obs_of_single_op )
1028
+ self .assertTrue (input_obs_of_single_op is not input_obs_of_conv )
1029
+
1030
+ @skipIfNoX86
1031
+ def test_flatten_recipe2 (self ):
1032
+ r"""
1033
+ Test pattern: conv -> flatten -> transpose
1034
+ """
1035
+ m = TestHelperModules .Conv2dFlattenTranspose ().eval ()
1036
+ x = torch .randn (1 , 3 , 224 , 224 )
1037
+ quantizer = X86InductorQuantizer ().set_global (
1038
+ xiq .get_default_x86_inductor_quantization_config ()
1039
+ )
1040
+ example_inputs = (x ,)
1041
+ node_occurrence = {
1042
+ torch .ops .quantized_decomposed .quantize_per_tensor .default : 1 ,
1043
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default : 1 ,
1044
+ # quantize_per_channel for weights are const propagated
1045
+ torch .ops .quantized_decomposed .quantize_per_channel .default : 0 ,
1046
+ torch .ops .quantized_decomposed .dequantize_per_channel .default : 1 ,
1047
+ }
1048
+ node_list = [
1049
+ torch .ops .quantized_decomposed .quantize_per_tensor .default ,
1050
+ torch .ops .quantized_decomposed .dequantize_per_tensor .default ,
1051
+ torch .ops .aten .conv2d .default ,
1052
+ torch .ops .aten .flatten .using_ints ,
1053
+ torch .ops .aten .transpose .int ,
1054
+ ]
1055
+ self ._test_quantizer (
1056
+ m ,
1057
+ example_inputs ,
1058
+ quantizer ,
1059
+ node_occurrence ,
1060
+ node_list ,
956
1061
)
957
1062
958
1063
@skipIfNoX86
0 commit comments