@@ -17,8 +17,8 @@ def f(x: int) -> int:
1717from typing import Dict , List , Tuple , Optional , Union
1818
1919from mypy .nodes import (
20- Node , MypyFile , FuncDef , ReturnStmt , AssignmentStmt , OpExpr , IntExpr , NameExpr , LDEF , Var ,
21- IfStmt , UnaryExpr , ComparisonExpr , WhileStmt , Argument , CallExpr , IndexExpr , Block ,
20+ Node , MypyFile , SymbolNode , FuncDef , ReturnStmt , AssignmentStmt , OpExpr , IntExpr , NameExpr ,
21+ LDEF , Var , IfStmt , UnaryExpr , ComparisonExpr , WhileStmt , Argument , CallExpr , IndexExpr , Block ,
2222 Expression , ListExpr , ExpressionStmt , MemberExpr , ForStmt , RefExpr , Lvalue , BreakStmt ,
2323 ContinueStmt , ConditionalExpr , OperatorAssignmentStmt , TupleExpr , ClassDef , TypeInfo ,
2424 Import , ImportFrom , ImportAll , DictExpr , StrExpr , CastExpr , TempNode , ARG_POS , MODULE_REF ,
@@ -235,6 +235,7 @@ def __init__(self,
235235 self .types = types
236236 self .environment = Environment ()
237237 self .environments = [self .environment ]
238+ self .ret_types = [] # type: List[RType]
238239 self .blocks = [] # type: List[List[BasicBlock]]
239240 self .functions = [] # type: List[FuncIR]
240241 self .classes = [] # type: List[ClassIR]
@@ -324,23 +325,45 @@ def visit_import_all(self, node: ImportAll) -> Value:
324325 return INVALID_VALUE
325326
326327 def gen_func_def (self , fdef : FuncDef , class_name : Optional [str ] = None ) -> FuncIR :
327- self .enter ()
328+ # If there is more than one environment in the environment stack, then we are visiting a
329+ # non-global function.
330+ is_nested = len (self .environments ) > 1
328331
332+ self .enter (fdef .name ())
333+
334+ if is_nested :
335+ # If this is a nested function, then add a 'self' field to the environment, since we
336+ # will be instantiating the function as a method of a new class representing that
337+ # original function.
338+ self .environment .add_local (Var ('self' ), object_rprimitive , is_arg = True )
329339 for arg in fdef .arguments :
330340 assert arg .variable .type , "Function argument missing type"
331341 self .environment .add_local (arg .variable , self .type_to_rtype (arg .variable .type ),
332342 is_arg = True )
333- self .ret_type = self .convert_return_type (fdef )
343+ self .ret_types [- 1 ] = self .convert_return_type (fdef )
344+
334345 fdef .body .accept (self )
335346
336- if is_none_rprimitive (self .ret_type ) or is_object_rprimitive (self .ret_type ):
347+ if (is_none_rprimitive (self .ret_types [- 1 ]) or
348+ is_object_rprimitive (self .ret_types [- 1 ])):
337349 self .add_implicit_return ()
338350 else :
339351 self .add_implicit_unreachable ()
340352
341- blocks , env = self .leave ()
353+ blocks , env , ret_type = self .leave ()
342354 args = self .convert_args (fdef )
343- return FuncIR (fdef .name (), class_name , self .module_name , args , self .ret_type , blocks , env )
355+
356+ if is_nested :
357+ namespace = self .generate_function_namespace ()
358+ func_ir = self .generate_function_class (fdef , namespace , blocks , env , ret_type )
359+
360+ # Instantiate the callable class and load it into a register in the current environment
361+ # immediately so that it does not have to be loaded every time the function is called.
362+ self .instantiate_function_class (fdef , namespace )
363+ else :
364+ func_ir = FuncIR (fdef .name (), class_name , self .module_name , args , ret_type , blocks ,
365+ env )
366+ return func_ir
344367
345368 def visit_func_def (self , fdef : FuncDef ) -> Value :
346369 self .functions .append (self .gen_func_def (fdef ))
@@ -379,7 +402,7 @@ def visit_expression_stmt(self, stmt: ExpressionStmt) -> Value:
379402 def visit_return_stmt (self , stmt : ReturnStmt ) -> Value :
380403 if stmt .expr :
381404 retval = self .accept (stmt .expr )
382- retval = self .coerce (retval , self .ret_type , stmt .line )
405+ retval = self .coerce (retval , self .ret_types [ - 1 ] , stmt .line )
383406 else :
384407 retval = self .add (PrimitiveOp ([], none_op , line = - 1 ))
385408 self .add (Return (retval ))
@@ -813,15 +836,15 @@ def visit_name_expr(self, expr: NameExpr) -> Value:
813836 if not self .is_native_name_expr (expr ):
814837 return self .load_static_module_attr (expr )
815838
816- # TODO: We assume that this is a Var or FuncDef node, which is very limited
817- if isinstance (expr .node , Var ):
818- return self .environment .lookup (expr .node )
819- if isinstance (expr .node , FuncDef ):
820- # If we have a function, then we can look it up in the global variables dictionary.
839+ # TODO: Behavior currently only defined for Var and FuncDef node types.
840+ if expr .kind == LDEF :
841+ try :
842+ return self .environment .lookup (expr .node )
843+ except KeyError :
844+ assert False , 'expression %s not defined in current scope' .format (expr .name )
845+ else :
821846 return self .load_global (expr )
822847
823- assert False , 'node must be of either Var or FuncDef type'
824-
825848 def is_global_name (self , name : str ) -> bool :
826849 # TODO: this is pretty hokey
827850 for _ , names in self .from_imports .items ():
@@ -1264,9 +1287,10 @@ def visit_yield_expr(self, o: YieldExpr) -> Value:
12641287
12651288 # Helpers
12661289
1267- def enter (self ) -> None :
1268- self .environment = Environment ()
1290+ def enter (self , name : Optional [ str ] = None ) -> None :
1291+ self .environment = Environment (name )
12691292 self .environments .append (self .environment )
1293+ self .ret_types .append (none_rprimitive )
12701294 self .blocks .append ([])
12711295 self .new_block ()
12721296
@@ -1282,18 +1306,22 @@ def goto_new_block(self) -> BasicBlock:
12821306 goto .label = block .label
12831307 return block
12841308
1285- def leave (self ) -> Tuple [List [BasicBlock ], Environment ]:
1309+ def leave (self ) -> Tuple [List [BasicBlock ], Environment , RType ]:
12861310 blocks = self .blocks .pop ()
12871311 env = self .environments .pop ()
1312+ ret_type = self .ret_types .pop ()
12881313 self .environment = self .environments [- 1 ]
1289- return blocks , env
1314+ return blocks , env , ret_type
12901315
12911316 def add (self , op : Op ) -> Value :
12921317 self .blocks [- 1 ][- 1 ].ops .append (op )
12931318 if isinstance (op , RegisterOp ):
12941319 self .environment .add_op (op )
12951320 return op
12961321
1322+ def generate_function_namespace (self ) -> str :
1323+ return '_' .join (env .name for env in self .environments if env .name )
1324+
12971325 def primitive_op (self , desc : OpDescription , args : List [Value ], line : int ) -> Value :
12981326 assert desc .result_type is not None
12991327 coerced = []
@@ -1345,6 +1373,44 @@ def unbox_or_cast(self, src: Value, target_type: RType, line: int) -> Value:
13451373 def box_expr (self , expr : Expression ) -> Value :
13461374 return self .box (self .accept (expr ))
13471375
1376+ def generate_function_class (self ,
1377+ fdef : FuncDef ,
1378+ namespace : str ,
1379+ blocks : List [BasicBlock ],
1380+ env : Environment ,
1381+ ret_type : RType ) -> FuncIR :
1382+ """Generates a callable class representing a nested function.
1383+
1384+ This takes a FuncDef and its associated namespace, blocks, environment, and return type and
1385+ builds a ClassIR with its '__call__' method implemented to represent the function. Note
1386+ that the name of the function is changed to be '__call__', and a 'self' parameter is added
1387+ to its list of arguments, as it becomes a class method. The name of the newly constructed
1388+ class is generated using the names of the functions that enclose the given nested function.
1389+
1390+ Returns a newly constructed FuncIR associated with the given FuncDef.
1391+ """
1392+ class_name = '{}_{}_obj' .format (fdef .name (), namespace )
1393+ args = self .convert_args (fdef )
1394+ args .insert (0 , RuntimeArg ('self' , object_rprimitive ))
1395+ func_ir = FuncIR ('__call__' , class_name , self .module_name , args , ret_type , blocks , env )
1396+ class_ir = ClassIR (class_name , self .module_name )
1397+ class_ir .methods .append (func_ir )
1398+ self .classes .append (class_ir )
1399+ return func_ir
1400+
1401+ def instantiate_function_class (self , fdef : FuncDef , namespace : str ) -> Value :
1402+ """Assigns a callable class to a register named after the given function definition."""
1403+ temp_reg = self .load_function_class (fdef , namespace )
1404+ func_reg = self .environment .add_local (fdef , object_rprimitive )
1405+ return self .add (Assign (func_reg , temp_reg ))
1406+
1407+ def load_function_class (self , fdef : FuncDef , namespace : str ) -> Value :
1408+ """Loads a callable class representing a nested function into a register."""
1409+ return self .add (Call (self .convert_return_type (fdef ),
1410+ '{}.{}_{}_obj' .format (self .module_name , fdef .name (), namespace ),
1411+ [],
1412+ fdef .line ))
1413+
13481414 def load_global (self , expr : NameExpr ) -> Value :
13491415 """Loads a Python-level global.
13501416
0 commit comments