Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit cbc17f6

Browse files
authored
Merge pull request #148 from bcaller/bc-get-call-names
Make get_call_names more resilient
2 parents e692581 + 83e496f commit cbc17f6

2 files changed

Lines changed: 24 additions & 13 deletions

File tree

pyt/core/ast_helper.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,28 +44,23 @@ def generate_ast(path):
4444
raise IOError('Input needs to be a file. Path: ' + path)
4545

4646

47-
def _get_call_names_helper(node, result):
47+
def _get_call_names_helper(node):
4848
"""Recursively finds all function names."""
4949
if isinstance(node, ast.Name):
5050
if node.id not in BLACK_LISTED_CALL_NAMES:
51-
result.append(node.id)
52-
return result
53-
elif isinstance(node, ast.Call):
54-
return result
51+
yield node.id
5552
elif isinstance(node, ast.Subscript):
56-
return _get_call_names_helper(node.value, result)
53+
yield from _get_call_names_helper(node.value)
5754
elif isinstance(node, ast.Str):
58-
result.append(node.s)
59-
return result
60-
else:
61-
result.append(node.attr)
62-
return _get_call_names_helper(node.value, result)
55+
yield node.s
56+
elif isinstance(node, ast.Attribute):
57+
yield node.attr
58+
yield from _get_call_names_helper(node.value)
6359

6460

6561
def get_call_names(node):
6662
"""Get a list of call names."""
67-
result = list()
68-
return reversed(_get_call_names_helper(node, result))
63+
return reversed(list(_get_call_names_helper(node)))
6964

7065

7166
def _list_to_dotted_string(list_of_components):

tests/cfg/import_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,3 +733,19 @@ def test_get_call_names_multi(self):
733733
result = get_call_names_as_string(call.func)
734734

735735
self.assertEqual(result, 'abc.defg.hi')
736+
737+
def test_get_call_names_with_binop(self):
738+
m = ast.parse('(date.today() - timedelta(days=1)).strftime("%Y-%m-%d")')
739+
call = m.body[0].value
740+
741+
result = get_call_names_as_string(call.func)
742+
743+
self.assertEqual(result, 'strftime')
744+
745+
def test_get_call_names_with_comprehension(self):
746+
m = ast.parse('{a for a in b()}.union(c)')
747+
call = m.body[0].value
748+
749+
result = get_call_names_as_string(call.func)
750+
751+
self.assertEqual(result, 'union')

0 commit comments

Comments
 (0)