Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
e584f19
Update
XuehaiPan Oct 14, 2024
b4ab24b
Update
XuehaiPan Oct 14, 2024
243d2cb
Update
XuehaiPan Oct 14, 2024
31689c3
Update
XuehaiPan Oct 14, 2024
2766bd3
Update
XuehaiPan Oct 14, 2024
aa59c70
Update
XuehaiPan Oct 14, 2024
160dbac
Update
XuehaiPan Oct 16, 2024
504143c
Update
XuehaiPan Oct 16, 2024
1ad3005
Update
XuehaiPan Oct 16, 2024
82c2cdd
Update
XuehaiPan Oct 16, 2024
bcd1181
Update
XuehaiPan Oct 16, 2024
703664f
Update
XuehaiPan Oct 16, 2024
27eae9c
Update
XuehaiPan Oct 16, 2024
9a1e176
Update
XuehaiPan Oct 16, 2024
bdd6018
Update
XuehaiPan Oct 16, 2024
758881d
Update
XuehaiPan Oct 16, 2024
33d14fa
Update
XuehaiPan Oct 16, 2024
c105236
Update
XuehaiPan Oct 17, 2024
c3ee9b1
Update
XuehaiPan Oct 17, 2024
441411e
Update
XuehaiPan Oct 17, 2024
d76e737
Update
XuehaiPan Oct 25, 2024
2ba8ee6
Update
XuehaiPan Oct 26, 2024
8035b99
Update
XuehaiPan Oct 29, 2024
5b6d258
Update
XuehaiPan Oct 29, 2024
5cfe929
Update
XuehaiPan Oct 29, 2024
5569bd0
Update
XuehaiPan Oct 29, 2024
ee1263b
Update
XuehaiPan Oct 29, 2024
4c4a9aa
Update
XuehaiPan Oct 29, 2024
245352b
Update
XuehaiPan Oct 30, 2024
008e423
Update
XuehaiPan Oct 30, 2024
7a5697b
Update
XuehaiPan Nov 5, 2024
49259b7
Update
XuehaiPan Nov 11, 2024
5313336
Update
XuehaiPan Nov 17, 2024
c33d8c4
Update
XuehaiPan Nov 20, 2024
1e9879f
Update
XuehaiPan Nov 20, 2024
cd6a342
Update
XuehaiPan Nov 20, 2024
ec06824
Update
XuehaiPan Nov 20, 2024
3aaa0d1
Update
XuehaiPan Nov 20, 2024
0d8c6d5
Update
XuehaiPan Nov 20, 2024
76c3aa1
Update
XuehaiPan Nov 20, 2024
434f34a
Update
XuehaiPan Nov 21, 2024
205337b
Update
XuehaiPan Nov 21, 2024
c2aedc3
Update
XuehaiPan Nov 21, 2024
c3c65b0
Update
XuehaiPan Nov 21, 2024
3eff6fd
Update
XuehaiPan Nov 22, 2024
7e59277
Update
XuehaiPan Nov 22, 2024
6d5e8ae
Update
XuehaiPan Nov 26, 2024
9e88565
Update
XuehaiPan Nov 26, 2024
1668995
Update
XuehaiPan Nov 27, 2024
065d3c3
Update
XuehaiPan Dec 2, 2024
ffd7442
Update
XuehaiPan Dec 2, 2024
04bfccf
Update
XuehaiPan Dec 7, 2024
bccc9d4
Update
XuehaiPan Dec 7, 2024
a6db8c5
Update
XuehaiPan Dec 9, 2024
4748775
Update
XuehaiPan Dec 13, 2024
009bdef
Update
XuehaiPan Dec 13, 2024
55bab5f
Update
XuehaiPan Dec 25, 2024
87e50bf
Update
XuehaiPan Dec 25, 2024
02954a7
Update
XuehaiPan Dec 25, 2024
82235e6
Update
XuehaiPan Jan 7, 2025
4d29b05
Update
XuehaiPan Jan 8, 2025
3944ed8
Update
XuehaiPan Jan 8, 2025
9e7e277
Update
XuehaiPan Jan 8, 2025
8e38bfb
Update
XuehaiPan Jan 8, 2025
3ba4a9c
Update
XuehaiPan Jan 9, 2025
56cd519
Update
XuehaiPan Jan 9, 2025
e6407da
Update
XuehaiPan Jan 9, 2025
04f392b
Update
XuehaiPan Jan 10, 2025
5ebf0f9
Update
XuehaiPan Jan 10, 2025
b024156
Update
XuehaiPan Feb 4, 2025
923a308
Update
XuehaiPan Feb 9, 2025
90c71ef
Update
XuehaiPan Feb 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/functorch/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functorch_additional_op_db import additional_op_db

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from functorch import vmap
from torch.testing._internal.autograd_function_db import autograd_function_db
from torch.testing._internal.common_device_type import toleranceOverride
Expand Down
2 changes: 1 addition & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch._dynamo as torchdynamo
import torch.nn as nn
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from functorch import grad, jacrev, make_fx, vjp, vmap
from functorch.compile import (
aot_function,
Expand Down
6 changes: 1 addition & 5 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from functorch.experimental import control_flow
from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException
from torch._dynamo.testing import normalize_gm
Expand Down Expand Up @@ -1741,8 +1741,6 @@ def f(x, y):
self.assertEqual(expected_grads, grads)

def test_map_autograd_nested_list(self):
import torch.utils.pytree.python as pytree

def f(x, y):
a, b = x
c, d = a
Expand Down Expand Up @@ -5341,8 +5339,6 @@ def g(xs, y):
self.check_map_count(gm, 2)

def test_tracing_map_autograd_symbolic_list(self):
import torch.utils.pytree.python as pytree

def f(x, y):
return [x[0].cos() + y.sin(), x[1].sin() * y.cos()]

Expand Down
2 changes: 1 addition & 1 deletion test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
TestCase,
xfailIfTorchDynamo,
)
from torch.utils.pytree.python import tree_flatten, tree_map, tree_unflatten
from torch.utils.pytree import tree_flatten, tree_map, tree_unflatten


