@@ -396,6 +396,10 @@ class EvaluationContext:
396396 policy_overrides : dict = field (default_factory = dict )
397397 #: Transient local namespace used to store mocks
398398 transient_locals : dict = field (default_factory = dict )
399+ #: Transients of class level
400+ class_transients : dict | None = None
401+ #: Instance variable name used in the method definition
402+ instance_arg_name : str | None = None
399403
400404 def replace (self , / , ** changes ):
401405 """Return a new copy of the context, with specified changes"""
@@ -560,6 +564,8 @@ def _validate_policy_overrides(
560564def _handle_assign (node : ast .Assign , context : EvaluationContext ):
561565 value = eval_node (node .value , context )
562566 transient_locals = context .transient_locals
567+ policy = get_policy (context )
568+ class_transients = context .class_transients
563569 for target in node .targets :
564570 if isinstance (target , (ast .Tuple , ast .List )):
565571 # Handle unpacking assignment
@@ -572,20 +578,81 @@ def _handle_assign(node: ast.Assign, context: EvaluationContext):
572578
573579 # Before starred
574580 for i in range (star_or_last_idx ):
575- transient_locals [targets [i ].id ] = values [i ]
581+ # Check for self.x assignment
582+ if _is_instance_attribute_assignment (targets [i ], context ):
583+ class_transients [targets [i ].attr ] = values [i ]
584+ else :
585+ transient_locals [targets [i ].id ] = values [i ]
576586
577587 # Starred if exists
578588 if starred :
579589 end = len (values ) - (len (targets ) - star_or_last_idx - 1 )
580- transient_locals [targets [star_or_last_idx ].value .id ] = values [
581- star_or_last_idx :end
582- ]
590+ if _is_instance_attribute_assignment (
591+ targets [star_or_last_idx ], context
592+ ):
593+ class_transients [targets [star_or_last_idx ].attr ] = values [
594+ star_or_last_idx :end
595+ ]
596+ else :
597+ transient_locals [targets [star_or_last_idx ].value .id ] = values [
598+ star_or_last_idx :end
599+ ]
583600
584601 # After starred
585602 for i in range (star_or_last_idx + 1 , len (targets )):
586- transient_locals [targets [i ].id ] = values [
587- len (values ) - (len (targets ) - i )
588- ]
603+ if _is_instance_attribute_assignment (targets [i ], context ):
604+ class_transients [targets [i ].attr ] = values [
605+ len (values ) - (len (targets ) - i )
606+ ]
607+ else :
608+ transient_locals [targets [i ].id ] = values [
609+ len (values ) - (len (targets ) - i )
610+ ]
611+ elif isinstance (target , ast .Subscript ):
612+ if isinstance (target .value , ast .Name ):
613+ name = target .value .id
614+ container = transient_locals .get (name )
615+ if container is None :
616+ container = context .locals .get (name )
617+ if container is None :
618+ container = context .globals .get (name )
619+ if container is None :
620+ raise NameError (
621+ f"{ name } not found in locals, globals, nor builtins"
622+ )
623+ storage_dict = transient_locals
624+ storage_key = name
625+ elif isinstance (
626+ target .value , ast .Attribute
627+ ) and _is_instance_attribute_assignment (target .value , context ):
628+ attr = target .value .attr
629+ container = class_transients .get (attr , None )
630+ if container is None :
631+ raise NameError (f"{ attr } not found in class transients" )
632+ storage_dict = class_transients
633+ storage_key = attr
634+ else :
635+ return
636+
637+ key = eval_node (target .slice , context )
638+ attributes = (
639+ dict .fromkeys (dir (container ))
640+ if policy .can_call (container .__dir__ )
641+ else {}
642+ )
643+ items = {}
644+
645+ if policy .can_get_item (container , None ):
646+ try :
647+ items = dict (container .items ())
648+ except Exception :
649+ pass
650+
651+ items [key ] = value
652+ duck_container = _Duck (attributes = attributes , items = items )
653+ storage_dict [storage_key ] = duck_container
654+ elif _is_instance_attribute_assignment (target , context ):
655+ class_transients [target .attr ] = value
589656 else :
590657 transient_locals [target .id ] = value
591658 return None
@@ -605,6 +672,30 @@ def _extract_args_and_kwargs(node: ast.Call, context: EvaluationContext):
605672 return args , kwargs
606673
607674
675+ def _is_instance_attribute_assignment (
676+ target : ast .AST , context : EvaluationContext
677+ ) -> bool :
678+ """Return True if target is an attribute access on the instance argument."""
679+ return (
680+ context .class_transients is not None
681+ and context .instance_arg_name is not None
682+ and isinstance (target , ast .Attribute )
683+ and isinstance (getattr (target , "value" , None ), ast .Name )
684+ and getattr (target .value , "id" , None ) == context .instance_arg_name
685+ )
686+
687+
688+ def _get_coroutine_attributes () -> dict [str , Optional [object ]]:
689+ async def _dummy ():
690+ return None
691+
692+ coro = _dummy ()
693+ try :
694+ return {attr : getattr (coro , attr , None ) for attr in dir (coro )}
695+ finally :
696+ coro .close ()
697+
698+
608699def eval_node (node : Union [ast .AST , None ], context : EvaluationContext ):
609700 """Evaluate AST node in provided context.
610701
@@ -641,10 +732,13 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
641732 for child_node in node .body :
642733 result = eval_node (child_node , context )
643734 return result
644- if isinstance (node , ast .FunctionDef ):
645- # we ignore body and only extract the return type
735+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )):
736+ is_async = isinstance (node , ast .AsyncFunctionDef )
737+ func_locals = context .transient_locals .copy ()
738+ func_context = context .replace (transient_locals = func_locals )
646739 is_property = False
647-
740+ is_static = False
741+ is_classmethod = False
648742 for decorator_node in node .decorator_list :
649743 try :
650744 decorator = eval_node (decorator_node , context )
@@ -654,42 +748,85 @@ def eval_node(node: Union[ast.AST, None], context: EvaluationContext):
654748 continue
655749 if decorator is property :
656750 is_property = True
751+ elif decorator is staticmethod :
752+ is_static = True
753+ elif decorator is classmethod :
754+ is_classmethod = True
755+
756+ if func_context .class_transients is not None :
757+ if not is_static and not is_classmethod :
758+ func_context .instance_arg_name = (
759+ node .args .args [0 ].arg if node .args .args else None
760+ )
657761
658762 return_type = eval_node (node .returns , context = context )
659763
764+ for child_node in node .body :
765+ eval_node (child_node , func_context )
766+
660767 if is_property :
661- context .transient_locals [node .name ] = _resolve_annotation (
662- return_type , context
663- )
768+ if return_type is not None :
769+ context .transient_locals [node .name ] = _resolve_annotation (
770+ return_type , context
771+ )
772+ else :
773+ return_value = _infer_return_value (node , func_context )
774+ context .transient_locals [node .name ] = return_value
775+
664776 return None
665777
666778 def dummy_function (* args , ** kwargs ):
667779 pass
668780
669- dummy_function .__annotations__ ["return" ] = return_type
781+ if return_type is not None :
782+ dummy_function .__annotations__ ["return" ] = return_type
783+ else :
784+ inferred_return = _infer_return_value (node , func_context )
785+ if inferred_return is not None :
786+ dummy_function .__inferred_return__ = inferred_return
787+
670788 dummy_function .__name__ = node .name
671789 dummy_function .__node__ = node
790+ dummy_function .__is_async__ = is_async
672791 context .transient_locals [node .name ] = dummy_function
673792 return None
793+ if isinstance (node , ast .Lambda ):
794+
795+ def dummy_function (* args , ** kwargs ):
796+ pass
797+
798+ dummy_function .__inferred_return__ = eval_node (node .body , context )
799+ return dummy_function
674800 if isinstance (node , ast .ClassDef ):
675801 # TODO support class decorators?
676802 class_locals = {}
677- class_context = context .replace (transient_locals = class_locals )
803+ outer_locals = context .locals .copy ()
804+ outer_locals .update (context .transient_locals )
805+ class_context = context .replace (
806+ transient_locals = class_locals , locals = outer_locals
807+ )
808+ class_context .class_transients = class_locals
678809 for child_node in node .body :
679810 eval_node (child_node , class_context )
680811 bases = tuple ([eval_node (base , context ) for base in node .bases ])
681812 dummy_class = type (node .name , bases , class_locals )
682813 context .transient_locals [node .name ] = dummy_class
683814 return None
815+ if isinstance (node , ast .Await ):
816+ value = eval_node (node .value , context )
817+ if hasattr (value , "__awaited_type__" ):
818+ return value .__awaited_type__
819+ return value
684820 if isinstance (node , ast .Assign ):
685821 return _handle_assign (node , context )
686822 if isinstance (node , ast .AnnAssign ):
687- if not node .simple :
688- # for now only handle simple annotations
689- return None
690- context .transient_locals [node .target .id ] = _resolve_annotation (
691- eval_node (node .annotation , context ), context
692- )
823+ if node .simple :
824+ value = _resolve_annotation (eval_node (node .annotation , context ), context )
825+ context .transient_locals [node .target .id ] = value
826+ # Handle non-simple annotated assignments only for self.x: type = value
827+ if _is_instance_attribute_assignment (node .target , context ):
828+ value = _resolve_annotation (eval_node (node .annotation , context ), context )
829+ context .class_transients [node .target .attr ] = value
693830 return None
694831 if isinstance (node , ast .Expression ):
695832 return eval_node (node .body , context )
@@ -807,6 +944,12 @@ def dummy_function(*args, **kwargs):
807944 if isinstance (node , ast .Name ):
808945 return _eval_node_name (node .id , context )
809946 if isinstance (node , ast .Attribute ):
947+ if (
948+ context .class_transients is not None
949+ and isinstance (node .value , ast .Name )
950+ and node .value .id == context .instance_arg_name
951+ ):
952+ return context .class_transients .get (node .attr )
810953 value = eval_node (node .value , context )
811954 if policy .can_get_attr (value , node .attr ):
812955 return getattr (value , node .attr )
@@ -836,7 +979,17 @@ def dummy_function(*args, **kwargs):
836979 return overridden_return_type
837980 return _create_duck_for_heap_type (func )
838981 else :
982+ inferred_return = getattr (func , "__inferred_return__" , NOT_EVALUATED )
839983 return_type = _eval_return_type (func , node , context )
984+ if getattr (func , "__is_async__" , False ):
985+ awaited_type = (
986+ inferred_return if inferred_return is not None else return_type
987+ )
988+ coroutine_duck = _Duck (attributes = _get_coroutine_attributes ())
989+ coroutine_duck .__awaited_type__ = awaited_type
990+ return coroutine_duck
991+ if inferred_return is not NOT_EVALUATED :
992+ return inferred_return
840993 if return_type is not NOT_EVALUATED :
841994 return return_type
842995 raise GuardRejection (
@@ -853,6 +1006,101 @@ def dummy_function(*args, **kwargs):
8531006 return None
8541007
8551008
1009+ def _merge_values (values , policy : EvaluationPolicy ):
1010+ """Recursively merge multiple values, combining attributes and dict items."""
1011+ if len (values ) == 1 :
1012+ return values [0 ]
1013+
1014+ types = {type (v ) for v in values }
1015+ merged_items = None
1016+ key_values = {}
1017+ attributes = set ()
1018+ for v in values :
1019+ if policy .can_call (v .__dir__ ):
1020+ attributes .update (dir (v ))
1021+ try :
1022+ if policy .can_call (v .items ):
1023+ try :
1024+ for k , val in v .items ():
1025+ key_values .setdefault (k , []).append (val )
1026+ except Exception as e :
1027+ pass
1028+ elif policy .can_call (v .keys ):
1029+ try :
1030+ for k in v .keys ():
1031+ key_values .setdefault (k , []).append (None )
1032+ except Exception as e :
1033+ pass
1034+ except Exception as e :
1035+ pass
1036+
1037+ if key_values :
1038+ merged_items = {
1039+ k : _merge_values (vals , policy ) if vals [0 ] is not None else None
1040+ for k , vals in key_values .items ()
1041+ }
1042+
1043+ if len (types ) == 1 :
1044+ t = next (iter (types ))
1045+ if t not in (dict ,) and not (
1046+ hasattr (next (iter (values )), "__getitem__" )
1047+ and (
1048+ hasattr (next (iter (values )), "items" )
1049+ or hasattr (next (iter (values )), "keys" )
1050+ )
1051+ ):
1052+ if t in (list , set , tuple ):
1053+ return t
1054+ return values [0 ]
1055+
1056+ return _Duck (attributes = dict .fromkeys (attributes ), items = merged_items )
1057+
1058+
1059+ def _infer_return_value (node : ast .FunctionDef , context : EvaluationContext ):
1060+ """Infer the return value(s) of a function by evaluating all return statements."""
1061+ return_values = _collect_return_values (node .body , context )
1062+
1063+ if not return_values :
1064+ return None
1065+ if len (return_values ) == 1 :
1066+ return return_values [0 ]
1067+
1068+ policy = get_policy (context )
1069+ return _merge_values (return_values , policy )
1070+
1071+
1072+ def _collect_return_values (body , context ):
1073+ """Recursively collect return values from a list of AST statements."""
1074+ return_values = []
1075+ for stmt in body :
1076+ if isinstance (stmt , ast .Return ):
1077+ if stmt .value is None :
1078+ continue
1079+ try :
1080+ value = eval_node (stmt .value , context )
1081+ if value is not None and value is not NOT_EVALUATED :
1082+ return_values .append (value )
1083+ except Exception :
1084+ pass
1085+ if isinstance (
1086+ stmt , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef , ast .Lambda )
1087+ ):
1088+ continue
1089+ elif hasattr (stmt , "body" ) and isinstance (stmt .body , list ):
1090+ return_values .extend (_collect_return_values (stmt .body , context ))
1091+ if isinstance (stmt , ast .Try ):
1092+ for h in stmt .handlers :
1093+ if hasattr (h , "body" ):
1094+ return_values .extend (_collect_return_values (h .body , context ))
1095+ if hasattr (stmt , "orelse" ):
1096+ return_values .extend (_collect_return_values (stmt .orelse , context ))
1097+ if hasattr (stmt , "finalbody" ):
1098+ return_values .extend (_collect_return_values (stmt .finalbody , context ))
1099+ if hasattr (stmt , "orelse" ) and isinstance (stmt .orelse , list ):
1100+ return_values .extend (_collect_return_values (stmt .orelse , context ))
1101+ return return_values
1102+
1103+
8561104def _eval_return_type (func : Callable , node : ast .Call , context : EvaluationContext ):
8571105 """Evaluate return type of a given callable function.
8581106
@@ -1118,6 +1366,8 @@ def _list_methods(cls, source=None):
11181366 * _list_methods (collections .Counter , dict_non_mutating_methods ),
11191367 collections .Counter .elements ,
11201368 collections .Counter .most_common ,
1369+ object .__dir__ ,
1370+ type .__dir__ ,
11211371}
11221372
11231373BUILTIN_GETATTR : set [MayHaveGetattr ] = {
0 commit comments