ndarray/numeric/impl_numeric.rs
1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul, MulAssign, Sub};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17use crate::Slice;
18
19/// # Numerical Methods for Arrays
20impl<A, D> ArrayRef<A, D>
21where D: Dimension
22{
23 /// Return the sum of all elements in the array.
24 ///
25 /// ```
26 /// use ndarray::arr2;
27 ///
28 /// let a = arr2(&[[1., 2.],
29 /// [3., 4.]]);
30 /// assert_eq!(a.sum(), 10.);
31 /// ```
32 pub fn sum(&self) -> A
33 where A: Clone + Add<Output = A> + num_traits::Zero
34 {
35 if let Some(slc) = self.as_slice_memory_order() {
36 return numeric_util::unrolled_fold(slc, A::zero, A::add);
37 }
38 let mut sum = A::zero();
39 for row in self.rows() {
40 if let Some(slc) = row.as_slice() {
41 sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
42 } else {
43 sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
44 }
45 }
46 sum
47 }
48
49 /// Returns the [arithmetic mean] x̅ of all elements in the array:
50 ///
51 /// ```text
52 /// 1 n
53 /// x̅ = ― ∑ xᵢ
54 /// n i=1
55 /// ```
56 ///
57 /// If the array is empty, `None` is returned.
58 ///
59 /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
60 ///
61 /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
62 pub fn mean(&self) -> Option<A>
63 where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
64 {
65 let n_elements = self.len();
66 if n_elements == 0 {
67 None
68 } else {
69 let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
70 Some(self.sum() / n_elements)
71 }
72 }
73
74 /// Return the product of all elements in the array.
75 ///
76 /// ```
77 /// use ndarray::arr2;
78 ///
79 /// let a = arr2(&[[1., 2.],
80 /// [3., 4.]]);
81 /// assert_eq!(a.product(), 24.);
82 /// ```
83 pub fn product(&self) -> A
84 where A: Clone + Mul<Output = A> + num_traits::One
85 {
86 if let Some(slc) = self.as_slice_memory_order() {
87 return numeric_util::unrolled_fold(slc, A::one, A::mul);
88 }
89 let mut sum = A::one();
90 for row in self.rows() {
91 if let Some(slc) = row.as_slice() {
92 sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
93 } else {
94 sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
95 }
96 }
97 sum
98 }
99
100 /// Return the cumulative product of elements along a given axis.
101 ///
102 /// ```
103 /// use ndarray::{arr2, Axis};
104 ///
105 /// let a = arr2(&[[1., 2., 3.],
106 /// [4., 5., 6.]]);
107 ///
108 /// // Cumulative product along rows (axis 0)
109 /// assert_eq!(
110 /// a.cumprod(Axis(0)),
111 /// arr2(&[[1., 2., 3.],
112 /// [4., 10., 18.]])
113 /// );
114 ///
115 /// // Cumulative product along columns (axis 1)
116 /// assert_eq!(
117 /// a.cumprod(Axis(1)),
118 /// arr2(&[[1., 2., 6.],
119 /// [4., 20., 120.]])
120 /// );
121 /// ```
122 ///
123 /// **Panics** if `axis` is out of bounds.
124 #[track_caller]
125 pub fn cumprod(&self, axis: Axis) -> Array<A, D>
126 where
127 A: Clone + Mul<Output = A> + MulAssign,
128 D: Dimension + RemoveAxis,
129 {
130 if axis.0 >= self.ndim() {
131 panic!("axis is out of bounds for array of dimension");
132 }
133
134 let mut result = self.to_owned();
135 result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone());
136 result
137 }
138
139 /// Return variance of elements in the array.
140 ///
141 /// The variance is computed using the [Welford one-pass
142 /// algorithm](https://www.jstor.org/stable/1266577).
143 ///
144 /// The parameter `ddof` specifies the "delta degrees of freedom". For
145 /// example, to calculate the population variance, use `ddof = 0`, or to
146 /// calculate the sample variance, use `ddof = 1`.
147 ///
148 /// The variance is defined as:
149 ///
150 /// ```text
151 /// 1 n
152 /// variance = ―――――――― ∑ (xᵢ - x̅)²
153 /// n - ddof i=1
154 /// ```
155 ///
156 /// where
157 ///
158 /// ```text
159 /// 1 n
160 /// x̅ = ― ∑ xᵢ
161 /// n i=1
162 /// ```
163 ///
164 /// and `n` is the length of the array.
165 ///
166 /// **Panics** if `ddof` is less than zero or greater than `n`
167 ///
168 /// # Example
169 ///
170 /// ```
171 /// use ndarray::array;
172 /// use approx::assert_abs_diff_eq;
173 ///
174 /// let a = array![1., -4.32, 1.14, 0.32];
175 /// let var = a.var(1.);
176 /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
177 /// ```
178 #[track_caller]
179 #[cfg(feature = "std")]
180 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
181 pub fn var(&self, ddof: A) -> A
182 where A: Float + FromPrimitive
183 {
184 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
185 let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
186 assert!(
187 !(ddof < zero || ddof > n),
188 "`ddof` must not be less than zero or greater than the length of \
189 the axis",
190 );
191 let dof = n - ddof;
192 let mut mean = A::zero();
193 let mut sum_sq = A::zero();
194 let mut i = 0;
195 self.for_each(|&x| {
196 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
197 let delta = x - mean;
198 mean = mean + delta / count;
199 sum_sq = (x - mean).mul_add(delta, sum_sq);
200 i += 1;
201 });
202 sum_sq / dof
203 }
204
205 /// Return standard deviation of elements in the array.
206 ///
207 /// The standard deviation is computed from the variance using
208 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
209 ///
210 /// The parameter `ddof` specifies the "delta degrees of freedom". For
211 /// example, to calculate the population standard deviation, use `ddof = 0`,
212 /// or to calculate the sample standard deviation, use `ddof = 1`.
213 ///
214 /// The standard deviation is defined as:
215 ///
216 /// ```text
217 /// ⎛ 1 n ⎞
218 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
219 /// ⎝ n - ddof i=1 ⎠
220 /// ```
221 ///
222 /// where
223 ///
224 /// ```text
225 /// 1 n
226 /// x̅ = ― ∑ xᵢ
227 /// n i=1
228 /// ```
229 ///
230 /// and `n` is the length of the array.
231 ///
232 /// **Panics** if `ddof` is less than zero or greater than `n`
233 ///
234 /// # Example
235 ///
236 /// ```
237 /// use ndarray::array;
238 /// use approx::assert_abs_diff_eq;
239 ///
240 /// let a = array![1., -4.32, 1.14, 0.32];
241 /// let stddev = a.std(1.);
242 /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
243 /// ```
244 #[track_caller]
245 #[cfg(feature = "std")]
246 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
247 pub fn std(&self, ddof: A) -> A
248 where A: Float + FromPrimitive
249 {
250 self.var(ddof).sqrt()
251 }
252
253 /// Return sum along `axis`.
254 ///
255 /// ```
256 /// use ndarray::{aview0, aview1, arr2, Axis};
257 ///
258 /// let a = arr2(&[[1., 2., 3.],
259 /// [4., 5., 6.]]);
260 /// assert!(
261 /// a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
262 /// a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
263 ///
264 /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
265 /// );
266 /// ```
267 ///
268 /// **Panics** if `axis` is out of bounds.
269 #[track_caller]
270 pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
271 where
272 A: Clone + Zero + Add<Output = A>,
273 D: RemoveAxis,
274 {
275 let min_stride_axis = self._dim().min_stride_axis(self._strides());
276 if axis == min_stride_axis {
277 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
278 } else {
279 let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
280 for subview in self.axis_iter(axis) {
281 res = res + &subview;
282 }
283 res
284 }
285 }
286
287 /// Return product along `axis`.
288 ///
289 /// The product of an empty array is 1.
290 ///
291 /// ```
292 /// use ndarray::{aview0, aview1, arr2, Axis};
293 ///
294 /// let a = arr2(&[[1., 2., 3.],
295 /// [4., 5., 6.]]);
296 ///
297 /// assert!(
298 /// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
299 /// a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
300 ///
301 /// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
302 /// );
303 /// ```
304 ///
305 /// **Panics** if `axis` is out of bounds.
306 #[track_caller]
307 pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
308 where
309 A: Clone + One + Mul<Output = A>,
310 D: RemoveAxis,
311 {
312 let min_stride_axis = self._dim().min_stride_axis(self._strides());
313 if axis == min_stride_axis {
314 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
315 } else {
316 let mut res = Array::ones(self.raw_dim().remove_axis(axis));
317 for subview in self.axis_iter(axis) {
318 res = res * &subview;
319 }
320 res
321 }
322 }
323
324 /// Return mean along `axis`.
325 ///
326 /// Return `None` if the length of the axis is zero.
327 ///
328 /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
329 /// fails for the axis length.
330 ///
331 /// ```
332 /// use ndarray::{aview0, aview1, arr2, Axis};
333 ///
334 /// let a = arr2(&[[1., 2., 3.],
335 /// [4., 5., 6.]]);
336 /// assert!(
337 /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
338 /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
339 ///
340 /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
341 /// );
342 /// ```
343 #[track_caller]
344 pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
345 where
346 A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
347 D: RemoveAxis,
348 {
349 let axis_length = self.len_of(axis);
350 if axis_length == 0 {
351 None
352 } else {
353 let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
354 let sum = self.sum_axis(axis);
355 Some(sum / aview0(&axis_length))
356 }
357 }
358
359 /// Return variance along `axis`.
360 ///
361 /// The variance is computed using the [Welford one-pass
362 /// algorithm](https://www.jstor.org/stable/1266577).
363 ///
364 /// The parameter `ddof` specifies the "delta degrees of freedom". For
365 /// example, to calculate the population variance, use `ddof = 0`, or to
366 /// calculate the sample variance, use `ddof = 1`.
367 ///
368 /// The variance is defined as:
369 ///
370 /// ```text
371 /// 1 n
372 /// variance = ―――――――― ∑ (xᵢ - x̅)²
373 /// n - ddof i=1
374 /// ```
375 ///
376 /// where
377 ///
378 /// ```text
379 /// 1 n
380 /// x̅ = ― ∑ xᵢ
381 /// n i=1
382 /// ```
383 ///
384 /// and `n` is the length of the axis.
385 ///
386 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
387 /// is out of bounds, or if `A::from_usize()` fails for any any of the
388 /// numbers in the range `0..=n`.
389 ///
390 /// # Example
391 ///
392 /// ```
393 /// use ndarray::{aview1, arr2, Axis};
394 ///
395 /// let a = arr2(&[[1., 2.],
396 /// [3., 4.],
397 /// [5., 6.]]);
398 /// let var = a.var_axis(Axis(0), 1.);
399 /// assert_eq!(var, aview1(&[4., 4.]));
400 /// ```
401 #[track_caller]
402 #[cfg(feature = "std")]
403 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
404 pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
405 where
406 A: Float + FromPrimitive,
407 D: RemoveAxis,
408 {
409 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
410 let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
411 assert!(
412 !(ddof < zero || ddof > n),
413 "`ddof` must not be less than zero or greater than the length of \
414 the axis",
415 );
416 let dof = n - ddof;
417 let mut mean = Array::<A, _>::zeros(self._dim().remove_axis(axis));
418 let mut sum_sq = Array::<A, _>::zeros(self._dim().remove_axis(axis));
419 for (i, subview) in self.axis_iter(axis).enumerate() {
420 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
421 azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
422 let delta = x - *mean;
423 *mean = *mean + delta / count;
424 *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
425 });
426 }
427 sum_sq.mapv_into(|s| s / dof)
428 }
429
430 /// Return standard deviation along `axis`.
431 ///
432 /// The standard deviation is computed from the variance using
433 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
434 ///
435 /// The parameter `ddof` specifies the "delta degrees of freedom". For
436 /// example, to calculate the population standard deviation, use `ddof = 0`,
437 /// or to calculate the sample standard deviation, use `ddof = 1`.
438 ///
439 /// The standard deviation is defined as:
440 ///
441 /// ```text
442 /// ⎛ 1 n ⎞
443 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
444 /// ⎝ n - ddof i=1 ⎠
445 /// ```
446 ///
447 /// where
448 ///
449 /// ```text
450 /// 1 n
451 /// x̅ = ― ∑ xᵢ
452 /// n i=1
453 /// ```
454 ///
455 /// and `n` is the length of the axis.
456 ///
457 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
458 /// is out of bounds, or if `A::from_usize()` fails for any any of the
459 /// numbers in the range `0..=n`.
460 ///
461 /// # Example
462 ///
463 /// ```
464 /// use ndarray::{aview1, arr2, Axis};
465 ///
466 /// let a = arr2(&[[1., 2.],
467 /// [3., 4.],
468 /// [5., 6.]]);
469 /// let stddev = a.std_axis(Axis(0), 1.);
470 /// assert_eq!(stddev, aview1(&[2., 2.]));
471 /// ```
472 #[track_caller]
473 #[cfg(feature = "std")]
474 #[cfg_attr(docsrs, doc(cfg(feature = "std")))]
475 pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
476 where
477 A: Float + FromPrimitive,
478 D: RemoveAxis,
479 {
480 self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
481 }
482
483 /// Calculates the (forward) finite differences of order `n`, along the `axis`.
484 /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]`
485 ///
486 /// For `n>=2`, the process is iterated:
487 /// ```
488 /// use ndarray::{array, Axis};
489 /// let arr = array![1.0, 2.0, 5.0];
490 /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0)))
491 /// ```
492 /// **Panics** if `axis` is out of bounds
493 ///
494 /// **Panics** if `n` is too big / the array is to short:
495 /// ```should_panic
496 /// use ndarray::{array, Axis};
497 /// array![1.0, 2.0, 3.0].diff(10, Axis(0));
498 /// ```
499 pub fn diff(&self, n: usize, axis: Axis) -> Array<A, D>
500 where A: Sub<A, Output = A> + Zero + Clone
501 {
502 if n == 0 {
503 return self.to_owned();
504 }
505 assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis);
506 assert!(
507 n < self.shape()[axis.0],
508 "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}",
509 n + 1,
510 self.shape()[axis.0]
511 );
512
513 let mut inp = self.to_owned();
514 let mut out = Array::zeros({
515 let mut inp_dim = self.raw_dim();
516 // inp_dim[axis.0] >= 1 as per the 2nd assertion.
517 inp_dim[axis.0] -= 1;
518 inp_dim
519 });
520 for _ in 0..n {
521 let head = inp.slice_axis(axis, Slice::from(..-1));
522 let tail = inp.slice_axis(axis, Slice::from(1..));
523
524 azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone());
525
526 // feed the output as the input to the next iteration
527 std::mem::swap(&mut inp, &mut out);
528
529 // adjust the new output array width along `axis`.
530 // Current situation: width of `inp`: k, `out`: k+1
531 // needed width: `inp`: k, `out`: k-1
532 // slice is possible, since k >= 1.
533 out.slice_axis_inplace(axis, Slice::from(..-2));
534 }
535 inp
536 }
537}