Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8825338

Browse files
authored
Merge branch 'main' into dep/lower-bounds
2 parents f7bad14 + f74b500 commit 8825338

4 files changed

Lines changed: 521 additions & 32 deletions

File tree

IPython/core/guarded_eval.py

Lines changed: 271 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,10 @@ class EvaluationContext:
396396
policy_overrides: dict = field(default_factory=dict)
397397
#: Transient local namespace used to store mocks
398398
transient_locals: dict = field(default_factory=dict)
399+
#: Transients of class level
400+
class_transients: dict | None = None
401+
#: Instance variable name used in the method definition
402+
instance_arg_name: str | None = None
399403

400404
def replace(self, /, **changes):
401405
"""Return a new copy of the context, with specified changes"""
@@ -560,6 +564,8 @@ def _validate_policy_overrides(
560564
def _handle_assign(node: ast.Assign, context: EvaluationContext):
561565
value = eval_node(node.value, context)
562566
transient_locals = context.transient_locals
567+
policy = get_policy(context)
568+
class_transients = context.class_transients
563569
for target in node.targets:
564570
if isinstance(target, (ast.Tuple, ast.List)):
565571
# Handle unpacking assignment
@@ -572,20 +578,81 @@ def _handle_assign(node: ast.Assign, context: EvaluationContext):
572578

573579
# Before starred
574580
for i in range(star_or_last_idx):
575-
transient_locals[targets[i].id] = values[i]
581+
# Check for self.x assignment
582+
if _is_instance_attribute_assignment(targets[i], context):
583+
class_transients[targets[i].attr] = values[i]
584+
else:
585+
transient_locals[targets[i].id] = values[i]
576586

577587
# Starred if exists
578588
if starred:
579589
end = len(values) - (len(targets) - star_or_last_idx - 1)
580-
transient_locals[targets[star_or_last_idx].value.id] = values[
581-
star_or_last_idx:end
582-
]
590+
if _is_instance_attribute_assignment(
591+
targets[star_or_last_idx], context
592+
):
593+
class_transients[targets[star_or_last_idx].attr] = values[
594+
star_or_last_idx:end
595+
]
596+
else:
597+
transient_locals[targets[star_or_last_idx].value.id] = values[
598+
star_or_last_idx:end
599+
]
583600

584601
# After starred
585602
for i in range(star_or_last_idx + 1, len(targets)):
586-
transient_locals[targets[i].id] = values[
587-
len(values) - (len(targets) - i)
588-
]
603+
if _is_instance_attribute_assignment(targets[i], context):
604+
class_transients[targets[i].attr] = values[
605+
len(values) - (len(targets) - i)
606+
]
607+
else:
608+
transient_locals[targets[i].id] = values[
609+
len(values) - (len(targets) - i)
610+
]
611+
elif isinstance(target, ast.Subscript):
612+
if isinstance(target.value, ast.Name):
613+
name = target.value.id
614+
container = transient_locals.get(name)
615+
if container is None:
616+
container = context.locals.get(name)
617+
if container is None:
618+
container = context.globals.get(name)
619+
if container is None:
620+
raise NameError(
621+
f"{name} not found in locals, globals, nor builtins"
622+
)
623+
storage_dict = transient_locals
624+
storage_key = name
625+
elif isinstance(
626+
target.value, ast.Attribute
627+
) and _is_instance_attribute_assignment(target.value, context):
628+
attr = target.value.attr
629+
container = class_transients.get(attr, None)
630+
if container is None:
631+
raise NameError(f"{attr} not found in class transients")
632+
storage_dict = class_transients
633+
storage_key = attr
634+
else:
635+
return
636+
637+
key = eval_node(target.slice, context)
638+
attributes = (
639+
dict.fromkeys(dir(container))
640+
if policy.can_call(container.__dir__)
641+
else {}
642+
)
643+
items = {}
644+
645+
if policy.can_get_item(container, None):
646+
try:
647+
items = dict(container.items())
648+
except Exception:
649+
pass
650+
651+
items[key] = value
652+
duck_container = _Duck(attributes=attributes, items=items)
653+
storage_dict[storage_key] = duck_container
654+
elif _is_instance_attribute_assignment(target, context):
655+
class_transients[target.attr] = value
589656
else:
590657
transient_locals[target.id] = value
591658
return None
@@ -605,6 +672,30 @@ def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
605672
return args, kwargs
606673

607674

675+
def _is_instance_attribute_assignment(
676+
target: ast.AST, context: EvaluationContext
677+
) -> bool:
678+
"""Return True if target is an attribute access on the instance argument."""
679+
return (
680+
context.class_transients is not None
681+
and context.instance_arg_name is not None
682+
and isinstance(target, ast.Attribute)
683+
and isinstance(getattr(target, "value", None), ast.Name)
684+
and getattr(target.value, "id", None) == context.instance_arg_name
685+
)
686+
687+
688+
def _get_coroutine_attributes() -> dict[str, Optional[object]]:
689+
async def _dummy():
690+
return None
691+
692+
coro = _dummy()
693+
try:
694+
return {attr: getattr(coro, attr, None) for attr in dir(coro)}
695+
finally:
696+
coro.close()
697+
698+
608699
def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
609700
"""Evaluate AST node in provided context.
610701
@@ -641,10 +732,13 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
641732
for child_node in node.body:
642733
result = eval_node(child_node, context)
643734
return result
644-
if isinstance(node, ast.FunctionDef):
645-
# we ignore body and only extract the return type
735+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
736+
is_async = isinstance(node, ast.AsyncFunctionDef)
737+
func_locals = context.transient_locals.copy()
738+
func_context = context.replace(transient_locals=func_locals)
646739
is_property = False
647-
740+
is_static = False
741+
is_classmethod = False
648742
for decorator_node in node.decorator_list:
649743
try:
650744
decorator = eval_node(decorator_node, context)
@@ -654,42 +748,85 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
654748
continue
655749
if decorator is property:
656750
is_property = True
751+
elif decorator is staticmethod:
752+
is_static = True
753+
elif decorator is classmethod:
754+
is_classmethod = True
755+
756+
if func_context.class_transients is not None:
757+
if not is_static and not is_classmethod:
758+
func_context.instance_arg_name = (
759+
node.args.args[0].arg if node.args.args else None
760+
)
657761

658762
return_type = eval_node(node.returns, context=context)
659763

764+
for child_node in node.body:
765+
eval_node(child_node, func_context)
766+
660767
if is_property:
661-
context.transient_locals[node.name] = _resolve_annotation(
662-
return_type, context
663-
)
768+
if return_type is not None:
769+
context.transient_locals[node.name] = _resolve_annotation(
770+
return_type, context
771+
)
772+
else:
773+
return_value = _infer_return_value(node, func_context)
774+
context.transient_locals[node.name] = return_value
775+
664776
return None
665777

666778
def dummy_function(*args, **kwargs):
667779
pass
668780

669-
dummy_function.__annotations__["return"] = return_type
781+
if return_type is not None:
782+
dummy_function.__annotations__["return"] = return_type
783+
else:
784+
inferred_return = _infer_return_value(node, func_context)
785+
if inferred_return is not None:
786+
dummy_function.__inferred_return__ = inferred_return
787+
670788
dummy_function.__name__ = node.name
671789
dummy_function.__node__ = node
790+
dummy_function.__is_async__ = is_async
672791
context.transient_locals[node.name] = dummy_function
673792
return None
793+
if isinstance(node, ast.Lambda):
794+
795+
def dummy_function(*args, **kwargs):
796+
pass
797+
798+
dummy_function.__inferred_return__ = eval_node(node.body, context)
799+
return dummy_function
674800
if isinstance(node, ast.ClassDef):
675801
# TODO support class decorators?
676802
class_locals = {}
677-
class_context = context.replace(transient_locals=class_locals)
803+
outer_locals = context.locals.copy()
804+
outer_locals.update(context.transient_locals)
805+
class_context = context.replace(
806+
transient_locals=class_locals, locals=outer_locals
807+
)
808+
class_context.class_transients = class_locals
678809
for child_node in node.body:
679810
eval_node(child_node, class_context)
680811
bases = tuple([eval_node(base, context) for base in node.bases])
681812
dummy_class = type(node.name, bases, class_locals)
682813
context.transient_locals[node.name] = dummy_class
683814
return None
815+
if isinstance(node, ast.Await):
816+
value = eval_node(node.value, context)
817+
if hasattr(value, "__awaited_type__"):
818+
return value.__awaited_type__
819+
return value
684820
if isinstance(node, ast.Assign):
685821
return _handle_assign(node, context)
686822
if isinstance(node, ast.AnnAssign):
687-
if not node.simple:
688-
# for now only handle simple annotations
689-
return None
690-
context.transient_locals[node.target.id] = _resolve_annotation(
691-
eval_node(node.annotation, context), context
692-
)
823+
if node.simple:
824+
value = _resolve_annotation(eval_node(node.annotation, context), context)
825+
context.transient_locals[node.target.id] = value
826+
# Handle non-simple annotated assignments only for self.x: type = value
827+
if _is_instance_attribute_assignment(node.target, context):
828+
value = _resolve_annotation(eval_node(node.annotation, context), context)
829+
context.class_transients[node.target.attr] = value
693830
return None
694831
if isinstance(node, ast.Expression):
695832
return eval_node(node.body, context)
@@ -807,6 +944,12 @@ def dummy_function(*args, **kwargs):
807944
if isinstance(node, ast.Name):
808945
return _eval_node_name(node.id, context)
809946
if isinstance(node, ast.Attribute):
947+
if (
948+
context.class_transients is not None
949+
and isinstance(node.value, ast.Name)
950+
and node.value.id == context.instance_arg_name
951+
):
952+
return context.class_transients.get(node.attr)
810953
value = eval_node(node.value, context)
811954
if policy.can_get_attr(value, node.attr):
812955
return getattr(value, node.attr)
@@ -836,7 +979,17 @@ def dummy_function(*args, **kwargs):
836979
return overridden_return_type
837980
return _create_duck_for_heap_type(func)
838981
else:
982+
inferred_return = getattr(func, "__inferred_return__", NOT_EVALUATED)
839983
return_type = _eval_return_type(func, node, context)
984+
if getattr(func, "__is_async__", False):
985+
awaited_type = (
986+
inferred_return if inferred_return is not None else return_type
987+
)
988+
coroutine_duck = _Duck(attributes=_get_coroutine_attributes())
989+
coroutine_duck.__awaited_type__ = awaited_type
990+
return coroutine_duck
991+
if inferred_return is not NOT_EVALUATED:
992+
return inferred_return
840993
if return_type is not NOT_EVALUATED:
841994
return return_type
842995
raise GuardRejection(
@@ -853,6 +1006,101 @@ def dummy_function(*args, **kwargs):
8531006
return None
8541007

8551008

1009+
def _merge_values(values, policy: EvaluationPolicy):
1010+
"""Recursively merge multiple values, combining attributes and dict items."""
1011+
if len(values) == 1:
1012+
return values[0]
1013+
1014+
types = {type(v) for v in values}
1015+
merged_items = None
1016+
key_values = {}
1017+
attributes = set()
1018+
for v in values:
1019+
if policy.can_call(v.__dir__):
1020+
attributes.update(dir(v))
1021+
try:
1022+
if policy.can_call(v.items):
1023+
try:
1024+
for k, val in v.items():
1025+
key_values.setdefault(k, []).append(val)
1026+
except Exception as e:
1027+
pass
1028+
elif policy.can_call(v.keys):
1029+
try:
1030+
for k in v.keys():
1031+
key_values.setdefault(k, []).append(None)
1032+
except Exception as e:
1033+
pass
1034+
except Exception as e:
1035+
pass
1036+
1037+
if key_values:
1038+
merged_items = {
1039+
k: _merge_values(vals, policy) if vals[0] is not None else None
1040+
for k, vals in key_values.items()
1041+
}
1042+
1043+
if len(types) == 1:
1044+
t = next(iter(types))
1045+
if t not in (dict,) and not (
1046+
hasattr(next(iter(values)), "__getitem__")
1047+
and (
1048+
hasattr(next(iter(values)), "items")
1049+
or hasattr(next(iter(values)), "keys")
1050+
)
1051+
):
1052+
if t in (list, set, tuple):
1053+
return t
1054+
return values[0]
1055+
1056+
return _Duck(attributes=dict.fromkeys(attributes), items=merged_items)
1057+
1058+
1059+
def _infer_return_value(node: ast.FunctionDef, context: EvaluationContext):
1060+
"""Infer the return value(s) of a function by evaluating all return statements."""
1061+
return_values = _collect_return_values(node.body, context)
1062+
1063+
if not return_values:
1064+
return None
1065+
if len(return_values) == 1:
1066+
return return_values[0]
1067+
1068+
policy = get_policy(context)
1069+
return _merge_values(return_values, policy)
1070+
1071+
1072+
def _collect_return_values(body, context):
1073+
"""Recursively collect return values from a list of AST statements."""
1074+
return_values = []
1075+
for stmt in body:
1076+
if isinstance(stmt, ast.Return):
1077+
if stmt.value is None:
1078+
continue
1079+
try:
1080+
value = eval_node(stmt.value, context)
1081+
if value is not None and value is not NOT_EVALUATED:
1082+
return_values.append(value)
1083+
except Exception:
1084+
pass
1085+
if isinstance(
1086+
stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)
1087+
):
1088+
continue
1089+
elif hasattr(stmt, "body") and isinstance(stmt.body, list):
1090+
return_values.extend(_collect_return_values(stmt.body, context))
1091+
if isinstance(stmt, ast.Try):
1092+
for h in stmt.handlers:
1093+
if hasattr(h, "body"):
1094+
return_values.extend(_collect_return_values(h.body, context))
1095+
if hasattr(stmt, "orelse"):
1096+
return_values.extend(_collect_return_values(stmt.orelse, context))
1097+
if hasattr(stmt, "finalbody"):
1098+
return_values.extend(_collect_return_values(stmt.finalbody, context))
1099+
if hasattr(stmt, "orelse") and isinstance(stmt.orelse, list):
1100+
return_values.extend(_collect_return_values(stmt.orelse, context))
1101+
return return_values
1102+
1103+
8561104
def _eval_return_type(func: Callable, node: ast.Call, context: EvaluationContext):
8571105
"""Evaluate return type of a given callable function.
8581106
@@ -1118,6 +1366,8 @@ def _list_methods(cls, source=None):
11181366
*_list_methods(collections.Counter, dict_non_mutating_methods),
11191367
collections.Counter.elements,
11201368
collections.Counter.most_common,
1369+
object.__dir__,
1370+
type.__dir__,
11211371
}
11221372

11231373
BUILTIN_GETATTR: set[MayHaveGetattr] = {

0 commit comments

Comments
 (0)