USE_TORCHVISION = False
Expand Down
31 changes: 15 additions & 16 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import torch
import torch.autograd.forward_ad as fwAD
import torch.utils.pytree.python as pytree
from functorch import grad, jacfwd, jacrev, vjp, vmap
from torch import Tensor
from torch._functorch.eager_transforms import _as_tuple, jvp
Expand All @@ -59,7 +58,7 @@
xfailIfS390X,
)
from torch.testing._internal.opinfo.core import SampleInput
from torch.utils.pytree.python import tree_flatten, tree_map, tree_unflatten
from torch.utils.pytree import tree_flatten, tree_leaves, tree_map, tree_unflatten


aten = torch.ops.aten
Expand Down Expand Up @@ -161,7 +160,7 @@ def normalize_op_input_output3(
f, args, kwargs, sample_args, output_process_fn_grad=None
):
flat_args, args_spec = tree_flatten(args)
flat_sample_args = pytree.tree_leaves(sample_args)
flat_sample_args = tree_leaves(sample_args)
diff_argnums = tuple(
i
for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
Expand Down Expand Up @@ -299,8 +298,8 @@ def wrapped(*args):
if isinstance(primals_out, torch.Tensor):
return (primals_out, tangents_out)
else:
flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
flat_primals_out = tree_leaves(primals_out)
flat_tangents_out = tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)

return wrapped, tangents
Expand Down Expand Up @@ -334,8 +333,8 @@ def wrapped(*args):
if isinstance(primals_out, torch.Tensor):
return (primals_out, tangents_out)
else:
flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
flat_primals_out = tree_leaves(primals_out)
flat_tangents_out = tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)

return wrapped, primals + tangents
Expand Down Expand Up @@ -1085,7 +1084,7 @@ def test_vmapvjpvjp(self, device, dtype, op):
fn, args = get_vjpfull_variant(op, sample)
result = fn(*args)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
cotangents = pytree.tree_leaves(cotangents)
cotangents = tree_leaves(cotangents)
num_args = len(args)

args_and_cotangents = tuple(args) + tuple(cotangents)
Expand All @@ -1095,8 +1094,8 @@ def vjp_of_vjp(*args_and_cotangents):
cotangents = args_and_cotangents[num_args:]
result, vjp_fn = vjp(fn, *args)
result_vjps = vjp_fn(cotangents)
result = pytree.tree_leaves(result)
result_vjps = pytree.tree_leaves(result_vjps)
result = tree_leaves(result)
result_vjps = tree_leaves(result_vjps)
return (*result, *result_vjps)

