Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6339727

Browse files
JukkaLmsullivan
authored andcommitted
Support import cycles (mypyc/mypyc#195)
This approach is pretty janky -- the module init function can be called multiple times if there are import cycles, and we just return the previously created module object (which may be only partially initialized) if it's available. This Seems To Work but might break in some unforeseen ways. However, I propose that we keep this for now, at least until we support module top levels -- as that's where the potential problems with this approach seem likely to materialize. I manually verified that references to other modules within an import cycle seem to generate efficient code (i.e. not go through Python semantics) but there are no tests for that yet. I'll create a follow-up issue about this. Fixes mypyc/mypyc#164.
1 parent 7c07741 commit 6339727

5 files changed

Lines changed: 173 additions & 36 deletions

File tree

mypyc/emitmodule.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def generate_c_for_modules(self) -> str:
8282
module_irs = [module_ir for _, module_ir in self.modules]
8383

8484
for module_name, module in self.modules:
85+
self.declare_module(module_name, emitter)
8586
self.declare_internal_globals(module_name, emitter)
8687
self.declare_imports(module.imports, emitter)
8788

@@ -162,8 +163,12 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
162163
else:
163164
declaration = 'PyObject *CPyInit_{}(void)'
164165
emitter.emit_lines(declaration.format(module_name),
165-
'{',
166-
'PyObject *m;')
166+
'{')
167+
module_static = self.module_static_name(module_name, emitter)
168+
emitter.emit_lines('if ({} != NULL) {{'.format(module_static),
169+
'Py_INCREF({});'.format(module_static),
170+
'return {};'.format(module_static),
171+
'}')
167172
for cl in module.classes:
168173
type_struct = emitter.type_struct_name(cl)
169174
if cl.traits:
@@ -175,15 +180,16 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
175180

176181
emitter.emit_lines('if (PyType_Ready(&{}) < 0)'.format(type_struct),
177182
' return NULL;')
178-
emitter.emit_lines('m = PyModule_Create(&{}module);'.format(module_prefix),
179-
'if (m == NULL)',
183+
emitter.emit_lines('{} = PyModule_Create(&{}module);'.format(module_static, module_prefix),
184+
'if ({} == NULL)'.format(module_static),
180185
' return NULL;')
181186
module_globals = emitter.static_name('globals', module_name)
182-
emitter.emit_lines('{} = PyModule_GetDict(m);'.format(module_globals),
187+
emitter.emit_lines('{} = PyModule_GetDict({});'.format(module_globals, module_static),
183188
'if ({} == NULL)'.format(module_globals),
184189
' return NULL;')
185190
self.generate_imports_init_section(module.imports, emitter)
186191
self.generate_from_imports_init_section(
192+
module_static,
187193
module.imports,
188194
module.from_imports,
189195
emitter,
@@ -216,8 +222,9 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
216222
type_struct = emitter.type_struct_name(cl)
217223
emitter.emit_lines(
218224
'Py_INCREF(&{});'.format(type_struct),
219-
'PyModule_AddObject(m, "{}", (PyObject *)&{});'.format(name, type_struct))
220-
emitter.emit_line('return m;')
225+
'PyModule_AddObject({}, "{}", (PyObject *)&{});'.format(module_static, name,
226+
type_struct))
227+
emitter.emit_line('return {};'.format(module_static))
221228
emitter.emit_line('}')
222229

223230
def toposort_declarations(self) -> List[HeaderDeclaration]:
@@ -262,13 +269,16 @@ def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None:
262269
static_name = emitter.static_name('globals', module_name)
263270
self.declare_global('PyObject *', static_name)
264271

265-
def declare_import(self, imp: str, emitter: Emitter) -> None:
266-
static_name = emitter.static_name('module', imp)
272+
def module_static_name(self, module_name: str, emitter: Emitter) -> str:
273+
return emitter.static_name('module', module_name)
274+
275+
def declare_module(self, module_name: str, emitter: Emitter) -> None:
276+
static_name = self.module_static_name(module_name, emitter)
267277
self.declare_global('CPyModule *', static_name)
268278

269279
def declare_imports(self, imps: Iterable[str], emitter: Emitter) -> None:
270280
for imp in imps:
271-
self.declare_import(imp, emitter)
281+
self.declare_module(imp, emitter)
272282

273283
def declare_static_pyobject(self, identifier: str, emitter: Emitter) -> None:
274284
symbol = emitter.static_name(identifier, None)
@@ -280,7 +290,7 @@ def generate_imports_init_section(self, imps: List[str], emitter: Emitter) -> No
280290
self.generate_import(imp, emitter, check_for_null=True)
281291

282292
def generate_import(self, imp: str, emitter: Emitter, check_for_null: bool) -> None:
283-
c_name = emitter.static_name('module', imp)
293+
c_name = self.module_static_name(imp, emitter)
284294
if check_for_null:
285295
emitter.emit_line('if ({} == NULL) {{'.format(c_name))
286296
emitter.emit_line('{} = PyImport_ImportModule("{}");'.format(c_name, imp))
@@ -290,21 +300,22 @@ def generate_import(self, imp: str, emitter: Emitter, check_for_null: bool) -> N
290300
emitter.emit_line('}')
291301

292302
def generate_from_imports_init_section(self,
303+
module_static: str,
293304
imps: List[str],
294305
from_imps: Dict[str, List[Tuple[str, str]]],
295306
emitter: Emitter) -> None:
296307
for imp, import_names in from_imps.items():
297308
# Only import it again if we haven't imported it from the main
298309
# imports section
299310
if imp not in imps:
300-
c_name = emitter.static_name('module', imp)
311+
c_name = self.module_static_name(imp, emitter)
301312
emitter.emit_line('CPyModule *{};'.format(c_name))
302313
self.generate_import(imp, emitter, check_for_null=False)
303314

304315
for original_name, as_name in import_names:
305316
# Obtain a reference to the original object
306317
object_temp_name = emitter.temp_name()
307-
c_name = emitter.static_name('module', imp)
318+
c_name = self.module_static_name(imp, emitter)
308319
emitter.emit_line('PyObject *{} = PyObject_GetAttrString({}, "{}");'.format(
309320
object_temp_name,
310321
c_name,
@@ -315,7 +326,8 @@ def generate_from_imports_init_section(self,
315326
' return NULL;',
316327
)
317328
# and add it to the namespace of the current module, which eats the ref
318-
emitter.emit_line('if (PyModule_AddObject(m, "{}", {}) < 0)'.format(
329+
emitter.emit_line('if (PyModule_AddObject({}, "{}", {}) < 0)'.format(
330+
module_static,
319331
as_name,
320332
object_temp_name,
321333
))
@@ -324,7 +336,7 @@ def generate_from_imports_init_section(self,
324336
# This particular import isn't saved as a global so we should decref it
325337
# and not keep it around
326338
if imp not in imps:
327-
c_name = emitter.static_name('module', imp)
339+
c_name = self.module_static_name(imp, emitter)
328340
emitter.emit_line('Py_DECREF({});'.format(c_name))
329341

330342

mypyc/genops.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -981,16 +981,16 @@ def visit_int_expr(self, expr: IntExpr) -> Value:
981981
def visit_float_expr(self, expr: FloatExpr) -> Value:
982982
return self.load_static_float(expr.value)
983983

984-
def is_native_name_expr(self, expr: NameExpr) -> bool:
985-
assert expr.node, "RefExpr not resolved"
984+
def is_native_ref_expr(self, expr: RefExpr) -> bool:
985+
if expr.node is None:
986+
return False
986987
if '.' in expr.node.fullname():
987988
module_name = '.'.join(expr.node.fullname().split('.')[:-1])
988989
return module_name in self.modules
989-
990990
return True
991991

992-
def is_native_module_name_expr(self, expr: NameExpr) -> bool:
993-
return self.is_native_name_expr(expr) and expr.kind == GDEF
992+
def is_native_module_ref_expr(self, expr: RefExpr) -> bool:
993+
return self.is_native_ref_expr(expr) and expr.kind == GDEF
994994

995995
def visit_name_expr(self, expr: NameExpr) -> Value:
996996
assert expr.node, "RefExpr not resolved"
@@ -1059,14 +1059,17 @@ def visit_call_expr(self, expr: CallExpr) -> Value:
10591059
callee = callee.analyzed.expr # Unwrap type application
10601060

10611061
if isinstance(callee, MemberExpr):
1062-
# TODO: Could be call to module-level function
1063-
return self.translate_method_call(expr, callee)
1062+
if self.is_native_ref_expr(callee):
1063+
# Call to module-level function or such
1064+
return self.translate_call(expr, callee)
1065+
else:
1066+
return self.translate_method_call(expr, callee)
10641067
else:
10651068
return self.translate_call(expr, callee)
10661069

10671070
def translate_call(self, expr: CallExpr, callee: Expression) -> Value:
10681071
"""Translate a non-method call."""
1069-
assert isinstance(callee, NameExpr) # TODO: Allow arbitrary callees
1072+
assert isinstance(callee, RefExpr) # TODO: Allow arbitrary callees
10701073

10711074
# Gen the args
10721075
fullname = callee.fullname
@@ -1080,9 +1083,9 @@ def translate_call(self, expr: CallExpr, callee: Expression) -> Value:
10801083
if (fullname == 'builtins.isinstance'
10811084
and len(expr.args) == 2
10821085
and expr.arg_kinds == [ARG_POS, ARG_POS]
1083-
and isinstance(expr.args[1], NameExpr)
1086+
and isinstance(expr.args[1], RefExpr)
10841087
and isinstance(expr.args[1].node, TypeInfo)
1085-
and self.is_native_module_name_expr(expr.args[1])):
1088+
and self.is_native_module_ref_expr(expr.args[1])):
10861089
# Special case native isinstance() checks as this makes them much faster.
10871090
return self.primitive_op(fast_isinstance_op, args, expr.line)
10881091

@@ -1111,13 +1114,13 @@ def translate_call(self, expr: CallExpr, callee: Expression) -> Value:
11111114
function = self.accept(callee)
11121115
return self.py_call(function, args, target_type, expr.line)
11131116

1114-
def get_native_signature(self, callee: NameExpr) -> Optional[CallableType]:
1117+
def get_native_signature(self, callee: RefExpr) -> Optional[CallableType]:
11151118
"""Get the signature of a native function, or return None if not available.
11161119
11171120
This only works for normal functions, not methods.
11181121
"""
11191122
signature = None
1120-
if self.is_native_module_name_expr(callee):
1123+
if self.is_native_module_ref_expr(callee):
11211124
node = callee.node
11221125
if isinstance(node, TypeInfo):
11231126
node = node['__init__'].node
@@ -1599,7 +1602,7 @@ def instantiate_function_class(self, fdef: FuncDef, namespace: str) -> Value:
15991602
func_reg = self.environment.add_local(fdef, object_rprimitive)
16001603
return self.add(Assign(func_reg, temp_reg))
16011604

1602-
def is_builtin_name_expr(self, expr: NameExpr) -> bool:
1605+
def is_builtin_ref_expr(self, expr: RefExpr) -> bool:
16031606
assert expr.node, "RefExpr not resolved"
16041607
return '.' in expr.node.fullname() and expr.node.fullname().split('.')[0] == 'builtins'
16051608

@@ -1610,9 +1613,9 @@ def load_global(self, expr: NameExpr) -> Value:
16101613
from the _globals dictionary in the C-generated code.
16111614
"""
16121615
# If the global is from 'builtins', turn it into a module attr load instead
1613-
if self.is_builtin_name_expr(expr):
1616+
if self.is_builtin_ref_expr(expr):
16141617
return self.load_static_module_attr(expr)
1615-
if self.is_native_module_name_expr(expr) and isinstance(expr.node, TypeInfo):
1618+
if self.is_native_module_ref_expr(expr) and isinstance(expr.node, TypeInfo):
16161619
assert expr.fullname is not None
16171620
return self.load_native_type_object(expr.fullname)
16181621
_globals = self.add(LoadStatic(object_rprimitive, 'globals', self.module_name))

mypyc/ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from abc import abstractmethod, abstractproperty
1414
import re
1515
from typing import (
16-
List, Sequence, Dict, Generic, TypeVar, Optional, Any, NamedTuple, Tuple, NewType, Callable,
16+
List, Sequence, Dict, Generic, TypeVar, Optional, Any, NamedTuple, Tuple, Callable,
1717
Union, Iterable, Type,
1818
)
1919
from collections import OrderedDict

test-data/module-output.test

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def f(x: int) -> int:
88
#include <Python.h>
99
#include <CPy.h>
1010

11+
static CPyModule *CPyStatic_module;
1112
static PyObject *CPyStatic_globals;
1213
static CPyModule *CPyStatic_builtins_module;
1314
static CPyTagged CPyDef_f(CPyTagged cpy_r_x);
@@ -29,19 +30,22 @@ static struct PyModuleDef module = {
2930

3031
PyMODINIT_FUNC PyInit_prog(void)
3132
{
32-
PyObject *m;
33-
m = PyModule_Create(&module);
34-
if (m == NULL)
33+
if (CPyStatic_module != NULL) {
34+
Py_INCREF(CPyStatic_module);
35+
return CPyStatic_module;
36+
}
37+
CPyStatic_module = PyModule_Create(&module);
38+
if (CPyStatic_module == NULL)
3539
return NULL;
36-
CPyStatic_globals = PyModule_GetDict(m);
40+
CPyStatic_globals = PyModule_GetDict(CPyStatic_module);
3741
if (CPyStatic_globals == NULL)
3842
return NULL;
3943
if (CPyStatic_builtins_module == NULL) {
4044
CPyStatic_builtins_module = PyImport_ImportModule("builtins");
4145
if (CPyStatic_builtins_module == NULL)
4246
return NULL;
4347
}
44-
return m;
48+
return CPyStatic_module;
4549
}
4650

4751
static CPyTagged CPyDef_f(CPyTagged cpy_r_x) {

test-data/run-multimodule.test

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,121 @@ Traceback (most recent call last):
212212
fail2()
213213
File "tmp/other.py", line 3, in fail2
214214
IndexError: list assignment index out of range
215+
216+
[case testMultiModuleCycle]
217+
import other
218+
219+
def f1() -> int:
220+
return other.f2()
221+
222+
def f3() -> int:
223+
return 5
224+
[file other.py]
225+
import native
226+
227+
def f2() -> int:
228+
return native.f3()
229+
[file driver.py]
230+
from native import f1
231+
assert f1() == 5
232+
233+
[case testMultiModuleCycleWithClasses]
234+
import other
235+
236+
class D: pass
237+
238+
def f() -> other.C:
239+
return other.C()
240+
241+
def g(c: other.C) -> D:
242+
return c.d
243+
244+
[file other.py]
245+
import native
246+
247+
class C:
248+
def __init__(self) -> None:
249+
self.d = native.D()
250+
251+
def h(d: native.D) -> None:
252+
pass
253+
254+
[file driver.py]
255+
from native import f, g
256+
from other import C, h
257+
258+
c = f()
259+
assert isinstance(c, C)
260+
assert g(c) is c.d
261+
h(c.d)
262+
263+
try:
264+
g(1)
265+
except TypeError:
266+
pass
267+
else:
268+
assert False
269+
270+
try:
271+
h(1)
272+
except TypeError:
273+
pass
274+
else:
275+
assert False
276+
277+
[case testMultiModuleCycleWithInheritance]
278+
import other
279+
280+
class Deriv1(other.Base1):
281+
pass
282+
283+
class Base2:
284+
y: int
285+
def __init__(self) -> None:
286+
self.y = 2
287+
288+
[file other.py]
289+
import native
290+
291+
class Base1:
292+
x: int
293+
def __init__(self) -> None:
294+
self.x = 1
295+
296+
class Deriv2(native.Base2):
297+
pass
298+
299+
[file driver.py]
300+
from native import Deriv1
301+
from other import Deriv2
302+
a = Deriv1()
303+
assert a.x == 1
304+
b = Deriv2()
305+
assert b.y == 2
306+
307+
[case testImportCycleWithNonCompiledModule]
308+
import m
309+
310+
class C: pass
311+
312+
def f1() -> int:
313+
m.D()
314+
return m.f2()
315+
316+
def f3() -> int:
317+
return 2
318+
319+
[file m.py]
320+
# This module is NOT compiled
321+
import native
322+
323+
class D: pass
324+
325+
def f2() -> int:
326+
native.C()
327+
return native.f3()
328+
329+
[file driver.py]
330+
from native import f1
331+
332+
assert f1() == 2

0 commit comments

Comments
 (0)