Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

XuehaiPan
Copy link
Collaborator

@XuehaiPan XuehaiPan commented Feb 28, 2025

Stack from ghstack (oldest at bottom):

Differences between torch.pytree and torch.utils.pytree:

  1. APIs in torch.utils.pytree have a tree_ prefix:

    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)
    
    leaves, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)

    This is similar to the JAX pytree API: jax.tree_util.tree_* vs. jax.tree.*.

  2. The argument order of unflatten is reversed for better functools.partial support:

    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)
    
    tree = torch.pytree.unflatten(treespec, leaves)
    
    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)

    This is also aligned with the JAX pytree API: jax.tree.unflatten(treedef, leaves).

Because we are adding a completely new module, there are no BC issues.

cc @zou3519

[ghstack-poisoned]
Copy link

pytorch-bot bot commented Feb 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/148180

Note: Links to docs will display an error until the docs builds have been completed.

❗ 2 Active SEVs

There are 2 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit 3f0985b with merge base 12d7cc5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@XuehaiPan XuehaiPan requested a review from albanD as a code owner February 28, 2025 14:11
[ghstack-poisoned]
XuehaiPan added a commit that referenced this pull request Feb 28, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 0bf5096
Pull Request resolved: #148180
]


def unflatten(treespec: PyTreeSpec, leaves: Iterable[_Any]) -> PyTree:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a TypeVar to dispatch iterable typing to the stub. No reason to erase the typing info here in case we improve typing in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTree is an alias of typing.Any which is not a generic type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that might change in the futre though. Good to not erase typing if unnecessary and make it flexible to future refactors

@zou3519 zou3519 requested a review from vmoens February 28, 2025 15:19
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change is localized and well documented + motivated, I'm happy with it
Thanks @XuehaiPan

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jun 27, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 7c73c7a
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jun 28, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: e64ef6f
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jul 3, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: ac60aaa
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jul 9, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: acb3131
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jul 17, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: c09eaed
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jul 25, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 2bb4cc0
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Jul 31, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: e456656
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Aug 8, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: d64ad3f
Pull Request resolved: pytorch#148180
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Aug 8, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: d64ad3f
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Aug 17, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: b3c6c37
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Sep 6, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 063982a
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Sep 19, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: bc86ea5
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 8, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 0a5bf5f
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 8, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 1b874cb
Pull Request resolved: pytorch#148180
[ghstack-poisoned]
XuehaiPan added a commit to XuehaiPan/pytorch that referenced this pull request Oct 11, 2025
Differences between `torch.pytree` and `torch.utils.pytree`:

1. APIs in `torch.utils.pytree` have a `tree_` prefix:

    ```python
    leaves, treespec = torch.utils.pytree.tree_flatten(tree)
    new_tree = torch.utils.pytree.tree_map(func, tree)

    leaevs, treespec = torch.pytree.flatten(tree)
    new_tree = torch.pytree.map(func, tree)
    ```

2. The argument order of `unflatten` is reversed for better `functools.partial` support:

    ```python
    tree = torch.utils.pytree.tree_unflatten(leaves, treespec)

    tree = torch.pytree.unflatten(treespec, leaves)

    unflatten_fn = functools.partial(torch.pytree.unflatten, treespec)
    tree1 = unflatten_fn(leaves1)
    tree2 = unflatten_fn(leaves2)
    ```

    This is also aligned with the JAX pytree API: `jax.tree.unflatten(treedef, leaves)`.

ghstack-source-id: 0a4e826
Pull Request resolved: pytorch#148180
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-test-showlocals Show local variables on test failures ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request module: dynamo module: pytree open source topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants