Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f15a116

Browse files
BUG: Address interaction between SME and FPSR (#29223)
* BUG: Address interaction between SME and FPSR This is intended to resolve #28687 The root cause is an interaction between Arm Scalable Matrix Extension (SME) and the floating point status register (FPSR). As noted in Arm docs for FPSR, "On entry to or exit from Streaming SVE mode, FPSR.{IOC, DZC, OFC, UFC, IXC, IDC, QC} are set to 1 and the remaining bits are set to 0". This means that floating point status flags are all raised when SME is used, regardless of values or operations performed. These are manifesting now because Apple Silicon M4 supports SME and macOS 15.4 enables SME codepaths for Accelerate BLAS / LAPACK. However, SME / FPSR behavior is not specific to Apple Silicon M4 and will occur on non-Apple chips using SME as well. Changes add compile and runtime checks to determine whether BLAS / LAPACK might use SME (macOS / Accelerate only at the moment). If so, special handling of floating-point error (FPE) is added, which includes: - clearing FPE after some BLAS calls - short-circuiting FPE read after some BLAS calls All tests pass Performance is similar Another approach would have been to wrap all BLAS / LAPACK calls with save / restore FPE. However, it added a lot of overhead for the inner loops that utilize BLAS / LAPACK. Some benchmarks were 8x slower. * add blas_supports_fpe and ifdef check Address the linker & linter failures
1 parent d52b36e commit f15a116

File tree

8 files changed

+217
-8
lines changed

8 files changed

+217
-8
lines changed

numpy/_core/meson.build

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,7 @@ src_multiarray_umath_common = [
11171117
]
11181118
if have_blas
11191119
src_multiarray_umath_common += [
1120+
'src/common/blas_utils.c',
11201121
'src/common/cblasfuncs.c',
11211122
'src/common/python_xerbla.c',
11221123
]

numpy/_core/src/common/blas_utils.c

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#include <stdbool.h>
2+
#include <stdio.h>
3+
#include <stdlib.h>
4+
5+
#ifdef __APPLE__
6+
#include <sys/sysctl.h>
7+
#endif
8+
9+
#include "numpy/numpyconfig.h" // NPY_VISIBILITY_HIDDEN
10+
#include "numpy/npy_math.h" // npy_get_floatstatus_barrier
11+
#include "blas_utils.h"
12+
13+
#if NPY_BLAS_CHECK_FPE_SUPPORT
14+
15+
/* Return whether we're running on macOS 15.4 or later
16+
*/
17+
static inline bool
18+
is_macOS_version_15_4_or_later(void){
19+
#if !defined(__APPLE__)
20+
return false;
21+
#else
22+
char *osProductVersion = NULL;
23+
size_t size = 0;
24+
bool ret = false;
25+
26+
// Query how large OS version string should be
27+
if(-1 == sysctlbyname("kern.osproductversion", NULL, &size, NULL, 0)){
28+
goto cleanup;
29+
}
30+
31+
osProductVersion = malloc(size + 1);
32+
33+
// Get the OS version string
34+
if(-1 == sysctlbyname("kern.osproductversion", osProductVersion, &size, NULL, 0)){
35+
goto cleanup;
36+
}
37+
38+
osProductVersion[size] = '\0';
39+
40+
// Parse the version string
41+
int major = 0, minor = 0;
42+
if(2 > sscanf(osProductVersion, "%d.%d", &major, &minor)) {
43+
goto cleanup;
44+
}
45+
46+
if(major >= 15 && minor >= 4){
47+
ret = true;
48+
}
49+
50+
cleanup:
51+
if(osProductVersion){
52+
free(osProductVersion);
53+
}
54+
55+
return ret;
56+
#endif
57+
}
58+
59+
/* ARM Scalable Matrix Extension (SME) raises all floating-point error flags
60+
* when it's used regardless of values or operations. As a consequence,
61+
* when SME is used, all FPE state is lost and special handling is needed.
62+
*
63+
* For NumPy, SME is not currently used directly, but can be used via
64+
* BLAS / LAPACK libraries. This function does a runtime check for whether
65+
* BLAS / LAPACK can use SME and special handling around FPE is required.
66+
*/
67+
static inline bool
68+
BLAS_can_use_ARM_SME(void)
69+
{
70+
#if defined(__APPLE__) && defined(__aarch64__) && defined(ACCELERATE_NEW_LAPACK)
71+
// ARM SME can be used by Apple's Accelerate framework for BLAS / LAPACK
72+
// - macOS 15.4+
73+
// - Apple silicon M4+
74+
75+
// Does OS / Accelerate support ARM SME?
76+
if(!is_macOS_version_15_4_or_later()){
77+
return false;
78+
}
79+
80+
// Does hardware support SME?
81+
int has_SME = 0;
82+
size_t size = sizeof(has_SME);
83+
if(-1 == sysctlbyname("hw.optional.arm.FEAT_SME", &has_SME, &size, NULL, 0)){
84+
return false;
85+
}
86+
87+
if(has_SME){
88+
return true;
89+
}
90+
#endif
91+
92+
// default assume SME is not used
93+
return false;
94+
}
95+
96+
/* Static variable to cache runtime check of BLAS FPE support.
97+
*/
98+
static bool blas_supports_fpe = true;
99+
100+
#endif // NPY_BLAS_CHECK_FPE_SUPPORT
101+
102+
103+
NPY_VISIBILITY_HIDDEN bool
104+
npy_blas_supports_fpe(void)
105+
{
106+
#if NPY_BLAS_CHECK_FPE_SUPPORT
107+
return blas_supports_fpe;
108+
#else
109+
return true;
110+
#endif
111+
}
112+
113+
NPY_VISIBILITY_HIDDEN void
114+
npy_blas_init(void)
115+
{
116+
#if NPY_BLAS_CHECK_FPE_SUPPORT
117+
blas_supports_fpe = !BLAS_can_use_ARM_SME();
118+
#endif
119+
}
120+
121+
NPY_VISIBILITY_HIDDEN int
122+
npy_get_floatstatus_after_blas(void)
123+
{
124+
#if NPY_BLAS_CHECK_FPE_SUPPORT
125+
if(!blas_supports_fpe){
126+
// BLAS does not support FPE and we need to return FPE state.
127+
// Instead of clearing and then grabbing state, just return
128+
// that no flags are set.
129+
return 0;
130+
}
131+
#endif
132+
char *param = NULL;
133+
return npy_get_floatstatus_barrier(param);
134+
}

numpy/_core/src/common/blas_utils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <stdbool.h>
2+
3+
#include "numpy/numpyconfig.h" // for NPY_VISIBILITY_HIDDEN
4+
5+
/* NPY_BLAS_CHECK_FPE_SUPPORT controls whether we need a runtime check
6+
* for floating-point error (FPE) support in BLAS.
7+
*/
8+
#if defined(__APPLE__) && defined(__aarch64__) && defined(ACCELERATE_NEW_LAPACK)
9+
#define NPY_BLAS_CHECK_FPE_SUPPORT 1
10+
#else
11+
#define NPY_BLAS_CHECK_FPE_SUPPORT 0
12+
#endif
13+
14+
/* Initialize BLAS environment, if needed
15+
*/
16+
NPY_VISIBILITY_HIDDEN void
17+
npy_blas_init(void);
18+
19+
/* Runtime check if BLAS supports floating-point errors.
20+
* true - BLAS supports FPE and one can rely on them to indicate errors
21+
* false - BLAS does not support FPE. Special handling needed for FPE state
22+
*/
23+
NPY_VISIBILITY_HIDDEN bool
24+
npy_blas_supports_fpe(void);
25+
26+
/* If BLAS supports FPE, exactly the same as npy_get_floatstatus_barrier().
27+
* Otherwise, we can't rely on FPE state and need special handling.
28+
*/
29+
NPY_VISIBILITY_HIDDEN int
30+
npy_get_floatstatus_after_blas(void);

numpy/_core/src/common/cblasfuncs.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "numpy/arrayobject.h"
1313
#include "numpy/npy_math.h"
1414
#include "numpy/ufuncobject.h"
15+
#include "blas_utils.h"
1516
#include "npy_cblas.h"
1617
#include "arraytypes.h"
1718
#include "common.h"
@@ -693,7 +694,7 @@ cblas_matrixproduct(int typenum, PyArrayObject *ap1, PyArrayObject *ap2,
693694
NPY_END_ALLOW_THREADS;
694695
}
695696

696-
int fpes = npy_get_floatstatus_barrier((char *) result);
697+
int fpes = npy_get_floatstatus_after_blas();
697698
if (fpes && PyUFunc_GiveFloatingpointErrors("dot", fpes) < 0) {
698699
goto fail;
699700
}

numpy/_core/src/multiarray/multiarraymodule.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ NPY_NO_EXPORT int NPY_NUMUSERTYPES = 0;
4343
#include "arraytypes.h"
4444
#include "arrayobject.h"
4545
#include "array_converter.h"
46+
#include "blas_utils.h"
4647
#include "hashdescr.h"
4748
#include "descriptor.h"
4849
#include "dragon4.h"
@@ -4781,6 +4782,10 @@ _multiarray_umath_exec(PyObject *m) {
47814782
return -1;
47824783
}
47834784

4785+
#if NPY_BLAS_CHECK_FPE_SUPPORT
4786+
npy_blas_init();
4787+
#endif
4788+
47844789
#if defined(MS_WIN64) && defined(__GNUC__)
47854790
PyErr_WarnEx(PyExc_Warning,
47864791
"Numpy built with MINGW-W64 on Windows 64 bits is experimental, " \

numpy/_core/src/umath/matmul.c.src

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717

1818

19+
#include "blas_utils.h"
1920
#include "npy_cblas.h"
2021
#include "arraytypes.h" /* For TYPE_dot functions */
2122

@@ -122,7 +123,7 @@ static inline void
122123
}
123124
}
124125

