Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@mwilson221
Copy link
Contributor

@mwilson221 mwilson221 commented Apr 6, 2024

Checklist

  • My pull request has a clear and explanatory title.
  • If necessary, my code is vectorized.
  • I added appropriate unit tests.
  • I made sure the code passes all unit tests. (refer to comment below)
  • My PR follows PEP8 guidelines. (refer to comment below)
  • My PR follows geomstats coding style and API.
  • My code is properly documented and I made sure the documentation renders properly. (Link)
  • I linked to issues and PRs that are relevant to this PR.

Description

This is a response to Bug: Error due to complex number distance in Bures-Wasserstein metric, using pytorch backend #1882 (closes #1882)

I've been working with Bures-Wasserstein for spd matrices. The issue is the result of the sqrtm function being used. The solution I have found is to use the fact that spd matrices are orthogonally diagonalizble (due to being symmetric) and full rank (due to positive definite) to calculate square roots via diagonalization. This both fixes the issue of imaginary outputs (even for matrices with spd.belongs(sigma)==True -- see Bures_distance_check), and increases the speed. I also have a torch implementation that allows for batch computation on the gpu (of space.metric.dist_pairwise), which greatly speeds up the computation, though it is not backend agnostic.

To be clear, this only works for spd-matrices. Bures-Wasserstein is technically a distance for symmetric positve semidefinite matrices. So, my approach is not actually valid for the Bures metric in general, but is correct (as far as I know) as the Riemannian metric for spd's.

Issue

sqrtm in Bures-Wasserstein metric for spd's sometimes returns imaginary numbers

I wrote this code months ago, now that I'm looking at it again I think the

gs.sqrt(Lx*(Lx>0))

might actually be intended to work for non-spd matrices.. so this might not be completely correct...

Additional context

I'm not really sure how to make a pull request.

Reproducible example:

import numpy as np
import geomstats
from geomstats.geometry.spd_matrices import SPDMatrices, SPDBuresWassersteinMetric

spd = SPDMatrices(3)
spd.equip_with_metric(SPDBuresWassersteinMetric)

X = np.random.randn(3,1)
Sigma_x = [email protected]

Y = np.random.randn(3,1)
Sigma_y = [email protected]

#matrices not spd - should return false
print(spd.belongs(Sigma_x))
print(spd.belongs(Sigma_y))

Sigma_x_star = spd.projection(Sigma_x)
Sigma_y_star = spd.projection(Sigma_y)

#should return True
print(spd.belongs(Sigma_x_star))
print(spd.belongs(Sigma_y_star))

#will sometimes return imaginary numbers due to sqrtm
spd.metric.dist(Sigma_x_star, Sigma_y_star)

Solutions:

Accuracy check: https://github.com/mwilson221/dtmrpy/blob/main/Bures_distance_check.ipynb
Speed check: https://github.com/mwilson221/dtmrpy/blob/main/Bures_distance_speed_check.ipynb

Thanks

@mwilson221 mwilson221 changed the title Update spd_matrices.py Update SPDBuresWassersteinMetric.squared_dist to fix numerical stability issues Apr 6, 2024
@luisfpereira luisfpereira self-requested a review April 8, 2024 07:24
@luisfpereira
Copy link
Collaborator

luisfpereira commented Apr 8, 2024

Very nice, thanks for the contribution @mwilson221! The notebooks with the benchmarks are spot on (we may need to make some changes in geomstats to make the (agnostic) use of gpu more straightforward).

I've fixed the linting and the vectorization. The more "controversial" change I've done is to add an abs to the output, because will testing I've realized sometimes the output is something like -1e16, which creates nan when calling sqrt to get the distance.

After running the tests several times, it looks the code is quite stable now.

For the future, we should keep in mind:

To be clear, this only works for spd-matrices. Bures-Wasserstein is technically a distance for symmetric positve semidefinite matrices. So, my approach is not actually valid for the Bures metric in general, but is correct (as far as I know) as the Riemannian metric for spd's.

@luisfpereira luisfpereira merged commit bacdc1d into geomstats:main Apr 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug: Error due to complex number distance in Bures-Wasserstein metric, using pytorch backend

2 participants