diff --git a/numpy/core/src/multiarray/iterators.c b/numpy/core/src/multiarray/iterators.c index 5099e3e193c8..4cb81fd00543 100644 --- a/numpy/core/src/multiarray/iterators.c +++ b/numpy/core/src/multiarray/iterators.c @@ -1436,232 +1436,183 @@ PyArray_Broadcast(PyArrayMultiIterObject *mit) return 0; } -/*NUMPY_API - * Get MultiIterator from array of Python objects and any additional - * - * PyObject **mps -- array of PyObjects - * int n - number of PyObjects in the array - * int nadd - number of additional arrays to include in the iterator. - * - * Returns a multi-iterator object. +static NPY_INLINE PyObject* +multiiter_wrong_number_of_args(void) +{ + return PyErr_Format(PyExc_ValueError, + "Need at least 1 and at most %d " + "array objects.", NPY_MAXARGS); +} + +/* + * Common implementation for all PyArrayMultiIterObject constructors. + * This function takes a pointer to an array of at most NPY_MAXARGS + * PyObject pointers, NULL terminated unless completely full, which + * must either hold array_like objects or multi iterators. */ -NPY_NO_EXPORT PyObject * -PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...) +static PyObject* +multiiter_new_impl(PyObject **args) { - va_list va; PyArrayMultiIterObject *multi; - PyObject *current; - PyObject *arr; - - int i, ntot, err=0; + int i; - ntot = n + nadd; - if (ntot < 1 || ntot > NPY_MAXARGS) { - PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); - return NULL; - } multi = PyArray_malloc(sizeof(PyArrayMultiIterObject)); if (multi == NULL) { return PyErr_NoMemory(); } PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type); + multi->numiter = 0; - for (i = 0; i < ntot; i++) { - multi->iters[i] = NULL; - } - multi->numiter = ntot; - multi->index = 0; + for (i = 0; i < NPY_MAXARGS; ++i) { + PyObject *obj = args[i]; + PyObject *arr; + PyArrayIterObject *it; - va_start(va, nadd); - for (i = 0; i < ntot; i++) { - if (i < n) { - current = mps[i]; - } - else { - current = va_arg(va, PyObject *); - } - arr = PyArray_FROM_O(current); - if (arr == NULL) { - err = 1; + if (obj == NULL) { break; } - else { - multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr); - if (multi->iters[i] == NULL) { - err = 1; - break; + if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { + PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj; + int j; + + if (multi->numiter + mit->numiter > NPY_MAXARGS) { + multiiter_wrong_number_of_args(); + goto fail; + } + for (j = 0; j < mit->numiter; ++j) { + arr = (PyObject *)mit->iters[j]->ao; + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if(it == NULL) { + goto fail; + } + multi->iters[multi->numiter++] = it; + } + } + else if (multi->numiter < NPY_MAXARGS) { + arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL); + if (arr == NULL) { + goto fail; } + it = (PyArrayIterObject *)PyArray_IterNew(arr); + if (it == NULL) { + Py_DECREF(arr); + goto fail; + } + multi->iters[multi->numiter++] = it; Py_DECREF(arr); } + else { + multiiter_wrong_number_of_args(); + goto fail; + } } - va_end(va); - if (!err && PyArray_Broadcast(multi) < 0) { - err = 1; + if (multi->numiter < 1) { + multiiter_wrong_number_of_args(); + goto fail; } - if (err) { - Py_DECREF(multi); - return NULL; + if (PyArray_Broadcast(multi) < 0) { + goto fail; } PyArray_MultiIter_RESET(multi); + return (PyObject *)multi; + +fail: + Py_DECREF(multi); + + return NULL; } /*NUMPY_API - * Get MultiIterator, + * Get MultiIterator from array of Python objects and any additional + * + * PyObject **mps - array of PyObjects + * int n - number of PyObjects in the array + * int nadd - number of additional arrays to include in the iterator. + * + * Returns a multi-iterator object. */ -NPY_NO_EXPORT PyObject * -PyArray_MultiIterNew(int n, ...) +NPY_NO_EXPORT PyObject* +PyArray_MultiIterFromObjects(PyObject **mps, int n, int nadd, ...) { + PyObject *args_impl[NPY_MAXARGS]; + int ntot = n + nadd; + int i; va_list va; - PyArrayMultiIterObject *multi; - PyObject *current; - PyObject *arr; - int i, err = 0; + if (ntot > NPY_MAXARGS) { + return multiiter_wrong_number_of_args(); + } - if (n < 1 || n > NPY_MAXARGS) { - PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); - return NULL; + for (i = 0; i < n; ++i) { + args_impl[i] = mps[i]; } - /* fprintf(stderr, "multi new...");*/ + va_start(va, nadd); + for (; i < ntot; ++i) { + args_impl[i] = va_arg(va, PyObject *); + } + va_end(va); - multi = PyArray_malloc(sizeof(PyArrayMultiIterObject)); - if (multi == NULL) { - return PyErr_NoMemory(); + if (ntot < NPY_MAXARGS) { + args_impl[ntot] = NULL; } - PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type); - for (i = 0; i < n; i++) { - multi->iters[i] = NULL; + return multiiter_new_impl(args_impl); +} + +/*NUMPY_API + * Get MultiIterator, + */ +NPY_NO_EXPORT PyObject* +PyArray_MultiIterNew(int n, ...) +{ + PyObject *args_impl[NPY_MAXARGS]; + int i; + va_list va; + + if (n > NPY_MAXARGS) { + return multiiter_wrong_number_of_args(); } - multi->numiter = n; - multi->index = 0; va_start(va, n); - for (i = 0; i < n; i++) { - current = va_arg(va, PyObject *); - arr = PyArray_FROM_O(current); - if (arr == NULL) { - err = 1; - break; - } - else { - multi->iters[i] = (PyArrayIterObject *)PyArray_IterNew(arr); - if (multi->iters[i] == NULL) { - err = 1; - break; - } - Py_DECREF(arr); - } + for (i = 0; i < n; ++i) { + args_impl[i] = va_arg(va, PyObject *); } va_end(va); - if (!err && PyArray_Broadcast(multi) < 0) { - err = 1; - } - if (err) { - Py_DECREF(multi); - return NULL; + if (n < NPY_MAXARGS) { + args_impl[n] = NULL; } - PyArray_MultiIter_RESET(multi); - return (PyObject *)multi; + + return multiiter_new_impl(args_impl); } -static PyObject * -arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, PyObject *kwds) +static PyObject* +arraymultiter_new(PyTypeObject *NPY_UNUSED(subtype), PyObject *args, + PyObject *kwds) { - - Py_ssize_t n = 0; - Py_ssize_t i, j, k; - PyArrayMultiIterObject *multi; - PyObject *arr; + PyObject *args_impl[NPY_MAXARGS]; + Py_ssize_t n = PyTuple_Size(args); + int i; if (kwds != NULL) { PyErr_SetString(PyExc_ValueError, "keyword arguments not accepted."); return NULL; } - - for (j = 0; j < PyTuple_Size(args); ++j) { - PyObject *obj = PyTuple_GET_ITEM(args, j); - - if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { - /* - * If obj is a multi-iterator, all its arrays will be added - * to the new multi-iterator. - */ - n += ((PyArrayMultiIterObject *)obj)->numiter; - } - else { - /* If not, will try to convert it to a single array */ - ++n; - } - } - if (n < 1 || n > NPY_MAXARGS) { - if (PyErr_Occurred()) { - return NULL; - } - PyErr_Format(PyExc_ValueError, - "Need at least 1 and at most %d " - "array objects.", NPY_MAXARGS); - return NULL; - } - - multi = PyArray_malloc(sizeof(PyArrayMultiIterObject)); - if (multi == NULL) { - return PyErr_NoMemory(); + if (n > NPY_MAXARGS) { + return multiiter_wrong_number_of_args(); } - PyObject_Init((PyObject *)multi, &PyArrayMultiIter_Type); - - multi->numiter = n; - multi->index = 0; - i = 0; - for (j = 0; j < PyTuple_GET_SIZE(args); ++j) { - PyObject *obj = PyTuple_GET_ITEM(args, j); - PyArrayIterObject *it; - - if (PyObject_IsInstance(obj, (PyObject *)&PyArrayMultiIter_Type)) { - PyArrayMultiIterObject *mit = (PyArrayMultiIterObject *)obj; - - for (k = 0; k < mit->numiter; ++k) { - arr = (PyObject *)mit->iters[k]->ao; - assert (arr != NULL); - it = (PyArrayIterObject *)PyArray_IterNew(arr); - if (it == NULL) { - goto fail; - } - multi->iters[i++] = it; - } - } - else { - arr = PyArray_FromAny(obj, NULL, 0, 0, 0, NULL); - if (arr == NULL) { - goto fail; - } - it = (PyArrayIterObject *)PyArray_IterNew(arr); - if (it == NULL) { - goto fail; - } - multi->iters[i++] = it; - Py_DECREF(arr); - } + for (i = 0; i < n; ++i) { + args_impl[i] = PyTuple_GET_ITEM(args, i); } - assert (i == n); - if (PyArray_Broadcast(multi) < 0) { - goto fail; + if (n < NPY_MAXARGS) { + args_impl[n] = NULL; } - PyArray_MultiIter_RESET(multi); - return (PyObject *)multi; - fail: - Py_DECREF(multi); - return NULL; + return multiiter_new_impl(args_impl); } static PyObject *