Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 53f1253

Browse files
authored
[mypyc] Merge string equality ops (#9363)
This PR merges str `==` and `!=` by directly building them in irbuild. The old primitive relies on `ERR_MAGIC` to handle the exception, now we use an `err_occurred_op` with `keep_propagating_op` to represent the same semantics. Actually, our first several commits to replace `PrimitiveOp` with `CallC` cause an incorrect primitive lookup logic that at the point the generic compare is merged, string primitives would just use the generic ones. This PR will also fix this since we can obsolete the old binary op registry.
1 parent a05f19e commit 53f1253

7 files changed

Lines changed: 131 additions & 93 deletions

File tree

mypyc/ir/ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,6 +1166,7 @@ def __init__(self,
11661166
args: List[Value],
11671167
ret_type: RType,
11681168
steals: StealsDescription,
1169+
is_borrowed: bool,
11691170
error_kind: int,
11701171
line: int,
11711172
var_arg_idx: int = -1) -> None:
@@ -1175,6 +1176,7 @@ def __init__(self,
11751176
self.args = args
11761177
self.type = ret_type
11771178
self.steals = steals
1179+
self.is_borrowed = is_borrowed
11781180
self.var_arg_idx = var_arg_idx # the position of the first variable argument in args
11791181

11801182
def to_str(self, env: Environment) -> str:

mypyc/irbuild/ll_builder.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
bool_rprimitive, list_rprimitive, str_rprimitive, is_none_rprimitive, object_rprimitive,
3030
c_pyssize_t_rprimitive, is_short_int_rprimitive, is_tagged, PyVarObject, short_int_rprimitive,
3131
is_list_rprimitive, is_tuple_rprimitive, is_dict_rprimitive, is_set_rprimitive, PySetObject,
32-
none_rprimitive, RTuple, is_bool_rprimitive
32+
none_rprimitive, RTuple, is_bool_rprimitive, is_str_rprimitive, c_int_rprimitive,
33+
pointer_rprimitive
3334
)
3435
from mypyc.ir.func_ir import FuncDecl, FuncSignature
3536
from mypyc.ir.class_ir import ClassIR, all_concrete_classes
@@ -38,7 +39,7 @@
3839
STATIC_PREFIX
3940
)
4041
from mypyc.primitives.registry import (
41-
binary_ops, method_ops, func_ops,
42+
method_ops, func_ops,
4243
c_method_call_ops, CFunctionDescription, c_function_ops,
4344
c_binary_ops, c_unary_ops
4445
)
@@ -56,6 +57,8 @@
5657
none_object_op, fast_isinstance_op, bool_op, type_is_op
5758
)
5859
from mypyc.primitives.int_ops import int_comparison_op_mapping
60+
from mypyc.primitives.exc_ops import err_occurred_op, keep_propagating_op
61+
from mypyc.primitives.str_ops import unicode_compare
5962
from mypyc.rt_subtype import is_runtime_subtype
6063
from mypyc.subtype import is_subtype
6164
from mypyc.sametype import is_same_type
@@ -567,15 +570,15 @@ def binary_op(self,
567570
if expr_op in ('is', 'is not'):
568571
return self.translate_is_op(lreg, rreg, expr_op, line)
569572

573+
if (is_str_rprimitive(lreg.type) and is_str_rprimitive(rreg.type)
574+
and expr_op in ('==', '!=')):
575+
return self.compare_strings(lreg, rreg, expr_op, line)
576+
570577
if is_tagged(lreg.type) and is_tagged(rreg.type) and expr_op in int_comparison_op_mapping:
571578
return self.compare_tagged(lreg, rreg, expr_op, line)
572579

573580
call_c_ops_candidates = c_binary_ops.get(expr_op, [])
574581
target = self.matching_call_c(call_c_ops_candidates, [lreg, rreg], line)
575-
if target:
576-
return target
577-
ops = binary_ops.get(expr_op, [])
578-
target = self.matching_primitive_op(ops, [lreg, rreg], line)
579582
assert target, 'Unsupported binary operation: %s' % expr_op
580583
return target
581584

@@ -626,6 +629,32 @@ def compare_tagged(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
626629
self.goto_and_activate(out)
627630
return result
628631

632+
def compare_strings(self, lhs: Value, rhs: Value, op: str, line: int) -> Value:
633+
"""Compare two strings"""
634+
compare_result = self.call_c(unicode_compare, [lhs, rhs], line)
635+
error_constant = self.add(LoadInt(-1, line, c_int_rprimitive))
636+
compare_error_check = self.add(ComparisonOp(compare_result,
637+
error_constant, ComparisonOp.EQ, line))
638+
exception_check, propagate, final_compare = BasicBlock(), BasicBlock(), BasicBlock()
639+
branch = Branch(compare_error_check, exception_check, final_compare, Branch.BOOL_EXPR)
640+
branch.negated = False
641+
self.add(branch)
642+
self.activate_block(exception_check)
643+
check_error_result = self.call_c(err_occurred_op, [], line)
644+
null = self.add(LoadInt(0, line, pointer_rprimitive))
645+
compare_error_check = self.add(ComparisonOp(check_error_result,
646+
null, ComparisonOp.NEQ, line))
647+
branch = Branch(compare_error_check, propagate, final_compare, Branch.BOOL_EXPR)
648+
branch.negated = False
649+
self.add(branch)
650+
self.activate_block(propagate)
651+
self.call_c(keep_propagating_op, [], line)
652+
self.goto(final_compare)
653+
self.activate_block(final_compare)
654+
op_type = ComparisonOp.EQ if op == '==' else ComparisonOp.NEQ
655+
return self.add(ComparisonOp(compare_result,
656+
self.add(LoadInt(0, line, c_int_rprimitive)), op_type, line))
657+
629658
def compare_tuples(self,
630659
lhs: Value,
631660
rhs: Value,
@@ -840,7 +869,7 @@ def call_c(self,
840869
extra_int_constant = self.add(LoadInt(val, line, rtype=typ))
841870
coerced.append(extra_int_constant)
842871
target = self.add(CallC(desc.c_function_name, coerced, desc.return_type, desc.steals,
843-
desc.error_kind, line, var_arg_idx))
872+
desc.is_borrowed, desc.error_kind, line, var_arg_idx))
844873
if desc.truncated_type is None:
845874
result = target
846875
else:

mypyc/primitives/exc_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@
4141
c_function_name='CPy_NoErrOccured',
4242
error_kind=ERR_FALSE)
4343

44+
err_occurred_op = c_custom_op(
45+
arg_types=[],
46+
return_type=object_rprimitive,
47+
c_function_name='PyErr_Occurred',
48+
error_kind=ERR_NEVER,
49+
is_borrowed=True)
4450

4551
# Keep propagating a raised exception by unconditionally giving an error value.
4652
# This doesn't actually raise an exception.

mypyc/primitives/registry.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
('c_function_name', str),
5252
('error_kind', int),
5353
('steals', StealsDescription),
54+
('is_borrowed', bool),
5455
('ordering', Optional[List[int]]),
5556
('extra_int_constants', List[Tuple[int, RType]]),
5657
('priority', int)])
@@ -61,9 +62,6 @@
6162
('type', RType),
6263
('src', str)]) # name of the target to load
6364

