Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH: Add (put|take)_along_axis #11105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 29, 2018
Merged

Conversation

eric-wieser
Copy link
Member

@eric-wieser eric-wieser commented May 16, 2018

This is a less ambitious version of #8714, as I proposed in this comment

Essentially, this limits the function to the obvious broadcasting cases when the indices and array have the same number of dimensions. This is sufficiently powerful that any other broadcasting behavior can be achieved by inserting extra dimensions, but keeps to:

In the face of ambiguity, refuse the temptation to guess.

An expansion of these functions is given in their docstring (using the format adopted by #9946), but the one-line summaries are:

  • put: out[ii + s_[j,] + kk] = a[ii + s_[:,] + kk][indices[ii + s_[j,] + kk]]
  • take: a[ii + s_[:,] + kk][indices[ii + s_[j,] + kk]] = values[ii + s_[j,] + kk]

Or perhaps more clearly in terms of basic indexing only:

  • put: out[ii + s_[j,] + kk] = a[ii + s_[indices[ii + s_[j,] + kk],] + kk]
  • take: a[ii + s_[indices[ii + s_[j,] + kk],] + kk] = values[ii + s_[j,] + kk]

Copy link
Contributor

@mhvk mhvk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, still looks very good and I think it makes sense to get this simpler version in first. Does need tests for put_along_axis (which I think has a problem as is for axis=None). Also, a suggestion for a clearer description.


def take_along_axis(arr, indices, axis):
"""
Take the elements described by `indices` along each 1-D slice of the given
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd try for a true one-liner, maybe just the one you have above: "Take elements by matching the array and the index arrays" (though perhaps in all places adding "along axis"?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about "Take values from the input array by matching 1d index and data slices."?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not bad. I'd take that or perhaps
"Take values by matching 1d index and data slices along axis."
(and then also update the one-liners in the other file)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now looking at the full docstring, I think your one-liner is really fine as is. I would still change it also in the See-also sections.

for ii in ndindex(Ni):
for j in range(J):
for kk in ndindex(Nk):
a_1d = a[ii + s_[j,] + kk]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This selects a single element! Would need to be s_[:,]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, good catch

This function can be used to index with the result of `argsort`, `argmax`,
and other `arg` functions.

This is equivalent to (but faster than) the following use of `ndindex` and
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find the description very hard to parse even though I implemented my own version of this! In part, it is because of using two other numpy routines which (I think) are rarely used. E.g., I had to look up what np.s_ was; I think one might as well omit that, and perhaps attempt to explain even more what ndindex does, as in:

    This is equivalent to (but faster than) the following use of `ndindex` to iterate
    over the axes before and after the requested axis:

        Ni, M, Nk = a.shape[:axis], a.shape[axis], a.shape[axis+1:]
        J = indices.shape[axis]
        out = np.empty(indices.shape, dtype=a.dtype)

        for ii in ndindex(Ni):  # ii is a tuple of indices into the axes before axis
            for kk in ndindex(Nk):  # kk is a tuple of indices into the axes after axis
                a_1d = a[ii + (slice(0, M),) + kk]  # values along axis for this ii, kk
                indices_1d = indices[ii + (slice(0, J),) + kk]  # corresponding indices
                for j in range(J):
                    out[ii + (j,) + kk] = a_1d[indices_1d[j]]

    Equivalently, eliminating the inner loop, the last two lines would be,

                out[ii + (slice(0, J),) + kk] = a_1d[indices_1d]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this again, it might even be an idea to add

out_1d = out[ii + (slice(0, J),) + kk]  # corresponding view of output
for j in range(J):
    out_1d[j] = a_1d[indices_1d[j]]

Copy link
Member Author

@eric-wieser eric-wieser May 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this does at least link to ndindex and s_. I do like the extra temporaries you introduce here, but I would prefer to keep s_ as it makes the connection to : clear.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the np.s_[:,] is not as confusing as np.s_[j,] (though the comma looks very odd). One argument in favour of the slice as I wrote it is that it uses M and J and thus might help make clearer that those dimensions are not necessarily the same (which I think is useful). But of course one could just write np.s_[0:M,], etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the flipside, using :M suggests that the dimension may not actually be of size M, which would also be confusing.

I've incorporated some of your ideas into the docstrings - perhaps best to create a new review thread on the updated one

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I'd really like here is to just be able to spell it out[*ii, j, *kk] - I tried patching python to allow this, but wasn't able to work out where to go after patching the grammar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I tried the same!

I do think your update is a lot clearer already. I would argue to keep the definition of out = np.empty(...) since that helps to understand what the result shape should be and makes the code already runnable. And perhaps add a comment in the example after extracting J along the lines of # Note that J does not have to equal M=a.shape[axis]

axis = 0
else:
axis = normalize_axis_index(axis, arr.ndim)
if not _nx.issubdtype(indices.dtype, _nx.integer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this up right after the np.asanyarray(indices), or, perhaps better, to _make_along_axis_idx (since that does other sanity checks as well that are shared with put_along_axis)


def put_along_axis(arr, indices, values, axis):
"""
Put `values` at the elements described by `indices` along each 1-D slice of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments on take_along_axis

# normalize inputs
indices = asanyarray(indices)
if axis is None:
arr = arr.ravel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't work if one cannot ravel the array without copying (in which case the assignment below would not affect the input array; do make a test case...). I think you'd need to use arr = arr.flat (for symmetry, might as well do the same for take_along_axis).


from numpy.lib.shape_base import (
apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
vsplit, dstack, column_stack, kron, tile, expand_dims,
vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs tests for put_along_axis as well...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add these in a future commit. There's not any obvious way to test the argfunc/func equivalence like there was for take_along_axis, as far as I can tell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to do in a separate commit. One can test, e.g., adding a very large number and then ensuring that argmax gives back indices.

a = rand(3, 4, 5)

funcs = [
(np.sort, np.argsort, dict()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Total nitpick, but replace dict() by {} (and {kth: 2} further down)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a deliberate choice, because dict(kth=2) looks more like some_func(kth=2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely fine to leave as is.

arr = asanyarray(arr)
indices = asanyarray(indices)
if axis is None:
arr = arr.ravel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a use-case for setting axis=None? Personally I find this auto-flattening behavior quite distasteful.

The only legitimate use-case I think of is making thing simpler for 1D arrays, but that would need a default axis. Then the default value axis=None would mean "axis=0 for 1D arrays, otherwise raise an error".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this behaviour, you can't use it to fully wrap argsort back into sort. I don't think argsort and sort should accept axis=None either, but given they do, we want to support it here too.

See the maskedarray change in this PR for a case where we sadly have to support None.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, that makes sense to me. Let's just make sure the behavior is documented, then.

@shoyer
Copy link
Member

shoyer commented May 16, 2018

I like this. My only objection is the auto-flattening behavior with axis=None.

arr: array_like (Ni..., M, Nk...)
source array
indices: array_like (Ni..., J, Nk...)
indices to take along each 1d slice of `arr`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should mention the broadcasting behavior with size 1 dimensions.

@eric-wieser
Copy link
Member Author

A case for axis=None:

def sort_by(func, x, axis):
    ind = func(x).argsort(axis=axis)
    return np.take_along_axis(x, ind, axis=axis)

This allows sort_by(np.abs, x, axis=None) to be used where np.sort(x, axis=None) was previously used. I don't think any of these functions should accept None, but they do.

Do we go with the status quo here, or make a deliberate effort to start dropping axis=None support from all new functions?

@shoyer
Copy link
Member

shoyer commented May 16, 2018

Do we go with the status quo here, or make a deliberate effort to start dropping axis=None support from all new functions?

I agree that it's better to be consistent with argsort() and similar functions. In the broader scheme of things, flattening with axis=None is not so terrible.

@eric-wieser
Copy link
Member Author

flattening with axis=None is not so terrible.

Well, unless axis=None is the default, and then it's pretty terrible.

@eric-wieser
Copy link
Member Author

eric-wieser commented May 17, 2018

Updated with re-ordered and improved docstrings, and correct handling of axis=None.

Fixup commits have a description as their third line. Will squash once approved.

Copy link
Contributor

@mhvk mhvk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only a docstring copy & paste left, plus making the one-liners in see-also consistent.

]


def _make_along_axis_idx(arr_shape, indices, axis):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had considered asking you to just pass on the shape, but thought it was too much trouble, but now I see it ended up being needed! Nicer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also considered adding .shape and .ndim to flatiter, but then I'd need to do it to MaskedArray.flat too, and it was a little unclear whether .shape and .ndim should refer to the flattened or original array.


See Also
--------
take_along_axis : Take along an axis without matching up subspaces
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, copy & paste error...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed and reworded here and for the other function

[40, 60]])
"""
# normalize inputs
arr = asanyarray(arr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More as a note, really, but this one line will prevent me from actually using the routine in astropy's Time class (which is not an ndarray subclass, but does allow indexing like it). I guess right now there is little choice but to do it, and eventually this should become a check for sufficient duck-typing. Anyway, an additional argument to make _make_along_axis_idx public in a follow-up PR (which means thinking of a better name....).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a better follow-up would be to implement this as a gufunc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I should just omit this? I don't see any point in allowing the caller to pass a list, and all I actually need here is .shape, .flat, and .__getitem__

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, why not indeed, since realistically this is not going to be used on lists, but on things that indices were taken from.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do I leave the asanyarray on indices?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well remove it too. These indices would be very hard to construct correctly as a nested list anyway.

@eric-wieser eric-wieser force-pushed the take_along_axis-strict branch from ad0ec85 to fcdced0 Compare May 25, 2018 16:36
@eric-wieser
Copy link
Member Author

@mhvk: asanyarray removed, and squashed

@mhvk
Copy link
Contributor

mhvk commented May 25, 2018

Darn, now your test with invalid index input raises a different error...

This is the reduced version that does not allow any insertion of extra dimensions
@eric-wieser eric-wieser force-pushed the take_along_axis-strict branch from fcdced0 to 7a3c50a Compare May 26, 2018 05:56
@mhvk
Copy link
Contributor

mhvk commented May 29, 2018

Hadn't noticed this was now fixed - all looks OK, so will merge. Nice addition!

@eric-wieser
Copy link
Member Author

Wonder if we should have run this by the mailing list, for the sake of name bike-shedding...

@eric-wieser eric-wieser added this to the 1.15.0 release milestone May 29, 2018
@mhvk
Copy link
Contributor

mhvk commented May 29, 2018

Could still do it?

@eric-wieser
Copy link
Member Author

Done

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants