Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Next Next commit
Use polymorphic inference in unification
  • Loading branch information
ilevkivskyi committed Jun 8, 2024
commit fad1094afa288d65aec7664c5adb733191f2b473
40 changes: 33 additions & 7 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ def __init__(
# new uses of this, as this may cause leaking `UnboundType`s to type checking.
self.allow_unbound_tvars = False

# Used to pass information about current overload index to visit_func_def().
self.current_overload_item: int | None = None

# mypyc doesn't properly handle implementing an abstractproperty
# with a regular attribute so we make them properties
@property
Expand Down Expand Up @@ -869,6 +872,15 @@ def visit_func_def(self, defn: FuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_func_def(defn)

def function_fullname(self, fullname: str) -> str:
if self.current_overload_item is None:
return fullname
if self.current_overload_item < 0:
suffix = "impl"
else:
suffix = str(self.current_overload_item)
return f"{fullname}#{suffix}"

def analyze_func_def(self, defn: FuncDef) -> None:
if self.push_type_args(defn.type_args, defn) is None:
self.defer(defn)
Expand All @@ -895,7 +907,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.prepare_method_signature(defn, self.type, has_self_type)

# Analyze function signature
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname =self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
if defn.type:
self.check_classvar_in_signature(defn.type)
assert isinstance(defn.type, CallableType)
Expand All @@ -904,7 +917,7 @@ def analyze_func_def(self, defn: FuncDef) -> None:
analyzer = self.type_analyzer()
tag = self.track_incomplete_refs()
result = analyzer.visit_callable_type(
defn.type, nested=False, namespace=defn.fullname
defn.type, nested=False, namespace=fullname
)
# Don't store not ready types (including placeholders).
if self.found_incomplete_ref(tag) or has_placeholder(result):
Expand Down Expand Up @@ -1117,7 +1130,8 @@ def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem)
if defn is generic. Return True, if the signature contains typing.Self
type, or False otherwise.
"""
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
a = self.type_analyzer()
fun_type.variables, has_self_type = a.bind_function_type_variables(fun_type, defn)
if has_self_type and self.type is not None:
Expand Down Expand Up @@ -1175,6 +1189,14 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
with self.scope.function_scope(defn):
self.analyze_overloaded_func_def(defn)

@contextmanager
def overload_item_set(self, item: int) -> Iterator[None]:
self.current_overload_item = item
try:
yield
finally:
self.current_overload_item = None

def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
# OverloadedFuncDef refers to any legitimate situation where you have
# more than one declaration for the same function in a row. This occurs
Expand All @@ -1187,7 +1209,8 @@ def analyze_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:

first_item = defn.items[0]
first_item.is_overload = True
first_item.accept(self)
with self.overload_item_set(0):
first_item.accept(self)

if isinstance(first_item, Decorator) and first_item.func.is_property:
# This is a property.
Expand Down Expand Up @@ -1272,7 +1295,8 @@ def analyze_overload_sigs_and_impl(
if i != 0:
# Assume that the first item was already visited
item.is_overload = True
item.accept(self)
with self.overload_item_set(i if i < len(defn.items) - 1 else -1):
item.accept(self)
# TODO: support decorated overloaded functions properly
if isinstance(item, Decorator):
callable = function_type(item.func, self.named_type("builtins.function"))
Expand Down Expand Up @@ -1444,15 +1468,17 @@ def add_function_to_symbol_table(self, func: FuncDef | OverloadedFuncDef) -> Non
self.add_symbol(func.name, func, func)

def analyze_arg_initializers(self, defn: FuncItem) -> None:
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Analyze default arguments
for arg in defn.arguments:
if arg.initializer:
arg.initializer.accept(self)

def analyze_function_body(self, defn: FuncItem) -> None:
is_method = self.is_class_scope()
with self.tvar_scope_frame(self.tvar_scope.method_frame(defn.fullname)):
fullname = self.function_fullname(defn.fullname)
with self.tvar_scope_frame(self.tvar_scope.method_frame(fullname)):
# Bind the type variables again to visit the body.
if defn.type:
a = self.type_analyzer()
Expand Down
9 changes: 8 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
is a linear constraint. This is however not true in presence of union types, for example
T :> Union[S, int] vs S <: T. Trying to solve such constraints would be detected ambiguous
as (T, S) form a non-linear SCC. However, simply removing the linear part results in a valid
solution T = Union[S, int], S = <free>.
solution T = Union[S, int], S = <free>. A similar scenario is when we get T <: Union[T, int],
such constraints carry no information, and will equally confuse linearity check.

TODO: a cleaner solution may be to avoid inferring such constraints in first place, but
this would require passing around a flag through all infer_constraints() calls.
Expand All @@ -525,7 +526,13 @@ def skip_reverse_union_constraints(cs: list[Constraint]) -> list[Constraint]:
if isinstance(p_target, UnionType):
for item in p_target.items:
if isinstance(item, TypeVarType):
if item == c.origin_type_var and c.op == SUBTYPE_OF:
reverse_union_cs.add(c)
continue
# These two forms are semantically identical, but are different from
# the point of view of Constraint.__eq__().
reverse_union_cs.add(Constraint(item, neg_op(c.op), c.origin_type_var))
reverse_union_cs.add(Constraint(c.origin_type_var, c.op, item))
return [c for c in cs if c not in reverse_union_cs]


Expand Down
4 changes: 3 additions & 1 deletion mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,9 @@ def unify_generic_callable(
constraints = [
c for c in constraints if not isinstance(get_proper_type(c.target), NoneType)
]
inferred_vars, _ = mypy.solve.solve_constraints(type.variables, constraints)
inferred_vars, _ = mypy.solve.solve_constraints(
type.variables, constraints, allow_polymorphic=True
)
if None in inferred_vars:
return None
non_none_inferred_vars = cast(List[Type], inferred_vars)
Expand Down
18 changes: 18 additions & 0 deletions test-data/unit/check-generics.test
Original file line number Diff line number Diff line change
Expand Up @@ -3442,3 +3442,21 @@ reveal_type(dec(g)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[b
h: Callable[[Unpack[Us]], Foo[int]]
reveal_type(dec(h)) # N: Revealed type is "def (builtins.int) -> __main__.Foo[builtins.int]"
[builtins fixtures/list.pyi]

[case testHigherOrderGenericPartial]
from typing import TypeVar, Callable

T = TypeVar("T")
S = TypeVar("S")
U = TypeVar("U")
def apply(f: Callable[[T], S], x: T) -> S: ...
def id(x: U) -> U: ...

A1 = TypeVar("A1")
A2 = TypeVar("A2")
R = TypeVar("R")
def fake_partial(fun: Callable[[A1, A2], R], arg: A1) -> Callable[[A2], R]: ...

f_pid = fake_partial(apply, id)
reveal_type(f_pid) # N: Revealed type is "def [A2] (A2`2) -> A2`2"
reveal_type(f_pid(1)) # N: Revealed type is "builtins.int"
22 changes: 8 additions & 14 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -371,20 +371,18 @@ def foo(t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and
def foo(t: T, s: T) -> str: ...
def foo(t, s): pass

# TODO: examples below are technically unsafe.
class Wrapper(Generic[T]):
@overload
def foo(self, t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def foo(self, t: List[T], s: T) -> int: ...
@overload
def foo(self, t: T, s: T) -> str: ...
def foo(self, t, s): pass

class Dummy(Generic[T]): pass

# Same root issue: why does the additional constraint bound T <: T
# cause the constraint solver to not infer T = object like it did in the
# first example?
@overload
def bar(d: Dummy[T], t: List[T], s: T) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def bar(d: Dummy[T], t: List[T], s: T) -> int: ...
@overload
def bar(d: Dummy[T], t: T, s: T) -> str: ...
def bar(d: Dummy[T], t, s): pass
Expand Down Expand Up @@ -2865,11 +2863,8 @@ class Wrapper(Generic[T]):
def f(self, x: T) -> T: ...
def f(self, x): ...

# TODO: This shouldn't trigger an error message?
# Related to testTypeCheckOverloadImplementationTypeVarDifferingUsage2?
# See https://github.com/python/mypy/issues/5510
@overload
def g(self, x: int) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g(self, x: int) -> int: ...
@overload
def g(self, x: T) -> T: ...
def g(self, x): ...
Expand All @@ -2892,16 +2887,15 @@ class Wrapper(Generic[T]):
def f2(self, x: List[T]) -> List[T]: ...
def f2(self, x): ...

# TODO: This shouldn't trigger an error message?
# See https://github.com/python/mypy/issues/5510
@overload
def g1(self, x: List[int]) -> int: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g1(self, x: List[int]) -> int: ...
@overload
def g1(self, x: List[T]) -> T: ...
def g1(self, x): ...

# TODO: this is technically unsafe.
@overload
def g2(self, x: List[int]) -> List[int]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def g2(self, x: List[int]) -> List[int]: ...
@overload
def g2(self, x: List[T]) -> List[T]: ...
def g2(self, x): ...
Expand Down Expand Up @@ -6483,7 +6477,7 @@ P = ParamSpec("P")
R = TypeVar("R")

@overload
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ... # E: Overloaded function signatures 1 and 2 overlap with incompatible return types
def func(x: Callable[Concatenate[Any, P], R]) -> Callable[P, R]: ...
@overload
def func(x: Callable[P, R]) -> Callable[Concatenate[str, P], R]: ...
def func(x: Callable[..., R]) -> Callable[..., R]: ...
Expand Down