Thanks to visit codestin.com
Credit goes to docs.rs

ndarray/numeric/
impl_numeric.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul, MulAssign, Sub};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17use crate::Slice;
18
19/// # Numerical Methods for Arrays
20impl<A, D> ArrayRef<A, D>
21where D: Dimension
22{
23    /// Return the sum of all elements in the array.
24    ///
25    /// ```
26    /// use ndarray::arr2;
27    ///
28    /// let a = arr2(&[[1., 2.],
29    ///                [3., 4.]]);
30    /// assert_eq!(a.sum(), 10.);
31    /// ```
32    pub fn sum(&self) -> A
33    where A: Clone + Add<Output = A> + num_traits::Zero
34    {
35        if let Some(slc) = self.as_slice_memory_order() {
36            return numeric_util::unrolled_fold(slc, A::zero, A::add);
37        }
38        let mut sum = A::zero();
39        for row in self.rows() {
40            if let Some(slc) = row.as_slice() {
41                sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
42            } else {
43                sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
44            }
45        }
46        sum
47    }
48
49    /// Returns the [arithmetic mean] x̅ of all elements in the array:
50    ///
51    /// ```text
52    ///     1   n
53    /// x̅ = ―   ∑ xᵢ
54    ///     n  i=1
55    /// ```
56    ///
57    /// If the array is empty, `None` is returned.
58    ///
59    /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
60    ///
61    /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
62    pub fn mean(&self) -> Option<A>
63    where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
64    {
65        let n_elements = self.len();
66        if n_elements == 0 {
67            None
68        } else {
69            let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
70            Some(self.sum() / n_elements)
71        }
72    }
73
74    /// Return the product of all elements in the array.
75    ///
76    /// ```
77    /// use ndarray::arr2;
78    ///
79    /// let a = arr2(&[[1., 2.],
80    ///                [3., 4.]]);
81    /// assert_eq!(a.product(), 24.);
82    /// ```
83    pub fn product(&self) -> A
84    where A: Clone + Mul<Output = A> + num_traits::One
85    {
86        if let Some(slc) = self.as_slice_memory_order() {
87            return numeric_util::unrolled_fold(slc, A::one, A::mul);
88        }
89        let mut sum = A::one();
90        for row in self.rows() {
91            if let Some(slc) = row.as_slice() {
92                sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
93            } else {
94                sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
95            }
96        }
97        sum
98    }
99
100    /// Return the cumulative product of elements along a given axis.
101    ///
102    /// ```
103    /// use ndarray::{arr2, Axis};
104    ///
105    /// let a = arr2(&[[1., 2., 3.],
106    ///                [4., 5., 6.]]);
107    ///
108    /// // Cumulative product along rows (axis 0)
109    /// assert_eq!(
110    ///     a.cumprod(Axis(0)),
111    ///     arr2(&[[1., 2., 3.],
112    ///           [4., 10., 18.]])
113    /// );
114    ///
115    /// // Cumulative product along columns (axis 1)
116    /// assert_eq!(
117    ///     a.cumprod(Axis(1)),
118    ///     arr2(&[[1., 2., 6.],
119    ///           [4., 20., 120.]])
120    /// );
121    /// ```
122    ///
123    /// **Panics** if `axis` is out of bounds.
124    #[track_caller]
125    pub fn cumprod(&self, axis: Axis) -> Array<A, D>
126    where
127        A: Clone + Mul<Output = A> + MulAssign,
128        D: Dimension + RemoveAxis,
129    {
130        if axis.0 >= self.ndim() {
131            panic!("axis is out of bounds for array of dimension");
132        }
133
134        let mut result = self.to_owned();
135        result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone());
136        result
137    }
138
139    /// Return variance of elements in the array.
140    ///
141    /// The variance is computed using the [Welford one-pass
142    /// algorithm](https://www.jstor.org/stable/1266577).
143    ///
144    /// The parameter `ddof` specifies the "delta degrees of freedom". For
145    /// example, to calculate the population variance, use `ddof = 0`, or to
146    /// calculate the sample variance, use `ddof = 1`.
147    ///
148    /// The variance is defined as:
149    ///
150    /// ```text
151    ///               1       n
152    /// variance = ――――――――   ∑ (xᵢ - x̅)²
153    ///            n - ddof  i=1
154    /// ```
155    ///
156    /// where
157    ///
158    /// ```text
159    ///     1   n
160    /// x̅ = ―   ∑ xᵢ
161    ///     n  i=1
162    /// ```
163    ///
164    /// and `n` is the length of the array.
165    ///
166    /// **Panics** if `ddof` is less than zero or greater than `n`
167    ///
168    /// # Example
169    ///
170    /// ```
171    /// use ndarray::array;
172    /// use approx::assert_abs_diff_eq;
173    ///
174    /// let a = array![1., -4.32, 1.14, 0.32];
175    /// let var = a.var(1.);
176    /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
177    /// ```
178    #[track_caller]
179    #[cfg(feature = "std")]
180    #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
181    pub fn var(&self, ddof: A) -> A
182    where A: Float + FromPrimitive
183    {
184        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
185        let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
186        assert!(
187            !(ddof < zero || ddof > n),
188            "`ddof` must not be less than zero or greater than the length of \
189             the axis",
190        );
191        let dof = n - ddof;
192        let mut mean = A::zero();
193        let mut sum_sq = A::zero();
194        let mut i = 0;
195        self.for_each(|&x| {
196            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
197            let delta = x - mean;
198            mean = mean + delta / count;
199            sum_sq = (x - mean).mul_add(delta, sum_sq);
200            i += 1;
201        });
202        sum_sq / dof
203    }
204
205    /// Return standard deviation of elements in the array.
206    ///
207    /// The standard deviation is computed from the variance using
208    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
209    ///
210    /// The parameter `ddof` specifies the "delta degrees of freedom". For
211    /// example, to calculate the population standard deviation, use `ddof = 0`,
212    /// or to calculate the sample standard deviation, use `ddof = 1`.
213    ///
214    /// The standard deviation is defined as:
215    ///
216    /// ```text
217    ///               ⎛    1       n          ⎞
218    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
219    ///               ⎝ n - ddof  i=1         ⎠
220    /// ```
221    ///
222    /// where
223    ///
224    /// ```text
225    ///     1   n
226    /// x̅ = ―   ∑ xᵢ
227    ///     n  i=1
228    /// ```
229    ///
230    /// and `n` is the length of the array.
231    ///
232    /// **Panics** if `ddof` is less than zero or greater than `n`
233    ///
234    /// # Example
235    ///
236    /// ```
237    /// use ndarray::array;
238    /// use approx::assert_abs_diff_eq;
239    ///
240    /// let a = array![1., -4.32, 1.14, 0.32];
241    /// let stddev = a.std(1.);
242    /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
243    /// ```
244    #[track_caller]
245    #[cfg(feature = "std")]
246    #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
247    pub fn std(&self, ddof: A) -> A
248    where A: Float + FromPrimitive
249    {
250        self.var(ddof).sqrt()
251    }
252
253    /// Return sum along `axis`.
254    ///
255    /// ```
256    /// use ndarray::{aview0, aview1, arr2, Axis};
257    ///
258    /// let a = arr2(&[[1., 2., 3.],
259    ///                [4., 5., 6.]]);
260    /// assert!(
261    ///     a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
262    ///     a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
263    ///
264    ///     a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
265    /// );
266    /// ```
267    ///
268    /// **Panics** if `axis` is out of bounds.
269    #[track_caller]
270    pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
271    where
272        A: Clone + Zero + Add<Output = A>,
273        D: RemoveAxis,
274    {
275        let min_stride_axis = self._dim().min_stride_axis(self._strides());
276        if axis == min_stride_axis {
277            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
278        } else {
279            let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
280            for subview in self.axis_iter(axis) {
281                res = res + &subview;
282            }
283            res
284        }
285    }
286
287    /// Return product along `axis`.
288    ///
289    /// The product of an empty array is 1.
290    ///
291    /// ```
292    /// use ndarray::{aview0, aview1, arr2, Axis};
293    ///
294    /// let a = arr2(&[[1., 2., 3.],
295    ///                [4., 5., 6.]]);
296    ///
297    /// assert!(
298    ///     a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
299    ///     a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
300    ///
301    ///     a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
302    /// );
303    /// ```
304    ///
305    /// **Panics** if `axis` is out of bounds.
306    #[track_caller]
307    pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
308    where
309        A: Clone + One + Mul<Output = A>,
310        D: RemoveAxis,
311    {
312        let min_stride_axis = self._dim().min_stride_axis(self._strides());
313        if axis == min_stride_axis {
314            crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
315        } else {
316            let mut res = Array::ones(self.raw_dim().remove_axis(axis));
317            for subview in self.axis_iter(axis) {
318                res = res * &subview;
319            }
320            res
321        }
322    }
323
324    /// Return mean along `axis`.
325    ///
326    /// Return `None` if the length of the axis is zero.
327    ///
328    /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
329    /// fails for the axis length.
330    ///
331    /// ```
332    /// use ndarray::{aview0, aview1, arr2, Axis};
333    ///
334    /// let a = arr2(&[[1., 2., 3.],
335    ///                [4., 5., 6.]]);
336    /// assert!(
337    ///     a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
338    ///     a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
339    ///
340    ///     a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
341    /// );
342    /// ```
343    #[track_caller]
344    pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
345    where
346        A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
347        D: RemoveAxis,
348    {
349        let axis_length = self.len_of(axis);
350        if axis_length == 0 {
351            None
352        } else {
353            let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
354            let sum = self.sum_axis(axis);
355            Some(sum / aview0(&axis_length))
356        }
357    }
358
359    /// Return variance along `axis`.
360    ///
361    /// The variance is computed using the [Welford one-pass
362    /// algorithm](https://www.jstor.org/stable/1266577).
363    ///
364    /// The parameter `ddof` specifies the "delta degrees of freedom". For
365    /// example, to calculate the population variance, use `ddof = 0`, or to
366    /// calculate the sample variance, use `ddof = 1`.
367    ///
368    /// The variance is defined as:
369    ///
370    /// ```text
371    ///               1       n
372    /// variance = ――――――――   ∑ (xᵢ - x̅)²
373    ///            n - ddof  i=1
374    /// ```
375    ///
376    /// where
377    ///
378    /// ```text
379    ///     1   n
380    /// x̅ = ―   ∑ xᵢ
381    ///     n  i=1
382    /// ```
383    ///
384    /// and `n` is the length of the axis.
385    ///
386    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
387    /// is out of bounds, or if `A::from_usize()` fails for any any of the
388    /// numbers in the range `0..=n`.
389    ///
390    /// # Example
391    ///
392    /// ```
393    /// use ndarray::{aview1, arr2, Axis};
394    ///
395    /// let a = arr2(&[[1., 2.],
396    ///                [3., 4.],
397    ///                [5., 6.]]);
398    /// let var = a.var_axis(Axis(0), 1.);
399    /// assert_eq!(var, aview1(&[4., 4.]));
400    /// ```
401    #[track_caller]
402    #[cfg(feature = "std")]
403    #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
404    pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
405    where
406        A: Float + FromPrimitive,
407        D: RemoveAxis,
408    {
409        let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
410        let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
411        assert!(
412            !(ddof < zero || ddof > n),
413            "`ddof` must not be less than zero or greater than the length of \
414             the axis",
415        );
416        let dof = n - ddof;
417        let mut mean = Array::<A, _>::zeros(self._dim().remove_axis(axis));
418        let mut sum_sq = Array::<A, _>::zeros(self._dim().remove_axis(axis));
419        for (i, subview) in self.axis_iter(axis).enumerate() {
420            let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
421            azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
422                let delta = x - *mean;
423                *mean = *mean + delta / count;
424                *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
425            });
426        }
427        sum_sq.mapv_into(|s| s / dof)
428    }
429
430    /// Return standard deviation along `axis`.
431    ///
432    /// The standard deviation is computed from the variance using
433    /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
434    ///
435    /// The parameter `ddof` specifies the "delta degrees of freedom". For
436    /// example, to calculate the population standard deviation, use `ddof = 0`,
437    /// or to calculate the sample standard deviation, use `ddof = 1`.
438    ///
439    /// The standard deviation is defined as:
440    ///
441    /// ```text
442    ///               ⎛    1       n          ⎞
443    /// stddev = sqrt ⎜ ――――――――   ∑ (xᵢ - x̅)²⎟
444    ///               ⎝ n - ddof  i=1         ⎠
445    /// ```
446    ///
447    /// where
448    ///
449    /// ```text
450    ///     1   n
451    /// x̅ = ―   ∑ xᵢ
452    ///     n  i=1
453    /// ```
454    ///
455    /// and `n` is the length of the axis.
456    ///
457    /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
458    /// is out of bounds, or if `A::from_usize()` fails for any any of the
459    /// numbers in the range `0..=n`.
460    ///
461    /// # Example
462    ///
463    /// ```
464    /// use ndarray::{aview1, arr2, Axis};
465    ///
466    /// let a = arr2(&[[1., 2.],
467    ///                [3., 4.],
468    ///                [5., 6.]]);
469    /// let stddev = a.std_axis(Axis(0), 1.);
470    /// assert_eq!(stddev, aview1(&[2., 2.]));
471    /// ```
472    #[track_caller]
473    #[cfg(feature = "std")]
474    #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
475    pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
476    where
477        A: Float + FromPrimitive,
478        D: RemoveAxis,
479    {
480        self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
481    }
482
483    /// Calculates the (forward) finite differences of order `n`, along the `axis`.
484    /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
485    ///
486    /// For `n>=2`, the process is iterated:
487    /// ```
488    /// use ndarray::{array, Axis};
489    /// let arr = array![1.0, 2.0, 5.0];
490    /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
491    /// ```
492    /// **Panics** if `axis` is out of bounds
493    ///
494    /// **Panics** if `n` is too big / the array is to short:
495    /// ```should_panic
496    /// use ndarray::{array, Axis};
497    /// array![1.0, 2.0, 3.0].diff(10, Axis(0));
498    /// ```
499    pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
500    where A: Sub<A, Output = A> + Zero + Clone
501    {
502        if n == 0 {
503            return self.to_owned();
504        }
505        assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
506        assert!(
507            n < self.shape()[axis.0],
508            "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
509            n + 1,
510            self.shape()[axis.0]
511        );
512
513        let mut inp = self.to_owned();
514        let mut out = Array::zeros({
515            let mut inp_dim = self.raw_dim();
516            // inp_dim[axis.0] >= 1 as per the 2nd assertion.
517            inp_dim[axis.0] -= 1;
518            inp_dim
519        });
520        for _ in 0..n {
521            let head = inp.slice_axis(axis, Slice::from(..-1));
522            let tail = inp.slice_axis(axis, Slice::from(1..));
523
524            azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());
525
526            // feed the output as the input to the next iteration
527            std::mem::swap(&mut inp, &mut out);
528
529            // adjust the new output array width along `axis`.
530            // Current situation: width of `inp`: k, `out`: k+1
531            // needed width:               `inp`: k, `out`: k-1
532            // slice is possible, since k >= 1.
533            out.slice_axis_inplace(axis, Slice::from(..-2));
534        }
535        inp
536    }
537}