Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Tree factorisation, TreeArrays and their manipulation #18

@Wuhyun

Description

@Wuhyun

While furax currently provides many efficient operations, it is often difficult for users to manipulate high-dimensional data into the format they need. This is especially tricky when some of the 'axes' are pytrees.

I propose some new features for the efficient and modular handling of pytrees that consist of traced arrays. All names and feature details may be adjusted for clarity and/or efficiency.

Tree factorisation

Two treedefs A and B can be composed to give A ⊗ B, where each leaf of A is replaced by B.

Let's call a treedef P to be irreducible if no non-empty treedefs A and B exist such that P = A ⊗ B.

Then, any pytreedef A can be uniquely factorised into a set of irreducible pytreedefs: A = P_0 ⊗ P_1 ⊗ ...⊗ P_(n-1), where n is less than or equal to the A's maximum tree depth.

For example:

  • PyTreeDef({'a': [*, *], 'b': {'x': [*, *], 'y': [*, *]}}) = PyTreeDef({'a': *, 'b': {'x': *, 'y': *}) ⊗ PyTreeDef([*, *]).
  • PyTreeDef({'a': [*], 'b': [*, *]}) = PyTreeDef({'a': [*], 'b': [*, *]}) (irreducible).
furax.tree.factorize_treedef(A: PyTreeDef) -> list[PyTreeDef]:
  ''' Factorize a given treedef into irreducible treedefs.  '''
  ...

This can be implemented, e.g., as a recursive algorithm that goes one tree depth at a time and checks if all child trees have the same structure. If so, extract it as an irreducible treedef; otherwise, go deeper down the tree.

We may want to add some generic treedef operations:

furax.tree.treedef_prod(A: PyTreeDef, B: PyTreeDef) -> PyTreeDef: '''(A, B) -> (A ⊗ B)''' ...

furax.tree.treedef_divide(A: PyTreeDef, B: PyTreeDef) -> PyTreeDef: '''(A ⊗ B, B) -> (A)''' ...

furax.tree.treedef_transpose(A: PyTreeDef) -> PyTreeDef:
  '''P_0 ⊗ P_1 ⊗ ...⊗ P_(n-1) -> P_(n-1) ⊗ P_(n-2) ⊗ ...⊗ P_0'''
  ...

furax.tree.moveaxis(A: PyTreeDef | PyTree, source, destination) -> PyTreeDef | PyTree:
  '''Equivalent of np.moveaxis but for the tree factorization A = P_0 ⊗ P_1 ⊗ ...⊗ P_(n-1).
     For example, move_axis(P_0 ⊗ P_1 ⊗ P_2, 0, 1) = P_1 ⊗ P_0 ⊗ P_2 if P_i's are irreducible.'''
  ...

TreeArrays

TreeArrays are pytrees of traced arrays where all leaves are traced arrays of identical shape and data type.

Suppose that a TreeArray x has treedef A and leaves (traced arrays) of shape (s_0, ..., s_(m-1)) each. We may factorise A = P_0 ⊗ P_1 ⊗ ...⊗ P_(n-1), where each irreducible treedef P_i has d_i leaves. Then, the TreeArray x can be stored as a big traced array of dimension (d_0, ..., d_(n-1), s_0, ..., s_(m-1)).

TreeArrays could be implemented as a dataclass with a static pytreedef and the full traced array. This can also be registered as a pytree via jax.tree_util.register_dataclass

This has many benefits:

  • Operations along the tree axis are no longer serialised
  • The tree structure is maintained for efficient operations along the tree dimensions
  • Array operations can be easily broadcasted along the tree dimensions
  • Allows easy conversion between the tree and array dimensions

We need the following generic operations:

furax.tree.TreeArray.[add/subtract/prod/div](self: TreeArray, y: TreeArray) -> TreeArray:
  ''' The usual traced array operations on two TreeArrays, broadcasting if needed'''
  ...

furax.tree.TreeArray.as_tree(self: TreeArray) -> PyTree:
  ''' Return a PyTree as specified by the TreeArray's treedef '''
  ...

furax.tree.TreeArray.to_tree(self: TreeArray, new_treedef: PyTreeDef): -> TreeArray:
  ''' Return a TreeArray with a new treedef. The new treedef should be compatible with the
      original dimensions (d_0, ..., d_(n-1), s_0, ..., s_(m-1)).
      The traced array stored in TreeArray remains unchanged.

      >>> x = TreeArray(jax.numpy.zeros((2, 3)), treedef=jax.tree.structure({'a': 0, 'b': 0}))
      >>> x.as_tree()
      {'a': Array([0., 0., 0.]), 'b': Array([0., 0., 0.])}
      >>> relabelled = x.to_tree(new_treedef=jax.tree.structure({'x': 0, 'y': 0})
      >>> relabelled.as_tree()
      {'x': Array([0., 0., 0.]), 'y': Array([0., 0., 0.])}
      >>> unstacked = x.to_tree(new_treedef=jax.tree.structure({'a': (0, 0, 0), 'b': (0, 0, 0)})
      >>> unstacked.as_tree()
      {'a': (Array(0.), Array(0.), Array(0.)), 'b': (Array(0.), Array(0.), Array(0.))}
      >>> stacked = x.to_tree(new_treedef=None)
      >>> stacked.as_tree()
      Array([[0., 0., 0.], [0., 0., 0.]])
  ''' 
  ...

furax.tree.TreeArray.moveaxis(self: TreeArray, source, destination, stack=True) -> TreeArray:
  ''' The equivalent of jax.numpy.moveaxis for TreeArrays. The source and destination indices
      span both the tree and array dimensions.
      jax.numpy.moveaxis is applied to the underlying traced array of the TreeArray.
      The treedef also changes accordingly. After the reordering of the tree and array axes:
      If stack=True, then all tree dimensions on the right of the leftmost array dimension are flattened.
      Equal or fewer tree dimensions remain after this operation.
      If stack=False, then all array dimensions on the left of the rightmost tree dimension are unstacked
      into a list. Equal or more tree dimensions exist after this operation.
  '''
  ...
  
furax.tree.take(self: TreeArray, indices, axis) -> TreeArray:
  ''' The equivalent of jax.numpy.take for TreeArrays. Slices along a given tree or array dimension.
  '''
  ...

furax.tree.einsum(subscripts, /, *operands, stack=True, enforce_treedef=True) -> TreeArray:
  ''' The equivalent of jax.numpy.einsum for TreeArrays.
      jax.numpy.einsum is applied to the TreeArrays' underlying traced arrays.
      The tree and array axes that are summed over have to be of equal dimensions.
      If enforce_treedef=True, the treedefs on the axes summed over must be identical.
      The new treedef is computed in a way similar to TreeArray.moveaxis().
  '''
  ...

Furax TreeArrayOperators

Once TreeArrays are implemented, we should provide wrappers for the above functions as furax's AbstractLinearOperators.

Note that the current implementation of TreeOperator (by Pierre) allows leaves to be of different shapes. However, in most scientific applications, the leaves traced arrays of identical shapes. Optimised operators and utilities on TreeArrays would therefore be incredibly useful.

For example, users can convert from (StokesIQU pytree of Array[n_freq, n_pix]) to a ([n_freq]-list of StokesIQU pytrees of Array[n_pix]) simply using TreeArray.moveaxis() (this was a task from recent work). Note also that furax.tree.matmat() is an example of einsum between two TreeArrays, where the summation is taken over the 'inner_treedef' of A and 'outer_treedef' of B.

Comments and suggestions are welcome.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions