Soft-DTW implementation in JAX with custom gradient
I find the implements using jax.lax.scan very interesting (see this). However, all available implementations do not actually follow the Algorithm 2 in Soft-DTW: a Differentiable Loss Function for Time-Series.
This small repository provides the implemetation of custom gradient for Soft-DTW in JAX. It is implemented using
jax.custom_vjp(see more in JAX docs)- Algorithm 2 in the Soft-DTW paper is done with
jax.lax.scan
See the notebook barycenter.ipynb for a demomstration and a small performance comparison.