Thanks to visit codestin.com
Credit goes to docs.rs

ndarray/
impl_ops.rs

1// Copyright 2014-2016 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13/// Elements that can be used as direct operands in arithmetic with arrays.
14///
15/// For example, `f64` is a `ScalarOperand` which means that for an array `a`,
16/// arithmetic like `a + 1.0`, and, `a * 2.`, and `a += 3.` are allowed.
17///
18/// In the description below, let `A` be an array or array view,
19/// let `B` be an array with owned data,
20/// and let `C` be an array with mutable data.
21///
22/// `ScalarOperand` determines for which scalars `K` operations `&A @ K`, and `B @ K`,
23/// and `C @= K` are defined, as ***right hand side operands***, for applicable
24/// arithmetic operators (denoted `@`).
25///
26/// ***Left hand side*** scalar operands are not related to this trait
27/// (they need one `impl` per concrete scalar type); but they are still
28/// implemented for the same types, allowing operations
29/// `K @ &A`, and `K @ B` for primitive numeric types `K`.
30///
31/// This trait ***does not*** limit which elements can be stored in an array in general.
32/// Non-`ScalarOperand` types can still participate in arithmetic as array elements in
33/// in array-array operations.
34pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54    ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55/// Perform elementwise
56#[doc=$doc]
57/// between `self` and `rhs`,
58/// and return the result.
59///
60/// `self` must be an `Array` or `ArcArray`.
61///
62/// If their shapes disagree, `self` is broadcast to their broadcast shape.
63///
64/// **Panics** if broadcasting isn’t possible.
65impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67    A: Clone + $trt<B, Output=A>,
68    B: Clone,
69    S: DataOwned<Elem=A> + DataMut,
70    S2: Data<Elem=B>,
71    D: Dimension + DimMax<E>,
72    E: Dimension,
73{
74    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75    #[track_caller]
76    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
77    {
78        self.$mth(&rhs)
79    }
80}
81
82/// Perform elementwise
83#[doc=$doc]
84/// between `self` and reference `rhs`,
85/// and return the result.
86///
87/// `rhs` must be an `Array` or `ArcArray`.
88///
89/// If their shapes disagree, `self` is broadcast to their broadcast shape,
90/// cloning the data if needed.
91///
92/// **Panics** if broadcasting isn’t possible.
93impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
94where
95    A: Clone + $trt<B, Output=A>,
96    B: Clone,
97    S: DataOwned<Elem=A> + DataMut,
98    S2: Data<Elem=B>,
99    D: Dimension + DimMax<E>,
100    E: Dimension,
101{
102    type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
103    #[track_caller]
104    fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
105    {
106        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
107            let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
108            out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
109            out
110        } else {
111            let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
112            if lhs_view.shape() == self.shape() {
113                let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
114                out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
115                out
116            } else {
117                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
118            }
119        }
120    }
121}
122
123/// Perform elementwise
124#[doc=$doc]
125/// between reference `self` and `rhs`,
126/// and return the result.
127///
128/// `rhs` must be an `Array` or `ArcArray`.
129///
130/// If their shapes disagree, `self` is broadcast to their broadcast shape,
131/// cloning the data if needed.
132///
133/// **Panics** if broadcasting isn’t possible.
134impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
135where
136    A: Clone + $trt<B, Output=B>,
137    B: Clone,
138    S: Data<Elem=A>,
139    S2: DataOwned<Elem=B> + DataMut,
140    D: Dimension,
141    E: Dimension + DimMax<D>,
142{
143    type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
144    #[track_caller]
145    fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
146    where
147    {
148        if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
149            let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
150            out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
151            out
152        } else {
153            let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
154            if rhs_view.shape() == rhs.shape() {
155                let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
156                out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
157                out
158            } else {
159                Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
160            }
161        }
162    }
163}
164
165/// Perform elementwise
166#[doc=$doc]
167/// between references `self` and `rhs`,
168/// and return the result as a new `Array`.
169///
170/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape,
171/// cloning the data if needed.
172///
173/// **Panics** if broadcasting isn’t possible.
174impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
175where
176    A: Clone + $trt<B, Output=A>,
177    B: Clone,
178    S: Data<Elem=A>,
179    S2: Data<Elem=B>,
180    D: Dimension + DimMax<E>,
181    E: Dimension,
182{
183    type Output = Array<A, <D as DimMax<E>>::Output>;
184    #[track_caller]
185    fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
186        let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
187            let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
188            let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
189            (lhs, rhs)
190        } else {
191            self.broadcast_with(rhs).unwrap()
192        };
193        Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
194    }
195}
196
197/// Perform elementwise
198#[doc=$doc]
199/// between `self` and the scalar `x`,
200/// and return the result (based on `self`).
201///
202/// `self` must be an `Array` or `ArcArray`.
203impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
204    where A: Clone + $trt<B, Output=A>,
205          S: DataOwned<Elem=A> + DataMut,
206          D: Dimension,
207          B: ScalarOperand,
208{
209    type Output = ArrayBase<S, D>;
210    fn $mth(mut self, x: B) -> ArrayBase<S, D> {
211        self.map_inplace(move |elt| {
212            *elt = elt.clone() $operator x.clone();
213        });
214        self
215    }
216}
217
218/// Perform elementwise
219#[doc=$doc]
220/// between the reference `self` and the scalar `x`,
221/// and return the result as a new `Array`.
222impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
223    where A: Clone + $trt<B, Output=A>,
224          S: Data<Elem=A>,
225          D: Dimension,
226          B: ScalarOperand,
227{
228    type Output = Array<A, D>;
229    fn $mth(self, x: B) -> Self::Output {
230        self.map(move |elt| elt.clone() $operator x.clone())
231    }
232}
233    );
234);
235
236// Pick the expression $a for commutative and $b for ordered binop
237macro_rules! if_commutative {
238    (Commute { $a:expr } or { $b:expr }) => {
239        $a
240    };
241    (Ordered { $a:expr } or { $b:expr }) => {
242        $b
243    };
244}
245
246macro_rules! impl_scalar_lhs_op {
247    // $commutative flag. Reuse the self + scalar impl if we can.
248    // We can do this safely since these are the primitive numeric types
249    ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
250// these have no doc -- they are not visible in rustdoc
251// Perform elementwise
252// between the scalar `self` and array `rhs`,
253// and return the result (based on `self`).
254impl<S, D> $trt<ArrayBase<S, D>> for $scalar
255    where S: DataOwned<Elem=$scalar> + DataMut,
256          D: Dimension,
257{
258    type Output = ArrayBase<S, D>;
259    fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
260        if_commutative!($commutative {
261            rhs.$mth(self)
262        } or {{
263            let mut rhs = rhs;
264            rhs.map_inplace(move |elt| {
265                *elt = self $operator *elt;
266            });
267            rhs
268        }})
269    }
270}
271
272// Perform elementwise
273// between the scalar `self` and array `rhs`,
274// and return the result as a new `Array`.
275impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
276    where S: Data<Elem=$scalar>,
277          D: Dimension,
278{
279    type Output = Array<$scalar, D>;
280    fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
281        if_commutative!($commutative {
282            rhs.$mth(self)
283        } or {
284            rhs.map(move |elt| self.clone() $operator elt.clone())
285        })
286    }
287}
288    );
289}
290
291mod arithmetic_ops
292{
293    use super::*;
294    use crate::imp_prelude::*;
295
296    use std::ops::*;
297
298    fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C
299    {
300        move |x, y| f(x.clone(), y.clone())
301    }
302
303    fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B)
304    {
305        move |x, y| *x = f(x.clone(), y.clone())
306    }
307
308    fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A)
309    {
310        move |x, y| *x = f(y.clone(), x.clone())
311    }
312
313    impl_binary_op!(Add, +, add, +=, "addition");
314    impl_binary_op!(Sub, -, sub, -=, "subtraction");
315    impl_binary_op!(Mul, *, mul, *=, "multiplication");
316    impl_binary_op!(Div, /, div, /=, "division");
317    impl_binary_op!(Rem, %, rem, %=, "remainder");
318    impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
319    impl_binary_op!(BitOr, |, bitor, |=, "bit or");
320    impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
321    impl_binary_op!(Shl, <<, shl, <<=, "left shift");
322    impl_binary_op!(Shr, >>, shr, >>=, "right shift");
323
324    macro_rules! all_scalar_ops {
325        ($int_scalar:ty) => (
326            impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
327            impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
328            impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
329            impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
330            impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
331            impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
332            impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
333            impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
334            impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
335            impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
336        );
337    }
338    all_scalar_ops!(i8);
339    all_scalar_ops!(u8);
340    all_scalar_ops!(i16);
341    all_scalar_ops!(u16);
342    all_scalar_ops!(i32);
343    all_scalar_ops!(u32);
344    all_scalar_ops!(i64);
345    all_scalar_ops!(u64);
346    all_scalar_ops!(isize);
347    all_scalar_ops!(usize);
348    all_scalar_ops!(i128);
349    all_scalar_ops!(u128);
350
351    impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
352    impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
353    impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
354
355    impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
356    impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
357    impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
358    impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
359    impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
360
361    impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
362    impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
363    impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
364    impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
365    impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
366
367    impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
368    impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
369    impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
370    impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
371
372    impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
373    impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
374    impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
375    impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
376
377    impl<A, S, D> Neg for ArrayBase<S, D>
378    where
379        A: Clone + Neg<Output = A>,
380        S: DataOwned<Elem = A> + DataMut,
381        D: Dimension,
382    {
383        type Output = Self;
384        /// Perform an elementwise negation of `self` and return the result.
385        fn neg(mut self) -> Self
386        {
387            self.map_inplace(|elt| {
388                *elt = -elt.clone();
389            });
390            self
391        }
392    }
393
394    impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
395    where
396        &'a A: 'a + Neg<Output = A>,
397        S: Data<Elem = A>,
398        D: Dimension,
399    {
400        type Output = Array<A, D>;
401        /// Perform an elementwise negation of reference `self` and return the
402        /// result as a new `Array`.
403        fn neg(self) -> Array<A, D>
404        {
405            self.map(Neg::neg)
406        }
407    }
408
409    impl<A, S, D> Not for ArrayBase<S, D>
410    where
411        A: Clone + Not<Output = A>,
412        S: DataOwned<Elem = A> + DataMut,
413        D: Dimension,
414    {
415        type Output = Self;
416        /// Perform an elementwise unary not of `self` and return the result.
417        fn not(mut self) -> Self
418        {
419            self.map_inplace(|elt| {
420                *elt = !elt.clone();
421            });
422            self
423        }
424    }
425
426    impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
427    where
428        &'a A: 'a + Not<Output = A>,
429        S: Data<Elem = A>,
430        D: Dimension,
431    {
432        type Output = Array<A, D>;
433        /// Perform an elementwise unary not of reference `self` and return the
434        /// result as a new `Array`.
435        fn not(self) -> Array<A, D>
436        {
437            self.map(Not::not)
438        }
439    }
440}
441
442mod assign_ops
443{
444    use super::*;
445    use crate::imp_prelude::*;
446
447    macro_rules! impl_assign_op {
448        ($trt:ident, $method:ident, $doc:expr) => {
449            use std::ops::$trt;
450
451            #[doc=$doc]
452            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
453            ///
454            /// **Panics** if broadcasting isn’t possible.
455            impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
456            where
457                A: Clone + $trt<A>,
458                S: DataMut<Elem = A>,
459                S2: Data<Elem = A>,
460                D: Dimension,
461                E: Dimension,
462            {
463                #[track_caller]
464                fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
465                    self.zip_mut_with(rhs, |x, y| {
466                        x.$method(y.clone());
467                    });
468                }
469            }
470
471            #[doc=$doc]
472            impl<A, S, D> $trt<A> for ArrayBase<S, D>
473            where
474                A: ScalarOperand + $trt<A>,
475                S: DataMut<Elem = A>,
476                D: Dimension,
477            {
478                fn $method(&mut self, rhs: A) {
479                    self.map_inplace(move |elt| {
480                        elt.$method(rhs.clone());
481                    });
482                }
483            }
484        };
485    }
486
487    impl_assign_op!(
488        AddAssign,
489        add_assign,
490        "Perform `self += rhs` as elementwise addition (in place).\n"
491    );
492    impl_assign_op!(
493        SubAssign,
494        sub_assign,
495        "Perform `self -= rhs` as elementwise subtraction (in place).\n"
496    );
497    impl_assign_op!(
498        MulAssign,
499        mul_assign,
500        "Perform `self *= rhs` as elementwise multiplication (in place).\n"
501    );
502    impl_assign_op!(
503        DivAssign,
504        div_assign,
505        "Perform `self /= rhs` as elementwise division (in place).\n"
506    );
507    impl_assign_op!(
508        RemAssign,
509        rem_assign,
510        "Perform `self %= rhs` as elementwise remainder (in place).\n"
511    );
512    impl_assign_op!(
513        BitAndAssign,
514        bitand_assign,
515        "Perform `self &= rhs` as elementwise bit and (in place).\n"
516    );
517    impl_assign_op!(
518        BitOrAssign,
519        bitor_assign,
520        "Perform `self |= rhs` as elementwise bit or (in place).\n"
521    );
522    impl_assign_op!(
523        BitXorAssign,
524        bitxor_assign,
525        "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
526    );
527    impl_assign_op!(
528        ShlAssign,
529        shl_assign,
530        "Perform `self <<= rhs` as elementwise left shift (in place).\n"
531    );
532    impl_assign_op!(
533        ShrAssign,
534        shr_assign,
535        "Perform `self >>= rhs` as elementwise right shift (in place).\n"
536    );
537}