ssax is a high-performance JAX implementation of the Sinkhorn Step—a batch gradient-free optimizer for highly non-convex objectives.
It leverages the power of OTT-JAX for robust and differentiable optimal transport solvers, enabling efficient exploration of complex landscapes.
Note: This repository provides the JAX implementation of the Sinkhorn Step as a standalone optimizer. For the MPOT trajectory optimizer (PyTorch version), please visit mpot.
This work was accepted to NeurIPS 2023. Accelerating Motion Planning via Optimal Transport
- Clone the repository:
git clone [https://github.com/anindex/ssax.git](https://github.com/anindex/ssax.git)
cd ssax-
Install dependencies: It is recommended to use a fresh conda environment.
pip install -e .GPU Support: To run on GPU, you must install the CUDA-enabled version of JAX manually:
pip install -U "jax[cuda13]"(Check the JAX Installation Guide for your specific CUDA version)
Run a simple optimization on the Ackley function:
# Run with default settings (Ackley 2D)
python scripts/example.py --objective Ackley --num_points 1000
# Try a different function
python scripts/example.py --objective Rosenbrock --num_points 5000 --max_iters 50Explore the landscape of the available synthetic test functions:
# View a gallery of all functions
python scripts/visualize_objectives.py
# Inspect a specific function in 3D
python scripts/visualize_objectives.py --name RastriginUse hydra to run configured experiments. Results (plots/animations) are saved in the logs/ directory.
python scripts/run.py experiment=ss-alAvailable Experiments:
| Config Name | Function | Dim |
|---|---|---|
ss-al |
Ackley | 2D |
ss-al-10d |
Ackley | 10D |
ss-bk |
Bukin | 2D |
ss-eh |
EggHolder | 2D |
ss-rb |
Rosenbrock | 2D |
ss-st |
Styblinski-Tang | 2D |
Hyperparameter Tuning: The critical parameters are
step_radius,probe_radius, andent_epsilon. You can override them from the command line:python scripts/run.py experiment=ss-al optimizer.kwargs.step_radius=0.2
We evaluate the quality of the Sinkhorn Step direction compared to the true gradient (cosine similarity).
Run single-seed analysis (Violin/Box plots):
python scripts/benchmark_cosin_similarity_single.py experiment=ss-st-cosin-sim num_seeds=20Run multi-epsilon sweep (Line plots):
python scripts/benchmark_cosin_similarity.py experiment=ss-st-cosin-sim num_seeds=5If you found this work useful, please consider citing this reference:
@article{le2023accelerating,
title={Accelerating motion planning via optimal transport},
author={Le, An T and Chalvatzaki, Georgia and Biess, Armin and Peters, Jan R},
journal={Advances in Neural Information Processing Systems},
volume={36},
pages={78453--78482},
year={2023}
}- OTT-JAX: The backbone for our differentiable linear solvers.
- SFU Optimization Library: Source for synthetic test function definitions.