Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
106 commits
Select commit Hold shift + click to select a range
99129e9
first commit
SaitejaUtpala Sep 7, 2021
6140e6a
add aux methods
SaitejaUtpala Sep 7, 2021
e1bb569
clean
SaitejaUtpala Sep 7, 2021
c5ac1df
projection methods
SaitejaUtpala Sep 7, 2021
980335c
to from vecs
SaitejaUtpala Sep 7, 2021
6755c6f
tf tril_to_vec
SaitejaUtpala Sep 7, 2021
6372c49
add cholesky factor
SaitejaUtpala Sep 7, 2021
c5d2417
cleaning
SaitejaUtpala Sep 7, 2021
995b485
testing
SaitejaUtpala Sep 7, 2021
878e476
cholesky metric
SaitejaUtpala Sep 9, 2021
7e308c6
metric
SaitejaUtpala Sep 9, 2021
cbf57b0
R-exp and R-log
SaitejaUtpala Sep 9, 2021
ed89a0a
differential
SaitejaUtpala Sep 9, 2021
1ea9273
gram and its differential
SaitejaUtpala Sep 9, 2021
4e1345b
comments
SaitejaUtpala Sep 9, 2021
2782f2a
inv differential
SaitejaUtpala Sep 9, 2021
981ee1b
some tests
SaitejaUtpala Sep 12, 2021
7c6cabb
tests
SaitejaUtpala Sep 12, 2021
9d5136a
tests for lwt
SaitejaUtpala Sep 12, 2021
ed8da9c
squared_dist for cholesky
SaitejaUtpala Sep 25, 2021
8393bcf
resolve conflicts
SaitejaUtpala Sep 25, 2021
176dfe5
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 25, 2021
c64b4f3
corrections
SaitejaUtpala Sep 25, 2021
c63b4b1
merged
SaitejaUtpala Sep 25, 2021
1f5c2d1
cleaning
SaitejaUtpala Sep 25, 2021
9e7c8c1
typo
SaitejaUtpala Sep 25, 2021
8fc1d43
Rexp and Rlog maps
SaitejaUtpala Sep 26, 2021
4a4243a
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 26, 2021
b922f4e
correction in squared_dist
SaitejaUtpala Sep 26, 2021
575d978
linting
SaitejaUtpala Sep 26, 2021
0712536
vec to diag without assignment
SaitejaUtpala Sep 26, 2021
13fcf24
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 26, 2021
4a609b6
remove unn code
SaitejaUtpala Sep 26, 2021
5c8eaa4
reformat
SaitejaUtpala Sep 26, 2021
336b816
renaming
SaitejaUtpala Sep 26, 2021
011731b
renaming-2
SaitejaUtpala Sep 26, 2021
dc4e1e0
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 26, 2021
fd7285b
tf_and_pytorch vec_to_diag
SaitejaUtpala Sep 26, 2021
555579a
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala Sep 26, 2021
1ff02a3
reformat
SaitejaUtpala Sep 26, 2021
e7b6d92
fix error in tets
SaitejaUtpala Sep 26, 2021
1ca14f2
typo
SaitejaUtpala Sep 26, 2021
ea3f941
change projection and random point
SaitejaUtpala Sep 26, 2021
185b958
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 26, 2021
279bbb6
gs.diagonal -> Matrices.diagonal
SaitejaUtpala Sep 26, 2021
265ff61
change projection method to handle case of 0
SaitejaUtpala Sep 26, 2021
af720ad
typo
SaitejaUtpala Sep 26, 2021
b5f9da0
cleaning
SaitejaUtpala Sep 26, 2021
b2517a1
tests for metric on cholesky manifold
SaitejaUtpala Sep 26, 2021
d2097f9
test for cholesky factor
SaitejaUtpala Sep 27, 2021
da85c6d
differential of cholesky factor
SaitejaUtpala Sep 27, 2021
4fddf16
tests for Matrices method
SaitejaUtpala Sep 27, 2021
962d27b
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 27, 2021
ae421fb
test for squared dist
SaitejaUtpala Sep 27, 2021
2e067b5
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala Sep 27, 2021
babce12
more tests
SaitejaUtpala Sep 28, 2021
160167f
more tets
SaitejaUtpala Sep 28, 2021
1bd87e3
fix few erorrs
SaitejaUtpala Sep 28, 2021
146823f
remove some more errors
SaitejaUtpala Sep 28, 2021
f733768
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 28, 2021
4dbacec
fix errors
SaitejaUtpala Sep 28, 2021
62e7d0f
fix inv_diff gram
SaitejaUtpala Sep 28, 2021
9bec436
belongs tests
SaitejaUtpala Sep 28, 2021
d130d07
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala Sep 28, 2021
84879e5
tests
SaitejaUtpala Sep 28, 2021
567600f
(gh-action-bot) Format Python code with black push
SaitejaUtpala Sep 28, 2021
92783a8
fix more errors
SaitejaUtpala Sep 28, 2021
f507dea
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala Sep 28, 2021
003ce99
fix bug in to diagonal
SaitejaUtpala Sep 28, 2021
ef8abbd
fix ellipsis error
SaitejaUtpala Sep 28, 2021
096183d
linting
SaitejaUtpala Sep 28, 2021
7672af0
testing
SaitejaUtpala Sep 28, 2021
b1fe05d
debugging
SaitejaUtpala Sep 28, 2021
b075ff5
debug
SaitejaUtpala Sep 28, 2021
dbdcd51
reuse inv differential gram
SaitejaUtpala Sep 28, 2021
67aa543
reduce duplication
SaitejaUtpala Sep 28, 2021
66be7ea
fix tf tril
SaitejaUtpala Sep 28, 2021
737ac5b
fix bug in tril
SaitejaUtpala Sep 28, 2021
48572b9
one last fix
SaitejaUtpala Sep 28, 2021
0f11886
rm flake8
SaitejaUtpala Dec 4, 2021
3a23065
(gh-action-bot) Format Python code with black push
SaitejaUtpala Dec 4, 2021
f9ae8f2
redo
SaitejaUtpala Dec 4, 2021
4fc9747
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala Dec 4, 2021
0acf925
test get basis method
SaitejaUtpala Dec 10, 2021
e0501dc
typo fix
SaitejaUtpala Dec 10, 2021
d18d8b8
fix typo again
SaitejaUtpala Dec 10, 2021
448490d
make it work for tf
SaitejaUtpala Dec 10, 2021
a88bcaf
cumprod fix
SaitejaUtpala Dec 10, 2021
91906ea
reshape
SaitejaUtpala Dec 10, 2021
1b2c2cc
fix deepsource
SaitejaUtpala Dec 10, 2021
5c02cc9
chol
SaitejaUtpala Dec 10, 2021
5c1be75
some changes
SaitejaUtpala Dec 11, 2021
79183c1
resolve conflicts
SaitejaUtpala Dec 16, 2021
2acf22b
add dots
SaitejaUtpala Dec 16, 2021
bb0d188
address pr review comments-1
SaitejaUtpala Dec 16, 2021
8aff1a1
fix
SaitejaUtpala Dec 16, 2021
dbb49ba
change to tests
SaitejaUtpala Dec 16, 2021
72924eb
cholesky
SaitejaUtpala Dec 16, 2021
f1883f0
checking
SaitejaUtpala Dec 16, 2021
6007984
cleaning
SaitejaUtpala Dec 16, 2021
175f107
testing
SaitejaUtpala Dec 16, 2021
1411e5d
testing
SaitejaUtpala Dec 16, 2021
812c9aa
test even more
SaitejaUtpala Dec 16, 2021
3ae2113
add
SaitejaUtpala Dec 16, 2021
81b02c2
final
SaitejaUtpala Dec 16, 2021
d7f4584
realy final
SaitejaUtpala Dec 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions geomstats/_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"polygamma",
"power",
"prod",
"ravel_tril_indices",
"real",
"repeat",
"reshape",
Expand Down Expand Up @@ -117,7 +118,9 @@
"triu",
"tril_indices",
"triu_indices",
"tril_to_vec",
"triu_to_vec",
"vec_to_diag",
"unique",
"vectorize",
"vstack",
Expand Down
25 changes: 24 additions & 1 deletion geomstats/_backend/autograd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
trace,
transpose,
tril,
triu,
tril_indices,
triu,
triu_indices,
uint8,
unique,
Expand Down Expand Up @@ -385,12 +385,26 @@ def array_from_sparse(indices, data, target_shape):
return array(coo_matrix((data, list(zip(*indices))), target_shape).todense())


