1use crate::dimension::DimMax;
10use crate::Zip;
11use num_complex::Complex;
12
13pub trait ScalarOperand: 'static + Clone {}
35impl ScalarOperand for bool {}
36impl ScalarOperand for i8 {}
37impl ScalarOperand for u8 {}
38impl ScalarOperand for i16 {}
39impl ScalarOperand for u16 {}
40impl ScalarOperand for i32 {}
41impl ScalarOperand for u32 {}
42impl ScalarOperand for i64 {}
43impl ScalarOperand for u64 {}
44impl ScalarOperand for i128 {}
45impl ScalarOperand for u128 {}
46impl ScalarOperand for isize {}
47impl ScalarOperand for usize {}
48impl ScalarOperand for f32 {}
49impl ScalarOperand for f64 {}
50impl ScalarOperand for Complex<f32> {}
51impl ScalarOperand for Complex<f64> {}
52
53macro_rules! impl_binary_op(
54 ($trt:ident, $operator:tt, $mth:ident, $iop:tt, $doc:expr) => (
55#[doc=$doc]
57impl<A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for ArrayBase<S, D>
66where
67 A: Clone + $trt<B, Output=A>,
68 B: Clone,
69 S: DataOwned<Elem=A> + DataMut,
70 S2: Data<Elem=B>,
71 D: Dimension + DimMax<E>,
72 E: Dimension,
73{
74 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
75 #[track_caller]
76 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
77 {
78 self.$mth(&rhs)
79 }
80}
81
82#[doc=$doc]
84impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
94where
95 A: Clone + $trt<B, Output=A>,
96 B: Clone,
97 S: DataOwned<Elem=A> + DataMut,
98 S2: Data<Elem=B>,
99 D: Dimension + DimMax<E>,
100 E: Dimension,
101{
102 type Output = ArrayBase<S, <D as DimMax<E>>::Output>;
103 #[track_caller]
104 fn $mth(self, rhs: &ArrayBase<S2, E>) -> Self::Output
105 {
106 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
107 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
108 out.zip_mut_with_same_shape(rhs, clone_iopf(A::$mth));
109 out
110 } else {
111 let (lhs_view, rhs_view) = self.broadcast_with(&rhs).unwrap();
112 if lhs_view.shape() == self.shape() {
113 let mut out = self.into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
114 out.zip_mut_with_same_shape(&rhs_view, clone_iopf(A::$mth));
115 out
116 } else {
117 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
118 }
119 }
120 }
121}
122
123#[doc=$doc]
125impl<'a, A, B, S, S2, D, E> $trt<ArrayBase<S2, E>> for &'a ArrayBase<S, D>
135where
136 A: Clone + $trt<B, Output=B>,
137 B: Clone,
138 S: Data<Elem=A>,
139 S2: DataOwned<Elem=B> + DataMut,
140 D: Dimension,
141 E: Dimension + DimMax<D>,
142{
143 type Output = ArrayBase<S2, <E as DimMax<D>>::Output>;
144 #[track_caller]
145 fn $mth(self, rhs: ArrayBase<S2, E>) -> Self::Output
146 where
147 {
148 if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
149 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
150 out.zip_mut_with_same_shape(self, clone_iopf_rev(A::$mth));
151 out
152 } else {
153 let (rhs_view, lhs_view) = rhs.broadcast_with(self).unwrap();
154 if rhs_view.shape() == rhs.shape() {
155 let mut out = rhs.into_dimensionality::<<E as DimMax<D>>::Output>().unwrap();
156 out.zip_mut_with_same_shape(&lhs_view, clone_iopf_rev(A::$mth));
157 out
158 } else {
159 Zip::from(&lhs_view).and(&rhs_view).map_collect_owned(clone_opf(A::$mth))
160 }
161 }
162 }
163}
164
165#[doc=$doc]
167impl<'a, A, B, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for &'a ArrayBase<S, D>
175where
176 A: Clone + $trt<B, Output=A>,
177 B: Clone,
178 S: Data<Elem=A>,
179 S2: Data<Elem=B>,
180 D: Dimension + DimMax<E>,
181 E: Dimension,
182{
183 type Output = Array<A, <D as DimMax<E>>::Output>;
184 #[track_caller]
185 fn $mth(self, rhs: &'a ArrayBase<S2, E>) -> Self::Output {
186 let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() {
187 let lhs = self.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
188 let rhs = rhs.view().into_dimensionality::<<D as DimMax<E>>::Output>().unwrap();
189 (lhs, rhs)
190 } else {
191 self.broadcast_with(rhs).unwrap()
192 };
193 Zip::from(lhs).and(rhs).map_collect(clone_opf(A::$mth))
194 }
195}
196
197#[doc=$doc]
199impl<A, S, D, B> $trt<B> for ArrayBase<S, D>
204 where A: Clone + $trt<B, Output=A>,
205 S: DataOwned<Elem=A> + DataMut,
206 D: Dimension,
207 B: ScalarOperand,
208{
209 type Output = ArrayBase<S, D>;
210 fn $mth(mut self, x: B) -> ArrayBase<S, D> {
211 self.map_inplace(move |elt| {
212 *elt = elt.clone() $operator x.clone();
213 });
214 self
215 }
216}
217
218#[doc=$doc]
220impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
223 where A: Clone + $trt<B, Output=A>,
224 S: Data<Elem=A>,
225 D: Dimension,
226 B: ScalarOperand,
227{
228 type Output = Array<A, D>;
229 fn $mth(self, x: B) -> Self::Output {
230 self.map(move |elt| elt.clone() $operator x.clone())
231 }
232}
233 );
234);
235
236macro_rules! if_commutative {
238 (Commute { $a:expr } or { $b:expr }) => {
239 $a
240 };
241 (Ordered { $a:expr } or { $b:expr }) => {
242 $b
243 };
244}
245
246macro_rules! impl_scalar_lhs_op {
247 ($scalar:ty, $commutative:ident, $operator:tt, $trt:ident, $mth:ident, $doc:expr) => (
250impl<S, D> $trt<ArrayBase<S, D>> for $scalar
255 where S: DataOwned<Elem=$scalar> + DataMut,
256 D: Dimension,
257{
258 type Output = ArrayBase<S, D>;
259 fn $mth(self, rhs: ArrayBase<S, D>) -> ArrayBase<S, D> {
260 if_commutative!($commutative {
261 rhs.$mth(self)
262 } or {{
263 let mut rhs = rhs;
264 rhs.map_inplace(move |elt| {
265 *elt = self $operator *elt;
266 });
267 rhs
268 }})
269 }
270}
271
272impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
276 where S: Data<Elem=$scalar>,
277 D: Dimension,
278{
279 type Output = Array<$scalar, D>;
280 fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
281 if_commutative!($commutative {
282 rhs.$mth(self)
283 } or {
284 rhs.map(move |elt| self.clone() $operator elt.clone())
285 })
286 }
287}
288 );
289}
290
291mod arithmetic_ops
292{
293 use super::*;
294 use crate::imp_prelude::*;
295
296 use std::ops::*;
297
298 fn clone_opf<A: Clone, B: Clone, C>(f: impl Fn(A, B) -> C) -> impl FnMut(&A, &B) -> C
299 {
300 move |x, y| f(x.clone(), y.clone())
301 }
302
303 fn clone_iopf<A: Clone, B: Clone>(f: impl Fn(A, B) -> A) -> impl FnMut(&mut A, &B)
304 {
305 move |x, y| *x = f(x.clone(), y.clone())
306 }
307
308 fn clone_iopf_rev<A: Clone, B: Clone>(f: impl Fn(A, B) -> B) -> impl FnMut(&mut B, &A)
309 {
310 move |x, y| *x = f(y.clone(), x.clone())
311 }
312
313 impl_binary_op!(Add, +, add, +=, "addition");
314 impl_binary_op!(Sub, -, sub, -=, "subtraction");
315 impl_binary_op!(Mul, *, mul, *=, "multiplication");
316 impl_binary_op!(Div, /, div, /=, "division");
317 impl_binary_op!(Rem, %, rem, %=, "remainder");
318 impl_binary_op!(BitAnd, &, bitand, &=, "bit and");
319 impl_binary_op!(BitOr, |, bitor, |=, "bit or");
320 impl_binary_op!(BitXor, ^, bitxor, ^=, "bit xor");
321 impl_binary_op!(Shl, <<, shl, <<=, "left shift");
322 impl_binary_op!(Shr, >>, shr, >>=, "right shift");
323
324 macro_rules! all_scalar_ops {
325 ($int_scalar:ty) => (
326 impl_scalar_lhs_op!($int_scalar, Commute, +, Add, add, "addition");
327 impl_scalar_lhs_op!($int_scalar, Ordered, -, Sub, sub, "subtraction");
328 impl_scalar_lhs_op!($int_scalar, Commute, *, Mul, mul, "multiplication");
329 impl_scalar_lhs_op!($int_scalar, Ordered, /, Div, div, "division");
330 impl_scalar_lhs_op!($int_scalar, Ordered, %, Rem, rem, "remainder");
331 impl_scalar_lhs_op!($int_scalar, Commute, &, BitAnd, bitand, "bit and");
332 impl_scalar_lhs_op!($int_scalar, Commute, |, BitOr, bitor, "bit or");
333 impl_scalar_lhs_op!($int_scalar, Commute, ^, BitXor, bitxor, "bit xor");
334 impl_scalar_lhs_op!($int_scalar, Ordered, <<, Shl, shl, "left shift");
335 impl_scalar_lhs_op!($int_scalar, Ordered, >>, Shr, shr, "right shift");
336 );
337 }
338 all_scalar_ops!(i8);
339 all_scalar_ops!(u8);
340 all_scalar_ops!(i16);
341 all_scalar_ops!(u16);
342 all_scalar_ops!(i32);
343 all_scalar_ops!(u32);
344 all_scalar_ops!(i64);
345 all_scalar_ops!(u64);
346 all_scalar_ops!(isize);
347 all_scalar_ops!(usize);
348 all_scalar_ops!(i128);
349 all_scalar_ops!(u128);
350
351 impl_scalar_lhs_op!(bool, Commute, &, BitAnd, bitand, "bit and");
352 impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
353 impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
354
355 impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
356 impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
357 impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
358 impl_scalar_lhs_op!(f32, Ordered, /, Div, div, "division");
359 impl_scalar_lhs_op!(f32, Ordered, %, Rem, rem, "remainder");
360
361 impl_scalar_lhs_op!(f64, Commute, +, Add, add, "addition");
362 impl_scalar_lhs_op!(f64, Ordered, -, Sub, sub, "subtraction");
363 impl_scalar_lhs_op!(f64, Commute, *, Mul, mul, "multiplication");
364 impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
365 impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
366
367 impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
368 impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
369 impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
370 impl_scalar_lhs_op!(Complex<f32>, Ordered, /, Div, div, "division");
371
372 impl_scalar_lhs_op!(Complex<f64>, Commute, +, Add, add, "addition");
373 impl_scalar_lhs_op!(Complex<f64>, Ordered, -, Sub, sub, "subtraction");
374 impl_scalar_lhs_op!(Complex<f64>, Commute, *, Mul, mul, "multiplication");
375 impl_scalar_lhs_op!(Complex<f64>, Ordered, /, Div, div, "division");
376
377 impl<A, S, D> Neg for ArrayBase<S, D>
378 where
379 A: Clone + Neg<Output = A>,
380 S: DataOwned<Elem = A> + DataMut,
381 D: Dimension,
382 {
383 type Output = Self;
384 fn neg(mut self) -> Self
386 {
387 self.map_inplace(|elt| {
388 *elt = -elt.clone();
389 });
390 self
391 }
392 }
393
394 impl<'a, A, S, D> Neg for &'a ArrayBase<S, D>
395 where
396 &'a A: 'a + Neg<Output = A>,
397 S: Data<Elem = A>,
398 D: Dimension,
399 {
400 type Output = Array<A, D>;
401 fn neg(self) -> Array<A, D>
404 {
405 self.map(Neg::neg)
406 }
407 }
408
409 impl<A, S, D> Not for ArrayBase<S, D>
410 where
411 A: Clone + Not<Output = A>,
412 S: DataOwned<Elem = A> + DataMut,
413 D: Dimension,
414 {
415 type Output = Self;
416 fn not(mut self) -> Self
418 {
419 self.map_inplace(|elt| {
420 *elt = !elt.clone();
421 });
422 self
423 }
424 }
425
426 impl<'a, A, S, D> Not for &'a ArrayBase<S, D>
427 where
428 &'a A: 'a + Not<Output = A>,
429 S: Data<Elem = A>,
430 D: Dimension,
431 {
432 type Output = Array<A, D>;
433 fn not(self) -> Array<A, D>
436 {
437 self.map(Not::not)
438 }
439 }
440}
441
442mod assign_ops
443{
444 use super::*;
445 use crate::imp_prelude::*;
446
447 macro_rules! impl_assign_op {
448 ($trt:ident, $method:ident, $doc:expr) => {
449 use std::ops::$trt;
450
451 #[doc=$doc]
452 impl<'a, A, S, S2, D, E> $trt<&'a ArrayBase<S2, E>> for ArrayBase<S, D>
456 where
457 A: Clone + $trt<A>,
458 S: DataMut<Elem = A>,
459 S2: Data<Elem = A>,
460 D: Dimension,
461 E: Dimension,
462 {
463 #[track_caller]
464 fn $method(&mut self, rhs: &ArrayBase<S2, E>) {
465 self.zip_mut_with(rhs, |x, y| {
466 x.$method(y.clone());
467 });
468 }
469 }
470
471 #[doc=$doc]
472 impl<A, S, D> $trt<A> for ArrayBase<S, D>
473 where
474 A: ScalarOperand + $trt<A>,
475 S: DataMut<Elem = A>,
476 D: Dimension,
477 {
478 fn $method(&mut self, rhs: A) {
479 self.map_inplace(move |elt| {
480 elt.$method(rhs.clone());
481 });
482 }
483 }
484 };
485 }
486
487 impl_assign_op!(
488 AddAssign,
489 add_assign,
490 "Perform `self += rhs` as elementwise addition (in place).\n"
491 );
492 impl_assign_op!(
493 SubAssign,
494 sub_assign,
495 "Perform `self -= rhs` as elementwise subtraction (in place).\n"
496 );
497 impl_assign_op!(
498 MulAssign,
499 mul_assign,
500 "Perform `self *= rhs` as elementwise multiplication (in place).\n"
501 );
502 impl_assign_op!(
503 DivAssign,
504 div_assign,
505 "Perform `self /= rhs` as elementwise division (in place).\n"
506 );
507 impl_assign_op!(
508 RemAssign,
509 rem_assign,
510 "Perform `self %= rhs` as elementwise remainder (in place).\n"
511 );
512 impl_assign_op!(
513 BitAndAssign,
514 bitand_assign,
515 "Perform `self &= rhs` as elementwise bit and (in place).\n"
516 );
517 impl_assign_op!(
518 BitOrAssign,
519 bitor_assign,
520 "Perform `self |= rhs` as elementwise bit or (in place).\n"
521 );
522 impl_assign_op!(
523 BitXorAssign,
524 bitxor_assign,
525 "Perform `self ^= rhs` as elementwise bit xor (in place).\n"
526 );
527 impl_assign_op!(
528 ShlAssign,
529 shl_assign,
530 "Perform `self <<= rhs` as elementwise left shift (in place).\n"
531 );
532 impl_assign_op!(
533 ShrAssign,
534 shr_assign,
535 "Perform `self >>= rhs` as elementwise right shift (in place).\n"
536 );
537}