Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH/WIP: introduce a where keyword for reductions. #12635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions numpy/core/src/multiarray/nditer_constr.c
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,9 @@ npyiter_check_per_op_flags(npy_uint32 op_flags, npyiter_opitflags *op_itflags)
}

*op_itflags = NPY_OP_ITFLAG_READ;
if (op_flags & NPY_ITER_UPDATEIFCOPY) {
*op_itflags |= NPY_OP_ITFLAG_CAST;
}
}
else if (op_flags & NPY_ITER_READWRITE) {
/* The read/write flags are mutually exclusive */
Expand Down
43 changes: 28 additions & 15 deletions numpy/core/src/umath/reduction.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,9 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,

/* Iterator parameters */
NpyIter *iter = NULL;
PyArrayObject *op[2];
PyArray_Descr *op_dtypes[2];
npy_uint32 flags, op_flags[2];
PyArrayObject *op[4];
PyArray_Descr *op_dtypes[4];
npy_uint32 flags, op_flags[4];

/* More than one axis means multiple orders are possible */
if (!reorderable && count_axes(PyArray_NDIM(operand), axis_flags) > 1) {
Expand All @@ -457,15 +457,6 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
return NULL;
}


/* Validate that the parameters for future expansion are NULL */
if (wheremask != NULL) {
PyErr_SetString(PyExc_RuntimeError,
"Reduce operations in NumPy do not yet support "
"a where mask");
return NULL;
}

/*
* This either conforms 'out' to the ndim of 'operand', or allocates
* a new array appropriate for this reduction.
Expand Down Expand Up @@ -493,6 +484,13 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
Py_INCREF(op_view);
}
else {
/* Cannot use where when we initialize from the operand */
if (wheremask != NULL) {
PyErr_SetString(PyExc_RuntimeError,
"Reduce operations with no idenity do not yet support "
"a where mask");
return NULL;
}
op_view = PyArray_InitializeReduceResult(
result, operand, axis_flags, &skip_first_count, funcname);
if (op_view == NULL) {
Expand Down Expand Up @@ -523,9 +521,24 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
NPY_ITER_ALIGNED |
NPY_ITER_NO_SUBTYPE;
op_flags[1] = NPY_ITER_READONLY |
NPY_ITER_ALIGNED;
NPY_ITER_ALIGNED |
NPY_ITER_NO_BROADCAST;
if (wheremask != NULL) {
op_flags[1] |= NPY_ITER_UPDATEIFCOPY;
op[2] = wheremask;
op_dtypes[2] = PyArray_DescrFromType(NPY_BOOL);
if (op_dtypes[2] == NULL) {
goto fail;
}
op_flags[2] = NPY_ITER_READONLY |
NPY_ITER_ALIGNED;
op[3] = (PyArrayObject *)PyArray_FromScalar(identity, operand_dtype);
op_dtypes[3] = operand_dtype;
op_flags[3] = NPY_ITER_READONLY |
NPY_ITER_ALIGNED;
}

iter = NpyIter_AdvancedNew(2, op, flags,
iter = NpyIter_AdvancedNew(wheremask == NULL ? 2 : 4, op, flags,
NPY_KEEPORDER, casting,
op_flags,
op_dtypes,
Expand Down Expand Up @@ -568,7 +581,7 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
goto fail;
}
}

/* Check whether any errors occurred during the loop */
if (PyErr_Occurred() ||
_check_ufunc_fperr(errormask, NULL, "reduce") < 0) {
Expand Down
77 changes: 61 additions & 16 deletions numpy/core/src/umath/ufunc_object.c
Original file line number Diff line number Diff line change
Expand Up @@ -3457,21 +3457,38 @@ reduce_type_resolver(PyUFuncObject *ufunc, PyArrayObject *arr,
return 0;
}

static void
clear_masked_items(char *dataptr, npy_intp s_data,
char *maskptr, npy_intp s_mask, char *identityptr,
npy_intp count)
{
int n;
char *data = dataptr, *mask = maskptr;
for (n = 0; n < count; n++, data += s_data, mask += s_mask) {
if (!(*mask)) {
memcpy(data, identityptr, s_data);
}
}
}

static int
reduce_loop(NpyIter *iter, char **dataptrs, npy_intp *strides,
npy_intp *countptr, NpyIter_IterNextFunc *iternext,
int needs_api, npy_intp skip_first_count, void *data)
{
PyArray_Descr *dtypes[3], **iter_dtypes;
PyArray_Descr *dtypes[4], **iter_dtypes;
PyUFuncObject *ufunc = (PyUFuncObject *)data;
char *dataptrs_copy[3];
npy_intp strides_copy[3];
char *dataptrs_copy[4];
npy_intp strides_copy[4];
npy_bool where_mask;

/* The normal selected inner loop */
PyUFuncGenericFunction innerloop = NULL;
void *innerloopdata = NULL;

NPY_BEGIN_THREADS_DEF;
/* Get the number of operands, to determine whether "where" is used */
where_mask = (NpyIter_GetNOp(iter) == 4);

/* Get the inner loop */
iter_dtypes = NpyIter_GetDescrArray(iter);
Expand Down Expand Up @@ -3524,6 +3541,13 @@ reduce_loop(NpyIter *iter, char **dataptrs, npy_intp *strides,
} while (iternext(iter));
}
do {
if (where_mask) {
printf("strides=%ld,%ld,%ld\n", strides[0], strides[1],
strides[2]);
clear_masked_items(dataptrs[1], strides[1],
dataptrs[2], strides[2],
dataptrs[3], *countptr);
}
/* Turn the two items into three for the inner loop */
dataptrs_copy[0] = dataptrs[0];
dataptrs_copy[1] = dataptrs[1];
Expand Down Expand Up @@ -3561,7 +3585,7 @@ reduce_loop(NpyIter *iter, char **dataptrs, npy_intp *strides,
static PyArrayObject *
PyUFunc_Reduce(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
int naxes, int *axes, PyArray_Descr *odtype, int keepdims,
PyObject *initial)
PyObject *initial, PyArrayObject *wheremask)
{
int iaxes, ndim;
npy_bool reorderable;
Expand Down Expand Up @@ -3627,7 +3651,7 @@ PyUFunc_Reduce(PyUFuncObject *ufunc, PyArrayObject *arr, PyArrayObject *out,
return NULL;
}

result = PyUFunc_ReduceWrapper(arr, out, NULL, dtype, dtype,
result = PyUFunc_ReduceWrapper(arr, out, wheremask, dtype, dtype,
NPY_UNSAFE_CASTING,
axis_flags, reorderable,
keepdims, 0,
Expand Down Expand Up @@ -4384,16 +4408,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
int i, naxes=0, ndim;
int axes[NPY_MAXDIMS];
PyObject *axes_in = NULL;
PyArrayObject *mp = NULL, *ret = NULL;
PyObject *op;
PyArrayObject *mp = NULL, *wheremask = NULL, *ret = NULL;
PyObject *op, *where = NULL;
PyObject *obj_ind, *context;
PyArrayObject *indices = NULL;
PyArray_Descr *otype = NULL;
PyArrayObject *out = NULL;
int keepdims = 0;
PyObject *initial = NULL;
static char *reduce_kwlist[] = {
"array", "axis", "dtype", "out", "keepdims", "initial", NULL};
"array", "axis", "dtype", "out", "keepdims", "initial", "where", NULL};
static char *accumulate_kwlist[] = {
"array", "axis", "dtype", "out", NULL};
static char *reduceat_kwlist[] = {
Expand Down Expand Up @@ -4456,24 +4480,43 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
}
else if (operation == UFUNC_ACCUMULATE) {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO&O&:accumulate",
accumulate_kwlist,
&op,
&axes_in,
PyArray_DescrConverter2, &otype,
PyArray_OutputConverter, &out)) {
accumulate_kwlist,
&op,
&axes_in,
PyArray_DescrConverter2, &otype,
PyArray_OutputConverter, &out)) {
goto fail;
}
}
else {
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO&O&iO:reduce",
if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|OO&O&iOO:reduce",
reduce_kwlist,
&op,
&axes_in,
PyArray_DescrConverter2, &otype,
PyArray_OutputConverter, &out,
&keepdims, &initial)) {
&keepdims, &initial, &where)) {
goto fail;
}
/* Interpret mask */
if (where != NULL) {
PyArray_Descr *dtype;
dtype = PyArray_DescrFromType(NPY_BOOL);
if (dtype == NULL) {
goto fail;
}
/*
* Optimization: where=True is the same as no where argument.
* This lets us document it as a default argument.
*/
if (where != Py_True) {
wheremask = (PyArrayObject *)PyArray_FromAny(where, dtype,
0, 0, 0, NULL);
if (wheremask == NULL) {
goto fail;
}
}
}
}
/* Ensure input is an array */
if (!PyArray_Check(op) && !PyArray_IsScalar(op, Generic)) {
Expand Down Expand Up @@ -4602,7 +4645,8 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
switch(operation) {
case UFUNC_REDUCE:
ret = PyUFunc_Reduce(ufunc, mp, out, naxes, axes,
otype, keepdims, initial);
otype, keepdims, initial, wheremask);
Py_XDECREF(wheremask);
break;
case UFUNC_ACCUMULATE:
if (naxes != 1) {
Expand Down Expand Up @@ -4660,6 +4704,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
fail:
Py_XDECREF(otype);
Py_XDECREF(mp);
Py_XDECREF(wheremask);
return NULL;
}

Expand Down
2 changes: 1 addition & 1 deletion numpy/core/tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,7 +1758,7 @@ def test_reduce_arguments(self):
# too little
assert_raises(TypeError, f)
# too much
assert_raises(TypeError, f, d, 0, None, None, False, 0, 1)
assert_raises(TypeError, f, d, 0, None, None, False, 0, True, 1)
# invalid axis
assert_raises(TypeError, f, d, "invalid")
assert_raises(TypeError, f, d, axis="invalid")
Expand Down