-
Couldn't load subscription status.
- Fork 71
Description
Hi,
I'm trying to use orbax to checkpoint arbitrary penzai models, which have parameters of type pz.ParameterValue, which themselves contain pz.NamedArray instances. These instances contain named axes. So, I thought I'd try my hand at implementing class derived from type_handlers.TypeHandler. However, I can't seem to see where in this workflow the axis names would be stored. I saw there is a TypeHandler.metadata method, but that seems to be called only during restore. And, TypeHandler.serialize doesn't seem to provide an opportunity to specialize except at a very low level, at the tensorstore level.
Am I missing something else? It would be nice to be able to use orbax with penzai models.
On the penzai side, they claim that orbax can be used, but there is no example of saving / loading an arbitrary model.