@@ -2541,7 +2541,7 @@ def lstsq(a, b, rcond=None):
2541
2541
return wrap (x ), wrap (resids ), rank , s
2542
2542
2543
2543
2544
- def _multi_svd_norm (x , row_axis , col_axis , op ):
2544
+ def _multi_svd_norm (x , row_axis , col_axis , op , initial = None ):
2545
2545
"""Compute a function of the singular values of the 2-D matrices in `x`.
2546
2546
2547
2547
This is a private utility function used by `numpy.linalg.norm()`.
@@ -2565,7 +2565,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
2565
2565
2566
2566
"""
2567
2567
y = moveaxis (x , (row_axis , col_axis ), (- 2 , - 1 ))
2568
- result = op (svd (y , compute_uv = False ), axis = - 1 )
2568
+ result = op (svd (y , compute_uv = False ), axis = - 1 , initial = initial )
2569
2569
return result
2570
2570
2571
2571
@@ -2763,7 +2763,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
2763
2763
2764
2764
if len (axis ) == 1 :
2765
2765
if ord == inf :
2766
- return abs (x ).max (axis = axis , keepdims = keepdims )
2766
+ return abs (x ).max (axis = axis , keepdims = keepdims , initial = 0 )
2767
2767
elif ord == - inf :
2768
2768
return abs (x ).min (axis = axis , keepdims = keepdims )
2769
2769
elif ord == 0 :
@@ -2797,17 +2797,17 @@ def norm(x, ord=None, axis=None, keepdims=False):
2797
2797
if row_axis == col_axis :
2798
2798
raise ValueError ('Duplicate axes given.' )
2799
2799
if ord == 2 :
2800
- ret = _multi_svd_norm (x , row_axis , col_axis , amax )
2800
+ ret = _multi_svd_norm (x , row_axis , col_axis , amax , 0 )
2801
2801
elif ord == - 2 :
2802
2802
ret = _multi_svd_norm (x , row_axis , col_axis , amin )
2803
2803
elif ord == 1 :
2804
2804
if col_axis > row_axis :
2805
2805
col_axis -= 1
2806
- ret = add .reduce (abs (x ), axis = row_axis ).max (axis = col_axis )
2806
+ ret = add .reduce (abs (x ), axis = row_axis ).max (axis = col_axis , initial = 0 )
2807
2807
elif ord == inf :
2808
2808
if row_axis > col_axis :
2809
2809
row_axis -= 1
2810
- ret = add .reduce (abs (x ), axis = col_axis ).max (axis = row_axis )
2810
+ ret = add .reduce (abs (x ), axis = col_axis ).max (axis = row_axis , initial = 0 )
2811
2811
elif ord == - 1 :
2812
2812
if col_axis > row_axis :
2813
2813
col_axis -= 1
@@ -2819,7 +2819,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
2819
2819
elif ord in [None , 'fro' , 'f' ]:
2820
2820
ret = sqrt (add .reduce ((x .conj () * x ).real , axis = axis ))
2821
2821
elif ord == 'nuc' :
2822
- ret = _multi_svd_norm (x , row_axis , col_axis , sum )
2822
+ ret = _multi_svd_norm (x , row_axis , col_axis , sum , 0 )
2823
2823
else :
2824
2824
raise ValueError ("Invalid norm order for matrices." )
2825
2825
if keepdims :
0 commit comments