You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The new simplified NNX transform implementations don’t support this argument. However, to keep supporting the behavior, a new `nnx.transform_metadata` transform is introduced that can be inserted to get back the same results.
253
+
The new simplified NNX transform implementations don’t support this argument. However, to keep supporting the behavior, a new `nnx.transform_metadata` transform is introduced that can be inserted to get back the same results. TODO: mention it works on `jax.vmap`.
251
254
252
255
```py
253
256
# new code
@@ -256,9 +259,82 @@ The new simplified NNX transform implementations don’t support this argument.
`transform_metada` accepts `in_axes` and `out_axes`, these should match the values passed to the corresponding transform.
269
+
270
+
### Module.sow
271
+
272
+
Previously, `Module.sow` used graph updates to capture intermediate values during computations and propagate them outside, it was used in conjunction with `nnx.pop` to log and extract intermediates:
In general, `nnx.capture` takes a function or Module to be transformed, a `nnx.Variable` subclass to collect, and an optional `init` argument to initialize the collected state, which will be stored within `nnx.Variable` objects. `nnx.capture` creates a `__captures__: tuple[Variable, ...]` attribute on each `Module` instance, each Variable in `__captures__` contains a dictionary which `sow` and `perturb` populate.
300
+
301
+
### Module.perturb
302
+
303
+
Similarly, `Module.perturb` was previously used to extract the gradients of intermediate values. This was done in two steps: initializing a perturbation state by running a module once, and then passing the perturbation state as a differentiable target to `grad`.
nnx.pop(model, nnx.Perturbation) # clean up perturbations
324
+
return interm_grads
325
+
```
326
+
327
+
Similar pattern can be used with `nnx.capture` during both perturbation initialization and when running the forward pass to insert the differentiable perturbations state. In this version explicitly pass the `perturbs` state as a separate argument and use `argnums` to specify that both arguments are differentiable:
0 commit comments