125-
NPY_NO_EXPORT void
126+
static void
126127
@name@_gemv(void *ip1, npy_intp is1_m, npy_intp is1_n,
127128
void *ip2, npy_intp is2_n,
128129
void *op, npy_intp op_m,
@@ -158,7 +159,7 @@ NPY_NO_EXPORT void
158159
is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
159160
}
160161

161-
NPY_NO_EXPORT void
162+
static void
162163
@name@_matmul_matrixmatrix(void *ip1, npy_intp is1_m, npy_intp is1_n,
163164
void *ip2, npy_intp is2_n, npy_intp is2_p,
164165
void *op, npy_intp os_m, npy_intp os_p,
@@ -262,7 +263,7 @@ NPY_NO_EXPORT void
262263
* #IS_HALF = 0, 0, 0, 1, 0*13#
263264
*/
264265

265-
NPY_NO_EXPORT void
266+
static void
266267
@TYPE@_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
267268
void *_ip2, npy_intp is2_n, npy_intp is2_p,
268269
void *_op, npy_intp os_m, npy_intp os_p,
@@ -320,7 +321,7 @@ NPY_NO_EXPORT void
320321
}
321322

322323
/**end repeat**/
323-
NPY_NO_EXPORT void
324+
static void
324325
BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
325326
void *_ip2, npy_intp is2_n, npy_intp is2_p,
326327
void *_op, npy_intp os_m, npy_intp os_p,
@@ -359,7 +360,7 @@ BOOL_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
359360
}
360361
}
361362

