Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ad8ebeb

Browse files
committed
BUG: fixing bugs while handling special value floats
(1) Fixing invalid exception thrown for the new AVX version of exp (2) Special handling of +/-np.nan and +/-np.inf
1 parent bf1e9b7 commit ad8ebeb

File tree

1 file changed

+26
-14
lines changed

1 file changed

+26
-14
lines changed

numpy/core/src/umath/simd.inc.src

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ abs_ptrdiff(char *a, char *b)
4040
return (a > b) ? (a - b) : (b - a);
4141
}
4242

43-
4443
/*
4544
* stride is equal to element size and input and destination are equal or
4645
* don't overlap within one register. The check of the steps against
@@ -133,7 +132,7 @@ abs_ptrdiff(char *a, char *b)
133132
*/
134133

135134
static void
136-
@ISA@_@func@_FLOAT(npy_float *, npy_float *, const npy_int n);
135+
@ISA@_@func@_FLOAT(npy_float *, npy_float *, const npy_intp n);
137136

138137
/**end repeat1**/
139138
#endif
@@ -1261,7 +1260,7 @@ static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ @vtype@
12611260
* #BYTES = 32, 64#
12621261
* #mask = __m256, __mmask16#
12631262
* #vsub = , _mask#
1264-
* #and_masks =_mm256_and_ps, _mm512_kand#
1263+
* #or_masks =_mm256_or_ps, _mm512_kor#
12651264
* #fmadd = avx2_fmadd,_mm512_fmadd_ps#
12661265
* #mask_to_int = _mm256_movemask_ps, #
12671266
* #full_mask= 0xFF, 0xFFFF#
@@ -1287,7 +1286,7 @@ static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ @vtype@
12871286

12881287
#if defined HAVE_ATTRIBUTE_TARGET_@ISA@_WITH_INTRINSICS
12891288
static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
1290-
@ISA@_exp_FLOAT(npy_float * op, npy_float * ip, const npy_int array_size)
1289+
@ISA@_exp_FLOAT(npy_float * op, npy_float * ip, const npy_intp array_size)
12911290
{
12921291
const npy_int num_lanes = @BYTES@/sizeof(npy_float);
12931292
npy_float xmax = 88.72283935546875f;
@@ -1312,21 +1311,24 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
13121311
@vtype@ poly, num_poly, denom_poly, quadrant;
13131312
@vtype@i exponent;
13141313

1315-
@mask@ xmax_mask, xmin_mask;
1314+
@mask@ xmax_mask, xmin_mask, nan_mask, inf_mask;
13161315
@mask@ load_mask = @isa@_get_full_load_mask();
1317-
npy_int num_remaining_elements = array_size;
1316+
npy_intp num_remaining_elements = array_size;
1317+
npy_intp set_overflow = 0;
13181318

13191319
while (num_remaining_elements > 0) {
13201320

13211321
if (num_remaining_elements < num_lanes)
13221322
load_mask = @isa@_get_partial_load_mask(num_remaining_elements,
13231323
num_lanes);
13241324
@vtype@ x = @isa@_masked_load(load_mask, ip);
1325+
13251326
xmax_mask = _mm@vsize@_cmp_ps@vsub@(x, _mm@vsize@_set1_ps(xmax), _CMP_GE_OQ);
13261327
xmin_mask = _mm@vsize@_cmp_ps@vsub@(x, _mm@vsize@_set1_ps(xmin), _CMP_LE_OQ);
1327-
1328-
x = @isa@_set_masked_lanes(x, zeros_f,
1329-
@and_masks@(xmax_mask,xmin_mask));
1328+
nan_mask = _mm@vsize@_cmp_ps@vsub@(x, x, _CMP_NEQ_UQ);
1329+
inf_mask = _mm@vsize@_cmp_ps@vsub@(x, inf, _CMP_EQ_OQ);
1330+
x = @isa@_set_masked_lanes(x, zeros_f, @or_masks@(
1331+
@or_masks@(nan_mask, xmin_mask), xmax_mask));
13301332

13311333
quadrant = _mm@vsize@_mul_ps(x, log2e);
13321334

@@ -1335,8 +1337,7 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
13351337
quadrant = _mm@vsize@_sub_ps(quadrant, cvt_magic);
13361338

13371339
/* Cody-Waite's range reduction algorithm */
1338-
x = @isa@_range_reduction(x, quadrant,
1339-
codyw_c1, codyw_c2, zeros_f);
1340+
x = @isa@_range_reduction(x, quadrant, codyw_c1, codyw_c2, zeros_f);
13401341

13411342
num_poly = @fmadd@(exp_p5, x, exp_p4);
13421343
num_poly = @fmadd@(num_poly, x, exp_p3);
@@ -1357,16 +1358,27 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
13571358
_mm@vsize@_add_epi32(
13581359
_mm@vsize@_castps_si@vsize@(poly), exponent));
13591360

1360-
/* elem > xmax; return inf, elem < xmin; return 0.0f */
1361+
/*
1362+
* elem > xmax; return inf
1363+
* elem < xmin; return 0.0f
1364+
* elem = +/- nan, return nan
1365+
*/
1366+
poly = @isa@_set_masked_lanes(poly, _mm@vsize@_set1_ps(NPY_NANF), nan_mask);
13611367
poly = @isa@_set_masked_lanes(poly, inf, xmax_mask);
13621368
poly = @isa@_set_masked_lanes(poly, zeros_f, xmin_mask);
13631369

13641370
@masked_store@(op, @cvtps_epi32@(load_mask), poly);
13651371

1372+
set_overflow += _mm_popcnt_u32(
1373+
@mask_to_int@(xmax_mask) ^ @mask_to_int@(inf_mask));
1374+
13661375
ip += num_lanes;
13671376
op += num_lanes;
13681377
num_remaining_elements -= num_lanes;
13691378
}
1379+
1380+
if (set_overflow)
1381+
_mm_setcsr(_mm_getcsr() | (0x1 << 3));
13701382
}
13711383

13721384
/*
@@ -1384,7 +1396,7 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
13841396
*/
13851397

13861398
static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
1387-
@ISA@_log_FLOAT(npy_float * op, npy_float * ip, const npy_int array_size)
1399+
@ISA@_log_FLOAT(npy_float * op, npy_float * ip, const npy_intp array_size)
13881400
{
13891401
const npy_int num_lanes = @BYTES@/sizeof(npy_float);
13901402

@@ -1410,7 +1422,7 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
14101422

14111423
@mask@ inf_nan_mask, sqrt2_mask, zero_mask, negx_mask;
14121424
@mask@ load_mask = @isa@_get_full_load_mask();
1413-
npy_int num_remaining_elements = array_size;
1425+
npy_intp num_remaining_elements = array_size;
14141426

14151427
while (num_remaining_elements > 0) {
14161428

0 commit comments

Comments
 (0)