Update SPDBuresWassersteinMetric.squared_dist to fix numerical stability issues #1987
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Checklist
Description
This is a response to Bug: Error due to complex number distance in Bures-Wasserstein metric, using pytorch backend #1882 (closes #1882)
I've been working with Bures-Wasserstein for spd matrices. The issue is the result of the sqrtm function being used. The solution I have found is to use the fact that spd matrices are orthogonally diagonalizble (due to being symmetric) and full rank (due to positive definite) to calculate square roots via diagonalization. This both fixes the issue of imaginary outputs (even for matrices with spd.belongs(sigma)==True -- see Bures_distance_check), and increases the speed. I also have a torch implementation that allows for batch computation on the gpu (of space.metric.dist_pairwise), which greatly speeds up the computation, though it is not backend agnostic.
To be clear, this only works for spd-matrices. Bures-Wasserstein is technically a distance for symmetric positve semidefinite matrices. So, my approach is not actually valid for the Bures metric in general, but is correct (as far as I know) as the Riemannian metric for spd's.
Issue
sqrtm in Bures-Wasserstein metric for spd's sometimes returns imaginary numbers
I wrote this code months ago, now that I'm looking at it again I think the
gs.sqrt(Lx*(Lx>0))
might actually be intended to work for non-spd matrices.. so this might not be completely correct...
Additional context
I'm not really sure how to make a pull request.
Reproducible example:
import numpy as np
import geomstats
from geomstats.geometry.spd_matrices import SPDMatrices, SPDBuresWassersteinMetric
spd = SPDMatrices(3)
spd.equip_with_metric(SPDBuresWassersteinMetric)
X = np.random.randn(3,1)
Sigma_x = [email protected]
Y = np.random.randn(3,1)
Sigma_y = [email protected]
#matrices not spd - should return false
print(spd.belongs(Sigma_x))
print(spd.belongs(Sigma_y))
Sigma_x_star = spd.projection(Sigma_x)
Sigma_y_star = spd.projection(Sigma_y)
#should return True
print(spd.belongs(Sigma_x_star))
print(spd.belongs(Sigma_y_star))
#will sometimes return imaginary numbers due to sqrtm
spd.metric.dist(Sigma_x_star, Sigma_y_star)
Solutions:
Accuracy check: https://github.com/mwilson221/dtmrpy/blob/main/Bures_distance_check.ipynb
Speed check: https://github.com/mwilson221/dtmrpy/blob/main/Bures_distance_speed_check.ipynb
Thanks