@@ -3172,3 +3172,219 @@ def test_cli_file_input(self):
3172
3172
self .assertEqual (res .err , b"" )
3173
3173
self .assertEqual (expected .splitlines (), res .out .decode ("utf8" ).splitlines ())
3174
3174
self .assertEqual (res .rc , 0 )
3175
+
3176
+ def compare (left , right ):
3177
+ return ast .dump (left ) == ast .dump (right )
3178
+
3179
+ class ASTOptimiziationTests (unittest .TestCase ):
3180
+ binop = {
3181
+ "+" : ast .Add (),
3182
+ "-" : ast .Sub (),
3183
+ "*" : ast .Mult (),
3184
+ "/" : ast .Div (),
3185
+ "%" : ast .Mod (),
3186
+ "<<" : ast .LShift (),
3187
+ ">>" : ast .RShift (),
3188
+ "|" : ast .BitOr (),
3189
+ "^" : ast .BitXor (),
3190
+ "&" : ast .BitAnd (),
3191
+ "//" : ast .FloorDiv (),
3192
+ "**" : ast .Pow (),
3193
+ }
3194
+
3195
+ unaryop = {
3196
+ "~" : ast .Invert (),
3197
+ "+" : ast .UAdd (),
3198
+ "-" : ast .USub (),
3199
+ }
3200
+
3201
+ def wrap_expr (self , expr ):
3202
+ return ast .Module (body = [ast .Expr (value = expr )])
3203
+
3204
+ def wrap_for (self , for_statement ):
3205
+ return ast .Module (body = [for_statement ])
3206
+
3207
+ def assert_ast (self , code , non_optimized_target , optimized_target ):
3208
+
3209
+ non_optimized_tree = ast .parse (code , optimize = - 1 )
3210
+ optimized_tree = ast .parse (code , optimize = 1 )
3211
+
3212
+ # Is a non-optimized tree equal to a non-optimized target?
3213
+ self .assertTrue (
3214
+ compare (non_optimized_tree , non_optimized_target ),
3215
+ f"{ ast .dump (non_optimized_target )} must equal "
3216
+ f"{ ast .dump (non_optimized_tree )} " ,
3217
+ )
3218
+
3219
+ # Is a optimized tree equal to a non-optimized target?
3220
+ self .assertFalse (
3221
+ compare (optimized_tree , non_optimized_target ),
3222
+ f"{ ast .dump (non_optimized_target )} must not equal "
3223
+ f"{ ast .dump (non_optimized_tree )} "
3224
+ )
3225
+
3226
+ # Is a optimized tree is equal to an optimized target?
3227
+ self .assertTrue (
3228
+ compare (optimized_tree , optimized_target ),
3229
+ f"{ ast .dump (optimized_target )} must equal "
3230
+ f"{ ast .dump (optimized_tree )} " ,
3231
+ )
3232
+
3233
+ def test_folding_binop (self ):
3234
+ code = "1 %s 1"
3235
+ operators = self .binop .keys ()
3236
+
3237
+ def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3238
+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3239
+
3240
+ for op in operators :
3241
+ result_code = code % op
3242
+ non_optimized_target = self .wrap_expr (create_binop (op ))
3243
+ optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3244
+
3245
+ with self .subTest (
3246
+ result_code = result_code ,
3247
+ non_optimized_target = non_optimized_target ,
3248
+ optimized_target = optimized_target
3249
+ ):
3250
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3251
+
3252
+ # Multiplication of constant tuples must be folded
3253
+ code = "(1,) * 3"
3254
+ non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3255
+ optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3256
+
3257
+ self .assert_ast (code , non_optimized_target , optimized_target )
3258
+
3259
+ def test_folding_unaryop (self ):
3260
+ code = "%s1"
3261
+ operators = self .unaryop .keys ()
3262
+
3263
+ def create_unaryop (operand ):
3264
+ return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3265
+
3266
+ for op in operators :
3267
+ result_code = code % op
3268
+ non_optimized_target = self .wrap_expr (create_unaryop (op ))
3269
+ optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3270
+
3271
+ with self .subTest (
3272
+ result_code = result_code ,
3273
+ non_optimized_target = non_optimized_target ,
3274
+ optimized_target = optimized_target
3275
+ ):
3276
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3277
+
3278
+ def test_folding_not (self ):
3279
+ code = "not (1 %s (1,))"
3280
+ operators = {
3281
+ "in" : ast .In (),
3282
+ "is" : ast .Is (),
3283
+ }
3284
+ opt_operators = {
3285
+ "is" : ast .IsNot (),
3286
+ "in" : ast .NotIn (),
3287
+ }
3288
+
3289
+ def create_notop (operand ):
3290
+ return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3291
+ left = ast .Constant (value = 1 ),
3292
+ ops = [operators [operand ]],
3293
+ comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3294
+ ))
3295
+
3296
+ for op in operators .keys ():
3297
+ result_code = code % op
3298
+ non_optimized_target = self .wrap_expr (create_notop (op ))
3299
+ optimized_target = self .wrap_expr (
3300
+ ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3301
+ )
3302
+
3303
+ with self .subTest (
3304
+ result_code = result_code ,
3305
+ non_optimized_target = non_optimized_target ,
3306
+ optimized_target = optimized_target
3307
+ ):
3308
+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3309
+
3310
+ def test_folding_format (self ):
3311
+ code = "'%s' % (a,)"
3312
+
3313
+ non_optimized_target = self .wrap_expr (
3314
+ ast .BinOp (
3315
+ left = ast .Constant (value = "%s" ),
3316
+ op = ast .Mod (),
3317
+ right = ast .Tuple (elts = [ast .Name (id = 'a' )]))
3318
+ )
3319
+ optimized_target = self .wrap_expr (
3320
+ ast .JoinedStr (
3321
+ values = [
3322
+ ast .FormattedValue (value = ast .Name (id = 'a' ), conversion = 115 )
3323
+ ]
3324
+ )
3325
+ )
3326
+
3327
+ self .assert_ast (code , non_optimized_target , optimized_target )
3328
+
3329
+
3330
+ def test_folding_tuple (self ):
3331
+ code = "(1,)"
3332
+
3333
+ non_optimized_target = self .wrap_expr (ast .Tuple (elts = [ast .Constant (1 )]))
3334
+ optimized_target = self .wrap_expr (ast .Constant (value = (1 ,)))
3335
+
3336
+ self .assert_ast (code , non_optimized_target , optimized_target )
3337
+
3338
+ def test_folding_comparator (self ):
3339
+ code = "1 %s %s1%s"
3340
+ operators = [("in" , ast .In ()), ("not in" , ast .NotIn ())]
3341
+ braces = [
3342
+ ("[" , "]" , ast .List , (1 ,)),
3343
+ ("{" , "}" , ast .Set , frozenset ({1 })),
3344
+ ]
3345
+ for left , right , non_optimized_comparator , optimized_comparator in braces :
3346
+ for op , node in operators :
3347
+ non_optimized_target = self .wrap_expr (ast .Compare (
3348
+ left = ast .Constant (1 ), ops = [node ],
3349
+ comparators = [non_optimized_comparator (elts = [ast .Constant (1 )])]
3350
+ ))
3351
+ optimized_target = self .wrap_expr (ast .Compare (
3352
+ left = ast .Constant (1 ), ops = [node ],
3353
+ comparators = [ast .Constant (value = optimized_comparator )]
3354
+ ))
3355
+ self .assert_ast (code % (op , left , right ), non_optimized_target , optimized_target )
3356
+
3357
+ def test_folding_iter (self ):
3358
+ code = "for _ in %s1%s: pass"
3359
+ braces = [
3360
+ ("[" , "]" , ast .List , (1 ,)),
3361
+ ("{" , "}" , ast .Set , frozenset ({1 })),
3362
+ ]
3363
+
3364
+ for left , right , ast_cls , optimized_iter in braces :
3365
+ non_optimized_target = self .wrap_for (ast .For (
3366
+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3367
+ iter = ast_cls (elts = [ast .Constant (1 )]),
3368
+ body = [ast .Pass ()]
3369
+ ))
3370
+ optimized_target = self .wrap_for (ast .For (
3371
+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3372
+ iter = ast .Constant (value = optimized_iter ),
3373
+ body = [ast .Pass ()]
3374
+ ))
3375
+
3376
+ self .assert_ast (code % (left , right ), non_optimized_target , optimized_target )
3377
+
3378
+ def test_folding_subscript (self ):
3379
+ code = "(1,)[0]"
3380
+
3381
+ non_optimized_target = self .wrap_expr (
3382
+ ast .Subscript (value = ast .Tuple (elts = [ast .Constant (value = 1 )]), slice = ast .Constant (value = 0 ))
3383
+ )
3384
+ optimized_target = self .wrap_expr (ast .Constant (value = 1 ))
3385
+
3386
+ self .assert_ast (code , non_optimized_target , optimized_target )
3387
+
3388
+
3389
+ if __name__ == "__main__" :
3390
+ unittest .main ()
0 commit comments