Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1fcc2a5

Browse files
Cristian GarciaFlax Authors
authored andcommitted
check aliases on all transform args and simplify apply_variable_updates
PiperOrigin-RevId: 885586300
1 parent b313d06 commit 1fcc2a5

6 files changed

Lines changed: 119 additions & 39 deletions

File tree

docs_nnx/flip/5310-tree-mode-nnx.md

Lines changed: 80 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Tree Mode NNX
22

3-
Mar 4, 2026,
4-
Cristian Garcia, Flax Team
3+
Mar 4, 2026
4+
Cristian Garcia, Samuel Anklesaria, Flax Team
55

66
## Motivation
77

@@ -67,6 +67,7 @@ These new transforms are highly simplified compared to current transforms, they
6767
```py
6868
def transform_wrapper(*args):
6969
if graph: args = to_tree(args)
70+
check_no_aliases(args=args)
7071

7172
@jax_transform
7273
def transformed_f(*args):
@@ -122,6 +123,7 @@ Code that relies on prefix filters such as StateAxes, StateSharding, and DiffSta
122123
```py
123124
# previous code
124125
state_axes = nnx.StateAxes({some_filter: 0, ...: None})
126+
125127
@nnx.vmap(in_axis=state_axes, graph=True, graph_updates=True)
126128
def f(model):
127129
...
@@ -184,7 +186,7 @@ def loss_fn(model: Foo):
184186
grads: Foo = nnx.grad(loss_fn)(model)
185187
```
186188

187-
### nnx.custom\_vjp
189+
### nnx.custom_vjp
188190

189191
Previously `nnx.custom_vjp` did two particular things:
190192

@@ -242,12 +244,13 @@ Previously NNX transforms like `vmap` and `scan` had a `transform_metadata` meta
242244
@nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: 'din'})
243245
class create_stack(rngs): # 'din' added to out_sharding metadata
244246
return nnx.Variable(rngs.uniform((16,)), out_sharding=('dout',))
247+
245248
v_stack = create_stack(nnx.Rngs(0))
246249
assert v_stack.shape == (8, 16)
247250
assert v_stack.out_shardings == ('din', 'dout')
248251
```
249252

250-
The new simplified NNX transform implementations don’t support this argument. However, to keep supporting the behavior, a new `nnx.transform_metadata` transform is introduced that can be inserted to get back the same results.
253+
The new simplified NNX transform implementations don’t support this argument. However, to keep supporting the behavior, a new `nnx.transform_metadata` transform is introduced that can be inserted to get back the same results. TODO: mention it works on `jax.vmap`.
251254

252255
```py
253256
# new code
@@ -256,9 +259,82 @@ The new simplified NNX transform implementations don’t support this argument.
256259
@nnx.transform_metadata(in_axes=0, out_axes=0, partition='din')
257260
class create_stack(rngs): # 'din' added to out_sharding metadata
258261
return nnx.Variable(rngs.uniform((16,)), out_sharding=('dout',))
262+
259263
v_stack = create_stack(nnx.Rngs(0))
260264
assert v_stack.shape == (8, 16)
261265
assert v_stack.out_shardings == ('din', 'dout')
262266
```
263267

