Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 823034c

Browse files
authored
Merge branch 'main' into use-value-for-completion
2 parents 606767f + 135015a commit 823034c

4 files changed

Lines changed: 220 additions & 12 deletions

File tree

IPython/core/completer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,8 @@ def _trim_expr(self, code: str) -> str:
13521352
assert res is not None
13531353
if len(res.body) != 1:
13541354
continue
1355+
if not isinstance(res.body[0], ast.Expr):
1356+
continue
13551357
expr = res.body[0].value
13561358
if isinstance(expr, ast.Tuple) and not code[-1] == ")":
13571359
# we skip implicit tuple, like when trimming `fun(a,b`<completion>

IPython/core/guarded_eval.py

Lines changed: 89 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030

3131
from IPython.utils.decorators import undoc
3232

33-
34-
from typing import Self, LiteralString
33+
import types
34+
from typing import Self, LiteralString, get_type_hints
3535

3636
if sys.version_info < (3, 12):
3737
from typing_extensions import TypeAliasType
@@ -403,6 +403,9 @@ class EvaluationContext:
403403
class_transients: dict | None = None
404404
#: Instance variable name used in the method definition
405405
instance_arg_name: str | None = None
406+
#: Currently associated value
407+
#: Useful for adding items to _Duck on annotated assignment
408+
current_value: ast.AST | None = None
406409

