Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1883d84

Browse files
authored
BUG: Ensure summed weights returned by np.average always are correct class (numpy#30522)
1 parent e6e056a commit 1883d84

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

numpy/lib/_function_base_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def average(a, axis=None, weights=None, returned=False, *,
594594

595595
if returned:
596596
if scl.shape != avg_as_array.shape:
597-
scl = np.broadcast_to(scl, avg_as_array.shape).copy()
597+
scl = np.broadcast_to(scl, avg_as_array.shape, subok=True).copy()
598598
return avg, scl
599599
else:
600600
return avg

numpy/lib/tests/test_function_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,14 @@ class subclass(np.ndarray):
497497

498498
assert_equal(type(np.average(a)), subclass)
499499
assert_equal(type(np.average(a, weights=w)), subclass)
500+
# Ensure a possibly returned sum of weights is correct too.
501+
ra, rw = np.average(a, weights=w, returned=True)
502+
assert_equal(type(ra), subclass)
503+
assert_equal(type(rw), subclass)
504+
# Even if it needs to be broadcast.
505+
ra, rw = np.average(a, weights=w[0], axis=1, returned=True)
506+
assert_equal(type(ra), subclass)
507+
assert_equal(type(rw), subclass)
500508

501509
def test_upcasting(self):
502510
typs = [('i4', 'i4', 'f8'), ('i4', 'f4', 'f8'), ('f4', 'i4', 'f8'),

0 commit comments

Comments
 (0)