@@ -82,6 +82,7 @@ def generate_c_for_modules(self) -> str:
8282 module_irs = [module_ir for _ , module_ir in self .modules ]
8383
8484 for module_name , module in self .modules :
85+ self .declare_module (module_name , emitter )
8586 self .declare_internal_globals (module_name , emitter )
8687 self .declare_imports (module .imports , emitter )
8788
@@ -162,8 +163,12 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
162163 else :
163164 declaration = 'PyObject *CPyInit_{}(void)'
164165 emitter .emit_lines (declaration .format (module_name ),
165- '{' ,
166- 'PyObject *m;' )
166+ '{' )
167+ module_static = self .module_static_name (module_name , emitter )
168+ emitter .emit_lines ('if ({} != NULL) {{' .format (module_static ),
169+ 'Py_INCREF({});' .format (module_static ),
170+ 'return {};' .format (module_static ),
171+ '}' )
167172 for cl in module .classes :
168173 type_struct = emitter .type_struct_name (cl )
169174 if cl .traits :
@@ -175,15 +180,16 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
175180
176181 emitter .emit_lines ('if (PyType_Ready(&{}) < 0)' .format (type_struct ),
177182 ' return NULL;' )
178- emitter .emit_lines ('m = PyModule_Create(&{}module);' .format (module_prefix ),
179- 'if (m == NULL)' ,
183+ emitter .emit_lines ('{} = PyModule_Create(&{}module);' .format (module_static , module_prefix ),
184+ 'if ({} == NULL)' . format ( module_static ) ,
180185 ' return NULL;' )
181186 module_globals = emitter .static_name ('globals' , module_name )
182- emitter .emit_lines ('{} = PyModule_GetDict(m );' .format (module_globals ),
187+ emitter .emit_lines ('{} = PyModule_GetDict({} );' .format (module_globals , module_static ),
183188 'if ({} == NULL)' .format (module_globals ),
184189 ' return NULL;' )
185190 self .generate_imports_init_section (module .imports , emitter )
186191 self .generate_from_imports_init_section (
192+ module_static ,
187193 module .imports ,
188194 module .from_imports ,
189195 emitter ,
@@ -216,8 +222,9 @@ def generate_module_def(self, emitter: Emitter, module_name: str, module: Module
216222 type_struct = emitter .type_struct_name (cl )
217223 emitter .emit_lines (
218224 'Py_INCREF(&{});' .format (type_struct ),
219- 'PyModule_AddObject(m, "{}", (PyObject *)&{});' .format (name , type_struct ))
220- emitter .emit_line ('return m;' )
225+ 'PyModule_AddObject({}, "{}", (PyObject *)&{});' .format (module_static , name ,
226+ type_struct ))
227+ emitter .emit_line ('return {};' .format (module_static ))
221228 emitter .emit_line ('}' )
222229
223230 def toposort_declarations (self ) -> List [HeaderDeclaration ]:
@@ -262,13 +269,16 @@ def declare_internal_globals(self, module_name: str, emitter: Emitter) -> None:
262269 static_name = emitter .static_name ('globals' , module_name )
263270 self .declare_global ('PyObject *' , static_name )
264271
265- def declare_import (self , imp : str , emitter : Emitter ) -> None :
266- static_name = emitter .static_name ('module' , imp )
272+ def module_static_name (self , module_name : str , emitter : Emitter ) -> str :
273+ return emitter .static_name ('module' , module_name )
274+
275+ def declare_module (self , module_name : str , emitter : Emitter ) -> None :
276+ static_name = self .module_static_name (module_name , emitter )
267277 self .declare_global ('CPyModule *' , static_name )
268278
269279 def declare_imports (self , imps : Iterable [str ], emitter : Emitter ) -> None :
270280 for imp in imps :
271- self .declare_import (imp , emitter )
281+ self .declare_module (imp , emitter )
272282
273283 def declare_static_pyobject (self , identifier : str , emitter : Emitter ) -> None :
274284 symbol = emitter .static_name (identifier , None )
@@ -280,7 +290,7 @@ def generate_imports_init_section(self, imps: List[str], emitter: Emitter) -> No
280290 self .generate_import (imp , emitter , check_for_null = True )
281291
282292 def generate_import (self , imp : str , emitter : Emitter , check_for_null : bool ) -> None :
283- c_name = emitter . static_name ( 'module' , imp )
293+ c_name = self . module_static_name ( imp , emitter )
284294 if check_for_null :
285295 emitter .emit_line ('if ({} == NULL) {{' .format (c_name ))
286296 emitter .emit_line ('{} = PyImport_ImportModule("{}");' .format (c_name , imp ))
@@ -290,21 +300,22 @@ def generate_import(self, imp: str, emitter: Emitter, check_for_null: bool) -> N
290300 emitter .emit_line ('}' )
291301
292302 def generate_from_imports_init_section (self ,
303+ module_static : str ,
293304 imps : List [str ],
294305 from_imps : Dict [str , List [Tuple [str , str ]]],
295306 emitter : Emitter ) -> None :
296307 for imp , import_names in from_imps .items ():
297308 # Only import it again if we haven't imported it from the main
298309 # imports section
299310 if imp not in imps :
300- c_name = emitter . static_name ( 'module' , imp )
311+ c_name = self . module_static_name ( imp , emitter )
301312 emitter .emit_line ('CPyModule *{};' .format (c_name ))
302313 self .generate_import (imp , emitter , check_for_null = False )
303314
304315 for original_name , as_name in import_names :
305316 # Obtain a reference to the original object
306317 object_temp_name = emitter .temp_name ()
307- c_name = emitter . static_name ( 'module' , imp )
318+ c_name = self . module_static_name ( imp , emitter )
308319 emitter .emit_line ('PyObject *{} = PyObject_GetAttrString({}, "{}");' .format (
309320 object_temp_name ,
310321 c_name ,
@@ -315,7 +326,8 @@ def generate_from_imports_init_section(self,
315326 ' return NULL;' ,
316327 )
317328 # and add it to the namespace of the current module, which eats the ref
318- emitter .emit_line ('if (PyModule_AddObject(m, "{}", {}) < 0)' .format (
329+ emitter .emit_line ('if (PyModule_AddObject({}, "{}", {}) < 0)' .format (
330+ module_static ,
319331 as_name ,
320332 object_temp_name ,
321333 ))
@@ -324,7 +336,7 @@ def generate_from_imports_init_section(self,
324336 # This particular import isn't saved as a global so we should decref it
325337 # and not keep it around
326338 if imp not in imps :
327- c_name = emitter . static_name ( 'module' , imp )
339+ c_name = self . module_static_name ( imp , emitter )
328340 emitter .emit_line ('Py_DECREF({});' .format (c_name ))
329341
330342
0 commit comments