264268
`transform_metada` accepts `in_axes` and `out_axes`, these should match the values passed to the corresponding transform.
269+
270+
### Module.sow
271+
272+
Previously, `Module.sow` used graph updates to capture intermediate values during computations and propagate them outside, it was used in conjunction with `nnx.pop` to log and extract intermediates:
273+
274+
```py
275+
# old code
276+
class Foo(nnx.Module):
277+
def __call__(self, x):
278+
self.sow(nnx.Intermediate, "y_mean", jnp.mean(x))
279+
return x
280+
281+
model = Foo()
282+
result = model(x)
283+
intermediates = nnx.pop(model, nnx.Intermediate) # extract intermediate values
284+
```
285+
286+
To achieve the same without graph updates we’ve added a new `nnx.capture` API which allows for a similar workflow.
287+
288+
```py
289+
# New Code
290+
class Foo(nnx.Module):
291+
def __call__(self, x):
292+
self.sow(nnx.Intermediate, "y_mean", jnp.mean(x))
293+
return x
294+
295+
model = Foo()
296+
result, intermediates = nnx.capture(model, nnx.Intermediate)(x)
297+
```
298+
299+
In general, `nnx.capture` takes a function or Module to be transformed, a `nnx.Variable` subclass to collect, and an optional `init` argument to initialize the collected state, which will be stored within `nnx.Variable` objects. `nnx.capture` creates a `__captures__: tuple[Variable, ...]` attribute on each `Module` instance, each Variable in `__captures__` contains a dictionary which `sow` and `perturb` populate.
300+
301+
### Module.perturb
302+
303+
Similarly, `Module.perturb` was previously used to extract the gradients of intermediate values. This was done in two steps: initializing a perturbation state by running a module once, and then passing the perturbation state as a differentiable target to `grad`.
304+
305+
```py
306+
class Model(nnx.Module):
307+
def __call__(self, x):
308+
x = self.perturb('grad_of_x', x)
309+
...
310+
return y
311+
312+
# old code
313+
@nnx.jit
314+
def train_step(model, optimizer, x, y):
315+
model(x) # Initialize perturbation state
316+
def loss_fn(model):
317+
y_pred = model(x)
318+
return jnp.mean((y_pred - y) ** 2)
319+
diff_state = nnx.DiffState(0, (nnx.Param, nnx.Perturbation))
320+
grads = nnx.grad(loss_fn, argnums=diff_state)(model)
321+
grads, interm_grads = nnx.state(grads, nnx.Param, nnx.Perturbation)
322+
optimizer.update(model, grads)
323+
nnx.pop(model, nnx.Perturbation) # clean up perturbations
324+
return interm_grads
325+
```
326+
327+
Similar pattern can be used with `nnx.capture` during both perturbation initialization and when running the forward pass to insert the differentiable perturbations state. In this version explicitly pass the `perturbs` state as a separate argument and use `argnums` to specify that both arguments are differentiable:
328+
329+
```py
330+
# new code
331+
@nnx.jit
332+
def train_step(model, optimizer, x, y):
333+
_, perturbs = nnx.capture(model, nnx.Perturbation)(x) # init perturbations
334+
def loss_fn(model, perturbs):
335+
y_pred = nnx.capture(model, init=perturbs)(x)
336+
return jnp.mean((y_pred - y) ** 2)
337+
grads, interm_grads = nnx.grad(loss_fn, argnums=(0, 1))(model, perturbs)
338+
optimizer.update(model, grads)
339+
return interm_grads
340+
```

flax/nnx/extract.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -580,29 +580,14 @@ def _mask_updates(path, current, snapshot):
580580
)
581581

582582

583-
def apply_variable_updates(
584-
args_tree: A, updates_tree: A, *, fn_name: str,
585-
) -> None:
583+
def apply_variable_updates(args_tree: A, updates_tree: A):
586584
is_leaf = lambda x: isinstance(x, variablelib.Variable) or isinstance(x, Mask)
587-
args_leaves, treedef = jax.tree.flatten_with_path(args_tree, is_leaf=is_leaf)
585+
args_leaves = jax.tree.leaves(args_tree, is_leaf=is_leaf)
586+
_, treedef = jax.tree.flatten(args_tree, is_leaf=is_leaf)
588587
updates_leaves = treedef.flatten_up_to(updates_tree)
589-
seen: dict[int, jax.tree_util.KeyPath] = {}
590-
for (path, variable), update in zip(args_leaves, updates_leaves):
591-
if not isinstance(variable, variablelib.Variable):
592-
continue
593-
var_id = id(variable)
594-
if var_id in seen:
595-
path_str = jax.tree_util.keystr(path)
596-
seen_path_str = jax.tree_util.keystr(seen[var_id])
597-
raise ValueError(
598-
f'Duplicate {variable}\nfound at paths:\n\n'
599-
f' - {seen_path_str}\n'
600-
f' - {path_str}\n\n'
601-
f'Tree mode (graph=False) does not support shared references. '
602-
+ graphlib._tree_mode_suggestion_transform(fn_name)
603-
)
604-
seen[var_id] = path
588+
for variable, update in zip(args_leaves, updates_leaves, strict=True):
605589
if isinstance(update, variablelib.Variable):
590+
assert isinstance(variable, variablelib.Variable)
606591
variable.update_from_state(update)
607592

