This module provides an interface between JAX and Pint to allow JAX to support operations with units. The propagation of units happens at trace time, so jitted functions should see no runtime cost. This library is experimental so expect some sharp edges.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> import jpu
>>>
>>> u = jpu.UnitRegistry()
>>>
>>> @jax.jit
... def add_two_lengths(a, b):
... return a + b
...
>>> add_two_lengths(3 * u.m, jnp.array([4.5, 1.2, 3.9]) * u.cm)
<Quantity([3.045 3.012 3.039], 'meter')>To install, use pip:
python -m pip install jpuThe only dependencies are jax and pint, and these will also be installed, if
not already in your environment. Take a look at the JAX docs for more
information about installing JAX on different
systems.
Here is a slightly more complete example:
>>> import jax
>>> import numpy as np
>>> from jpu import UnitRegistry, numpy as jnpu
>>>
>>> u = UnitRegistry()
>>>
>>> @jax.jit
... def projectile_motion(v_init, theta, time, g=u.standard_gravity):
... """Compute the motion of a projectile with support for units"""
... x = v_init * time * jnpu.cos(theta)
... y = v_init * time * jnpu.sin(theta) - 0.5 * g * jnpu.square(time)
... return x.to(u.m), y.to(u.m)
...
>>> x, y = projectile_motion(
... 5.0 * u.km / u.h, 60 * u.deg, np.linspace(0, 1, 50) * u.s
... )The most significant limitation of this library is the fact that users must use
jpu.numpy functions when interacting with "quantities" with units instead of
the jax.numpy interface. This is because JAX does not (yet?) provide a general
interface for dispatching of ufuncs on custom array classes. I have played
around with the undocumented __jax_array__ interface, but it's not really
flexible enough, and it isn't currently compatible with Pytree objects.
So far, only a subset of the numpy/jax.numpy interface is implemented. Pull
requests adding broader support (including submodules) would be welcome!