Thanks to visit codestin.com
Credit goes to github.com

Skip to content

MAINT: do not use copyswap in where internals #23770

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/benchmarks/bench_function_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ def time_sort_worst(self):
class Where(Benchmark):
def setup(self):
self.d = np.arange(20000)
self.d_o = self.d.astype(object)
self.e = self.d.copy()
self.e_o = self.d_o.copy()
self.cond = (self.d > 5000)
size = 1024 * 1024 // 8
rnd_array = np.random.rand(size)
Expand All @@ -332,6 +334,11 @@ def time_1(self):
def time_2(self):
np.where(self.cond, self.d, self.e)

def time_2_object(self):
# object and byteswapped arrays have a
# special slow path in the where internals
np.where(self.cond, self.d_o, self.e_o)

def time_2_broadcast(self):
np.where(self.cond, self.d, 0)

Expand Down
90 changes: 66 additions & 24 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "convert_datatype.h"
#include "conversion_utils.h"
#include "nditer_pywrap.h"
#define NPY_ITERATOR_IMPLEMENTATION_CODE
#include "nditer_impl.h"
#include "methods.h"
#include "_datetime.h"
#include "datetime_strings.h"
Expand All @@ -67,6 +69,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
#include "mem_overlap.h"
#include "typeinfo.h"
#include "convert.h" /* for PyArray_AssignZero */
#include "lowlevel_strided_loops.h"
#include "dtype_transfer.h"

#include "get_attr_string.h"
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */
Expand Down Expand Up @@ -3381,6 +3385,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
return NULL;
}

NPY_cast_info cast_info = {.func = NULL};

ax = (PyArrayObject*)PyArray_FROM_O(x);
ay = (PyArrayObject*)PyArray_FROM_O(y);
if (ax == NULL || ay == NULL) {
Expand All @@ -3394,14 +3400,15 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
};
npy_uint32 op_flags[4] = {
NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_NO_SUBTYPE,
NPY_ITER_READONLY, NPY_ITER_READONLY, NPY_ITER_READONLY
NPY_ITER_READONLY,
NPY_ITER_READONLY | NPY_ITER_ALIGNED,
NPY_ITER_READONLY | NPY_ITER_ALIGNED
};
PyArray_Descr * common_dt = PyArray_ResultType(2, &op_in[0] + 2,
0, NULL);
PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL),
common_dt, common_dt};
NpyIter * iter;
int needs_api;
NPY_BEGIN_THREADS_DEF;

if (common_dt == NULL || op_dt[1] == NULL) {
Expand All @@ -3418,61 +3425,94 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
goto fail;
}

needs_api = NpyIter_IterationNeedsAPI(iter);

/* Get the result from the iterator object array */
ret = (PyObject*)NpyIter_GetOperandArray(iter)[0];

NPY_BEGIN_THREADS_NDITER(iter);
npy_intp itemsize = common_dt->elsize;

int has_ref = PyDataType_REFCHK(common_dt);

NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;

npy_intp transfer_strides[2] = {itemsize, itemsize};
npy_intp one = 1;

if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) &&
(itemsize != 2) && (itemsize != 1))) {
// The iterator has NPY_ITER_ALIGNED flag so no need to check alignment
// of the input arrays.
//
// There's also no need to set up a cast for y, since the iterator
// ensures both casts are identical.
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize, common_dt, common_dt, 0,
&cast_info, &transfer_flags) != NPY_SUCCEED) {
goto fail;
}
}

transfer_flags = PyArrayMethod_COMBINED_FLAGS(
transfer_flags, NpyIter_GetTransferFlags(iter));

if (!(transfer_flags & NPY_METH_REQUIRES_PYAPI)) {
NPY_BEGIN_THREADS_THRESHOLDED(NpyIter_GetIterSize(iter));
}

if (NpyIter_GetIterSize(iter) != 0) {
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL);
npy_intp * innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
char **dataptrarray = NpyIter_GetDataPtrArray(iter);
npy_intp *strides = NpyIter_GetInnerStrideArray(iter);

do {
PyArray_Descr * dtx = NpyIter_GetDescrArray(iter)[2];
PyArray_Descr * dty = NpyIter_GetDescrArray(iter)[3];
int axswap = PyDataType_ISBYTESWAPPED(dtx);
int ayswap = PyDataType_ISBYTESWAPPED(dty);
PyArray_CopySwapFunc *copyswapx = dtx->f->copyswap;
PyArray_CopySwapFunc *copyswapy = dty->f->copyswap;
int native = (axswap == ayswap) && (axswap == 0) && !needs_api;
npy_intp n = (*innersizeptr);
npy_intp itemsize = NpyIter_GetDescrArray(iter)[0]->elsize;
npy_intp cstride = NpyIter_GetInnerStrideArray(iter)[1];
npy_intp xstride = NpyIter_GetInnerStrideArray(iter)[2];
npy_intp ystride = NpyIter_GetInnerStrideArray(iter)[3];
char * dst = dataptrarray[0];
char * csrc = dataptrarray[1];
char * xsrc = dataptrarray[2];
char * ysrc = dataptrarray[3];

// the iterator might mutate these pointers,
// so need to update them every iteration
npy_intp cstride = strides[1];
npy_intp xstride = strides[2];
npy_intp ystride = strides[3];

/* constant sizes so compiler replaces memcpy */
if (native && itemsize == 16) {
if (!has_ref && itemsize == 16) {
INNER_WHERE_LOOP(16);
}
else if (native && itemsize == 8) {
else if (!has_ref && itemsize == 8) {
INNER_WHERE_LOOP(8);
}
else if (native && itemsize == 4) {
else if (!has_ref && itemsize == 4) {
INNER_WHERE_LOOP(4);
}
else if (native && itemsize == 2) {
else if (!has_ref && itemsize == 2) {
INNER_WHERE_LOOP(2);
}
else if (native && itemsize == 1) {
else if (!has_ref && itemsize == 1) {
INNER_WHERE_LOOP(1);
}
else {
/* copyswap is faster than memcpy even if we are native */
npy_intp i;
for (i = 0; i < n; i++) {
if (*csrc) {
copyswapx(dst, xsrc, axswap, ret);
char *args[2] = {xsrc, dst};

if (cast_info.func(
&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
goto fail;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there a path where has_ref is false but we get here because of strange itemsize, and then we call this function without the GIL (after NPY_BEGIN_THREADS_THRESHOLDED) (also cast_info below)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we go here, we account for whether or not the GIL is needed (e.g. the GIL should be released for the string dtype prototype). That is done via the transfer_flags.

The only thing we assume is that if !has_ref is true than a value of that dtype can be copied via a memcpy (or pointer assignment really). At this point the function call never does a cast and only a copy.

}
}
else {
copyswapy(dst, ysrc, ayswap, ret);
char *args[2] = {ysrc, dst};

if (cast_info.func(
&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
goto fail;
}
}
dst += itemsize;
xsrc += xstride;
Expand All @@ -3489,6 +3529,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
Py_DECREF(arr);
Py_DECREF(ax);
Py_DECREF(ay);
NPY_cast_info_xfree(&cast_info);

if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
Py_DECREF(ret);
Expand All @@ -3502,6 +3543,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y)
Py_DECREF(arr);
Py_XDECREF(ax);
Py_XDECREF(ay);
NPY_cast_info_xfree(&cast_info);
return NULL;
}

Expand Down