Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 56eb28e

Browse files
committed
Merge pull request #4364 from argriffing/triu-broadcasting
ENH: tril and triu broadcasting
2 parents 9573f78 + e4c274f commit 56eb28e

3 files changed

Lines changed: 35 additions & 7 deletions

File tree

doc/release/1.9.0-notes.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ Dtype parameter added to `np.linspace` and `np.logspace`
8181
The returned data type from the `linspace` and `logspace` functions
8282
can now be specificed using the dtype parameter.
8383

84+
More general `np.triu` and `np.tril` broadcasting
85+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
86+
For arrays with `ndim` exceeding 2, these functions will now apply to the
87+
final two axes instead of raising an exception.
8488

8589
`tobytes` alias for `tostring` method
8690
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

numpy/lib/tests/test_twodim_base.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,16 +275,40 @@ def test_dtype(self):
275275
assert_array_equal(tri(3, dtype=bool), out.astype(bool))
276276

277277

278-
def test_tril_triu():
278+
def test_tril_triu_ndim2():
279279
for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']:
280280
a = np.ones((2, 2), dtype=dtype)
281281
b = np.tril(a)
282282
c = np.triu(a)
283-
assert_array_equal(b, [[1, 0], [1, 1]])
284-
assert_array_equal(c, b.T)
283+
yield assert_array_equal, b, [[1, 0], [1, 1]]
284+
yield assert_array_equal, c, b.T
285285
# should return the same dtype as the original array
286-
assert_equal(b.dtype, a.dtype)
287-
assert_equal(c.dtype, a.dtype)
286+
yield assert_equal, b.dtype, a.dtype
287+
yield assert_equal, c.dtype, a.dtype
288+
289+
def test_tril_triu_ndim3():
290+
for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']:
291+
a = np.array([
292+
[[1, 1], [1, 1]],
293+
[[1, 1], [1, 0]],
294+
[[1, 1], [0, 0]],
295+
], dtype=dtype)
296+
a_tril_desired = np.array([
297+
[[1, 0], [1, 1]],
298+
[[1, 0], [1, 0]],
299+
[[1, 0], [0, 0]],
300+
], dtype=dtype)
301+
a_triu_desired = np.array([
302+
[[1, 1], [0, 1]],
303+
[[1, 1], [0, 0]],
304+
[[1, 1], [0, 0]],
305+
], dtype=dtype)
306+
a_triu_observed = np.triu(a)
307+
a_tril_observed = np.tril(a)
308+
yield assert_array_equal, a_triu_observed, a_triu_desired
309+
yield assert_array_equal, a_tril_observed, a_tril_desired
310+
yield assert_equal, a_triu_observed.dtype, a.dtype
311+
yield assert_equal, a_tril_observed.dtype, a.dtype
288312

289313

290314
def test_mask_indices():

numpy/lib/twodim_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def tril(m, k=0):
430430
431431
"""
432432
m = asanyarray(m)
433-
out = multiply(tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype), m)
433+
out = multiply(tri(m.shape[-2], m.shape[-1], k=k, dtype=m.dtype), m)
434434
return out
435435

436436

@@ -457,7 +457,7 @@ def triu(m, k=0):
457457
458458
"""
459459
m = asanyarray(m)
460-
out = multiply((1 - tri(m.shape[0], m.shape[1], k - 1, dtype=m.dtype)), m)
460+
out = multiply((1 - tri(m.shape[-2], m.shape[-1], k - 1, dtype=m.dtype)), m)
461461
return out
462462

463463

0 commit comments

Comments
 (0)