Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MoePien/sliced_OT_for_GMMs

Repository files navigation

Sliced Wasserstein Distances for Gaussian Mixture Models

This repository provides PyTorch-based implementations of sliced Wasserstein-type distances between Gaussian Mixture Models (GMMs). These metrics are useful for comparing distributions in high-dimensional spaces. Check out our preprint (Piening and Beinert, 2025) for details! We would also like to point out our main reference (Delon and Desolneux, 2021) for additional information.

Features

  • 📐 MSW (Mixture Sliced Wasserstein): (Partially) Sliced Wasserstein-type distance for high-dimensional GMMs.
  • 📏 DSMW (Double Sliced Mixture Wasserstein) or SMSW (Sliced Mixture Sliced Wasserstein): Fully Sliced Wasserstein-type distance for high-dimensional GMMs with many components.
  • Parallel & Vectorized Implementations: Fast, differentiable versions of DSMW (SMSW) and MSW
  • 🧮 Support for Full and Diagonal Covariances: Easily generate and manipulate random or structured GMMs.

Contents

Core Modules

  • sliced_mw.py:

    • Implements calc_MSW, calc_SMSW, calc_parallel_SMSW for sliced mixture Wasserstein distances
    • Projection utilities: project_mu, project_sigma, project_gmm_1d
  • gmm_utils.py:

    • GaussianMixtureModel class with sampling and conversion support
    • Utility functions for matrix operations: nearest_psd, get_cholesky, generate_random_covariances
  • ImageGMM.py:

    • Get GMMs based on MNIST for GMMBarycenter_GMMQuantization.ipynb

Dependencies

  • torch
  • numpy
  • POT (Python Optimal Transport) – pip install POT
  • scipy, joblib, scikit-learn

Example Usage

import torch
import sliced_mw as smw
import gmm_utils as GMM

# Define two simple 2D GMMs with 10 components
K = 10
D = 2
gmm1 = GMM.RandomGaussianMixtureModel(K, D, device="cpu")
gmm2 = GMM.RandomGaussianMixtureModel(K, D, device="cpu")

# Compute sliced mixture Wasserstein distance
DSMW_distance = smw.calc_parallel_SMSW(gmm1, gmm2, pnum=1000) # SMSW=DSMW
print(f"SMSW Distance: {DSMW_distance.item():.4f}")
MSW_distance = smw.calc_MSW(gmm1, gmm2, pnum=1000) # SMSW=DSMW
print(f"MSW Distance: {MSW_distance.item():.4f}")

Experiments

  • ComputationTime.ipynb: Measure computation time
  • ClusterDetection.ipynb: Detect ideal number of GMM cluster
  • PerceptualGMMOT.ipynb: Use our sliced metrics for perceptual evaluation based on the framework presented in Luzi et al., 2023
  • GMMBarycenter_GMMQuantization.ipynb: Gradient descent on GMMs for barycenter and component reduction

Example of Quantization

Density of Input 2d GMM Density of Reduced 2d GMM
Input Reduced

Notes

  • All distances are computed with squared Wasserstein-2 metrics unless noted otherwise.
  • Parallelized variants (calc_parallel_*) are faster and fully differentiable.

Citation

@article{piening2025slicing,
  title={Slicing the Gaussian Mixture Wasserstein Distance},
  author={Piening, Moritz and Beinert, Robert},
  journal={arXiv preprint arXiv:2504.08544},
  year={2025}
}

License

MIT License

About

"Slicing the Gaussian Mixture Wasserstein Distance" (TMLR)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published