Add a tqdm progress bar to your JAX scans and loops.
Install with pip:
pip install jax-tqdmfrom jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp
n = 10_000
@scan_tqdm(n)
def step(carry, x):
    return carry + 1, carry + 1
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))from jax_tqdm import loop_tqdm
from jax import lax
n = 10_000
@loop_tqdm(n)
def step(i, val):
    return val + 1
last_number = lax.fori_loop(0, n, step, 0)By default, the progress bar is updated 20 times over the course of the scan/loop
(for performance purposes, see below). This
update rate can be manually controlled with the print_rate keyword argument. For
example:
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp
n = 10_000
@scan_tqdm(n, print_rate=2)
def step(carry, x):
    return carry + 1, carry + 1
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))will update every other step.
Any additional keyword arguments are passed to the tqdm progress bar constructor. For example:
from jax_tqdm import scan_tqdm
from jax import lax
import jax.numpy as jnp
n = 10_000
@scan_tqdm(n, print_rate=1, desc='progress bar', position=0, leave=False)
def step(carry, x):
    return carry + 1, carry + 1
last_number, all_numbers = lax.scan(step, 0, jnp.arange(n))JAX functions are pure, so side effects such as printing progress when running scans and loops are not allowed. However, the host_callback module has primitives for calling Python functions on the host from JAX code. This can be used to update a Python tqdm progress bar regularly during the computation. JAX-tqdm implements this for JAX scans and loops and is used by simply adding a decorator to the body of your update function.
Note that as the tqdm progress bar is only updated 20 times during the scan or loop, there is no performance penalty.
The code is explained in more detail in this blog post.
Dependencies can be installed with poetry by running
poetry installPre commit hooks can be installed by running
pre-commit installPre-commit checks can then be run using
task lintTests can be run with
task test