Created
March 19, 2015 13:27
-
-
Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.
Revisions
-
ewmoore created this gist
Mar 19, 2015 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,251 @@ #define NPY_NO_DEPRECATED_API NPY_API_VERSION #include <Python.h> #include <numpy/arrayobject.h> #include "numpy/npy_3kcompat.h" PyObject* sum_and_prod(PyArrayObject *arr, int *axis, int naxis) { NpyIter *iter_outer, *iter_inner; NpyIter_IterNextFunc *iternext_outer, *iternext_inner; char **dataptr_outer, **dataptr_inner; npy_intp *strideptr_outer, *strideptr_inner; npy_intp *innersizeptr_outer, *innersizeptr_inner; PyArrayObject *ops[2]; PyObject *out; npy_uint32 outer_flags = 0; npy_uint32 op_flags_outer[2]; int oa_ndim_outer; int *oa_axes_outer[2]; int arr_axes_outer[NPY_MAXDIMS]; int out_axes_outer[NPY_MAXDIMS]; //PyArrayDescr *dtype_outer[2]; npy_intp out_shape[NPY_MAXDIMS]; int i,j,k,m,n,flag; npy_uint32 inner_flags = NPY_ITER_EXTERNAL_LOOP | NPY_ITER_BUFFERED | NPY_ITER_GROWINNER; npy_uint32 op_flags_inner[1]; int oa_ndim_inner; int *oa_axes_inner[1]; int arr_axes_inner[NPY_MAXDIMS]; PyArray_Descr *dtype_inner[1]; PyArray_Descr *double_dtype = PyArray_DescrFromType(NPY_DOUBLE); ops[0] = arr; op_flags_outer[0] = NPY_ITER_READONLY; op_flags_inner[0] = NPY_ITER_READONLY | NPY_ITER_NBO; //dtype_outer[0] = PyArray_DESCR(arr); dtype_inner[0] = double_dtype; op_flags_outer[1] = NPY_ITER_WRITEONLY; //dtype_outer[1] = double_dtype; if (naxis <= 0) { PyErr_SetString(PyExc_ValueError, "Need at at least 1 axis"); Py_DECREF(double_dtype); } if (naxis > PyArray_NDIM(arr)) { PyErr_SetString(PyExc_ValueError, "Too many axes"); Py_DECREF(double_dtype); } for (i = 0; i < naxis; ++i) { if (axis[i] < 0) { axis[i] += PyArray_NDIM(arr); } if (axis[i] >= naxis) { PyErr_SetString(PyExc_ValueError, "axis specified does not exist"); Py_DECREF(double_dtype); return 0; } } oa_ndim_outer = PyArray_NDIM(arr) - naxis; oa_ndim_inner = naxis; i = 0; j = 0; k = 0; for (m = 0; m < PyArray_NDIM(arr); ++m) { flag = 0; for (n = 0; n < naxis; ++n) { if (m == axis[n]) { flag = 1; break; } } if (flag) { arr_axes_inner[i++] = m; } else { arr_axes_outer[j++] = m; out_axes_outer[k] = k; out_shape[k++] = PyArray_DIM(arr, m); } } out_shape[k] = 2; out = PyArray_Zeros(++k, out_shape, double_dtype, 0); if (!out) { Py_DECREF(double_dtype); return 0; } ops[1] = (PyArrayObject*)out; /* for (k = 0; k < PyArray_NDIM(arr); ++k) { printf("%d %d %d\n", arr_axes_inner[k], arr_axes_outer[k], out_axes_outer[k]); } */ oa_axes_outer[0] = arr_axes_outer; oa_axes_outer[1] = out_axes_outer; iter_outer = NpyIter_AdvancedNew(2, ops, outer_flags, NPY_KEEPORDER, NPY_NO_CASTING, op_flags_outer, NULL, oa_ndim_outer, oa_axes_outer, NULL, -1); if (!iter_outer) { Py_DECREF(double_dtype); Py_DECREF(out); return 0; } iternext_outer = NpyIter_GetIterNext(iter_outer, NULL); if (!iternext_outer) { NpyIter_Deallocate(iter_outer); Py_DECREF(double_dtype); Py_DECREF(out); } dataptr_outer = NpyIter_GetDataPtrArray(iter_outer); strideptr_outer = NpyIter_GetInnerStrideArray(iter_outer); innersizeptr_outer = NpyIter_GetInnerLoopSizePtr(iter_outer); oa_axes_inner[0] = arr_axes_inner; iter_inner = NpyIter_AdvancedNew(1, ops, inner_flags, NPY_KEEPORDER, NPY_UNSAFE_CASTING, op_flags_inner, dtype_inner, oa_ndim_inner, oa_axes_inner, NULL, -1); if (!iter_inner) { NpyIter_Deallocate(iter_outer); Py_DECREF(double_dtype); Py_DECREF(out); } iternext_inner = NpyIter_GetIterNext(iter_inner, NULL); if (!iternext_inner) { NpyIter_Deallocate(iter_outer); NpyIter_Deallocate(iter_inner); Py_DECREF(double_dtype); Py_DECREF(out); } dataptr_inner = NpyIter_GetDataPtrArray(iter_inner); strideptr_inner = NpyIter_GetInnerStrideArray(iter_inner); innersizeptr_inner = NpyIter_GetInnerLoopSizePtr(iter_inner); do { double *sum_ptr = (double*)dataptr_outer[1]; sum_ptr[0] = 0; //double *prod_ptr = (double*)(dataptr_outer[1] + strideptr_outer[1]); // this is clearly cheating... double *prod_ptr = (double*)(dataptr_outer[1] + PyArray_STRIDE(out, PyArray_NDIM(out)-1)); prod_ptr[0] = 1; //printf("%ld, %ld\n", (long)sum_ptr, (long)prod_ptr); NpyIter_ResetBasePointers(iter_inner, dataptr_outer, NULL); do { for (k = 0; k < *innersizeptr_inner; ++k) { const double in = *((double*)(dataptr_inner[0] + k*strideptr_inner[0])); sum_ptr[0] += in; prod_ptr[0] *= in; } } while(iternext_inner(iter_inner)); } while(iternext_outer(iter_outer)); NpyIter_Deallocate(iter_outer); NpyIter_Deallocate(iter_inner); Py_DECREF(double_dtype); return out; } static PyObject* iter_test(PyObject *self, PyObject* args) { PyObject *o_arr; PyObject *axes; PyObject *out; PyArrayObject* a_arr; Py_ssize_t naxes; Py_ssize_t size = PyTuple_GET_SIZE(args); int *axis; int k; if (size != 2) { PyErr_SetString(PyExc_TypeError, "wrong # args, expected 2"); return 0; } o_arr = PyTuple_GET_ITEM(args, 0); if (!PyArray_Check(o_arr)) { PyErr_SetString(PyExc_ValueError, "Expected an array"); return 0; } a_arr = (PyArrayObject*)o_arr; axes = PyTuple_GET_ITEM(args, 1); if (!PyTuple_Check(axes)) { PyErr_SetString(PyExc_ValueError, "Expected a tuple"); return 0; } // missing some error checking... naxes = PyTuple_GET_SIZE(axes); axis = malloc(naxes*sizeof(int)); for (k = 0; k < naxes; ++k) { axis[k] = PyInt_AsLong(PyTuple_GET_ITEM(axes, k)); } out = sum_and_prod(a_arr, axis, naxes); free(axis); return out; } PyMethodDef module_methods[] = { {"iter_test", &iter_test, METH_VARARGS, ""}, {0} /* sentinel */ }; #if defined(NPY_PY3K) static struct PyModuleDef moduledef = { PyModuleDef_HEAD_INIT, "ndit", NULL, -1, module_methods, NULL, NULL, NULL, NULL }; #endif #if defined(NPY_PY3K) PyMODINIT_FUNC PyInit_ndit(void) { #else PyMODINIT_FUNC initndit(void) { #endif PyObject *m; import_array(); if (PyErr_Occurred()) { return; } /* Create module */ #if defined(NPY_PY3K) m = PyModule_Create(&moduledef); #else m = Py_InitModule("ndit", module_methods); #endif if (!m) { return; } } This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,12 @@ from distutils.core import setup, Extension import numpy as np ext_modules = [Extension('ndit', sources=['ndit.c'])] setup( name = 'ndit', version = '1.0', include_dirs = [np.get_include()], ext_modules = ext_modules )