Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e5e4532

Browse files
ilevkivskyiJukkaL
authored andcommitted
Add plugin hook for dynamic class definition (#5875)
Fixes #5508. The new hook allows this: ``` from some_lib import dynamic_base Base = dynamic_base() class C(Base): # No error, the plugin acts and replaces 'Base' with a TypeInfo ... ``` This plugin hook is useful for SQLAlchemy ORM, for user-defined constructs similar to namedtuple(), and maybe we can even re-implement namedtuple() as a plugin at some point.
1 parent 52d937a commit e5e4532

7 files changed

Lines changed: 201 additions & 8 deletions

File tree

mypy/interpreted_plugin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,8 @@ def get_customize_class_mro_hook(self, fullname: str
6767
) -> Optional[Callable[['mypy.plugin.ClassDefContext'],
6868
None]]:
6969
return None
70+
71+
def get_dynamic_class_hook(self, fullname: str
72+
) -> Optional[Callable[['mypy.plugin.DynamicClassDefContext'],
73+
None]]:
74+
return None

mypy/plugin.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from mypy.nodes import (
1111
Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr, ClassDef,
12-
TypeInfo, SymbolTableNode, MypyFile
12+
TypeInfo, SymbolTableNode, MypyFile, CallExpr
1313
)
1414
from mypy.tvar_scope import TypeVarScope
1515
from mypy.types import (
@@ -69,6 +69,7 @@ class SemanticAnalyzerPluginInterface:
6969

7070
modules = None # type: Dict[str, MypyFile]
7171
options = None # type: Options
72+
cur_mod_id = None # type: str
7273
msg = None # type: MessageBuilder
7374

7475
@abstractmethod
@@ -117,6 +118,15 @@ def lookup_qualified(self, name: str, ctx: Context,
117118
def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None:
118119
raise NotImplementedError
119120

121+
@abstractmethod
122+
def add_symbol_table_node(self, name: str, stnode: SymbolTableNode) -> None:
123+
"""Add node to global symbol table (or to nearest class if there is one)."""
124+
raise NotImplementedError
125+
126+
@abstractmethod
127+
def qualified_name(self, n: str) -> str:
128+
raise NotImplementedError
129+
120130

121131
# A context for a function hook that infers the return type of a function with
122132
# a special signature.
@@ -165,12 +175,21 @@ def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> N
165175

166176
# A context for a class hook that modifies the class definition.
167177
ClassDefContext = NamedTuple(
168-
'ClassDecoratorContext', [
178+
'ClassDefContext', [
169179
('cls', ClassDef), # The class definition
170180
('reason', Expression), # The expression being applied (decorator, metaclass, base class)
171181
('api', SemanticAnalyzerPluginInterface)
172182
])
173183

184+
# A context for dynamic class definitions like
185+
# Base = declarative_base()
186+
DynamicClassDefContext = NamedTuple(
187+
'DynamicClassDefContext', [
188+
('call', CallExpr), # The r.h.s. of dynamic class definition
189+
('name', str), # The name this class is being assigned to
190+
('api', SemanticAnalyzerPluginInterface)
191+
])
192+
174193

175194
class Plugin:
176195
"""Base class of all type checker plugins.
@@ -225,6 +244,10 @@ def get_customize_class_mro_hook(self, fullname: str
225244
) -> Optional[Callable[[ClassDefContext], None]]:
226245
return None
227246

247+
def get_dynamic_class_hook(self, fullname: str
248+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
249+
return None
250+
228251

229252
T = TypeVar('T')
230253

@@ -280,6 +303,10 @@ def get_customize_class_mro_hook(self, fullname: str
280303
) -> Optional[Callable[[ClassDefContext], None]]:
281304
return self.plugin.get_customize_class_mro_hook(fullname)
282305

306+
def get_dynamic_class_hook(self, fullname: str
307+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
308+
return self.plugin.get_dynamic_class_hook(fullname)
309+
283310

284311
class ChainedPlugin(Plugin):
285312
"""A plugin that represents a sequence of chained plugins.
@@ -337,6 +364,10 @@ def get_customize_class_mro_hook(self, fullname: str
337364
) -> Optional[Callable[[ClassDefContext], None]]:
338365
return self._find_hook(lambda plugin: plugin.get_customize_class_mro_hook(fullname))
339366

367+
def get_dynamic_class_hook(self, fullname: str
368+
) -> Optional[Callable[[DynamicClassDefContext], None]]:
369+
return self._find_hook(lambda plugin: plugin.get_dynamic_class_hook(fullname))
370+
340371
def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
341372
for plugin in self._plugins:
342373
hook = lookup(plugin)

mypy/semanal.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@
8080
from mypy.sametypes import is_same_type
8181
from mypy.options import Options
8282
from mypy import experiments
83-
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
83+
from mypy.plugin import (
84+
Plugin, ClassDefContext, SemanticAnalyzerPluginInterface,
85+
DynamicClassDefContext
86+
)
8487
from mypy.util import get_prefix, correct_relative_import
8588
from mypy.semanal_shared import SemanticAnalyzerInterface, set_callable_name
8689
from mypy.scope import Scope
@@ -1729,6 +1732,7 @@ def final_cb(keep_final: bool) -> None:
17291732
# Store type into nodes.
17301733
for lvalue in s.lvalues:
17311734
self.store_declared_types(lvalue, s.type)
1735+
self.apply_dynamic_class_hook(s)
17321736
self.check_and_set_up_type_alias(s)
17331737
self.newtype_analyzer.process_newtype_declaration(s)
17341738
self.process_typevar_declaration(s)
@@ -1744,6 +1748,21 @@ def final_cb(keep_final: bool) -> None:
17441748
isinstance(s.rvalue, (ListExpr, TupleExpr))):
17451749
self.add_exports(s.rvalue.items)
17461750

1751+
def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None:
1752+
if len(s.lvalues) > 1:
1753+
return
1754+
lval = s.lvalues[0]
1755+
if not isinstance(lval, NameExpr) or not isinstance(s.rvalue, CallExpr):
1756+
return
1757+
call = s.rvalue
1758+
if not isinstance(call.callee, RefExpr):
1759+
return
1760+
fname = call.callee.fullname
1761+
if fname:
1762+
hook = self.plugin.get_dynamic_class_hook(fname)
1763+
if hook:
1764+
hook(DynamicClassDefContext(call, lval.name, self))
1765+
17471766
def unwrap_final(self, s: AssignmentStmt) -> None:
17481767
"""Strip Final[...] if present in an assignment.
17491768

mypy/test/testdiff.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from mypy.server.astdiff import snapshot_symbol_table, compare_symbol_table_snapshots
1313
from mypy.test.config import test_temp_dir
1414
from mypy.test.data import DataDrivenTestCase, DataSuite
15-
from mypy.test.helpers import assert_string_arrays_equal
15+
from mypy.test.helpers import assert_string_arrays_equal, parse_options
1616

1717

1818
class ASTDiffSuite(DataSuite):
@@ -22,9 +22,10 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
2222
first_src = '\n'.join(testcase.input)
2323
files_dict = dict(testcase.files)
2424
second_src = files_dict['tmp/next.py']
25+
options = parse_options(first_src, testcase, 1)
2526

26-
messages1, files1 = self.build(first_src)
27-
messages2, files2 = self.build(second_src)
27+
messages1, files1 = self.build(first_src, options)
28+
messages2, files2 = self.build(second_src, options)
2829

2930
a = []
3031
if messages1:
@@ -47,8 +48,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
4748
'Invalid output ({}, line {})'.format(testcase.file,
4849
testcase.line))
4950

50-
def build(self, source: str) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]:
51-
options = Options()
51+
def build(self, source: str,
52+
options: Options) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]:
5253
options.use_builtins_fixtures = True
5354
options.show_traceback = True
5455
options.cache_dir = os.devnull

test-data/unit/check-custom-plugin.test

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,54 @@ reveal_type(FullyQualifiedTestNamedTuple('')._asdict()) # E: Revealed type is 'b
280280
[[mypy]
281281
plugins=<ROOT>/test-data/unit/plugins/fully_qualified_test_hook.py
282282
[builtins fixtures/classmethod.pyi]
283+
284+
[case testDynamicClassPlugin]
285+
# flags: --config-file tmp/mypy.ini
286+
from mod import declarative_base, Column, Instr
287+
288+
Base = declarative_base()
289+
290+
class Model(Base):
291+
x: Column[int]
292+
class Other:
293+
x: Column[int]
294+
295+
reveal_type(Model().x) # E: Revealed type is 'mod.Instr[builtins.int]'
296+
reveal_type(Other().x) # E: Revealed type is 'mod.Column[builtins.int]'
297+
[file mod.py]
298+
from typing import Generic, TypeVar
299+
def declarative_base(): ...
300+
301+
T = TypeVar('T')
302+
303+
class Column(Generic[T]): ...
304+
class Instr(Generic[T]): ...
305+
306+
[file mypy.ini]
307+
[[mypy]
308+
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
309+
310+
[case testDynamicClassPluginNegatives]
311+
# flags: --config-file tmp/mypy.ini
312+
from mod import declarative_base, Column, Instr, non_declarative_base
313+
314+
Bad1 = non_declarative_base()
315+
Bad2 = Bad3 = declarative_base()
316+
317+
class C1(Bad1): ... # E: Invalid base class
318+
class C2(Bad2): ... # E: Invalid base class
319+
class C3(Bad3): ... # E: Invalid base class
320+
321+
[file mod.py]
322+
from typing import Generic, TypeVar
323+
def declarative_base(): ...
324+
def non_declarative_base(): ...
325+
326+
T = TypeVar('T')
327+
328+
class Column(Generic[T]): ...
329+
class Instr(Generic[T]): ...
330+
331+
[file mypy.ini]
332+
[[mypy]
333+
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py

test-data/unit/diff.test

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,3 +1070,42 @@ class C:
10701070
pass
10711071
[out]
10721072
__main__.C.m
1073+
1074+
[case testDynamicBasePluginDiff]
1075+
# flags: --config-file tmp/mypy.ini
1076+
from mod import declarative_base, Column, Instr
1077+
1078+
Base = declarative_base()
1079+
1080+
class Model(Base):
1081+
x: Column[int]
1082+
class Other:
1083+
x: Column[int]
1084+
class Diff:
1085+
x: Column[int]
1086+
[file next.py]
1087+
from mod import declarative_base, Column, Instr
1088+
1089+
Base = declarative_base()
1090+
1091+
class Model(Base):
1092+
x: Column[int]
1093+
class Other:
1094+
x: Column[int]
1095+
class Diff(Base):
1096+
x: Column[int]
1097+
[file mod.py]
1098+
from typing import Generic, TypeVar
1099+
def declarative_base(): ...
1100+
1101+
T = TypeVar('T')
1102+
1103+
class Column(Generic[T]): ...
1104+
class Instr(Generic[T]): ...
1105+
1106+
[file mypy.ini]
1107+
[[mypy]
1108+
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
1109+
[out]
1110+
__main__.Diff
1111+
__main__.Diff.x
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from mypy.plugin import Plugin
2+
from mypy.nodes import (
3+
ClassDef, Block, TypeInfo, SymbolTable, SymbolTableNode, GDEF, Var
4+
)
5+
from mypy.types import Instance
6+
7+
DECL_BASES = set()
8+
9+
class DynPlugin(Plugin):
10+
def get_dynamic_class_hook(self, fullname):
11+
if fullname == 'mod.declarative_base':
12+
return add_info_hook
13+
return None
14+
15+
def get_base_class_hook(self, fullname: str):
16+
if fullname in DECL_BASES:
17+
return replace_col_hook
18+
return None
19+
20+
def add_info_hook(ctx):
21+
class_def = ClassDef(ctx.name, Block([]))
22+
class_def.fullname = ctx.api.qualified_name(ctx.name)
23+
24+
info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
25+
class_def.info = info
26+
obj = ctx.api.builtin_type('builtins.object')
27+
info.mro = [info, obj.type]
28+
info.bases = [obj]
29+
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
30+
DECL_BASES.add(class_def.fullname)
31+
32+
def replace_col_hook(ctx):
33+
info = ctx.cls.info
34+
for sym in info.names.values():
35+
node = sym.node
36+
if isinstance(node, Var) and isinstance(node.type, Instance):
37+
if node.type.type.fullname() == 'mod.Column':
38+
new_sym = ctx.api.lookup_fully_qualified_or_none('mod.Instr')
39+
if new_sym:
40+
new_info = new_sym.node
41+
assert isinstance(new_info, TypeInfo)
42+
node.type = Instance(new_info, node.type.args.copy(),
43+
node.type.line,
44+
node.type.column)
45+
46+
def plugin(version):
47+
return DynPlugin

0 commit comments

Comments
 (0)