-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
ENH: Add (put|take)_along_axis #11105
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, still looks very good and I think it makes sense to get this simpler version in first. Does need tests for put_along_axis
(which I think has a problem as is for axis=None
). Also, a suggestion for a clearer description.
numpy/lib/shape_base.py
Outdated
|
||
def take_along_axis(arr, indices, axis): | ||
""" | ||
Take the elements described by `indices` along each 1-D slice of the given |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd try for a true one-liner, maybe just the one you have above: "Take elements by matching the array and the index arrays" (though perhaps in all places adding "along axis"?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about "Take values from the input array by matching 1d index and data slices."?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not bad. I'd take that or perhaps
"Take values by matching 1d index and data slices along axis."
(and then also update the one-liners in the other file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now looking at the full docstring, I think your one-liner is really fine as is. I would still change it also in the See-also sections.
numpy/lib/shape_base.py
Outdated
for ii in ndindex(Ni): | ||
for j in range(J): | ||
for kk in ndindex(Nk): | ||
a_1d = a[ii + s_[j,] + kk] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This selects a single element! Would need to be s_[:,]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, good catch
numpy/lib/shape_base.py
Outdated
This function can be used to index with the result of `argsort`, `argmax`, | ||
and other `arg` functions. | ||
|
||
This is equivalent to (but faster than) the following use of `ndindex` and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find the description very hard to parse even though I implemented my own version of this! In part, it is because of using two other numpy routines which (I think) are rarely used. E.g., I had to look up what np.s_
was; I think one might as well omit that, and perhaps attempt to explain even more what ndindex
does, as in:
This is equivalent to (but faster than) the following use of `ndindex` to iterate
over the axes before and after the requested axis:
Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
J = indices.shape[axis]
out = np.empty(indices.shape, dtype=a.dtype)
for ii in ndindex(Ni): # ii is a tuple of indices into the axes before axis
for kk in ndindex(Nk): # kk is a tuple of indices into the axes after axis
a_1d = a[ii + (slice(0, M),) + kk] # values along axis for this ii, kk
indices_1d = indices[ii + (slice(0, J),) + kk] # corresponding indices
for j in range(J):
out[ii + (j,) + kk] = a_1d[indices_1d[j]]
Equivalently, eliminating the inner loop, the last two lines would be,
out[ii + (slice(0, J),) + kk] = a_1d[indices_1d]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at this again, it might even be an idea to add
out_1d = out[ii + (slice(0, J),) + kk] # corresponding view of output
for j in range(J):
out_1d[j] = a_1d[indices_1d[j]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this does at least link to ndindex
and s_
. I do like the extra temporaries you introduce here, but I would prefer to keep s_
as it makes the connection to :
clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the np.s_[:,]
is not as confusing as np.s_[j,]
(though the comma looks very odd). One argument in favour of the slice as I wrote it is that it uses M
and J
and thus might help make clearer that those dimensions are not necessarily the same (which I think is useful). But of course one could just write np.s_[0:M,]
, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the flipside, using :M
suggests that the dimension may not actually be of size M, which would also be confusing.
I've incorporated some of your ideas into the docstrings - perhaps best to create a new review thread on the updated one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I'd really like here is to just be able to spell it out[*ii, j, *kk]
- I tried patching python to allow this, but wasn't able to work out where to go after patching the grammar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I tried the same!
I do think your update is a lot clearer already. I would argue to keep the definition of out = np.empty(...)
since that helps to understand what the result shape should be and makes the code already runnable. And perhaps add a comment in the example after extracting J
along the lines of # Note that J does not have to equal M=a.shape[axis]
numpy/lib/shape_base.py
Outdated
axis = 0 | ||
else: | ||
axis = normalize_axis_index(axis, arr.ndim) | ||
if not _nx.issubdtype(indices.dtype, _nx.integer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this up right after the np.asanyarray(indices)
, or, perhaps better, to _make_along_axis_idx
(since that does other sanity checks as well that are shared with put_along_axis
)
numpy/lib/shape_base.py
Outdated
|
||
def put_along_axis(arr, indices, values, axis): | ||
""" | ||
Put `values` at the elements described by `indices` along each 1-D slice of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments on take_along_axis
numpy/lib/shape_base.py
Outdated
# normalize inputs | ||
indices = asanyarray(indices) | ||
if axis is None: | ||
arr = arr.ravel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This won't work if one cannot ravel the array without copying (in which case the assignment below would not affect the input array; do make a test case...). I think you'd need to use arr = arr.flat
(for symmetry, might as well do the same for take_along_axis
).
numpy/lib/tests/test_shape_base.py
Outdated
|
||
from numpy.lib.shape_base import ( | ||
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit, | ||
vsplit, dstack, column_stack, kron, tile, expand_dims, | ||
vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs tests for put_along_axis
as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will add these in a future commit. There's not any obvious way to test the argfunc/func equivalence like there was for take_along_axis
, as far as I can tell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fine to do in a separate commit. One can test, e.g., adding a very large number and then ensuring that argmax
gives back indices
.
a = rand(3, 4, 5) | ||
|
||
funcs = [ | ||
(np.sort, np.argsort, dict()), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Total nitpick, but replace dict()
by {}
(and {kth: 2}
further down)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was a deliberate choice, because dict(kth=2)
looks more like some_func(kth=2)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely fine to leave as is.
numpy/lib/shape_base.py
Outdated
arr = asanyarray(arr) | ||
indices = asanyarray(indices) | ||
if axis is None: | ||
arr = arr.ravel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a use-case for setting axis=None
? Personally I find this auto-flattening behavior quite distasteful.
The only legitimate use-case I think of is making thing simpler for 1D arrays, but that would need a default axis. Then the default value axis=None
would mean "axis=0 for 1D arrays, otherwise raise an error".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this behaviour, you can't use it to fully wrap argsort
back into sort
. I don't think argsort
and sort
should accept axis=None
either, but given they do, we want to support it here too.
See the maskedarray change in this PR for a case where we sadly have to support None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, that makes sense to me. Let's just make sure the behavior is documented, then.
I like this. My only objection is the auto-flattening behavior with |
numpy/lib/shape_base.py
Outdated
arr: array_like (Ni..., M, Nk...) | ||
source array | ||
indices: array_like (Ni..., J, Nk...) | ||
indices to take along each 1d slice of `arr` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should mention the broadcasting behavior with size 1 dimensions.
A case for def sort_by(func, x, axis):
ind = func(x).argsort(axis=axis)
return np.take_along_axis(x, ind, axis=axis) This allows Do we go with the status quo here, or make a deliberate effort to start dropping |
I agree that it's better to be consistent with |
Well, unless |
Updated with re-ordered and improved docstrings, and correct handling of Fixup commits have a description as their third line. Will squash once approved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only a docstring copy & paste left, plus making the one-liners in see-also consistent.
] | ||
|
||
|
||
def _make_along_axis_idx(arr_shape, indices, axis): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had considered asking you to just pass on the shape, but thought it was too much trouble, but now I see it ended up being needed! Nicer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also considered adding .shape
and .ndim
to flatiter, but then I'd need to do it to MaskedArray.flat
too, and it was a little unclear whether .shape and .ndim should refer to the flattened or original array.
numpy/lib/shape_base.py
Outdated
|
||
See Also | ||
-------- | ||
take_along_axis : Take along an axis without matching up subspaces |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops, copy & paste error...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed and reworded here and for the other function
numpy/lib/shape_base.py
Outdated
[40, 60]]) | ||
""" | ||
# normalize inputs | ||
arr = asanyarray(arr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More as a note, really, but this one line will prevent me from actually using the routine in astropy's Time
class (which is not an ndarray
subclass, but does allow indexing like it). I guess right now there is little choice but to do it, and eventually this should become a check for sufficient duck-typing. Anyway, an additional argument to make _make_along_axis_idx
public in a follow-up PR (which means thinking of a better name....).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a better follow-up would be to implement this as a gufunc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps I should just omit this? I don't see any point in allowing the caller to pass a list, and all I actually need here is .shape
, .flat
, and .__getitem__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, why not indeed, since realistically this is not going to be used on lists, but on things that indices were taken from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do I leave the asanyarray
on indices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might as well remove it too. These indices would be very hard to construct correctly as a nested list anyway.
ad0ec85
to
fcdced0
Compare
@mhvk: |
Darn, now your test with invalid index input raises a different error... |
This is the reduced version that does not allow any insertion of extra dimensions
fcdced0
to
7a3c50a
Compare
Hadn't noticed this was now fixed - all looks OK, so will merge. Nice addition! |
Wonder if we should have run this by the mailing list, for the sake of name bike-shedding... |
Could still do it? |
Done |
This is a less ambitious version of #8714, as I proposed in this comment
Essentially, this limits the function to the obvious broadcasting cases when the indices and array have the same number of dimensions. This is sufficiently powerful that any other broadcasting behavior can be achieved by inserting extra dimensions, but keeps to:
An expansion of these functions is given in their docstring (using the format adopted by #9946), but the one-line summaries are:
out[ii + s_[j,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[j,] + kk]]
a[ii + s_[:,] + kk][indices[ii + s_[j,] + kk]] = values[ii + s_[j,] + kk]
Or perhaps more clearly in terms of basic indexing only:
out[ii + s_[j,] + kk] = a[ii + s_[indices[ii + s_[j,] + kk],] + kk]
a[ii + s_[indices[ii + s_[j,] + kk],] + kk] = values[ii + s_[j,] + kk]