|
28 | 28 | from mypyc.common import ( |
29 | 29 | BITMAP_BITS, |
30 | 30 | BITMAP_TYPE, |
| 31 | + CPYFUNCTION_NAME, |
31 | 32 | NATIVE_PREFIX, |
32 | 33 | PREFIX, |
33 | 34 | REG_PREFIX, |
@@ -411,7 +412,9 @@ def emit_line() -> None: |
411 | 412 |
|
412 | 413 | emitter.emit_line() |
413 | 414 | if generate_full: |
414 | | - generate_setup_for_class(cl, defaults_fn, vtable_name, shadow_vtable_name, emitter) |
| 415 | + generate_setup_for_class( |
| 416 | + cl, defaults_fn, vtable_name, shadow_vtable_name, coroutine_setup_name, emitter |
| 417 | + ) |
415 | 418 | emitter.emit_line() |
416 | 419 | generate_constructor_for_class(cl, cl.ctor, init_fn, setup_name, vtable_name, emitter) |
417 | 420 | emitter.emit_line() |
@@ -603,6 +606,7 @@ def generate_setup_for_class( |
603 | 606 | defaults_fn: FuncIR | None, |
604 | 607 | vtable_name: str, |
605 | 608 | shadow_vtable_name: str | None, |
| 609 | + coroutine_setup_name: str, |
606 | 610 | emitter: Emitter, |
607 | 611 | ) -> None: |
608 | 612 | """Generate a native function that allocates an instance of a class.""" |
@@ -658,6 +662,13 @@ def generate_setup_for_class( |
658 | 662 | if defaults_fn is not None: |
659 | 663 | emit_attr_defaults_func_call(defaults_fn, "self", emitter) |
660 | 664 |
|
| 665 | + # Initialize function wrapper for callable classes. As opposed to regular functions, |
| 666 | + # each instance of a callable class needs its own wrapper because they might be instantiated |
| 667 | + # inside other functions. |
| 668 | + if cl.coroutine_name: |
| 669 | + emitter.emit_line(f"if ({NATIVE_PREFIX}{coroutine_setup_name}((PyObject *)self) != 1)") |
| 670 | + emitter.emit_line(" return NULL;") |
| 671 | + |
661 | 672 | emitter.emit_line("return (PyObject *)self;") |
662 | 673 | emitter.emit_line("}") |
663 | 674 |
|
@@ -1281,27 +1292,40 @@ def generate_coroutine_setup( |
1281 | 1292 | emitter.emit_line(f"{NATIVE_PREFIX}{coroutine_setup_name}(PyObject *type)") |
1282 | 1293 | emitter.emit_line("{") |
1283 | 1294 |
|
1284 | | - if not any(fn.decl.is_coroutine for fn in cl.methods.values()): |
| 1295 | + error_stmt = " return 2;" |
| 1296 | + |
| 1297 | + def emit_instance(fn: FuncIR, fn_name: str) -> str: |
| 1298 | + filepath = emitter.filepath or "" |
| 1299 | + return emitter.emit_cpyfunction_instance(fn, fn_name, filepath, error_stmt) |
| 1300 | + |
| 1301 | + def success() -> None: |
1285 | 1302 | emitter.emit_line("return 1;") |
1286 | 1303 | emitter.emit_line("}") |
1287 | | - return |
| 1304 | + |
| 1305 | + if cl.coroutine_name: |
| 1306 | + # Callable class generated for a coroutine. It stores its function wrapper as an attribute. |
| 1307 | + wrapper_name = emit_instance(cl.methods["__call__"], cl.coroutine_name) |
| 1308 | + struct_name = cl.struct_name(emitter.names) |
| 1309 | + attr = emitter.attr(CPYFUNCTION_NAME) |
| 1310 | + emitter.emit_line(f"(({struct_name} *)type)->{attr} = {wrapper_name};") |
| 1311 | + return success() |
| 1312 | + |
| 1313 | + if not any(fn.decl.is_coroutine for fn in cl.methods.values()): |
| 1314 | + return success() |
1288 | 1315 |
|
1289 | 1316 | emitter.emit_line("PyTypeObject *tp = (PyTypeObject *)type;") |
1290 | 1317 |
|
1291 | 1318 | for fn in cl.methods.values(): |
1292 | 1319 | if not fn.decl.is_coroutine: |
1293 | 1320 | continue |
1294 | 1321 |
|
1295 | | - filepath = emitter.filepath or "" |
1296 | | - error_stmt = " return 2;" |
1297 | 1322 | name = short_id_from_name(fn.name, fn.decl.shortname, fn.line) |
1298 | | - wrapper_name = emitter.emit_cpyfunction_instance(fn, name, filepath, error_stmt) |
| 1323 | + wrapper_name = emit_instance(fn, name) |
1299 | 1324 | name_obj = f"{wrapper_name}_name" |
1300 | 1325 | emitter.emit_line(f'PyObject *{name_obj} = PyUnicode_FromString("{fn.name}");') |
1301 | 1326 | emitter.emit_line(f"if (unlikely(!{name_obj}))") |
1302 | 1327 | emitter.emit_line(error_stmt) |
1303 | 1328 | emitter.emit_line(f"if (PyDict_SetItem(tp->tp_dict, {name_obj}, {wrapper_name}) < 0)") |
1304 | 1329 | emitter.emit_line(error_stmt) |
1305 | 1330 |
|
1306 | | - emitter.emit_line("return 1;") |
1307 | | - emitter.emit_line("}") |
| 1331 | + return success() |
0 commit comments