Thanks to visit codestin.com
Credit goes to docs.rs

ndarray/
tri.rs

1// Copyright 2014-2024 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use core::cmp::min;
10
11use num_traits::Zero;
12
13use crate::{
14    dimension::{is_layout_c, is_layout_f},
15    Array,
16    ArrayRef,
17    Axis,
18    Dimension,
19    Zip,
20};
21
22impl<A, D> ArrayRef<A, D>
23where
24    D: Dimension,
25    A: Clone + Zero,
26{
27    /// Upper triangular of an array.
28    ///
29    /// Return a copy of the array with elements below the *k*-th diagonal zeroed.
30    /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes.
31    /// For 0D and 1D arrays, `triu` will return an unchanged clone.
32    ///
33    /// See also [`ArrayRef::tril`]
34    ///
35    /// ```
36    /// use ndarray::array;
37    ///
38    /// let arr = array![
39    ///     [1, 2, 3],
40    ///     [4, 5, 6],
41    ///     [7, 8, 9]
42    /// ];
43    /// assert_eq!(
44    ///     arr.triu(0),
45    ///     array![
46    ///         [1, 2, 3],
47    ///         [0, 5, 6],
48    ///         [0, 0, 9]
49    ///     ]
50    /// );
51    /// ```
52    pub fn triu(&self, k: isize) -> Array<A, D>
53    {
54        if self.ndim() <= 1 {
55            return self.to_owned();
56        }
57
58        // Performance optimization for F-order arrays.
59        // C-order array check prevents infinite recursion in edge cases like [[1]].
60        // k-size check prevents underflow when k == isize::MIN
61        let n = self.ndim();
62        if is_layout_f(self._dim(), self._strides()) && !is_layout_c(self._dim(), self._strides()) && k > isize::MIN {
63            let mut x = self.view();
64            x.swap_axes(n - 2, n - 1);
65            let mut tril = x.tril(-k);
66            tril.swap_axes(n - 2, n - 1);
67
68            return tril;
69        }
70
71        let mut res = Array::zeros(self.raw_dim());
72        let ncols = self.len_of(Axis(n - 1));
73        let nrows = self.len_of(Axis(n - 2));
74        let indices = Array::from_iter(0..nrows);
75        Zip::from(self.rows())
76            .and(res.rows_mut())
77            .and_broadcast(&indices)
78            .for_each(|src, mut dst, row_num| {
79                let mut lower = match k >= 0 {
80                    true => row_num.saturating_add(k as usize),        // Avoid overflow
81                    false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0
82                };
83                lower = min(lower, ncols);
84                (*dst)
85                    .slice_mut(s![lower..])
86                    .assign(&(*src).slice(s![lower..]));
87            });
88
89        res
90    }
91
92    /// Lower triangular of an array.
93    ///
94    /// Return a copy of the array with elements above the *k*-th diagonal zeroed.
95    /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes.
96    /// For 0D and 1D arrays, `tril` will return an unchanged clone.
97    ///
98    /// See also [`ArrayRef::triu`]
99    ///
100    /// ```
101    /// use ndarray::array;
102    ///
103    /// let arr = array![
104    ///     [1, 2, 3],
105    ///     [4, 5, 6],
106    ///     [7, 8, 9]
107    /// ];
108    /// assert_eq!(
109    ///     arr.tril(0),
110    ///     array![
111    ///         [1, 0, 0],
112    ///         [4, 5, 0],
113    ///         [7, 8, 9]
114    ///     ]
115    /// );
116    /// ```
117    pub fn tril(&self, k: isize) -> Array<A, D>
118    {
119        if self.ndim() <= 1 {
120            return self.to_owned();
121        }
122
123        // Performance optimization for F-order arrays.
124        // C-order array check prevents infinite recursion in edge cases like [[1]].
125        // k-size check prevents underflow when k == isize::MIN
126        let n = self.ndim();
127        if is_layout_f(self._dim(), self._strides()) && !is_layout_c(self._dim(), self._strides()) && k > isize::MIN {
128            let mut x = self.view();
129            x.swap_axes(n - 2, n - 1);
130            let mut triu = x.triu(-k);
131            triu.swap_axes(n - 2, n - 1);
132
133            return triu;
134        }
135
136        let mut res = Array::zeros(self.raw_dim());
137        let ncols = self.len_of(Axis(n - 1));
138        let nrows = self.len_of(Axis(n - 2));
139        let indices = Array::from_iter(0..nrows);
140        Zip::from(self.rows())
141            .and(res.rows_mut())
142            .and_broadcast(&indices)
143            .for_each(|src, mut dst, row_num| {
144                // let row_num = i.into_dimension().last_elem();
145                let mut upper = match k >= 0 {
146                    true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow
147                    false => row_num.saturating_sub((k + 1).unsigned_abs()),      // Avoid underflow
148                };
149                upper = min(upper, ncols);
150                (*dst)
151                    .slice_mut(s![..upper])
152                    .assign(&(*src).slice(s![..upper]));
153            });
154
155        res
156    }
157}
158
159#[cfg(test)]
160mod tests
161{
162    use core::isize;
163
164    use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder};
165    use alloc::vec;
166
167    #[test]
168    fn test_keep_order()
169    {
170        let x = Array2::<f64>::ones((3, 3).f());
171        let res = x.triu(0);
172        assert!(dimension::is_layout_f(&res.parts.dim, &res.parts.strides));
173
174        let res = x.tril(0);
175        assert!(dimension::is_layout_f(&res.parts.dim, &res.parts.strides));
176    }
177
178    #[test]
179    fn test_0d()
180    {
181        let x = Array0::<f64>::ones(());
182        let res = x.triu(0);
183        assert_eq!(res, x);
184
185        let res = x.tril(0);
186        assert_eq!(res, x);
187
188        let x = Array0::<f64>::ones(().f());
189        let res = x.triu(0);
190        assert_eq!(res, x);
191
192        let res = x.tril(0);
193        assert_eq!(res, x);
194    }
195
196    #[test]
197    fn test_1d()
198    {
199        let x = array![1, 2, 3];
200        let res = x.triu(0);
201        assert_eq!(res, x);
202
203        let res = x.triu(0);
204        assert_eq!(res, x);
205
206        let x = Array1::<f64>::ones(3.f());
207        let res = x.triu(0);
208        assert_eq!(res, x);
209
210        let res = x.triu(0);
211        assert_eq!(res, x);
212    }
213
214    #[test]
215    fn test_2d()
216    {
217        let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
218
219        // Upper
220        let res = x.triu(0);
221        assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
222
223        // Lower
224        let res = x.tril(0);
225        assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
226
227        let x = Array2::from_shape_vec((3, 3).f(), vec![1, 4, 7, 2, 5, 8, 3, 6, 9]).unwrap();
228
229        // Upper
230        let res = x.triu(0);
231        assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]);
232
233        // Lower
234        let res = x.tril(0);
235        assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]);
236    }
237
238    #[test]
239    fn test_2d_single()
240    {
241        let x = array![[1]];
242
243        assert_eq!(x.triu(0), array![[1]]);
244        assert_eq!(x.tril(0), array![[1]]);
245        assert_eq!(x.triu(1), array![[0]]);
246        assert_eq!(x.tril(1), array![[1]]);
247        assert_eq!(x.triu(-1), array![[1]]);
248        assert_eq!(x.tril(-1), array![[0]]);
249    }
250
251    #[test]
252    fn test_3d()
253    {
254        let x = array![
255            [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
256            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
257            [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
258        ];
259
260        // Upper
261        let res = x.triu(0);
262        assert_eq!(
263            res,
264            array![
265                [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
266                [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
267                [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
268            ]
269        );
270
271        // Lower
272        let res = x.tril(0);
273        assert_eq!(
274            res,
275            array![
276                [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
277                [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
278                [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
279            ]
280        );
281
282        let x = Array3::from_shape_vec(
283            (3, 3, 3).f(),
284            vec![1, 10, 19, 4, 13, 22, 7, 16, 25, 2, 11, 20, 5, 14, 23, 8, 17, 26, 3, 12, 21, 6, 15, 24, 9, 18, 27],
285        )
286        .unwrap();
287
288        // Upper
289        let res = x.triu(0);
290        assert_eq!(
291            res,
292            array![
293                [[1, 2, 3], [0, 5, 6], [0, 0, 9]],
294                [[10, 11, 12], [0, 14, 15], [0, 0, 18]],
295                [[19, 20, 21], [0, 23, 24], [0, 0, 27]]
296            ]
297        );
298
299        // Lower
300        let res = x.tril(0);
301        assert_eq!(
302            res,
303            array![
304                [[1, 0, 0], [4, 5, 0], [7, 8, 9]],
305                [[10, 0, 0], [13, 14, 0], [16, 17, 18]],
306                [[19, 0, 0], [22, 23, 0], [25, 26, 27]]
307            ]
308        );
309    }
310
311    #[test]
312    fn test_off_axis()
313    {
314        let x = array![
315            [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
316            [[10, 11, 12], [13, 14, 15], [16, 17, 18]],
317            [[19, 20, 21], [22, 23, 24], [25, 26, 27]]
318        ];
319
320        let res = x.triu(1);
321        assert_eq!(
322            res,
323            array![
324                [[0, 2, 3], [0, 0, 6], [0, 0, 0]],
325                [[0, 11, 12], [0, 0, 15], [0, 0, 0]],
326                [[0, 20, 21], [0, 0, 24], [0, 0, 0]]
327            ]
328        );
329
330        let res = x.triu(-1);
331        assert_eq!(
332            res,
333            array![
334                [[1, 2, 3], [4, 5, 6], [0, 8, 9]],
335                [[10, 11, 12], [13, 14, 15], [0, 17, 18]],
336                [[19, 20, 21], [22, 23, 24], [0, 26, 27]]
337            ]
338        );
339    }
340
341    #[test]
342    fn test_odd_shape()
343    {
344        let x = array![[1, 2, 3], [4, 5, 6]];
345        let res = x.triu(0);
346        assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]);
347
348        let res = x.tril(0);
349        assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]);
350
351        let x = array![[1, 2], [3, 4], [5, 6]];
352        let res = x.triu(0);
353        assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]);
354
355        let res = x.tril(0);
356        assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]);
357    }
358
359    #[test]
360    fn test_odd_k()
361    {
362        let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]];
363        let z = Array2::zeros([3, 3]);
364        assert_eq!(x.triu(isize::MIN), x);
365        assert_eq!(x.tril(isize::MIN), z);
366        assert_eq!(x.triu(isize::MAX), z);
367        assert_eq!(x.tril(isize::MAX), x);
368    }
369}