608593

flax/nnx/transforms/autodiff.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def tree_grad_wrapper(*args, **kwargs):
168168
(args, kwargs), prefix=(args_prefix, False),
169169
)
170170

171+
extract.check_no_aliases('grad', args=args, kwargs=kwargs)
172+
171173
fn_out = gradded_fn(*args, **kwargs)
172174

173175
if return_value:
@@ -189,7 +191,7 @@ def tree_grad_wrapper(*args, **kwargs):
189191
if graph: grads = extract.from_tree2(grads)
190192
result = grads
191193

192-
extract.apply_variable_updates((args, kwargs), updates, fn_name='grad')
194+
extract.apply_variable_updates((args, kwargs), updates)
193195
return result
194196

195197
return tree_grad_wrapper
@@ -691,6 +693,7 @@ def vjp(
691693

692694
if graph:
693695
primals = extract.to_tree2(primals)
696+
extract.check_no_aliases('vjp', primals=primals)
694697
primals_out, vjp_fn, aux = jax.vjp(
695698
SimpleVjpFn(f_unbound, has_aux=has_aux, graph=graph),
696699
*primals,
@@ -706,7 +709,7 @@ def vjp(
706709
raw_vjp_fn = vjp_fn
707710
def vjp_fn(g):
708711
return extract.from_tree2(raw_vjp_fn(g))
709-
extract.apply_variable_updates(primals, updates, fn_name='vjp')
712+
extract.apply_variable_updates(primals, updates)
710713
if has_aux:
711714
return primals_out, vjp_fn, user_aux
712715
else:
@@ -865,6 +868,8 @@ def jvp(
865868
if graph:
866869
primals = extract.to_tree2(primals)
867870
tangents = extract.to_tree2(tangents)
871+
extract.check_no_aliases('jvp', primals=primals)
872+
extract.check_no_aliases('jvp', tangents=tangents)
868873
if has_aux:
869874
(primals_out, updates), (tangent_out, _updates_tangent), aux = jax.jvp(
870875
SimpleJvpFn(f_unbound, has_aux=True, graph=graph),
@@ -881,7 +886,7 @@ def jvp(
881886
if graph:
882887
primals_out = extract.from_tree2(primals_out)
883888
tangent_out = extract.from_tree2(tangent_out)
884-
extract.apply_variable_updates(primals, updates, fn_name='jvp')
889+
extract.apply_variable_updates(primals, updates)
885890
if has_aux:
886891
return primals_out, tangent_out, aux
887892
else:
@@ -982,6 +987,7 @@ def __call__(
982987
i not in self.nondiff_argnums for i in range(len(args))
983988
)
984989
args = extract.to_tree2(args, prefix=prefix)
990+
extract.check_no_aliases('custom_vjp', args=args)
985991
(out, updates) = self.custom_vjp_fn(*args)
986992
# check that differentiable arguments were not mutated
987993
diff_argnums = tuple(
@@ -1007,7 +1013,7 @@ def __call__(
10071013
)
10081014
if self.graph:
10091015
out = extract.from_tree2(out)
1010-
extract.apply_variable_updates(args, updates, fn_name='custom_vjp')
1016+
extract.apply_variable_updates(args, updates)
10111017
return out
10121018

10131019
def defvjp(
@@ -1661,10 +1667,11 @@ def remat(
16611667
def simple_remat_wrapper(*args, **kwargs):
16621668
if graph:
16631669
args, kwargs = extract.to_tree2((args, kwargs))
1670+
extract.check_no_aliases('remat', args=args, kwargs=kwargs)
16641671
out, updates = checkpointed_fn(*args, **kwargs)
16651672
if graph:
16661673
out = extract.from_tree2(out)
1667-
extract.apply_variable_updates((args, kwargs), updates, fn_name='remat')
1674+
extract.apply_variable_updates((args, kwargs), updates)
16681675
return out
16691676

16701677
return simple_remat_wrapper # type: ignore[return-value]

flax/nnx/transforms/compilation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,10 @@ def _maybe_from_tree(self, out):
571571

572572
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
573573
args, kwargs = self._maybe_to_tree(args, kwargs)
574+
extract.check_no_aliases('jit', args=args, kwargs=kwargs)
574575
out, updates = self.jitted_fn(*self.partial_args, *args, **kwargs)
575576
extract.apply_variable_updates(
576-
((*self.partial_args, *args), kwargs), updates, fn_name='jit')
577+
((*self.partial_args, *args), kwargs), updates)
577578
return self._maybe_from_tree(out)
578579

579580
def __get__(self, obj, objtype=None):
@@ -1149,8 +1150,9 @@ def call(*args, **kwargs):
11491150

11501151
def __call__(self, *args, **kwargs):
11511152
args, kwargs = self.jit_wrapped._maybe_to_tree(args, kwargs)
1153+
extract.check_no_aliases('jit', args=args, kwargs=kwargs)
11521154
out, updates = self.compiled(*args, **kwargs)
1153-
extract.apply_variable_updates((args, kwargs), updates, fn_name='jit')
1155+
extract.apply_variable_updates((args, kwargs), updates)
11541156
return self.jit_wrapped._maybe_from_tree(out)
11551157

11561158
@property
@@ -1545,8 +1547,9 @@ def shard_map_wrapper(*args, **kwargs):
15451547
prefix=in_specs,
15461548
check_aliasing=in_specs is not None,
15471549
)
1550+
extract.check_no_aliases('shard_map', args=args)
15481551
out, updates = shard_map_fn(*args, **kwargs)
1549-
extract.apply_variable_updates(args, updates, fn_name='shard_map')
1552+
extract.apply_variable_updates(args, updates)
15501553
if graph:
15511554
out = extract.from_tree2(out)
15521555
return out

flax/nnx/transforms/iteration.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def wrapper(*in_args, **in_kwargs):
128128
in_args = resolve_kwargs(f, in_args, in_kwargs)
129129
if graph:
130130
in_args = extract.to_tree2(in_args, prefix=in_axes)
131+
extract.check_no_aliases('transform_metadata', args=in_args)
131132
args = graphlib.clone(in_args, graph=graph)
132133
_apply_axis_fn(args, in_axes, metadata, spmd.remove_axis)
133134
updates, snapshot = extract.updates_and_snapshot(args)
@@ -136,10 +137,11 @@ def wrapper(*in_args, **in_kwargs):
136137
out = f(*args)
137138
if graph:
138139
out = extract.to_tree2(out, prefix=out_axes)
140+
extract.check_no_aliases('transform_metadata', args=updates, out=out)
139141
_apply_axis_fn(args, in_axes, metadata, spmd.add_axis)
140142
_apply_axis_fn(out, out_axes, metadata, spmd.add_axis)
141143
updates = extract.mask_variable_updates(updates, snapshot)
142-
extract.apply_variable_updates(in_args, updates, fn_name='transform_metadata')
144+
extract.apply_variable_updates(in_args, updates)
143145
if graph:
144146
out = extract.from_tree2(out)
145147
return out
@@ -524,8 +526,9 @@ def simple_vmap_wrapper(*args, **kwargs):
524526
else None,
525527
check_aliasing=in_axes is not None,
526528
)
529+
extract.check_no_aliases('vmap', args=args, kwargs=kwargs)
527530
out, updates = vmapped_fn(*args, **kwargs)
528-
extract.apply_variable_updates((args, kwargs), updates, fn_name='vmap')
531+
extract.apply_variable_updates((args, kwargs), updates)
529532
if graph:
530533
out = extract.from_tree2(out)
531534
return out
@@ -791,8 +794,9 @@ def simple_pmap_wrapper(*args, **kwargs):
791794
else None,
792795
check_aliasing=in_axes is not None,
793796
)
797+
extract.check_no_aliases('pmap', args=args, kwargs=kwargs)
794798
out, updates = pmapped_fn(*args, **kwargs)
795-
extract.apply_variable_updates((args, kwargs), updates, fn_name='pmap')
799+
extract.apply_variable_updates((args, kwargs), updates)
796800
if graph:
797801
out = extract.from_tree2(out)
798802
return out
@@ -1668,6 +1672,8 @@ def simple_scan_wrapper(*args):
16681672
if graph:
16691673
args = extract.to_tree2(args, prefix=in_axes)
16701674

1675+
extract.check_no_aliases('scan', args=args)
1676+
16711677
result = pure_jax_fancy_scan(
16721678
simple_scan_fn,
16731679
*args,
@@ -1687,7 +1693,7 @@ def simple_scan_wrapper(*args):
16871693
out, updates = result
16881694

16891695
masked_args = extract.mask_at(args, carry_arg_index)
1690-
extract.apply_variable_updates(masked_args, updates, fn_name='scan')
1696+
extract.apply_variable_updates(masked_args, updates)
16911697

16921698
if carry_arg_index is not None:
16931699
carry_in = args[carry_arg_index]

flax/nnx/transforms/transforms.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,10 +434,11 @@ def checkify(
434434
def simple_checkify_wrapper(*args):
435435
if graph:
436436
args = extract.to_tree2(args)
437+
extract.check_no_aliases('checkify', args=args)
437438
error, (out, updates) = checkify_fn(*args)
438439
if graph:
439440
out = extract.from_tree2(out)
440-
extract.apply_variable_updates(args, updates, fn_name='checkify')
441+
extract.apply_variable_updates(args, updates)
441442
return error, out
442443

443444
return simple_checkify_wrapper # type: ignore
@@ -519,6 +520,7 @@ def cond(
519520
if not graph or not graph_updates:
520521
if graph:
521522
operands = extract.to_tree2(operands)
523+
extract.check_no_aliases('cond', operands=operands)
522524
out, updates = jax.lax.cond(
523525
pred,
524526
SimpleCondFn(true_fun, graph=graph),
@@ -527,7 +529,7 @@ def cond(
527529
)
528530
if graph:
529531
out = extract.from_tree2(out)
530-
extract.apply_variable_updates(operands, updates, fn_name='cond')
532+
extract.apply_variable_updates(operands, updates)
531533
return out
532534

533535
@general.split_inputs(ctxtag='cond')
@@ -573,14 +575,15 @@ def switch(
573575
if not graph or not graph_updates:
574576
if graph:
575577
operands = extract.to_tree2(operands)
578+
extract.check_no_aliases('switch', operands=operands)
576579
out, updates = jax.lax.switch(
577580
index,
578581
[SimpleCondFn(f, graph=graph) for f in branches],
579582
*operands,
580583
)
581584
if graph:
582585
out = extract.from_tree2(out)
583-
extract.apply_variable_updates(operands, updates, fn_name='switch')
586+
extract.apply_variable_updates(operands, updates)
584587
return out
585588

586589
@general.split_inputs(ctxtag='switch')

0 commit comments

Comments
 (0)