64-
# Primitive binary ops (key is operator such as '+')
65-
binary_ops = {} # type: Dict[str, List[OpDescription]]
66-
6765
# Primitive ops for built-in functions (key is function name such as 'builtins.len')
6866
func_ops = {} # type: Dict[str, List[OpDescription]]
6967

@@ -116,31 +114,6 @@ def call_emit(func: str) -> EmitCallback:
116114
return simple_emit('{dest} = %s({comma_args});' % func)
117115

118116

119-
def binary_op(op: str,
120-
arg_types: List[RType],
121-
result_type: RType,
122-
error_kind: int,
123-
emit: EmitCallback,
124-
format_str: Optional[str] = None,
125-
steals: StealsDescription = False,
126-
is_borrowed: bool = False,
127-
priority: int = 1) -> None:
128-
"""Define a PrimitiveOp for a binary operation.
129-
130-
Arguments are similar to func_op(), but exactly two argument types
131-
are expected.
132-
133-
This will be automatically generated by matching against the AST.
134-
"""
135-
assert len(arg_types) == 2
136-
ops = binary_ops.setdefault(op, [])
137-
if format_str is None:
138-
format_str = '{dest} = {args[0]} %s {args[1]}' % op
139-
desc = OpDescription(op, arg_types, result_type, False, error_kind, format_str, emit,
140-
steals, is_borrowed, priority)
141-
ops.append(desc)
142-
143-
144117
def func_op(name: str,
145118
arg_types: List[RType],
146119
result_type: RType,
@@ -281,6 +254,7 @@ def c_method_op(name: str,
281254
ordering: Optional[List[int]] = None,
282255
extra_int_constants: List[Tuple[int, RType]] = [],
283256
steals: StealsDescription = False,
257+
is_borrowed: bool = False,
284258
priority: int = 1) -> CFunctionDescription:
285259
"""Define a c function call op that replaces a method call.
286260
@@ -303,12 +277,13 @@ def c_method_op(name: str,
303277
accepted by the python syntax(before reordering)
304278
extra_int_constants: optional extra integer constants as the last arguments to a C call
305279
steals: description of arguments that this steals (ref count wise)
280+
is_borrowed: if True, returned value is borrowed (no need to decrease refcount)
306281
priority: if multiple ops match, the one with the highest priority is picked
307282
"""
308283
ops = c_method_call_ops.setdefault(name, [])
309284
desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type,
310-
c_function_name, error_kind, steals, ordering, extra_int_constants,
311-
priority)
285+
c_function_name, error_kind, steals, is_borrowed, ordering,
286+
extra_int_constants, priority)
312287
ops.append(desc)
313288
return desc
314289

@@ -323,6 +298,7 @@ def c_function_op(name: str,
323298
ordering: Optional[List[int]] = None,
324299
extra_int_constants: List[Tuple[int, RType]] = [],
325300
steals: StealsDescription = False,
301+
is_borrowed: bool = False,
326302
priority: int = 1) -> CFunctionDescription:
327303
"""Define a c function call op that replaces a function call.
328304
@@ -336,8 +312,8 @@ def c_function_op(name: str,
336312
"""
337313
ops = c_function_ops.setdefault(name, [])
338314
desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type,
339-
c_function_name, error_kind, steals, ordering, extra_int_constants,
340-
priority)
315+
c_function_name, error_kind, steals, is_borrowed, ordering,
316+
extra_int_constants, priority)
341317
ops.append(desc)
342318
return desc
343319

@@ -352,6 +328,7 @@ def c_binary_op(name: str,
352328
ordering: Optional[List[int]] = None,
353329
extra_int_constants: List[Tuple[int, RType]] = [],
354330
steals: StealsDescription = False,
331+
is_borrowed: bool = False,
355332
priority: int = 1) -> CFunctionDescription:
356333
"""Define a c function call op for a binary operation.
357334
@@ -362,8 +339,8 @@ def c_binary_op(name: str,
362339
"""
363340
ops = c_binary_ops.setdefault(name, [])
364341
desc = CFunctionDescription(name, arg_types, return_type, var_arg_type, truncated_type,
365-
c_function_name, error_kind, steals, ordering, extra_int_constants,
366-
priority)
342+
c_function_name, error_kind, steals, is_borrowed, ordering,
343+
extra_int_constants, priority)
367344
ops.append(desc)
368345
return desc
369346

@@ -376,13 +353,14 @@ def c_custom_op(arg_types: List[RType],
376353
truncated_type: Optional[RType] = None,
377354
ordering: Optional[List[int]] = None,
378355
extra_int_constants: List[Tuple[int, RType]] = [],
379-
steals: StealsDescription = False) -> CFunctionDescription:
356+
steals: StealsDescription = False,
357+
is_borrowed: bool = False) -> CFunctionDescription:
380358
"""Create a one-off CallC op that can't be automatically generated from the AST.
381359
382360
Most arguments are similar to c_method_op().
383361
"""
384362
return CFunctionDescription('<custom>', arg_types, return_type, var_arg_type, truncated_type,
385-
c_function_name, error_kind, steals, ordering,
363+
c_function_name, error_kind, steals, is_borrowed, ordering,
386364
extra_int_constants, 0)
387365

388366

@@ -395,6 +373,7 @@ def c_unary_op(name: str,
395373
ordering: Optional[List[int]] = None,
396374
extra_int_constants: List[Tuple[int, RType]] = [],
397375
steals: StealsDescription = False,
376+
is_borrowed: bool = False,
398377
priority: int = 1) -> CFunctionDescription:
399378
"""Define a c function call op for an unary operation.
400379
@@ -405,8 +384,8 @@ def c_unary_op(name: str,
405384
"""
406385
ops = c_unary_ops.setdefault(name, [])
407386
desc = CFunctionDescription(name, [arg_type], return_type, None, truncated_type,
408-
c_function_name, error_kind, steals, ordering, extra_int_constants,
409-
priority)
387+
c_function_name, error_kind, steals, is_borrowed, ordering,
388+
extra_int_constants, priority)
410389
ops.append(desc)
411390
return desc
412391

