1use core::cmp::min;
10
11use num_traits::Zero;
12
13use crate::{
14 dimension::{is_layout_c, is_layout_f},
15 Array,
16 ArrayRef,
17 Axis,
18 Dimension,
19 Zip,
20};
21
22impl<A, D> ArrayRef<A, D>
23where
24 D: Dimension,
25 A: Clone + Zero,
26{
27 pub fn triu(&self, k: isize) -> Array<A, D>
53 {
54 if self.ndim() <= 1 {
55 return self.to_owned();
56 }
57
58 let n = self.ndim();
62 if is_layout_f(self._dim(), self._strides()) && !is_layout_c(self._dim(), self._strides()) && k > isize::MIN {
63 let mut x = self.view();
64 x.swap_axes(n - 2, n - 1);
65 let mut tril = x.tril(-k);
66 tril.swap_axes(n - 2, n - 1);
67
68 return tril;
69 }
70
71 let mut res = Array::zeros(self.raw_dim());
72 let ncols = self.len_of(Axis(n - 1));
73 let nrows = self.len_of(Axis(n - 2));
74 let indices = Array::from_iter(0..nrows);
75 Zip::from(self.rows())
76 .and(res.rows_mut())
77 .and_broadcast(&indices)
78 .for_each(|src, mut dst, row_num| {
79 let mut lower = match k >= 0 {
80 true => row_num.saturating_add(k as usize), false => row_num.saturating_sub(k.unsigned_abs()), };
83 lower = min(lower, ncols);
84 (*dst)
85 .slice_mut(s![lower..])
86 .assign(&(*src).slice(s![lower..]));
87 });
88
89 res
90 }
91
92 pub fn tril(&self, k: isize) -> Array<A, D>
118 {
119 if self.ndim() <= 1 {
120 return self.to_owned();
121 }
122
123 let n = self.ndim();
127 if is_layout_f(self._dim(), self._strides()) && !is_layout_c(self._dim(), self._strides()) && k > isize::MIN {
128 let mut x = self.view();
129 x.swap_axes(n - 2, n - 1);
130 let mut triu = x.triu(-k);
131 triu.swap_axes(n - 2, n - 1);
132
133 return triu;
134 }
135
136 let mut res = Array::zeros(self.raw_dim());
137 let ncols = self.len_of(Axis(n - 1));
138 let nrows = self.len_of(Axis(n - 2));
139 let indices = Array::from_iter(0..nrows);
140 Zip::from(self.rows())
141 .and(res.rows_mut())
142 .and_broadcast(&indices)
143 .for_each(|src, mut dst, row_num| {
144 let mut upper = match k >= 0 {
146 true => row_num.saturating_add(k as usize).saturating_add(1), false => row_num.saturating_sub((k + 1).unsigned_abs()), };
149 upper = min(upper, ncols);
150 (*dst)
151 .slice_mut(s![..upper])
152 .assign(&(*src).slice(s![..upper]));
153 });
154
155 res
156 }
157}
158
159#[cfg(test)]
160mod tests
161{
162 use core::isize;
163
164 use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
165 use alloc::vec;
166
167 #[test]
168 fn test_keep_order()
169 {
170 let x = Array2::<f64>::ones((3, 3).f());
171 let res = x.triu(0);
172 assert!(dimension::is_layout_f(&res.parts.dim, &res.parts.strides));
173
174 let res = x.tril(0);
175 assert!(dimension::is_layout_f(&res.parts.dim, &res.parts.strides));
176 }
177
178 #[test]
179 fn test_0d()
180 {
181 let x = Array0::<f64>::ones(());
182 let res = x.triu(0);
183 assert_eq!(res, x);
184
185 let res = x.tril(0);
186 assert_eq!(res, x);
187
188 let x = Array0::<f64>::ones(().f());
189 let res = x.triu(0);
190 assert_eq!(res, x);
191
192 let res = x.tril(0);
193 assert_eq!(res, x);
194 }
195
196 #[test]
197 fn test_1d()
198 {
199 let x = array![1, 2, 3];
200 let res = x.triu(0);
201 assert_eq!(res, x);
202
203 let res = x.triu(0);
204 assert_eq!(res, x);
205
206 let x = Array1::<f64>::ones(3.f());
207 let res = x.triu(0);
208 assert_eq!(res, x);
209
210 let res = x.triu(0);
211 assert_eq!(res, x);
212 }
213
214 #[test]
215 fn test_2d()
216 {
217 let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
218
219 let res = x.triu(0);
221 assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
222
223 let res = x.tril(0);
225 assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
226
227 let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
228
229 let res = x.triu(0);
231 assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
232
233 let res = x.tril(0);
235 assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
236 }
237
238 #[test]
239 fn test_2d_single()
240 {
241 let x = array![[1]];
242
243 assert_eq!(x.triu(0), array![[1]]);
244 assert_eq!(x.tril(0), array![[1]]);
245 assert_eq!(x.triu(1), array![[0]]);
246 assert_eq!(x.tril(1), array![[1]]);
247 assert_eq!(x.triu(-1), array![[1]]);
248 assert_eq!(x.tril(-1), array![[0]]);
249 }
250
251 #[test]
252 fn test_3d()
253 {
254 let x = array![
255 [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
256 [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
257 [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
258 ];
259
260 let res = x.triu(0);
262 assert_eq!(
263 res,
264 array![
265 [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
266 [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
267 [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
268 ]
269 );
270
271 let res = x.tril(0);
273 assert_eq!(
274 res,
275 array![
276 [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
277 [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
278 [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
279 ]
280 );
281
282 let x = Array3::from_shape_vec(
283 (3, 3, 3).f(),
284 vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
285 )
286 .unwrap();
287
288 let res = x.triu(0);
290 assert_eq!(
291 res,
292 array![
293 [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
294 [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
295 [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
296 ]
297 );
298
299 let res = x.tril(0);
301 assert_eq!(
302 res,
303 array![
304 [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
305 [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
306 [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
307 ]
308 );
309 }
310
311 #[test]
312 fn test_off_axis()
313 {
314 let x = array![
315 [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
316 [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
317 [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
318 ];
319
320 let res = x.triu(1);
321 assert_eq!(
322 res,
323 array![
324 [[0, 2, 3], [0, 0, 6], [0, 0, 0]],
325 [[0, 11, 12], [0, 0, 15], [0, 0, 0]],
326 [[0, 20, 21], [0, 0, 24], [0, 0, 0]]
327 ]
328 );
329
330 let res = x.triu(-1);
331 assert_eq!(
332 res,
333 array![
334 [[1, 2, 3], [4, 5, 6], [0, 8, 9]],
335 [[10, 11, 12], [13, 14, 15], [0, 17, 18]],
336 [[19, 20, 21], [22, 23, 24], [0, 26, 27]]
337 ]
338 );
339 }
340
341 #[test]
342 fn test_odd_shape()
343 {
344 let x = array![[1, 2, 3], [4, 5, 6]];
345 let res = x.triu(0);
346 assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
347
348 let res = x.tril(0);
349 assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);
350
351 let x = array![[1, 2], [3, 4], [5, 6]];
352 let res = x.triu(0);
353 assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
354
355 let res = x.tril(0);
356 assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
357 }
358
359 #[test]
360 fn test_odd_k()
361 {
362 let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
363 let z = Array2::zeros([3, 3]);
364 assert_eq!(x.triu(isize::MIN), x);
365 assert_eq!(x.tril(isize::MIN), z);
366 assert_eq!(x.triu(isize::MAX), z);
367 assert_eq!(x.tril(isize::MAX), x);
368 }
369}