is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
Expand Down Expand Up @@ -2110,8 +2109,8 @@ def jvp_of_vjp(*args):
(primals, tangents) = tree_unflatten(args, spec)
primals_out, tangents_out = jvp(push_vjp, primals, tangents)

flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
flat_primals_out = tree_leaves(primals_out)
flat_tangents_out = tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)

is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
Expand Down Expand Up @@ -2428,7 +2427,7 @@ def is_differentiable(inp):
)

def get_flat_differentiable(tree):
flattened = pytree.tree_leaves(tree)
flattened = tree_leaves(tree)
return tuple(i for i in flattened if is_differentiable(i))

def get_differentiable_linked(list1, list2):
Expand All @@ -2441,7 +2440,7 @@ def get_differentiable_linked(list1, list2):
return zip(*paired_list)

def filter_none(out):
flattened = pytree.tree_leaves(out)
flattened = tree_leaves(out)
return tuple(o for o in flattened if o is not None)

if not op.supports_autograd:
Expand All @@ -2459,8 +2458,8 @@ def compute_grad(cotangents):
out_flattened = out
cotangents_flattened = cotangents
if not isinstance(out_flattened, torch.Tensor):
out_flattened = pytree.tree_leaves(out)
cotangents_flattened = pytree.tree_leaves(cotangents)
out_flattened = tree_leaves(out)
cotangents_flattened = tree_leaves(cotangents)
out_flattened, cotangents_flattened = get_differentiable_linked(
out_flattened, cotangents_flattened
)
Expand Down
4 changes: 2 additions & 2 deletions test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import functorch
import torch
import torch.nn.functional as F
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from functorch import grad, grad_and_value, jacfwd, jvp, vjp, vmap
from functorch.experimental import chunk_vmap
from torch import Tensor
Expand Down Expand Up @@ -1342,7 +1342,7 @@ def _vmap_test(
check_propagates_grad=True,
):
result = vmap(op, in_dims, out_dims)(*inputs)
are_nested = [t.is_nested for t in pytree.tree_leaves(result)]
are_nested = [t.is_nested for t in pytree.tree_iter(result)]
reference_result = reference_vmap(
op, inputs, in_dims, out_dims, return_nt=any(are_nested)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Callable, Optional

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch import Tensor
from torch._guards import detect_fake_mode
from torch._logging import getArtifactLogger
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import torch.utils.dlpack
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch import Tensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code
Expand Down
11 changes: 7 additions & 4 deletions torch/_functorch/_aot_autograd/input_output_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@

import contextlib
import itertools
from typing import Any, Optional, Union
from typing import Any, Optional, TYPE_CHECKING, Union

import torch
import torch.utils.pytree.python as pytree
from torch import Tensor
from torch._C._dynamo.guards import compute_overlapping_tensors
from torch._functorch._aot_autograd.schemas import PlainTensorMeta
Expand All @@ -35,6 +34,10 @@
from .utils import strict_zip


if TYPE_CHECKING:
from torch.utils.pytree import PyTreeSpec


zip = strict_zip


Expand Down Expand Up @@ -345,8 +348,8 @@ def _graph_output_names(gm):
def create_graph_signature(
fx_g: torch.fx.GraphModule,
fw_metadata: ViewAndMutationMeta,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
in_spec: "PyTreeSpec",
out_spec: "PyTreeSpec",
*,
user_args_flat: list[Tensor],
params_and_buffers_flat: list[Tensor],
Expand Down
16 changes: 10 additions & 6 deletions torch/_functorch/_aot_autograd/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from collections.abc import Iterable
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, NewType, Optional, Union
from typing import Any, Callable, NewType, Optional, TYPE_CHECKING, Union

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch._guards import Source
from torch._ops import OpOverload
from torch._subclasses import FakeTensor
Expand All @@ -28,6 +28,10 @@
from .utils import strict_zip


if TYPE_CHECKING:
from torch.utils.pytree import PyTreeSpec


zip = strict_zip


Expand Down Expand Up @@ -749,8 +753,8 @@ class GraphSignature:
buffers_to_mutate: dict[GraphOutputName, FQN]
user_inputs_to_mutate: dict[GraphOutputName, GraphInputName]

in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
in_spec: "PyTreeSpec"
out_spec: "PyTreeSpec"

backward_signature: Optional[BackwardSignature]

Expand All @@ -761,8 +765,8 @@ class GraphSignature:
def from_tracing_metadata(
cls,
*,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
in_spec: "PyTreeSpec",
out_spec: "PyTreeSpec",
graph_input_names: list[str],
graph_output_names: list[str],
view_mutation_metadata: ViewAndMutationMeta,
Expand Down
4 changes: 2 additions & 2 deletions torch/_functorch/_aot_autograd/subclass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Callable, Optional, TypeVar, Union

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch import SymInt, Tensor
from torch._subclasses.fake_tensor import get_plain_tensors
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
Expand All @@ -31,7 +31,7 @@


def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
args_flattened = pytree.arg_tree_leaves(*args)
args_flattened = pytree.tree_leaves(args)
any_subclass_args = any(
is_traceable_wrapper_subclass(x)
for x in args_flattened
Expand Down
12 changes: 6 additions & 6 deletions torch/_functorch/_aot_autograd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any, Callable, Optional, Union

import torch
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch._library.fake_class_registry import FakeScriptObject
from torch._logging import getArtifactLogger
from torch._subclasses.fake_tensor import FakeTensor
Expand Down Expand Up @@ -138,19 +138,19 @@ def call_func_at_runtime_with_args(

# Inspired by autodidax (thanks!)
class PytreeThunk:
spec: Optional[pytree.TreeSpec] = None
spec: Optional["pytree.PyTreeSpec"] = None
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
is_simple: Optional[
bool
] = None # if the output spec is a tuple/list, we won't bother unflattening it.
is_really_simple: Optional[bool] = None # if the output spec is a LeafSpec

def set(self, spec: pytree.TreeSpec) -> None:
def set(self, spec: "pytree.PyTreeSpec") -> None:
assert self.spec is None or self.spec == spec
assert spec is not None
self.spec: pytree.TreeSpec = spec
self.spec = spec
if self.spec.type in {tuple, list} and all(
child.is_leaf() for child in spec.children_specs
child.is_leaf() for child in spec.children()
):
self.is_simple = True
if self.spec.is_leaf():
Expand All @@ -172,7 +172,7 @@ def create_tree_flattened_fn(fn, args, kwargs=None) -> tuple[Callable, PytreeThu
if kwargs is None:
kwargs = {}
# Save the args_spec for flat_tensor_args to unflatten while tracing
_, tensor_args_spec = pytree.tree_flatten((args, kwargs))
tensor_args_spec = pytree.tree_structure((args, kwargs))
out_spec = PytreeThunk()

def flat_fn(*flat_args):
Expand Down
12 changes: 6 additions & 6 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch._dynamo.logging
import torch.nn as nn
import torch.utils.dlpack
import torch.utils.pytree.python as pytree
import torch.utils.pytree as pytree
from torch import Tensor
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
from torch._dispatch.python import enable_python_dispatcher
Expand Down Expand Up @@ -917,7 +917,7 @@ def aot_function(
def returned_function(*args, **kwargs):
nonlocal cached_res
# Now flatten the tensor args
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
flat_args = pytree.tree_leaves((args, kwargs))

# Compile the function and save it in the cache
if cached_res is None:
Expand Down Expand Up @@ -1397,7 +1397,7 @@ def flattened_joint(*args):

fx_g = make_fx(flattened_joint, record_module_stack=True)(*full_args)

user_args_flat = pytree.arg_tree_leaves(*args, **kwargs)
user_args_flat = pytree.tree_leaves((args, kwargs))
return fx_g, create_graph_signature(
fx_g,
metadata,
Expand Down Expand Up @@ -1453,7 +1453,7 @@ def aot_export_joint_simple(
args,
decompositions=decompositions,
)
in_spec, _kw_in_spec = in_spec.children_specs
in_spec, _kw_in_spec = in_spec.children()
# At this point, we can just directly return the (joint or inference graph) that we traced.
# First though: a bunch of assertions to make sure that our graph doesn't require
# any calling convention changes compared to the original function.
Expand All @@ -1480,15 +1480,15 @@ def aot_export_joint_simple(
raise RuntimeError(
f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}"
)
if not all(child.is_leaf() for child in in_spec.children_specs):
if not all(child.is_leaf() for child in in_spec.children()):
raise RuntimeError(
f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}"
)
if out_spec.is_leaf():
raise RuntimeError(
f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}"
)
if not all(child.is_leaf() for child in out_spec.children_specs):
if not all(child.is_leaf() for child in out_spec.children()):
raise RuntimeError(
f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}"
)
Expand Down
Loading