# Runtime value debugging in JAX

Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more.

Table of contents:

* [Interactive inspection with `jax.debug`](print_breakpoint)
* [Functional error checks with jax.experimental.checkify](checkify_guide)
* [Throwing Python errors with JAX’s debug flags](flags)

## [Interactive inspection with `jax.debug`](print_breakpoint)

  **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions,
  and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack:

  ```python
  import jax
  import jax.numpy as jnp

  @jax.jit
  def f(x):
    jax.debug.print("🤯 {x} 🤯", x=x)
    y = jnp.sin(x)
    jax.debug.breakpoint()
    jax.debug.print("🤯 {y} 🤯", y=y)
    return y

  f(2.)
# Prints:
# 🤯 2.0 🤯
# Enters breakpoint to inspect values!
# 🤯 0.9092974662780762 🤯
  ```

Click [here](print_breakpoint) to learn more!

## [Functional error checks with `jax.experimental.checkify`](checkify_guide)

  **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:

  ```python
  from jax.experimental import checkify
  import jax
  import jax.numpy as jnp

  def f(x, i):
    checkify.check(i >= 0, "index needs to be non-negative!")
    y = x[i]
    z = jnp.sin(y)
    return z

  jittable_f = checkify.checkify(f)

  err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
  print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
  ```

  You can also use checkify to automatically add common checks:

  ```python
  errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
  checked_f = checkify.checkify(f, errors=errors)

  err, z = checked_f(jnp.ones((5,)), 100)
  err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)

  err, z = checked_f(jnp.ones((5,)), -1)
  err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))

  err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
  err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
  ```

Click [here](checkify_guide) to learn more!

## [Throwing Python errors with JAX's debug flags](flags)

**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`.

```python
from jax.config import config
config.update("jax_debug_nans", True)

def f(x, y):
  return x / y
jax.jit(f)(0., 0.)  # ==> raises FloatingPointError exception!
```

Click [here](flags) to learn more!

```{toctree}
:caption: Read more
:maxdepth: 1

print_breakpoint
checkify_guide
flags
```