407410
def replace(self, /, **changes):
408411
"""Return a new copy of the context, with specified changes"""
@@ -566,6 +569,30 @@ def _validate_policy_overrides(
566569
return all_good
567570

568571

572+
def _is_type_annotation(obj) -> bool:
573+
"""
574+
Returns True if obj is a type annotation, False otherwise.
575+
"""
576+
if isinstance(obj, type):
577+
return True
578+
if isinstance(obj, types.GenericAlias):
579+
return True
580+
if hasattr(types, "UnionType") and isinstance(obj, types.UnionType):
581+
return True
582+
if isinstance(obj, (typing._SpecialForm, typing._BaseGenericAlias)):
583+
return True
584+
if isinstance(obj, typing.TypeVar):
585+
return True
586+
# Types that support __class_getitem__
587+
if isinstance(obj, type) and hasattr(obj, "__class_getitem__"):
588+
return True
589+
# Fallback: check if get_origin returns something
590+
if hasattr(typing, "get_origin") and get_origin(obj) is not None:
591+
return True
592+
593+
return False
594+
595+
569596
def _handle_assign(node: ast.Assign, context: EvaluationContext):
570597
value = eval_node(node.value, context)
571598
transient_locals = context.transient_locals
@@ -664,12 +691,17 @@ def _handle_assign(node: ast.Assign, context: EvaluationContext):
664691

665692

666693
def _handle_annassign(node, context):
667-
annotation_value = _resolve_annotation(eval_node(node.annotation, context), context)
668-
669-
# Use Value for generic types
670-
use_value = (
671-
isinstance(annotation_value, GENERIC_CONTAINER_TYPES) and node.value is not None
672-
)
694+
context_with_value = context.replace(current_value=getattr(node, "value", None))
695+
annotation_result = eval_node(node.annotation, context_with_value)
696+
if _is_type_annotation(annotation_result):
697+
annotation_value = _resolve_annotation(annotation_result, context)
698+
# Use Value for generic types
699+
use_value = (
700+
isinstance(annotation_value, GENERIC_CONTAINER_TYPES) and node.value is not None
701+
)
702+
else:
703+
annotation_value = annotation_result
704+
use_value = False
673705

674706
# LOCAL VARIABLE
675707
if getattr(node, "simple", False) and isinstance(node.target, ast.Name):
@@ -801,9 +833,12 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
801833

802834
if is_property:
803835
if return_type is not None:
804-
context.transient_locals[node.name] = _resolve_annotation(
805-
return_type, context
806-
)
836+
if _is_type_annotation(return_type):
837+
context.transient_locals[node.name] = _resolve_annotation(
838+
return_type, context
839+
)
840+
else:
841+
context.transient_locals[node.name] = return_type
807842
else:
808843
return_value = _infer_return_value(node, func_context)
809844
context.transient_locals[node.name] = return_value
@@ -814,7 +849,10 @@ def dummy_function(*args, **kwargs):
814849
pass
815850

816851
if return_type is not None:
817-
dummy_function.__annotations__["return"] = return_type
852+
if _is_type_annotation(return_type):
853+
dummy_function.__annotations__["return"] = return_type
854+
else:
855+
dummy_function.__inferred_return__ = return_type
818856
else:
819857
inferred_return = _infer_return_value(node, func_context)
820858
if inferred_return is not None:
@@ -952,6 +990,29 @@ def dummy_function(*args, **kwargs):
952990
if isinstance(node, ast.BinOp):
953991
left = eval_node(node.left, context)
954992
right = eval_node(node.right, context)
993+
if (
994+
isinstance(node.op, ast.BitOr)
995+
and _is_type_annotation(left)
996+
and _is_type_annotation(right)
997+
):
998+
left_duck = (
999+
_Duck(dict.fromkeys(dir(left)))
1000+
if policy.can_call(left.__dir__)
1001+
else _Duck()
1002+
)
1003+
right_duck = (
1004+
_Duck(dict.fromkeys(dir(right)))
1005+
if policy.can_call(right.__dir__)
1006+
else _Duck()
1007+
)
1008+
value_node = context.current_value
1009+
if value_node is not None and isinstance(value_node, ast.Dict):
1010+
if dict in [left, right]:
1011+
return _merge_values(
1012+
[left_duck, right_duck, ast.literal_eval(value_node)],
1013+
policy=get_policy(context),
1014+
)
1015+
return _merge_values([left_duck, right_duck], policy=get_policy(context))
9551016
dunders = _find_dunder(node.op, BINARY_OP_DUNDERS)
9561017
if dunders:
9571018
if policy.can_operate(dunders, left, right):
@@ -1061,6 +1122,22 @@ def dummy_function(*args, **kwargs):
10611122
value = eval_node(node.value, context)
10621123
if policy.can_get_attr(value, node.attr):
10631124
return getattr(value, node.attr)
1125+
try:
1126+
cls = (
1127+
value if isinstance(value, type) else getattr(value, "__class__", None)
1128+
)
1129+
if cls is not None:
1130+
resolved_hints = get_type_hints(
1131+
cls,
1132+
globalns=(context.globals or {}),
1133+
localns=(context.locals or {}),
1134+
)
1135+
if node.attr in resolved_hints:
1136+
annotated = resolved_hints[node.attr]
1137+
return _resolve_annotation(annotated, context)
1138+
except Exception:
1139+
# Fall through to the guard rejection
1140+
pass
10641141
raise GuardRejection(
10651142
"Attribute access (`__getattr__`) for",
10661143
type(value), # not joined to avoid calling `repr`

tests/test_completer.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1572,6 +1572,39 @@ def __getattr__(self, attr):
15721572
self.assertNotIn(".append", matches)
15731573
self.assertNotIn(".keys", matches)
15741574

1575+
def test_completion_fallback_to_annotation_for_attribute(self):
1576+
code = textwrap.dedent(
1577+
"""
1578+
class StringMethods:
1579+
def a():
1580+
pass
1581+
1582+
class Test:
1583+
str: StringMethods
1584+
def __init__(self):
1585+
self.str = StringMethods()
1586+
def __getattr__(self, name):
1587+
raise AttributeError(f"{name} not found")
1588+
"""
1589+
)
1590+
1591+
repro = types.ModuleType("repro")
1592+
sys.modules["repro"] = repro
1593+
exec(code, repro.__dict__)
1594+
1595+
ip = get_ipython()
1596+
ip.user_ns["repro"] = repro
1597+
exec("r = repro.Test()", ip.user_ns)
1598+
1599+
complete = ip.Completer.complete
1600+
try:
1601+
with evaluation_policy("limited"), jedi_status(False):
1602+
_, matches = complete(line_buffer="r.str.")
1603+
self.assertIn(".a", matches)
1604+
finally:
1605+
sys.modules.pop("repro", None)
1606+
ip.user_ns.pop("r", None)
1607+
15751608
def test_policy_warnings(self):
15761609
with self.assertWarns(
15771610
UserWarning,
@@ -2494,6 +2527,15 @@ def _(expected):
24942527
),
24952528
"bit_length",
24962529
],
2530+
[
2531+
"\n".join(
2532+
[
2533+
"t: list[str]",
2534+
"t[0].",
2535+
]
2536+
),
2537+
["capitalize"],
2538+
],
24972539
],
24982540
)
24992541
def test_undefined_variables(use_jedi, evaluation, code, insert_text):
@@ -2521,6 +2563,42 @@ def test_undefined_variables(use_jedi, evaluation, code, insert_text):
25212563
]
25222564
),
25232565
["append"],
2566+
],
2567+
"\n".join(
2568+
[
2569+
"t: int | dict = {'a': []}",
2570+
"t.",
2571+
]
2572+
),
2573+
["keys", "bit_length"],
2574+
],
2575+
[
2576+
"\n".join(
2577+
[
2578+
"t: int | dict = {'a': []}",
2579+
"t['a'].",
2580+
]
2581+
),
2582+
"append",
2583+
],
2584+
# Test union types
2585+
[
2586+
"\n".join(
2587+
[
2588+
"t: int | str",
2589+
"t.",
2590+
]
2591+
),
2592+
["bit_length", "capitalize"],
2593+
],
2594+
[
2595+
"\n".join(
2596+
[
2597+
"def func() -> int | str: pass",
2598+
"func().",
2599+
]
2600+
),
2601+
["bit_length", "capitalize"],
25242602
],
25252603
[
25262604
"\n".join(
@@ -2531,6 +2609,18 @@ def test_undefined_variables(use_jedi, evaluation, code, insert_text):
25312609
),
25322610
["capitalize"],
25332611
],
2612+
[
2613+
"\n".join(
2614+
[
2615+
"class T:",
2616+
" @property",
2617+
" def p(self) -> int | str: pass",
2618+
"t = T()",
2619+
"t.p.",
2620+
]
2621+
),
2622+
["bit_length", "capitalize"],
2623+
],
25342624
],
25352625
)
25362626
def test_undefined_variables_without_jedi(code, insert_text):
@@ -2692,6 +2782,7 @@ def test_misc_no_jedi_completions(setup, code, expected, not_expected):
26922782
("x = {1, y", "y"),
26932783
("x = [1, y", "y"),
26942784
("x = fun(1, y", "y"),
2785+
(" assert a", "a"),
26952786
],
26962787
)
26972788
def test_trim_expr(code, expected):

