Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 7be5f11

Browse files
authored
Merge pull request #13520 from r-devulap/issue13512
BUG: exp, log AVX loops do not use steps
2 parents b82869e + 4b4d2ab commit 7be5f11

File tree

3 files changed

+57
-14
lines changed

3 files changed

+57
-14
lines changed

numpy/core/src/umath/loops.c.src

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,21 +1621,23 @@ FLOAT_@func@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSE
16211621
NPY_NO_EXPORT NPY_GCC_OPT_3 void
16221622
FLOAT_@func@_@isa@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(data))
16231623
{
1624+
if (!run_unary_@isa@_@func@_FLOAT(args, dimensions, steps)) {
1625+
UNARY_LOOP {
1626+
/*
1627+
* We use the AVX function to compute exp/log for scalar elements as well.
1628+
* This is needed to ensure the output of strided and non-strided
1629+
* cases match. But this worsens the performance of strided arrays.
1630+
* There is plan to fix this in a subsequent patch by using gather
1631+
* instructions for strided arrays in the AVX function.
1632+
*/
16241633
#if defined @CHK@ && defined NPY_HAVE_SSE2_INTRINSICS
1625-
@ISA@_@func@_FLOAT((npy_float*)args[1], (npy_float*)args[0], dimensions[0]);
1634+
@ISA@_@func@_FLOAT((npy_float *)op1, (npy_float *)ip1, 1);
16261635
#else
1627-
/*
1628-
* This is the path it would take if ISA was runtime detected, but not
1629-
* compiled for. It fixes the error on clang6.0 which fails to compile
1630-
* AVX512F version. Not sure if I like this idea, if during runtime it
1631-
* detects AXV512F, it will end up running the scalar version instead
1632-
* of AVX2.
1633-
*/
1634-
UNARY_LOOP {
1635-
const npy_float in1 = *(npy_float *)ip1;
1636-
*(npy_float *)op1 = @scalarf@(in1);
1637-
}
1636+
const npy_float in1 = *(npy_float *)ip1;
1637+
*(npy_float *)op1 = @scalarf@(in1);
16381638
#endif
1639+
}
1640+
}
16391641
}
16401642

16411643
/**end repeat1**/

numpy/core/src/umath/simd.inc.src

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,20 +122,36 @@ abs_ptrdiff(char *a, char *b)
122122

123123
/**begin repeat
124124
* #ISA = AVX2, AVX512F#
125+
* #isa = avx2, avx512f#
126+
* #REGISTER_SIZE = 32, 64#
125127
*/
126128

127129
/* prototypes */
128-
#if defined HAVE_ATTRIBUTE_TARGET_@ISA@_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS
129130

130131
/**begin repeat1
131132
* #func = exp, log#
132133
*/
133134

135+
#if defined HAVE_ATTRIBUTE_TARGET_@ISA@_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS
134136
static NPY_INLINE void
135137
@ISA@_@func@_FLOAT(npy_float *, npy_float *, const npy_intp n);
138+
#endif
136139

137-
/**end repeat1**/
140+
static NPY_INLINE int
141+
run_unary_@isa@_@func@_FLOAT(char **args, npy_intp *dimensions, npy_intp *steps)
142+
{
143+
#if defined HAVE_ATTRIBUTE_TARGET_@ISA@_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS
144+
if (IS_BLOCKABLE_UNARY(sizeof(npy_float), @REGISTER_SIZE@)) {
145+
@ISA@_@func@_FLOAT((npy_float*)args[1], (npy_float*)args[0], dimensions[0]);
146+
return 1;
147+
}
148+
else
149+
return 0;
138150
#endif
151+
return 0;
152+
}
153+
154+
/**end repeat1**/
139155

140156
/**end repeat**/
141157

numpy/core/tests/test_ufunc.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,3 +1964,28 @@ def test_ufunc_types(ufunc):
19641964
assert r.dtype == np.dtype(t)
19651965
else:
19661966
assert res.dtype == np.dtype(out)
1967+
1968+
@pytest.mark.parametrize('ufunc', [getattr(np, x) for x in dir(np)
1969+
if isinstance(getattr(np, x), np.ufunc)])
1970+
def test_ufunc_noncontiguous(ufunc):
1971+
'''
1972+
Check that contiguous and non-contiguous calls to ufuncs
1973+
have the same results for values in range(9)
1974+
'''
1975+
for typ in ufunc.types:
1976+
# types is a list of strings like ii->i
1977+
if any(set('O?mM') & set(typ)):
1978+
# bool, object, datetime are too irregular for this simple test
1979+
continue
1980+
inp, out = typ.split('->')
1981+
args_c = [np.empty(6, t) for t in inp]
1982+
args_n = [np.empty(18, t)[::3] for t in inp]
1983+
for a in args_c:
1984+
a.flat = range(1,7)
1985+
for a in args_n:
1986+
a.flat = range(1,7)
1987+
with warnings.catch_warnings(record=True):
1988+
warnings.filterwarnings("always")
1989+
res_c = ufunc(*args_c)
1990+
res_n = ufunc(*args_n)
1991+
assert_equal(res_c, res_n)

0 commit comments

Comments
 (0)