Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f7a53d8

Browse files
authored
[mypyc] Constant fold int operations and str concat (#11194)
Work on mypyc/mypyc#772. Replace things like '5 + 8' with '13' during IR building. Also adds support for negative int literals, such as -5 (the negation gets constant folded). Arithmetic and bitwise operations that produce int results are supported, plus string concatenation. Comparisons and float operations are not supported yet, among other things. This is a little tricky because of error cases, such as division by zero. The approach here is to avoid constant folding error cases. We still won't produce compile-time errors for them. We do potentially lots of extra work by repeatedly trying to constant fold expressions, but it didn't seem to slow down self-compilation measurably, so the effect seems minor at most. Some caching could help this if it becomes a problem.
1 parent 82f767a commit f7a53d8

13 files changed

Lines changed: 524 additions & 39 deletions

mypyc/codegen/emitfunc.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -483,14 +483,20 @@ def visit_comparison_op(self, op: ComparisonOp) -> None:
483483
rhs = self.reg(op.rhs)
484484
lhs_cast = ""
485485
rhs_cast = ""
486-
signed_op = {ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE}
487-
unsigned_op = {ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE}
488-
if op.op in signed_op:
486+
if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE):
487+
# Always signed comparison op
489488
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
490489
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
491-
elif op.op in unsigned_op:
490+
elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE):
491+
# Always unsigned comparison op
492492
lhs_cast = self.emit_unsigned_int_cast(op.lhs.type)
493493
rhs_cast = self.emit_unsigned_int_cast(op.rhs.type)
494+
elif isinstance(op.lhs, Integer) and op.lhs.value < 0:
495+
# Force signed ==/!= with negative operand
496+
rhs_cast = self.emit_signed_int_cast(op.rhs.type)
497+
elif isinstance(op.rhs, Integer) and op.rhs.value < 0:
498+
# Force signed ==/!= with negative operand
499+
lhs_cast = self.emit_signed_int_cast(op.lhs.type)
494500
self.emit_line('%s = %s%s %s %s%s;' % (dest, lhs_cast, lhs,
495501
op.op_str[op.op], rhs_cast, rhs))
496502

@@ -542,7 +548,12 @@ def reg(self, reg: Value) -> str:
542548
s = str(val)
543549
if val >= (1 << 31):
544550
# Avoid overflowing signed 32-bit int
545-
s += 'U'
551+
s += 'ULL'
552+
elif val == -(1 << 63):
553+
# Avoid overflowing C integer literal
554+
s = '(-9223372036854775807LL - 1)'
555+
elif val <= -(1 << 31):
556+
s += 'LL'
546557
return s
547558
else:
548559
return self.emitter.reg(reg)

mypyc/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
# Note: Assume that the compiled code uses the same bit width as mypyc, except for
4848
# Python 3.5 on macOS.
4949
MAX_LITERAL_SHORT_INT: Final = sys.maxsize >> 1 if not IS_MIXED_32_64_BIT_BUILD else 2 ** 30 - 1
50+
MIN_LITERAL_SHORT_INT: Final = -MAX_LITERAL_SHORT_INT - 1
5051

