@@ -56,7 +56,6 @@ def get_stub_names(
5656 ModulePath (tuple (module_name .split ("." ))),
5757 is_init = is_init ,
5858 file_path = path ,
59- is_py_file = path .suffix == ".py" ,
6059 )
6160
6261
@@ -65,16 +64,11 @@ def parse_ast(
6564 search_context : SearchContext ,
6665 module_name : ModulePath ,
6766 * ,
67+ file_path : Path ,
6868 is_init : bool = False ,
69- file_path : Optional [Path ] = None ,
70- is_py_file : bool = False ,
7169) -> NameDict :
7270 visitor = _NameExtractor (
73- search_context ,
74- module_name ,
75- is_init = is_init ,
76- file_path = file_path ,
77- is_py_file = is_py_file ,
71+ search_context , module_name , is_init = is_init , file_path = file_path
7872 )
7973 name_dict : NameDict = {}
8074 try :
@@ -208,15 +202,17 @@ def __init__(
208202 ctx : SearchContext ,
209203 module_name : ModulePath ,
210204 * ,
205+ file_path : Path ,
211206 is_init : bool = False ,
212- file_path : Optional [Path ],
213- is_py_file : bool = False ,
214207 ) -> None :
215208 self .ctx = ctx
216209 self .module_name = module_name
217210 self .is_init = is_init
218211 self .file_path = file_path
219- self .is_py_file = is_py_file
212+
213+ @property
214+ def is_py_file (self ) -> bool :
215+ return self .file_path .suffix == ".py"
220216
221217 def visit_Module (self , node : ast .Module ) -> list [NameInfo ]:
222218 return [info for child in node .body for info in self .visit (child )]
@@ -305,15 +301,9 @@ def visit_If(self, node: ast.If) -> Iterable[NameInfo]:
305301 yield from self .visit (stmt )
306302
307303 def _visit_condition (self , expr : ast .expr ) -> Optional [bool ]:
308- visitor = LiteralEvalVisitor (self .ctx , self .file_path )
309- try :
310- value = visitor .visit (expr )
311- except InvalidStub :
312- if not self .is_py_file :
313- raise
314- return None
315- else :
316- return bool (value )
304+ return evaluate_expression_truthiness (
305+ expr , ctx = self .ctx , file_path = self .file_path
306+ )
317307
318308 def visit_Try (self , node : ast .Try ) -> Iterable [NameInfo ]:
319309 # try-except sometimes gets used with conditional imports. We assume
@@ -430,17 +420,41 @@ def generic_visit(self, node: ast.AST) -> Iterable[NameInfo]:
430420 raise InvalidStub (f"Cannot handle node { ast .dump (node )} " , self .file_path )
431421
432422
433- class LiteralEvalVisitor (ast .NodeVisitor ):
434- """Visitor to evaluate the truthiness of a ``test`` expression in an ``ast.Compare`` node.
423+ def evaluate_expression_truthiness (
424+ expr : ast .expr , * , ctx : SearchContext , file_path : Path
425+ ) -> Optional [bool ]:
426+ """Attempt to statically evaluate the truthiness of the expression represented by ``expr``.
427+
428+ This is useful for evaluating conditions that are used for branches in stubs, such as
429+ ``if sys.platform == "linux": ...`` or ``if sys.version_info >= (3, 8): ...``. It is usually
430+ desirable for a type checker only to consider one of these branches as reachable code for a
431+ given configuration of the type checker.
435432
436- ``LiteralEvalVisitor(ctx, path).visit(node)`` will return ``True`` if ``node`` is an
437- expression that can be statically determined to always be ``True``, ``False`` if it can
438- be statically determined to always be ``False``, and ``None`` if its truthiness cannot
439- be determined statically. For example, if passed an AST node representing the expression
440- ``sys.platform == "linux"``, it will return ``True`` if ``ctx.platform`` is equal to
441- ``"linux"``, otherwise ``False``.
433+ Details:
434+ * If the truthiness can be statically determined to always be ``True``, it returns ``True``.
435+ * If the truthiness can be statically determined to always be ``False``, it returns ``False``.
436+ * If the truthiness cannot be statically determined:
437+ * If ``file_path`` has a ``.pyi`` extension, ``InvalidStub`` is raised
438+ * If ``file_path`` has a any other extension, however, it returns ``None``, since it is
439+ expected that non-stub Python source files may contain dynamic expressions in ``if`` tests
440+ that cannot be evaluated statically.
441+
442+ For example, if passed an AST node representing the expression ``sys.platform == "linux"``,
443+ it will return ``True`` if ``ctx.platform`` is equal to ``"linux"``, otherwise ``False``.
442444 """
443445
446+ visitor = _LiteralEvalVisitor (ctx , file_path )
447+ try :
448+ value = visitor .visit (expr )
449+ except InvalidStub :
450+ if file_path .suffix == ".pyi" :
451+ raise
452+ return None
453+ else :
454+ return bool (value )
455+
456+
457+ class _LiteralEvalVisitor (ast .NodeVisitor ):
444458 def __init__ (self , ctx : SearchContext , file_path : Optional [Path ]) -> None :
445459 self .ctx = ctx
446460 self .file_path = file_path
0 commit comments