From 861efcd3fffb84f5570e88eb727a2bd848d733e7 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 5 Jul 2024 13:07:29 +0800 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- torch/utils/_cxx_pytree.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 01adf0a4f9b1f..d57d2c683bdcf 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -75,6 +75,10 @@ ] +__TORCH_DICT_SESSION = optree.dict_insertion_ordered(True, namespace="torch") +__TORCH_DICT_SESSION.__enter__() # enable globally and permanently + + T = TypeVar("T") S = TypeVar("S") U = TypeVar("U") From 452f75009a1ccda17b2b793496f8d9d6408e3501 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 5 Jul 2024 17:28:58 +0800 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- torch/utils/_cxx_pytree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index d57d2c683bdcf..7289a49c082e3 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -264,7 +264,7 @@ def tree_flatten( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) - ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf)) + ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) >>> tree_flatten(1) ([1], PyTreeSpec(*, NoneIsLeaf)) >>> tree_flatten(None) From 16f36dfc9e2ae351c7270e4d0144a97c2611c86b Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 5 Jul 2024 17:33:01 +0800 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- torch/utils/_cxx_pytree.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 7289a49c082e3..4710e15129d55 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -264,20 +264,15 @@ def tree_flatten( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_flatten(tree) - ([1, 2, 3, 4, None, 5], PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) + ([2, 3, 4, 1, None, 5], PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch')) >>> tree_flatten(1) - ([1], PyTreeSpec(*, NoneIsLeaf)) + ([1], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) >>> tree_flatten(None) - ([None], PyTreeSpec(*, NoneIsLeaf)) - - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` - if you want to keep the keys in the insertion order. - + ([None], PyTreeSpec(*, NoneIsLeaf, namespace='torch')) >>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) >>> tree_flatten(tree) - ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict([('b', (*, [*, *])), ('a', *), ('c', *), ('d', *)]), NoneIsLeaf)) + ([2, 3, 4, 1, None, 5], PyTreeSpec(OrderedDict({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}), NoneIsLeaf, namespace='torch')) Args: tree (pytree): A pytree to flatten. @@ -406,11 +401,11 @@ def tree_structure( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_structure(tree) - PyTreeSpec({'a': *, 'b': (*, [*, *]), 'c': *, 'd': *}, NoneIsLeaf) + PyTreeSpec({'b': (*, [*, *]), 'a': *, 'c': *, 'd': *}, NoneIsLeaf, namespace='torch') >>> tree_structure(1) - PyTreeSpec(*, NoneIsLeaf) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') >>> tree_structure(None) - PyTreeSpec(*, NoneIsLeaf) + PyTreeSpec(*, NoneIsLeaf, namespace='torch') Args: tree (pytree): A pytree to flatten. From 6353a5b6574cc18d7f168461b0a33fbbc3a6a8ff Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 5 Jul 2024 19:01:01 +0800 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- test/test_pytree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_pytree.py b/test/test_pytree.py index 0a1c480a8fa7d..e8bbed6e55321 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1284,7 +1284,7 @@ def test_treespec_repr(self): _, spec = cxx_pytree.tree_flatten(pytree) self.assertEqual( repr(spec), - ("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)"), + ("PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')"), ) @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") From 2dda457258f0d11368fdd91134b9466530aeca9f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 5 Jul 2024 22:43:15 +0800 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- torch/utils/_cxx_pytree.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 4710e15129d55..d4066811a6968 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -331,7 +331,7 @@ def tree_iter( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> list(tree_iter(tree)) - [1, 2, 3, 4, None, 5] + [2, 3, 4, 1, None, 5] >>> list(tree_iter(1)) [1] >>> list(tree_iter(None)) @@ -366,7 +366,7 @@ def tree_leaves( >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} >>> tree_leaves(tree) - [1, 2, 3, 4, None, 5] + [2, 3, 4, 1, None, 5] >>> tree_leaves(1) [1] >>> tree_leaves(None) From 27f15645041f576503b9f69e3f7846da0b458457 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sun, 7 Jul 2024 01:53:53 +0800 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- test/test_pytree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_pytree.py b/test/test_pytree.py index e8bbed6e55321..988b6a7055ccb 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1294,7 +1294,7 @@ def test_treespec_repr_dynamo(self): _, spec = cxx_pytree.tree_flatten(pytree) self.assertExpectedInline( repr(spec), - "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)", + "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')", ) @parametrize( From 8a7abfe510a04adb4d365b5661644dd1611185f4 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 5 Aug 2024 23:41:42 +0800 Subject: [PATCH 7/7] Update [ghstack-poisoned] --- test/test_pytree.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/test_pytree.py b/test/test_pytree.py index cd0357dac32fc..89c101a503dd9 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -1308,7 +1308,9 @@ def test_treespec_repr(self): # Check that it looks sane pytree = (0, [0, 0, [0]]) _, spec = cxx_pytree.tree_flatten(pytree) - self.assertEqual(repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)") + self.assertEqual( + repr(spec), "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')" + ) @unittest.skipIf(not TEST_WITH_TORCHDYNAMO, "Eager test in test_treespec_repr.") def test_treespec_repr_dynamo(self): @@ -1317,7 +1319,7 @@ def test_treespec_repr_dynamo(self): _, spec = cxx_pytree.tree_flatten(pytree) self.assertExpectedInline( repr(spec), - "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf)", + "PyTreeSpec((*, [*, *, [*]]), NoneIsLeaf, namespace='torch')", ) @parametrize(