-
Notifications
You must be signed in to change notification settings - Fork 746
Description
Hi,
I'm currently playing around with nnx (awesome library, I love the way model surgery works!), and JAX's new hijax feature, so, I'm very happy to see this development in #4960.
My problem is as follows:
I'm trying to construct a Parameter class that is freezable (i.e. not differentiable) given a metadata attribute that I called 'frozen':
from typing import TypeVar
from flax import nnx
import jax
T = TypeVar("T")
sg = jax.lax.stop_gradient
class Parameter(nnx.Variable[T]):
"""Freezable parameter class."""
def __init__(self, value: T, frozen: bool, **kwargs):
value = jnp.asarray(value)
if frozen:
value = sg(value)
super().__init__(value=value, **kwargs)
self.frozen = frozen
@property
def value(self):
value = super().value
return sg(value) if self.frozen else value
# build a collection of params
class Params(nnx.Module):
def __init__(self):
self.p1 = Parameter(0.0, frozen=True)
self.p2 = Parameter(2.0, frozen=False)
self.o = nnx.Param(2.0)
my_params = Params()
jax.tree.map(jax.grad(jnp.sin), my_params)
# Params( # Param: 1 (4 B), Parameter: 2 (8 B), Total: 3 (12 B)
# o=Param( # 1 (4 B)
# value=Array(-0.41614684, dtype=float32, weak_type=True)
# ),
# p1=Parameter( # 1 (4 B)
# value=Array(1., dtype=float32, weak_type=True), <--- this should be 0. because it's frozen
# frozen=True
# ),
# p2=Parameter( # 1 (4 B)
# value=Array(-0.41614684, dtype=float32, weak_type=True),
# frozen=False
# )
# )This doesn't seem to work, and also I'd like to have this dynamic (maybe I have to set this hook instead?), i.e. when I change this frozen from False to True the Parameter (value) should become non-differentiable, something like this:
def fun(param: Parameter):
...
param.frozen = True # essentially `jax.lax.stop_gradient` for `param.value`
...In my understanding this should be able to work (somehow) because .frozen is just another metadata attribute that's known at trace/compile time (like .shape) and independent of the data (numeric values of the array).
Would this somehow be possible to achieve? Also, how (or if) does this change in the context of the new hijax approach?
Thanks, Peter