diff --git a/benchmarks/benchmarks/bench_function_base.py b/benchmarks/benchmarks/bench_function_base.py index cc37bef3994b..1655868b3b55 100644 --- a/benchmarks/benchmarks/bench_function_base.py +++ b/benchmarks/benchmarks/bench_function_base.py @@ -308,7 +308,9 @@ def time_sort_worst(self): class Where(Benchmark): def setup(self): self.d = np.arange(20000) + self.d_o = self.d.astype(object) self.e = self.d.copy() + self.e_o = self.d_o.copy() self.cond = (self.d > 5000) size = 1024 * 1024 // 8 rnd_array = np.random.rand(size) @@ -332,6 +334,11 @@ def time_1(self): def time_2(self): np.where(self.cond, self.d, self.e) + def time_2_object(self): + # object and byteswapped arrays have a + # special slow path in the where internals + np.where(self.cond, self.d_o, self.e_o) + def time_2_broadcast(self): np.where(self.cond, self.d, 0) diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index d7d19493b754..0fa1af11d0ab 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -49,6 +49,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "convert_datatype.h" #include "conversion_utils.h" #include "nditer_pywrap.h" +#define NPY_ITERATOR_IMPLEMENTATION_CODE +#include "nditer_impl.h" #include "methods.h" #include "_datetime.h" #include "datetime_strings.h" @@ -67,6 +69,8 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0; #include "mem_overlap.h" #include "typeinfo.h" #include "convert.h" /* for PyArray_AssignZero */ +#include "lowlevel_strided_loops.h" +#include "dtype_transfer.h" #include "get_attr_string.h" #include "experimental_public_dtype_api.h" /* _get_experimental_dtype_api */ @@ -3381,6 +3385,8 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) return NULL; } + NPY_cast_info cast_info = {.func = NULL}; + ax = (PyArrayObject*)PyArray_FROM_O(x); ay = (PyArrayObject*)PyArray_FROM_O(y); if (ax == NULL || ay == NULL) { @@ -3394,14 +3400,15 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) }; npy_uint32 op_flags[4] = { NPY_ITER_WRITEONLY | NPY_ITER_ALLOCATE | NPY_ITER_NO_SUBTYPE, - NPY_ITER_READONLY, NPY_ITER_READONLY, NPY_ITER_READONLY + NPY_ITER_READONLY, + NPY_ITER_READONLY | NPY_ITER_ALIGNED, + NPY_ITER_READONLY | NPY_ITER_ALIGNED }; PyArray_Descr * common_dt = PyArray_ResultType(2, &op_in[0] + 2, 0, NULL); PyArray_Descr * op_dt[4] = {common_dt, PyArray_DescrFromType(NPY_BOOL), common_dt, common_dt}; NpyIter * iter; - int needs_api; NPY_BEGIN_THREADS_DEF; if (common_dt == NULL || op_dt[1] == NULL) { @@ -3418,61 +3425,94 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) goto fail; } - needs_api = NpyIter_IterationNeedsAPI(iter); - /* Get the result from the iterator object array */ ret = (PyObject*)NpyIter_GetOperandArray(iter)[0]; - NPY_BEGIN_THREADS_NDITER(iter); + npy_intp itemsize = common_dt->elsize; + + int has_ref = PyDataType_REFCHK(common_dt); + + NPY_ARRAYMETHOD_FLAGS transfer_flags = 0; + + npy_intp transfer_strides[2] = {itemsize, itemsize}; + npy_intp one = 1; + + if (has_ref || ((itemsize != 16) && (itemsize != 8) && (itemsize != 4) && + (itemsize != 2) && (itemsize != 1))) { + // The iterator has NPY_ITER_ALIGNED flag so no need to check alignment + // of the input arrays. + // + // There's also no need to set up a cast for y, since the iterator + // ensures both casts are identical. + if (PyArray_GetDTypeTransferFunction( + 1, itemsize, itemsize, common_dt, common_dt, 0, + &cast_info, &transfer_flags) != NPY_SUCCEED) { + goto fail; + } + } + + transfer_flags = PyArrayMethod_COMBINED_FLAGS( + transfer_flags, NpyIter_GetTransferFlags(iter)); + + if (!(transfer_flags & NPY_METH_REQUIRES_PYAPI)) { + NPY_BEGIN_THREADS_THRESHOLDED(NpyIter_GetIterSize(iter)); + } if (NpyIter_GetIterSize(iter) != 0) { NpyIter_IterNextFunc *iternext = NpyIter_GetIterNext(iter, NULL); npy_intp * innersizeptr = NpyIter_GetInnerLoopSizePtr(iter); char **dataptrarray = NpyIter_GetDataPtrArray(iter); + npy_intp *strides = NpyIter_GetInnerStrideArray(iter); do { - PyArray_Descr * dtx = NpyIter_GetDescrArray(iter)[2]; - PyArray_Descr * dty = NpyIter_GetDescrArray(iter)[3]; - int axswap = PyDataType_ISBYTESWAPPED(dtx); - int ayswap = PyDataType_ISBYTESWAPPED(dty); - PyArray_CopySwapFunc *copyswapx = dtx->f->copyswap; - PyArray_CopySwapFunc *copyswapy = dty->f->copyswap; - int native = (axswap == ayswap) && (axswap == 0) && !needs_api; npy_intp n = (*innersizeptr); - npy_intp itemsize = NpyIter_GetDescrArray(iter)[0]->elsize; - npy_intp cstride = NpyIter_GetInnerStrideArray(iter)[1]; - npy_intp xstride = NpyIter_GetInnerStrideArray(iter)[2]; - npy_intp ystride = NpyIter_GetInnerStrideArray(iter)[3]; char * dst = dataptrarray[0]; char * csrc = dataptrarray[1]; char * xsrc = dataptrarray[2]; char * ysrc = dataptrarray[3]; + // the iterator might mutate these pointers, + // so need to update them every iteration + npy_intp cstride = strides[1]; + npy_intp xstride = strides[2]; + npy_intp ystride = strides[3]; + /* constant sizes so compiler replaces memcpy */ - if (native && itemsize == 16) { + if (!has_ref && itemsize == 16) { INNER_WHERE_LOOP(16); } - else if (native && itemsize == 8) { + else if (!has_ref && itemsize == 8) { INNER_WHERE_LOOP(8); } - else if (native && itemsize == 4) { + else if (!has_ref && itemsize == 4) { INNER_WHERE_LOOP(4); } - else if (native && itemsize == 2) { + else if (!has_ref && itemsize == 2) { INNER_WHERE_LOOP(2); } - else if (native && itemsize == 1) { + else if (!has_ref && itemsize == 1) { INNER_WHERE_LOOP(1); } else { - /* copyswap is faster than memcpy even if we are native */ npy_intp i; for (i = 0; i < n; i++) { if (*csrc) { - copyswapx(dst, xsrc, axswap, ret); + char *args[2] = {xsrc, dst}; + + if (cast_info.func( + &cast_info.context, args, &one, + transfer_strides, cast_info.auxdata) < 0) { + goto fail; + } } else { - copyswapy(dst, ysrc, ayswap, ret); + char *args[2] = {ysrc, dst}; + + if (cast_info.func( + &cast_info.context, args, &one, + transfer_strides, cast_info.auxdata) < 0) { + goto fail; + } } dst += itemsize; xsrc += xstride; @@ -3489,6 +3529,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) Py_DECREF(arr); Py_DECREF(ax); Py_DECREF(ay); + NPY_cast_info_xfree(&cast_info); if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { Py_DECREF(ret); @@ -3502,6 +3543,7 @@ PyArray_Where(PyObject *condition, PyObject *x, PyObject *y) Py_DECREF(arr); Py_XDECREF(ax); Py_XDECREF(ay); + NPY_cast_info_xfree(&cast_info); return NULL; }