Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6d81d91

Browse files
authored
Merge pull request #150 from bcaller/starred-tuple-assign
Starred tuple assignment
2 parents c6820ff + 80113af commit 6d81d91

6 files changed

Lines changed: 94 additions & 11 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
a, *b, c, d, e = f, *g, *h, f + i, j

pyt/cfg/stmt_visitor.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -327,28 +327,59 @@ def visit_Try(self, node):
327327
return ControlFlowNode(try_node, last_statements, break_statements=body.break_statements)
328328

329329
def assign_tuple_target(self, node, right_hand_side_variables):
330-
new_assignment_nodes = list()
331-
for i, target in enumerate(node.targets[0].elts):
332-
value = node.value.elts[i]
330+
new_assignment_nodes = []
331+
remaining_variables = list(right_hand_side_variables)
332+
remaining_targets = list(node.targets[0].elts)
333+
remaining_values = list(node.value.elts) # May contain duplicates
333334

335+
def visit(target, value):
334336
label = LabelVisitor()
335337
label.visit(target)
336-
338+
rhs_visitor = RHSVisitor()
339+
rhs_visitor.visit(value)
337340
if isinstance(value, ast.Call):
338341
new_ast_node = ast.Assign(target, value)
339-
new_ast_node.lineno = node.lineno
340-
342+
ast.copy_location(new_ast_node, node)
341343
new_assignment_nodes.append(self.assignment_call_node(label.result, new_ast_node))
342-
343344
else:
344345
label.result += ' = '
345346
label.visit(value)
346-
347347
new_assignment_nodes.append(self.append_node(AssignmentNode(
348348
label.result,
349349
extract_left_hand_side(target),
350350
ast.Assign(target, value),
351-
right_hand_side_variables,
351+
rhs_visitor.result,
352+
line_number=node.lineno,
353+
path=self.filenames[-1]
354+
)))
355+
remaining_targets.remove(target)
356+
remaining_values.remove(value)
357+
for var in rhs_visitor.result:
358+
remaining_variables.remove(var)
359+
360+
# Pair targets and values until a Starred node is reached
361+
for target, value in zip(node.targets[0].elts, node.value.elts):
362+
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
363+
break
364+
visit(target, value)
365+
366+
# If there was a Starred node, pair remaining targets and values from the end
367+
for target, value in zip(reversed(list(remaining_targets)), reversed(list(remaining_values))):
368+
if isinstance(target, ast.Starred) or isinstance(value, ast.Starred):
369+
break
370+
visit(target, value)
371+
372+
if remaining_targets:
373+
label = LabelVisitor()
374+
label.handle_comma_separated(remaining_targets)
375+
label.result += ' = '
376+
label.handle_comma_separated(remaining_values)
377+
for target in remaining_targets:
378+
new_assignment_nodes.append(self.append_node(AssignmentNode(
379+
label.result,
380+
extract_left_hand_side(target),
381+
ast.Assign(target, remaining_values[0]),
382+
remaining_variables,
352383
line_number=node.lineno,
353384
path=self.filenames[-1]
354385
)))
@@ -380,8 +411,8 @@ def assign_multi_target(self, node, right_hand_side_variables):
380411
def visit_Assign(self, node):
381412
rhs_visitor = RHSVisitor()
382413
rhs_visitor.visit(node.value)
383-
if isinstance(node.targets[0], ast.Tuple): # x,y = [1,2]
384-
if isinstance(node.value, ast.Tuple):
414+
if isinstance(node.targets[0], (ast.Tuple, ast.List)): # x,y = [1,2]
415+
if isinstance(node.value, (ast.Tuple, ast.List)):
385416
return self.assign_tuple_target(node, rhs_visitor.result)
386417
elif isinstance(node.value, ast.Call):
387418
call = None

pyt/cfg/stmt_visitor_helper.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _get_names(node, result):
7979
return node.id + result
8080
elif isinstance(node, ast.Subscript):
8181
return result
82+
elif isinstance(node, ast.Starred):
83+
return _get_names(node.value, result)
8284
else:
8385
return _get_names(node.value, result + '.' + node.attr)
8486

pyt/helper_visitors/label_visitor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,7 @@ def visit_FormattedValue(self, node):
320320
self.result += ':'
321321
self.visit_joined_str(node.format_spec)
322322
self.result += '}'
323+
324+
def visit_Starred(self, node):
325+
self.result += '*'
326+
self.visit(node.value)

tests/cfg/cfg_test.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import ast
2+
13
from .cfg_base_test_case import CFGBaseTestCase
24

35
from pyt.core.node_types import (
@@ -779,6 +781,45 @@ def test_assignment_tuple_value(self):
779781

780782
self.assertEqual(self.cfg.nodes[node].label, 'a = (x, y)')
781783

784+
def test_assignment_starred(self):
785+
self.cfg_create_from_file('examples/example_inputs/assignment_starred.py')
786+
787+
middle_nodes = self.cfg.nodes[1:-1]
788+
self.assert_length(middle_nodes, expected_length=5)
789+
790+
visited = [self.cfg.nodes[0]]
791+
while True:
792+
current_node = visited[-1]
793+
if len(current_node.outgoing) != 1:
794+
break
795+
visited.append(current_node.outgoing[0])
796+
self.assertCountEqual(self.cfg.nodes, visited, msg="Did not complete a path from Entry to Exit")
797+
798+
self.assertEqual(middle_nodes[0].label, 'a = f')
799+
self.assertCountEqual( # We don't assert a specific order for the assignment nodes
800+
[n.label for n in middle_nodes],
801+
['a = f', 'd = f + i', 'e = j'] + ['*b, c = *g, *h'] * 2,
802+
)
803+
self.assertCountEqual(
804+
[(n.left_hand_side, n.right_hand_side_variables) for n in middle_nodes],
805+
[('a', ['f']), ('b', ['g', 'h']), ('c', ['g', 'h']), ('d', ['f', 'i']), ('e', ['j'])],
806+
)
807+
808+
def test_assignment_starred_list(self):
809+
self.cfg_create_from_ast(ast.parse('[a, b, c] = *d, e'))
810+
811+
middle_nodes = self.cfg.nodes[1:-1]
812+
self.assert_length(middle_nodes, expected_length=3)
813+
814+
self.assertCountEqual(
815+
[n.label for n in middle_nodes],
816+
['a, b = *d', 'a, b = *d', 'c = e'],
817+
)
818+
self.assertCountEqual(
819+
[(n.left_hand_side, n.right_hand_side_variables) for n in middle_nodes],
820+
[('a', ['d']), ('b', ['d']), ('c', ['e'])],
821+
)
822+
782823

783824
class CFGComprehensionTest(CFGBaseTestCase):
784825
def test_nodes(self):

tests/helper_visitors/label_visitor_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,7 @@ def test_joined_str(self):
7979
def test_joined_str_with_format_spec(self):
8080
label = self.perform_labeling_on_expression('f"a{b!s:.{length}}"')
8181
self.assertEqual(label.result, 'f\'a{b!s:.{length}}\'')
82+
83+
def test_starred(self):
84+
label = self.perform_labeling_on_expression('[a, *b] = *c, d')
85+
self.assertEqual(label.result, '[a, *b] = (*c, d)')

0 commit comments

Comments
 (0)