@@ -3035,3 +3035,216 @@ def test_cli_file_input(self):
30353035 self .assertEqual (expected .splitlines (),
30363036 res .out .decode ("utf8" ).splitlines ())
30373037 self .assertEqual (res .rc , 0 )
3038+
3039+
3040+ class ASTOptimiziationTests (unittest .TestCase ):
3041+ binop = {
3042+ "+" : ast .Add (),
3043+ "-" : ast .Sub (),
3044+ "*" : ast .Mult (),
3045+ "/" : ast .Div (),
3046+ "%" : ast .Mod (),
3047+ "<<" : ast .LShift (),
3048+ ">>" : ast .RShift (),
3049+ "|" : ast .BitOr (),
3050+ "^" : ast .BitXor (),
3051+ "&" : ast .BitAnd (),
3052+ "//" : ast .FloorDiv (),
3053+ "**" : ast .Pow (),
3054+ }
3055+
3056+ unaryop = {
3057+ "~" : ast .Invert (),
3058+ "+" : ast .UAdd (),
3059+ "-" : ast .USub (),
3060+ }
3061+
3062+ def wrap_expr (self , expr ):
3063+ return ast .Module (body = [ast .Expr (value = expr )])
3064+
3065+ def wrap_for (self , for_statement ):
3066+ return ast .Module (body = [for_statement ])
3067+
3068+ def assert_ast (self , code , non_optimized_target , optimized_target ):
3069+ non_optimized_tree = ast .parse (code , optimize = - 1 )
3070+ optimized_tree = ast .parse (code , optimize = 1 )
3071+
3072+ # Is a non-optimized tree equal to a non-optimized target?
3073+ self .assertTrue (
3074+ ast .compare (non_optimized_tree , non_optimized_target ),
3075+ f"{ ast .dump (non_optimized_target )} must equal "
3076+ f"{ ast .dump (non_optimized_tree )} " ,
3077+ )
3078+
3079+ # Is a optimized tree equal to a non-optimized target?
3080+ self .assertFalse (
3081+ ast .compare (optimized_tree , non_optimized_target ),
3082+ f"{ ast .dump (non_optimized_target )} must not equal "
3083+ f"{ ast .dump (non_optimized_tree )} "
3084+ )
3085+
3086+ # Is a optimized tree is equal to an optimized target?
3087+ self .assertTrue (
3088+ ast .compare (optimized_tree , optimized_target ),
3089+ f"{ ast .dump (optimized_target )} must equal "
3090+ f"{ ast .dump (optimized_tree )} " ,
3091+ )
3092+
3093+ def test_folding_binop (self ):
3094+ code = "1 %s 1"
3095+ operators = self .binop .keys ()
3096+
3097+ def create_binop (operand , left = ast .Constant (1 ), right = ast .Constant (1 )):
3098+ return ast .BinOp (left = left , op = self .binop [operand ], right = right )
3099+
3100+ for op in operators :
3101+ result_code = code % op
3102+ non_optimized_target = self .wrap_expr (create_binop (op ))
3103+ optimized_target = self .wrap_expr (ast .Constant (value = eval (result_code )))
3104+
3105+ with self .subTest (
3106+ result_code = result_code ,
3107+ non_optimized_target = non_optimized_target ,
3108+ optimized_target = optimized_target
3109+ ):
3110+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3111+
3112+ # Multiplication of constant tuples must be folded
3113+ code = "(1,) * 3"
3114+ non_optimized_target = self .wrap_expr (create_binop ("*" , ast .Tuple (elts = [ast .Constant (value = 1 )]), ast .Constant (value = 3 )))
3115+ optimized_target = self .wrap_expr (ast .Constant (eval (code )))
3116+
3117+ self .assert_ast (code , non_optimized_target , optimized_target )
3118+
3119+ def test_folding_unaryop (self ):
3120+ code = "%s1"
3121+ operators = self .unaryop .keys ()
3122+
3123+ def create_unaryop (operand ):
3124+ return ast .UnaryOp (op = self .unaryop [operand ], operand = ast .Constant (1 ))
3125+
3126+ for op in operators :
3127+ result_code = code % op
3128+ non_optimized_target = self .wrap_expr (create_unaryop (op ))
3129+ optimized_target = self .wrap_expr (ast .Constant (eval (result_code )))
3130+
3131+ with self .subTest (
3132+ result_code = result_code ,
3133+ non_optimized_target = non_optimized_target ,
3134+ optimized_target = optimized_target
3135+ ):
3136+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3137+
3138+ def test_folding_not (self ):
3139+ code = "not (1 %s (1,))"
3140+ operators = {
3141+ "in" : ast .In (),
3142+ "is" : ast .Is (),
3143+ }
3144+ opt_operators = {
3145+ "is" : ast .IsNot (),
3146+ "in" : ast .NotIn (),
3147+ }
3148+
3149+ def create_notop (operand ):
3150+ return ast .UnaryOp (op = ast .Not (), operand = ast .Compare (
3151+ left = ast .Constant (value = 1 ),
3152+ ops = [operators [operand ]],
3153+ comparators = [ast .Tuple (elts = [ast .Constant (value = 1 )])]
3154+ ))
3155+
3156+ for op in operators .keys ():
3157+ result_code = code % op
3158+ non_optimized_target = self .wrap_expr (create_notop (op ))
3159+ optimized_target = self .wrap_expr (
3160+ ast .Compare (left = ast .Constant (1 ), ops = [opt_operators [op ]], comparators = [ast .Constant (value = (1 ,))])
3161+ )
3162+
3163+ with self .subTest (
3164+ result_code = result_code ,
3165+ non_optimized_target = non_optimized_target ,
3166+ optimized_target = optimized_target
3167+ ):
3168+ self .assert_ast (result_code , non_optimized_target , optimized_target )
3169+
3170+ def test_folding_format (self ):
3171+ code = "'%s' % (a,)"
3172+
3173+ non_optimized_target = self .wrap_expr (
3174+ ast .BinOp (
3175+ left = ast .Constant (value = "%s" ),
3176+ op = ast .Mod (),
3177+ right = ast .Tuple (elts = [ast .Name (id = 'a' )]))
3178+ )
3179+ optimized_target = self .wrap_expr (
3180+ ast .JoinedStr (
3181+ values = [
3182+ ast .FormattedValue (value = ast .Name (id = 'a' ), conversion = 115 )
3183+ ]
3184+ )
3185+ )
3186+
3187+ self .assert_ast (code , non_optimized_target , optimized_target )
3188+
3189+
3190+ def test_folding_tuple (self ):
3191+ code = "(1,)"
3192+
3193+ non_optimized_target = self .wrap_expr (ast .Tuple (elts = [ast .Constant (1 )]))
3194+ optimized_target = self .wrap_expr (ast .Constant (value = (1 ,)))
3195+
3196+ self .assert_ast (code , non_optimized_target , optimized_target )
3197+
3198+ def test_folding_comparator (self ):
3199+ code = "1 %s %s1%s"
3200+ operators = [("in" , ast .In ()), ("not in" , ast .NotIn ())]
3201+ braces = [
3202+ ("[" , "]" , ast .List , (1 ,)),
3203+ ("{" , "}" , ast .Set , frozenset ({1 })),
3204+ ]
3205+ for left , right , non_optimized_comparator , optimized_comparator in braces :
3206+ for op , node in operators :
3207+ non_optimized_target = self .wrap_expr (ast .Compare (
3208+ left = ast .Constant (1 ), ops = [node ],
3209+ comparators = [non_optimized_comparator (elts = [ast .Constant (1 )])]
3210+ ))
3211+ optimized_target = self .wrap_expr (ast .Compare (
3212+ left = ast .Constant (1 ), ops = [node ],
3213+ comparators = [ast .Constant (value = optimized_comparator )]
3214+ ))
3215+ self .assert_ast (code % (op , left , right ), non_optimized_target , optimized_target )
3216+
3217+ def test_folding_iter (self ):
3218+ code = "for _ in %s1%s: pass"
3219+ braces = [
3220+ ("[" , "]" , ast .List , (1 ,)),
3221+ ("{" , "}" , ast .Set , frozenset ({1 })),
3222+ ]
3223+
3224+ for left , right , ast_cls , optimized_iter in braces :
3225+ non_optimized_target = self .wrap_for (ast .For (
3226+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3227+ iter = ast_cls (elts = [ast .Constant (1 )]),
3228+ body = [ast .Pass ()]
3229+ ))
3230+ optimized_target = self .wrap_for (ast .For (
3231+ target = ast .Name (id = "_" , ctx = ast .Store ()),
3232+ iter = ast .Constant (value = optimized_iter ),
3233+ body = [ast .Pass ()]
3234+ ))
3235+
3236+ self .assert_ast (code % (left , right ), non_optimized_target , optimized_target )
3237+
3238+ def test_folding_subscript (self ):
3239+ code = "(1,)[0]"
3240+
3241+ non_optimized_target = self .wrap_expr (
3242+ ast .Subscript (value = ast .Tuple (elts = [ast .Constant (value = 1 )]), slice = ast .Constant (value = 0 ))
3243+ )
3244+ optimized_target = self .wrap_expr (ast .Constant (value = 1 ))
3245+
3246+ self .assert_ast (code , non_optimized_target , optimized_target )
3247+
3248+
3249+ if __name__ == "__main__" :
3250+ unittest .main ()
0 commit comments