From 7aee0fa0805c55d0aed2d26dec8e1153757c69ff Mon Sep 17 00:00:00 2001 From: Sayed Adel Date: Sun, 9 May 2021 21:33:37 +0200 Subject: [PATCH] MAINT, SIMD: Hardened the AVX compile-time tests To avoid optimizing it out by the compiler so we make sure that the assembler is getting involved. --- numpy/distutils/checks/cpu_avx.c | 4 ++-- numpy/distutils/checks/cpu_avx2.c | 4 ++-- numpy/distutils/checks/cpu_avx512_clx.c | 5 +++-- numpy/distutils/checks/cpu_avx512_cnl.c | 7 ++++--- numpy/distutils/checks/cpu_avx512_icl.c | 5 +++-- numpy/distutils/checks/cpu_avx512_knl.c | 5 +++-- numpy/distutils/checks/cpu_avx512_knm.c | 6 +++--- numpy/distutils/checks/cpu_avx512_skx.c | 5 +++-- numpy/distutils/checks/cpu_avx512cd.c | 4 ++-- numpy/distutils/checks/cpu_avx512f.c | 4 ++-- numpy/distutils/checks/cpu_f16c.c | 6 +++--- numpy/distutils/checks/cpu_fma3.c | 5 +++-- numpy/distutils/checks/cpu_fma4.c | 5 +++-- 13 files changed, 36 insertions(+), 29 deletions(-) diff --git a/numpy/distutils/checks/cpu_avx.c b/numpy/distutils/checks/cpu_avx.c index 737c0d2e9492..cee4f36ab3f4 100644 --- a/numpy/distutils/checks/cpu_avx.c +++ b/numpy/distutils/checks/cpu_avx.c @@ -1,7 +1,7 @@ #include -int main(void) +int main(int argc, char **argv) { - __m256 a = _mm256_add_ps(_mm256_setzero_ps(), _mm256_setzero_ps()); + __m256 a = _mm256_add_ps(_mm256_loadu_ps((const float*)argv[argc-1]), _mm256_loadu_ps((const float*)argv[1])); return (int)_mm_cvtss_f32(_mm256_castps256_ps128(a)); } diff --git a/numpy/distutils/checks/cpu_avx2.c b/numpy/distutils/checks/cpu_avx2.c index dfb11fd79967..15b6c919b089 100644 --- a/numpy/distutils/checks/cpu_avx2.c +++ b/numpy/distutils/checks/cpu_avx2.c @@ -1,7 +1,7 @@ #include -int main(void) +int main(int argc, char **argv) { - __m256i a = _mm256_abs_epi16(_mm256_setzero_si256()); + __m256i a = _mm256_abs_epi16(_mm256_loadu_si256((const __m256i*)argv[argc-1])); return _mm_cvtsi128_si32(_mm256_castsi256_si128(a)); } diff --git a/numpy/distutils/checks/cpu_avx512_clx.c b/numpy/distutils/checks/cpu_avx512_clx.c index 71dad83a79f0..4baa8fea0475 100644 --- a/numpy/distutils/checks/cpu_avx512_clx.c +++ b/numpy/distutils/checks/cpu_avx512_clx.c @@ -1,8 +1,9 @@ #include -int main(void) +int main(int argc, char **argv) { /* VNNI */ - __m512i a = _mm512_dpbusd_epi32(_mm512_setzero_si512(), _mm512_setzero_si512(), _mm512_setzero_si512()); + __m512i a = _mm512_loadu_si512((const __m512i*)argv[argc-1]); + a = _mm512_dpbusd_epi32(a, _mm512_setzero_si512(), a); return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } diff --git a/numpy/distutils/checks/cpu_avx512_cnl.c b/numpy/distutils/checks/cpu_avx512_cnl.c index dfab4436d07e..f2ff3725ea93 100644 --- a/numpy/distutils/checks/cpu_avx512_cnl.c +++ b/numpy/distutils/checks/cpu_avx512_cnl.c @@ -1,10 +1,11 @@ #include -int main(void) +int main(int argc, char **argv) { + __m512i a = _mm512_loadu_si512((const __m512i*)argv[argc-1]); /* IFMA */ - __m512i a = _mm512_madd52hi_epu64(_mm512_setzero_si512(), _mm512_setzero_si512(), _mm512_setzero_si512()); + a = _mm512_madd52hi_epu64(a, a, _mm512_setzero_si512()); /* VMBI */ - a = _mm512_permutex2var_epi8(a, _mm512_setzero_si512(), _mm512_setzero_si512()); + a = _mm512_permutex2var_epi8(a, _mm512_setzero_si512(), a); return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } diff --git a/numpy/distutils/checks/cpu_avx512_icl.c b/numpy/distutils/checks/cpu_avx512_icl.c index cf2706b3b9ff..085b947e05bf 100644 --- a/numpy/distutils/checks/cpu_avx512_icl.c +++ b/numpy/distutils/checks/cpu_avx512_icl.c @@ -1,9 +1,10 @@ #include -int main(void) +int main(int argc, char **argv) { + __m512i a = _mm512_loadu_si512((const __m512i*)argv[argc-1]); /* VBMI2 */ - __m512i a = _mm512_shrdv_epi64(_mm512_setzero_si512(), _mm512_setzero_si512(), _mm512_setzero_si512()); + a = _mm512_shrdv_epi64(a, a, _mm512_setzero_si512()); /* BITLAG */ a = _mm512_popcnt_epi8(a); /* VPOPCNTDQ */ diff --git a/numpy/distutils/checks/cpu_avx512_knl.c b/numpy/distutils/checks/cpu_avx512_knl.c index 0699f37a6346..10ba52bcc5a7 100644 --- a/numpy/distutils/checks/cpu_avx512_knl.c +++ b/numpy/distutils/checks/cpu_avx512_knl.c @@ -1,10 +1,11 @@ #include -int main(void) +int main(int argc, char **argv) { int base[128]; + __m512d ad = _mm512_loadu_pd((const __m512d*)argv[argc-1]); /* ER */ - __m512i a = _mm512_castpd_si512(_mm512_exp2a23_pd(_mm512_setzero_pd())); + __m512i a = _mm512_castpd_si512(_mm512_exp2a23_pd(ad)); /* PF */ _mm512_mask_prefetch_i64scatter_pd(base, _mm512_cmpeq_epi64_mask(a, a), a, 1, _MM_HINT_T1); return base[0]; diff --git a/numpy/distutils/checks/cpu_avx512_knm.c b/numpy/distutils/checks/cpu_avx512_knm.c index db61b4bfa674..d03b0fe8beb3 100644 --- a/numpy/distutils/checks/cpu_avx512_knm.c +++ b/numpy/distutils/checks/cpu_avx512_knm.c @@ -1,9 +1,9 @@ #include -int main(void) +int main(int argc, char **argv) { - __m512i a = _mm512_setzero_si512(); - __m512 b = _mm512_setzero_ps(); + __m512i a = _mm512_loadu_si512((const __m512i*)argv[argc-1]); + __m512 b = _mm512_loadu_ps((const __m512*)argv[argc-2]); /* 4FMAPS */ b = _mm512_4fmadd_ps(b, b, b, b, b, NULL); diff --git a/numpy/distutils/checks/cpu_avx512_skx.c b/numpy/distutils/checks/cpu_avx512_skx.c index 1d5e15b5e5b8..04761876295f 100644 --- a/numpy/distutils/checks/cpu_avx512_skx.c +++ b/numpy/distutils/checks/cpu_avx512_skx.c @@ -1,9 +1,10 @@ #include -int main(void) +int main(int argc, char **argv) { + __m512i aa = _mm512_abs_epi32(_mm512_loadu_si512((const __m512i*)argv[argc-1])); /* VL */ - __m256i a = _mm256_abs_epi64(_mm256_setzero_si256()); + __m256i a = _mm256_abs_epi64(_mm512_extracti64x4_epi64(aa, 1)); /* DQ */ __m512i b = _mm512_broadcast_i32x8(a); /* BW */ diff --git a/numpy/distutils/checks/cpu_avx512cd.c b/numpy/distutils/checks/cpu_avx512cd.c index 61bef6b8270e..52f4c7f8be0d 100644 --- a/numpy/distutils/checks/cpu_avx512cd.c +++ b/numpy/distutils/checks/cpu_avx512cd.c @@ -1,7 +1,7 @@ #include -int main(void) +int main(int argc, char **argv) { - __m512i a = _mm512_lzcnt_epi32(_mm512_setzero_si512()); + __m512i a = _mm512_lzcnt_epi32(_mm512_loadu_si512((const __m512i*)argv[argc-1])); return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } diff --git a/numpy/distutils/checks/cpu_avx512f.c b/numpy/distutils/checks/cpu_avx512f.c index f60cc09dd094..22d861471ced 100644 --- a/numpy/distutils/checks/cpu_avx512f.c +++ b/numpy/distutils/checks/cpu_avx512f.c @@ -1,7 +1,7 @@ #include -int main(void) +int main(int argc, char **argv) { - __m512i a = _mm512_abs_epi32(_mm512_setzero_si512()); + __m512i a = _mm512_abs_epi32(_mm512_loadu_si512((const __m512i*)argv[argc-1])); return _mm_cvtsi128_si32(_mm512_castsi512_si128(a)); } diff --git a/numpy/distutils/checks/cpu_f16c.c b/numpy/distutils/checks/cpu_f16c.c index a5a343e2dd59..678c582e410c 100644 --- a/numpy/distutils/checks/cpu_f16c.c +++ b/numpy/distutils/checks/cpu_f16c.c @@ -1,9 +1,9 @@ #include #include -int main(void) +int main(int argc, char **argv) { - __m128 a = _mm_cvtph_ps(_mm_setzero_si128()); - __m256 a8 = _mm256_cvtph_ps(_mm_setzero_si128()); + __m128 a = _mm_cvtph_ps(_mm_loadu_si128((const __m128i*)argv[argc-1])); + __m256 a8 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)argv[argc-2])); return (int)(_mm_cvtss_f32(a) + _mm_cvtss_f32(_mm256_castps256_ps128(a8))); } diff --git a/numpy/distutils/checks/cpu_fma3.c b/numpy/distutils/checks/cpu_fma3.c index cf34c6cb1572..2f879c3b357f 100644 --- a/numpy/distutils/checks/cpu_fma3.c +++ b/numpy/distutils/checks/cpu_fma3.c @@ -1,8 +1,9 @@ #include #include -int main(void) +int main(int argc, char **argv) { - __m256 a = _mm256_fmadd_ps(_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()); + __m256 a = _mm256_loadu_ps((const float*)argv[argc-1]); + a = _mm256_fmadd_ps(a, a, a); return (int)_mm_cvtss_f32(_mm256_castps256_ps128(a)); } diff --git a/numpy/distutils/checks/cpu_fma4.c b/numpy/distutils/checks/cpu_fma4.c index 1ad717033e24..0ff17a483385 100644 --- a/numpy/distutils/checks/cpu_fma4.c +++ b/numpy/distutils/checks/cpu_fma4.c @@ -5,8 +5,9 @@ #include #endif -int main(void) +int main(int argc, char **argv) { - __m256 a = _mm256_macc_ps(_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()); + __m256 a = _mm256_loadu_ps((const float*)argv[argc-1]); + a = _mm256_macc_ps(a, a, a); return (int)_mm_cvtss_f32(_mm256_castps256_ps128(a)); }