tests/test_guarded_eval.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212
TypeGuard,
1313
Union,
1414
TypedDict,
15+
TypeVar,
16+
List,
17+
Callable,
18+
Any,
19+
Dict,
1520
)
1621
from functools import partial
1722
from IPython.core.guarded_eval import (
@@ -22,6 +27,7 @@
2227
)
2328
from IPython.testing import decorators as dec
2429
import pytest
30+
from IPython.core.guarded_eval import _is_type_annotation
2531

2632

2733
from typing import Self, LiteralString
@@ -579,6 +585,38 @@ def test_mock_class_and_func_instances(code, expected):
579585
assert isinstance(value, expected)
580586

581587

588+
@pytest.mark.parametrize(
589+
"annotation,expected",
590+
[
591+
# Basic types
592+
(int, True),
593+
(str, True),
594+
(list, True),
595+
# Typing generics
596+
(list[str], True),
597+
(dict[str, int], True),
598+
(Optional[int], True),
599+
(Union[int, str], True),
600+
# Special forms
601+
(AnyStr, True),
602+
(TypeVar("T"), True),
603+
(Callable[[int], str], True),
604+
(Literal["GET", "POST"], True),
605+
(Any, True),
606+
(str | int, True),
607+
# Nested
608+
(List[Dict[str, int]], True),
609+
# Non-annotations
610+
(42, False),
611+
("string", False),
612+
([1, 2, 3], False),
613+
(None, False),
614+
],
615+
)
616+
def test_is_type_annotation(annotation, expected):
617+
assert _is_type_annotation(annotation) == expected
618+
619+
582620
@pytest.mark.parametrize(
583621
"code,expected",
584622
[

0 commit comments

Comments
 (0)