-
Notifications
You must be signed in to change notification settings - Fork 275
Cholesky space #1142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Cholesky space #1142
Changes from all commits
Commits
Show all changes
106 commits
Select commit
Hold shift + click to select a range
99129e9
first commit
SaitejaUtpala 6140e6a
add aux methods
SaitejaUtpala e1bb569
clean
SaitejaUtpala c5ac1df
projection methods
SaitejaUtpala 980335c
to from vecs
SaitejaUtpala 6755c6f
tf tril_to_vec
SaitejaUtpala 6372c49
add cholesky factor
SaitejaUtpala c5d2417
cleaning
SaitejaUtpala 995b485
testing
SaitejaUtpala 878e476
cholesky metric
SaitejaUtpala 7e308c6
metric
SaitejaUtpala cbf57b0
R-exp and R-log
SaitejaUtpala ed89a0a
differential
SaitejaUtpala 1ea9273
gram and its differential
SaitejaUtpala 4e1345b
comments
SaitejaUtpala 2782f2a
inv differential
SaitejaUtpala 981ee1b
some tests
SaitejaUtpala 7c6cabb
tests
SaitejaUtpala 9d5136a
tests for lwt
SaitejaUtpala ed8da9c
squared_dist for cholesky
SaitejaUtpala 8393bcf
resolve conflicts
SaitejaUtpala 176dfe5
(gh-action-bot) Format Python code with black push
SaitejaUtpala c64b4f3
corrections
SaitejaUtpala c63b4b1
merged
SaitejaUtpala 1f5c2d1
cleaning
SaitejaUtpala 9e7c8c1
typo
SaitejaUtpala 8fc1d43
Rexp and Rlog maps
SaitejaUtpala 4a4243a
(gh-action-bot) Format Python code with black push
SaitejaUtpala b922f4e
correction in squared_dist
SaitejaUtpala 575d978
linting
SaitejaUtpala 0712536
vec to diag without assignment
SaitejaUtpala 13fcf24
(gh-action-bot) Format Python code with black push
SaitejaUtpala 4a609b6
remove unn code
SaitejaUtpala 5c8eaa4
reformat
SaitejaUtpala 336b816
renaming
SaitejaUtpala 011731b
renaming-2
SaitejaUtpala dc4e1e0
(gh-action-bot) Format Python code with black push
SaitejaUtpala fd7285b
tf_and_pytorch vec_to_diag
SaitejaUtpala 555579a
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala 1ff02a3
reformat
SaitejaUtpala e7b6d92
fix error in tets
SaitejaUtpala 1ca14f2
typo
SaitejaUtpala ea3f941
change projection and random point
SaitejaUtpala 185b958
(gh-action-bot) Format Python code with black push
SaitejaUtpala 279bbb6
gs.diagonal -> Matrices.diagonal
SaitejaUtpala 265ff61
change projection method to handle case of 0
SaitejaUtpala af720ad
typo
SaitejaUtpala b5f9da0
cleaning
SaitejaUtpala b2517a1
tests for metric on cholesky manifold
SaitejaUtpala d2097f9
test for cholesky factor
SaitejaUtpala da85c6d
differential of cholesky factor
SaitejaUtpala 4fddf16
tests for Matrices method
SaitejaUtpala 962d27b
(gh-action-bot) Format Python code with black push
SaitejaUtpala ae421fb
test for squared dist
SaitejaUtpala 2e067b5
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala babce12
more tests
SaitejaUtpala 160167f
more tets
SaitejaUtpala 1bd87e3
fix few erorrs
SaitejaUtpala 146823f
remove some more errors
SaitejaUtpala f733768
(gh-action-bot) Format Python code with black push
SaitejaUtpala 4dbacec
fix errors
SaitejaUtpala 62e7d0f
fix inv_diff gram
SaitejaUtpala 9bec436
belongs tests
SaitejaUtpala d130d07
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala 84879e5
tests
SaitejaUtpala 567600f
(gh-action-bot) Format Python code with black push
SaitejaUtpala 92783a8
fix more errors
SaitejaUtpala f507dea
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala 003ce99
fix bug in to diagonal
SaitejaUtpala ef8abbd
fix ellipsis error
SaitejaUtpala 096183d
linting
SaitejaUtpala 7672af0
testing
SaitejaUtpala b1fe05d
debugging
SaitejaUtpala b075ff5
debug
SaitejaUtpala dbdcd51
reuse inv differential gram
SaitejaUtpala 67aa543
reduce duplication
SaitejaUtpala 66be7ea
fix tf tril
SaitejaUtpala 737ac5b
fix bug in tril
SaitejaUtpala 48572b9
one last fix
SaitejaUtpala 0f11886
rm flake8
SaitejaUtpala 3a23065
(gh-action-bot) Format Python code with black push
SaitejaUtpala f9ae8f2
redo
SaitejaUtpala 4fc9747
Merge branch 'cholesky-space' of github.com:saitejautpala/geomstats i…
SaitejaUtpala 0acf925
test get basis method
SaitejaUtpala e0501dc
typo fix
SaitejaUtpala d18d8b8
fix typo again
SaitejaUtpala 448490d
make it work for tf
SaitejaUtpala a88bcaf
cumprod fix
SaitejaUtpala 91906ea
reshape
SaitejaUtpala 1b2c2cc
fix deepsource
SaitejaUtpala 5c02cc9
chol
SaitejaUtpala 5c1be75
some changes
SaitejaUtpala 79183c1
resolve conflicts
SaitejaUtpala 2acf22b
add dots
SaitejaUtpala bb0d188
address pr review comments-1
SaitejaUtpala 8aff1a1
fix
SaitejaUtpala dbb49ba
change to tests
SaitejaUtpala 72924eb
cholesky
SaitejaUtpala f1883f0
checking
SaitejaUtpala 6007984
cleaning
SaitejaUtpala 175f107
testing
SaitejaUtpala 1411e5d
testing
SaitejaUtpala 812c9aa
test even more
SaitejaUtpala 3ae2113
add
SaitejaUtpala 81b02c2
final
SaitejaUtpala d7f4584
realy final
SaitejaUtpala File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -92,8 +92,8 @@ | |
| trace, | ||
| transpose, | ||
| tril, | ||
| triu, | ||
| tril_indices, | ||
| triu, | ||
| triu_indices, | ||
| uint8, | ||
| unique, | ||
|
|
@@ -385,6 +385,18 @@ def array_from_sparse(indices, data, target_shape): | |
| return array(coo_matrix((data, list(zip(*indices))), target_shape).todense()) | ||
|
|
||
|
|
||
| def vec_to_diag(vec): | ||
ninamiolane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Convert vector to diagonal matrix.""" | ||
| d = vec.shape[-1] | ||
| return np.squeeze(vec[..., None, :] * np.eye(d)[None, :, :]) | ||
|
|
||
|
|
||
| def tril_to_vec(x, k=0): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ditto: put this in the algebra_utils.py |
||
| n = x.shape[-1] | ||
| rows, cols = tril_indices(n, k=k) | ||
| return x[..., rows, cols] | ||
|
|
||
|
|
||
| def triu_to_vec(x, k=0): | ||
| n = x.shape[-1] | ||
| rows, cols = triu_indices(n, k=k) | ||
|
|
@@ -415,3 +427,12 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low): | |
| mat[..., j, k] = tri_upp | ||
| mat[..., k, j] = tri_low | ||
| return mat | ||
|
|
||
|
|
||
| def ravel_tril_indices(n, k=0, m=None): | ||
| if m is None: | ||
| size = (n, n) | ||
| else: | ||
| size = (n, m) | ||
| idxs = np.tril_indices(n, k, m) | ||
| return np.ravel_multi_index(idxs, size) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -763,16 +763,26 @@ def cumprod(a, axis=None): | |
| return tf.math.cumprod(a, axis=axis) | ||
|
|
||
|
|
||
| def tril(m, k=0): | ||
| if k != 0: | ||
| raise NotImplementedError("Only k=0 supported so far") | ||
| return tf.linalg.band_part(m, -1, 0) | ||
|
|
||
|
|
||
| def triu(m, k=0): | ||
| if k != 0: | ||
| raise NotImplementedError("Only k=0 supported so far") | ||
| return tf.linalg.band_part(m, 0, -1) | ||
| # (sait) there is tf.experimental.tril (we can use it once it moves to stable) | ||
SaitejaUtpala marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def tril(mat, k=0): | ||
| if k not in (0, -1): | ||
| raise NotImplementedError("Only k=0 and k=-1 supported so far") | ||
| tril = tf.linalg.band_part(mat, -1, 0) | ||
| if k == 0: | ||
| return tril | ||
| zero_diag = tf.zeros(mat.shape[:-1]) | ||
| return tf.linalg.set_diag(tril, zero_diag) | ||
|
|
||
|
|
||
| # TODO(sait) use tf.experimental.triu once it becomes stable. | ||
| def triu(mat, k=0): | ||
| if k not in (0, 1): | ||
| raise NotImplementedError("Only k=0 and k=1 supported so far") | ||
| triu = tf.linalg.band_part(mat, 0, -1) | ||
| if k == 0: | ||
| return triu | ||
| zero_diag = tf.zeros(mat.shape[:-1]) | ||
| return tf.linalg.set_diag(triu, zero_diag) | ||
|
|
||
|
|
||
| def diag_indices(*args, **kwargs): | ||
|
|
@@ -802,6 +812,19 @@ def where(condition, x=None, y=None): | |
| return tf.where(condition, x, y) | ||
|
|
||
|
|
||
| def tril_to_vec(x, k=0): | ||
ninamiolane marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| n = x.shape[-1] | ||
| axis = 1 if x.ndim == 3 else 0 | ||
| mask = tf.ones((n, n)) | ||
| mask_a = tf.linalg.band_part(mask, -1, 0) | ||
| if k < 0: | ||
| mask_b = tf.linalg.band_part(mask, -k - 1, 0) | ||
| else: | ||
| mask_b = tf.zeros_like(mask_a) | ||
| mask = tf.cast(mask_a - mask_b, dtype=tf.bool) | ||
| return tf.boolean_mask(x, mask, axis=axis) | ||
|
|
||
|
|
||
| def triu_to_vec(x, k=0): | ||
| n = x.shape[-1] | ||
| axis = 1 if x.ndim == 3 else 0 | ||
|
|
@@ -823,8 +846,13 @@ def tile(x, multiples): | |
| return tf.tile(x_reshape, multiples) | ||
|
|
||
|
|
||
| def vec_to_diag(vec): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Put on top of the file as, see: |
||
| """Convert vec to diagonal matrix""" | ||
| return tf.linalg.diag(vec) | ||
|
|
||
|
|
||
| def vec_to_triu(vec): | ||
| """Take vec and forms strictly upper traingular matrix. | ||
| """Take vec and forms strictly upper triangular matrix. | ||
|
|
||
| Parameters | ||
| --------- | ||
|
|
@@ -896,3 +924,17 @@ def mat_from_diag_triu_tril(diag, tri_upp, tri_low): | |
| triu_tril_mat = triu_mat + tril_mat | ||
| mat = tf.linalg.set_diag(triu_tril_mat, diag) | ||
| return mat | ||
|
|
||
|
|
||
| def _ravel_multi_index(multi_index, shape): | ||
| strides = tf.math.cumprod(shape, exclusive=True, reverse=True) | ||
| return tf.reduce_sum(multi_index * tf.expand_dims(strides, 1), axis=0) | ||
|
|
||
|
|
||
| def ravel_tril_indices(n, k=0, m=None): | ||
| if m is None: | ||
| size = (n, n) | ||
| else: | ||
| size = (n, m) | ||
| idxs = tril_indices(n, k, m) | ||
| return _ravel_multi_index(idxs, size) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.