From e584f1980df643a55d792c5e842ee15e9a7c34ba Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 14 Oct 2024 17:31:16 +0800 Subject: [PATCH 1/4] Update [ghstack-poisoned] --- test/functorch/common_utils.py | 2 +- test/functorch/test_aotdispatch.py | 2 +- test/functorch/test_control_flow.py | 6 +-- test/functorch/test_eager_transforms.py | 2 +- test/functorch/test_ops.py | 31 ++++++----- test/functorch/test_vmap.py | 4 +- .../collect_metadata_analysis.py | 7 +-- .../dispatch_and_compile_graph.py | 2 +- .../_aot_autograd/input_output_analysis.py | 11 ++-- torch/_functorch/_aot_autograd/schemas.py | 25 ++++++--- .../_aot_autograd/subclass_utils.py | 4 +- .../traced_function_transforms.py | 2 +- torch/_functorch/_aot_autograd/utils.py | 12 ++--- torch/_functorch/aot_autograd.py | 12 ++--- torch/_functorch/autograd_function.py | 16 +++--- torch/_functorch/compile_utils.py | 5 +- torch/_functorch/compilers.py | 6 +-- torch/_functorch/eager_transforms.py | 52 +++++++++---------- torch/_functorch/partitioners.py | 8 +-- torch/_functorch/pyfunctorch.py | 6 +-- torch/_functorch/vmap.py | 11 ++-- torch/_higher_order_ops/map.py | 2 +- torch/_higher_order_ops/utils.py | 6 +-- torch/return_types.py | 3 +- torch/utils/_cxx_pytree.py | 5 +- torch/utils/_pytree.py | 6 +++ torch/utils/pytree.py | 23 +++++++- 27 files changed, 154 insertions(+), 117 deletions(-) diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 3cc61b84a52cb..5070531eeacb8 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -12,7 +12,7 @@ from functorch_additional_op_db import additional_op_db import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from functorch import vmap from torch.testing._internal.autograd_function_db import autograd_function_db from torch.testing._internal.common_device_type import toleranceOverride diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 3e1eeb8255b75..5052601fec988 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -20,7 +20,7 @@ import torch import torch._dynamo as torchdynamo import torch.nn as nn -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from functorch import grad, jacrev, make_fx, vjp, vmap from functorch.compile import ( aot_function, diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e4714fe768fb5..0aa073b72454e 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4,7 +4,7 @@ import unittest import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from functorch.experimental import control_flow from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from torch._higher_order_ops.associative_scan import associative_scan @@ -1293,8 +1293,6 @@ def f(x, y): self.assertEqual(expected_grads, grads) def test_map_autograd_nested_list(self): - import torch.utils._pytree as pytree - def f(x, y): a, b = x c, d = a @@ -4304,8 +4302,6 @@ def g(xs, y): self.check_map_count(gm, 2) def test_tracing_map_autograd_symbolic_list(self): - import torch.utils._pytree as pytree - def f(x, y): return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index a1bd52a2fbb80..ddd3453b66f66 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -79,7 +79,7 @@ TestCase, xfailIfTorchDynamo, ) -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils.pytree import tree_flatten, tree_map, tree_unflatten USE_TORCHVISION = False diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 54136a4f7babb..2b92bf6bc3605 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -58,8 +58,7 @@ unMarkDynamoStrictTest, ) from torch.testing._internal.opinfo.core import SampleInput -from torch.utils import _pytree as pytree -from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten +from torch.utils.pytree import tree_flatten, tree_leaves, tree_map, tree_unflatten aten = torch.ops.aten @@ -161,7 +160,7 @@ def normalize_op_input_output3( f, args, kwargs, sample_args, output_process_fn_grad=None ): flat_args, args_spec = tree_flatten(args) - flat_sample_args = pytree.tree_leaves(sample_args) + flat_sample_args = tree_leaves(sample_args) diff_argnums = tuple( i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args)) @@ -299,8 +298,8 @@ def wrapped(*args): if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, tangents @@ -334,8 +333,8 @@ def wrapped(*args): if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, primals + tangents @@ -1086,7 +1085,7 @@ def test_vmapvjpvjp(self, device, dtype, op): fn, args = get_vjpfull_variant(op, sample) result = fn(*args) cotangents = tree_map(lambda x: torch.randn_like(x), result) - cotangents = pytree.tree_leaves(cotangents) + cotangents = tree_leaves(cotangents) num_args = len(args) args_and_cotangents = tuple(args) + tuple(cotangents) @@ -1096,8 +1095,8 @@ def vjp_of_vjp(*args_and_cotangents): cotangents = args_and_cotangents[num_args:] result, vjp_fn = vjp(fn, *args) result_vjps = vjp_fn(cotangents) - result = pytree.tree_leaves(result) - result_vjps = pytree.tree_leaves(result_vjps) + result = tree_leaves(result) + result_vjps = tree_leaves(result_vjps) return (*result, *result_vjps) is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) @@ -2103,8 +2102,8 @@ def jvp_of_vjp(*args): (primals, tangents) = tree_unflatten(args, spec) primals_out, tangents_out = jvp(push_vjp, primals, tangents) - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs) @@ -2421,7 +2420,7 @@ def is_differentiable(inp): ) def get_flat_differentiable(tree): - flattened = pytree.tree_leaves(tree) + flattened = tree_leaves(tree) return tuple(i for i in flattened if is_differentiable(i)) def get_differentiable_linked(list1, list2): @@ -2434,7 +2433,7 @@ def get_differentiable_linked(list1, list2): return zip(*paired_list) def filter_none(out): - flattened = pytree.tree_leaves(out) + flattened = tree_leaves(out) return tuple(o for o in flattened if o is not None) if not op.supports_autograd: @@ -2452,8 +2451,8 @@ def compute_grad(cotangents): out_flattened = out cotangents_flattened = cotangents if not isinstance(out_flattened, torch.Tensor): - out_flattened = pytree.tree_leaves(out) - cotangents_flattened = pytree.tree_leaves(cotangents) + out_flattened = tree_leaves(out) + cotangents_flattened = tree_leaves(cotangents) out_flattened, cotangents_flattened = get_differentiable_linked( out_flattened, cotangents_flattened ) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 870b0e61b26e5..67ebd12fa23a2 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -38,6 +38,7 @@ import functorch import torch import torch.nn.functional as F +import torch.utils.pytree as pytree from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap from functorch.experimental import chunk_vmap from torch import Tensor @@ -76,7 +77,6 @@ xfailIfTorchDynamo, ) from torch.testing._internal.custom_op_db import custom_op_db -from torch.utils import _pytree as pytree def get_platform_specific_sdpa(): @@ -1340,7 +1340,7 @@ def _vmap_test( check_propagates_grad=True, ): result = vmap(op, in_dims, out_dims)(*inputs) - are_nested = [t.is_nested for t in pytree.tree_leaves(result)] + are_nested = [t.is_nested for t in pytree.tree_iter(result)] reference_result = reference_vmap( op, inputs, in_dims, out_dims, return_nt=any(are_nested) ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 9be7779789196..a0a0088d886a8 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -15,7 +15,7 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Set import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._guards import detect_fake_mode from torch._logging import getArtifactLogger @@ -694,11 +694,8 @@ def view_avoid_dupes_with_primals(t): view_avoid_dupes_with_primals, traced_tangents ) - output_tangents_start_idx = len(f_input_tangents) - output_tangents_end_idx = output_tangents_start_idx + len(f_output_tangents) tangents_and_memory_formats = [ - coerce_tangent_and_suggest_memory_format(tt) - for i, tt in enumerate(traced_tangents) + coerce_tangent_and_suggest_memory_format(tt) for tt in traced_tangents ] traced_tangents = [t[0] for t in tangents_and_memory_formats] traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats] diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 62cf7b68cd3fa..87dd02c3c41d1 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -8,8 +8,8 @@ from typing import Any, Dict, List, Optional, Tuple import torch -import torch.utils._pytree as pytree import torch.utils.dlpack +import torch.utils.pytree as pytree from torch import Tensor from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import lazy_format_graph_code diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 8f330a056b7ae..0a828187946c0 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -11,10 +11,9 @@ """ import itertools -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -import torch.utils._pytree as pytree from torch import Tensor from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental.symbolic_shapes import is_concrete_int @@ -32,6 +31,10 @@ from .utils import strict_zip +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + zip = strict_zip @@ -421,8 +424,8 @@ def _graph_output_names(gm): def create_graph_signature( fx_g: torch.fx.GraphModule, fw_metadata: ViewAndMutationMeta, - in_spec: pytree.TreeSpec, - out_spec: pytree.TreeSpec, + in_spec: "PyTreeSpec", + out_spec: "PyTreeSpec", *, user_args_flat: List[Tensor], params_and_buffers_flat: List[Tensor], diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 4439908c601fb..62176f8470563 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -9,10 +9,19 @@ import functools from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, List, NewType, Optional, Set, Union +from typing import ( + Any, + Callable, + Dict, + List, + NewType, + Optional, + Set, + TYPE_CHECKING, + Union, +) import torch -import torch.utils._pytree as pytree from torch._guards import Source from torch._ops import OpOverload from torch._subclasses import FakeTensor @@ -27,6 +36,10 @@ from .utils import strict_zip +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + zip = strict_zip @@ -691,8 +704,8 @@ class GraphSignature: buffers_to_mutate: Dict[GraphOutputName, FQN] user_inputs_to_mutate: Dict[GraphOutputName, GraphInputName] - in_spec: pytree.TreeSpec - out_spec: pytree.TreeSpec + in_spec: "PyTreeSpec" + out_spec: "PyTreeSpec" backward_signature: Optional[BackwardSignature] @@ -703,8 +716,8 @@ class GraphSignature: def from_tracing_metadata( cls, *, - in_spec: pytree.TreeSpec, - out_spec: pytree.TreeSpec, + in_spec: "PyTreeSpec", + out_spec: "PyTreeSpec", graph_input_names: List[str], graph_output_names: List[str], view_mutation_metadata: ViewAndMutationMeta, diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index ad8de0eac069f..18cae82f81b37 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -8,7 +8,7 @@ import typing from typing import Any, List, Optional, Tuple, Union -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._subclasses.fake_tensor import get_plain_tensors from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -21,7 +21,7 @@ def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: - args_flattened = pytree.arg_tree_leaves(*args) + args_flattened = pytree.tree_leaves(args) any_subclass_args = any( is_traceable_wrapper_subclass(x) for x in args_flattened diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index 6f1462febef76..982deeb8f99ff 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -19,7 +19,7 @@ import torch import torch.fx.traceback as fx_traceback -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._decomp.decompositions_for_rng import PhiloxStateTracker from torch._guards import detect_fake_mode diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index ca26234fdaab3..3b942262e0e33 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -11,7 +11,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch._library.fake_class_registry import FakeScriptObject from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor @@ -138,19 +138,19 @@ def call_func_at_runtime_with_args( # Inspired by autodidax (thanks!) class PytreeThunk: - spec: Optional[pytree.TreeSpec] = None + spec: Optional["pytree.PyTreeSpec"] = None # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. is_simple: Optional[ bool ] = None # if the output spec is a tuple/list, we won't bother unflattening it. is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec - def set(self, spec: pytree.TreeSpec) -> None: + def set(self, spec: "pytree.PyTreeSpec") -> None: assert self.spec is None or self.spec == spec assert spec is not None - self.spec: pytree.TreeSpec = spec + self.spec = spec if self.spec.type in {tuple, list} and all( - child.is_leaf() for child in spec.children_specs + child.is_leaf() for child in spec.children() ): self.is_simple = True if self.spec.is_leaf(): @@ -172,7 +172,7 @@ def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThu if kwargs is None: kwargs = {} # Save the args_spec for flat_tensor_args to unflatten while tracing - _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + tensor_args_spec = pytree.tree_structure((args, kwargs)) out_spec = PytreeThunk() def flat_fn(*flat_args): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 87c49887dea15..1a1061356fc49 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -9,8 +9,8 @@ import torch import torch._dynamo.logging import torch.nn as nn -import torch.utils._pytree as pytree import torch.utils.dlpack +import torch.utils.pytree as pytree from torch import Tensor from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions from torch._dispatch.python import enable_python_dispatcher @@ -859,7 +859,7 @@ def aot_function( def returned_function(*args, **kwargs): nonlocal cached_res # Now flatten the tensor args - flat_args = pytree.arg_tree_leaves(*args, **kwargs) + flat_args = pytree.tree_leaves((args, kwargs)) # Compile the function and save it in the cache if cached_res is None: @@ -1314,7 +1314,7 @@ def flattened_joint(*args): fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args) - user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) + user_args_flat = pytree.tree_leaves((args, kwargs)) return fx_g, create_graph_signature( fx_g, metadata, @@ -1370,7 +1370,7 @@ def aot_export_joint_simple( args, decompositions=decompositions, ) - in_spec, _kw_in_spec = in_spec.children_specs + in_spec, _kw_in_spec = in_spec.children() # At this point, we can just directly return the (joint or inference graph) that we traced. # First though: a bunch of assertions to make sure that our graph doesn't require # any calling convention changes compared to the original function. @@ -1397,7 +1397,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" ) - if not all(child.is_leaf() for child in in_spec.children_specs): + if not all(child.is_leaf() for child in in_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" ) @@ -1405,7 +1405,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" ) - if not all(child.is_leaf() for child in out_spec.children_specs): + if not all(child.is_leaf() for child in out_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" ) diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 0d66cb7a50cbb..f5d0c7ff7b9be 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -2,7 +2,7 @@ from typing import Any, NamedTuple, Tuple import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch._C._functorch import ( _unwrap_for_grad, _wrap_for_grad, @@ -167,8 +167,8 @@ def jvp(ctx, *tangents): def wrap_outputs_maintaining_identity( outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS ): - flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) - flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) + flat_unwrapped_inputs = pytree.tree_leaves(unwrapped_inputs) + flat_orig_inputs = pytree.tree_leaves(orig_inputs) unwrapped_input_to_orig_input = { id(unwrapped): orig @@ -185,12 +185,13 @@ def wrap_outputs_maintaining_identity( # _broadcast_to_and_flatten returns None if it is unable to broadcast. # TODO: update following link from master to stable once that's out if flat_out_dims is None: + out_dims_spec = pytree.tree_structure(out_dims) raise RuntimeError( f"The autograd.Function's vmap staticmethod returned an " f"incompatible (output, out_dims) tuple. " f"Expected out_dims={out_dims} " f"to be compatible with the structure of `output`. " - f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " + f"out_dims has structure {pytree.treespec_pprint(out_dims_spec)} " f"but output has structure {spec}. " f"For more details, please see " f"https://pytorch.org/docs/main/notes/extending.func.html" @@ -275,10 +276,7 @@ def validate_vmap_returns_tuple_of_two_elements(result): @custom_function_call.py_impl(TransformType.Vmap) def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): - if any( - isinstance(val, torch.Tensor) - for val in torch.utils._pytree.tree_flatten(kwargs)[0] - ): + if pytree.tree_any(torch.is_tensor, kwargs): raise NotImplementedError( f"Run vmap on autograd.Function with kwarg-only Tensor args. " f"Please do not pass kwarg-only Tensors to autograd.Function. " @@ -505,7 +503,7 @@ def get_out_dims(): # the corresponding in_dims with None. def get_tangents_in_dims(input_dims, tangents): flat_in_dims, spec = pytree.tree_flatten(input_dims) - flat_tangents = pytree.arg_tree_leaves(*tangents) + flat_tangents = pytree.tree_leaves(tangents) result = [ None if tangent is None else in_dim for in_dim, tangent in zip(flat_in_dims, flat_tangents) diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 3bf61f1af3bf3..d7a6147a893b8 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -5,9 +5,8 @@ import torch import torch.fx as fx +import torch.utils.pytree as pytree from torch.multiprocessing.reductions import StorageWeakRef -from torch.utils import _pytree as pytree -from torch.utils._pytree import tree_flatten aten = torch.ops.aten @@ -102,7 +101,7 @@ def checkable_node(node: fx.Node) -> bool: # substitute args and kwargs members to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): - arg_list, spec = tree_flatten(arg_list) + arg_list, spec = pytree.tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, torch.fx.node.Node) and v in env: diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index b420daca5ac34..fb106f6d47b84 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -14,7 +14,7 @@ import torch import torch.fx as fx import torch.nn as nn -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch import SymInt from torch._decomp import get_decompositions from torch.fx.experimental.symbolic_shapes import bind_symbols @@ -160,8 +160,8 @@ def check(nv, rv, desc): r = super().run_node(n) if "val" in n.meta: - n_vals, n_spec = pytree.tree_flatten(n.meta["val"]) - r_vals, r_spec = pytree.tree_flatten(r) + n_vals = pytree.tree_leaves(n.meta["val"]) + r_vals = pytree.tree_leaves(r) # TODO: There is some sort of problem where we record that an # operator returned a tuple/list, and then later it turns out the # real version of the operator returned a list/tuple. Need to diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index d389c7fda7894..48ea4e5622720 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -32,12 +32,15 @@ from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental import const_fold from torch.fx.experimental.proxy_tensor import make_fx -from torch.utils import _pytree as pytree -from torch.utils._pytree import ( +from torch.utils.pytree import ( + tree_any, tree_flatten, + tree_iter, + tree_leaves, tree_map, tree_map_, tree_map_only, + tree_structure, tree_unflatten, treespec_pprint, ) @@ -64,8 +67,8 @@ def enable_inplace_requires_grad(enabled): def _vjp_treespec_compare(primals_out, cotangents): # Revert this once #116264 gets fixed - _, primals_out_spec = tree_flatten(primals_out) - _, cotangents_spec = tree_flatten(cotangents) + primals_out_spec = tree_structure(primals_out) + cotangents_spec = tree_structure(cotangents) # Dynamo fails to trace operator.ne below. To bypass this limitation, this # function is not inlined. if primals_out_spec != cotangents_spec: @@ -79,8 +82,8 @@ def _vjp_treespec_compare(primals_out, cotangents): def _jvp_treespec_compare(primals, tangents): # Revert this once #116264 gets fixed - _, primals_spec = tree_flatten(primals) - _, tangents_spec = tree_flatten(tangents) + primals_spec = tree_structure(primals) + tangents_spec = tree_structure(tangents) if primals_spec != tangents_spec: raise RuntimeError( f"{jvp_str}: Expected primals and tangents to have the same python " @@ -92,8 +95,8 @@ def _jvp_treespec_compare(primals, tangents): def _linearize_treespec_compare(primals, tangents): # Revert this once #116264 gets fixed - _, primals_argspec = tree_flatten(primals) - _, tangent_argspec = tree_flatten(tangents) + primals_argspec = tree_structure(primals) + tangent_argspec = tree_structure(tangents) if tangent_argspec != primals_argspec: raise RuntimeError( f"Expected the tangents {tangent_argspec} to have " @@ -139,8 +142,7 @@ def _is_differentiable(maybe_tensor): def _any_differentiable(tensor_or_tuple_of_tensors): - flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) - return any(tuple(map(_is_differentiable, flat_args))) + return tree_any(_is_differentiable, tensor_or_tuple_of_tensors) def _wrap_tensor_for_grad(maybe_tensor, level): @@ -407,7 +409,7 @@ def _vjp_with_argnums( primals_out, aux = primals_out aux = _undo_create_differentiable(aux, level) - flat_primals_out, primals_out_spec = tree_flatten(primals_out) + flat_primals_out = tree_leaves(primals_out) assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)") flat_diff_primals, primals_spec = tree_flatten(diff_primals) results = _undo_create_differentiable(primals_out, level) @@ -425,7 +427,7 @@ def _vjp_with_argnums( def wrapper(cotangents, retain_graph=True, create_graph=None): if create_graph is None: create_graph = torch.is_grad_enabled() - flat_cotangents, cotangents_spec = tree_flatten(cotangents) + flat_cotangents = tree_leaves(cotangents) _vjp_treespec_compare(primals_out, cotangents) result = _autograd_grad( flat_primals_out, @@ -450,8 +452,7 @@ def _safe_zero_index(x): # jacrev and jacfwd don't support complex functions # Helper function to throw appropriate error. def error_if_complex(func_name, args, is_input): - flat_args = pytree.tree_leaves(args) - for idx, arg in enumerate(flat_args): + for idx, arg in enumerate(tree_iter(args)): if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: input_or_output = "inputs" if is_input else "outputs" err_msg = ( @@ -645,7 +646,7 @@ def compute_jacobian_stacked(): else: # chunk_size is None or chunk_size != 1 chunked_result = vmap(vjp_fn)(basis) - flat_results = pytree.tree_leaves(chunked_result) + flat_results = tree_leaves(chunked_result) if chunk_size == 1: flat_results = tree_map( @@ -702,7 +703,7 @@ def compute_jacobian_preallocate_and_copy(): else: # chunk_size is None or chunk_size != 1 chunked_result = vmap(vjp_fn)(basis) - flat_results = pytree.tree_leaves(chunked_result) + flat_results = tree_leaves(chunked_result) # Short-circuit if we have a single chunk. if chunk_size is None or chunk_size >= out_vec_size: @@ -1115,7 +1116,7 @@ def _jvp_with_argnums( ) diff_args = primals if argnums is None else _slice_argnums(primals, argnums) flat_primals, primals_spec = tree_flatten(diff_args) - flat_tangents, tangents_spec = tree_flatten(tangents) + flat_tangents = tree_leaves(tangents) _jvp_treespec_compare(diff_args, tangents) assert_non_empty_list_of_tensors(flat_primals, jvp_str, "primals") assert_non_empty_list_of_tensors(flat_tangents, jvp_str, "tangents") @@ -1312,9 +1313,7 @@ def push_jvp(basis): results, aux = results # aux is in the standard basis format, e.g. NxN matrix # We need to fetch the first element as original `func` output - flat_aux, aux_spec = tree_flatten(aux) - flat_aux = [value[0] for value in flat_aux] - aux = tree_unflatten(flat_aux, aux_spec) + aux = tree_map(lambda value: value[0], aux) jac_outs, spec = tree_flatten(results) # Most probably below output check can never raise an error @@ -1680,16 +1679,15 @@ def wrapped(*args, **kwargs): func_args = _wrap_all_tensors_to_functional(args, func_level) func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) - flattened_unwrapped_args = pytree.arg_tree_leaves(*args) - flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) - flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) - flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) + flattened_unwrapped_args = tree_leaves(args) + flattened_wrapped_args = tree_leaves(func_args) + flattened_unwrapped_kwargs = tree_leaves(kwargs) + flattened_wrapped_kwargs = tree_leaves(func_kwargs) func_outputs = func(*func_args, **func_kwargs) outputs = _unwrap_all_tensors_from_functional( func_outputs, reapply_views=reapply_views ) - flat_outputs, func_out_spec = tree_flatten(outputs) for a in flattened_wrapped_args + flattened_wrapped_kwargs: if isinstance(a, torch.Tensor): @@ -1765,7 +1763,7 @@ def linearize(func: Callable, *primals) -> Tuple[Any, Callable]: # make_fx such that it also returns the output. output = func(*primals) - _, output_spec = tree_flatten(output) + output_spec = tree_structure(output) flat_primals, primals_argspec = tree_flatten(primals) @@ -1826,7 +1824,7 @@ def forward_ad_checks(flat_tangents): # It takes care of checking the argspec of tangents, # calling the folded fx graph and unflattening fx graph output def jvp_fn(*tangents): - flat_tangents, tangent_argspec = tree_flatten(tangents) + flat_tangents = tree_leaves(tangents) _linearize_treespec_compare(primals, tangents) forward_ad_checks(flat_tangents) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index d5ba05eca0390..0cf505be94c3f 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -14,7 +14,7 @@ import torch import torch._inductor.inductor_prims import torch.fx as fx -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.sym_node import magic_methods, method_to_operator @@ -189,7 +189,7 @@ def _extract_graph_with_inputs_outputs( elif node.op == "placeholder": env[node] = InvalidNode # type: ignore[assignment] elif node.op == "call_function": - all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = pytree.tree_leaves((node.args, node.kwargs)) all_args = [ isinstance(env[x], InvalidNodeBase) for x in all_args @@ -267,8 +267,8 @@ def _must_be_in_backward(node: fx.Node) -> bool: def _extract_fwd_bwd_outputs( joint_module: fx.GraphModule, *, num_fwd_outputs ) -> Tuple[List[fx.Node], List[fx.Node]]: - outputs = pytree.arg_tree_leaves( - *(node.args for node in joint_module.graph.find_nodes(op="output")) + outputs = pytree.tree_leaves( + [node.args for node in joint_module.graph.find_nodes(op="output")] ) fwd_outputs = outputs[:num_fwd_outputs] bwd_outputs = outputs[num_fwd_outputs:] diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index b2dfaa116f729..d0cb80698025e 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -4,7 +4,7 @@ from typing import Any, List, Tuple import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch._C._functorch import ( CFunctionalizeInterpreterPtr, CGradInterpreterPtr, @@ -163,7 +163,7 @@ def __init__(self, cdata: CInterpreter): def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( - torch.Tensor, self._cptr.lift, [args, kwargs] + torch.Tensor, self._cptr.lift, (args, kwargs) ) return args, kwargs @@ -197,7 +197,7 @@ def __init__(self, cdata: CInterpreter): def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( - torch.Tensor, self._cptr.lift, [args, kwargs] + torch.Tensor, self._cptr.lift, (args, kwargs) ) return args, kwargs diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index fcb96ad06d24e..293aff9cc7f9d 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -12,7 +12,7 @@ import os import threading from functools import partial -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch import Tensor @@ -23,15 +23,18 @@ _vmap_increment_nesting, is_batchedtensor, ) -from torch.utils._pytree import ( +from torch.utils.pytree import ( _broadcast_to_and_flatten, tree_flatten, tree_map_, tree_unflatten, - TreeSpec, ) +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + in_dims_t = Union[int, Tuple] out_dims_t = Union[int, Tuple[int, ...]] @@ -91,7 +94,7 @@ def _as_tuple( def _process_batched_inputs( in_dims: in_dims_t, args: Tuple, func: Callable -) -> Tuple[int, List[Any], List[Any], TreeSpec]: +) -> Tuple[int, List[Any], List[Any], "PyTreeSpec"]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index d57d68d5e473f..e6e2fc05d3e3c 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import torch -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization from torch._functorch.aot_autograd import AOTConfig, create_joint diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 139e9a160cbe2..8aca379c2ec6f 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -6,7 +6,7 @@ import torch import torch.fx.traceback as fx_traceback -import torch.utils._pytree as pytree +import torch.utils.pytree as pytree from torch._ops import OperatorBase from torch.fx.experimental.proxy_tensor import make_fx from torch.multiprocessing.reductions import StorageWeakRef @@ -35,7 +35,7 @@ def autograd_not_implemented_inner( """ with torch._C._AutoDispatchBelowAutograd(): result = operator(*args, **kwargs) - flat_operands = pytree.arg_tree_leaves(*args) + flat_operands = pytree.tree_leaves(args) if torch.is_grad_enabled() and any( f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) ): @@ -185,7 +185,7 @@ def check_alias(out): return out_storage in input_storages return False - if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): + if pytree.tree_any(check_alias, node.args): return True for _, module in gm.named_children(): diff --git a/torch/return_types.py b/torch/return_types.py index d456742be4b88..9861e202a4a70 100644 --- a/torch/return_types.py +++ b/torch/return_types.py @@ -1,7 +1,6 @@ import inspect import torch -from torch.utils._pytree import register_pytree_node, SequenceKey __all__ = ["pytree_register_structseq", "all_return_types"] @@ -13,6 +12,8 @@ def pytree_register_structseq(cls): + from torch.utils._pytree import register_pytree_node, SequenceKey + def structseq_flatten(structseq): return list(structseq), None diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index eef5f275d8004..c9f5cbccdcddf 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -30,7 +30,10 @@ from typing_extensions import deprecated, Self import optree -from optree import PyTreeSpec as TreeSpec # direct import for type annotations +from optree import ( # noqa: F401 # direct import for type annotations + PyTreeSpec as PyTreeSpec, + PyTreeSpec as TreeSpec, +) import torch.utils._pytree as python_pytree from torch.utils._pytree import KeyEntry as KeyEntry diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 2bbad00c64e67..ac6e8e6b872d0 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -717,6 +717,12 @@ def __repr__(self, indent: int = 0) -> str: def is_leaf(self) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 + def children(self) -> List["TreeSpec"]: + return self.children_specs.copy() + + def child(self, index: int) -> "TreeSpec": + return self.children_specs[index] + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: if self.is_leaf(): subtrees.append(tree) diff --git a/torch/utils/pytree.py b/torch/utils/pytree.py index ae21f196f676a..ded410e108c7a 100644 --- a/torch/utils/pytree.py +++ b/torch/utils/pytree.py @@ -15,6 +15,7 @@ import os as _os from dataclasses import dataclass as _dataclass from typing import ( + Any as _Any, Callable as _Callable, Literal as _Literal, TYPE_CHECKING as _TYPE_CHECKING, @@ -29,7 +30,9 @@ from types import ModuleType import torch.utils._cxx_pytree as cxx - from torch.utils._cxx_pytree import ( + from torch.utils._cxx_pytree import ( # noqa: TCH004; noqa: F401 + _broadcast_to_and_flatten as _broadcast_to_and_flatten, + PyTreeSpec as PyTreeSpec, register_pytree_node as register_pytree_node, tree_all as tree_all, tree_all_only as tree_all_only, @@ -49,6 +52,7 @@ __all__ = [ + "PyTreeSpec", "register_pytree_node", "tree_flatten", "tree_unflatten", @@ -130,5 +134,22 @@ def exported(*args: _P.args, **kwargs: _P.kwargs) -> _T: treespec_pprint = _reexport(implementation.module.treespec_pprint) +# Used in vmap +_broadcast_to_and_flatten = _reexport(implementation.module._broadcast_to_and_flatten) + + del _reexport del PyTreeImplementation + + +# Use the __getattr__ function allowing us to change the underlying `implementation` at runtime. +def __getattr__(name: str) -> _Any: + name = {"PyTreeSpec": "TreeSpec"}.get(name, name) + try: + return getattr(implementation.module, name) + except AttributeError as ex: + raise AttributeError( + f"module {__name__!r} has no attribute {name!r}: " + f"no attribute {name!r} in " + f"{implementation.name} implementation: {implementation.module.__name__!r}" + ) from ex From b4ab24bb49cbbf0ccf01bbca517afde6ff541971 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Mon, 14 Oct 2024 17:59:23 +0800 Subject: [PATCH 2/4] Update [ghstack-poisoned] --- torch/export/_trace.py | 2 +- torch/utils/_cxx_pytree.py | 5 ----- torch/utils/_pytree.py | 5 ----- torch/utils/pytree.py | 2 +- 4 files changed, 2 insertions(+), 12 deletions(-) diff --git a/torch/export/_trace.py b/torch/export/_trace.py index ded8d0376bb09..3d72980efafce 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1498,7 +1498,7 @@ def wrapped_fn(*args): ), buffers_to_mutate={}, user_inputs_to_mutate={}, - in_spec=in_spec, + in_spec=in_spec, # type: ignore[arg-type] out_spec=out_spec, # type: ignore[arg-type] backward_signature=None, input_tokens=[], diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index c9f5cbccdcddf..08db973a4d361 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -309,11 +309,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: The reconstructed pytree, containing the ``leaves`` placed in the structure described by ``treespec``. """ - if not isinstance(treespec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(treespec)}." - ) return optree.tree_unflatten(treespec, leaves) # type: ignore[arg-type] diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index ac6e8e6b872d0..41fe4dd5009d8 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -885,11 +885,6 @@ def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: """Given a list of values and a TreeSpec, builds a pytree. This is the inverse operation of `tree_flatten`. """ - if not isinstance(treespec, TreeSpec): - raise TypeError( - f"tree_unflatten(leaves, treespec): Expected `treespec` to be " - f"instance of TreeSpec but got item of type {type(treespec)}.", - ) return treespec.unflatten(leaves) diff --git a/torch/utils/pytree.py b/torch/utils/pytree.py index ded410e108c7a..ea9cdb7617e1c 100644 --- a/torch/utils/pytree.py +++ b/torch/utils/pytree.py @@ -30,7 +30,7 @@ from types import ModuleType import torch.utils._cxx_pytree as cxx - from torch.utils._cxx_pytree import ( # noqa: TCH004; noqa: F401 + from torch.utils._cxx_pytree import ( # noqa: TCH004 _broadcast_to_and_flatten as _broadcast_to_and_flatten, PyTreeSpec as PyTreeSpec, register_pytree_node as register_pytree_node, From bccc9d49c0449deb7f79dacded09eef4e40d0923 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Sat, 7 Dec 2024 23:33:37 +0800 Subject: [PATCH 3/4] Update [ghstack-poisoned] --- torch/_functorch/eager_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 7b04edfc6f551..18c52b1b54392 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -367,7 +367,7 @@ def _vjp_with_argnums( primals_out, aux = primals_out aux = _undo_create_differentiable(aux, level) - flat_primals_out = tree_leaves(primals_out) + flat_primals_out, primals_out_spec = tree_flatten(primals_out) assert_non_empty_tensor_output(flat_primals_out, "vjp(f, *primals)") flat_diff_primals, primals_spec = tree_flatten(diff_primals) results = _undo_create_differentiable(primals_out, level) From 4d29b05f4ad1df01aad0fbc9b8cd35d2ae0e7bae Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 8 Jan 2025 20:15:56 +0800 Subject: [PATCH 4/4] Update [ghstack-poisoned] --- torch/_functorch/_aot_autograd/traced_function_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_functorch/_aot_autograd/traced_function_transforms.py b/torch/_functorch/_aot_autograd/traced_function_transforms.py index fcc5767e77980..8c5b392275585 100644 --- a/torch/_functorch/_aot_autograd/traced_function_transforms.py +++ b/torch/_functorch/_aot_autograd/traced_function_transforms.py @@ -19,7 +19,7 @@ import torch import torch.fx.traceback as fx_traceback -import torch.utils.pytree as pytree +import torch.utils._pytree as pytree from torch import Tensor from torch._decomp.decompositions_for_rng import PhiloxStateTracker from torch._guards import detect_fake_mode