From 9ccbf6e172c4420bfdcd048593ab87a9c3861287 Mon Sep 17 00:00:00 2001 From: Sayed Adel Date: Wed, 19 Aug 2020 19:30:24 +0200 Subject: [PATCH] ENH: Add runtime CPU dispatching support for einsum This patch doesn't cause any performance changes, it just aims to simplify the review process for numpy#17049, according to https://github.com/numpy/numpy/pull/17049#discussion_r47303 --- numpy/core/setup.py | 1 + numpy/core/src/multiarray/einsum.c.src | 1921 +---------------- .../core/src/multiarray/einsum.dispatch.c.src | 1873 ++++++++++++++++ numpy/core/src/multiarray/einsum_helpers.h | 52 + 4 files changed, 1940 insertions(+), 1907 deletions(-) create mode 100644 numpy/core/src/multiarray/einsum.dispatch.c.src create mode 100644 numpy/core/src/multiarray/einsum_helpers.h diff --git a/numpy/core/setup.py b/numpy/core/setup.py index aede12080017..e854fc0cae5b 100644 --- a/numpy/core/setup.py +++ b/numpy/core/setup.py @@ -853,6 +853,7 @@ def get_mathlib_info(*args): join('src', 'multiarray', 'dragon4.c'), join('src', 'multiarray', 'dtype_transfer.c'), join('src', 'multiarray', 'einsum.c.src'), + join('src', 'multiarray', 'einsum.dispatch.c.src'), join('src', 'multiarray', 'flagsobject.c'), join('src', 'multiarray', 'getset.c'), join('src', 'multiarray', 'hashdescr.c'), diff --git a/numpy/core/src/multiarray/einsum.c.src b/numpy/core/src/multiarray/einsum.c.src index 2538e05c626a..30254fb800ce 100644 --- a/numpy/core/src/multiarray/einsum.c.src +++ b/numpy/core/src/multiarray/einsum.c.src @@ -8,1917 +8,24 @@ * See LICENSE.txt for the license. */ -#define PY_SSIZE_T_CLEAN -#include "Python.h" -#include "structmember.h" - -#define NPY_NO_DEPRECATED_API NPY_API_VERSION -#define _MULTIARRAYMODULE -#include -#include -#include -#include - -#include - -#include "convert.h" -#include "common.h" -#include "ctors.h" - -#ifdef NPY_HAVE_SSE_INTRINSICS -#define EINSUM_USE_SSE1 1 -#else -#define EINSUM_USE_SSE1 0 -#endif - -#ifdef NPY_HAVE_SSE2_INTRINSICS -#define EINSUM_USE_SSE2 1 -#else -#define EINSUM_USE_SSE2 0 -#endif - -#if EINSUM_USE_SSE1 -#include -#endif - -#if EINSUM_USE_SSE2 -#include -#endif - -#define EINSUM_IS_SSE_ALIGNED(x) ((((npy_intp)x)&0xf) == 0) - -/********** PRINTF DEBUG TRACING **************/ -#define NPY_EINSUM_DBG_TRACING 0 - -#if NPY_EINSUM_DBG_TRACING -#define NPY_EINSUM_DBG_PRINT(s) printf("%s", s); -#define NPY_EINSUM_DBG_PRINT1(s, p1) printf(s, p1); -#define NPY_EINSUM_DBG_PRINT2(s, p1, p2) printf(s, p1, p2); -#define NPY_EINSUM_DBG_PRINT3(s, p1, p2, p3) printf(s); -#else -#define NPY_EINSUM_DBG_PRINT(s) -#define NPY_EINSUM_DBG_PRINT1(s, p1) -#define NPY_EINSUM_DBG_PRINT2(s, p1, p2) -#define NPY_EINSUM_DBG_PRINT3(s, p1, p2, p3) -#endif -/**********************************************/ - -/**begin repeat - * #name = byte, short, int, long, longlong, - * ubyte, ushort, uint, ulong, ulonglong, - * half, float, double, longdouble, - * cfloat, cdouble, clongdouble# - * #type = npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, - * npy_half, npy_float, npy_double, npy_longdouble, - * npy_cfloat, npy_cdouble, npy_clongdouble# - * #temptype = npy_byte, npy_short, npy_int, npy_long, npy_longlong, - * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, - * npy_float, npy_float, npy_double, npy_longdouble, - * npy_float, npy_double, npy_longdouble# - * #to = ,,,,, - * ,,,,, - * npy_float_to_half,,,, - * ,,# - * #from = ,,,,, - * ,,,,, - * npy_half_to_float,,,, - * ,,# - * #complex = 0*5, - * 0*5, - * 0*4, - * 1*3# - * #float32 = 0*5, - * 0*5, - * 0,1,0,0, - * 0*3# - * #float64 = 0*5, - * 0*5, - * 0,0,1,0, - * 0*3# - */ - -/**begin repeat1 - * #nop = 1, 2, 3, 1000# - * #noplabel = one, two, three, any# - */ -static void -@name@_sum_of_products_@noplabel@(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ -#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) - char *data0 = dataptr[0]; - npy_intp stride0 = strides[0]; -#endif -#if (@nop@ == 2 || @nop@ == 3) && !@complex@ - char *data1 = dataptr[1]; - npy_intp stride1 = strides[1]; -#endif -#if (@nop@ == 3) && !@complex@ - char *data2 = dataptr[2]; - npy_intp stride2 = strides[2]; -#endif -#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) - char *data_out = dataptr[@nop@]; - npy_intp stride_out = strides[@nop@]; -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_@noplabel@ (%d)\n", (int)count); - - while (count--) { -#if !@complex@ -# if @nop@ == 1 - *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) + - @from@(*(@type@ *)data_out)); - data0 += stride0; - data_out += stride_out; -# elif @nop@ == 2 - *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) * - @from@(*(@type@ *)data1) + - @from@(*(@type@ *)data_out)); - data0 += stride0; - data1 += stride1; - data_out += stride_out; -# elif @nop@ == 3 - *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) * - @from@(*(@type@ *)data1) * - @from@(*(@type@ *)data2) + - @from@(*(@type@ *)data_out)); - data0 += stride0; - data1 += stride1; - data2 += stride2; - data_out += stride_out; -# else - @temptype@ temp = @from@(*(@type@ *)dataptr[0]); - int i; - for (i = 1; i < nop; ++i) { - temp *= @from@(*(@type@ *)dataptr[i]); - } - *(@type@ *)dataptr[nop] = @to@(temp + - @from@(*(@type@ *)dataptr[i])); - for (i = 0; i <= nop; ++i) { - dataptr[i] += strides[i]; - } -# endif -#else /* complex */ -# if @nop@ == 1 - ((@temptype@ *)data_out)[0] = ((@temptype@ *)data0)[0] + - ((@temptype@ *)data_out)[0]; - ((@temptype@ *)data_out)[1] = ((@temptype@ *)data0)[1] + - ((@temptype@ *)data_out)[1]; - data0 += stride0; - data_out += stride_out; -# else -# if @nop@ <= 3 -#define _SUMPROD_NOP @nop@ -# else -#define _SUMPROD_NOP nop -# endif - @temptype@ re, im, tmp; - int i; - re = ((@temptype@ *)dataptr[0])[0]; - im = ((@temptype@ *)dataptr[0])[1]; - for (i = 1; i < _SUMPROD_NOP; ++i) { - tmp = re * ((@temptype@ *)dataptr[i])[0] - - im * ((@temptype@ *)dataptr[i])[1]; - im = re * ((@temptype@ *)dataptr[i])[1] + - im * ((@temptype@ *)dataptr[i])[0]; - re = tmp; - } - ((@temptype@ *)dataptr[_SUMPROD_NOP])[0] = re + - ((@temptype@ *)dataptr[_SUMPROD_NOP])[0]; - ((@temptype@ *)dataptr[_SUMPROD_NOP])[1] = im + - ((@temptype@ *)dataptr[_SUMPROD_NOP])[1]; - - for (i = 0; i <= _SUMPROD_NOP; ++i) { - dataptr[i] += strides[i]; - } -#undef _SUMPROD_NOP -# endif -#endif - } -} - -#if @nop@ == 1 - -static void -@name@_sum_of_products_contig_one(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @type@ *data_out = (@type@ *)dataptr[1]; - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_one (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: -#if !@complex@ - data_out[@i@] = @to@(@from@(data0[@i@]) + - @from@(data_out[@i@])); -#else - ((@temptype@ *)data_out + 2*@i@)[0] = - ((@temptype@ *)data0 + 2*@i@)[0] + - ((@temptype@ *)data_out + 2*@i@)[0]; - ((@temptype@ *)data_out + 2*@i@)[1] = - ((@temptype@ *)data0 + 2*@i@)[1] + - ((@temptype@ *)data_out + 2*@i@)[1]; -#endif -/**end repeat2**/ - case 0: - return; - } - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ -#if !@complex@ - data_out[@i@] = @to@(@from@(data0[@i@]) + - @from@(data_out[@i@])); -#else /* complex */ - ((@temptype@ *)data_out + 2*@i@)[0] = - ((@temptype@ *)data0 + 2*@i@)[0] + - ((@temptype@ *)data_out + 2*@i@)[0]; - ((@temptype@ *)data_out + 2*@i@)[1] = - ((@temptype@ *)data0 + 2*@i@)[1] + - ((@temptype@ *)data_out + 2*@i@)[1]; -#endif -/**end repeat2**/ - data0 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -#elif @nop@ == 2 && !@complex@ - -static void -@name@_sum_of_products_contig_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @type@ *data1 = (@type@ *)dataptr[1]; - @type@ *data_out = (@type@ *)dataptr[2]; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, b; -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, b; -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_two (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - data_out[@i@] = @to@(@from@(data0[@i@]) * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ - case 0: - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) && - EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); - b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); - _mm_store_ps(data_out+@i@, b); -/**end repeat2**/ - data0 += 8; - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) && - EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@)); - b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); - _mm_store_pd(data_out+@i@, b); -/**end repeat2**/ - data0 += 8; - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); - b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); - _mm_storeu_ps(data_out+@i@, b); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@)); - b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); - _mm_storeu_pd(data_out+@i@, b); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - data_out[@i@] = @to@(@from@(data0[@i@]) * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ -#endif - data0 += 8; - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -/* Some extra specializations for the two operand case */ -static void -@name@_sum_of_products_stride0_contig_outcontig_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @temptype@ value0 = @from@(*(@type@ *)dataptr[0]); - @type@ *data1 = (@type@ *)dataptr[1]; - @type@ *data_out = (@type@ *)dataptr[2]; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, b, value0_sse; -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, b, value0_sse; -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_stride0_contig_outcontig_two (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - data_out[@i@] = @to@(value0 * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ - case 0: - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - value0_sse = _mm_set_ps1(value0); - - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(value0_sse, _mm_load_ps(data1+@i@)); - b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); - _mm_store_ps(data_out+@i@, b); -/**end repeat2**/ - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - if (count > 0) { - goto finish_after_unrolled_loop; - } - else { - return; - } - } -#elif EINSUM_USE_SSE2 && @float64@ - value0_sse = _mm_set1_pd(value0); - - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(value0_sse, _mm_load_pd(data1+@i@)); - b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); - _mm_store_pd(data_out+@i@, b); -/**end repeat2**/ - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - if (count > 0) { - goto finish_after_unrolled_loop; - } - else { - return; - } - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(value0_sse, _mm_loadu_ps(data1+@i@)); - b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); - _mm_storeu_ps(data_out+@i@, b); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(value0_sse, _mm_loadu_pd(data1+@i@)); - b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); - _mm_storeu_pd(data_out+@i@, b); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - data_out[@i@] = @to@(value0 * - @from@(data1[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ -#endif - data1 += 8; - data_out += 8; - } - - /* Finish off the loop */ - if (count > 0) { - goto finish_after_unrolled_loop; - } -} - -static void -@name@_sum_of_products_contig_stride0_outcontig_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @temptype@ value1 = @from@(*(@type@ *)dataptr[1]); - @type@ *data_out = (@type@ *)dataptr[2]; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, b, value1_sse; -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, b, value1_sse; -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_stride0_outcontig_two (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - data_out[@i@] = @to@(@from@(data0[@i@])* - value1 + - @from@(data_out[@i@])); -/**end repeat2**/ - case 0: - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - value1_sse = _mm_set_ps1(value1); - - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(_mm_load_ps(data0+@i@), value1_sse); - b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); - _mm_store_ps(data_out+@i@, b); -/**end repeat2**/ - data0 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - value1_sse = _mm_set1_pd(value1); - - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(_mm_load_pd(data0+@i@), value1_sse); - b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); - _mm_store_pd(data_out+@i@, b); -/**end repeat2**/ - data0 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ -/**begin repeat2 - * #i = 0, 4# - */ - a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), value1_sse); - b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); - _mm_storeu_ps(data_out+@i@, b); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), value1_sse); - b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); - _mm_storeu_pd(data_out+@i@, b); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - data_out[@i@] = @to@(@from@(data0[@i@])* - value1 + - @from@(data_out[@i@])); -/**end repeat2**/ -#endif - data0 += 8; - data_out += 8; - } - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -static void -@name@_sum_of_products_contig_contig_outstride0_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @type@ *data1 = (@type@ *)dataptr[1]; - @temptype@ accum = 0; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, accum_sse = _mm_setzero_ps(); -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, accum_sse = _mm_setzero_pd(); -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_contig_outstride0_two (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - accum += @from@(data0[@i@]) * @from@(data1[@i@]); -/**end repeat2**/ - case 0: - *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum); - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - - _mm_prefetch(data0 + 512, _MM_HINT_T0); - _mm_prefetch(data1 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); - accum_sse = _mm_add_ps(accum_sse, a); -/**end repeat2**/ - data0 += 8; - data1 += 8; - } - - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - - _mm_prefetch(data0 + 512, _MM_HINT_T0); - _mm_prefetch(data1 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@)); - accum_sse = _mm_add_pd(accum_sse, a); -/**end repeat2**/ - data0 += 8; - data1 += 8; - } - - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ - _mm_prefetch(data0 + 512, _MM_HINT_T0); - _mm_prefetch(data1 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); - accum_sse = _mm_add_ps(accum_sse, a); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ - _mm_prefetch(data0 + 512, _MM_HINT_T0); - _mm_prefetch(data1 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@)); - accum_sse = _mm_add_pd(accum_sse, a); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - accum += @from@(data0[@i@]) * @from@(data1[@i@]); -/**end repeat2**/ -#endif - data0 += 8; - data1 += 8; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); -#elif EINSUM_USE_SSE2 && @float64@ - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); -#endif - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -static void -@name@_sum_of_products_stride0_contig_outstride0_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @temptype@ value0 = @from@(*(@type@ *)dataptr[0]); - @type@ *data1 = (@type@ *)dataptr[1]; - @temptype@ accum = 0; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, accum_sse = _mm_setzero_ps(); -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, accum_sse = _mm_setzero_pd(); -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_stride0_contig_outstride0_two (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - accum += @from@(data1[@i@]); -/**end repeat2**/ - case 0: - *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + value0 * accum); - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data1+@i@)); -/**end repeat2**/ - data1 += 8; - } - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data1)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data1+@i@)); -/**end repeat2**/ - data1 += 8; - } - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data1+@i@)); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data1+@i@)); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - accum += @from@(data1[@i@]); -/**end repeat2**/ -#endif - data1 += 8; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); -#elif EINSUM_USE_SSE2 && @float64@ - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); -#endif - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -static void -@name@_sum_of_products_contig_stride0_outstride0_two(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @temptype@ value1 = @from@(*(@type@ *)dataptr[1]); - @temptype@ accum = 0; - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, accum_sse = _mm_setzero_ps(); -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, accum_sse = _mm_setzero_pd(); -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_stride0_outstride0_two (%d)\n", - (int)count); +#include "einsum_helpers.h" -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: - accum += @from@(data0[@i@]); -/**end repeat2**/ - case 0: - *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum * value1); - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@)); -/**end repeat2**/ - data0 += 8; - } - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@)); -/**end repeat2**/ - data0 += 8; - } - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@)); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@)); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - accum += @from@(data0[@i@]); -/**end repeat2**/ -#endif - data0 += 8; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); -#elif EINSUM_USE_SSE2 && @float64@ - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); -#endif - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -#elif @nop@ == 3 && !@complex@ - -static void -@name@_sum_of_products_contig_three(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - @type@ *data0 = (@type@ *)dataptr[0]; - @type@ *data1 = (@type@ *)dataptr[1]; - @type@ *data2 = (@type@ *)dataptr[2]; - @type@ *data_out = (@type@ *)dataptr[3]; - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - data_out[@i@] = @to@(@from@(data0[@i@]) * - @from@(data1[@i@]) * - @from@(data2[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ - data0 += 8; - data1 += 8; - data2 += 8; - data_out += 8; - } - - /* Finish off the loop */ - -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - if (count-- == 0) { - return; - } - data_out[@i@] = @to@(@from@(data0[@i@]) * - @from@(data1[@i@]) * - @from@(data2[@i@]) + - @from@(data_out[@i@])); -/**end repeat2**/ -} - -#else /* @nop@ > 3 || @complex */ - -static void -@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr, - npy_intp const *NPY_UNUSED(strides), npy_intp count) -{ - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_@noplabel@ (%d)\n", - (int)count); - - while (count--) { -#if !@complex@ - @temptype@ temp = @from@(*(@type@ *)dataptr[0]); - int i; - for (i = 1; i < nop; ++i) { - temp *= @from@(*(@type@ *)dataptr[i]); - } - *(@type@ *)dataptr[nop] = @to@(temp + - @from@(*(@type@ *)dataptr[i])); - for (i = 0; i <= nop; ++i) { - dataptr[i] += sizeof(@type@); - } -#else /* complex */ -# if @nop@ <= 3 -# define _SUMPROD_NOP @nop@ -# else -# define _SUMPROD_NOP nop -# endif - @temptype@ re, im, tmp; - int i; - re = ((@temptype@ *)dataptr[0])[0]; - im = ((@temptype@ *)dataptr[0])[1]; - for (i = 1; i < _SUMPROD_NOP; ++i) { - tmp = re * ((@temptype@ *)dataptr[i])[0] - - im * ((@temptype@ *)dataptr[i])[1]; - im = re * ((@temptype@ *)dataptr[i])[1] + - im * ((@temptype@ *)dataptr[i])[0]; - re = tmp; - } - ((@temptype@ *)dataptr[_SUMPROD_NOP])[0] = re + - ((@temptype@ *)dataptr[_SUMPROD_NOP])[0]; - ((@temptype@ *)dataptr[_SUMPROD_NOP])[1] = im + - ((@temptype@ *)dataptr[_SUMPROD_NOP])[1]; - - for (i = 0; i <= _SUMPROD_NOP; ++i) { - dataptr[i] += sizeof(@type@); - } -# undef _SUMPROD_NOP -#endif - } -} - -#endif /* functions for various @nop@ */ - -#if @nop@ == 1 - -static void -@name@_sum_of_products_contig_outstride0_one(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ -#if @complex@ - @temptype@ accum_re = 0, accum_im = 0; - @temptype@ *data0 = (@temptype@ *)dataptr[0]; -#else - @temptype@ accum = 0; - @type@ *data0 = (@type@ *)dataptr[0]; -#endif - -#if EINSUM_USE_SSE1 && @float32@ - __m128 a, accum_sse = _mm_setzero_ps(); -#elif EINSUM_USE_SSE2 && @float64@ - __m128d a, accum_sse = _mm_setzero_pd(); -#endif - - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_outstride0_one (%d)\n", - (int)count); - -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat2 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: -#if !@complex@ - accum += @from@(data0[@i@]); -#else /* complex */ - accum_re += data0[2*@i@+0]; - accum_im += data0[2*@i@+1]; -#endif -/**end repeat2**/ - case 0: -#if @complex@ - ((@temptype@ *)dataptr[1])[0] += accum_re; - ((@temptype@ *)dataptr[1])[1] += accum_im; -#else - *((@type@ *)dataptr[1]) = @to@(accum + - @from@(*((@type@ *)dataptr[1]))); -#endif - return; - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - - _mm_prefetch(data0 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@)); -/**end repeat2**/ - data0 += 8; - } - - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#elif EINSUM_USE_SSE2 && @float64@ - /* Use aligned instructions if possible */ - if (EINSUM_IS_SSE_ALIGNED(data0)) { - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - - _mm_prefetch(data0 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@)); -/**end repeat2**/ - data0 += 8; - } - - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); - - /* Finish off the loop */ - goto finish_after_unrolled_loop; - } -#endif - - /* Unroll the loop by 8 */ - while (count >= 8) { - count -= 8; - -#if EINSUM_USE_SSE1 && @float32@ - _mm_prefetch(data0 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 4# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@)); -/**end repeat2**/ -#elif EINSUM_USE_SSE2 && @float64@ - _mm_prefetch(data0 + 512, _MM_HINT_T0); - -/**begin repeat2 - * #i = 0, 2, 4, 6# - */ - /* - * NOTE: This accumulation changes the order, so will likely - * produce slightly different results. - */ - accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@)); -/**end repeat2**/ -#else -/**begin repeat2 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ -# if !@complex@ - accum += @from@(data0[@i@]); -# else /* complex */ - accum_re += data0[2*@i@+0]; - accum_im += data0[2*@i@+1]; -# endif -/**end repeat2**/ -#endif - -#if !@complex@ - data0 += 8; -#else - data0 += 8*2; -#endif - } - -#if EINSUM_USE_SSE1 && @float32@ - /* Add the four SSE values and put in accum */ - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); - accum_sse = _mm_add_ps(a, accum_sse); - a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); - accum_sse = _mm_add_ps(a, accum_sse); - _mm_store_ss(&accum, accum_sse); -#elif EINSUM_USE_SSE2 && @float64@ - /* Add the two SSE2 values and put in accum */ - a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); - accum_sse = _mm_add_pd(a, accum_sse); - _mm_store_sd(&accum, accum_sse); -#endif - - /* Finish off the loop */ - goto finish_after_unrolled_loop; -} - -#endif /* @nop@ == 1 */ - -static void -@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ -#if @complex@ - @temptype@ accum_re = 0, accum_im = 0; -#else - @temptype@ accum = 0; -#endif - -#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) - char *data0 = dataptr[0]; - npy_intp stride0 = strides[0]; -#endif -#if (@nop@ == 2 || @nop@ == 3) && !@complex@ - char *data1 = dataptr[1]; - npy_intp stride1 = strides[1]; -#endif -#if (@nop@ == 3) && !@complex@ - char *data2 = dataptr[2]; - npy_intp stride2 = strides[2]; -#endif - - NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_outstride0_@noplabel@ (%d)\n", - (int)count); - - while (count--) { -#if !@complex@ -# if @nop@ == 1 - accum += @from@(*(@type@ *)data0); - data0 += stride0; -# elif @nop@ == 2 - accum += @from@(*(@type@ *)data0) * - @from@(*(@type@ *)data1); - data0 += stride0; - data1 += stride1; -# elif @nop@ == 3 - accum += @from@(*(@type@ *)data0) * - @from@(*(@type@ *)data1) * - @from@(*(@type@ *)data2); - data0 += stride0; - data1 += stride1; - data2 += stride2; -# else - @temptype@ temp = @from@(*(@type@ *)dataptr[0]); - int i; - for (i = 1; i < nop; ++i) { - temp *= @from@(*(@type@ *)dataptr[i]); - } - accum += temp; - for (i = 0; i < nop; ++i) { - dataptr[i] += strides[i]; - } -# endif -#else /* complex */ -# if @nop@ == 1 - accum_re += ((@temptype@ *)data0)[0]; - accum_im += ((@temptype@ *)data0)[1]; - data0 += stride0; -# else -# if @nop@ <= 3 -#define _SUMPROD_NOP @nop@ -# else -#define _SUMPROD_NOP nop -# endif - @temptype@ re, im, tmp; - int i; - re = ((@temptype@ *)dataptr[0])[0]; - im = ((@temptype@ *)dataptr[0])[1]; - for (i = 1; i < _SUMPROD_NOP; ++i) { - tmp = re * ((@temptype@ *)dataptr[i])[0] - - im * ((@temptype@ *)dataptr[i])[1]; - im = re * ((@temptype@ *)dataptr[i])[1] + - im * ((@temptype@ *)dataptr[i])[0]; - re = tmp; - } - accum_re += re; - accum_im += im; - for (i = 0; i < _SUMPROD_NOP; ++i) { - dataptr[i] += strides[i]; - } -#undef _SUMPROD_NOP -# endif -#endif - } - -#if @complex@ -# if @nop@ <= 3 - ((@temptype@ *)dataptr[@nop@])[0] += accum_re; - ((@temptype@ *)dataptr[@nop@])[1] += accum_im; -# else - ((@temptype@ *)dataptr[nop])[0] += accum_re; - ((@temptype@ *)dataptr[nop])[1] += accum_im; -# endif -#else -# if @nop@ <= 3 - *((@type@ *)dataptr[@nop@]) = @to@(accum + - @from@(*((@type@ *)dataptr[@nop@]))); -# else - *((@type@ *)dataptr[nop]) = @to@(accum + - @from@(*((@type@ *)dataptr[nop]))); -# endif -#endif - -} - -/**end repeat1**/ - -/**end repeat**/ - - -/* Do OR of ANDs for the boolean type */ - -/**begin repeat - * #nop = 1, 2, 3, 1000# - * #noplabel = one, two, three, any# - */ - -static void -bool_sum_of_products_@noplabel@(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ -#if (@nop@ <= 3) - char *data0 = dataptr[0]; - npy_intp stride0 = strides[0]; -#endif -#if (@nop@ == 2 || @nop@ == 3) - char *data1 = dataptr[1]; - npy_intp stride1 = strides[1]; -#endif -#if (@nop@ == 3) - char *data2 = dataptr[2]; - npy_intp stride2 = strides[2]; -#endif -#if (@nop@ <= 3) - char *data_out = dataptr[@nop@]; - npy_intp stride_out = strides[@nop@]; -#endif - - while (count--) { -#if @nop@ == 1 - *(npy_bool *)data_out = *(npy_bool *)data0 || - *(npy_bool *)data_out; - data0 += stride0; - data_out += stride_out; -#elif @nop@ == 2 - *(npy_bool *)data_out = (*(npy_bool *)data0 && - *(npy_bool *)data1) || - *(npy_bool *)data_out; - data0 += stride0; - data1 += stride1; - data_out += stride_out; -#elif @nop@ == 3 - *(npy_bool *)data_out = (*(npy_bool *)data0 && - *(npy_bool *)data1 && - *(npy_bool *)data2) || - *(npy_bool *)data_out; - data0 += stride0; - data1 += stride1; - data2 += stride2; - data_out += stride_out; -#else - npy_bool temp = *(npy_bool *)dataptr[0]; - int i; - for (i = 1; i < nop; ++i) { - temp = temp && *(npy_bool *)dataptr[i]; - } - *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i]; - for (i = 0; i <= nop; ++i) { - dataptr[i] += strides[i]; - } -#endif - } -} - -static void -bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ -#if (@nop@ <= 3) - char *data0 = dataptr[0]; -#endif -#if (@nop@ == 2 || @nop@ == 3) - char *data1 = dataptr[1]; -#endif -#if (@nop@ == 3) - char *data2 = dataptr[2]; -#endif -#if (@nop@ <= 3) - char *data_out = dataptr[@nop@]; -#endif - -#if (@nop@ <= 3) -/* This is placed before the main loop to make small counts faster */ -finish_after_unrolled_loop: - switch (count) { -/**begin repeat1 - * #i = 6, 5, 4, 3, 2, 1, 0# - */ - case @i@+1: -# if @nop@ == 1 - ((npy_bool *)data_out)[@i@] = ((npy_bool *)data0)[@i@] || - ((npy_bool *)data_out)[@i@]; -# elif @nop@ == 2 - ((npy_bool *)data_out)[@i@] = - (((npy_bool *)data0)[@i@] && - ((npy_bool *)data1)[@i@]) || - ((npy_bool *)data_out)[@i@]; -# elif @nop@ == 3 - ((npy_bool *)data_out)[@i@] = - (((npy_bool *)data0)[@i@] && - ((npy_bool *)data1)[@i@] && - ((npy_bool *)data2)[@i@]) || - ((npy_bool *)data_out)[@i@]; -# endif -/**end repeat1**/ - case 0: - return; - } -#endif - -/* Unroll the loop by 8 for fixed-size nop */ -#if (@nop@ <= 3) - while (count >= 8) { - count -= 8; -#else - while (count--) { -#endif - -# if @nop@ == 1 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 8*sizeof(npy_bool); - data_out += 8*sizeof(npy_bool); -# elif @nop@ == 2 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - *((npy_bool *)data_out + @i@) = - ((*((npy_bool *)data0 + @i@)) && - (*((npy_bool *)data1 + @i@))) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 8*sizeof(npy_bool); - data1 += 8*sizeof(npy_bool); - data_out += 8*sizeof(npy_bool); -# elif @nop@ == 3 -/**begin repeat1 - * #i = 0, 1, 2, 3, 4, 5, 6, 7# - */ - *((npy_bool *)data_out + @i@) = - ((*((npy_bool *)data0 + @i@)) && - (*((npy_bool *)data1 + @i@)) && - (*((npy_bool *)data2 + @i@))) || - (*((npy_bool *)data_out + @i@)); -/**end repeat1**/ - data0 += 8*sizeof(npy_bool); - data1 += 8*sizeof(npy_bool); - data2 += 8*sizeof(npy_bool); - data_out += 8*sizeof(npy_bool); -# else - npy_bool temp = *(npy_bool *)dataptr[0]; - int i; - for (i = 1; i < nop; ++i) { - temp = temp && *(npy_bool *)dataptr[i]; - } - *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i]; - for (i = 0; i <= nop; ++i) { - dataptr[i] += sizeof(npy_bool); - } -# endif - } - - /* If the loop was unrolled, we need to finish it off */ -#if (@nop@ <= 3) - goto finish_after_unrolled_loop; -#endif -} - -static void -bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, - npy_intp const *strides, npy_intp count) -{ - npy_bool accum = 0; - -#if (@nop@ <= 3) - char *data0 = dataptr[0]; - npy_intp stride0 = strides[0]; -#endif -#if (@nop@ == 2 || @nop@ == 3) - char *data1 = dataptr[1]; - npy_intp stride1 = strides[1]; -#endif -#if (@nop@ == 3) - char *data2 = dataptr[2]; - npy_intp stride2 = strides[2]; -#endif - - while (count--) { -#if @nop@ == 1 - accum = *(npy_bool *)data0 || accum; - data0 += stride0; -#elif @nop@ == 2 - accum = (*(npy_bool *)data0 && *(npy_bool *)data1) || accum; - data0 += stride0; - data1 += stride1; -#elif @nop@ == 3 - accum = (*(npy_bool *)data0 && - *(npy_bool *)data1 && - *(npy_bool *)data2) || accum; - data0 += stride0; - data1 += stride1; - data2 += stride2; -#else - npy_bool temp = *(npy_bool *)dataptr[0]; - int i; - for (i = 1; i < nop; ++i) { - temp = temp && *(npy_bool *)dataptr[i]; - } - accum = temp || accum; - for (i = 0; i <= nop; ++i) { - dataptr[i] += strides[i]; - } -#endif - } - -# if @nop@ <= 3 - *((npy_bool *)dataptr[@nop@]) = accum || *((npy_bool *)dataptr[@nop@]); -# else - *((npy_bool *)dataptr[nop]) = accum || *((npy_bool *)dataptr[nop]); -# endif -} - -/**end repeat**/ - -typedef void (*sum_of_products_fn)(int, char **, npy_intp const*, npy_intp); - -/* These tables need to match up with the type enum */ static sum_of_products_fn -_contig_outstride0_unary_specialization_table[NPY_NTYPES] = { -/**begin repeat - * #name = bool, - * byte, ubyte, - * short, ushort, - * int, uint, - * long, ulong, - * longlong, ulonglong, - * float, double, longdouble, - * cfloat, cdouble, clongdouble, - * object, string, unicode, void, - * datetime, timedelta, half# - * #use = 0, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, 1, - * 1, 1, 1, - * 0, 0, 0, 0, - * 0, 0, 1# - */ -#if @use@ - &@name@_sum_of_products_contig_outstride0_one, -#else - NULL, -#endif -/**end repeat**/ -}; /* End of _contig_outstride0_unary_specialization_table */ - -static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = { -/**begin repeat - * #name = bool, - * byte, ubyte, - * short, ushort, - * int, uint, - * long, ulong, - * longlong, ulonglong, - * float, double, longdouble, - * cfloat, cdouble, clongdouble, - * object, string, unicode, void, - * datetime, timedelta, half# - * #use = 0, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, 1, - * 0, 0, 0, - * 0, 0, 0, 0, - * 0, 0, 1# - */ -#if @use@ -{ - &@name@_sum_of_products_stride0_contig_outstride0_two, - &@name@_sum_of_products_stride0_contig_outcontig_two, - &@name@_sum_of_products_contig_stride0_outstride0_two, - &@name@_sum_of_products_contig_stride0_outcontig_two, - &@name@_sum_of_products_contig_contig_outstride0_two, -}, -#else - {NULL, NULL, NULL, NULL, NULL}, -#endif -/**end repeat**/ -}; /* End of _binary_specialization_table */ - -static sum_of_products_fn _outstride0_specialized_table[NPY_NTYPES][4] = { -/**begin repeat - * #name = bool, - * byte, ubyte, - * short, ushort, - * int, uint, - * long, ulong, - * longlong, ulonglong, - * float, double, longdouble, - * cfloat, cdouble, clongdouble, - * object, string, unicode, void, - * datetime, timedelta, half# - * #use = 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, 1, - * 1, 1, 1, - * 0, 0, 0, 0, - * 0, 0, 1# - */ -#if @use@ -{ - &@name@_sum_of_products_outstride0_any, - &@name@_sum_of_products_outstride0_one, - &@name@_sum_of_products_outstride0_two, - &@name@_sum_of_products_outstride0_three -}, -#else - {NULL, NULL, NULL, NULL}, -#endif -/**end repeat**/ -}; /* End of _outstride0_specialized_table */ - -static sum_of_products_fn _allcontig_specialized_table[NPY_NTYPES][4] = { -/**begin repeat - * #name = bool, - * byte, ubyte, - * short, ushort, - * int, uint, - * long, ulong, - * longlong, ulonglong, - * float, double, longdouble, - * cfloat, cdouble, clongdouble, - * object, string, unicode, void, - * datetime, timedelta, half# - * #use = 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, 1, - * 1, 1, 1, - * 0, 0, 0, 0, - * 0, 0, 1# - */ -#if @use@ -{ - &@name@_sum_of_products_contig_any, - &@name@_sum_of_products_contig_one, - &@name@_sum_of_products_contig_two, - &@name@_sum_of_products_contig_three -}, -#else - {NULL, NULL, NULL, NULL}, -#endif -/**end repeat**/ -}; /* End of _allcontig_specialized_table */ - -static sum_of_products_fn _unspecialized_table[NPY_NTYPES][4] = { -/**begin repeat - * #name = bool, - * byte, ubyte, - * short, ushort, - * int, uint, - * long, ulong, - * longlong, ulonglong, - * float, double, longdouble, - * cfloat, cdouble, clongdouble, - * object, string, unicode, void, - * datetime, timedelta, half# - * #use = 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, - * 1, 1, 1, - * 1, 1, 1, - * 0, 0, 0, 0, - * 0, 0, 1# - */ -#if @use@ +get_sum_of_products_function(int nop, int type_num, npy_intp itemsize, npy_intp const *fixed_strides) { - &@name@_sum_of_products_any, - &@name@_sum_of_products_one, - &@name@_sum_of_products_two, - &@name@_sum_of_products_three -}, -#else - {NULL, NULL, NULL, NULL}, -#endif -/**end repeat**/ -}; /* End of _unnspecialized_table */ - -static sum_of_products_fn -get_sum_of_products_function(int nop, int type_num, - npy_intp itemsize, npy_intp const *fixed_strides) -{ - int iop; - - if (type_num >= NPY_NTYPES) { - return NULL; - } - - /* contiguous reduction */ - if (nop == 1 && fixed_strides[0] == itemsize && fixed_strides[1] == 0) { - sum_of_products_fn ret = - _contig_outstride0_unary_specialization_table[type_num]; - if (ret != NULL) { - return ret; - } - } - - /* nop of 2 has more specializations */ - if (nop == 2) { - /* Encode the zero/contiguous strides */ - int code; - code = (fixed_strides[0] == 0) ? 0 : - (fixed_strides[0] == itemsize) ? 2*2*1 : 8; - code += (fixed_strides[1] == 0) ? 0 : - (fixed_strides[1] == itemsize) ? 2*1 : 8; - code += (fixed_strides[2] == 0) ? 0 : - (fixed_strides[2] == itemsize) ? 1 : 8; - if (code >= 2 && code < 7) { - sum_of_products_fn ret = - _binary_specialization_table[type_num][code-2]; - if (ret != NULL) { - return ret; - } - } - } - - /* Inner loop with an output stride of 0 */ - if (fixed_strides[nop] == 0) { - return _outstride0_specialized_table[type_num][nop <= 3 ? nop : 0]; - } - - /* Check for all contiguous */ - for (iop = 0; iop < nop + 1; ++iop) { - if (fixed_strides[iop] != itemsize) { - break; - } - } - - /* Contiguous loop */ - if (iop == nop + 1) { - return _allcontig_specialized_table[type_num][nop <= 3 ? nop : 0]; - } - - /* None of the above specializations caught it, general loops */ - return _unspecialized_table[type_num][nop <= 3 ? nop : 0]; + #ifndef NPY_DISABLE_OPTIMIZATION + /** + * Auto-generated config headers '*.dispatch.h' are overriding each other, + * which allows the possibility of 'race condition' if another config + * header has been involved in the scope. + * Therefore we tend to include the desired header close from the + * dispatching macros. + */ + #include "einsum.dispatch.h" + #endif + NPY_CPU_DISPATCH_CALL(return einsum_get_sum_of_products_function, + (nop, type_num, itemsize, fixed_strides)) } - - /* * Parses the subscripts for one operand into an output of 'ndim' * labels. The resulting 'op_labels' array will have: diff --git a/numpy/core/src/multiarray/einsum.dispatch.c.src b/numpy/core/src/multiarray/einsum.dispatch.c.src new file mode 100644 index 000000000000..de2b036bc0da --- /dev/null +++ b/numpy/core/src/multiarray/einsum.dispatch.c.src @@ -0,0 +1,1873 @@ +/* + * This file contains the implementation of the 'einsum' function, + * which provides an einstein-summation operation. + * + * Copyright (c) 2011 by Mark Wiebe (mwwiebe@gmail.com) + * The University of British Columbia + * + * See LICENSE.txt for the license. + */ +/** + * @targets baseline sse2 + */ +#include "einsum_helpers.h" + +#undef EINSUM_USE_SSE1 +#undef EINSUM_USE_SSE2 +#define EINSUM_USE_SSE1 defined(NPY_HAVE_SSE2) +#define EINSUM_USE_SSE2 defined(NPY_HAVE_SSE2) +#define EINSUM_IS_SSE_ALIGNED EINSUM_IS_ALIGNED + +/**begin repeat + * #name = byte, short, int, long, longlong, + * ubyte, ushort, uint, ulong, ulonglong, + * half, float, double, longdouble, + * cfloat, cdouble, clongdouble# + * #type = npy_byte, npy_short, npy_int, npy_long, npy_longlong, + * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, + * npy_half, npy_float, npy_double, npy_longdouble, + * npy_cfloat, npy_cdouble, npy_clongdouble# + * #temptype = npy_byte, npy_short, npy_int, npy_long, npy_longlong, + * npy_ubyte, npy_ushort, npy_uint, npy_ulong, npy_ulonglong, + * npy_float, npy_float, npy_double, npy_longdouble, + * npy_float, npy_double, npy_longdouble# + * #to = ,,,,, + * ,,,,, + * npy_float_to_half,,,, + * ,,# + * #from = ,,,,, + * ,,,,, + * npy_half_to_float,,,, + * ,,# + * #complex = 0*5, + * 0*5, + * 0*4, + * 1*3# + * #float32 = 0*5, + * 0*5, + * 0,1,0,0, + * 0*3# + * #float64 = 0*5, + * 0*5, + * 0,0,1,0, + * 0*3# + */ + +/**begin repeat1 + * #nop = 1, 2, 3, 1000# + * #noplabel = one, two, three, any# + */ +static void +@name@_sum_of_products_@noplabel@(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ +#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) + char *data0 = dataptr[0]; + npy_intp stride0 = strides[0]; +#endif +#if (@nop@ == 2 || @nop@ == 3) && !@complex@ + char *data1 = dataptr[1]; + npy_intp stride1 = strides[1]; +#endif +#if (@nop@ == 3) && !@complex@ + char *data2 = dataptr[2]; + npy_intp stride2 = strides[2]; +#endif +#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) + char *data_out = dataptr[@nop@]; + npy_intp stride_out = strides[@nop@]; +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_@noplabel@ (%d)\n", (int)count); + + while (count--) { +#if !@complex@ +# if @nop@ == 1 + *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) + + @from@(*(@type@ *)data_out)); + data0 += stride0; + data_out += stride_out; +# elif @nop@ == 2 + *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) * + @from@(*(@type@ *)data1) + + @from@(*(@type@ *)data_out)); + data0 += stride0; + data1 += stride1; + data_out += stride_out; +# elif @nop@ == 3 + *(@type@ *)data_out = @to@(@from@(*(@type@ *)data0) * + @from@(*(@type@ *)data1) * + @from@(*(@type@ *)data2) + + @from@(*(@type@ *)data_out)); + data0 += stride0; + data1 += stride1; + data2 += stride2; + data_out += stride_out; +# else + @temptype@ temp = @from@(*(@type@ *)dataptr[0]); + int i; + for (i = 1; i < nop; ++i) { + temp *= @from@(*(@type@ *)dataptr[i]); + } + *(@type@ *)dataptr[nop] = @to@(temp + + @from@(*(@type@ *)dataptr[i])); + for (i = 0; i <= nop; ++i) { + dataptr[i] += strides[i]; + } +# endif +#else /* complex */ +# if @nop@ == 1 + ((@temptype@ *)data_out)[0] = ((@temptype@ *)data0)[0] + + ((@temptype@ *)data_out)[0]; + ((@temptype@ *)data_out)[1] = ((@temptype@ *)data0)[1] + + ((@temptype@ *)data_out)[1]; + data0 += stride0; + data_out += stride_out; +# else +# if @nop@ <= 3 +#define _SUMPROD_NOP @nop@ +# else +#define _SUMPROD_NOP nop +# endif + @temptype@ re, im, tmp; + int i; + re = ((@temptype@ *)dataptr[0])[0]; + im = ((@temptype@ *)dataptr[0])[1]; + for (i = 1; i < _SUMPROD_NOP; ++i) { + tmp = re * ((@temptype@ *)dataptr[i])[0] - + im * ((@temptype@ *)dataptr[i])[1]; + im = re * ((@temptype@ *)dataptr[i])[1] + + im * ((@temptype@ *)dataptr[i])[0]; + re = tmp; + } + ((@temptype@ *)dataptr[_SUMPROD_NOP])[0] = re + + ((@temptype@ *)dataptr[_SUMPROD_NOP])[0]; + ((@temptype@ *)dataptr[_SUMPROD_NOP])[1] = im + + ((@temptype@ *)dataptr[_SUMPROD_NOP])[1]; + + for (i = 0; i <= _SUMPROD_NOP; ++i) { + dataptr[i] += strides[i]; + } +#undef _SUMPROD_NOP +# endif +#endif + } +} + +#if @nop@ == 1 + +static void +@name@_sum_of_products_contig_one(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @type@ *data_out = (@type@ *)dataptr[1]; + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_one (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: +#if !@complex@ + data_out[@i@] = @to@(@from@(data0[@i@]) + + @from@(data_out[@i@])); +#else + ((@temptype@ *)data_out + 2*@i@)[0] = + ((@temptype@ *)data0 + 2*@i@)[0] + + ((@temptype@ *)data_out + 2*@i@)[0]; + ((@temptype@ *)data_out + 2*@i@)[1] = + ((@temptype@ *)data0 + 2*@i@)[1] + + ((@temptype@ *)data_out + 2*@i@)[1]; +#endif +/**end repeat2**/ + case 0: + return; + } + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ +#if !@complex@ + data_out[@i@] = @to@(@from@(data0[@i@]) + + @from@(data_out[@i@])); +#else /* complex */ + ((@temptype@ *)data_out + 2*@i@)[0] = + ((@temptype@ *)data0 + 2*@i@)[0] + + ((@temptype@ *)data_out + 2*@i@)[0]; + ((@temptype@ *)data_out + 2*@i@)[1] = + ((@temptype@ *)data0 + 2*@i@)[1] + + ((@temptype@ *)data_out + 2*@i@)[1]; +#endif +/**end repeat2**/ + data0 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +#elif @nop@ == 2 && !@complex@ + +static void +@name@_sum_of_products_contig_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @type@ *data1 = (@type@ *)dataptr[1]; + @type@ *data_out = (@type@ *)dataptr[2]; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, b; +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, b; +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(@from@(data0[@i@]) * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) && + EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); + b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); + _mm_store_ps(data_out+@i@, b); +/**end repeat2**/ + data0 += 8; + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1) && + EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); + _mm_store_pd(data_out+@i@, b); +/**end repeat2**/ + data0 += 8; + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); + b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); + _mm_storeu_ps(data_out+@i@, b); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); + _mm_storeu_pd(data_out+@i@, b); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + data_out[@i@] = @to@(@from@(data0[@i@]) * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ +#endif + data0 += 8; + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +/* Some extra specializations for the two operand case */ +static void +@name@_sum_of_products_stride0_contig_outcontig_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @temptype@ value0 = @from@(*(@type@ *)dataptr[0]); + @type@ *data1 = (@type@ *)dataptr[1]; + @type@ *data_out = (@type@ *)dataptr[2]; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, b, value0_sse; +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, b, value0_sse; +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_stride0_contig_outcontig_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(value0 * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + value0_sse = _mm_set_ps1(value0); + + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(value0_sse, _mm_load_ps(data1+@i@)); + b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); + _mm_store_ps(data_out+@i@, b); +/**end repeat2**/ + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + if (count > 0) { + goto finish_after_unrolled_loop; + } + else { + return; + } + } +#elif EINSUM_USE_SSE2 && @float64@ + value0_sse = _mm_set1_pd(value0); + + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data1) && EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(value0_sse, _mm_load_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); + _mm_store_pd(data_out+@i@, b); +/**end repeat2**/ + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + if (count > 0) { + goto finish_after_unrolled_loop; + } + else { + return; + } + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(value0_sse, _mm_loadu_ps(data1+@i@)); + b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); + _mm_storeu_ps(data_out+@i@, b); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(value0_sse, _mm_loadu_pd(data1+@i@)); + b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); + _mm_storeu_pd(data_out+@i@, b); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + data_out[@i@] = @to@(value0 * + @from@(data1[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ +#endif + data1 += 8; + data_out += 8; + } + + /* Finish off the loop */ + if (count > 0) { + goto finish_after_unrolled_loop; + } +} + +static void +@name@_sum_of_products_contig_stride0_outcontig_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @temptype@ value1 = @from@(*(@type@ *)dataptr[1]); + @type@ *data_out = (@type@ *)dataptr[2]; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, b, value1_sse; +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, b, value1_sse; +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_stride0_outcontig_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + data_out[@i@] = @to@(@from@(data0[@i@])* + value1 + + @from@(data_out[@i@])); +/**end repeat2**/ + case 0: + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + value1_sse = _mm_set_ps1(value1); + + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(_mm_load_ps(data0+@i@), value1_sse); + b = _mm_add_ps(a, _mm_load_ps(data_out+@i@)); + _mm_store_ps(data_out+@i@, b); +/**end repeat2**/ + data0 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + value1_sse = _mm_set1_pd(value1); + + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data_out)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(_mm_load_pd(data0+@i@), value1_sse); + b = _mm_add_pd(a, _mm_load_pd(data_out+@i@)); + _mm_store_pd(data_out+@i@, b); +/**end repeat2**/ + data0 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ +/**begin repeat2 + * #i = 0, 4# + */ + a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), value1_sse); + b = _mm_add_ps(a, _mm_loadu_ps(data_out+@i@)); + _mm_storeu_ps(data_out+@i@, b); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), value1_sse); + b = _mm_add_pd(a, _mm_loadu_pd(data_out+@i@)); + _mm_storeu_pd(data_out+@i@, b); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + data_out[@i@] = @to@(@from@(data0[@i@])* + value1 + + @from@(data_out[@i@])); +/**end repeat2**/ +#endif + data0 += 8; + data_out += 8; + } + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +static void +@name@_sum_of_products_contig_contig_outstride0_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @type@ *data1 = (@type@ *)dataptr[1]; + @temptype@ accum = 0; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, accum_sse = _mm_setzero_ps(); +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, accum_sse = _mm_setzero_pd(); +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_contig_outstride0_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data0[@i@]) * @from@(data1[@i@]); +/**end repeat2**/ + case 0: + *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum); + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + + _mm_prefetch(data0 + 512, _MM_HINT_T0); + _mm_prefetch(data1 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_ps(_mm_load_ps(data0+@i@), _mm_load_ps(data1+@i@)); + accum_sse = _mm_add_ps(accum_sse, a); +/**end repeat2**/ + data0 += 8; + data1 += 8; + } + + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0) && EINSUM_IS_SSE_ALIGNED(data1)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + + _mm_prefetch(data0 + 512, _MM_HINT_T0); + _mm_prefetch(data1 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_pd(_mm_load_pd(data0+@i@), _mm_load_pd(data1+@i@)); + accum_sse = _mm_add_pd(accum_sse, a); +/**end repeat2**/ + data0 += 8; + data1 += 8; + } + + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ + _mm_prefetch(data0 + 512, _MM_HINT_T0); + _mm_prefetch(data1 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_ps(_mm_loadu_ps(data0+@i@), _mm_loadu_ps(data1+@i@)); + accum_sse = _mm_add_ps(accum_sse, a); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ + _mm_prefetch(data0 + 512, _MM_HINT_T0); + _mm_prefetch(data1 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + a = _mm_mul_pd(_mm_loadu_pd(data0+@i@), _mm_loadu_pd(data1+@i@)); + accum_sse = _mm_add_pd(accum_sse, a); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + accum += @from@(data0[@i@]) * @from@(data1[@i@]); +/**end repeat2**/ +#endif + data0 += 8; + data1 += 8; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#elif EINSUM_USE_SSE2 && @float64@ + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +static void +@name@_sum_of_products_stride0_contig_outstride0_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @temptype@ value0 = @from@(*(@type@ *)dataptr[0]); + @type@ *data1 = (@type@ *)dataptr[1]; + @temptype@ accum = 0; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, accum_sse = _mm_setzero_ps(); +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, accum_sse = _mm_setzero_pd(); +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_stride0_contig_outstride0_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data1[@i@]); +/**end repeat2**/ + case 0: + *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + value0 * accum); + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data1)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data1+@i@)); +/**end repeat2**/ + data1 += 8; + } + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data1)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data1+@i@)); +/**end repeat2**/ + data1 += 8; + } + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data1+@i@)); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data1+@i@)); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + accum += @from@(data1[@i@]); +/**end repeat2**/ +#endif + data1 += 8; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#elif EINSUM_USE_SSE2 && @float64@ + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +static void +@name@_sum_of_products_contig_stride0_outstride0_two(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @temptype@ value1 = @from@(*(@type@ *)dataptr[1]); + @temptype@ accum = 0; + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, accum_sse = _mm_setzero_ps(); +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, accum_sse = _mm_setzero_pd(); +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_stride0_outstride0_two (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: + accum += @from@(data0[@i@]); +/**end repeat2**/ + case 0: + *(@type@ *)dataptr[2] = @to@(@from@(*(@type@ *)dataptr[2]) + accum * value1); + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@)); +/**end repeat2**/ + data0 += 8; + } + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@)); +/**end repeat2**/ + data0 += 8; + } + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@)); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@)); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + accum += @from@(data0[@i@]); +/**end repeat2**/ +#endif + data0 += 8; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#elif EINSUM_USE_SSE2 && @float64@ + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +#elif @nop@ == 3 && !@complex@ + +static void +@name@_sum_of_products_contig_three(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + @type@ *data0 = (@type@ *)dataptr[0]; + @type@ *data1 = (@type@ *)dataptr[1]; + @type@ *data2 = (@type@ *)dataptr[2]; + @type@ *data_out = (@type@ *)dataptr[3]; + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + data_out[@i@] = @to@(@from@(data0[@i@]) * + @from@(data1[@i@]) * + @from@(data2[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ + data0 += 8; + data1 += 8; + data2 += 8; + data_out += 8; + } + + /* Finish off the loop */ + +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + if (count-- == 0) { + return; + } + data_out[@i@] = @to@(@from@(data0[@i@]) * + @from@(data1[@i@]) * + @from@(data2[@i@]) + + @from@(data_out[@i@])); +/**end repeat2**/ +} + +#else /* @nop@ > 3 || @complex */ + +static void +@name@_sum_of_products_contig_@noplabel@(int nop, char **dataptr, + npy_intp const *NPY_UNUSED(strides), npy_intp count) +{ + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_@noplabel@ (%d)\n", + (int)count); + + while (count--) { +#if !@complex@ + @temptype@ temp = @from@(*(@type@ *)dataptr[0]); + int i; + for (i = 1; i < nop; ++i) { + temp *= @from@(*(@type@ *)dataptr[i]); + } + *(@type@ *)dataptr[nop] = @to@(temp + + @from@(*(@type@ *)dataptr[i])); + for (i = 0; i <= nop; ++i) { + dataptr[i] += sizeof(@type@); + } +#else /* complex */ +# if @nop@ <= 3 +# define _SUMPROD_NOP @nop@ +# else +# define _SUMPROD_NOP nop +# endif + @temptype@ re, im, tmp; + int i; + re = ((@temptype@ *)dataptr[0])[0]; + im = ((@temptype@ *)dataptr[0])[1]; + for (i = 1; i < _SUMPROD_NOP; ++i) { + tmp = re * ((@temptype@ *)dataptr[i])[0] - + im * ((@temptype@ *)dataptr[i])[1]; + im = re * ((@temptype@ *)dataptr[i])[1] + + im * ((@temptype@ *)dataptr[i])[0]; + re = tmp; + } + ((@temptype@ *)dataptr[_SUMPROD_NOP])[0] = re + + ((@temptype@ *)dataptr[_SUMPROD_NOP])[0]; + ((@temptype@ *)dataptr[_SUMPROD_NOP])[1] = im + + ((@temptype@ *)dataptr[_SUMPROD_NOP])[1]; + + for (i = 0; i <= _SUMPROD_NOP; ++i) { + dataptr[i] += sizeof(@type@); + } +# undef _SUMPROD_NOP +#endif + } +} + +#endif /* functions for various @nop@ */ + +#if @nop@ == 1 + +static void +@name@_sum_of_products_contig_outstride0_one(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ +#if @complex@ + @temptype@ accum_re = 0, accum_im = 0; + @temptype@ *data0 = (@temptype@ *)dataptr[0]; +#else + @temptype@ accum = 0; + @type@ *data0 = (@type@ *)dataptr[0]; +#endif + +#if EINSUM_USE_SSE1 && @float32@ + __m128 a, accum_sse = _mm_setzero_ps(); +#elif EINSUM_USE_SSE2 && @float64@ + __m128d a, accum_sse = _mm_setzero_pd(); +#endif + + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_contig_outstride0_one (%d)\n", + (int)count); + +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat2 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: +#if !@complex@ + accum += @from@(data0[@i@]); +#else /* complex */ + accum_re += data0[2*@i@+0]; + accum_im += data0[2*@i@+1]; +#endif +/**end repeat2**/ + case 0: +#if @complex@ + ((@temptype@ *)dataptr[1])[0] += accum_re; + ((@temptype@ *)dataptr[1])[1] += accum_im; +#else + *((@type@ *)dataptr[1]) = @to@(accum + + @from@(*((@type@ *)dataptr[1]))); +#endif + return; + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + + _mm_prefetch(data0 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_load_ps(data0+@i@)); +/**end repeat2**/ + data0 += 8; + } + + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#elif EINSUM_USE_SSE2 && @float64@ + /* Use aligned instructions if possible */ + if (EINSUM_IS_SSE_ALIGNED(data0)) { + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + + _mm_prefetch(data0 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_load_pd(data0+@i@)); +/**end repeat2**/ + data0 += 8; + } + + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); + + /* Finish off the loop */ + goto finish_after_unrolled_loop; + } +#endif + + /* Unroll the loop by 8 */ + while (count >= 8) { + count -= 8; + +#if EINSUM_USE_SSE1 && @float32@ + _mm_prefetch(data0 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 4# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_ps(accum_sse, _mm_loadu_ps(data0+@i@)); +/**end repeat2**/ +#elif EINSUM_USE_SSE2 && @float64@ + _mm_prefetch(data0 + 512, _MM_HINT_T0); + +/**begin repeat2 + * #i = 0, 2, 4, 6# + */ + /* + * NOTE: This accumulation changes the order, so will likely + * produce slightly different results. + */ + accum_sse = _mm_add_pd(accum_sse, _mm_loadu_pd(data0+@i@)); +/**end repeat2**/ +#else +/**begin repeat2 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ +# if !@complex@ + accum += @from@(data0[@i@]); +# else /* complex */ + accum_re += data0[2*@i@+0]; + accum_im += data0[2*@i@+1]; +# endif +/**end repeat2**/ +#endif + +#if !@complex@ + data0 += 8; +#else + data0 += 8*2; +#endif + } + +#if EINSUM_USE_SSE1 && @float32@ + /* Add the four SSE values and put in accum */ + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(2,3,0,1)); + accum_sse = _mm_add_ps(a, accum_sse); + a = _mm_shuffle_ps(accum_sse, accum_sse, _MM_SHUFFLE(1,0,3,2)); + accum_sse = _mm_add_ps(a, accum_sse); + _mm_store_ss(&accum, accum_sse); +#elif EINSUM_USE_SSE2 && @float64@ + /* Add the two SSE2 values and put in accum */ + a = _mm_shuffle_pd(accum_sse, accum_sse, _MM_SHUFFLE2(0,1)); + accum_sse = _mm_add_pd(a, accum_sse); + _mm_store_sd(&accum, accum_sse); +#endif + + /* Finish off the loop */ + goto finish_after_unrolled_loop; +} + +#endif /* @nop@ == 1 */ + +static void +@name@_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ +#if @complex@ + @temptype@ accum_re = 0, accum_im = 0; +#else + @temptype@ accum = 0; +#endif + +#if (@nop@ == 1) || (@nop@ <= 3 && !@complex@) + char *data0 = dataptr[0]; + npy_intp stride0 = strides[0]; +#endif +#if (@nop@ == 2 || @nop@ == 3) && !@complex@ + char *data1 = dataptr[1]; + npy_intp stride1 = strides[1]; +#endif +#if (@nop@ == 3) && !@complex@ + char *data2 = dataptr[2]; + npy_intp stride2 = strides[2]; +#endif + + NPY_EINSUM_DBG_PRINT1("@name@_sum_of_products_outstride0_@noplabel@ (%d)\n", + (int)count); + + while (count--) { +#if !@complex@ +# if @nop@ == 1 + accum += @from@(*(@type@ *)data0); + data0 += stride0; +# elif @nop@ == 2 + accum += @from@(*(@type@ *)data0) * + @from@(*(@type@ *)data1); + data0 += stride0; + data1 += stride1; +# elif @nop@ == 3 + accum += @from@(*(@type@ *)data0) * + @from@(*(@type@ *)data1) * + @from@(*(@type@ *)data2); + data0 += stride0; + data1 += stride1; + data2 += stride2; +# else + @temptype@ temp = @from@(*(@type@ *)dataptr[0]); + int i; + for (i = 1; i < nop; ++i) { + temp *= @from@(*(@type@ *)dataptr[i]); + } + accum += temp; + for (i = 0; i < nop; ++i) { + dataptr[i] += strides[i]; + } +# endif +#else /* complex */ +# if @nop@ == 1 + accum_re += ((@temptype@ *)data0)[0]; + accum_im += ((@temptype@ *)data0)[1]; + data0 += stride0; +# else +# if @nop@ <= 3 +#define _SUMPROD_NOP @nop@ +# else +#define _SUMPROD_NOP nop +# endif + @temptype@ re, im, tmp; + int i; + re = ((@temptype@ *)dataptr[0])[0]; + im = ((@temptype@ *)dataptr[0])[1]; + for (i = 1; i < _SUMPROD_NOP; ++i) { + tmp = re * ((@temptype@ *)dataptr[i])[0] - + im * ((@temptype@ *)dataptr[i])[1]; + im = re * ((@temptype@ *)dataptr[i])[1] + + im * ((@temptype@ *)dataptr[i])[0]; + re = tmp; + } + accum_re += re; + accum_im += im; + for (i = 0; i < _SUMPROD_NOP; ++i) { + dataptr[i] += strides[i]; + } +#undef _SUMPROD_NOP +# endif +#endif + } + +#if @complex@ +# if @nop@ <= 3 + ((@temptype@ *)dataptr[@nop@])[0] += accum_re; + ((@temptype@ *)dataptr[@nop@])[1] += accum_im; +# else + ((@temptype@ *)dataptr[nop])[0] += accum_re; + ((@temptype@ *)dataptr[nop])[1] += accum_im; +# endif +#else +# if @nop@ <= 3 + *((@type@ *)dataptr[@nop@]) = @to@(accum + + @from@(*((@type@ *)dataptr[@nop@]))); +# else + *((@type@ *)dataptr[nop]) = @to@(accum + + @from@(*((@type@ *)dataptr[nop]))); +# endif +#endif + +} + +/**end repeat1**/ + +/**end repeat**/ + + +/* Do OR of ANDs for the boolean type */ + +/**begin repeat + * #nop = 1, 2, 3, 1000# + * #noplabel = one, two, three, any# + */ + +static void +bool_sum_of_products_@noplabel@(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ +#if (@nop@ <= 3) + char *data0 = dataptr[0]; + npy_intp stride0 = strides[0]; +#endif +#if (@nop@ == 2 || @nop@ == 3) + char *data1 = dataptr[1]; + npy_intp stride1 = strides[1]; +#endif +#if (@nop@ == 3) + char *data2 = dataptr[2]; + npy_intp stride2 = strides[2]; +#endif +#if (@nop@ <= 3) + char *data_out = dataptr[@nop@]; + npy_intp stride_out = strides[@nop@]; +#endif + + while (count--) { +#if @nop@ == 1 + *(npy_bool *)data_out = *(npy_bool *)data0 || + *(npy_bool *)data_out; + data0 += stride0; + data_out += stride_out; +#elif @nop@ == 2 + *(npy_bool *)data_out = (*(npy_bool *)data0 && + *(npy_bool *)data1) || + *(npy_bool *)data_out; + data0 += stride0; + data1 += stride1; + data_out += stride_out; +#elif @nop@ == 3 + *(npy_bool *)data_out = (*(npy_bool *)data0 && + *(npy_bool *)data1 && + *(npy_bool *)data2) || + *(npy_bool *)data_out; + data0 += stride0; + data1 += stride1; + data2 += stride2; + data_out += stride_out; +#else + npy_bool temp = *(npy_bool *)dataptr[0]; + int i; + for (i = 1; i < nop; ++i) { + temp = temp && *(npy_bool *)dataptr[i]; + } + *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i]; + for (i = 0; i <= nop; ++i) { + dataptr[i] += strides[i]; + } +#endif + } +} + +static void +bool_sum_of_products_contig_@noplabel@(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ +#if (@nop@ <= 3) + char *data0 = dataptr[0]; +#endif +#if (@nop@ == 2 || @nop@ == 3) + char *data1 = dataptr[1]; +#endif +#if (@nop@ == 3) + char *data2 = dataptr[2]; +#endif +#if (@nop@ <= 3) + char *data_out = dataptr[@nop@]; +#endif + +#if (@nop@ <= 3) +/* This is placed before the main loop to make small counts faster */ +finish_after_unrolled_loop: + switch (count) { +/**begin repeat1 + * #i = 6, 5, 4, 3, 2, 1, 0# + */ + case @i@+1: +# if @nop@ == 1 + ((npy_bool *)data_out)[@i@] = ((npy_bool *)data0)[@i@] || + ((npy_bool *)data_out)[@i@]; +# elif @nop@ == 2 + ((npy_bool *)data_out)[@i@] = + (((npy_bool *)data0)[@i@] && + ((npy_bool *)data1)[@i@]) || + ((npy_bool *)data_out)[@i@]; +# elif @nop@ == 3 + ((npy_bool *)data_out)[@i@] = + (((npy_bool *)data0)[@i@] && + ((npy_bool *)data1)[@i@] && + ((npy_bool *)data2)[@i@]) || + ((npy_bool *)data_out)[@i@]; +# endif +/**end repeat1**/ + case 0: + return; + } +#endif + +/* Unroll the loop by 8 for fixed-size nop */ +#if (@nop@ <= 3) + while (count >= 8) { + count -= 8; +#else + while (count--) { +#endif + +# if @nop@ == 1 +/**begin repeat1 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + *((npy_bool *)data_out + @i@) = (*((npy_bool *)data0 + @i@)) || + (*((npy_bool *)data_out + @i@)); +/**end repeat1**/ + data0 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# elif @nop@ == 2 +/**begin repeat1 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + *((npy_bool *)data_out + @i@) = + ((*((npy_bool *)data0 + @i@)) && + (*((npy_bool *)data1 + @i@))) || + (*((npy_bool *)data_out + @i@)); +/**end repeat1**/ + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# elif @nop@ == 3 +/**begin repeat1 + * #i = 0, 1, 2, 3, 4, 5, 6, 7# + */ + *((npy_bool *)data_out + @i@) = + ((*((npy_bool *)data0 + @i@)) && + (*((npy_bool *)data1 + @i@)) && + (*((npy_bool *)data2 + @i@))) || + (*((npy_bool *)data_out + @i@)); +/**end repeat1**/ + data0 += 8*sizeof(npy_bool); + data1 += 8*sizeof(npy_bool); + data2 += 8*sizeof(npy_bool); + data_out += 8*sizeof(npy_bool); +# else + npy_bool temp = *(npy_bool *)dataptr[0]; + int i; + for (i = 1; i < nop; ++i) { + temp = temp && *(npy_bool *)dataptr[i]; + } + *(npy_bool *)dataptr[nop] = temp || *(npy_bool *)dataptr[i]; + for (i = 0; i <= nop; ++i) { + dataptr[i] += sizeof(npy_bool); + } +# endif + } + + /* If the loop was unrolled, we need to finish it off */ +#if (@nop@ <= 3) + goto finish_after_unrolled_loop; +#endif +} + +static void +bool_sum_of_products_outstride0_@noplabel@(int nop, char **dataptr, + npy_intp const *strides, npy_intp count) +{ + npy_bool accum = 0; + +#if (@nop@ <= 3) + char *data0 = dataptr[0]; + npy_intp stride0 = strides[0]; +#endif +#if (@nop@ == 2 || @nop@ == 3) + char *data1 = dataptr[1]; + npy_intp stride1 = strides[1]; +#endif +#if (@nop@ == 3) + char *data2 = dataptr[2]; + npy_intp stride2 = strides[2]; +#endif + + while (count--) { +#if @nop@ == 1 + accum = *(npy_bool *)data0 || accum; + data0 += stride0; +#elif @nop@ == 2 + accum = (*(npy_bool *)data0 && *(npy_bool *)data1) || accum; + data0 += stride0; + data1 += stride1; +#elif @nop@ == 3 + accum = (*(npy_bool *)data0 && + *(npy_bool *)data1 && + *(npy_bool *)data2) || accum; + data0 += stride0; + data1 += stride1; + data2 += stride2; +#else + npy_bool temp = *(npy_bool *)dataptr[0]; + int i; + for (i = 1; i < nop; ++i) { + temp = temp && *(npy_bool *)dataptr[i]; + } + accum = temp || accum; + for (i = 0; i <= nop; ++i) { + dataptr[i] += strides[i]; + } +#endif + } + +# if @nop@ <= 3 + *((npy_bool *)dataptr[@nop@]) = accum || *((npy_bool *)dataptr[@nop@]); +# else + *((npy_bool *)dataptr[nop]) = accum || *((npy_bool *)dataptr[nop]); +# endif +} + +/**end repeat**/ + +typedef void (*sum_of_products_fn)(int, char **, npy_intp const*, npy_intp); + +/* These tables need to match up with the type enum */ +static sum_of_products_fn +_contig_outstride0_unary_specialization_table[NPY_NTYPES] = { +/**begin repeat + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 0, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# + */ +#if @use@ + &@name@_sum_of_products_contig_outstride0_one, +#else + NULL, +#endif +/**end repeat**/ +}; /* End of _contig_outstride0_unary_specialization_table */ + +static sum_of_products_fn _binary_specialization_table[NPY_NTYPES][5] = { +/**begin repeat + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 0, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 0, 0, 0, + * 0, 0, 0, 0, + * 0, 0, 1# + */ +#if @use@ +{ + &@name@_sum_of_products_stride0_contig_outstride0_two, + &@name@_sum_of_products_stride0_contig_outcontig_two, + &@name@_sum_of_products_contig_stride0_outstride0_two, + &@name@_sum_of_products_contig_stride0_outcontig_two, + &@name@_sum_of_products_contig_contig_outstride0_two, +}, +#else + {NULL, NULL, NULL, NULL, NULL}, +#endif +/**end repeat**/ +}; /* End of _binary_specialization_table */ + +static sum_of_products_fn _outstride0_specialized_table[NPY_NTYPES][4] = { +/**begin repeat + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# + */ +#if @use@ +{ + &@name@_sum_of_products_outstride0_any, + &@name@_sum_of_products_outstride0_one, + &@name@_sum_of_products_outstride0_two, + &@name@_sum_of_products_outstride0_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif +/**end repeat**/ +}; /* End of _outstride0_specialized_table */ + +static sum_of_products_fn _allcontig_specialized_table[NPY_NTYPES][4] = { +/**begin repeat + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# + */ +#if @use@ +{ + &@name@_sum_of_products_contig_any, + &@name@_sum_of_products_contig_one, + &@name@_sum_of_products_contig_two, + &@name@_sum_of_products_contig_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif +/**end repeat**/ +}; /* End of _allcontig_specialized_table */ + +static sum_of_products_fn _unspecialized_table[NPY_NTYPES][4] = { +/**begin repeat + * #name = bool, + * byte, ubyte, + * short, ushort, + * int, uint, + * long, ulong, + * longlong, ulonglong, + * float, double, longdouble, + * cfloat, cdouble, clongdouble, + * object, string, unicode, void, + * datetime, timedelta, half# + * #use = 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, + * 1, 1, 1, + * 1, 1, 1, + * 0, 0, 0, 0, + * 0, 0, 1# + */ +#if @use@ +{ + &@name@_sum_of_products_any, + &@name@_sum_of_products_one, + &@name@_sum_of_products_two, + &@name@_sum_of_products_three +}, +#else + {NULL, NULL, NULL, NULL}, +#endif +/**end repeat**/ +}; /* End of _unnspecialized_table */ + +NPY_NO_EXPORT sum_of_products_fn NPY_CPU_DISPATCH_CURFX(einsum_get_sum_of_products_function) +(int nop, int type_num, npy_intp itemsize,npy_intp const *fixed_strides) +{ + int iop; + + if (type_num >= NPY_NTYPES) { + return NULL; + } + + /* contiguous reduction */ + if (nop == 1 && fixed_strides[0] == itemsize && fixed_strides[1] == 0) { + sum_of_products_fn ret = + _contig_outstride0_unary_specialization_table[type_num]; + if (ret != NULL) { + return ret; + } + } + + /* nop of 2 has more specializations */ + if (nop == 2) { + /* Encode the zero/contiguous strides */ + int code; + code = (fixed_strides[0] == 0) ? 0 : + (fixed_strides[0] == itemsize) ? 2*2*1 : 8; + code += (fixed_strides[1] == 0) ? 0 : + (fixed_strides[1] == itemsize) ? 2*1 : 8; + code += (fixed_strides[2] == 0) ? 0 : + (fixed_strides[2] == itemsize) ? 1 : 8; + if (code >= 2 && code < 7) { + sum_of_products_fn ret = + _binary_specialization_table[type_num][code-2]; + if (ret != NULL) { + return ret; + } + } + } + + /* Inner loop with an output stride of 0 */ + if (fixed_strides[nop] == 0) { + return _outstride0_specialized_table[type_num][nop <= 3 ? nop : 0]; + } + + /* Check for all contiguous */ + for (iop = 0; iop < nop + 1; ++iop) { + if (fixed_strides[iop] != itemsize) { + break; + } + } + + /* Contiguous loop */ + if (iop == nop + 1) { + return _allcontig_specialized_table[type_num][nop <= 3 ? nop : 0]; + } + + /* None of the above specializations caught it, general loops */ + return _unspecialized_table[type_num][nop <= 3 ? nop : 0]; +} diff --git a/numpy/core/src/multiarray/einsum_helpers.h b/numpy/core/src/multiarray/einsum_helpers.h new file mode 100644 index 000000000000..ece0dca805b6 --- /dev/null +++ b/numpy/core/src/multiarray/einsum_helpers.h @@ -0,0 +1,52 @@ +#ifndef _NPY_EINSUM_HELPERS_H_ +#define _NPY_EINSUM_HELPERS_H_ + +#define PY_SSIZE_T_CLEAN +#include "Python.h" +#include "structmember.h" + +#define NPY_NO_DEPRECATED_API NPY_API_VERSION +#define _MULTIARRAYMODULE +#include +#include +#include +#include + +#include + +#include "simd/simd.h" +#include "convert.h" +#include "common.h" +#include "ctors.h" + +#define EINSUM_IS_ALIGNED(x) npy_is_aligned(x, NPY_SIMD_WIDTH) + +/********** PRINTF DEBUG TRACING **************/ +#define NPY_EINSUM_DBG_TRACING 0 + +#if NPY_EINSUM_DBG_TRACING +#define NPY_EINSUM_DBG_PRINT(s) printf("%s", s); +#define NPY_EINSUM_DBG_PRINT1(s, p1) printf(s, p1); +#define NPY_EINSUM_DBG_PRINT2(s, p1, p2) printf(s, p1, p2); +#define NPY_EINSUM_DBG_PRINT3(s, p1, p2, p3) printf(s); +#else +#define NPY_EINSUM_DBG_PRINT(s) +#define NPY_EINSUM_DBG_PRINT1(s, p1) +#define NPY_EINSUM_DBG_PRINT2(s, p1, p2) +#define NPY_EINSUM_DBG_PRINT3(s, p1, p2, p3) +#endif + +typedef void (*sum_of_products_fn)(int, char **, npy_intp const*, npy_intp); + +/** + * forward declarations according to configuration statements in + * the dispatch-able source 'einsum.dispatch.c.src', + * see 'npy_cpu_dispatch.h' for more clarification. + */ +#ifndef NPY_DISABLE_OPTIMIZATION + // auto-generated config header required by NPY_CPU_DISPATCH_DECLARE + #include "einsum.dispatch.h" +#endif +NPY_CPU_DISPATCH_DECLARE(NPY_NO_EXPORT sum_of_products_fn einsum_get_sum_of_products_function, + (int nop, int type_num, npy_intp itemsize, npy_intp const *fixed_strides)) +#endif // _NPY_EINSUM_HELPERS_H_