|
| 1 | +"""Constant folding of IR values. |
| 2 | +
|
| 3 | +For example, 3 + 5 can be constant folded into 8. |
| 4 | +""" |
| 5 | + |
| 6 | +from typing import Optional, Union |
| 7 | +from typing_extensions import Final |
| 8 | + |
| 9 | +from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var |
| 10 | +from mypyc.irbuild.builder import IRBuilder |
| 11 | + |
| 12 | + |
| 13 | +# All possible result types of constant folding |
| 14 | +ConstantValue = Union[int, str] |
| 15 | +CONST_TYPES: Final = (int, str) |
| 16 | + |
| 17 | + |
| 18 | +def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[ConstantValue]: |
| 19 | + """Return the constant value of an expression for supported operations. |
| 20 | +
|
| 21 | + Return None otherwise. |
| 22 | + """ |
| 23 | + if isinstance(expr, IntExpr): |
| 24 | + return expr.value |
| 25 | + if isinstance(expr, StrExpr): |
| 26 | + return expr.value |
| 27 | + elif isinstance(expr, NameExpr): |
| 28 | + node = expr.node |
| 29 | + if isinstance(node, Var) and node.is_final: |
| 30 | + value = node.final_value |
| 31 | + if isinstance(value, (CONST_TYPES)): |
| 32 | + return value |
| 33 | + elif isinstance(expr, MemberExpr): |
| 34 | + final = builder.get_final_ref(expr) |
| 35 | + if final is not None: |
| 36 | + fn, final_var, native = final |
| 37 | + if final_var.is_final: |
| 38 | + value = final_var.final_value |
| 39 | + if isinstance(value, (CONST_TYPES)): |
| 40 | + return value |
| 41 | + elif isinstance(expr, OpExpr): |
| 42 | + left = constant_fold_expr(builder, expr.left) |
| 43 | + right = constant_fold_expr(builder, expr.right) |
| 44 | + if isinstance(left, int) and isinstance(right, int): |
| 45 | + return constant_fold_binary_int_op(expr.op, left, right) |
| 46 | + elif isinstance(left, str) and isinstance(right, str): |
| 47 | + return constant_fold_binary_str_op(expr.op, left, right) |
| 48 | + elif isinstance(expr, UnaryExpr): |
| 49 | + value = constant_fold_expr(builder, expr.expr) |
| 50 | + if isinstance(value, int): |
| 51 | + return constant_fold_unary_int_op(expr.op, value) |
| 52 | + return None |
| 53 | + |
| 54 | + |
| 55 | +def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]: |
| 56 | + if op == '+': |
| 57 | + return left + right |
| 58 | + if op == '-': |
| 59 | + return left - right |
| 60 | + elif op == '*': |
| 61 | + return left * right |
| 62 | + elif op == '//': |
| 63 | + if right != 0: |
| 64 | + return left // right |
| 65 | + elif op == '%': |
| 66 | + if right != 0: |
| 67 | + return left % right |
| 68 | + elif op == '&': |
| 69 | + return left & right |
| 70 | + elif op == '|': |
| 71 | + return left | right |
| 72 | + elif op == '^': |
| 73 | + return left ^ right |
| 74 | + elif op == '<<': |
| 75 | + if right >= 0: |
| 76 | + return left << right |
| 77 | + elif op == '>>': |
| 78 | + if right >= 0: |
| 79 | + return left >> right |
| 80 | + elif op == '**': |
| 81 | + if right >= 0: |
| 82 | + return left ** right |
| 83 | + return None |
| 84 | + |
| 85 | + |
| 86 | +def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]: |
| 87 | + if op == '-': |
| 88 | + return -value |
| 89 | + elif op == '~': |
| 90 | + return ~value |
| 91 | + elif op == '+': |
| 92 | + return value |
| 93 | + return None |
| 94 | + |
| 95 | + |
| 96 | +def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]: |
| 97 | + if op == '+': |
| 98 | + return left + right |
| 99 | + return None |
0 commit comments