Thanks to visit codestin.com
Credit goes to github.com

Skip to content

How to construct freezable Parameters? #4995

@pfackeldey

Description

@pfackeldey

Hi,
I'm currently playing around with nnx (awesome library, I love the way model surgery works!), and JAX's new hijax feature, so, I'm very happy to see this development in #4960.

My problem is as follows:
I'm trying to construct a Parameter class that is freezable (i.e. not differentiable) given a metadata attribute that I called 'frozen':

from typing import TypeVar
from flax import nnx
import jax

T = TypeVar("T")
sg = jax.lax.stop_gradient

class Parameter(nnx.Variable[T]):
  """Freezable parameter class."""

  def __init__(self, value: T, frozen: bool, **kwargs):
    value = jnp.asarray(value)
    if frozen:
      value = sg(value)
    super().__init__(value=value, **kwargs)
    self.frozen = frozen
    
  @property
  def value(self):
    value = super().value
    return sg(value) if self.frozen else value


# build a collection of params
class Params(nnx.Module):
  def __init__(self):
    self.p1 = Parameter(0.0, frozen=True)
    self.p2 = Parameter(2.0, frozen=False)

    self.o = nnx.Param(2.0)

my_params = Params()
jax.tree.map(jax.grad(jnp.sin), my_params)
# Params( # Param: 1 (4 B), Parameter: 2 (8 B), Total: 3 (12 B)
#   o=Param( # 1 (4 B)
#     value=Array(-0.41614684, dtype=float32, weak_type=True)
#   ),
#   p1=Parameter( # 1 (4 B)
#     value=Array(1., dtype=float32, weak_type=True), <--- this should be 0. because it's frozen
#     frozen=True
#   ),
#   p2=Parameter( # 1 (4 B)
#     value=Array(-0.41614684, dtype=float32, weak_type=True),
#     frozen=False
#   )
# )

This doesn't seem to work, and also I'd like to have this dynamic (maybe I have to set this hook instead?), i.e. when I change this frozen from False to True the Parameter (value) should become non-differentiable, something like this:

def fun(param: Parameter):
  ...
  param.frozen = True  # essentially `jax.lax.stop_gradient` for `param.value`
  ...

In my understanding this should be able to work (somehow) because .frozen is just another metadata attribute that's known at trace/compile time (like .shape) and independent of the data (numeric values of the array).

Would this somehow be possible to achieve? Also, how (or if) does this change in the context of the new hijax approach?

Thanks, Peter

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions