Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0dbee7e

Browse files
authored
Merge pull request #28343 from carlosgmartin/fix_norm_empty_matrix
BUG: Fix `linalg.norm` to handle empty matrices correctly.
2 parents eedb1f8 + 9bcf80c commit 0dbee7e

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
* The vector norm ``ord=inf`` and the matrix norms ``ord={1, 2, inf, 'nuc'}`` now always returns zero for empty arrays. Empty arrays have at least one axis of size zero. This affects `np.linalg.norm`, `np.linalg.vector_norm`, and `np.linalg.matrix_norm`. Previously, NumPy would raises errors or return zero depending on the shape of the array.

numpy/linalg/_linalg.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,7 +2541,7 @@ def lstsq(a, b, rcond=None):
25412541
return wrap(x), wrap(resids), rank, s
25422542

25432543

2544-
def _multi_svd_norm(x, row_axis, col_axis, op):
2544+
def _multi_svd_norm(x, row_axis, col_axis, op, initial=None):
25452545
"""Compute a function of the singular values of the 2-D matrices in `x`.
25462546
25472547
This is a private utility function used by `numpy.linalg.norm()`.
@@ -2565,7 +2565,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
25652565
25662566
"""
25672567
y = moveaxis(x, (row_axis, col_axis), (-2, -1))
2568-
result = op(svd(y, compute_uv=False), axis=-1)
2568+
result = op(svd(y, compute_uv=False), axis=-1, initial=initial)
25692569
return result
25702570

25712571

@@ -2763,7 +2763,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
27632763

27642764
if len(axis) == 1:
27652765
if ord == inf:
2766-
return abs(x).max(axis=axis, keepdims=keepdims)
2766+
return abs(x).max(axis=axis, keepdims=keepdims, initial=0)
27672767
elif ord == -inf:
27682768
return abs(x).min(axis=axis, keepdims=keepdims)
27692769
elif ord == 0:
@@ -2797,17 +2797,17 @@ def norm(x, ord=None, axis=None, keepdims=False):
27972797
if row_axis == col_axis:
27982798
raise ValueError('Duplicate axes given.')
27992799
if ord == 2:
2800-
ret = _multi_svd_norm(x, row_axis, col_axis, amax)
2800+
ret = _multi_svd_norm(x, row_axis, col_axis, amax, 0)
28012801
elif ord == -2:
28022802
ret = _multi_svd_norm(x, row_axis, col_axis, amin)
28032803
elif ord == 1:
28042804
if col_axis > row_axis:
28052805
col_axis -= 1
2806-
ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis)
2806+
ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis, initial=0)
28072807
elif ord == inf:
28082808
if row_axis > col_axis:
28092809
row_axis -= 1
2810-
ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis)
2810+
ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis, initial=0)
28112811
elif ord == -1:
28122812
if col_axis > row_axis:
28132813
col_axis -= 1
@@ -2819,7 +2819,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
28192819
elif ord in [None, 'fro', 'f']:
28202820
ret = sqrt(add.reduce((x.conj() * x).real, axis=axis))
28212821
elif ord == 'nuc':
2822-
ret = _multi_svd_norm(x, row_axis, col_axis, sum)
2822+
ret = _multi_svd_norm(x, row_axis, col_axis, sum, 0)
28232823
else:
28242824
raise ValueError("Invalid norm order for matrices.")
28252825
if keepdims:

numpy/linalg/tests/test_linalg.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2372,6 +2372,16 @@ def test_matrix_norm():
23722372
assert_almost_equal(actual, np.array([[14.2828]]), double_decimal=3)
23732373

23742374

2375+
def test_matrix_norm_empty():
2376+
for shape in [(0, 2), (2, 0), (0, 0)]:
2377+
for dtype in [np.float64, np.float32, np.int32]:
2378+
x = np.zeros(shape, dtype)
2379+
assert_equal(np.linalg.matrix_norm(x, ord="fro"), 0)
2380+
assert_equal(np.linalg.matrix_norm(x, ord="nuc"), 0)
2381+
assert_equal(np.linalg.matrix_norm(x, ord=1), 0)
2382+
assert_equal(np.linalg.matrix_norm(x, ord=2), 0)
2383+
assert_equal(np.linalg.matrix_norm(x, ord=np.inf), 0)
2384+
23752385
def test_vector_norm():
23762386
x = np.arange(9).reshape((3, 3))
23772387
actual = np.linalg.vector_norm(x)
@@ -2388,3 +2398,11 @@ def test_vector_norm():
23882398
expected = np.full((1, 1), 14.2828, dtype='float64')
23892399
assert_equal(actual.shape, expected.shape)
23902400
assert_almost_equal(actual, expected, double_decimal=3)
2401+
2402+
2403+
def test_vector_norm_empty():
2404+
for dtype in [np.float64, np.float32, np.int32]:
2405+
x = np.zeros(0, dtype)
2406+
assert_equal(np.linalg.vector_norm(x, ord=1), 0)
2407+
assert_equal(np.linalg.vector_norm(x, ord=2), 0)
2408+
assert_equal(np.linalg.vector_norm(x, ord=np.inf), 0)

0 commit comments

Comments
 (0)