@@ -275,16 +275,40 @@ def test_dtype(self):
275275 assert_array_equal (tri (3 , dtype = bool ), out .astype (bool ))
276276
277277
278- def test_tril_triu ():
278+ def test_tril_triu_ndim2 ():
279279 for dtype in np .typecodes ['AllFloat' ] + np .typecodes ['AllInteger' ]:
280280 a = np .ones ((2 , 2 ), dtype = dtype )
281281 b = np .tril (a )
282282 c = np .triu (a )
283- assert_array_equal ( b , [[1 , 0 ], [1 , 1 ]])
284- assert_array_equal ( c , b .T )
283+ yield assert_array_equal , b , [[1 , 0 ], [1 , 1 ]]
284+ yield assert_array_equal , c , b .T
285285 # should return the same dtype as the original array
286- assert_equal (b .dtype , a .dtype )
287- assert_equal (c .dtype , a .dtype )
286+ yield assert_equal , b .dtype , a .dtype
287+ yield assert_equal , c .dtype , a .dtype
288+
289+ def test_tril_triu_ndim3 ():
290+ for dtype in np .typecodes ['AllFloat' ] + np .typecodes ['AllInteger' ]:
291+ a = np .array ([
292+ [[1 , 1 ], [1 , 1 ]],
293+ [[1 , 1 ], [1 , 0 ]],
294+ [[1 , 1 ], [0 , 0 ]],
295+ ], dtype = dtype )
296+ a_tril_desired = np .array ([
297+ [[1 , 0 ], [1 , 1 ]],
298+ [[1 , 0 ], [1 , 0 ]],
299+ [[1 , 0 ], [0 , 0 ]],
300+ ], dtype = dtype )
301+ a_triu_desired = np .array ([
302+ [[1 , 1 ], [0 , 1 ]],
303+ [[1 , 1 ], [0 , 0 ]],
304+ [[1 , 1 ], [0 , 0 ]],
305+ ], dtype = dtype )
306+ a_triu_observed = np .triu (a )
307+ a_tril_observed = np .tril (a )
308+ yield assert_array_equal , a_triu_observed , a_triu_desired
309+ yield assert_array_equal , a_tril_observed , a_tril_desired
310+ yield assert_equal , a_triu_observed .dtype , a .dtype
311+ yield assert_equal , a_tril_observed .dtype , a .dtype
288312
289313
290314def test_mask_indices ():
0 commit comments