-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
ENH: stats: add axis tuple support to _axis_nan_policy_factory decorators #15257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| Broadcast shapes, ignoring incompatibility of specified axes | ||
| """ | ||
| if not shapes: | ||
| return shapes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If shapes is an empty array, just return it; we don't want an error.
scipy/stats/_axis_nan_policy.py
Outdated
|
|
||
| # Remove the shape elements of the axes to be ignored, but remember them. | ||
| if axis is not None: | ||
| axis = np.atleast_1d(axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now this is already done above.
| if axis is not None: | ||
| axis = np.atleast_1d(axis) | ||
| axis[axis < 0] = n_dims + axis[axis < 0] | ||
| axis = np.sort(axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The was actually needed to fix a bug in the existing _broadcast_shapes, which was supposed to work for axis tuples already - but previously it only worked if the axes were already sorted. test_other_axis_tuples now checks that everything works even when axes are passed in out of order.
| axis = np.atleast_1d(axis) | ||
| axis[axis < 0] = n_dims + axis[axis < 0] | ||
| axis = np.sort(axis) | ||
| if axis[-1] >= n_dims or axis[0] < 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that axes are converted to be all positive above.
| # standardize to always work along last axis | ||
| if axis is None: | ||
| samples = [sample.ravel() for sample in samples] | ||
| axis = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed this to axis=-1 below. For axis is None, they're the same.
| elif axis != int(axis): | ||
| raise ValueError('`axis` must be an integer') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Input validation is handled in _broadcast_arrays now.
| axis = np.atleast_1d(axis) | ||
| n_axes = len(axis) | ||
| # move all axes in `axis` to the end to be raveled | ||
| samples = [np.moveaxis(sample, axis, range(-len(axis), 0)) | ||
| for sample in samples] | ||
| shapes = [sample.shape for sample in samples] | ||
| # New shape is unchanged for all axes _not_ in `axis` | ||
| # At the end, we append the product of the shapes of the axes | ||
| # in `axis`. Appending -1 doesn't work for zero-size arrays! | ||
| new_shapes = [shape[:-n_axes] + (np.prod(shape[-n_axes:]),) | ||
| for shape in shapes] | ||
| samples = [sample.reshape(new_shape) | ||
| for sample, new_shape in zip(samples, new_shapes)] | ||
| axis = -1 # work over the last axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the heart of the changes. It's a little more complicated than first expected because we have to calculate the size of the last axis (all axes in axis raveled) manually; -1 doesn't work for empty arrays.
|
|
||
|
|
||
| @pytest.mark.parametrize(("axis"), range(-2, 2)) | ||
| @pytest.mark.parametrize(("axis"), range(-3, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test (already merged) was supposed to work for this range, but there was a bug before. It's fixed now.
scipy/stats/_axis_nan_policy.py
Outdated
| raise ValueError("Array shapes are incompatible for broadcasting.") | ||
| return tuple(new_shape) | ||
| shapes = _broadcast_shapes(shapes, axis) | ||
| shape = np.delete(shapes[0], axis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't work when axis=None:
In [1]: from from scipy.stats._axis_nan_policy import _broadcast_shapes_remove_axis
In [2]: _broadcast_shapes_remove_axis([(3, 9, 2, 5), (3, 9, 2, 5)])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-35-66e4d08b298f> in <module>
----> 1 _broadcast_shapes_remove_axis([(3, 9, 2, 5), (3, 9, 2, 5)])
~/Desktop/scipy_source/scipy/stats/_axis_nan_policy.py in _broadcast_shapes_remove_axis(shapes, axis)
122 """
123 shapes = _broadcast_shapes(shapes, axis)
--> 124 shape = np.delete(shapes[0], axis)
125 return tuple(shape)
126
<__array_function__ internals> in delete(*args, **kwargs)
~/Desktop/scipy_source/scipy-dev/lib/python3.9/site-packages/numpy/lib/function_base.py in delete(arr, obj, axis)
4550 else:
4551 keep = ones(N, dtype=bool)
-> 4552 keep[obj,] = False
4553
4554 slobj[axis] = keep
IndexError: arrays used as indices must be of integer (or boolean) typeI guess we never use this function unless we want to remove some axes. In that case, it's better to not default axis to None so the signature is just _broadcast_shapes_remove_axis(shapes, axis)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous version did work with axis=None. I guess we should preserve that behavior:
def _broadcast_shapes_remove_axis(shapes, axis=None):
"""
Broadcast shapes, dropping specified axes
Same as _broadcast_array_shapes, but given a sequence
of array shapes `shapes` instead of the arrays themselves.
"""
shapes = _broadcast_shapes(shapes, axis)
if axis is not None:
shapes = shapes[0]
shape = np.delete(shapes, axis)
return tuple(shape)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I'll change these things when we get CI back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Er, I think you meant:
shapes = _broadcast_shapes(shapes, axis)
shape = shapes[0]
if axis is not None:
shape = np.delete(shape, axis)
return tuple(shape)That's what I did.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, you are right! Although, this would still fail for cases like these:
In [1]: import numpy as np
In [2]: from scipy.stats._axis_nan_policy import _check_empty_inputs
In [3]: _check_empty_inputs([np.array([]), np.array([])], None)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-8945ea70ea1a> in <module>
----> 1 _check_empty_inputs([np.array([]), np.array([])], None)
~/Desktop/scipy_source/scipy/stats/_axis_nan_policy.py in _check_empty_inputs(samples, axis)
222 # otherwise, the statistic and p-value will be either empty arrays or
223 # arrays with NaNs. Produce the appropriate array and return it.
--> 224 output_shape = _broadcast_array_shapes_remove_axis(samples, axis)
225 output = np.ones(output_shape) * np.nan
226 return output
~/Desktop/scipy_source/scipy/stats/_axis_nan_policy.py in _broadcast_array_shapes_remove_axis(arrays, axis)
111 # ravel arrays before broadcasting.
112 shapes = [arr.shape for arr in arrays]
--> 113 return _broadcast_shapes_remove_axis(shapes, axis)
114
115
~/Desktop/scipy_source/scipy/stats/_axis_nan_policy.py in _broadcast_shapes_remove_axis(shapes, axis)
125 if axis is not None:
126 shape = np.delete(shape, axis)
--> 127 return tuple(shape)
128
129
TypeError: 'numpy.int64' object is not iterableBut from the _axis_nan_policy_factory code, it looks like the case of axis=None is never hit. So, I think we don't need to worry about this.
|
@mdhaber it's going to take me a while to read back into all this. I'm also going to be busy with family stuff until Christmas is over. I'll should have enough time to give this a thorough review between Christmas and New Years. Is that going to be too late for you? |
|
@Kai-Striega Thanks!
I can't expect other volunteers to work on my timetable! That said, I have some time this week, and I would like to move on to the more exciting part of all this - applying the decorator to more functions and testing those out. Since @tirthasheshpatel reviewed the recent PR and started to review this PR, maybe it's not necessary for you to dig deeply into the guts of the decorator right now, and you could spend your review time in the coming weeks on the next step (applying the decorator to more functions and testing those)? @tirthasheshpatel Were you pretty happy with this? I made the change that you suggested and fixed a lint error. More of the CI suite is running now than when I submitted, and that's all looking good. What would you prefer Kai work on (reviewing this PR or the application to new functions)? |
Yeah, this PR looks in a very good shape!
I'd agree with you. @Kai-Striega I looked at the internals of the decorator but could use a helping hand in experimenting with the decorator and applying it to stats functions. Feel free to review this PR if you have time, otherwise, it would be great if you could review the PRs that apply the decorator! |
|
So @tirthasheshpatel is this ready? As soon as this is merged I'll post the gmean PR again. |
Yes, I was just doing some experiments with tuple axis and it seems to work as expected. I will approve a merge this! |
tirthasheshpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test failures are unrelated. Although the decorator has become quite complex, the tests are very strong and seem to validate that it works. So, merging! Thanks @mdhaber!
Reference issue
gh-14651 (checks the second box)
What does this implement/fix?
gh-13312 added
_axis_nan_policy_factory, which returns a decorator that addsaxisandnan_policyarguments to stats functions. With the additions in this PR, the decorators produced now supportaxisarguments that are tuples of integers.Additional information
This implements the steps discussed in this comment.
To facilitate review, I've left some self-review comments to explain the changes. Feel free to mark them resolved.
I would also suggest viewing the changes in three separate steps.
_broadcast_arraysand related functions are moved (verbatim) from_hypotests.pyinto_axis_nan_policy.py, a better permanent home_broadcast...functions are consolidated, so that the main logic is only in_broadcast_arrays_axis_nan_policy_factoryto supportaxistuples and add tests.Details
Selecting the last two commits in the last step will also select the first three commits (because of the merge commit).
This should be squash-merged; the commit history is messier than intended.