From f129d75863976a5b3432c79fd043e79db35a6318 Mon Sep 17 00:00:00 2001 From: Hameer Abbasi Date: Fri, 27 Apr 2018 11:16:18 +0200 Subject: [PATCH] BUG: Fix empty linalg.norm for ord=inf and ord=-inf. --- numpy/linalg/linalg.py | 12 ++++++------ numpy/linalg/tests/test_linalg.py | 4 ++++ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py index 5ee230f920d2..3f942c59ca97 100644 --- a/numpy/linalg/linalg.py +++ b/numpy/linalg/linalg.py @@ -2260,9 +2260,9 @@ def norm(x, ord=None, axis=None, keepdims=False): if len(axis) == 1: if ord == Inf: - return abs(x).max(axis=axis, keepdims=keepdims) + return abs(x).max(axis=axis, keepdims=keepdims, initial=0) elif ord == -Inf: - return abs(x).min(axis=axis, keepdims=keepdims) + return abs(x).min(axis=axis, keepdims=keepdims, initial=Inf) elif ord == 0: # Zero norm return (x != 0).astype(x.real.dtype).sum(axis=axis, keepdims=keepdims) @@ -2296,19 +2296,19 @@ def norm(x, ord=None, axis=None, keepdims=False): elif ord == 1: if col_axis > row_axis: col_axis -= 1 - ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).max(axis=col_axis, initial=0) elif ord == Inf: if row_axis > col_axis: row_axis -= 1 - ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).max(axis=row_axis, initial=0) elif ord == -1: if col_axis > row_axis: col_axis -= 1 - ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis) + ret = add.reduce(abs(x), axis=row_axis).min(axis=col_axis, initial=Inf) elif ord == -Inf: if row_axis > col_axis: row_axis -= 1 - ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis) + ret = add.reduce(abs(x), axis=col_axis).min(axis=row_axis, initial=Inf) elif ord in [None, 'fro', 'f']: ret = sqrt(add.reduce((x.conj() * x).real, axis=axis)) elif ord == 'nuc': diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py index 4a87330c71c1..b8e83a646748 100644 --- a/numpy/linalg/tests/test_linalg.py +++ b/numpy/linalg/tests/test_linalg.py @@ -1124,6 +1124,10 @@ def test_empty(self): assert_equal(norm([]), 0.0) assert_equal(norm(array([], dtype=self.dt)), 0.0) assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0) + assert_equal(norm([], ord=np.inf), 0.0) + assert_equal(norm([], ord=-np.inf), np.inf) + assert_equal(norm(np.empty((0, 1), dtype=self.dt), ord=np.inf), 0.0) + assert_equal(norm(np.empty((0, 1), dtype=self.dt), ord=-np.inf), np.inf) def test_vector_return_type(self): a = np.array([1, 0, 1])