Thanks to visit codestin.com
Credit goes to gist.github.com

Skip to content

Instantly share code, notes, and snippets.

@ewmoore
Created March 19, 2015 13:27
Show Gist options
  • Select an option

  • Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.

Select an option

Save ewmoore/e88ee9dd84c1d9d58892 to your computer and use it in GitHub Desktop.

Revisions

  1. ewmoore created this gist Mar 19, 2015.
    251 changes: 251 additions & 0 deletions ndit.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,251 @@
    #define NPY_NO_DEPRECATED_API NPY_API_VERSION
    #include <Python.h>
    #include <numpy/arrayobject.h>
    #include "numpy/npy_3kcompat.h"


    PyObject* sum_and_prod(PyArrayObject *arr, int *axis, int naxis)
    {
    NpyIter *iter_outer, *iter_inner;
    NpyIter_IterNextFunc *iternext_outer, *iternext_inner;
    char **dataptr_outer, **dataptr_inner;
    npy_intp *strideptr_outer, *strideptr_inner;
    npy_intp *innersizeptr_outer, *innersizeptr_inner;
    PyArrayObject *ops[2];
    PyObject *out;
    npy_uint32 outer_flags = 0;
    npy_uint32 op_flags_outer[2];
    int oa_ndim_outer;
    int *oa_axes_outer[2];
    int arr_axes_outer[NPY_MAXDIMS];
    int out_axes_outer[NPY_MAXDIMS];
    //PyArrayDescr *dtype_outer[2];
    npy_intp out_shape[NPY_MAXDIMS];
    int i,j,k,m,n,flag;

    npy_uint32 inner_flags = NPY_ITER_EXTERNAL_LOOP | NPY_ITER_BUFFERED | NPY_ITER_GROWINNER;
    npy_uint32 op_flags_inner[1];
    int oa_ndim_inner;
    int *oa_axes_inner[1];
    int arr_axes_inner[NPY_MAXDIMS];
    PyArray_Descr *dtype_inner[1];
    PyArray_Descr *double_dtype = PyArray_DescrFromType(NPY_DOUBLE);

    ops[0] = arr;
    op_flags_outer[0] = NPY_ITER_READONLY;
    op_flags_inner[0] = NPY_ITER_READONLY | NPY_ITER_NBO;
    //dtype_outer[0] = PyArray_DESCR(arr);
    dtype_inner[0] = double_dtype;
    op_flags_outer[1] = NPY_ITER_WRITEONLY;
    //dtype_outer[1] = double_dtype;

    if (naxis <= 0) {
    PyErr_SetString(PyExc_ValueError, "Need at at least 1 axis");
    Py_DECREF(double_dtype);
    }

    if (naxis > PyArray_NDIM(arr)) {
    PyErr_SetString(PyExc_ValueError, "Too many axes");
    Py_DECREF(double_dtype);
    }

    for (i = 0; i < naxis; ++i) {
    if (axis[i] < 0) {
    axis[i] += PyArray_NDIM(arr);
    }
    if (axis[i] >= naxis) {
    PyErr_SetString(PyExc_ValueError, "axis specified does not exist");
    Py_DECREF(double_dtype);
    return 0;
    }
    }

    oa_ndim_outer = PyArray_NDIM(arr) - naxis;
    oa_ndim_inner = naxis;
    i = 0;
    j = 0;
    k = 0;
    for (m = 0; m < PyArray_NDIM(arr); ++m) {
    flag = 0;
    for (n = 0; n < naxis; ++n) {
    if (m == axis[n]) {
    flag = 1;
    break;
    }
    }
    if (flag) {
    arr_axes_inner[i++] = m;
    } else {
    arr_axes_outer[j++] = m;
    out_axes_outer[k] = k;
    out_shape[k++] = PyArray_DIM(arr, m);
    }
    }
    out_shape[k] = 2;
    out = PyArray_Zeros(++k, out_shape, double_dtype, 0);
    if (!out) {
    Py_DECREF(double_dtype);
    return 0;
    }
    ops[1] = (PyArrayObject*)out;
    /*
    for (k = 0; k < PyArray_NDIM(arr); ++k) {
    printf("%d %d %d\n", arr_axes_inner[k], arr_axes_outer[k], out_axes_outer[k]);
    }
    */
    oa_axes_outer[0] = arr_axes_outer;
    oa_axes_outer[1] = out_axes_outer;
    iter_outer = NpyIter_AdvancedNew(2, ops, outer_flags, NPY_KEEPORDER,
    NPY_NO_CASTING, op_flags_outer, NULL,
    oa_ndim_outer, oa_axes_outer, NULL, -1);

    if (!iter_outer) {
    Py_DECREF(double_dtype);
    Py_DECREF(out);
    return 0;
    }

    iternext_outer = NpyIter_GetIterNext(iter_outer, NULL);
    if (!iternext_outer) {
    NpyIter_Deallocate(iter_outer);
    Py_DECREF(double_dtype);
    Py_DECREF(out);
    }
    dataptr_outer = NpyIter_GetDataPtrArray(iter_outer);
    strideptr_outer = NpyIter_GetInnerStrideArray(iter_outer);
    innersizeptr_outer = NpyIter_GetInnerLoopSizePtr(iter_outer);

    oa_axes_inner[0] = arr_axes_inner;
    iter_inner = NpyIter_AdvancedNew(1, ops, inner_flags, NPY_KEEPORDER,
    NPY_UNSAFE_CASTING, op_flags_inner,
    dtype_inner, oa_ndim_inner,
    oa_axes_inner, NULL, -1);
    if (!iter_inner) {
    NpyIter_Deallocate(iter_outer);
    Py_DECREF(double_dtype);
    Py_DECREF(out);
    }
    iternext_inner = NpyIter_GetIterNext(iter_inner, NULL);
    if (!iternext_inner) {
    NpyIter_Deallocate(iter_outer);
    NpyIter_Deallocate(iter_inner);
    Py_DECREF(double_dtype);
    Py_DECREF(out);
    }
    dataptr_inner = NpyIter_GetDataPtrArray(iter_inner);
    strideptr_inner = NpyIter_GetInnerStrideArray(iter_inner);
    innersizeptr_inner = NpyIter_GetInnerLoopSizePtr(iter_inner);

    do {
    double *sum_ptr = (double*)dataptr_outer[1];
    sum_ptr[0] = 0;
    //double *prod_ptr = (double*)(dataptr_outer[1] + strideptr_outer[1]);
    // this is clearly cheating...
    double *prod_ptr = (double*)(dataptr_outer[1] + PyArray_STRIDE(out, PyArray_NDIM(out)-1));
    prod_ptr[0] = 1;
    //printf("%ld, %ld\n", (long)sum_ptr, (long)prod_ptr);
    NpyIter_ResetBasePointers(iter_inner, dataptr_outer, NULL);
    do {
    for (k = 0; k < *innersizeptr_inner; ++k) {
    const double in = *((double*)(dataptr_inner[0] + k*strideptr_inner[0]));
    sum_ptr[0] += in;
    prod_ptr[0] *= in;
    }
    } while(iternext_inner(iter_inner));
    } while(iternext_outer(iter_outer));

    NpyIter_Deallocate(iter_outer);
    NpyIter_Deallocate(iter_inner);
    Py_DECREF(double_dtype);
    return out;
    }

    static PyObject* iter_test(PyObject *self, PyObject* args)
    {
    PyObject *o_arr;
    PyObject *axes;
    PyObject *out;
    PyArrayObject* a_arr;
    Py_ssize_t naxes;
    Py_ssize_t size = PyTuple_GET_SIZE(args);
    int *axis;
    int k;


    if (size != 2) {
    PyErr_SetString(PyExc_TypeError,
    "wrong # args, expected 2");
    return 0;
    }
    o_arr = PyTuple_GET_ITEM(args, 0);
    if (!PyArray_Check(o_arr)) {
    PyErr_SetString(PyExc_ValueError,
    "Expected an array");
    return 0;
    }
    a_arr = (PyArrayObject*)o_arr;

    axes = PyTuple_GET_ITEM(args, 1);
    if (!PyTuple_Check(axes)) {
    PyErr_SetString(PyExc_ValueError,
    "Expected a tuple");
    return 0;
    }
    // missing some error checking...
    naxes = PyTuple_GET_SIZE(axes);
    axis = malloc(naxes*sizeof(int));

    for (k = 0; k < naxes; ++k) {
    axis[k] = PyInt_AsLong(PyTuple_GET_ITEM(axes, k));
    }

    out = sum_and_prod(a_arr, axis, naxes);

    free(axis);
    return out;
    }


    PyMethodDef module_methods[] = {
    {"iter_test", &iter_test, METH_VARARGS, ""},
    {0} /* sentinel */
    };

    #if defined(NPY_PY3K)
    static struct PyModuleDef moduledef = {
    PyModuleDef_HEAD_INIT,
    "ndit",
    NULL,
    -1,
    module_methods,
    NULL,
    NULL,
    NULL,
    NULL
    };
    #endif

    #if defined(NPY_PY3K)
    PyMODINIT_FUNC PyInit_ndit(void) {
    #else
    PyMODINIT_FUNC initndit(void) {
    #endif

    PyObject *m;

    import_array();
    if (PyErr_Occurred()) {
    return;
    }

    /* Create module */
    #if defined(NPY_PY3K)
    m = PyModule_Create(&moduledef);
    #else
    m = Py_InitModule("ndit", module_methods);
    #endif

    if (!m) {
    return;
    }
    }
    12 changes: 12 additions & 0 deletions setup.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,12 @@
    from distutils.core import setup, Extension
    import numpy as np

    ext_modules = [Extension('ndit', sources=['ndit.c'])]

    setup(
    name = 'ndit',
    version = '1.0',
    include_dirs = [np.get_include()],
    ext_modules = ext_modules
    )