-
-
Notifications
You must be signed in to change notification settings - Fork 11k
MAINT: do not use copyswap in where internals #23770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
9d08632
f8b1a3e
46cf47d
efa004b
20463ea
2db272b
01a251b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -49,6 +49,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; | |
#include "convert_datatype.h" | ||
#include "conversion_utils.h" | ||
#include "nditer_pywrap.h" | ||
#define NPY_ITERATOR_IMPLEMENTATION_CODE | ||
#include "nditer_impl.h" | ||
#include "methods.h" | ||
#include "_datetime.h" | ||
#include "datetime_strings.h" | ||
|
@@ -67,6 +69,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; | |
#include "mem_overlap.h" | ||
#include "typeinfo.h" | ||
#include "convert.h" /* for PyArray_AssignZero */ | ||
#include "lowlevel_strided_loops.h" | ||
#include "dtype_transfer.h" | ||
|
||
#include "get_attr_string.h" | ||
#include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */ | ||
|
@@ -3381,6 +3385,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) | |
return NULL; | ||
} | ||
|
||
NPY_cast_info cast_info = {.func = NULL}; | ||
|
||
ax = (PyArrayObject*)PyArray_FROM_O(x); | ||
ay = (PyArrayObject*)PyArray_FROM_O(y); | ||
if (ax == NULL || ay == NULL) { | ||
|
@@ -3394,14 +3400,15 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) | |
}; | ||
npy_uint32 op_flags[4] = { | ||
NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_NO_SUBTYPE, | ||
NPY_ITER_READONLY, NPY_ITER_READONLY, NPY_ITER_READONLY | ||
NPY_ITER_READONLY, | ||
NPY_ITER_READONLY | NPY_ITER_ALIGNED, | ||
NPY_ITER_READONLY | NPY_ITER_ALIGNED | ||
}; | ||
PyArray_Descr * common_dt = PyArray_ResultType(2, &op_in[0] + 2, | ||
0, NULL); | ||
PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL), | ||
common_dt, common_dt}; | ||
NpyIter * iter; | ||
int needs_api; | ||
NPY_BEGIN_THREADS_DEF; | ||
|
||
if (common_dt == NULL || op_dt[1] == NULL) { | ||
|
@@ -3418,61 +3425,94 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) | |
goto fail; | ||
} | ||
|
||
needs_api = NpyIter_IterationNeedsAPI(iter); | ||
|
||
/* Get the result from the iterator object array */ | ||
ret = (PyObject*)NpyIter_GetOperandArray(iter)[0]; | ||
|
||
NPY_BEGIN_THREADS_NDITER(iter); | ||
npy_intp itemsize = common_dt->elsize; | ||
|
||
int has_ref = PyDataType_REFCHK(common_dt); | ||
|
||
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0; | ||
|
||
npy_intp transfer_strides[2] = {itemsize, itemsize}; | ||
npy_intp one = 1; | ||
|
||
if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) && | ||
(itemsize != 2) && (itemsize != 1))) { | ||
// The iterator has NPY_ITER_ALIGNED flag so no need to check alignment | ||
// of the input arrays. | ||
// | ||
// There's also no need to set up a cast for y, since the iterator | ||
// ensures both casts are identical. | ||
if (PyArray_GetDTypeTransferFunction( | ||
1, itemsize, itemsize, common_dt, common_dt, 0, | ||
&cast_info, &transfer_flags) != NPY_SUCCEED) { | ||
goto fail; | ||
} | ||
} | ||
|
||
transfer_flags = PyArrayMethod_COMBINED_FLAGS( | ||
transfer_flags, NpyIter_GetTransferFlags(iter)); | ||
|
||
if (!(transfer_flags & NPY_METH_REQUIRES_PYAPI)) { | ||
NPY_BEGIN_THREADS_THRESHOLDED(NpyIter_GetIterSize(iter)); | ||
} | ||
|
||
if (NpyIter_GetIterSize(iter) != 0) { | ||
NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL); | ||
npy_intp * innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); | ||
char **dataptrarray = NpyIter_GetDataPtrArray(iter); | ||
npy_intp *strides = NpyIter_GetInnerStrideArray(iter); | ||
|
||
do { | ||
PyArray_Descr * dtx = NpyIter_GetDescrArray(iter)[2]; | ||
PyArray_Descr * dty = NpyIter_GetDescrArray(iter)[3]; | ||
int axswap = PyDataType_ISBYTESWAPPED(dtx); | ||
int ayswap = PyDataType_ISBYTESWAPPED(dty); | ||
PyArray_CopySwapFunc *copyswapx = dtx->f->copyswap; | ||
PyArray_CopySwapFunc *copyswapy = dty->f->copyswap; | ||
int native = (axswap == ayswap) && (axswap == 0) && !needs_api; | ||
npy_intp n = (*innersizeptr); | ||
npy_intp itemsize = NpyIter_GetDescrArray(iter)[0]->elsize; | ||
npy_intp cstride = NpyIter_GetInnerStrideArray(iter)[1]; | ||
npy_intp xstride = NpyIter_GetInnerStrideArray(iter)[2]; | ||
npy_intp ystride = NpyIter_GetInnerStrideArray(iter)[3]; | ||
char * dst = dataptrarray[0]; | ||
char * csrc = dataptrarray[1]; | ||
char * xsrc = dataptrarray[2]; | ||
char * ysrc = dataptrarray[3]; | ||
|
||
// the iterator might mutate these pointers, | ||
// so need to update them every iteration | ||
npy_intp cstride = strides[1]; | ||
npy_intp xstride = strides[2]; | ||
npy_intp ystride = strides[3]; | ||
|
||
/* constant sizes so compiler replaces memcpy */ | ||
if (native && itemsize == 16) { | ||
if (!has_ref && itemsize == 16) { | ||
INNER_WHERE_LOOP(16); | ||
} | ||
else if (native && itemsize == 8) { | ||
else if (!has_ref && itemsize == 8) { | ||
INNER_WHERE_LOOP(8); | ||
} | ||
else if (native && itemsize == 4) { | ||
else if (!has_ref && itemsize == 4) { | ||
INNER_WHERE_LOOP(4); | ||
} | ||
else if (native && itemsize == 2) { | ||
else if (!has_ref && itemsize == 2) { | ||
INNER_WHERE_LOOP(2); | ||
} | ||
else if (native && itemsize == 1) { | ||
else if (!has_ref && itemsize == 1) { | ||
INNER_WHERE_LOOP(1); | ||
} | ||
else { | ||
/* copyswap is faster than memcpy even if we are native */ | ||
npy_intp i; | ||
for (i = 0; i < n; i++) { | ||
if (*csrc) { | ||
copyswapx(dst, xsrc, axswap, ret); | ||
char *args[2] = {xsrc, dst}; | ||
|
||
if (cast_info.func( | ||
&cast_info.context, args, &one, | ||
transfer_strides, cast_info.auxdata) < 0) { | ||
goto fail; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't there a path where There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we go here, we account for whether or not the GIL is needed (e.g. the GIL should be released for the string dtype prototype). That is done via the The only thing we assume is that if |
||
} | ||
} | ||
else { | ||
copyswapy(dst, ysrc, ayswap, ret); | ||
char *args[2] = {ysrc, dst}; | ||
|
||
if (cast_info.func( | ||
&cast_info.context, args, &one, | ||
transfer_strides, cast_info.auxdata) < 0) { | ||
goto fail; | ||
} | ||
} | ||
dst += itemsize; | ||
xsrc += xstride; | ||
|
@@ -3489,6 +3529,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) | |
Py_DECREF(arr); | ||
Py_DECREF(ax); | ||
Py_DECREF(ay); | ||
NPY_cast_info_xfree(&cast_info); | ||
|
||
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { | ||
Py_DECREF(ret); | ||
|
@@ -3502,6 +3543,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) | |
Py_DECREF(arr); | ||
Py_XDECREF(ax); | ||
Py_XDECREF(ay); | ||
NPY_cast_info_xfree(&cast_info); | ||
return NULL; | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.