From deebf2c7b731f2397e0f21f838a348dad55dd8da Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Sun, 20 Apr 2025 22:29:52 +0900 Subject: [PATCH 1/4] ENH: np.linalg.inv: support noerr parameter --- numpy/linalg/_linalg.py | 37 ++++++++++++++++++++++++++----- numpy/linalg/tests/test_linalg.py | 15 +++++++++++++ 2 files changed, 47 insertions(+), 5 deletions(-) diff --git a/numpy/linalg/_linalg.py b/numpy/linalg/_linalg.py index 1301f1cb7e9a..0b66d87b16fd 100644 --- a/numpy/linalg/_linalg.py +++ b/numpy/linalg/_linalg.py @@ -491,12 +491,12 @@ def tensorinv(a, ind=2): # Matrix inversion -def _unary_dispatcher(a): +def _unary_dispatcher(a, *, noerr=None): return (a,) @array_function_dispatch(_unary_dispatcher) -def inv(a): +def inv(a, *, noerr=False): """ Compute the inverse of a matrix. @@ -507,6 +507,10 @@ def inv(a): ---------- a : (..., M, M) array_like Matrix to be inverted. + noerr : bool, optional + If True, do not raise a LinAlgError when a matrix is singular. + Instead, return a matrix with NaN values for the singular matrices. + Default is False. Returns ------- @@ -579,6 +583,25 @@ def inv(a): [-0.5 , 0.625, 0.25 ], [ 0. , 0. , 1. ]]) + Using the `noerr` parameter to handle singular matrices in a stack: + + >>> a = np.array([ + ... [[1.0, 0.0], [0.0, 1.0]], # invertible + ... [[1.0, 1.0], [1.0, 1.0]], # singular + ... [[2.0, 1.0], [1.0, 2.0]] # invertible + ... ]) + >>> # Without noerr, a LinAlgError is raised + >>> try: + ... inv(a) + ... except np.linalg.LinAlgError: + ... print("LinAlgError raised") + LinAlgError raised + >>> # With noerr=True, NaN values are returned for singular matrices + >>> result = inv(a, noerr=True) + >>> # Check which matrices were singular + >>> np.isnan(result).any(axis=(1, 2)) + array([False, True, False]) + To detect ill-conditioned matrices, you can use `numpy.linalg.cond` to compute its *condition number* [1]_. The larger the condition number, the more ill-conditioned the matrix is. As a rule of thumb, if the condition @@ -605,9 +628,13 @@ def inv(a): t, result_t = _commonType(a) signature = 'D->D' if isComplexType(t) else 'd->d' - with errstate(call=_raise_linalgerror_singular, invalid='call', - over='ignore', divide='ignore', under='ignore'): - ainv = _umath_linalg.inv(a, signature=signature) + if noerr: + with errstate(all='ignore'): + ainv = _umath_linalg.inv(a, signature=signature) + else: + with errstate(call=_raise_linalgerror_singular, invalid='call', + over='ignore', divide='ignore', under='ignore'): + ainv = _umath_linalg.inv(a, signature=signature) return wrap(ainv.astype(result_t, copy=False)) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 1a79629814e9..20015b17f012 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -569,6 +569,21 @@ class ArraySubclass(np.ndarray): assert_equal(a.shape, res.shape) assert_(isinstance(res, ArraySubclass)) + def test_noerr(self): + # test noerr=True case + a = np.array([ + [[1.0, 0.0], [0.0, 1.0]], # invertible + [[1.0, 1.0], [1.0, 1.0]], # singular + [[2.0, 1.0], [1.0, 2.0]] # invertible + ]) + + result = linalg.inv(a, noerr=True) + + assert_almost_equal(result[0], np.array([[1.0, 0.0], [0.0, 1.0]])) + assert_almost_equal(result[2], np.array([[2/3, -1/3], [-1/3, 2/3]])) + + assert_(np.isnan(result[1]).all()) + class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): From 0c5cd88e3f433a267f72c8576ba86d94394d0968 Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Sun, 20 Apr 2025 23:19:34 +0900 Subject: [PATCH 2/4] FIX: linter --- numpy/linalg/tests/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 20015b17f012..65f805f8c8f2 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -580,7 +580,7 @@ def test_noerr(self): result = linalg.inv(a, noerr=True) assert_almost_equal(result[0], np.array([[1.0, 0.0], [0.0, 1.0]])) - assert_almost_equal(result[2], np.array([[2/3, -1/3], [-1/3, 2/3]])) + assert_almost_equal(result[2], np.array([[2 / 3, -1 / 3], [-1 / 3, 2 / 3]])) assert_(np.isnan(result[1]).all()) From 13fe6a4f9397f0baacb1f64b387eaacfd4b3a432 Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Thu, 24 Apr 2025 01:26:06 +0900 Subject: [PATCH 3/4] ENH: add testcase --- numpy/linalg/tests/test_linalg.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 65f805f8c8f2..f63dbc95c04b 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -580,9 +580,11 @@ def test_noerr(self): result = linalg.inv(a, noerr=True) assert_almost_equal(result[0], np.array([[1.0, 0.0], [0.0, 1.0]])) + assert_(np.isnan(result[1]).all()) assert_almost_equal(result[2], np.array([[2 / 3, -1 / 3], [-1 / 3, 2 / 3]])) - assert_(np.isnan(result[1]).all()) + with assert_raises(np.linalg.LinAlgError): + linalg.inv(a, noerr=False) class EigvalsCases(LinalgSquareTestCase, LinalgGeneralizedSquareTestCase): From cb8ffa007903fab26efe9e0c65f252a6437930b8 Mon Sep 17 00:00:00 2001 From: koki watanabe Date: Fri, 25 Apr 2025 01:06:48 +0900 Subject: [PATCH 4/4] FIX: test function --- numpy/linalg/tests/test_linalg.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index f63dbc95c04b..6b8f2565916d 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -579,10 +579,16 @@ def test_noerr(self): result = linalg.inv(a, noerr=True) - assert_almost_equal(result[0], np.array([[1.0, 0.0], [0.0, 1.0]])) - assert_(np.isnan(result[1]).all()) - assert_almost_equal(result[2], np.array([[2 / 3, -1 / 3], [-1 / 3, 2 / 3]])) + assert_allclose( + result, + [ + [[1.0, 0.0], [0.0, 1.0]], + [[np.nan, np.nan], [np.nan, np.nan]], + [[2.0 / 3, -1.0 / 3], [-1.0 / 3, 2.0 / 3]], + ] + ) + # test noerr=False case with assert_raises(np.linalg.LinAlgError): linalg.inv(a, noerr=False)