5152
# Runtime C library files
5253
RUNTIME_C_FILES: Final = [

mypyc/irbuild/constant_fold.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Constant folding of IR values.
2+
3+
For example, 3 + 5 can be constant folded into 8.
4+
"""
5+
6+
from typing import Optional, Union
7+
from typing_extensions import Final
8+
9+
from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var
10+
from mypyc.irbuild.builder import IRBuilder
11+
12+
13+
# All possible result types of constant folding
14+
ConstantValue = Union[int, str]
15+
CONST_TYPES: Final = (int, str)
16+
17+
18+
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[ConstantValue]:
19+
"""Return the constant value of an expression for supported operations.
20+
21+
Return None otherwise.
22+
"""
23+
if isinstance(expr, IntExpr):
24+
return expr.value
25+
if isinstance(expr, StrExpr):
26+
return expr.value
27+
elif isinstance(expr, NameExpr):
28+
node = expr.node
29+
if isinstance(node, Var) and node.is_final:
30+
value = node.final_value
31+
if isinstance(value, (CONST_TYPES)):
32+
return value
33+
elif isinstance(expr, MemberExpr):
34+
final = builder.get_final_ref(expr)
35+
if final is not None:
36+
fn, final_var, native = final
37+
if final_var.is_final:
38+
value = final_var.final_value
39+
if isinstance(value, (CONST_TYPES)):
40+
return value
41+
elif isinstance(expr, OpExpr):
42+
left = constant_fold_expr(builder, expr.left)
43+
right = constant_fold_expr(builder, expr.right)
44+
if isinstance(left, int) and isinstance(right, int):
45+
return constant_fold_binary_int_op(expr.op, left, right)
46+
elif isinstance(left, str) and isinstance(right, str):
47+
return constant_fold_binary_str_op(expr.op, left, right)
48+
elif isinstance(expr, UnaryExpr):
49+
value = constant_fold_expr(builder, expr.expr)
50+
if isinstance(value, int):
51+
return constant_fold_unary_int_op(expr.op, value)
52+
return None
53+
54+
55+
def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]:
56+
if op == '+':
57+
return left + right
58+
if op == '-':
59+
return left - right
60+
elif op == '*':
61+
return left * right
62+
elif op == '//':
63+
if right != 0:
64+
return left // right
65+
elif op == '%':
66+
if right != 0:
67+
return left % right
68+
elif op == '&':
69+
return left & right
70+
elif op == '|':
71+
return left | right
72+
elif op == '^':
73+
return left ^ right
74+
elif op == '<<':
75+
if right >= 0:
76+
return left << right
77+
elif op == '>>':
78+
if right >= 0:
79+
return left >> right
80+
elif op == '**':
81+
if right >= 0:
82+
return left ** right
83+
return None
84+
85+
86+
def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]:
87+
if op == '-':
88+
return -value
89+
elif op == '~':
90+
return ~value
91+
elif op == '+':
92+
return value
93+
return None
94+
95+
96+
def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]:
97+
if op == '+':
98+
return left + right
99+
return None

mypyc/irbuild/expression.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
translate_list_comprehension, translate_set_comprehension,
4545
comprehension_helper
4646
)
47+
from mypyc.irbuild.constant_fold import constant_fold_expr
4748

4849

4950
# Name and attribute references
@@ -378,6 +379,10 @@ def translate_cast_expr(builder: IRBuilder, expr: CastExpr) -> Value:
378379

379380

380381
def transform_unary_expr(builder: IRBuilder, expr: UnaryExpr) -> Value:
382+
folded = try_constant_fold(builder, expr)
383+
if folded:
384+
return folded
385+
381386
return builder.unary_op(builder.accept(expr.expr), expr.op, expr.line)
382387

383388

@@ -391,6 +396,10 @@ def transform_op_expr(builder: IRBuilder, expr: OpExpr) -> Value:
391396
if ret is not None:
392397
return ret
393398

399+
folded = try_constant_fold(builder, expr)
400+
if folded:
401+
return folded
402+
394403
return builder.binary_op(
395404
builder.accept(expr.left), builder.accept(expr.right), expr.op, expr.line
396405
)
@@ -413,6 +422,19 @@ def transform_index_expr(builder: IRBuilder, expr: IndexExpr) -> Value:
413422
base, '__getitem__', [index_reg], builder.node_type(expr), expr.line)
414423

415424

425+
def try_constant_fold(builder: IRBuilder, expr: Expression) -> Optional[Value]:
426+
"""Return the constant value of an expression if possible.
427+
428+
Return None otherwise.
429+
"""
430+
value = constant_fold_expr(builder, expr)
431+
if isinstance(value, int):
432+
return builder.load_int(value)
433+
elif isinstance(value, str):
434+
return builder.load_str(value)
435+
return None
436+
437+
416438
def try_gen_slice_op(builder: IRBuilder, base: Value, index: SliceExpr) -> Optional[Value]:
417439
"""Generate specialized slice op for some index expressions.
418440

mypyc/irbuild/ll_builder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@
3939
from mypyc.ir.func_ir import FuncDecl, FuncSignature
4040
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
4141
from mypyc.common import (
42-
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT, PLATFORM_SIZE, use_vectorcall,
43-
use_method_vectorcall
42+
FAST_ISINSTANCE_MAX_SUBCLASSES, MAX_LITERAL_SHORT_INT, MIN_LITERAL_SHORT_INT, PLATFORM_SIZE,
43+
use_vectorcall, use_method_vectorcall
4444
)
4545
from mypyc.primitives.registry import (
4646
method_call_ops, CFunctionDescription,
@@ -789,7 +789,7 @@ def none_object(self) -> Value:
789789

790790
def load_int(self, value: int) -> Value:
791791
"""Load a tagged (Python) integer literal value."""
792-
if abs(value) > MAX_LITERAL_SHORT_INT:
792+
if value > MAX_LITERAL_SHORT_INT or value < MIN_LITERAL_SHORT_INT:
793793
return self.add(LoadLiteral(value, int_rprimitive))
794794
else:
795795
return Integer(value)

mypyc/test-data/analysis.test

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,8 @@ def lol(x):
536536
r2 :: object
537537
r3 :: str
538538
r4 :: object
539-
r5 :: bit
540-
r6 :: int
541-
r7 :: bit
542-
r8, r9 :: int
539+
r5, r6 :: bit
540+
r7, r8 :: int
543541
L0:
544542
L1:
545543
r0 = CPyTagged_Id(x)
@@ -555,9 +553,8 @@ L3:
555553
r5 = CPy_ExceptionMatches(r4)
556554
if r5 goto L4 else goto L5 :: bool
557555
L4:
558-
r6 = CPyTagged_Negate(2)
559556
CPy_RestoreExcInfo(r1)
560-
return r6
557+
return -2
561558
L5:
562559
CPy_Reraise()
563560
if not 0 goto L8 else goto L6 :: bool
@@ -568,16 +565,16 @@ L7:
568565
goto L10
569566
L8:
570567
CPy_RestoreExcInfo(r1)
571-
r7 = CPy_KeepPropagating()
572-
if not r7 goto L11 else goto L9 :: bool
568+
r6 = CPy_KeepPropagating()
569+
if not r6 goto L11 else goto L9 :: bool
573570
L9:
574571
unreachable
575572
L10:
576-
r8 = CPyTagged_Add(st, 2)
577-
return r8
573+
r7 = CPyTagged_Add(st, 2)
574+
return r7
578575
L11:
579-
r9 = <error> :: int
580-
return r9
576+
r8 = <error> :: int
577+
return r8
581578
(0, 0) {x} {x}
582579
(1, 0) {x} {r0}
583580
(1, 1) {r0} {st}
@@ -589,20 +586,18 @@ L11:
589586
(2, 4) {r1, r4} {r1, r4}
590587
(3, 0) {r1, r4} {r1, r5}
591588
(3, 1) {r1, r5} {r1}
592-
(4, 0) {r1} {r1, r6}
593-
(4, 1) {r1, r6} {r6}
594-
(4, 2) {r6} {}
589+
(4, 0) {r1} {}
590+
(4, 1) {} {}
595591
(5, 0) {r1} {r1}
596592
(5, 1) {r1} {r1}
597593
(6, 0) {} {}
598594
(7, 0) {r1, st} {st}
599595
(7, 1) {st} {st}
600596
(8, 0) {r1} {}
601-
(8, 1) {} {r7}
602-
(8, 2) {r7} {}
597+
(8, 1) {} {r6}
598+
(8, 2) {r6} {}
603599
(9, 0) {} {}
604-
(10, 0) {st} {r8}
605-
(10, 1) {r8} {}
606-
(11, 0) {} {r9}
607-
(11, 1) {r9} {}
608-
600+
(10, 0) {st} {r7}
601+
(10, 1) {r7} {}
602+
(11, 0) {} {r8}
603+
(11, 1) {r8} {}

mypyc/test-data/fixtures/ir.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __sub__(self, n: int) -> int: pass
3636
def __mul__(self, n: int) -> int: pass
3737
def __pow__(self, n: int, modulo: Optional[int] = None) -> int: pass
3838
def __floordiv__(self, x: int) -> int: pass
39+
def __truediv__(self, x: float) -> float: pass
3940
def __mod__(self, x: int) -> int: pass
4041
def __neg__(self) -> int: pass
4142
def __pos__(self) -> int: pass
@@ -271,6 +272,10 @@ class NotImplementedError(RuntimeError): pass
271272
class StopIteration(Exception):
272273
value: Any
273274

275+
class ArithmeticError(Exception): pass
276+
277+
class ZeroDivisionError(Exception): pass
278+
274279
def any(i: Iterable[T]) -> bool: pass
275280
def all(i: Iterable[T]) -> bool: pass
276281
def reversed(object: Sequence[T]) -> Iterator[T]: ...

mypyc/test-data/irbuild-basic.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,12 +581,12 @@ L8:
581581

582582
[case testUnaryMinus]
583583
def f(n: int) -> int:
584-
return -1
584+
return -n
585585
[out]
586586
def f(n):
587587
n, r0 :: int
588588
L0:
589-
r0 = CPyTagged_Negate(2)
589+
r0 = CPyTagged_Negate(n)
590590
return r0
591591

592592
[case testConditionalExpr]

0 commit comments

Comments
 (0)