ndarray/zip/mod.rs
1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::partial::Partial;
18use crate::AssignElem;
19use crate::IntoDimension;
20use crate::Layout;
21
22use crate::dimension;
23use crate::indexes::{indices, Indices};
24use crate::split_at::{SplitAt, SplitPreference};
25
26pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30 ($e:expr) => {
31 match $e {
32 FoldWhile::Continue(x) => x,
33 x => return x,
34 }
35 };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where E: IntoDimension
43{
44 type Output: NdProducer<Dim = E::Dim>;
45 /// Broadcast the array to the new dimensions `shape`.
46 ///
47 /// ***Panics*** if broadcasting isn’t possible.
48 #[track_caller]
49 fn broadcast_unwrap(self, shape: E) -> Self::Output;
50 private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout
55{
56 let n = dim.ndim();
57 if dimension::is_layout_c(dim, strides) {
58 // effectively one-dimensional => C and F layout compatible
59 if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
60 Layout::one_dimensional()
61 } else {
62 Layout::c()
63 }
64 } else if n > 1 && dimension::is_layout_f(dim, strides) {
65 Layout::f()
66 } else if n > 1 {
67 if dim[0] > 1 && strides[0] == 1 {
68 Layout::fpref()
69 } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
70 Layout::cpref()
71 } else {
72 Layout::none()
73 }
74 } else {
75 Layout::none()
76 }
77}
78
79impl<A, D> LayoutRef<A, D>
80where D: Dimension
81{
82 pub(crate) fn layout_impl(&self) -> Layout
83 {
84 array_layout(self._dim(), self._strides())
85 }
86}
87
88impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
89where
90 E: IntoDimension,
91 D: Dimension,
92{
93 type Output = ArrayView<'a, A, E::Dim>;
94 fn broadcast_unwrap(self, shape: E) -> Self::Output
95 {
96 #[allow(clippy::needless_borrow)]
97 let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension());
98 unsafe { ArrayView::new(res.parts.ptr, res.parts.dim, res.parts.strides) }
99 }
100 private_impl! {}
101}
102
103trait ZippableTuple: Sized
104{
105 type Item;
106 type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
107 type Dim: Dimension;
108 type Stride: Copy;
109 fn as_ptr(&self) -> Self::Ptr;
110 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
111 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
112 fn stride_of(&self, index: usize) -> Self::Stride;
113 fn contiguous_stride(&self) -> Self::Stride;
114 fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
115}
116
117/// Lock step function application across several arrays or other producers.
118///
119/// Zip allows matching several producers to each other elementwise and applying
120/// a function over all tuples of elements (one item from each input at
121/// a time).
122///
123/// In general, the zip uses a tuple of producers
124/// ([`NdProducer`] trait) that all have to be of the
125/// same shape. The NdProducer implementation defines what its item type is
126/// (for example if it's a shared reference, mutable reference or an array
127/// view etc).
128///
129/// If all the input arrays are of the same memory layout the zip performs much
130/// better and the compiler can usually vectorize the loop (if applicable).
131///
132/// The order elements are visited is not specified. The producers don’t have to
133/// have the same item type.
134///
135/// The `Zip` has two methods for function application: `for_each` and
136/// `fold_while`. The zip object can be split, which allows parallelization.
137/// A read-only zip object (no mutable producers) can be cloned.
138///
139/// See also the [`azip!()`] which offers a convenient shorthand
140/// to common ways to use `Zip`.
141///
142/// ```
143/// use ndarray::Zip;
144/// use ndarray::Array2;
145///
146/// type M = Array2<f64>;
147///
148/// // Create four 2d arrays of the same size
149/// let mut a = M::zeros((64, 32));
150/// let b = M::from_elem(a.dim(), 1.);
151/// let c = M::from_elem(a.dim(), 2.);
152/// let d = M::from_elem(a.dim(), 3.);
153///
154/// // Example 1: Perform an elementwise arithmetic operation across
155/// // the four arrays a, b, c, d.
156///
157/// Zip::from(&mut a)
158/// .and(&b)
159/// .and(&c)
160/// .and(&d)
161/// .for_each(|w, &x, &y, &z| {
162/// *w += x + y * z;
163/// });
164///
165/// // Example 2: Create a new array `totals` with one entry per row of `a`.
166/// // Use Zip to traverse the rows of `a` and assign to the corresponding
167/// // entry in `totals` with the sum across each row.
168/// // This is possible because the producer for `totals` and the row producer
169/// // for `a` have the same shape and dimensionality.
170/// // The rows producer yields one array view (`row`) per iteration.
171///
172/// use ndarray::{Array1, Axis};
173///
174/// let mut totals = Array1::zeros(a.nrows());
175///
176/// Zip::from(&mut totals)
177/// .and(a.rows())
178/// .for_each(|totals, row| *totals = row.sum());
179///
180/// // Check the result against the built in `.sum_axis()` along axis 1.
181/// assert_eq!(totals, a.sum_axis(Axis(1)));
182///
183///
184/// // Example 3: Recreate Example 2 using map_collect to make a new array
185///
186/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
187///
188/// // Check the result against the previous example.
189/// assert_eq!(totals, totals2);
190/// ```
191#[derive(Debug, Clone)]
192#[must_use = "zipping producers is lazy and does nothing unless consumed"]
193pub struct Zip<Parts, D>
194{
195 parts: Parts,
196 dimension: D,
197 layout: Layout,
198 /// The sum of the layout tendencies of the parts;
199 /// positive for c- and negative for f-layout preference.
200 layout_tendency: i32,
201}
202
203impl<P, D> Zip<(P,), D>
204where
205 D: Dimension,
206 P: NdProducer<Dim = D>,
207{
208 /// Create a new `Zip` from the input array or other producer `p`.
209 ///
210 /// The Zip will take the exact dimension of `p` and all inputs
211 /// must have the same dimensions (or be broadcast to them).
212 pub fn from<IP>(p: IP) -> Self
213 where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
214 {
215 let array = p.into_producer();
216 let dim = array.raw_dim();
217 let layout = array.layout();
218 Zip {
219 dimension: dim,
220 layout,
221 parts: (array,),
222 layout_tendency: layout.tendency(),
223 }
224 }
225}
226impl<P, D> Zip<(Indices<D>, P), D>
227where
228 D: Dimension + Copy,
229 P: NdProducer<Dim = D>,
230{
231 /// Create a new `Zip` with an index producer and the producer `p`.
232 ///
233 /// The Zip will take the exact dimension of `p` and all inputs
234 /// must have the same dimensions (or be broadcast to them).
235 ///
236 /// *Note:* Indexed zip has overhead.
237 pub fn indexed<IP>(p: IP) -> Self
238 where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
239 {
240 let array = p.into_producer();
241 let dim = array.raw_dim();
242 Zip::from(indices(dim)).and(array)
243 }
244}
245
246#[inline]
247fn zip_dimension_check<D, P>(dimension: &D, part: &P)
248where
249 D: Dimension,
250 P: NdProducer<Dim = D>,
251{
252 ndassert!(
253 part.equal_dim(dimension),
254 "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
255 dimension,
256 part.raw_dim()
257 );
258}
259
260impl<Parts, D> Zip<Parts, D>
261where D: Dimension
262{
263 /// Return a the number of element tuples in the Zip
264 pub fn size(&self) -> usize
265 {
266 self.dimension.size()
267 }
268
269 /// Return the length of `axis`
270 ///
271 /// ***Panics*** if `axis` is out of bounds.
272 #[track_caller]
273 fn len_of(&self, axis: Axis) -> usize
274 {
275 self.dimension[axis.index()]
276 }
277
278 fn prefer_f(&self) -> bool
279 {
280 !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
281 }
282
283 /// Return an *approximation* to the max stride axis; if
284 /// component arrays disagree, there may be no choice better than the
285 /// others.
286 fn max_stride_axis(&self) -> Axis
287 {
288 let i = if self.prefer_f() {
289 self.dimension
290 .slice()
291 .iter()
292 .rposition(|&len| len > 1)
293 .unwrap_or(self.dimension.ndim() - 1)
294 } else {
295 /* corder or default */
296 self.dimension
297 .slice()
298 .iter()
299 .position(|&len| len > 1)
300 .unwrap_or(0)
301 };
302 Axis(i)
303 }
304}
305
306impl<P, D> Zip<P, D>
307where D: Dimension
308{
309 fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
310 where
311 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
312 P: ZippableTuple<Dim = D>,
313 {
314 if self.dimension.ndim() == 0 {
315 function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
316 } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
317 self.for_each_core_contiguous(acc, function)
318 } else {
319 self.for_each_core_strided(acc, function)
320 }
321 }
322
323 fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
324 where
325 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
326 P: ZippableTuple<Dim = D>,
327 {
328 debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
329 let size = self.dimension.size();
330 let ptrs = self.parts.as_ptr();
331 let inner_strides = self.parts.contiguous_stride();
332 unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) }
333 }
334
335 /// The innermost loop of the Zip for_each methods
336 ///
337 /// Run the fold while operation on a stretch of elements with constant strides
338 ///
339 /// `ptr`: base pointer for the first element in this stretch
340 /// `strides`: strides for the elements in this stretch
341 /// `len`: number of elements
342 /// `function`: closure
343 unsafe fn inner<F, Acc>(
344 &self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F,
345 ) -> FoldWhile<Acc>
346 where
347 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
348 P: ZippableTuple,
349 {
350 let mut i = 0;
351 while i < len {
352 let p = ptr.stride_offset(strides, i);
353 acc = fold_while!(function(acc, self.parts.as_ref(p)));
354 i += 1;
355 }
356 FoldWhile::Continue(acc)
357 }
358
359 fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
360 where
361 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
362 P: ZippableTuple<Dim = D>,
363 {
364 let n = self.dimension.ndim();
365 if n == 0 {
366 panic!("Unreachable: ndim == 0 is contiguous")
367 }
368 if n == 1 || self.layout_tendency >= 0 {
369 self.for_each_core_strided_c(acc, function)
370 } else {
371 self.for_each_core_strided_f(acc, function)
372 }
373 }
374
375 // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
376 fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
377 where
378 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
379 P: ZippableTuple<Dim = D>,
380 {
381 let n = self.dimension.ndim();
382 let unroll_axis = n - 1;
383 let inner_len = self.dimension[unroll_axis];
384 self.dimension[unroll_axis] = 1;
385 let mut index_ = self.dimension.first_index();
386 let inner_strides = self.parts.stride_of(unroll_axis);
387 // Loop unrolled over closest axis
388 while let Some(index) = index_ {
389 unsafe {
390 let ptr = self.parts.uget_ptr(&index);
391 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
392 }
393
394 index_ = self.dimension.next_for(index);
395 }
396 FoldWhile::Continue(acc)
397 }
398
399 // Non-contiguous but preference for F - unroll over Axis(0)
400 fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
401 where
402 F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
403 P: ZippableTuple<Dim = D>,
404 {
405 let unroll_axis = 0;
406 let inner_len = self.dimension[unroll_axis];
407 self.dimension[unroll_axis] = 1;
408 let index_ = self.dimension.first_index();
409 let inner_strides = self.parts.stride_of(unroll_axis);
410 // Loop unrolled over closest axis
411 if let Some(mut index) = index_ {
412 loop {
413 unsafe {
414 let ptr = self.parts.uget_ptr(&index);
415 acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
416 }
417
418 if !self.dimension.next_for_f(&mut index) {
419 break;
420 }
421 }
422 }
423 FoldWhile::Continue(acc)
424 }
425
426 #[cfg(feature = "rayon")]
427 pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
428 {
429 let is_f = self.prefer_f();
430 Array::uninit(self.dimension.clone().set_f(is_f))
431 }
432}
433
434impl<D, P1, P2> Zip<(P1, P2), D>
435where
436 D: Dimension,
437 P1: NdProducer<Dim = D>,
438 P1: NdProducer<Dim = D>,
439{
440 /// Debug assert traversal order is like c (including 1D case)
441 // Method placement: only used for binary Zip at the moment.
442 #[inline]
443 pub(crate) fn debug_assert_c_order(self) -> Self
444 {
445 debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
446 self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
447 "Assertion failed: traversal is not c-order or 1D for \
448 layout {:?}, tendency {}, dimension {:?}",
449 self.layout, self.layout_tendency, self.dimension);
450 self
451 }
452}
453
454/*
455trait Offset : Copy {
456 unsafe fn offset(self, off: isize) -> Self;
457 unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
458 self.offset(index as isize * stride)
459 }
460}
461
462impl<T> Offset for *mut T {
463 unsafe fn offset(self, off: isize) -> Self {
464 self.offset(off)
465 }
466}
467*/
468
469trait OffsetTuple
470{
471 type Args;
472 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
473}
474
475impl<T> OffsetTuple for *mut T
476{
477 type Args = isize;
478 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self
479 {
480 self.offset(index as isize * stride)
481 }
482}
483
484macro_rules! offset_impl {
485 ($([$($param:ident)*][ $($q:ident)*],)+) => {
486 $(
487 #[allow(non_snake_case)]
488 impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
489 type Args = ($($param::Stride,)*);
490 unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
491 let ($($param, )*) = self;
492 let ($($q, )*) = stride;
493 ($(Offset::stride_offset($param, $q, index),)*)
494 }
495 }
496 )+
497 };
498}
499
500offset_impl! {
501 [A ][ a],
502 [A B][ a b],
503 [A B C][ a b c],
504 [A B C D][ a b c d],
505 [A B C D E][ a b c d e],
506 [A B C D E F][ a b c d e f],
507}
508
509macro_rules! zipt_impl {
510 ($([$($p:ident)*][ $($q:ident)*],)+) => {
511 $(
512 #[allow(non_snake_case)]
513 impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
514 type Item = ($($p::Item, )*);
515 type Ptr = ($($p::Ptr, )*);
516 type Dim = Dim;
517 type Stride = ($($p::Stride,)* );
518
519 fn stride_of(&self, index: usize) -> Self::Stride {
520 let ($(ref $p,)*) = *self;
521 ($($p.stride_of(Axis(index)), )*)
522 }
523
524 fn contiguous_stride(&self) -> Self::Stride {
525 let ($(ref $p,)*) = *self;
526 ($($p.contiguous_stride(), )*)
527 }
528
529 fn as_ptr(&self) -> Self::Ptr {
530 let ($(ref $p,)*) = *self;
531 ($($p.as_ptr(), )*)
532 }
533 unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
534 let ($(ref $q ,)*) = *self;
535 let ($($p,)*) = ptr;
536 ($($q.as_ref($p),)*)
537 }
538
539 unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
540 let ($(ref $p,)*) = *self;
541 ($($p.uget_ptr(i), )*)
542 }
543
544 fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
545 let ($($p,)*) = self;
546 let ($($p,)*) = (
547 $($p.split_at(axis, index), )*
548 );
549 (
550 ($($p.0,)*),
551 ($($p.1,)*)
552 )
553 }
554 }
555 )+
556 };
557}
558
559zipt_impl! {
560 [A ][ a],
561 [A B][ a b],
562 [A B C][ a b c],
563 [A B C D][ a b c d],
564 [A B C D E][ a b c d e],
565 [A B C D E F][ a b c d e f],
566}
567
568macro_rules! map_impl {
569 ($([$notlast:ident $($p:ident)*],)+) => {
570 $(
571 #[allow(non_snake_case)]
572 impl<D, $($p),*> Zip<($($p,)*), D>
573 where D: Dimension,
574 $($p: NdProducer<Dim=D> ,)*
575 {
576 /// Apply a function to all elements of the input arrays,
577 /// visiting elements in lock step.
578 pub fn for_each<F>(mut self, mut function: F)
579 where F: FnMut($($p::Item),*)
580 {
581 self.for_each_core((), move |(), args| {
582 let ($($p,)*) = args;
583 FoldWhile::Continue(function($($p),*))
584 });
585 }
586
587 /// Apply a fold function to all elements of the input arrays,
588 /// visiting elements in lock step.
589 ///
590 /// # Example
591 ///
592 /// The expression `tr(AᵀB)` can be more efficiently computed as
593 /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
594 /// elements of the entry-wise product). It would be possible to
595 /// evaluate this expression by first computing the entry-wise
596 /// product, `A∘B`, and then computing the elementwise sum of that
597 /// product, but it's possible to do this in a single loop (and
598 /// avoid an extra heap allocation if `A` and `B` can't be
599 /// consumed) by using `Zip`:
600 ///
601 /// ```
602 /// use ndarray::{array, Zip};
603 ///
604 /// let a = array![[1, 5], [3, 7]];
605 /// let b = array![[2, 4], [8, 6]];
606 ///
607 /// // Without using `Zip`. This involves two loops and an extra
608 /// // heap allocation for the result of `&a * &b`.
609 /// let sum_prod_nonzip = (&a * &b).sum();
610 /// // Using `Zip`. This is a single loop without any heap allocations.
611 /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
612 ///
613 /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
614 /// ```
615 pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
616 where
617 F: FnMut(Acc, $($p::Item),*) -> Acc,
618 {
619 self.for_each_core(acc, move |acc, args| {
620 let ($($p,)*) = args;
621 FoldWhile::Continue(function(acc, $($p),*))
622 }).into_inner()
623 }
624
625 /// Apply a fold function to the input arrays while the return
626 /// value is `FoldWhile::Continue`, visiting elements in lock step.
627 ///
628 pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
629 -> FoldWhile<Acc>
630 where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
631 {
632 self.for_each_core(acc, move |acc, args| {
633 let ($($p,)*) = args;
634 function(acc, $($p),*)
635 })
636 }
637
638 /// Tests if every element of the iterator matches a predicate.
639 ///
640 /// Returns `true` if `predicate` evaluates to `true` for all elements.
641 /// Returns `true` if the input arrays are empty.
642 ///
643 /// Example:
644 ///
645 /// ```
646 /// use ndarray::{array, Zip};
647 /// let a = array![1, 2, 3];
648 /// let b = array![1, 4, 9];
649 /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
650 /// ```
651 pub fn all<F>(mut self, mut predicate: F) -> bool
652 where F: FnMut($($p::Item),*) -> bool
653 {
654 !self.for_each_core((), move |_, args| {
655 let ($($p,)*) = args;
656 if predicate($($p),*) {
657 FoldWhile::Continue(())
658 } else {
659 FoldWhile::Done(())
660 }
661 }).is_done()
662 }
663
664 /// Tests if at least one element of the iterator matches a predicate.
665 ///
666 /// Returns `true` if `predicate` evaluates to `true` for at least one element.
667 /// Returns `false` if the input arrays are empty.
668 ///
669 /// Example:
670 ///
671 /// ```
672 /// use ndarray::{array, Zip};
673 /// let a = array![1, 2, 3];
674 /// let b = array![1, 4, 9];
675 /// assert!(Zip::from(&a).and(&b).any(|&a, &b| a == b));
676 /// assert!(!Zip::from(&a).and(&b).any(|&a, &b| a - 1 == b));
677 /// ```
678 pub fn any<F>(mut self, mut predicate: F) -> bool
679 where F: FnMut($($p::Item),*) -> bool
680 {
681 self.for_each_core((), move |_, args| {
682 let ($($p,)*) = args;
683 if predicate($($p),*) {
684 FoldWhile::Done(())
685 } else {
686 FoldWhile::Continue(())
687 }
688 }).is_done()
689 }
690
691 expand_if!(@bool [$notlast]
692
693 /// Include the producer `p` in the Zip.
694 ///
695 /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
696 #[track_caller]
697 pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
698 where P: IntoNdProducer<Dim=D>,
699 {
700 let part = p.into_producer();
701 zip_dimension_check(&self.dimension, &part);
702 self.build_and(part)
703 }
704
705 /// Include the producer `p` in the Zip.
706 ///
707 /// ## Safety
708 ///
709 /// The caller must ensure that the producer's shape is equal to the Zip's shape.
710 /// Uses assertions when debug assertions are enabled.
711 #[allow(unused)]
712 pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
713 where P: IntoNdProducer<Dim=D>,
714 {
715 #[cfg(debug_assertions)]
716 {
717 self.and(p)
718 }
719 #[cfg(not(debug_assertions))]
720 {
721 self.build_and(p.into_producer())
722 }
723 }
724
725 /// Include the producer `p` in the Zip, broadcasting if needed.
726 ///
727 /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
728 ///
729 /// ***Panics*** if broadcasting isn’t possible.
730 #[track_caller]
731 pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
732 -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
733 where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
734 D2: Dimension,
735 {
736 let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
737 self.build_and(part)
738 }
739
740 fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
741 where P: NdProducer<Dim=D>,
742 {
743 let part_layout = part.layout();
744 let ($($p,)*) = self.parts;
745 Zip {
746 parts: ($($p,)* part, ),
747 layout: self.layout.intersect(part_layout),
748 dimension: self.dimension,
749 layout_tendency: self.layout_tendency + part_layout.tendency(),
750 }
751 }
752
753 /// Map and collect the results into a new array, which has the same size as the
754 /// inputs.
755 ///
756 /// If all inputs are c- or f-order respectively, that is preserved in the output.
757 pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
758 self.map_collect_owned(f)
759 }
760
761 pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
762 -> ArrayBase<S, D>
763 where
764 S: DataOwned<Elem = R>,
765 {
766 // safe because: all elements are written before the array is completed
767
768 let shape = self.dimension.clone().set_f(self.prefer_f());
769 let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
770 // Use partial to count the number of filled elements, and can drop the right
771 // number of elements on unwinding (if it happens during apply/collect).
772 unsafe {
773 let output_view = output.into_raw_view_mut().cast::<R>();
774 self.and(output_view)
775 .collect_with_partial(f)
776 .release_ownership();
777 }
778 });
779 unsafe {
780 output.assume_init()
781 }
782 }
783
784 /// Map and assign the results into the producer `into`, which should have the same
785 /// size as the other inputs.
786 ///
787 /// The producer should have assignable items as dictated by the `AssignElem` trait,
788 /// for example `&mut R`.
789 pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
790 where Q: IntoNdProducer<Dim=D>,
791 Q::Item: AssignElem<R>
792 {
793 self.and(into)
794 .for_each(move |$($p, )* output_| {
795 output_.assign_elem(f($($p ),*));
796 });
797 }
798
799 );
800
801 /// Split the `Zip` evenly in two.
802 ///
803 /// It will be split in the way that best preserves element locality.
804 pub fn split(self) -> (Self, Self) {
805 debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
806 debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
807 SplitPreference::split(self)
808 }
809 }
810
811 expand_if!(@bool [$notlast]
812 // For collect; Last producer is a RawViewMut
813 #[allow(non_snake_case)]
814 impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
815 where D: Dimension,
816 $($p: NdProducer<Dim=D> ,)*
817 PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
818 {
819 /// The inner workings of map_collect and par_map_collect
820 ///
821 /// Apply the function and collect the results into the output (last producer)
822 /// which should be a raw array view; a Partial that owns the written
823 /// elements is returned.
824 ///
825 /// Elements will be overwritten in place (in the sense of std::ptr::write).
826 ///
827 /// ## Safety
828 ///
829 /// The last producer is a RawArrayViewMut and must be safe to write into.
830 /// The producer must be c- or f-contig and have the same layout tendency
831 /// as the whole Zip.
832 ///
833 /// The returned Partial's proxy ownership of the elements must be handled,
834 /// before the array the raw view points to realizes its ownership.
835 pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
836 where F: FnMut($($p::Item,)* ) -> R
837 {
838 // Get the last producer; and make a Partial that aliases its data pointer
839 let (.., ref output) = &self.parts;
840
841 // debug assert that the output is contiguous in the memory layout we need
842 if cfg!(debug_assertions) {
843 let out_layout = output.layout();
844 assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
845 assert!(
846 (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
847 (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
848 "layout tendency violation for self layout {:?}, output layout {:?},\
849 output shape {:?}",
850 self.layout, out_layout, output.raw_dim());
851 }
852
853 let mut partial = Partial::new(output.as_ptr());
854
855 // Apply the mapping function on this zip
856 // if we panic with unwinding; Partial will drop the written elements.
857 let partial_len = &mut partial.len;
858 self.for_each(move |$($p,)* output_elem: *mut R| {
859 output_elem.write(f($($p),*));
860 if std::mem::needs_drop::<R>() {
861 *partial_len += 1;
862 }
863 });
864
865 partial
866 }
867 }
868 );
869
870 impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
871 where D: Dimension,
872 $($p: NdProducer<Dim=D> ,)*
873 {
874 fn can_split(&self) -> bool { self.size() > 1 }
875
876 fn split_preference(&self) -> (Axis, usize) {
877 // Always split in a way that preserves layout (if any)
878 let axis = self.max_stride_axis();
879 let index = self.len_of(axis) / 2;
880 (axis, index)
881 }
882 }
883
884 impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
885 where D: Dimension,
886 $($p: NdProducer<Dim=D> ,)*
887 {
888 fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
889 let (p1, p2) = self.parts.split_at(axis, index);
890 let (d1, d2) = self.dimension.split_at(axis, index);
891 (Zip {
892 dimension: d1,
893 layout: self.layout,
894 parts: p1,
895 layout_tendency: self.layout_tendency,
896 },
897 Zip {
898 dimension: d2,
899 layout: self.layout,
900 parts: p2,
901 layout_tendency: self.layout_tendency,
902 })
903 }
904
905 }
906
907 )+
908 };
909}
910
911map_impl! {
912 [true P1],
913 [true P1 P2],
914 [true P1 P2 P3],
915 [true P1 P2 P3 P4],
916 [true P1 P2 P3 P4 P5],
917 [false P1 P2 P3 P4 P5 P6],
918}
919
920/// Value controlling the execution of `.fold_while` on `Zip`.
921#[derive(Debug, Copy, Clone)]
922pub enum FoldWhile<T>
923{
924 /// Continue folding with this value
925 Continue(T),
926 /// Fold is complete and will return this value
927 Done(T),
928}
929
930impl<T> FoldWhile<T>
931{
932 /// Return the inner value
933 pub fn into_inner(self) -> T
934 {
935 match self {
936 FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
937 }
938 }
939
940 /// Return true if it is `Done`, false if `Continue`
941 pub fn is_done(&self) -> bool
942 {
943 match *self {
944 FoldWhile::Continue(_) => false,
945 FoldWhile::Done(_) => true,
946 }
947 }
948}