def tril_to_vec(x, k=0):
""" """
n = x.shape[-1]
rows, cols = tril_indices(n, k=k)
return x[..., rows, cols]


def triu_to_vec(x, k=0):
""" """
n = x.shape[-1]
rows, cols = triu_indices(n, k=k)
return x[..., rows, cols]


def vec_to_diag(vec):
"""Convert vector to diagonal matrix."""
d = vec.shape[-1]
return np.squeeze(vec[..., None, :] * np.eye(d)[None, :, :])


def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
"""Build matrix from given components.

Expand All @@ -415,3 +429,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
mat[..., j, k] = tri_upp
mat[..., k, j] = tri_low
return mat


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
else:
size = (n, m)
idxs = np.tril_indices(n, k, m)
return np.ravel_multi_index(idxs, size)
23 changes: 22 additions & 1 deletion geomstats/_backend/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@
trace,
transpose,
tril,
triu,
tril_indices,
triu,
triu_indices,
uint8,
unique,
Expand Down Expand Up @@ -385,6 +385,18 @@ def array_from_sparse(indices, data, target_shape):
return array(coo_matrix((data, list(zip(*indices))), target_shape).todense())


def vec_to_diag(vec):
"""Convert vector to diagonal matrix."""
d = vec.shape[-1]
return np.squeeze(vec[..., None, :] * np.eye(d)[None, :, :])


def tril_to_vec(x, k=0):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto: put this in the algebra_utils.py

n = x.shape[-1]
rows, cols = tril_indices(n, k=k)
return x[..., rows, cols]


def triu_to_vec(x, k=0):
n = x.shape[-1]
rows, cols = triu_indices(n, k=k)
Expand Down Expand Up @@ -415,3 +427,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
mat[..., j, k] = tri_upp
mat[..., k, j] = tri_low
return mat


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
else:
size = (n, m)
idxs = np.tril_indices(n, k, m)
return np.ravel_multi_index(idxs, size)
49 changes: 32 additions & 17 deletions geomstats/_backend/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch import (
ceil,
clip,
complex32,
complex64,
complex128,
cos,
cosh,
cross,
Expand All @@ -24,23 +27,10 @@
flatten,
float32,
float64,
complex32,
complex64,
complex128,
floor,
)
from torch import fmod as mod
from torch import (
greater,
hstack,
imag,
int32,
int64,
isnan,
less,
log,
logical_or,
)
from torch import greater, hstack, imag, int32, int64, isnan, less, log, logical_or
from torch import max as amax
from torch import mean, meshgrid
from torch import min as amin
Expand All @@ -57,8 +47,6 @@
std,
tan,
tanh,
tril,
triu,
uint8,
unique,
vstack,
Expand Down Expand Up @@ -514,6 +502,14 @@ def diag_indices(*args, **kwargs):
return tuple(map(torch.from_numpy, _np.diag_indices(*args, **kwargs)))


def tril(mat, k=0):
return torch.tril(mat, diagonal=k)


def triu(mat, k=0):
return torch.triu(mat, diagonal=k)


def tril_indices(n, k=0, m=None):
if m is None:
m = n
Expand Down Expand Up @@ -780,6 +776,16 @@ def vectorize(x, pyfunc, multiple_args=False, **kwargs):
return stack(list(map(pyfunc, x)))


def vec_to_diag(vec):
return torch.diag_embed(vec, offset=0)


def tril_to_vec(x, k=0):
n = x.shape[-1]
rows, cols = tril_indices(n, k=k)
return x[..., rows, cols]


def triu_to_vec(x, k=0):
n = x.shape[-1]
rows, cols = triu_indices(n, k=k)
Expand Down Expand Up @@ -812,6 +818,15 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
return mat


def sort(a, axis=- 1):
def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
else:
size = (n, m)
idxs = _np.tril_indices(n, k, m)
return torch.from_numpy(_np.ravel_multi_index(idxs, size))


def sort(a, axis=-1):
sorted_a, _ = torch.sort(a, dim=axis)
return sorted_a
5 changes: 1 addition & 4 deletions geomstats/_backend/pytorch/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def backward(ctx, grad):
return Logm._logm(backward_tensor).to(tensor.dtype)[..., :n, n:]


cholesky = torch.linalg.cholesky
eig = torch.linalg.eig
eigh = torch.linalg.eigh
eigvalsh = torch.linalg.eigvalsh
Expand All @@ -53,10 +54,6 @@ def backward(ctx, grad):
logm = Logm.apply


def cholesky(a):
return torch.cholesky(a, upper=False)


def sqrtm(x):
np_sqrtm = np.vectorize(scipy.linalg.sqrtm, signature="(n,m)->(n,m)")(x)
return torch.as_tensor(np_sqrtm, dtype=x.dtype)
Expand Down
64 changes: 53 additions & 11 deletions geomstats/_backend/tensorflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,16 +763,26 @@ def cumprod(a, axis=None):
return tf.math.cumprod(a, axis=axis)


def tril(m, k=0):
if k != 0:
raise NotImplementedError("Only k=0 supported so far")
return tf.linalg.band_part(m, -1, 0)


def triu(m, k=0):
if k != 0:
raise NotImplementedError("Only k=0 supported so far")
return tf.linalg.band_part(m, 0, -1)
# (sait) there is tf.experimental.tril (we can use it once it moves to stable)
def tril(mat, k=0):
if k not in (0, -1):
raise NotImplementedError("Only k=0 and k=-1 supported so far")
tril = tf.linalg.band_part(mat, -1, 0)
if k == 0:
return tril
zero_diag = tf.zeros(mat.shape[:-1])
return tf.linalg.set_diag(tril, zero_diag)


# TODO(sait) use tf.experimental.triu once it becomes stable.
def triu(mat, k=0):
if k not in (0, 1):
raise NotImplementedError("Only k=0 and k=1 supported so far")
triu = tf.linalg.band_part(mat, 0, -1)
if k == 0:
return triu
zero_diag = tf.zeros(mat.shape[:-1])
return tf.linalg.set_diag(triu, zero_diag)


def diag_indices(*args, **kwargs):
Expand Down Expand Up @@ -802,6 +812,19 @@ def where(condition, x=None, y=None):
return tf.where(condition, x, y)


def tril_to_vec(x, k=0):
n = x.shape[-1]
axis = 1 if x.ndim == 3 else 0
mask = tf.ones((n, n))
mask_a = tf.linalg.band_part(mask, -1, 0)
if k < 0:
mask_b = tf.linalg.band_part(mask, -k - 1, 0)
else:
mask_b = tf.zeros_like(mask_a)
mask = tf.cast(mask_a - mask_b, dtype=tf.bool)
return tf.boolean_mask(x, mask, axis=axis)


def triu_to_vec(x, k=0):
n = x.shape[-1]
axis = 1 if x.ndim == 3 else 0
Expand All @@ -823,8 +846,13 @@ def tile(x, multiples):
return tf.tile(x_reshape, multiples)


def vec_to_diag(vec):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put on top of the file as, see:

"""Convert vec to diagonal matrix"""
return tf.linalg.diag(vec)


def vec_to_triu(vec):
"""Take vec and forms strictly upper traingular matrix.
"""Take vec and forms strictly upper triangular matrix.

Parameters
---------
Expand Down Expand Up @@ -896,3 +924,17 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low):
triu_tril_mat = triu_mat + tril_mat
mat = tf.linalg.set_diag(triu_tril_mat, diag)
return mat


def _ravel_multi_index(multi_index, shape):
strides = tf.math.cumprod(shape, exclusive=True, reverse=True)
return tf.reduce_sum(multi_index * tf.expand_dims(strides, 1), axis=0)


def ravel_tril_indices(n, k=0, m=None):
if m is None:
size = (n, n)
else:
size = (n, m)
idxs = tril_indices(n, k, m)
return _ravel_multi_index(idxs, size)
Loading