-
-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Expand file tree
/
Copy pathapplytype.py
More file actions
303 lines (269 loc) · 11.8 KB
/
applytype.py
File metadata and controls
303 lines (269 loc) · 11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
from __future__ import annotations
from collections.abc import Callable, Iterable, Sequence
import mypy.subtypes
from mypy.erasetype import erase_typevars
from mypy.expandtype import expand_type
from mypy.nodes import Context, TypeInfo
from mypy.type_visitor import TypeTranslator
from mypy.typeops import get_all_type_vars
from mypy.types import (
AnyType,
CallableType,
Instance,
Parameters,
ParamSpecFlavor,
ParamSpecType,
PartialType,
ProperType,
Type,
TypeAliasType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
UninhabitedType,
UnpackType,
get_proper_type,
remove_dups,
)
def get_target_type(
tvar: TypeVarLikeType,
type: Type,
callable: CallableType,
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool,
) -> Type | None:
p_type = get_proper_type(type)
if isinstance(p_type, UninhabitedType) and tvar.has_default():
return tvar.default
if isinstance(tvar, ParamSpecType):
return type
if isinstance(tvar, TypeVarTupleType):
return type
assert isinstance(tvar, TypeVarType)
values = tvar.values
if values:
if isinstance(p_type, AnyType):
return type
if isinstance(p_type, TypeVarType) and p_type.values:
# Allow substituting T1 for T if every allowed value of T1
# is also a legal value of T.
if all(any(mypy.subtypes.is_same_type(v, v1) for v in values) for v1 in p_type.values):
return type
matching = []
for value in values:
if mypy.subtypes.is_subtype(type, value):
matching.append(value)
if matching:
best = matching[0]
# If there are more than one matching value, we select the narrowest
for match in matching[1:]:
if mypy.subtypes.is_subtype(match, best):
best = match
return best
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
else:
upper_bound = tvar.upper_bound
if tvar.name == "Self":
# Internally constructed Self-types contain class type variables in upper bound,
# so we need to erase them to avoid false positives. This is safe because we do
# not support type variables in upper bounds of user defined types.
upper_bound = erase_typevars(upper_bound)
if not mypy.subtypes.is_subtype(type, upper_bound):
if skip_unsatisfied:
return None
report_incompatible_typevar_value(callable, type, tvar.name, context)
return type
def apply_generic_arguments(
callable: CallableType,
orig_types: Sequence[Type | None],
report_incompatible_typevar_value: Callable[[CallableType, Type, str, Context], None],
context: Context,
skip_unsatisfied: bool = False,
) -> CallableType:
"""Apply generic type arguments to a callable type.
For example, applying [int] to 'def [T] (T) -> T' results in
'def (int) -> int'.
Note that each type can be None; in this case, it will not be applied.
If `skip_unsatisfied` is True, then just skip the types that don't satisfy type variable
bound or constraints, instead of giving an error.
"""
tvars = callable.variables
assert len(orig_types) <= len(tvars)
# Check that inferred type variable values are compatible with allowed
# values and bounds. Also, promote subtype values to allowed values.
# Create a map from type variable id to target type.
id_to_type: dict[TypeVarId, Type] = {}
for tvar, type in zip(tvars, orig_types):
assert not isinstance(type, PartialType), "Internal error: must never apply partial type"
if type is None:
continue
target_type = get_target_type(
tvar, type, callable, report_incompatible_typevar_value, context, skip_unsatisfied
)
if target_type is not None:
id_to_type[tvar.id] = target_type
# TODO: validate arg_kinds/arg_names for ParamSpec and TypeVarTuple replacements,
# not just type variable bounds above.
param_spec = callable.param_spec()
if param_spec is not None:
nt = id_to_type.get(param_spec.id)
if nt is not None:
# ParamSpec expansion is special-cased, so we need to always expand callable
# as a whole, not expanding arguments individually.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(
variables=[tv for tv in tvars if tv.id not in id_to_type]
)
# Apply arguments to argument types.
var_arg = callable.var_arg()
if var_arg is not None and isinstance(var_arg.typ, UnpackType):
# Same as for ParamSpec, callable with variadic types needs to be expanded as a whole.
callable = expand_type(callable, id_to_type)
assert isinstance(callable, CallableType)
return callable.copy_modified(variables=[tv for tv in tvars if tv.id not in id_to_type])
else:
callable = callable.copy_modified(
arg_types=[expand_type(at, id_to_type) for at in callable.arg_types]
)
# Apply arguments to TypeGuard and TypeIs if any.
if callable.type_guard is not None:
type_guard = expand_type(callable.type_guard, id_to_type)
else:
type_guard = None
if callable.type_is is not None:
type_is = expand_type(callable.type_is, id_to_type)
else:
type_is = None
# The callable may retain some type vars if only some were applied.
# TODO: move apply_poly() logic here when new inference
# becomes universally used (i.e. in all passes + in unification).
# With this new logic we can actually *add* some new free variables.
remaining_tvars: list[TypeVarLikeType] = []
for tv in tvars:
if tv.id in id_to_type:
continue
if not tv.has_default():
remaining_tvars.append(tv)
continue
# TypeVarLike isn't in id_to_type mapping.
# Only expand the TypeVar default here.
typ = expand_type(tv, id_to_type)
assert isinstance(typ, TypeVarLikeType)
remaining_tvars.append(typ)
return callable.copy_modified(
ret_type=expand_type(callable.ret_type, id_to_type),
variables=remaining_tvars,
type_guard=type_guard,
type_is=type_is,
)
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
"""Make free type variables generic in the type if possible.
This will translate the type `tp` while trying to create valid bindings for
type variables `poly_tvars` while traversing the type. This follows the same rules
as we do during semantic analysis phase, examples:
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
* List[T] -> None (not possible)
"""
try:
return tp.copy_modified(
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
variables=[],
)
except PolyTranslationError:
return None
class PolyTranslationError(Exception):
pass
class PolyTranslator(TypeTranslator):
"""Make free type variables generic in the type if possible.
See docstring for apply_poly() for details.
"""
def __init__(
self,
poly_tvars: Iterable[TypeVarLikeType],
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
super().__init__()
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars = bound_tvars
self.seen_aliases = seen_aliases
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
for arg in t.arg_types:
for tv in get_all_type_vars(arg):
if isinstance(tv, ParamSpecType):
normalized: TypeVarLikeType = tv.copy_modified(
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
)
else:
normalized = tv
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
found_vars.append(normalized)
return remove_dups(found_vars)
def visit_callable_type(self, t: CallableType) -> Type:
found_vars = self.collect_vars(t)
self.bound_tvars |= set(found_vars)
result = super().visit_callable_type(t)
self.bound_tvars -= set(found_vars)
assert isinstance(result, ProperType) and isinstance(result, CallableType)
result.variables = result.variables + tuple(found_vars)
return result
def visit_type_var(self, t: TypeVarType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var(t)
def visit_param_spec(self, t: ParamSpecType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_param_spec(t)
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
if t in self.poly_tvars and t not in self.bound_tvars:
raise PolyTranslationError()
return super().visit_type_var_tuple(t)
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
if not t.args:
return t.copy_modified()
if not t.is_recursive:
return get_proper_type(t).accept(self)
# We can't handle polymorphic application for recursive generic aliases
# without risking an infinite recursion, just give up for now.
raise PolyTranslationError()
def visit_instance(self, t: Instance) -> Type:
if t.type.has_param_spec_type:
# We need this special-casing to preserve the possibility to store a
# generic function in an instance type. Things like
# forall T . Foo[[x: T], T]
# are not really expressible in current type system, but this looks like
# a useful feature, so let's keep it.
param_spec_index = next(
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
)
p = get_proper_type(t.args[param_spec_index])
if isinstance(p, Parameters):
found_vars = self.collect_vars(p)
self.bound_tvars |= set(found_vars)
new_args = [a.accept(self) for a in t.args]
self.bound_tvars -= set(found_vars)
repl = new_args[param_spec_index]
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
repl.variables = list(repl.variables) + list(found_vars)
return t.copy_modified(args=new_args)
# There is the same problem with callback protocols as with aliases
# (callback protocols are essentially more flexible aliases to callables).
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
call = mypy.subtypes.find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
)
return super().visit_instance(t)