Thanks to visit codestin.com
Credit goes to github.com

Skip to content

anh-tong/soft-dtw-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SoftDTW in JAX

Soft-DTW implementation in JAX with custom gradient

Why implement SoftDTW again?

I find the implements using jax.lax.scan very interesting (see this). However, all available implementations do not actually follow the Algorithm 2 in Soft-DTW: a Differentiable Loss Function for Time-Series.

This small repository provides the implemetation of custom gradient for Soft-DTW in JAX. It is implemented using

  • jax.custom_vjp (see more in JAX docs)
  • Algorithm 2 in the Soft-DTW paper is done with jax.lax.scan

See the notebook barycenter.ipynb for a demomstration and a small performance comparison.

About

Soft-DTW implementation in JAX with custom gradient

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published