diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index f42ae7c2d0d8..e935a27edb6c 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -993,9 +993,10 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, if (multi == NULL) { goto fail; } + dtype = PyArray_DESCR(mps[0]); + /* Set-up return array */ if (out == NULL) { - dtype = PyArray_DESCR(mps[0]); Py_INCREF(dtype); obj = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(ap), dtype, @@ -1032,7 +1033,6 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, */ flags |= NPY_ARRAY_ENSURECOPY; } - dtype = PyArray_DESCR(mps[0]); Py_INCREF(dtype); obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags); } @@ -1040,8 +1040,22 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, if (obj == NULL) { goto fail; } - elsize = PyArray_DESCR(obj)->elsize; + elsize = dtype->elsize; ret_data = PyArray_DATA(obj); + npy_intp transfer_strides[2] = {elsize, elsize}; + npy_intp one = 1; + NPY_ARRAYMETHOD_FLAGS transfer_flags = 0; + NPY_cast_info cast_info = {.func = NULL}; + if (PyDataType_REFCHK(dtype)) { + int is_aligned = IsUintAligned(obj); + PyArray_GetDTypeTransferFunction( + is_aligned, + dtype->elsize, + dtype->elsize, + dtype, + dtype, 0, &cast_info, + &transfer_flags); + } while (PyArray_MultiIter_NOTDONE(multi)) { mi = *((npy_intp *)PyArray_MultiIter_DATA(multi, n)); @@ -1074,12 +1088,22 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, break; } } - memmove(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize); + if (cast_info.func == NULL) { + /* We ensure memory doesn't overlap, so can use memcpy */ + memcpy(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize); + } + else { + char *args[2] = {PyArray_MultiIter_DATA(multi, mi), ret_data}; + if (cast_info.func(&cast_info.context, args, &one, + transfer_strides, cast_info.auxdata) < 0) { + goto fail; + } + } ret_data += elsize; PyArray_MultiIter_NEXT(multi); } - PyArray_INCREF(obj); + NPY_cast_info_xfree(&cast_info); Py_DECREF(multi); for (i = 0; i < n; i++) { Py_XDECREF(mps[i]); @@ -1095,6 +1119,7 @@ PyArray_Choose(PyArrayObject *ip, PyObject *op, PyArrayObject *out, return (PyObject *)obj; fail: + NPY_cast_info_xfree(&cast_info); Py_XDECREF(multi); for (i = 0; i < n; i++) { Py_XDECREF(mps[i]); diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 514d271f0f6b..869ebe4d8ac7 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -10031,3 +10031,14 @@ def test_argsort_int(N, dtype): arr = rnd.randint(low=minv, high=maxv, size=N, dtype=dtype) arr[N-1] = maxv assert_arg_sorted(arr, np.argsort(arr, kind='quick')) + + +@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts") +def test_gh_22683(): + b = 777.68760986 + a = np.array([b] * 10000, dtype=object) + refc_start = sys.getrefcount(b) + np.choose(np.zeros(10000, dtype=int), [a], out=a) + np.choose(np.zeros(10000, dtype=int), [a], out=a) + refc_end = sys.getrefcount(b) + assert refc_end - refc_start < 10