From efc9ff3a8a22d20fdce7df07fcca74cd028ddcf0 Mon Sep 17 00:00:00 2001 From: Eric Wieser Date: Mon, 4 Feb 2019 00:01:17 -0800 Subject: [PATCH] MAINT: Merge together the unary and binary type resolvers This merges `PyUFunc_SimpleUnaryOperationTypeResolver` and `PyUFunc_SimpleBinaryOperationTypeResolver` These are almost identical, save for using `ResultType` vs simply forcing a byte order. This comes at the cost of a handful of branches, which should be insignifcant compared to the rest of the ufunc overhead. --- numpy/core/code_generators/generate_umath.py | 18 +- numpy/core/src/umath/ufunc_type_resolution.c | 177 +++++-------------- numpy/core/src/umath/ufunc_type_resolution.h | 9 +- 3 files changed, 57 insertions(+), 147 deletions(-) diff --git a/numpy/core/code_generators/generate_umath.py b/numpy/core/code_generators/generate_umath.py index 0fac9b05eeff..7b84e8c3b96c 100644 --- a/numpy/core/code_generators/generate_umath.py +++ b/numpy/core/code_generators/generate_umath.py @@ -407,7 +407,7 @@ def english_upper(s): 'positive': Ufunc(1, 1, None, docstrings.get('numpy.core.umath.positive'), - 'PyUFunc_SimpleUnaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(ints+flts+timedeltaonly), TD(cmplx, f='pos'), TD(O, f='PyNumber_Positive'), @@ -415,7 +415,7 @@ def english_upper(s): 'sign': Ufunc(1, 1, None, docstrings.get('numpy.core.umath.sign'), - 'PyUFunc_SimpleUnaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(nobool_or_datetime), ), 'greater': @@ -491,28 +491,28 @@ def english_upper(s): 'maximum': Ufunc(2, 1, ReorderableNone, docstrings.get('numpy.core.umath.maximum'), - 'PyUFunc_SimpleBinaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(noobj), TD(O, f='npy_ObjectMax') ), 'minimum': Ufunc(2, 1, ReorderableNone, docstrings.get('numpy.core.umath.minimum'), - 'PyUFunc_SimpleBinaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(noobj), TD(O, f='npy_ObjectMin') ), 'fmax': Ufunc(2, 1, ReorderableNone, docstrings.get('numpy.core.umath.fmax'), - 'PyUFunc_SimpleBinaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(noobj), TD(O, f='npy_ObjectMax') ), 'fmin': Ufunc(2, 1, ReorderableNone, docstrings.get('numpy.core.umath.fmin'), - 'PyUFunc_SimpleBinaryOperationTypeResolver', + 'PyUFunc_SimpleUniformOperationTypeResolver', TD(noobj), TD(O, f='npy_ObjectMin') ), @@ -895,21 +895,21 @@ def english_upper(s): 'gcd' : Ufunc(2, 1, Zero, docstrings.get('numpy.core.umath.gcd'), - "PyUFunc_SimpleBinaryOperationTypeResolver", + "PyUFunc_SimpleUniformOperationTypeResolver", TD(ints), TD('O', f='npy_ObjectGCD'), ), 'lcm' : Ufunc(2, 1, None, docstrings.get('numpy.core.umath.lcm'), - "PyUFunc_SimpleBinaryOperationTypeResolver", + "PyUFunc_SimpleUniformOperationTypeResolver", TD(ints), TD('O', f='npy_ObjectLCM'), ), 'matmul' : Ufunc(2, 1, None, docstrings.get('numpy.core.umath.matmul'), - "PyUFunc_SimpleBinaryOperationTypeResolver", + "PyUFunc_SimpleUniformOperationTypeResolver", TD(notimes_or_obj), signature='(n?,k),(k,m?)->(n?,m?)', ), diff --git a/numpy/core/src/umath/ufunc_type_resolution.c b/numpy/core/src/umath/ufunc_type_resolution.c index a4a59faa92c0..c07934caaf8f 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.c +++ b/numpy/core/src/umath/ufunc_type_resolution.c @@ -12,6 +12,8 @@ #define _MULTIARRAYMODULE #define NPY_NO_DEPRECATED_API NPY_API_VERSION +#include + #include "Python.h" #include "npy_config.h" @@ -407,99 +409,6 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, return 0; } -/* - * This function applies special type resolution rules for the case - * where all the functions have the pattern X->X, copying - * the input descr directly so that metadata is maintained. - * - * Note that a simpler linear search through the functions loop - * is still done, but switching to a simple array lookup for - * built-in types would be better at some point. - * - * Returns 0 on success, -1 on error. - */ -NPY_NO_EXPORT int -PyUFunc_SimpleUnaryOperationTypeResolver(PyUFuncObject *ufunc, - NPY_CASTING casting, - PyArrayObject **operands, - PyObject *type_tup, - PyArray_Descr **out_dtypes) -{ - int i, type_num1; - const char *ufunc_name = ufunc_get_name_cstr(ufunc); - - if (ufunc->nin != 1 || ufunc->nout != 1) { - PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured " - "to use unary operation type resolution but has " - "the wrong number of inputs or outputs", - ufunc_name); - return -1; - } - - /* - * Use the default type resolution if there's a custom data type - * or object arrays. - */ - type_num1 = PyArray_DESCR(operands[0])->type_num; - if (type_num1 >= NPY_NTYPES || type_num1 == NPY_OBJECT) { - return PyUFunc_DefaultTypeResolver(ufunc, casting, operands, - type_tup, out_dtypes); - } - - if (type_tup == NULL) { - /* Input types are the result type */ - out_dtypes[0] = ensure_dtype_nbo(PyArray_DESCR(operands[0])); - if (out_dtypes[0] == NULL) { - return -1; - } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); - } - else { - PyObject *item; - PyArray_Descr *dtype = NULL; - - /* - * If the type tuple isn't a single-element tuple, let the - * default type resolution handle this one. - */ - if (!PyTuple_Check(type_tup) || PyTuple_GET_SIZE(type_tup) != 1) { - return PyUFunc_DefaultTypeResolver(ufunc, casting, - operands, type_tup, out_dtypes); - } - - item = PyTuple_GET_ITEM(type_tup, 0); - - if (item == Py_None) { - PyErr_SetString(PyExc_ValueError, - "require data type in the type tuple"); - return -1; - } - else if (!PyArray_DescrConverter(item, &dtype)) { - return -1; - } - - out_dtypes[0] = ensure_dtype_nbo(dtype); - if (out_dtypes[0] == NULL) { - return -1; - } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); - } - - /* Check against the casting rules */ - if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { - for (i = 0; i < 2; ++i) { - Py_DECREF(out_dtypes[i]); - out_dtypes[i] = NULL; - } - return -1; - } - - return 0; -} - - NPY_NO_EXPORT int PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc, NPY_CASTING casting, @@ -508,7 +417,7 @@ PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc, PyArray_Descr **out_dtypes) { int ret; - ret = PyUFunc_SimpleUnaryOperationTypeResolver(ufunc, casting, operands, + ret = PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); if (ret < 0) { return ret; @@ -538,16 +447,15 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc, PyObject *type_tup, PyArray_Descr **out_dtypes) { - return PyUFunc_SimpleUnaryOperationTypeResolver(ufunc, + return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, NPY_UNSAFE_CASTING, operands, type_tup, out_dtypes); } - /* * This function applies special type resolution rules for the case - * where all the functions have the pattern XX->X, using - * PyArray_ResultType instead of a linear search to get the best + * where all of the types in the signature are the same, eg XX->X or XX->XX. + * It uses PyArray_ResultType instead of a linear search to get the best * loop. * * Note that a simpler linear search through the functions loop @@ -557,45 +465,52 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc, * Returns 0 on success, -1 on error. */ NPY_NO_EXPORT int -PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc, - NPY_CASTING casting, - PyArrayObject **operands, - PyObject *type_tup, - PyArray_Descr **out_dtypes) +PyUFunc_SimpleUniformOperationTypeResolver( + PyUFuncObject *ufunc, + NPY_CASTING casting, + PyArrayObject **operands, + PyObject *type_tup, + PyArray_Descr **out_dtypes) { - int i, type_num1, type_num2; const char *ufunc_name = ufunc_get_name_cstr(ufunc); - if (ufunc->nin != 2 || ufunc->nout != 1) { + if (ufunc->nin < 1) { PyErr_Format(PyExc_RuntimeError, "ufunc %s is configured " - "to use binary operation type resolution but has " - "the wrong number of inputs or outputs", + "to use uniform operation type resolution but has " + "no inputs", ufunc_name); return -1; } + int nop = ufunc->nin + ufunc->nout; /* - * Use the default type resolution if there's a custom data type - * or object arrays. + * There's a custom data type or an object array */ - type_num1 = PyArray_DESCR(operands[0])->type_num; - type_num2 = PyArray_DESCR(operands[1])->type_num; - if (type_num1 >= NPY_NTYPES || type_num2 >= NPY_NTYPES || - type_num1 == NPY_OBJECT || type_num2 == NPY_OBJECT) { + bool has_custom_or_object = false; + for (int iop = 0; iop < ufunc->nin; iop++) { + int type_num = PyArray_DESCR(operands[iop])->type_num; + if (type_num >= NPY_NTYPES || type_num == NPY_OBJECT) { + has_custom_or_object = true; + break; + } + } + + if (has_custom_or_object) { return PyUFunc_DefaultTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); } if (type_tup == NULL) { - /* Input types are the result type */ - out_dtypes[0] = PyArray_ResultType(2, operands, 0, NULL); + /* PyArray_ResultType forgets to force a byte order when n == 1 */ + if (ufunc->nin == 1){ + out_dtypes[0] = ensure_dtype_nbo(PyArray_DESCR(operands[0])); + } + else { + out_dtypes[0] = PyArray_ResultType(ufunc->nin, operands, 0, NULL); + } if (out_dtypes[0] == NULL) { return -1; } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); - out_dtypes[2] = out_dtypes[0]; - Py_INCREF(out_dtypes[2]); } else { PyObject *item; @@ -625,17 +540,19 @@ PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc, if (out_dtypes[0] == NULL) { return -1; } - out_dtypes[1] = out_dtypes[0]; - Py_INCREF(out_dtypes[1]); - out_dtypes[2] = out_dtypes[0]; - Py_INCREF(out_dtypes[2]); + } + + /* All types are the same - copy the first one to the rest */ + for (int iop = 1; iop < nop; iop++) { + out_dtypes[iop] = out_dtypes[0]; + Py_INCREF(out_dtypes[iop]); } /* Check against the casting rules */ if (PyUFunc_ValidateCasting(ufunc, casting, operands, out_dtypes) < 0) { - for (i = 0; i < 3; ++i) { - Py_DECREF(out_dtypes[i]); - out_dtypes[i] = NULL; + for (int iop = 0; iop < nop; iop++) { + Py_DECREF(out_dtypes[iop]); + out_dtypes[iop] = NULL; } return -1; } @@ -663,7 +580,7 @@ PyUFunc_AbsoluteTypeResolver(PyUFuncObject *ufunc, type_tup, out_dtypes); } else { - return PyUFunc_SimpleUnaryOperationTypeResolver(ufunc, casting, + return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); } } @@ -752,7 +669,7 @@ PyUFunc_AdditionTypeResolver(PyUFuncObject *ufunc, /* Use the default when datetime and timedelta are not involved */ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { - return PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting, + return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); } @@ -925,7 +842,7 @@ PyUFunc_SubtractionTypeResolver(PyUFuncObject *ufunc, /* Use the default when datetime and timedelta are not involved */ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { int ret; - ret = PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting, + ret = PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); if (ret < 0) { return ret; @@ -1088,7 +1005,7 @@ PyUFunc_MultiplicationTypeResolver(PyUFuncObject *ufunc, /* Use the default when datetime and timedelta are not involved */ if (!PyTypeNum_ISDATETIME(type_num1) && !PyTypeNum_ISDATETIME(type_num2)) { - return PyUFunc_SimpleBinaryOperationTypeResolver(ufunc, casting, + return PyUFunc_SimpleUniformOperationTypeResolver(ufunc, casting, operands, type_tup, out_dtypes); } diff --git a/numpy/core/src/umath/ufunc_type_resolution.h b/numpy/core/src/umath/ufunc_type_resolution.h index 78313b1ef6bf..7256fcf61678 100644 --- a/numpy/core/src/umath/ufunc_type_resolution.h +++ b/numpy/core/src/umath/ufunc_type_resolution.h @@ -8,13 +8,6 @@ PyUFunc_SimpleBinaryComparisonTypeResolver(PyUFuncObject *ufunc, PyObject *type_tup, PyArray_Descr **out_dtypes); -NPY_NO_EXPORT int -PyUFunc_SimpleUnaryOperationTypeResolver(PyUFuncObject *ufunc, - NPY_CASTING casting, - PyArrayObject **operands, - PyObject *type_tup, - PyArray_Descr **out_dtypes); - NPY_NO_EXPORT int PyUFunc_NegativeTypeResolver(PyUFuncObject *ufunc, NPY_CASTING casting, @@ -30,7 +23,7 @@ PyUFunc_OnesLikeTypeResolver(PyUFuncObject *ufunc, PyArray_Descr **out_dtypes); NPY_NO_EXPORT int -PyUFunc_SimpleBinaryOperationTypeResolver(PyUFuncObject *ufunc, +PyUFunc_SimpleUniformOperationTypeResolver(PyUFuncObject *ufunc, NPY_CASTING casting, PyArrayObject **operands, PyObject *type_tup,