Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b913793

Browse files
hughhan1msullivan
authored andcommitted
Nested functions working without free variables (mypyc/mypyc#149)
Nested functions that do not reference free variables now work. Functions can be multiply nested, and same-named functions in different scopes won't conflict. The environment symbol table can now take FuncDef node types in addition to Var node types.
1 parent 6b364e6 commit b913793

6 files changed

Lines changed: 290 additions & 35 deletions

File tree

mypyc/emitclass.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ def emit_line() -> None:
4848
init_name = '0'
4949
init_args = []
5050

51+
call_fn = cl.get_method('__call__')
52+
call_name = '{}{}'.format(PREFIX, call_fn.cname(emitter.names)) if call_fn else '0'
53+
5154
emitter.emit_line('static PyObject *{}(void);'.format(setup_name))
5255
# TODO: Use RInstance
5356
ctor = FuncIR(cl.name, None, module, init_args, object_rprimitive, [], Environment())
@@ -88,7 +91,7 @@ def emit_line() -> None:
8891
0, /* tp_as_sequence */
8992
0, /* tp_as_mapping */
9093
0, /* tp_hash */
91-
0, /* tp_call */
94+
{tp_call}, /* tp_call */
9295
0, /* tp_str */
9396
0, /* tp_getattro */
9497
0, /* tp_setattro */
@@ -119,6 +122,7 @@ def emit_line() -> None:
119122
traverse_name=traverse_name,
120123
clear_name=clear_name,
121124
dealloc_name=dealloc_name,
125+
tp_call=call_name,
122126
new_name=new_name,
123127
methods_name=methods_name,
124128
getseters_name=getseters_name,

mypyc/genops.py

Lines changed: 85 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def f(x: int) -> int:
1717
from typing import Dict, List, Tuple, Optional, Union
1818

1919
from mypy.nodes import (
20-
Node, MypyFile, FuncDef, ReturnStmt, AssignmentStmt, OpExpr, IntExpr, NameExpr, LDEF, Var,
21-
IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, Argument, CallExpr, IndexExpr, Block,
20+
Node, MypyFile, SymbolNode, FuncDef, ReturnStmt, AssignmentStmt, OpExpr, IntExpr, NameExpr,
21+
LDEF, Var, IfStmt, UnaryExpr, ComparisonExpr, WhileStmt, Argument, CallExpr, IndexExpr, Block,
2222
Expression, ListExpr, ExpressionStmt, MemberExpr, ForStmt, RefExpr, Lvalue, BreakStmt,
2323
ContinueStmt, ConditionalExpr, OperatorAssignmentStmt, TupleExpr, ClassDef, TypeInfo,
2424
Import, ImportFrom, ImportAll, DictExpr, StrExpr, CastExpr, TempNode, ARG_POS, MODULE_REF,
@@ -235,6 +235,7 @@ def __init__(self,
235235
self.types = types
236236
self.environment = Environment()
237237
self.environments = [self.environment]
238+
self.ret_types = [] # type: List[RType]
238239
self.blocks = [] # type: List[List[BasicBlock]]
239240
self.functions = [] # type: List[FuncIR]
240241
self.classes = [] # type: List[ClassIR]
@@ -324,23 +325,45 @@ def visit_import_all(self, node: ImportAll) -> Value:
324325
return INVALID_VALUE
325326

326327
def gen_func_def(self, fdef: FuncDef, class_name: Optional[str] = None) -> FuncIR:
327-
self.enter()
328+
# If there is more than one environment in the environment stack, then we are visiting a
329+
# non-global function.
330+
is_nested = len(self.environments) > 1
328331

332+
self.enter(fdef.name())
333+
334+
if is_nested:
335+
# If this is a nested function, then add a 'self' field to the environment, since we
336+
# will be instantiating the function as a method of a new class representing that
337+
# original function.
338+
self.environment.add_local(Var('self'), object_rprimitive, is_arg=True)
329339
for arg in fdef.arguments:
330340
assert arg.variable.type, "Function argument missing type"
331341
self.environment.add_local(arg.variable, self.type_to_rtype(arg.variable.type),
332342
is_arg=True)
333-
self.ret_type = self.convert_return_type(fdef)
343+
self.ret_types[-1] = self.convert_return_type(fdef)
344+
334345
fdef.body.accept(self)
335346

336-
if is_none_rprimitive(self.ret_type) or is_object_rprimitive(self.ret_type):
347+
if (is_none_rprimitive(self.ret_types[-1]) or
348+
is_object_rprimitive(self.ret_types[-1])):
337349
self.add_implicit_return()
338350
else:
339351
self.add_implicit_unreachable()
340352

341-
blocks, env = self.leave()
353+
blocks, env, ret_type = self.leave()
342354
args = self.convert_args(fdef)
343-
return FuncIR(fdef.name(), class_name, self.module_name, args, self.ret_type, blocks, env)
355+
356+
if is_nested:
357+
namespace = self.generate_function_namespace()
358+
func_ir = self.generate_function_class(fdef, namespace, blocks, env, ret_type)
359+
360+
# Instantiate the callable class and load it into a register in the current environment
361+
# immediately so that it does not have to be loaded every time the function is called.
362+
self.instantiate_function_class(fdef, namespace)
363+
else:
364+
func_ir = FuncIR(fdef.name(), class_name, self.module_name, args, ret_type, blocks,
365+
env)
366+
return func_ir
344367

345368
def visit_func_def(self, fdef: FuncDef) -> Value:
346369
self.functions.append(self.gen_func_def(fdef))
@@ -379,7 +402,7 @@ def visit_expression_stmt(self, stmt: ExpressionStmt) -> Value:
379402
def visit_return_stmt(self, stmt: ReturnStmt) -> Value:
380403
if stmt.expr:
381404
retval = self.accept(stmt.expr)
382-
retval = self.coerce(retval, self.ret_type, stmt.line)
405+
retval = self.coerce(retval, self.ret_types[-1], stmt.line)
383406
else:
384407
retval = self.add(PrimitiveOp([], none_op, line=-1))
385408
self.add(Return(retval))
@@ -813,15 +836,15 @@ def visit_name_expr(self, expr: NameExpr) -> Value:
813836
if not self.is_native_name_expr(expr):
814837
return self.load_static_module_attr(expr)
815838

816-
# TODO: We assume that this is a Var or FuncDef node, which is very limited
817-
if isinstance(expr.node, Var):
818-
return self.environment.lookup(expr.node)
819-
if isinstance(expr.node, FuncDef):
820-
# If we have a function, then we can look it up in the global variables dictionary.
839+
# TODO: Behavior currently only defined for Var and FuncDef node types.
840+
if expr.kind == LDEF:
841+
try:
842+
return self.environment.lookup(expr.node)
843+
except KeyError:
844+
assert False, 'expression %s not defined in current scope'.format(expr.name)
845+
else:
821846
return self.load_global(expr)
822847

823-
assert False, 'node must be of either Var or FuncDef type'
824-
825848
def is_global_name(self, name: str) -> bool:
826849
# TODO: this is pretty hokey
827850
for _, names in self.from_imports.items():
@@ -1264,9 +1287,10 @@ def visit_yield_expr(self, o: YieldExpr) -> Value:
12641287

12651288
# Helpers
12661289

1267-
def enter(self) -> None:
1268-
self.environment = Environment()
1290+
def enter(self, name: Optional[str] = None) -> None:
1291+
self.environment = Environment(name)
12691292
self.environments.append(self.environment)
1293+
self.ret_types.append(none_rprimitive)
12701294
self.blocks.append([])
12711295
self.new_block()
12721296

@@ -1282,18 +1306,22 @@ def goto_new_block(self) -> BasicBlock:
12821306
goto.label = block.label
12831307
return block
12841308

1285-
def leave(self) -> Tuple[List[BasicBlock], Environment]:
1309+
def leave(self) -> Tuple[List[BasicBlock], Environment, RType]:
12861310
blocks = self.blocks.pop()
12871311
env = self.environments.pop()
1312+
ret_type = self.ret_types.pop()
12881313
self.environment = self.environments[-1]
1289-
return blocks, env
1314+
return blocks, env, ret_type
12901315

12911316
def add(self, op: Op) -> Value:
12921317
self.blocks[-1][-1].ops.append(op)
12931318
if isinstance(op, RegisterOp):
12941319
self.environment.add_op(op)
12951320
return op
12961321

1322+
def generate_function_namespace(self) -> str:
1323+
return '_'.join(env.name for env in self.environments if env.name)
1324+
12971325
def primitive_op(self, desc: OpDescription, args: List[Value], line: int) -> Value:
12981326
assert desc.result_type is not None
12991327
coerced = []
@@ -1345,6 +1373,44 @@ def unbox_or_cast(self, src: Value, target_type: RType, line: int) -> Value:
13451373
def box_expr(self, expr: Expression) -> Value:
13461374
return self.box(self.accept(expr))
13471375

1376+
def generate_function_class(self,
1377+
fdef: FuncDef,
1378+
namespace: str,
1379+
blocks: List[BasicBlock],
1380+
env: Environment,
1381+
ret_type: RType) -> FuncIR:
1382+
"""Generates a callable class representing a nested function.
1383+
1384+
This takes a FuncDef and its associated namespace, blocks, environment, and return type and
1385+
builds a ClassIR with its '__call__' method implemented to represent the function. Note
1386+
that the name of the function is changed to be '__call__', and a 'self' parameter is added
1387+
to its list of arguments, as it becomes a class method. The name of the newly constructed
1388+
class is generated using the names of the functions that enclose the given nested function.
1389+
1390+
Returns a newly constructed FuncIR associated with the given FuncDef.
1391+
"""
1392+
class_name = '{}_{}_obj'.format(fdef.name(), namespace)
1393+
args = self.convert_args(fdef)
1394+
args.insert(0, RuntimeArg('self', object_rprimitive))
1395+
func_ir = FuncIR('__call__', class_name, self.module_name, args, ret_type, blocks, env)
1396+
class_ir = ClassIR(class_name, self.module_name)
1397+
class_ir.methods.append(func_ir)
1398+
self.classes.append(class_ir)
1399+
return func_ir
1400+
1401+
def instantiate_function_class(self, fdef: FuncDef, namespace: str) -> Value:
1402+
"""Assigns a callable class to a register named after the given function definition."""
1403+
temp_reg = self.load_function_class(fdef, namespace)
1404+
func_reg = self.environment.add_local(fdef, object_rprimitive)
1405+
return self.add(Assign(func_reg, temp_reg))
1406+
1407+
def load_function_class(self, fdef: FuncDef, namespace: str) -> Value:
1408+
"""Loads a callable class representing a nested function into a register."""
1409+
return self.add(Call(self.convert_return_type(fdef),
1410+
'{}.{}_{}_obj'.format(self.module_name, fdef.name(), namespace),
1411+
[],
1412+
fdef.line))
1413+
13481414
def load_global(self, expr: NameExpr) -> Value:
13491415
"""Loads a Python-level global.
13501416

mypyc/ops.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
import re
1515
from typing import (
1616
List, Dict, Generic, TypeVar, Optional, Any, NamedTuple, Tuple, NewType, Callable, Union,
17-
Iterable, Type,
17+
Iterable, Type, Sequence,
1818
)
1919
from collections import OrderedDict
2020

21-
from mypy.nodes import Var
21+
from mypy.nodes import SymbolNode, Var, FuncDef
2222

2323
from mypyc.namegen import NameGenerator
2424

@@ -322,9 +322,10 @@ def __hash__(self) -> int:
322322
class Environment:
323323
"""Maintain the register symbol table and manage temp generation"""
324324

325-
def __init__(self) -> None:
325+
def __init__(self, name: Optional[str] = None) -> None:
326+
self.name = name
326327
self.indexes = OrderedDict() # type: Dict[Value, int]
327-
self.symtable = {} # type: Dict[Var, Register]
328+
self.symtable = {} # type: Dict[SymbolNode, Register]
328329
self.temp_index = 0
329330

330331
def regs(self) -> Iterable['Value']:
@@ -334,16 +335,16 @@ def add(self, reg: 'Value', name: str) -> None:
334335
reg.name = name
335336
self.indexes[reg] = len(self.indexes)
336337

337-
def add_local(self, var: Var, typ: RType, is_arg: bool = False) -> 'Register':
338-
assert isinstance(var, Var)
339-
reg = Register(typ, var.line, is_arg = is_arg)
338+
def add_local(self, symbol: SymbolNode, typ: RType, is_arg: bool = False) -> 'Register':
339+
assert isinstance(symbol, SymbolNode)
340+
reg = Register(typ, symbol.line, is_arg = is_arg)
340341

341-
self.symtable[var] = reg
342-
self.add(reg, var.name())
342+
self.symtable[symbol] = reg
343+
self.add(reg, symbol.name())
343344
return reg
344345

345-
def lookup(self, var: Var) -> 'Register':
346-
return self.symtable[var]
346+
def lookup(self, symbol: SymbolNode) -> 'Register':
347+
return self.symtable[symbol]
347348

348349
def add_temp(self, typ: RType) -> 'Register':
349350
assert isinstance(typ, RType)
@@ -663,7 +664,7 @@ class Call(RegisterOp):
663664
error_kind = ERR_MAGIC
664665

665666
# TODO: take a FuncIR and extract the ret type
666-
def __init__(self, ret_type: RType, fn: str, args: List[Value], line: int) -> None:
667+
def __init__(self, ret_type: RType, fn: str, args: Sequence[Value], line: int) -> None:
667668
super().__init__(line)
668669
self.fn = fn
669670
self.args = args
@@ -679,7 +680,7 @@ def to_str(self, env: Environment) -> str:
679680
return s
680681

681682
def sources(self) -> List[Value]:
682-
return self.args[:]
683+
return list(self.args[:])
683684

684685
def accept(self, visitor: 'OpVisitor[T]') -> T:
685686
return visitor.visit_call(self)

mypyc/test/test_emitfunc.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,8 @@ def setUp(self) -> None:
268268

269269
def test_simple(self) -> None:
270270
self.block.ops.append(Return(self.reg))
271-
fn = FuncIR('myfunc', None, 'mod', [self.arg], int_rprimitive, [self.block], self.env)
271+
fn = FuncIR('myfunc', None, 'mod', [self.arg], int_rprimitive, [self.block],
272+
self.env)
272273
emitter = Emitter(EmitterContext(['mod']))
273274
generate_native_function(fn, emitter, 'prog.py')
274275
result = emitter.fragments
@@ -286,7 +287,8 @@ def test_register(self) -> None:
286287
op = LoadInt(5)
287288
self.block.ops.append(op)
288289
self.env.add_op(op)
289-
fn = FuncIR('myfunc', None, 'mod', [self.arg], list_rprimitive, [self.block], self.env)
290+
fn = FuncIR('myfunc', None, 'mod', [self.arg], list_rprimitive, [self.block],
291+
self.env)
290292
emitter = Emitter(EmitterContext(['mod']))
291293
generate_native_function(fn, emitter, 'prog.py')
292294
result = emitter.fragments

0 commit comments

Comments
 (0)