@@ -3035,3 +3035,216 @@ def test_cli_file_input(self):
3035
3035
self .assertEqual (expected .splitlines (),
3036
3036
res .out .decode ("utf8" ).splitlines ())
3037
3037
self .assertEqual (res .rc , 0 )
3038
+
3039
+
3040
+ class ASTOptimiziationTests (unittest .TestCase ):
3041
+ binop = {
3042
+ "+" : ast .Add (),
3043
+ "-" : ast .Sub (),
3044
+ "*" : ast .Mult (),
3045
+ "/" : ast .Div (),
3046
+ "%" : ast .Mod (),
3047
+ "<<" : ast .LShift (),
3048
+ ">>" : ast .RShift (),
3049
+ "|" : ast .BitOr (),
3050
+ "^" : ast .BitXor (),
3051
+ "&" : ast .BitAnd (),
3052
+ "//" : ast .FloorDiv (),
3053
+ "**" : ast .Pow (),
3054
+ }
3055
+
3056
+ unaryop = {
3057
+ "~" : ast .Invert (),
3058
+ "+" : ast .UAdd (),
3059
+ "-" : ast .USub (),
3060
+ }
3061
+
3062
+ def wrap_expr (self , expr ):
3063
+ return ast .Module (body = [ast .Expr (value = expr )])
3064
+
3065
+ def wrap_for (self , for_statement ):
3066
+ return ast .Module (body = [for_statement ])
3067
+
3068
+ def assert_ast (self , code , non_optimized_target , optimized_target ):
3069
+ non_optimized_tree = ast .parse (code , optimize = - 1 )
3070
+ optimized_tree = ast .parse (code , optimize = 1 )
3071
+
3072
+ # Is a non-optimized tree equal to a non-optimized target?
3073
+ self .assertTrue (
3074
+ ast .compare (non_optimized_tree , non_optimized_target ),
3075
+ f"{ ast .dump (non_optimized_target )} must equal "
3076
+ f"{ ast .dump (non_optimized_tree )} " ,
3077
+ )
3078
+
3079
+ # Is a optimized tree equal to a non-optimized target?
3080
+ self .assertFalse (
3081
+ ast .compare (optimized_tree , non_optimized_target ),
3082
+ f"{ ast .dump (non_optimized_target )} must not equal "
3083
+ f"{ ast .dump (non_optimized_tree )} "
3084
+ )
3085
+
3086
+ # Is a optimized tree is equal to an optimized target?
3087
+ self .assertTrue (
3088
+ ast .compare (optimized_tree , optimized_target ),
3089
+ f"{ ast .dump (optimized_target )} must equal "
3090
+ f"{ ast .dump (optimized_tree )} " ,
3091
+ )
3092
+
3093
+ def test_folding_binop (self ):
3094
+ code = "1 %s 1"
3095
+ operators = self .binop .keys ()
3096
+
3097
+ def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3098
+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3099
+
3100
+ for op in operators :
3101
+ result_code = code % op
3102
+ non_optimized_target = self .wrap_expr (create_binop (op ))
3103
+ optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3104
+
3105
+ with self .subTest (
3106
+ result_code = result_code ,
3107
+ non_optimized_target = non_optimized_target ,
3108
+ optimized_target = optimized_target
3109
+ ):
3110
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3111
+
3112
+ # Multiplication of constant tuples must be folded
3113
+ code = "(1,) * 3"
3114
+ non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3115
+ optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3116
+
3117
+ self .assert_ast (code , non_optimized_target , optimized_target )
3118
+
3119
+ def test_folding_unaryop (self ):
3120
+ code = "%s1"
3121
+ operators = self .unaryop .keys ()
3122
+
3123
+ def create_unaryop (operand ):
3124
+ return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3125
+
3126
+ for op in operators :
3127
+ result_code = code % op
3128
+ non_optimized_target = self .wrap_expr (create_unaryop (op ))
3129
+ optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3130
+
3131
+ with self .subTest (
3132
+ result_code = result_code ,
3133
+ non_optimized_target = non_optimized_target ,
3134
+ optimized_target = optimized_target
3135
+ ):
3136
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3137
+
3138
+ def test_folding_not (self ):
3139
+ code = "not (1 %s (1,))"
3140
+ operators = {
3141
+ "in" : ast .In (),
3142
+ "is" : ast .Is (),
3143
+ }
3144
+ opt_operators = {
3145
+ "is" : ast .IsNot (),
3146
+ "in" : ast .NotIn (),
3147
+ }
3148
+
3149
+ def create_notop (operand ):
3150
+ return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3151
+ left = ast .Constant (value = 1 ),
3152
+ ops = [operators [operand ]],
3153
+ comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3154
+ ))
3155
+
3156
+ for op in operators .keys ():
3157
+ result_code = code % op
3158
+ non_optimized_target = self .wrap_expr (create_notop (op ))
3159
+ optimized_target = self .wrap_expr (
3160
+ ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3161
+ )
3162
+
3163
+ with self .subTest (
3164
+ result_code = result_code ,
3165
+ non_optimized_target = non_optimized_target ,
3166
+ optimized_target = optimized_target
3167
+ ):
3168
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3169
+
3170
+ def test_folding_format (self ):
3171
+ code = "'%s' % (a,)"
3172
+
3173
+ non_optimized_target = self .wrap_expr (
3174
+ ast .BinOp (
3175
+ left = ast .Constant (value = "%s" ),
3176
+ op = ast .Mod (),
3177
+ right = ast .Tuple (elts = [ast .Name (id = 'a' )]))
3178
+ )
3179
+ optimized_target = self .wrap_expr (
3180
+ ast .JoinedStr (
3181
+ values = [
3182
+ ast .FormattedValue (value = ast .Name (id = 'a' ), conversion = 115 )
3183
+ ]
3184
+ )
3185
+ )
3186
+
3187
+ self .assert_ast (code , non_optimized_target , optimized_target )
3188
+
3189
+
3190
+ def test_folding_tuple (self ):
3191
+ code = "(1,)"
3192
+
3193
+ non_optimized_target = self .wrap_expr (ast .Tuple (elts = [ast .Constant (1 )]))
3194
+ optimized_target = self .wrap_expr (ast .Constant (value = (1 ,)))
3195
+
3196
+ self .assert_ast (code , non_optimized_target , optimized_target )
3197
+
3198
+ def test_folding_comparator (self ):
3199
+ code = "1 %s %s1%s"
3200
+ operators = [("in" , ast .In ()), ("not in" , ast .NotIn ())]
3201
+ braces = [
3202
+ ("[" , "]" , ast .List , (1 ,)),
3203
+ ("{" , "}" , ast .Set , frozenset ({1 })),
3204
+ ]
3205
+ for left , right , non_optimized_comparator , optimized_comparator in braces :
3206
+ for op , node in operators :
3207
+ non_optimized_target = self .wrap_expr (ast .Compare (
3208
+ left = ast .Constant (1 ), ops = [node ],
3209
+ comparators = [non_optimized_comparator (elts = [ast .Constant (1 )])]
3210
+ ))
3211
+ optimized_target = self .wrap_expr (ast .Compare (
3212
+ left = ast .Constant (1 ), ops = [node ],
3213
+ comparators = [ast .Constant (value = optimized_comparator )]
3214
+ ))
3215
+ self .assert_ast (code % (op , left , right ), non_optimized_target , optimized_target )
3216
+
3217
+ def test_folding_iter (self ):
3218
+ code = "for _ in %s1%s: pass"
3219
+ braces = [
3220
+ ("[" , "]" , ast .List , (1 ,)),
3221
+ ("{" , "}" , ast .Set , frozenset ({1 })),
3222
+ ]
3223
+
3224
+ for left , right , ast_cls , optimized_iter in braces :
3225
+ non_optimized_target = self .wrap_for (ast .For (
3226
+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3227
+ iter = ast_cls (elts = [ast .Constant (1 )]),
3228
+ body = [ast .Pass ()]
3229
+ ))
3230
+ optimized_target = self .wrap_for (ast .For (
3231
+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3232
+ iter = ast .Constant (value = optimized_iter ),
3233
+ body = [ast .Pass ()]
3234
+ ))
3235
+
3236
+ self .assert_ast (code % (left , right ), non_optimized_target , optimized_target )
3237
+
3238
+ def test_folding_subscript (self ):
3239
+ code = "(1,)[0]"
3240
+
3241
+ non_optimized_target = self .wrap_expr (
3242
+ ast .Subscript (value = ast .Tuple (elts = [ast .Constant (value = 1 )]), slice = ast .Constant (value = 0 ))
3243
+ )
3244
+ optimized_target = self .wrap_expr (ast .Constant (value = 1 ))
3245
+
3246
+ self .assert_ast (code , non_optimized_target , optimized_target )
3247
+
3248
+
3249
+ if __name__ == "__main__" :
3250
+ unittest .main ()
0 commit comments