diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 4b245b7c5b53a..508c9782c23ef 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -12,7 +12,7 @@ from functorch_additional_op_db import additional_op_db import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from functorch import vmap from torch.testing._internal.autograd_function_db import autograd_function_db from torch.testing._internal.common_device_type import toleranceOverride diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 90f2dbec8e29d..bc1386919b2b8 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -20,7 +20,7 @@ import torch import torch._dynamo as torchdynamo import torch.nn as nn -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from functorch import grad, jacrev, make_fx, vjp, vmap from functorch.compile import ( aot_function, diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 29fb9f04e689b..b395b5cc34cbc 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4,7 +4,7 @@ import unittest import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from functorch.experimental import control_flow from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from torch._dynamo.testing import normalize_gm @@ -1741,8 +1741,6 @@ def f(x, y): self.assertEqual(expected_grads, grads) def test_map_autograd_nested_list(self): - import torch.utils.pytree.python as pytree - def f(x, y): a, b = x c, d = a @@ -5341,8 +5339,6 @@ def g(xs, y): self.check_map_count(gm, 2) def test_tracing_map_autograd_symbolic_list(self): - import torch.utils.pytree.python as pytree - def f(x, y): return [x[0].cos() + y.sin(), x[1].sin() * y.cos()] diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index cc7a48e747264..2c534c06cd17a 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -78,7 +78,7 @@ TestCase, xfailIfTorchDynamo, ) -from torch.utils.pytree.python import tree_flatten, tree_map, tree_unflatten +from torch.utils.pytree import tree_flatten, tree_map, tree_unflatten USE_TORCHVISION = False diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 32225bb5dfa1f..32ac3e6a90e53 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -32,7 +32,6 @@ import torch import torch.autograd.forward_ad as fwAD -import torch.utils.pytree.python as pytree from functorch import grad, jacfwd, jacrev, vjp, vmap from torch import Tensor from torch._functorch.eager_transforms import _as_tuple, jvp @@ -59,7 +58,7 @@ xfailIfS390X, ) from torch.testing._internal.opinfo.core import SampleInput -from torch.utils.pytree.python import tree_flatten, tree_map, tree_unflatten +from torch.utils.pytree import tree_flatten, tree_leaves, tree_map, tree_unflatten aten = torch.ops.aten @@ -161,7 +160,7 @@ def normalize_op_input_output3( f, args, kwargs, sample_args, output_process_fn_grad=None ): flat_args, args_spec = tree_flatten(args) - flat_sample_args = pytree.tree_leaves(sample_args) + flat_sample_args = tree_leaves(sample_args) diff_argnums = tuple( i for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args)) @@ -299,8 +298,8 @@ def wrapped(*args): if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, tangents @@ -334,8 +333,8 @@ def wrapped(*args): if isinstance(primals_out, torch.Tensor): return (primals_out, tangents_out) else: - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) return wrapped, primals + tangents @@ -1085,7 +1084,7 @@ def test_vmapvjpvjp(self, device, dtype, op): fn, args = get_vjpfull_variant(op, sample) result = fn(*args) cotangents = tree_map(lambda x: torch.randn_like(x), result) - cotangents = pytree.tree_leaves(cotangents) + cotangents = tree_leaves(cotangents) num_args = len(args) args_and_cotangents = tuple(args) + tuple(cotangents) @@ -1095,8 +1094,8 @@ def vjp_of_vjp(*args_and_cotangents): cotangents = args_and_cotangents[num_args:] result, vjp_fn = vjp(fn, *args) result_vjps = vjp_fn(cotangents) - result = pytree.tree_leaves(result) - result_vjps = pytree.tree_leaves(result_vjps) + result = tree_leaves(result) + result_vjps = tree_leaves(result_vjps) return (*result, *result_vjps) is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs) @@ -2110,8 +2109,8 @@ def jvp_of_vjp(*args): (primals, tangents) = tree_unflatten(args, spec) primals_out, tangents_out = jvp(push_vjp, primals, tangents) - flat_primals_out = pytree.tree_leaves(primals_out) - flat_tangents_out = pytree.tree_leaves(tangents_out) + flat_primals_out = tree_leaves(primals_out) + flat_tangents_out = tree_leaves(tangents_out) return tuple(flat_primals_out + flat_tangents_out) is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs) @@ -2428,7 +2427,7 @@ def is_differentiable(inp): ) def get_flat_differentiable(tree): - flattened = pytree.tree_leaves(tree) + flattened = tree_leaves(tree) return tuple(i for i in flattened if is_differentiable(i)) def get_differentiable_linked(list1, list2): @@ -2441,7 +2440,7 @@ def get_differentiable_linked(list1, list2): return zip(*paired_list) def filter_none(out): - flattened = pytree.tree_leaves(out) + flattened = tree_leaves(out) return tuple(o for o in flattened if o is not None) if not op.supports_autograd: @@ -2459,8 +2458,8 @@ def compute_grad(cotangents): out_flattened = out cotangents_flattened = cotangents if not isinstance(out_flattened, torch.Tensor): - out_flattened = pytree.tree_leaves(out) - cotangents_flattened = pytree.tree_leaves(cotangents) + out_flattened = tree_leaves(out) + cotangents_flattened = tree_leaves(cotangents) out_flattened, cotangents_flattened = get_differentiable_linked( out_flattened, cotangents_flattened ) diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 714d2e83abac1..54cd06cb4c875 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -38,7 +38,7 @@ import functorch import torch import torch.nn.functional as F -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap from functorch.experimental import chunk_vmap from torch import Tensor @@ -1342,7 +1342,7 @@ def _vmap_test( check_propagates_grad=True, ): result = vmap(op, in_dims, out_dims)(*inputs) - are_nested = [t.is_nested for t in pytree.tree_leaves(result)] + are_nested = [t.is_nested for t in pytree.tree_iter(result)] reference_result = reference_vmap( op, inputs, in_dims, out_dims, return_nt=any(are_nested) ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index b7fb63b2d44c1..12f45ea360c3c 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -15,7 +15,7 @@ from typing import Callable, Optional import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._guards import detect_fake_mode from torch._logging import getArtifactLogger diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index b91dba85896b2..376aae14da4d3 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -9,7 +9,7 @@ import torch import torch.utils.dlpack -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 1b418f4a4c3c2..9222ed574deea 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -12,10 +12,9 @@ import contextlib import itertools -from typing import Any, Optional, Union +from typing import Any, Optional, TYPE_CHECKING, Union import torch -import torch.utils.pytree.python as pytree from torch import Tensor from torch._C._dynamo.guards import compute_overlapping_tensors from torch._functorch._aot_autograd.schemas import PlainTensorMeta @@ -35,6 +34,10 @@ from .utils import strict_zip +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + zip = strict_zip @@ -345,8 +348,8 @@ def _graph_output_names(gm): def create_graph_signature( fx_g: torch.fx.GraphModule, fw_metadata: ViewAndMutationMeta, - in_spec: pytree.TreeSpec, - out_spec: pytree.TreeSpec, + in_spec: "PyTreeSpec", + out_spec: "PyTreeSpec", *, user_args_flat: list[Tensor], params_and_buffers_flat: list[Tensor], diff --git a/torch/_functorch/_aot_autograd/schemas.py b/torch/_functorch/_aot_autograd/schemas.py index 915325fbf3824..da45af2441d0c 100644 --- a/torch/_functorch/_aot_autograd/schemas.py +++ b/torch/_functorch/_aot_autograd/schemas.py @@ -10,10 +10,10 @@ from collections.abc import Iterable from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, NewType, Optional, Union +from typing import Any, Callable, NewType, Optional, TYPE_CHECKING, Union import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._guards import Source from torch._ops import OpOverload from torch._subclasses import FakeTensor @@ -28,6 +28,10 @@ from .utils import strict_zip +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + zip = strict_zip @@ -749,8 +753,8 @@ class GraphSignature: buffers_to_mutate: dict[GraphOutputName, FQN] user_inputs_to_mutate: dict[GraphOutputName, GraphInputName] - in_spec: pytree.TreeSpec - out_spec: pytree.TreeSpec + in_spec: "PyTreeSpec" + out_spec: "PyTreeSpec" backward_signature: Optional[BackwardSignature] @@ -761,8 +765,8 @@ class GraphSignature: def from_tracing_metadata( cls, *, - in_spec: pytree.TreeSpec, - out_spec: pytree.TreeSpec, + in_spec: "PyTreeSpec", + out_spec: "PyTreeSpec", graph_input_names: list[str], graph_output_names: list[str], view_mutation_metadata: ViewAndMutationMeta, diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index ff7a8f31d975c..4f54325721698 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Optional, TypeVar, Union import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch import SymInt, Tensor from torch._subclasses.fake_tensor import get_plain_tensors from torch.utils._python_dispatch import is_traceable_wrapper_subclass @@ -31,7 +31,7 @@ def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool: - args_flattened = pytree.arg_tree_leaves(*args) + args_flattened = pytree.tree_leaves(args) any_subclass_args = any( is_traceable_wrapper_subclass(x) for x in args_flattened diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index ee142f134f39c..dd971fb2fac0e 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -11,7 +11,7 @@ from typing import Any, Callable, Optional, Union import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._library.fake_class_registry import FakeScriptObject from torch._logging import getArtifactLogger from torch._subclasses.fake_tensor import FakeTensor @@ -138,19 +138,19 @@ def call_func_at_runtime_with_args( # Inspired by autodidax (thanks!) class PytreeThunk: - spec: Optional[pytree.TreeSpec] = None + spec: Optional["pytree.PyTreeSpec"] = None # These are some kinda dumb microoptimizations that save about 3-4 us of overhead. is_simple: Optional[ bool ] = None # if the output spec is a tuple/list, we won't bother unflattening it. is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec - def set(self, spec: pytree.TreeSpec) -> None: + def set(self, spec: "pytree.PyTreeSpec") -> None: assert self.spec is None or self.spec == spec assert spec is not None - self.spec: pytree.TreeSpec = spec + self.spec = spec if self.spec.type in {tuple, list} and all( - child.is_leaf() for child in spec.children_specs + child.is_leaf() for child in spec.children() ): self.is_simple = True if self.spec.is_leaf(): @@ -172,7 +172,7 @@ def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThu if kwargs is None: kwargs = {} # Save the args_spec for flat_tensor_args to unflatten while tracing - _, tensor_args_spec = pytree.tree_flatten((args, kwargs)) + tensor_args_spec = pytree.tree_structure((args, kwargs)) out_spec = PytreeThunk() def flat_fn(*flat_args): diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 5f0978bbf6796..56130a531b63d 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -11,7 +11,7 @@ import torch._dynamo.logging import torch.nn as nn import torch.utils.dlpack -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch import Tensor from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions from torch._dispatch.python import enable_python_dispatcher @@ -917,7 +917,7 @@ def aot_function( def returned_function(*args, **kwargs): nonlocal cached_res # Now flatten the tensor args - flat_args = pytree.arg_tree_leaves(*args, **kwargs) + flat_args = pytree.tree_leaves((args, kwargs)) # Compile the function and save it in the cache if cached_res is None: @@ -1397,7 +1397,7 @@ def flattened_joint(*args): fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args) - user_args_flat = pytree.arg_tree_leaves(*args, **kwargs) + user_args_flat = pytree.tree_leaves((args, kwargs)) return fx_g, create_graph_signature( fx_g, metadata, @@ -1453,7 +1453,7 @@ def aot_export_joint_simple( args, decompositions=decompositions, ) - in_spec, _kw_in_spec = in_spec.children_specs + in_spec, _kw_in_spec = in_spec.children() # At this point, we can just directly return the (joint or inference graph) that we traced. # First though: a bunch of assertions to make sure that our graph doesn't require # any calling convention changes compared to the original function. @@ -1480,7 +1480,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}" ) - if not all(child.is_leaf() for child in in_spec.children_specs): + if not all(child.is_leaf() for child in in_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}" ) @@ -1488,7 +1488,7 @@ def aot_export_joint_simple( raise RuntimeError( f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}" ) - if not all(child.is_leaf() for child in out_spec.children_specs): + if not all(child.is_leaf() for child in out_spec.children()): raise RuntimeError( f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}" ) diff --git a/torch/_functorch/autograd_function.py b/torch/_functorch/autograd_function.py index 72dd3bf87950b..de45209760ec0 100644 --- a/torch/_functorch/autograd_function.py +++ b/torch/_functorch/autograd_function.py @@ -2,7 +2,7 @@ from typing import NamedTuple import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._C._functorch import ( _unwrap_for_grad, _wrap_for_grad, @@ -167,8 +167,8 @@ def jvp(ctx, *tangents): def wrap_outputs_maintaining_identity( outputs, unwrapped_inputs, orig_inputs, wrap_fn, out_dims=NO_OUT_DIMS ): - flat_unwrapped_inputs = pytree.arg_tree_leaves(*unwrapped_inputs) - flat_orig_inputs = pytree.arg_tree_leaves(*orig_inputs) + flat_unwrapped_inputs = pytree.tree_leaves(unwrapped_inputs) + flat_orig_inputs = pytree.tree_leaves(orig_inputs) unwrapped_input_to_orig_input = { id(unwrapped): orig @@ -185,12 +185,13 @@ def wrap_outputs_maintaining_identity( # _broadcast_to_and_flatten returns None if it is unable to broadcast. # TODO: update following link from master to stable once that's out if flat_out_dims is None: + out_dims_spec = pytree.tree_structure(out_dims) raise RuntimeError( f"The autograd.Function's vmap staticmethod returned an " f"incompatible (output, out_dims) tuple. " f"Expected out_dims={out_dims} " f"to be compatible with the structure of `output`. " - f"out_dims has structure {pytree.tree_flatten(out_dims)[1]} " + f"out_dims has structure {pytree.treespec_pprint(out_dims_spec)} " f"but output has structure {spec}. " f"For more details, please see " f"https://pytorch.org/docs/main/notes/extending.func.html" @@ -275,10 +276,7 @@ def validate_vmap_returns_tuple_of_two_elements(result): @custom_function_call.py_impl(TransformType.Vmap) def custom_function_call_vmap(interpreter, autograd_function, *operands, **kwargs): - if any( - isinstance(val, torch.Tensor) - for val in torch.utils.pytree.python.tree_flatten(kwargs)[0] - ): + if pytree.tree_any(torch.is_tensor, kwargs): raise NotImplementedError( f"Run vmap on autograd.Function with kwarg-only Tensor args. " f"Please do not pass kwarg-only Tensors to autograd.Function. " @@ -518,7 +516,7 @@ def backward_no_context(inputs): # the corresponding in_dims with None. def get_tangents_in_dims(input_dims, tangents): flat_in_dims, spec = pytree.tree_flatten(input_dims) - flat_tangents = pytree.arg_tree_leaves(*tangents) + flat_tangents = pytree.tree_leaves(tangents) result = [ None if tangent is None else in_dim for in_dim, tangent in zip(flat_in_dims, flat_tangents) diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py index 65db0ed7dd426..d7a6147a893b8 100644 --- a/torch/_functorch/compile_utils.py +++ b/torch/_functorch/compile_utils.py @@ -5,9 +5,8 @@ import torch import torch.fx as fx -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch.multiprocessing.reductions import StorageWeakRef -from torch.utils.pytree.python import tree_flatten aten = torch.ops.aten @@ -102,7 +101,7 @@ def checkable_node(node: fx.Node) -> bool: # substitute args and kwargs members to their mapping in env if exists # specs can be used to reconstruct nested list/dictionaries def substitute(arg_list): - arg_list, spec = tree_flatten(arg_list) + arg_list, spec = pytree.tree_flatten(arg_list) for i in range(len(arg_list)): v = arg_list[i] if isinstance(v, torch.fx.node.Node) and v in env: diff --git a/torch/_functorch/compilers.py b/torch/_functorch/compilers.py index de8c42b6713cd..8d8a262b8f64b 100644 --- a/torch/_functorch/compilers.py +++ b/torch/_functorch/compilers.py @@ -14,7 +14,7 @@ import torch import torch.fx as fx import torch.nn as nn -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch import SymInt from torch._decomp import get_decompositions from torch.fx.experimental.symbolic_shapes import bind_symbols @@ -160,8 +160,8 @@ def check(nv, rv, desc): r = super().run_node(n) if "val" in n.meta: - n_vals, _n_spec = pytree.tree_flatten(n.meta["val"]) - r_vals, _r_spec = pytree.tree_flatten(r) + n_vals = pytree.tree_leaves(n.meta["val"]) + r_vals = pytree.tree_leaves(r) # TODO: There is some sort of problem where we record that an # operator returned a tuple/list, and then later it turns out the # real version of the operator returned a list/tuple. Need to diff --git a/torch/_functorch/eager_transforms.py b/torch/_functorch/eager_transforms.py index 77cfc0486f260..df23677af38c5 100644 --- a/torch/_functorch/eager_transforms.py +++ b/torch/_functorch/eager_transforms.py @@ -12,7 +12,6 @@ import torch import torch.autograd.forward_ad as fwAD -import torch.utils.pytree.python as pytree from torch._C._functorch import ( _assert_wrapped_functional, _func_decrement_nesting, @@ -35,11 +34,15 @@ from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental import const_fold from torch.fx.experimental.proxy_tensor import make_fx -from torch.utils.pytree.python import ( +from torch.utils.pytree import ( + tree_any, tree_flatten, + tree_iter, + tree_leaves, tree_map, tree_map_, tree_map_only, + tree_structure, tree_unflatten, treespec_pprint, ) @@ -100,8 +103,7 @@ def _is_differentiable(maybe_tensor): def _any_differentiable(tensor_or_tuple_of_tensors): - flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors) - return any(tuple(map(_is_differentiable, flat_args))) + return tree_any(_is_differentiable, tensor_or_tuple_of_tensors) def _wrap_tensor_for_grad(maybe_tensor, level): @@ -414,8 +416,7 @@ def _safe_zero_index(x): # jacrev and jacfwd don't support complex functions # Helper function to throw appropriate error. def error_if_complex(func_name, args, is_input): - flat_args = pytree.tree_leaves(args) - for idx, arg in enumerate(flat_args): + for idx, arg in enumerate(tree_iter(args)): if isinstance(arg, torch.Tensor) and arg.dtype.is_complex: input_or_output = "inputs" if is_input else "outputs" err_msg = ( @@ -610,7 +611,7 @@ def compute_jacobian_stacked(): else: # chunk_size is None or chunk_size != 1 chunked_result = vmap(vjp_fn)(basis) - flat_results = pytree.tree_leaves(chunked_result) + flat_results = tree_leaves(chunked_result) if chunk_size == 1: flat_results = tree_map( @@ -667,7 +668,7 @@ def compute_jacobian_preallocate_and_copy(): else: # chunk_size is None or chunk_size != 1 chunked_result = vmap(vjp_fn)(basis) - flat_results = pytree.tree_leaves(chunked_result) + flat_results = tree_leaves(chunked_result) # Short-circuit if we have a single chunk. if chunk_size is None or chunk_size >= out_vec_size: @@ -1275,9 +1276,7 @@ def push_jvp(basis): results, aux = results # aux is in the standard basis format, e.g. NxN matrix # We need to fetch the first element as original `func` output - flat_aux, aux_spec = tree_flatten(aux) - flat_aux = [value[0] for value in flat_aux] - aux = tree_unflatten(flat_aux, aux_spec) + aux = tree_map(lambda value: value[0], aux) jac_outs, spec = tree_flatten(results) # Most probably below output check can never raise an error @@ -1637,10 +1636,10 @@ def wrapped(*args, **kwargs): func_args = _wrap_all_tensors_to_functional(args, func_level) func_kwargs = _wrap_all_tensors_to_functional(kwargs, func_level) - flattened_unwrapped_args = pytree.arg_tree_leaves(*args) - flattened_wrapped_args = pytree.arg_tree_leaves(*func_args) - flattened_unwrapped_kwargs = pytree.arg_tree_leaves(**kwargs) - flattened_wrapped_kwargs = pytree.arg_tree_leaves(**func_kwargs) + flattened_unwrapped_args = tree_leaves(args) + flattened_wrapped_args = tree_leaves(func_args) + flattened_unwrapped_kwargs = tree_leaves(kwargs) + flattened_wrapped_kwargs = tree_leaves(func_kwargs) func_outputs = func(*func_args, **func_kwargs) outputs = _unwrap_all_tensors_from_functional( @@ -1721,7 +1720,7 @@ def linearize(func: Callable, *primals) -> tuple[Any, Callable]: # make_fx such that it also returns the output. output = func(*primals) - _, output_spec = tree_flatten(output) + output_spec = tree_structure(output) flat_primals, primals_argspec = tree_flatten(primals) diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index 2afc904763ccd..d5602c6c76261 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -14,7 +14,7 @@ import torch import torch._inductor.inductor_prims import torch.fx as fx -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types from torch.fx.experimental.sym_node import magic_methods, method_to_operator @@ -196,7 +196,7 @@ def _extract_graph_with_inputs_outputs( elif node.op == "placeholder": env[node] = InvalidNode # type: ignore[assignment] elif node.op == "call_function": - all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs) + all_args = pytree.tree_leaves((node.args, node.kwargs)) all_args = [ isinstance(env[x], InvalidNodeBase) for x in all_args @@ -274,8 +274,8 @@ def _must_be_in_backward(node: fx.Node) -> bool: def _extract_fwd_bwd_outputs( joint_module: fx.GraphModule, *, num_fwd_outputs ) -> tuple[list[fx.Node], list[fx.Node]]: - outputs = pytree.arg_tree_leaves( - *(node.args for node in joint_module.graph.find_nodes(op="output")) + outputs = pytree.tree_leaves( + [node.args for node in joint_module.graph.find_nodes(op="output")] ) fwd_outputs = outputs[:num_fwd_outputs] bwd_outputs = outputs[num_fwd_outputs:] diff --git a/torch/_functorch/pyfunctorch.py b/torch/_functorch/pyfunctorch.py index 1bae89e882ae8..2f657dce42fce 100644 --- a/torch/_functorch/pyfunctorch.py +++ b/torch/_functorch/pyfunctorch.py @@ -4,7 +4,7 @@ from typing import Any import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._C._functorch import ( CFunctionalizeInterpreterPtr, CGradInterpreterPtr, @@ -163,7 +163,7 @@ def __init__(self, cdata: CInterpreter): def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( - torch.Tensor, self._cptr.lift, [args, kwargs] + torch.Tensor, self._cptr.lift, (args, kwargs) ) return args, kwargs @@ -197,7 +197,7 @@ def __init__(self, cdata: CInterpreter): def lift(self, args, kwargs): args, kwargs = pytree.tree_map_only( - torch.Tensor, self._cptr.lift, [args, kwargs] + torch.Tensor, self._cptr.lift, (args, kwargs) ) return args, kwargs diff --git a/torch/_functorch/vmap.py b/torch/_functorch/vmap.py index 814ca1f12c77f..c4eec7fc1ced8 100644 --- a/torch/_functorch/vmap.py +++ b/torch/_functorch/vmap.py @@ -12,7 +12,7 @@ import os import threading from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch import Tensor @@ -23,15 +23,18 @@ _vmap_increment_nesting, is_batchedtensor, ) -from torch.utils.pytree.python import ( +from torch.utils.pytree import ( _broadcast_to_and_flatten, tree_flatten, tree_map_, tree_unflatten, - TreeSpec, ) +if TYPE_CHECKING: + from torch.utils.pytree import PyTreeSpec + + in_dims_t = Union[int, tuple] out_dims_t = Union[int, tuple[int, ...]] @@ -91,7 +94,7 @@ def _as_tuple( def _process_batched_inputs( in_dims: in_dims_t, args: tuple, func: Callable -) -> tuple[int, list[Any], list[Any], TreeSpec]: +) -> tuple[int, list[Any], list[Any], "PyTreeSpec"]: if not isinstance(in_dims, int) and not isinstance(in_dims, tuple): raise ValueError( f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(): " diff --git a/torch/_higher_order_ops/map.py b/torch/_higher_order_ops/map.py index d1528997f4234..cde1511288f6e 100644 --- a/torch/_higher_order_ops/map.py +++ b/torch/_higher_order_ops/map.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import torch -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._C import DispatchKey from torch._dispatch.python import suspend_functionalization from torch._functorch.aot_autograd import AOTConfig, create_joint diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 238fea2e9e167..60ab25ccde92d 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -6,7 +6,7 @@ import torch import torch.fx.traceback as fx_traceback -import torch.utils.pytree.python as pytree +import torch.utils.pytree as pytree from torch._guards import detect_fake_mode from torch._ops import OperatorBase from torch._subclasses.fake_tensor import FakeTensor @@ -38,7 +38,7 @@ def autograd_not_implemented_inner( """ with torch._C._AutoDispatchBelowAutograd(): result = operator(*args, **kwargs) - flat_operands = pytree.arg_tree_leaves(*args) + flat_operands = pytree.tree_leaves(args) if torch.is_grad_enabled() and any( f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor) ): diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 23c0769a1289f..6290d4ccb3618 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -1632,7 +1632,7 @@ def _is_impure(node): ), buffers_to_mutate={}, user_inputs_to_mutate={}, - in_spec=in_spec, + in_spec=in_spec, # type: ignore[arg-type] out_spec=out_spec, # type: ignore[arg-type] backward_signature=None, input_tokens=[], diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 188fe01259e10..ce25f5e32fa31 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -862,6 +862,12 @@ def __eq__(self, other: PyTree) -> bool: def is_leaf(self) -> bool: return self.num_nodes == 1 and self.num_leaves == 1 + def children(self) -> list["TreeSpec"]: + return self.children_specs.copy() + + def child(self, index: int) -> "TreeSpec": + return self.children_specs[index] + def flatten_up_to(self, tree: PyTree) -> list[PyTree]: def helper(treespec: TreeSpec, tree: PyTree, subtrees: list[PyTree]) -> None: if treespec.is_leaf(): diff --git a/torch/utils/pytree/__init__.py b/torch/utils/pytree/__init__.py index db7fb01a8d97e..3bd3bfa21959a 100644 --- a/torch/utils/pytree/__init__.py +++ b/torch/utils/pytree/__init__.py @@ -41,6 +41,7 @@ from types import ModuleType from torch.utils._cxx_pytree import ( # noqa: TC004 + _broadcast_to_and_flatten as _broadcast_to_and_flatten, PyTreeSpec as PyTreeSpec, tree_all as tree_all, tree_all_only as tree_all_only, @@ -227,6 +228,10 @@ def exported(*args: _P.args, **kwargs: _P.kwargs) -> _R: treespec_pprint = _reexport(implementation.module.treespec_pprint) +# Used in vmap +_broadcast_to_and_flatten = _reexport(implementation.module._broadcast_to_and_flatten) + + del _reexport del PyTreeImplementation