362-
NPY_NO_EXPORT void
363+
static void
363364
OBJECT_matmul_inner_noblas(void *_ip1, npy_intp is1_m, npy_intp is1_n,
364365
void *_ip2, npy_intp is2_n, npy_intp is2_p,
365366
void *_op, npy_intp os_m, npy_intp os_p,
@@ -631,6 +632,11 @@ NPY_NO_EXPORT void
631632
#endif
632633
}
633634
#if @USEBLAS@ && defined(HAVE_CBLAS)
635+
#if NPY_BLAS_CHECK_FPE_SUPPORT
636+
if (!npy_blas_supports_fpe()) {
637+
npy_clear_floatstatus_barrier((char*)args);
638+
}
639+
#endif
634640
if (allocate_buffer) free(tmp_ip12op);
635641
#endif
636642
}
@@ -655,7 +661,7 @@ NPY_NO_EXPORT void
655661
* #prefix = c, z, 0#
656662
* #USE_BLAS = 1, 1, 0#
657663
*/
658-
NPY_NO_EXPORT void
664+
static void
659665
@name@_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
660666
char *op, npy_intp n, void *NPY_UNUSED(ignore))
661667
{
@@ -751,6 +757,7 @@ OBJECT_dotc(char *ip1, npy_intp is1, char *ip2, npy_intp is2, char *op, npy_intp
751757
* CFLOAT, CDOUBLE, CLONGDOUBLE, OBJECT#
752758
* #DOT = dot*15, dotc*4#
753759
* #CHECK_PYERR = 0*18, 1#
760+
* #CHECK_BLAS = 1*2, 0*13, 1*2, 0*2#
754761
*/
755762
NPY_NO_EXPORT void
756763
@TYPE@_vecdot(char **args, npy_intp const *dimensions, npy_intp const *steps,
@@ -774,6 +781,11 @@ NPY_NO_EXPORT void
774781
}
775782
#endif
776783
}
784+
#if @CHECK_BLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
785+
if (!npy_blas_supports_fpe()) {
786+
npy_clear_floatstatus_barrier((char*)args);
787+
}
788+
#endif
777789
}
778790
/**end repeat**/
779791

