|
10 | 10 | from mypy.plugins.common import try_getting_str_literals |
11 | 11 | from mypy.types import ( |
12 | 12 | Type, Instance, AnyType, TypeOfAny, CallableType, NoneType, TypedDictType, |
13 | | - TypeVarType, TPDICT_FB_NAMES, get_proper_type |
| 13 | + TypeVarType, TPDICT_FB_NAMES, get_proper_type, LiteralType |
14 | 14 | ) |
15 | 15 | from mypy.subtypes import is_subtype |
16 | 16 | from mypy.typeops import make_simplified_union |
| 17 | +from mypy.checkexpr import is_literal_type_like |
17 | 18 |
|
18 | 19 |
|
19 | 20 | class DefaultPlugin(Plugin): |
@@ -57,6 +58,8 @@ def get_method_hook(self, fullname: str |
57 | 58 | return typed_dict_get_callback |
58 | 59 | elif fullname == 'builtins.int.__pow__': |
59 | 60 | return int_pow_callback |
| 61 | + elif fullname == 'builtins.int.__neg__': |
| 62 | + return int_neg_callback |
60 | 63 | elif fullname in set(n + '.setdefault' for n in TPDICT_FB_NAMES): |
61 | 64 | return typed_dict_setdefault_callback |
62 | 65 | elif fullname in set(n + '.pop' for n in TPDICT_FB_NAMES): |
@@ -417,3 +420,30 @@ def int_pow_callback(ctx: MethodContext) -> Type: |
417 | 420 | else: |
418 | 421 | return ctx.api.named_generic_type('builtins.float', []) |
419 | 422 | return ctx.default_return_type |
| 423 | + |
| 424 | + |
| 425 | +def int_neg_callback(ctx: MethodContext) -> Type: |
| 426 | + """Infer a more precise return type for int.__neg__. |
| 427 | +
|
| 428 | + This is mainly used to infer the return type as LiteralType |
| 429 | + if the original underlying object is a LiteralType object |
| 430 | + """ |
| 431 | + if isinstance(ctx.type, Instance) and ctx.type.last_known_value is not None: |
| 432 | + value = ctx.type.last_known_value.value |
| 433 | + fallback = ctx.type.last_known_value.fallback |
| 434 | + if isinstance(value, int): |
| 435 | + if is_literal_type_like(ctx.api.type_context[-1]): |
| 436 | + return LiteralType(value=-value, fallback=fallback) |
| 437 | + else: |
| 438 | + return ctx.type.copy_modified(last_known_value=LiteralType( |
| 439 | + value=-value, |
| 440 | + fallback=ctx.type, |
| 441 | + line=ctx.type.line, |
| 442 | + column=ctx.type.column, |
| 443 | + )) |
| 444 | + elif isinstance(ctx.type, LiteralType): |
| 445 | + value = ctx.type.value |
| 446 | + fallback = ctx.type.fallback |
| 447 | + if isinstance(value, int): |
| 448 | + return LiteralType(value=-value, fallback=fallback) |
| 449 | + return ctx.default_return_type |
0 commit comments