mypyc/primitives/str_ops.py

Lines changed: 10 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
"""Primitive str ops."""
22

3-
from typing import List, Callable, Tuple
3+
from typing import List, Tuple
44

5-
from mypyc.ir.ops import ERR_MAGIC, EmitterInterface
5+
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
66
from mypyc.ir.rtypes import (
7-
RType, object_rprimitive, str_rprimitive, bool_rprimitive, int_rprimitive, list_rprimitive,
7+
RType, object_rprimitive, str_rprimitive, int_rprimitive, list_rprimitive,
88
c_int_rprimitive, pointer_rprimitive
99
)
1010
from mypyc.primitives.registry import (
11-
binary_op, c_method_op, c_binary_op, c_function_op,
12-
load_address_op
11+
c_method_op, c_binary_op, c_function_op,
12+
load_address_op, c_custom_op
1313
)
1414

1515

@@ -80,30 +80,8 @@
8080
steals=[True, False])
8181

8282

83-
def emit_str_compare(comparison: str) -> Callable[[EmitterInterface, List[str], str], None]:
84-
def emit(emitter: EmitterInterface, args: List[str], dest: str) -> None:
85-
temp = emitter.temp_name()
86-
emitter.emit_declaration('int %s;' % temp)
87-
emitter.emit_lines(
88-
'%s = PyUnicode_Compare(%s, %s);' % (temp, args[0], args[1]),
89-
'if (%s == -1 && PyErr_Occurred())' % temp,
90-
' %s = 2;' % dest,
91-
'else',
92-
' %s = (%s %s);' % (dest, temp, comparison))
93-
94-
return emit
95-
96-
97-
# str1 == str2
98-
binary_op(op='==',
99-
arg_types=[str_rprimitive, str_rprimitive],
100-
result_type=bool_rprimitive,
101-
error_kind=ERR_MAGIC,
102-
emit=emit_str_compare('== 0'))
103-
104-
# str1 != str2
105-
binary_op(op='!=',
106-
arg_types=[str_rprimitive, str_rprimitive],
107-
result_type=bool_rprimitive,
108-
error_kind=ERR_MAGIC,
109-
emit=emit_str_compare('!= 0'))
83+
unicode_compare = c_custom_op(
84+
arg_types=[str_rprimitive, str_rprimitive],
85+
return_type=c_int_rprimitive,
86+
c_function_name='PyUnicode_Compare',
87+
error_kind=ERR_NEVER)

mypyc/test-data/irbuild-str.test

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,50 @@ L8:
5555
L9:
5656
r13 = PyUnicode_Split(s, 0, -1)
5757
return r13
58+
59+
[case testStrEquality]
60+
def eq(x: str, y: str) -> bool:
61+
return x == y
62+
63+
def neq(x: str, y: str) -> bool:
64+
return x != y
65+
66+
[out]
67+
def eq(x, y):
68+
x, y :: str
69+
r0 :: int32
70+
r1 :: bool
71+
r2 :: object
72+
r3, r4, r5 :: bool
73+
L0:
74+
r0 = PyUnicode_Compare(x, y)
75+
r1 = r0 == -1
76+
if r1 goto L1 else goto L3 :: bool
77+
L1:
78+
r2 = PyErr_Occurred()
79+
r3 = r2 != 0
80+
if r3 goto L2 else goto L3 :: bool
81+
L2:
82+
r4 = CPy_KeepPropagating()
83+
L3:
84+
r5 = r0 == 0
85+
return r5
86+
def neq(x, y):
87+
x, y :: str
88+
r0 :: int32
89+
r1 :: bool
90+
r2 :: object
91+
r3, r4, r5 :: bool
92+
L0:
93+
r0 = PyUnicode_Compare(x, y)
94+
r1 = r0 == -1
95+
if r1 goto L1 else goto L3 :: bool
96+
L1:
97+
r2 = PyErr_Occurred()
98+
r3 = r2 != 0
99+
if r3 goto L2 else goto L3 :: bool
100+
L2:
101+
r4 = CPy_KeepPropagating()
102+
L3:
103+
r5 = r0 != 0
104+
return r5

0 commit comments

Comments
 (0)