From b7bf1da479141be281590ba8b4563ee4c10d0165 Mon Sep 17 00:00:00 2001 From: Jeremy Kloth Date: Sat, 26 Mar 2022 13:18:02 -0600 Subject: [PATCH 1/3] bpo-47131: Speedup AST comparisons by using node traversal --- Lib/test/test_unparse.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index e38b33574ccccf..554290ce66e033 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -130,7 +130,23 @@ class Foo: pass class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): - self.assertEqual(ast.dump(ast1), ast.dump(ast2)) + missing = object() + def compare(a, b): + if type(a) is not type(b): + self.fail("type(a) != type(b)") + if isinstance(a, ast.AST): + for field in a._fields: + value1 = getattr(a, field, missing) + value2 = getattr(b, field, missing) + compare(value1, value2) + elif isinstance(a, list): + if len(a) != len(b): + self.fail("len(a) != len(b)") + for node1, node2 in zip(a, b): + compare(node1, node2) + elif a != b: + self.fail(f"{a!r} != {b!r}") + compare(ast1, ast2) def check_ast_roundtrip(self, code1, **kwargs): with self.subTest(code1=code1, ast_parse_kwargs=kwargs): From 1f03499c62c07717603ce0994196554d9b78f502 Mon Sep 17 00:00:00 2001 From: Jeremy Kloth Date: Mon, 28 Mar 2022 09:41:29 -0600 Subject: [PATCH 2/3] Add additional error checking with a bit of description --- Lib/test/test_unparse.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index 554290ce66e033..d40eb27d656dcc 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -130,23 +130,32 @@ class Foo: pass class ASTTestCase(unittest.TestCase): def assertASTEqual(self, ast1, ast2): + # Ensure the comparisons start at an AST node + self.assertIsInstance(ast1, ast.AST) + self.assertIsInstance(ast2, ast.AST) + + # An AST comparison routine modeled after ast.dump(), but + # instead of string building, it traverses the two trees + # in lock-step. Defining an inner function also prevents + # the `missing` value from polluting the "real" function + # signature or module namespace. missing = object() - def compare(a, b): + def traverse_compare(a, b): if type(a) is not type(b): self.fail("type(a) != type(b)") if isinstance(a, ast.AST): for field in a._fields: value1 = getattr(a, field, missing) value2 = getattr(b, field, missing) - compare(value1, value2) + traverse_compare(value1, value2) elif isinstance(a, list): if len(a) != len(b): self.fail("len(a) != len(b)") for node1, node2 in zip(a, b): - compare(node1, node2) + traverse_compare(node1, node2) elif a != b: self.fail(f"{a!r} != {b!r}") - compare(ast1, ast2) + traverse_compare(ast1, ast2) def check_ast_roundtrip(self, code1, **kwargs): with self.subTest(code1=code1, ast_parse_kwargs=kwargs): From dd8db6adfbca03fa8ea9a805866e33456db556c8 Mon Sep 17 00:00:00 2001 From: Jeremy Kloth Date: Wed, 30 Mar 2022 10:46:59 -0600 Subject: [PATCH 3/3] Improve error messaging --- Lib/test/test_unparse.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/Lib/test/test_unparse.py b/Lib/test/test_unparse.py index d40eb27d656dcc..f999ae8c16ceaf 100644 --- a/Lib/test/test_unparse.py +++ b/Lib/test/test_unparse.py @@ -136,23 +136,34 @@ def assertASTEqual(self, ast1, ast2): # An AST comparison routine modeled after ast.dump(), but # instead of string building, it traverses the two trees - # in lock-step. Defining an inner function also prevents - # the `missing` value from polluting the "real" function - # signature or module namespace. - missing = object() - def traverse_compare(a, b): + # in lock-step. + def traverse_compare(a, b, missing=object()): if type(a) is not type(b): - self.fail("type(a) != type(b)") + self.fail(f"{type(a)!r} is not {type(b)!r}") if isinstance(a, ast.AST): for field in a._fields: value1 = getattr(a, field, missing) value2 = getattr(b, field, missing) - traverse_compare(value1, value2) + # Singletons are equal by definition, so further + # testing can be skipped. + if value1 is not value2: + traverse_compare(value1, value2) elif isinstance(a, list): - if len(a) != len(b): - self.fail("len(a) != len(b)") - for node1, node2 in zip(a, b): - traverse_compare(node1, node2) + try: + for node1, node2 in zip(a, b, strict=True): + traverse_compare(node1, node2) + except ValueError: + # Attempt a "pretty" error ala assertSequenceEqual() + len1 = len(a) + len2 = len(b) + if len1 > len2: + what = "First" + diff = len1 - len2 + else: + what = "Second" + diff = len2 - len1 + msg = f"{what} list contains {diff} additional elements." + raise self.failureException(msg) from None elif a != b: self.fail(f"{a!r} != {b!r}") traverse_compare(ast1, ast2)