ndarray/numeric/impl_numeric.rs
1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[cfg(feature = "std")]
10use num_traits::Float;
11use num_traits::One;
12use num_traits::{FromPrimitive, Zero};
13use std::ops::{Add, Div, Mul};
14
15use crate::imp_prelude::*;
16use crate::numeric_util;
17
18/// # Numerical Methods for Arrays
19impl<A, S, D> ArrayBase<S, D>
20where
21 S: Data<Elem = A>,
22 D: Dimension,
23{
24 /// Return the sum of all elements in the array.
25 ///
26 /// ```
27 /// use ndarray::arr2;
28 ///
29 /// let a = arr2(&[[1., 2.],
30 /// [3., 4.]]);
31 /// assert_eq!(a.sum(), 10.);
32 /// ```
33 pub fn sum(&self) -> A
34 where A: Clone + Add<Output = A> + num_traits::Zero
35 {
36 if let Some(slc) = self.as_slice_memory_order() {
37 return numeric_util::unrolled_fold(slc, A::zero, A::add);
38 }
39 let mut sum = A::zero();
40 for row in self.rows() {
41 if let Some(slc) = row.as_slice() {
42 sum = sum + numeric_util::unrolled_fold(slc, A::zero, A::add);
43 } else {
44 sum = sum + row.iter().fold(A::zero(), |acc, elt| acc + elt.clone());
45 }
46 }
47 sum
48 }
49
50 /// Returns the [arithmetic mean] x̅ of all elements in the array:
51 ///
52 /// ```text
53 /// 1 n
54 /// x̅ = ― ∑ xᵢ
55 /// n i=1
56 /// ```
57 ///
58 /// If the array is empty, `None` is returned.
59 ///
60 /// **Panics** if `A::from_usize()` fails to convert the number of elements in the array.
61 ///
62 /// [arithmetic mean]: https://en.wikipedia.org/wiki/Arithmetic_mean
63 pub fn mean(&self) -> Option<A>
64 where A: Clone + FromPrimitive + Add<Output = A> + Div<Output = A> + Zero
65 {
66 let n_elements = self.len();
67 if n_elements == 0 {
68 None
69 } else {
70 let n_elements = A::from_usize(n_elements).expect("Converting number of elements to `A` must not fail.");
71 Some(self.sum() / n_elements)
72 }
73 }
74
75 /// Return the product of all elements in the array.
76 ///
77 /// ```
78 /// use ndarray::arr2;
79 ///
80 /// let a = arr2(&[[1., 2.],
81 /// [3., 4.]]);
82 /// assert_eq!(a.product(), 24.);
83 /// ```
84 pub fn product(&self) -> A
85 where A: Clone + Mul<Output = A> + num_traits::One
86 {
87 if let Some(slc) = self.as_slice_memory_order() {
88 return numeric_util::unrolled_fold(slc, A::one, A::mul);
89 }
90 let mut sum = A::one();
91 for row in self.rows() {
92 if let Some(slc) = row.as_slice() {
93 sum = sum * numeric_util::unrolled_fold(slc, A::one, A::mul);
94 } else {
95 sum = sum * row.iter().fold(A::one(), |acc, elt| acc * elt.clone());
96 }
97 }
98 sum
99 }
100
101 /// Return variance of elements in the array.
102 ///
103 /// The variance is computed using the [Welford one-pass
104 /// algorithm](https://www.jstor.org/stable/1266577).
105 ///
106 /// The parameter `ddof` specifies the "delta degrees of freedom". For
107 /// example, to calculate the population variance, use `ddof = 0`, or to
108 /// calculate the sample variance, use `ddof = 1`.
109 ///
110 /// The variance is defined as:
111 ///
112 /// ```text
113 /// 1 n
114 /// variance = ―――――――― ∑ (xᵢ - x̅)²
115 /// n - ddof i=1
116 /// ```
117 ///
118 /// where
119 ///
120 /// ```text
121 /// 1 n
122 /// x̅ = ― ∑ xᵢ
123 /// n i=1
124 /// ```
125 ///
126 /// and `n` is the length of the array.
127 ///
128 /// **Panics** if `ddof` is less than zero or greater than `n`
129 ///
130 /// # Example
131 ///
132 /// ```
133 /// use ndarray::array;
134 /// use approx::assert_abs_diff_eq;
135 ///
136 /// let a = array![1., -4.32, 1.14, 0.32];
137 /// let var = a.var(1.);
138 /// assert_abs_diff_eq!(var, 6.7331, epsilon = 1e-4);
139 /// ```
140 #[track_caller]
141 #[cfg(feature = "std")]
142 pub fn var(&self, ddof: A) -> A
143 where A: Float + FromPrimitive
144 {
145 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
146 let n = A::from_usize(self.len()).expect("Converting length to `A` must not fail.");
147 assert!(
148 !(ddof < zero || ddof > n),
149 "`ddof` must not be less than zero or greater than the length of \
150 the axis",
151 );
152 let dof = n - ddof;
153 let mut mean = A::zero();
154 let mut sum_sq = A::zero();
155 let mut i = 0;
156 self.for_each(|&x| {
157 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
158 let delta = x - mean;
159 mean = mean + delta / count;
160 sum_sq = (x - mean).mul_add(delta, sum_sq);
161 i += 1;
162 });
163 sum_sq / dof
164 }
165
166 /// Return standard deviation of elements in the array.
167 ///
168 /// The standard deviation is computed from the variance using
169 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
170 ///
171 /// The parameter `ddof` specifies the "delta degrees of freedom". For
172 /// example, to calculate the population standard deviation, use `ddof = 0`,
173 /// or to calculate the sample standard deviation, use `ddof = 1`.
174 ///
175 /// The standard deviation is defined as:
176 ///
177 /// ```text
178 /// ⎛ 1 n ⎞
179 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
180 /// ⎝ n - ddof i=1 ⎠
181 /// ```
182 ///
183 /// where
184 ///
185 /// ```text
186 /// 1 n
187 /// x̅ = ― ∑ xᵢ
188 /// n i=1
189 /// ```
190 ///
191 /// and `n` is the length of the array.
192 ///
193 /// **Panics** if `ddof` is less than zero or greater than `n`
194 ///
195 /// # Example
196 ///
197 /// ```
198 /// use ndarray::array;
199 /// use approx::assert_abs_diff_eq;
200 ///
201 /// let a = array![1., -4.32, 1.14, 0.32];
202 /// let stddev = a.std(1.);
203 /// assert_abs_diff_eq!(stddev, 2.59483, epsilon = 1e-4);
204 /// ```
205 #[track_caller]
206 #[cfg(feature = "std")]
207 pub fn std(&self, ddof: A) -> A
208 where A: Float + FromPrimitive
209 {
210 self.var(ddof).sqrt()
211 }
212
213 /// Return sum along `axis`.
214 ///
215 /// ```
216 /// use ndarray::{aview0, aview1, arr2, Axis};
217 ///
218 /// let a = arr2(&[[1., 2., 3.],
219 /// [4., 5., 6.]]);
220 /// assert!(
221 /// a.sum_axis(Axis(0)) == aview1(&[5., 7., 9.]) &&
222 /// a.sum_axis(Axis(1)) == aview1(&[6., 15.]) &&
223 ///
224 /// a.sum_axis(Axis(0)).sum_axis(Axis(0)) == aview0(&21.)
225 /// );
226 /// ```
227 ///
228 /// **Panics** if `axis` is out of bounds.
229 #[track_caller]
230 pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
231 where
232 A: Clone + Zero + Add<Output = A>,
233 D: RemoveAxis,
234 {
235 let min_stride_axis = self.dim.min_stride_axis(&self.strides);
236 if axis == min_stride_axis {
237 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.sum())
238 } else {
239 let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
240 for subview in self.axis_iter(axis) {
241 res = res + &subview;
242 }
243 res
244 }
245 }
246
247 /// Return product along `axis`.
248 ///
249 /// The product of an empty array is 1.
250 ///
251 /// ```
252 /// use ndarray::{aview0, aview1, arr2, Axis};
253 ///
254 /// let a = arr2(&[[1., 2., 3.],
255 /// [4., 5., 6.]]);
256 ///
257 /// assert!(
258 /// a.product_axis(Axis(0)) == aview1(&[4., 10., 18.]) &&
259 /// a.product_axis(Axis(1)) == aview1(&[6., 120.]) &&
260 ///
261 /// a.product_axis(Axis(0)).product_axis(Axis(0)) == aview0(&720.)
262 /// );
263 /// ```
264 ///
265 /// **Panics** if `axis` is out of bounds.
266 #[track_caller]
267 pub fn product_axis(&self, axis: Axis) -> Array<A, D::Smaller>
268 where
269 A: Clone + One + Mul<Output = A>,
270 D: RemoveAxis,
271 {
272 let min_stride_axis = self.dim.min_stride_axis(&self.strides);
273 if axis == min_stride_axis {
274 crate::Zip::from(self.lanes(axis)).map_collect(|lane| lane.product())
275 } else {
276 let mut res = Array::ones(self.raw_dim().remove_axis(axis));
277 for subview in self.axis_iter(axis) {
278 res = res * &subview;
279 }
280 res
281 }
282 }
283
284 /// Return mean along `axis`.
285 ///
286 /// Return `None` if the length of the axis is zero.
287 ///
288 /// **Panics** if `axis` is out of bounds or if `A::from_usize()`
289 /// fails for the axis length.
290 ///
291 /// ```
292 /// use ndarray::{aview0, aview1, arr2, Axis};
293 ///
294 /// let a = arr2(&[[1., 2., 3.],
295 /// [4., 5., 6.]]);
296 /// assert!(
297 /// a.mean_axis(Axis(0)).unwrap() == aview1(&[2.5, 3.5, 4.5]) &&
298 /// a.mean_axis(Axis(1)).unwrap() == aview1(&[2., 5.]) &&
299 ///
300 /// a.mean_axis(Axis(0)).unwrap().mean_axis(Axis(0)).unwrap() == aview0(&3.5)
301 /// );
302 /// ```
303 #[track_caller]
304 pub fn mean_axis(&self, axis: Axis) -> Option<Array<A, D::Smaller>>
305 where
306 A: Clone + Zero + FromPrimitive + Add<Output = A> + Div<Output = A>,
307 D: RemoveAxis,
308 {
309 let axis_length = self.len_of(axis);
310 if axis_length == 0 {
311 None
312 } else {
313 let axis_length = A::from_usize(axis_length).expect("Converting axis length to `A` must not fail.");
314 let sum = self.sum_axis(axis);
315 Some(sum / aview0(&axis_length))
316 }
317 }
318
319 /// Return variance along `axis`.
320 ///
321 /// The variance is computed using the [Welford one-pass
322 /// algorithm](https://www.jstor.org/stable/1266577).
323 ///
324 /// The parameter `ddof` specifies the "delta degrees of freedom". For
325 /// example, to calculate the population variance, use `ddof = 0`, or to
326 /// calculate the sample variance, use `ddof = 1`.
327 ///
328 /// The variance is defined as:
329 ///
330 /// ```text
331 /// 1 n
332 /// variance = ―――――――― ∑ (xᵢ - x̅)²
333 /// n - ddof i=1
334 /// ```
335 ///
336 /// where
337 ///
338 /// ```text
339 /// 1 n
340 /// x̅ = ― ∑ xᵢ
341 /// n i=1
342 /// ```
343 ///
344 /// and `n` is the length of the axis.
345 ///
346 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
347 /// is out of bounds, or if `A::from_usize()` fails for any any of the
348 /// numbers in the range `0..=n`.
349 ///
350 /// # Example
351 ///
352 /// ```
353 /// use ndarray::{aview1, arr2, Axis};
354 ///
355 /// let a = arr2(&[[1., 2.],
356 /// [3., 4.],
357 /// [5., 6.]]);
358 /// let var = a.var_axis(Axis(0), 1.);
359 /// assert_eq!(var, aview1(&[4., 4.]));
360 /// ```
361 #[track_caller]
362 #[cfg(feature = "std")]
363 pub fn var_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
364 where
365 A: Float + FromPrimitive,
366 D: RemoveAxis,
367 {
368 let zero = A::from_usize(0).expect("Converting 0 to `A` must not fail.");
369 let n = A::from_usize(self.len_of(axis)).expect("Converting length to `A` must not fail.");
370 assert!(
371 !(ddof < zero || ddof > n),
372 "`ddof` must not be less than zero or greater than the length of \
373 the axis",
374 );
375 let dof = n - ddof;
376 let mut mean = Array::<A, _>::zeros(self.dim.remove_axis(axis));
377 let mut sum_sq = Array::<A, _>::zeros(self.dim.remove_axis(axis));
378 for (i, subview) in self.axis_iter(axis).enumerate() {
379 let count = A::from_usize(i + 1).expect("Converting index to `A` must not fail.");
380 azip!((mean in &mut mean, sum_sq in &mut sum_sq, &x in &subview) {
381 let delta = x - *mean;
382 *mean = *mean + delta / count;
383 *sum_sq = (x - *mean).mul_add(delta, *sum_sq);
384 });
385 }
386 sum_sq.mapv_into(|s| s / dof)
387 }
388
389 /// Return standard deviation along `axis`.
390 ///
391 /// The standard deviation is computed from the variance using
392 /// the [Welford one-pass algorithm](https://www.jstor.org/stable/1266577).
393 ///
394 /// The parameter `ddof` specifies the "delta degrees of freedom". For
395 /// example, to calculate the population standard deviation, use `ddof = 0`,
396 /// or to calculate the sample standard deviation, use `ddof = 1`.
397 ///
398 /// The standard deviation is defined as:
399 ///
400 /// ```text
401 /// ⎛ 1 n ⎞
402 /// stddev = sqrt ⎜ ―――――――― ∑ (xᵢ - x̅)²⎟
403 /// ⎝ n - ddof i=1 ⎠
404 /// ```
405 ///
406 /// where
407 ///
408 /// ```text
409 /// 1 n
410 /// x̅ = ― ∑ xᵢ
411 /// n i=1
412 /// ```
413 ///
414 /// and `n` is the length of the axis.
415 ///
416 /// **Panics** if `ddof` is less than zero or greater than `n`, if `axis`
417 /// is out of bounds, or if `A::from_usize()` fails for any any of the
418 /// numbers in the range `0..=n`.
419 ///
420 /// # Example
421 ///
422 /// ```
423 /// use ndarray::{aview1, arr2, Axis};
424 ///
425 /// let a = arr2(&[[1., 2.],
426 /// [3., 4.],
427 /// [5., 6.]]);
428 /// let stddev = a.std_axis(Axis(0), 1.);
429 /// assert_eq!(stddev, aview1(&[2., 2.]));
430 /// ```
431 #[track_caller]
432 #[cfg(feature = "std")]
433 pub fn std_axis(&self, axis: Axis, ddof: A) -> Array<A, D::Smaller>
434 where
435 A: Float + FromPrimitive,
436 D: RemoveAxis,
437 {
438 self.var_axis(axis, ddof).mapv_into(|x| x.sqrt())
439 }
440}