Hi Kevin,
I find the notes very helpful! For the gradient calculation of the bias term, I suppose d_b1 = jnp.mean(d_h2, axis=1) should be d_b1 = jnp.sum(d_h2, axis=1) as we are summing up the gradient influence paths across all datapoints in the batch? I also verified in colab that replacing mean with sum yields jnp.allclose(jax_grad_output[1], manual_grad_output[0][1]) == True :)