@@ -789,7 +801,7 @@ NPY_NO_EXPORT void
789801
* #step1 = &oneF, &oneD#
790802
* #step0 = &zeroF, &zeroD#
791803
*/
792-
NPY_NO_EXPORT void
804+
static void
793805
@name@_vecmat_via_gemm(void *ip1, npy_intp is1_n,
794806
void *ip2, npy_intp is2_n, npy_intp is2_m,
795807
void *op, npy_intp os_m,
@@ -880,6 +892,11 @@ NPY_NO_EXPORT void
880892
#endif
881893
}
882894
}
895+
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
896+
if (!npy_blas_supports_fpe()) {
897+
npy_clear_floatstatus_barrier((char*)args);
898+
}
899+
#endif
883900
}
884901
/**end repeat**/
885902

@@ -945,5 +962,10 @@ NPY_NO_EXPORT void
945962
#endif
946963
}
947964
}
965+
#if @USEBLAS@ && NPY_BLAS_CHECK_FPE_SUPPORT
966+
if (!npy_blas_supports_fpe()) {
967+
npy_clear_floatstatus_barrier((char*)args);
968+
}
969+
#endif
948970
}
949971
/**end repeat**/

numpy/_core/tests/test_multiarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from numpy.exceptions import AxisError, ComplexWarning
3232
from numpy.lib.recfunctions import repack_fields
3333
from numpy.testing import (
34+
BLAS_SUPPORTS_FPE,
3435
HAS_REFCOUNT,
3536
IS_64BIT,
3637
IS_PYPY,
@@ -3363,6 +3364,11 @@ def test_dot(self):
33633364
@pytest.mark.parametrize("dtype", [np.half, np.double, np.longdouble])
33643365
@pytest.mark.skipif(IS_WASM, reason="no wasm fp exception support")
33653366
def test_dot_errstate(self, dtype):
3367+
# Some dtypes use BLAS for 'dot' operation and
3368+
# not all BLAS support floating-point errors.
3369+
if not BLAS_SUPPORTS_FPE and dtype == np.double:
3370+
pytest.skip("BLAS does not support FPE")
3371+
33663372
a = np.array([1, 1], dtype=dtype)
33673373
b = np.array([-np.inf, np.inf], dtype=dtype)
33683374

numpy/testing/_private/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
'assert_no_gc_cycles', 'break_cycles', 'HAS_LAPACK64', 'IS_PYSTON',
4343
'IS_MUSL', 'check_support_sve', 'NOGIL_BUILD',
4444
'IS_EDITABLE', 'IS_INSTALLED', 'NUMPY_ROOT', 'run_threaded', 'IS_64BIT',
45+
'BLAS_SUPPORTS_FPE',
4546
]
4647

4748

@@ -89,6 +90,15 @@ class KnownFailureException(Exception):
8990
IS_PYPY = sys.implementation.name == 'pypy'
9091
IS_PYSTON = hasattr(sys, "pyston_version_info")
9192
HAS_REFCOUNT = getattr(sys, 'getrefcount', None) is not None and not IS_PYSTON
93+
BLAS_SUPPORTS_FPE = True
94+
if platform.system() == 'Darwin' or platform.machine() == 'arm64':
95+
try:
96+
blas = np.__config__.CONFIG['Build Dependencies']['blas']
97+
if blas['name'] == 'accelerate':
98+
BLAS_SUPPORTS_FPE = False
99+
except KeyError:
100+
pass
101+
92102
HAS_LAPACK64 = numpy.linalg._umath_linalg._ilp64
93103

94104
IS_MUSL = False

0 commit comments

Comments
 (0)