Thanks to visit codestin.com
Credit goes to docs.rs

ndarray/zip/
mod.rs

1// Copyright 2017 bluss and ndarray developers.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9#[macro_use]
10mod zipmacro;
11mod ndproducer;
12
13#[cfg(feature = "rayon")]
14use std::mem::MaybeUninit;
15
16use crate::imp_prelude::*;
17use crate::partial::Partial;
18use crate::AssignElem;
19use crate::IntoDimension;
20use crate::Layout;
21
22use crate::dimension;
23use crate::indexes::{indices, Indices};
24use crate::split_at::{SplitAt, SplitPreference};
25
26pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset};
27
28/// Return if the expression is a break value.
29macro_rules! fold_while {
30    ($e:expr) => {
31        match $e {
32            FoldWhile::Continue(x) => x,
33            x => return x,
34        }
35    };
36}
37
38/// Broadcast an array so that it acts like a larger size and/or shape array.
39///
40/// See [broadcasting](ArrayBase#broadcasting) for more information.
41trait Broadcast<E>
42where E: IntoDimension
43{
44    type Output: NdProducer<Dim = E::Dim>;
45    /// Broadcast the array to the new dimensions `shape`.
46    ///
47    /// ***Panics*** if broadcasting isn’t possible.
48    #[track_caller]
49    fn broadcast_unwrap(self, shape: E) -> Self::Output;
50    private_decl! {}
51}
52
53/// Compute `Layout` hints for array shape dim, strides
54fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout
55{
56    let n = dim.ndim();
57    if dimension::is_layout_c(dim, strides) {
58        // effectively one-dimensional => C and F layout compatible
59        if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
60            Layout::one_dimensional()
61        } else {
62            Layout::c()
63        }
64    } else if n > 1 && dimension::is_layout_f(dim, strides) {
65        Layout::f()
66    } else if n > 1 {
67        if dim[0] > 1 && strides[0] == 1 {
68            Layout::fpref()
69        } else if dim[n - 1] > 1 && strides[n - 1] == 1 {
70            Layout::cpref()
71        } else {
72            Layout::none()
73        }
74    } else {
75        Layout::none()
76    }
77}
78
79impl<A, D> LayoutRef<A, D>
80where D: Dimension
81{
82    pub(crate) fn layout_impl(&self) -> Layout
83    {
84        array_layout(self._dim(), self._strides())
85    }
86}
87
88impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
89where
90    E: IntoDimension,
91    D: Dimension,
92{
93    type Output = ArrayView<'a, A, E::Dim>;
94    fn broadcast_unwrap(self, shape: E) -> Self::Output
95    {
96        #[allow(clippy::needless_borrow)]
97        let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension());
98        unsafe { ArrayView::new(res.parts.ptr, res.parts.dim, res.parts.strides) }
99    }
100    private_impl! {}
101}
102
103trait ZippableTuple: Sized
104{
105    type Item;
106    type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
107    type Dim: Dimension;
108    type Stride: Copy;
109    fn as_ptr(&self) -> Self::Ptr;
110    unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
111    unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
112    fn stride_of(&self, index: usize) -> Self::Stride;
113    fn contiguous_stride(&self) -> Self::Stride;
114    fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
115}
116
117/// Lock step function application across several arrays or other producers.
118///
119/// Zip allows matching several producers to each other elementwise and applying
120/// a function over all tuples of elements (one item from each input at
121/// a time).
122///
123/// In general, the zip uses a tuple of producers
124/// ([`NdProducer`] trait) that all have to be of the
125/// same shape. The NdProducer implementation defines what its item type is
126/// (for example if it's a shared reference, mutable reference or an array
127/// view etc).
128///
129/// If all the input arrays are of the same memory layout the zip performs much
130/// better and the compiler can usually vectorize the loop (if applicable).
131///
132/// The order elements are visited is not specified. The producers don’t have to
133/// have the same item type.
134///
135/// The `Zip` has two methods for function application: `for_each` and
136/// `fold_while`. The zip object can be split, which allows parallelization.
137/// A read-only zip object (no mutable producers) can be cloned.
138///
139/// See also the [`azip!()`] which offers a convenient shorthand
140/// to common ways to use `Zip`.
141///
142/// ```
143/// use ndarray::Zip;
144/// use ndarray::Array2;
145///
146/// type M = Array2<f64>;
147///
148/// // Create four 2d arrays of the same size
149/// let mut a = M::zeros((64, 32));
150/// let b = M::from_elem(a.dim(), 1.);
151/// let c = M::from_elem(a.dim(), 2.);
152/// let d = M::from_elem(a.dim(), 3.);
153///
154/// // Example 1: Perform an elementwise arithmetic operation across
155/// // the four arrays a, b, c, d.
156///
157/// Zip::from(&mut a)
158///     .and(&b)
159///     .and(&c)
160///     .and(&d)
161///     .for_each(|w, &x, &y, &z| {
162///         *w += x + y * z;
163///     });
164///
165/// // Example 2: Create a new array `totals` with one entry per row of `a`.
166/// //  Use Zip to traverse the rows of `a` and assign to the corresponding
167/// //  entry in `totals` with the sum across each row.
168/// //  This is possible because the producer for `totals` and the row producer
169/// //  for `a` have the same shape and dimensionality.
170/// //  The rows producer yields one array view (`row`) per iteration.
171///
172/// use ndarray::{Array1, Axis};
173///
174/// let mut totals = Array1::zeros(a.nrows());
175///
176/// Zip::from(&mut totals)
177///     .and(a.rows())
178///     .for_each(|totals, row| *totals = row.sum());
179///
180/// // Check the result against the built in `.sum_axis()` along axis 1.
181/// assert_eq!(totals, a.sum_axis(Axis(1)));
182///
183///
184/// // Example 3: Recreate Example 2 using map_collect to make a new array
185///
186/// let totals2 = Zip::from(a.rows()).map_collect(|row| row.sum());
187///
188/// // Check the result against the previous example.
189/// assert_eq!(totals, totals2);
190/// ```
191#[derive(Debug, Clone)]
192#[must_use = "zipping producers is lazy and does nothing unless consumed"]
193pub struct Zip<Parts, D>
194{
195    parts: Parts,
196    dimension: D,
197    layout: Layout,
198    /// The sum of the layout tendencies of the parts;
199    /// positive for c- and negative for f-layout preference.
200    layout_tendency: i32,
201}
202
203impl<P, D> Zip<(P,), D>
204where
205    D: Dimension,
206    P: NdProducer<Dim = D>,
207{
208    /// Create a new `Zip` from the input array or other producer `p`.
209    ///
210    /// The Zip will take the exact dimension of `p` and all inputs
211    /// must have the same dimensions (or be broadcast to them).
212    pub fn from<IP>(p: IP) -> Self
213    where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
214    {
215        let array = p.into_producer();
216        let dim = array.raw_dim();
217        let layout = array.layout();
218        Zip {
219            dimension: dim,
220            layout,
221            parts: (array,),
222            layout_tendency: layout.tendency(),
223        }
224    }
225}
226impl<P, D> Zip<(Indices<D>, P), D>
227where
228    D: Dimension + Copy,
229    P: NdProducer<Dim = D>,
230{
231    /// Create a new `Zip` with an index producer and the producer `p`.
232    ///
233    /// The Zip will take the exact dimension of `p` and all inputs
234    /// must have the same dimensions (or be broadcast to them).
235    ///
236    /// *Note:* Indexed zip has overhead.
237    pub fn indexed<IP>(p: IP) -> Self
238    where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
239    {
240        let array = p.into_producer();
241        let dim = array.raw_dim();
242        Zip::from(indices(dim)).and(array)
243    }
244}
245
246#[inline]
247fn zip_dimension_check<D, P>(dimension: &D, part: &P)
248where
249    D: Dimension,
250    P: NdProducer<Dim = D>,
251{
252    ndassert!(
253        part.equal_dim(dimension),
254        "Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
255        dimension,
256        part.raw_dim()
257    );
258}
259
260impl<Parts, D> Zip<Parts, D>
261where D: Dimension
262{
263    /// Return a the number of element tuples in the Zip
264    pub fn size(&self) -> usize
265    {
266        self.dimension.size()
267    }
268
269    /// Return the length of `axis`
270    ///
271    /// ***Panics*** if `axis` is out of bounds.
272    #[track_caller]
273    fn len_of(&self, axis: Axis) -> usize
274    {
275        self.dimension[axis.index()]
276    }
277
278    fn prefer_f(&self) -> bool
279    {
280        !self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
281    }
282
283    /// Return an *approximation* to the max stride axis; if
284    /// component arrays disagree, there may be no choice better than the
285    /// others.
286    fn max_stride_axis(&self) -> Axis
287    {
288        let i = if self.prefer_f() {
289            self.dimension
290                .slice()
291                .iter()
292                .rposition(|&len| len > 1)
293                .unwrap_or(self.dimension.ndim() - 1)
294        } else {
295            /* corder or default */
296            self.dimension
297                .slice()
298                .iter()
299                .position(|&len| len > 1)
300                .unwrap_or(0)
301        };
302        Axis(i)
303    }
304}
305
306impl<P, D> Zip<P, D>
307where D: Dimension
308{
309    fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
310    where
311        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
312        P: ZippableTuple<Dim = D>,
313    {
314        if self.dimension.ndim() == 0 {
315            function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
316        } else if self.layout.is(Layout::CORDER | Layout::FORDER) {
317            self.for_each_core_contiguous(acc, function)
318        } else {
319            self.for_each_core_strided(acc, function)
320        }
321    }
322
323    fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
324    where
325        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
326        P: ZippableTuple<Dim = D>,
327    {
328        debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
329        let size = self.dimension.size();
330        let ptrs = self.parts.as_ptr();
331        let inner_strides = self.parts.contiguous_stride();
332        unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) }
333    }
334
335    /// The innermost loop of the Zip for_each methods
336    ///
337    /// Run the fold while operation on a stretch of elements with constant strides
338    ///
339    /// `ptr`: base pointer for the first element in this stretch
340    /// `strides`: strides for the elements in this stretch
341    /// `len`: number of elements
342    /// `function`: closure
343    unsafe fn inner<F, Acc>(
344        &self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F,
345    ) -> FoldWhile<Acc>
346    where
347        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
348        P: ZippableTuple,
349    {
350        let mut i = 0;
351        while i < len {
352            let p = ptr.stride_offset(strides, i);
353            acc = fold_while!(function(acc, self.parts.as_ref(p)));
354            i += 1;
355        }
356        FoldWhile::Continue(acc)
357    }
358
359    fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
360    where
361        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
362        P: ZippableTuple<Dim = D>,
363    {
364        let n = self.dimension.ndim();
365        if n == 0 {
366            panic!("Unreachable: ndim == 0 is contiguous")
367        }
368        if n == 1 || self.layout_tendency >= 0 {
369            self.for_each_core_strided_c(acc, function)
370        } else {
371            self.for_each_core_strided_f(acc, function)
372        }
373    }
374
375    // Non-contiguous but preference for C - unroll over Axis(ndim - 1)
376    fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
377    where
378        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
379        P: ZippableTuple<Dim = D>,
380    {
381        let n = self.dimension.ndim();
382        let unroll_axis = n - 1;
383        let inner_len = self.dimension[unroll_axis];
384        self.dimension[unroll_axis] = 1;
385        let mut index_ = self.dimension.first_index();
386        let inner_strides = self.parts.stride_of(unroll_axis);
387        // Loop unrolled over closest axis
388        while let Some(index) = index_ {
389            unsafe {
390                let ptr = self.parts.uget_ptr(&index);
391                acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
392            }
393
394            index_ = self.dimension.next_for(index);
395        }
396        FoldWhile::Continue(acc)
397    }
398
399    // Non-contiguous but preference for F - unroll over Axis(0)
400    fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
401    where
402        F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
403        P: ZippableTuple<Dim = D>,
404    {
405        let unroll_axis = 0;
406        let inner_len = self.dimension[unroll_axis];
407        self.dimension[unroll_axis] = 1;
408        let index_ = self.dimension.first_index();
409        let inner_strides = self.parts.stride_of(unroll_axis);
410        // Loop unrolled over closest axis
411        if let Some(mut index) = index_ {
412            loop {
413                unsafe {
414                    let ptr = self.parts.uget_ptr(&index);
415                    acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
416                }
417
418                if !self.dimension.next_for_f(&mut index) {
419                    break;
420                }
421            }
422        }
423        FoldWhile::Continue(acc)
424    }
425
426    #[cfg(feature = "rayon")]
427    pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
428    {
429        let is_f = self.prefer_f();
430        Array::uninit(self.dimension.clone().set_f(is_f))
431    }
432}
433
434impl<D, P1, P2> Zip<(P1, P2), D>
435where
436    D: Dimension,
437    P1: NdProducer<Dim = D>,
438    P1: NdProducer<Dim = D>,
439{
440    /// Debug assert traversal order is like c (including 1D case)
441    // Method placement: only used for binary Zip at the moment.
442    #[inline]
443    pub(crate) fn debug_assert_c_order(self) -> Self
444    {
445        debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
446                      self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
447                      "Assertion failed: traversal is not c-order or 1D for \
448                      layout {:?}, tendency {}, dimension {:?}",
449                      self.layout, self.layout_tendency, self.dimension);
450        self
451    }
452}
453
454/*
455trait Offset : Copy {
456    unsafe fn offset(self, off: isize) -> Self;
457    unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
458        self.offset(index as isize * stride)
459    }
460}
461
462impl<T> Offset for *mut T {
463    unsafe fn offset(self, off: isize) -> Self {
464        self.offset(off)
465    }
466}
467*/
468
469trait OffsetTuple
470{
471    type Args;
472    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
473}
474
475impl<T> OffsetTuple for *mut T
476{
477    type Args = isize;
478    unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self
479    {
480        self.offset(index as isize * stride)
481    }
482}
483
484macro_rules! offset_impl {
485    ($([$($param:ident)*][ $($q:ident)*],)+) => {
486        $(
487        #[allow(non_snake_case)]
488        impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
489            type Args = ($($param::Stride,)*);
490            unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
491                let ($($param, )*) = self;
492                let ($($q, )*) = stride;
493                ($(Offset::stride_offset($param, $q, index),)*)
494            }
495        }
496        )+
497    };
498}
499
500offset_impl! {
501    [A ][ a],
502    [A B][ a b],
503    [A B C][ a b c],
504    [A B C D][ a b c d],
505    [A B C D E][ a b c d e],
506    [A B C D E F][ a b c d e f],
507}
508
509macro_rules! zipt_impl {
510    ($([$($p:ident)*][ $($q:ident)*],)+) => {
511        $(
512        #[allow(non_snake_case)]
513        impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
514            type Item = ($($p::Item, )*);
515            type Ptr = ($($p::Ptr, )*);
516            type Dim = Dim;
517            type Stride = ($($p::Stride,)* );
518
519            fn stride_of(&self, index: usize) -> Self::Stride {
520                let ($(ref $p,)*) = *self;
521                ($($p.stride_of(Axis(index)), )*)
522            }
523
524            fn contiguous_stride(&self) -> Self::Stride {
525                let ($(ref $p,)*) = *self;
526                ($($p.contiguous_stride(), )*)
527            }
528
529            fn as_ptr(&self) -> Self::Ptr {
530                let ($(ref $p,)*) = *self;
531                ($($p.as_ptr(), )*)
532            }
533            unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
534                let ($(ref $q ,)*) = *self;
535                let ($($p,)*) = ptr;
536                ($($q.as_ref($p),)*)
537            }
538
539            unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
540                let ($(ref $p,)*) = *self;
541                ($($p.uget_ptr(i), )*)
542            }
543
544            fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
545                let ($($p,)*) = self;
546                let ($($p,)*) = (
547                    $($p.split_at(axis, index), )*
548                );
549                (
550                    ($($p.0,)*),
551                    ($($p.1,)*)
552                )
553            }
554        }
555        )+
556    };
557}
558
559zipt_impl! {
560    [A ][ a],
561    [A B][ a b],
562    [A B C][ a b c],
563    [A B C D][ a b c d],
564    [A B C D E][ a b c d e],
565    [A B C D E F][ a b c d e f],
566}
567
568macro_rules! map_impl {
569    ($([$notlast:ident $($p:ident)*],)+) => {
570        $(
571        #[allow(non_snake_case)]
572        impl<D, $($p),*> Zip<($($p,)*), D>
573            where D: Dimension,
574                  $($p: NdProducer<Dim=D> ,)*
575        {
576            /// Apply a function to all elements of the input arrays,
577            /// visiting elements in lock step.
578            pub fn for_each<F>(mut self, mut function: F)
579                where F: FnMut($($p::Item),*)
580            {
581                self.for_each_core((), move |(), args| {
582                    let ($($p,)*) = args;
583                    FoldWhile::Continue(function($($p),*))
584                });
585            }
586
587            /// Apply a fold function to all elements of the input arrays,
588            /// visiting elements in lock step.
589            ///
590            /// # Example
591            ///
592            /// The expression `tr(AᵀB)` can be more efficiently computed as
593            /// the equivalent expression `∑ᵢⱼ(A∘B)ᵢⱼ` (i.e. the sum of the
594            /// elements of the entry-wise product). It would be possible to
595            /// evaluate this expression by first computing the entry-wise
596            /// product, `A∘B`, and then computing the elementwise sum of that
597            /// product, but it's possible to do this in a single loop (and
598            /// avoid an extra heap allocation if `A` and `B` can't be
599            /// consumed) by using `Zip`:
600            ///
601            /// ```
602            /// use ndarray::{array, Zip};
603            ///
604            /// let a = array![[1, 5], [3, 7]];
605            /// let b = array![[2, 4], [8, 6]];
606            ///
607            /// // Without using `Zip`. This involves two loops and an extra
608            /// // heap allocation for the result of `&a * &b`.
609            /// let sum_prod_nonzip = (&a * &b).sum();
610            /// // Using `Zip`. This is a single loop without any heap allocations.
611            /// let sum_prod_zip = Zip::from(&a).and(&b).fold(0, |acc, a, b| acc + a * b);
612            ///
613            /// assert_eq!(sum_prod_nonzip, sum_prod_zip);
614            /// ```
615            pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
616            where
617                F: FnMut(Acc, $($p::Item),*) -> Acc,
618            {
619                self.for_each_core(acc, move |acc, args| {
620                    let ($($p,)*) = args;
621                    FoldWhile::Continue(function(acc, $($p),*))
622                }).into_inner()
623            }
624
625            /// Apply a fold function to the input arrays while the return
626            /// value is `FoldWhile::Continue`, visiting elements in lock step.
627            ///
628            pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
629                -> FoldWhile<Acc>
630                where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
631            {
632                self.for_each_core(acc, move |acc, args| {
633                    let ($($p,)*) = args;
634                    function(acc, $($p),*)
635                })
636            }
637
638            /// Tests if every element of the iterator matches a predicate.
639            ///
640            /// Returns `true` if `predicate` evaluates to `true` for all elements.
641            /// Returns `true` if the input arrays are empty.
642            ///
643            /// Example:
644            ///
645            /// ```
646            /// use ndarray::{array, Zip};
647            /// let a = array![1, 2, 3];
648            /// let b = array![1, 4, 9];
649            /// assert!(Zip::from(&a).and(&b).all(|&a, &b| a * a == b));
650            /// ```
651            pub fn all<F>(mut self, mut predicate: F) -> bool
652                where F: FnMut($($p::Item),*) -> bool
653            {
654                !self.for_each_core((), move |_, args| {
655                    let ($($p,)*) = args;
656                    if predicate($($p),*) {
657                        FoldWhile::Continue(())
658                    } else {
659                        FoldWhile::Done(())
660                    }
661                }).is_done()
662            }
663
664            /// Tests if at least one element of the iterator matches a predicate.
665            ///
666            /// Returns `true` if `predicate` evaluates to `true` for at least one element.
667            /// Returns `false` if the input arrays are empty.
668            ///
669            /// Example:
670            ///
671            /// ```
672            /// use ndarray::{array, Zip};
673            /// let a = array![1, 2, 3];
674            /// let b = array![1, 4, 9];
675            /// assert!(Zip::from(&a).and(&b).any(|&a, &b| a == b));
676            /// assert!(!Zip::from(&a).and(&b).any(|&a, &b| a - 1 == b));
677            /// ```
678            pub fn any<F>(mut self, mut predicate: F) -> bool
679                where F: FnMut($($p::Item),*) -> bool
680            {
681                self.for_each_core((), move |_, args| {
682                    let ($($p,)*) = args;
683                    if predicate($($p),*) {
684                        FoldWhile::Done(())
685                    } else {
686                        FoldWhile::Continue(())
687                    }
688                }).is_done()
689            }
690
691            expand_if!(@bool [$notlast]
692
693            /// Include the producer `p` in the Zip.
694            ///
695            /// ***Panics*** if `p`’s shape doesn’t match the Zip’s exactly.
696            #[track_caller]
697            pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
698                where P: IntoNdProducer<Dim=D>,
699            {
700                let part = p.into_producer();
701                zip_dimension_check(&self.dimension, &part);
702                self.build_and(part)
703            }
704
705            /// Include the producer `p` in the Zip.
706            ///
707            /// ## Safety
708            ///
709            /// The caller must ensure that the producer's shape is equal to the Zip's shape.
710            /// Uses assertions when debug assertions are enabled.
711            #[allow(unused)]
712            pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
713                where P: IntoNdProducer<Dim=D>,
714            {
715                #[cfg(debug_assertions)]
716                {
717                    self.and(p)
718                }
719                #[cfg(not(debug_assertions))]
720                {
721                    self.build_and(p.into_producer())
722                }
723            }
724
725            /// Include the producer `p` in the Zip, broadcasting if needed.
726            ///
727            /// If their shapes disagree, `rhs` is broadcast to the shape of `self`.
728            ///
729            /// ***Panics*** if broadcasting isn’t possible.
730            #[track_caller]
731            pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
732                -> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
733                where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
734                      D2: Dimension,
735            {
736                let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
737                self.build_and(part)
738            }
739
740            fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
741                where P: NdProducer<Dim=D>,
742            {
743                let part_layout = part.layout();
744                let ($($p,)*) = self.parts;
745                Zip {
746                    parts: ($($p,)* part, ),
747                    layout: self.layout.intersect(part_layout),
748                    dimension: self.dimension,
749                    layout_tendency: self.layout_tendency + part_layout.tendency(),
750                }
751            }
752
753            /// Map and collect the results into a new array, which has the same size as the
754            /// inputs.
755            ///
756            /// If all inputs are c- or f-order respectively, that is preserved in the output.
757            pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
758                self.map_collect_owned(f)
759            }
760
761            pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
762                -> ArrayBase<S, D>
763                where
764                    S: DataOwned<Elem = R>,
765            {
766                // safe because: all elements are written before the array is completed
767
768                let shape = self.dimension.clone().set_f(self.prefer_f());
769                let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
770                    // Use partial to count the number of filled elements, and can drop the right
771                    // number of elements on unwinding (if it happens during apply/collect).
772                    unsafe {
773                        let output_view = output.into_raw_view_mut().cast::<R>();
774                        self.and(output_view)
775                            .collect_with_partial(f)
776                            .release_ownership();
777                    }
778                });
779                unsafe {
780                    output.assume_init()
781                }
782            }
783
784            /// Map and assign the results into the producer `into`, which should have the same
785            /// size as the other inputs.
786            ///
787            /// The producer should have assignable items as dictated by the `AssignElem` trait,
788            /// for example `&mut R`.
789            pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
790                where Q: IntoNdProducer<Dim=D>,
791                      Q::Item: AssignElem<R>
792            {
793                self.and(into)
794                    .for_each(move |$($p, )* output_| {
795                        output_.assign_elem(f($($p ),*));
796                    });
797            }
798
799            );
800
801            /// Split the `Zip` evenly in two.
802            ///
803            /// It will be split in the way that best preserves element locality.
804            pub fn split(self) -> (Self, Self) {
805                debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
806                debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
807                SplitPreference::split(self)
808            }
809        }
810
811        expand_if!(@bool [$notlast]
812            // For collect; Last producer is a RawViewMut
813            #[allow(non_snake_case)]
814            impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
815                where D: Dimension,
816                      $($p: NdProducer<Dim=D> ,)*
817                      PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
818            {
819                /// The inner workings of map_collect and par_map_collect
820                ///
821                /// Apply the function and collect the results into the output (last producer)
822                /// which should be a raw array view; a Partial that owns the written
823                /// elements is returned.
824                ///
825                /// Elements will be overwritten in place (in the sense of std::ptr::write).
826                ///
827                /// ## Safety
828                ///
829                /// The last producer is a RawArrayViewMut and must be safe to write into.
830                /// The producer must be c- or f-contig and have the same layout tendency
831                /// as the whole Zip.
832                ///
833                /// The returned Partial's proxy ownership of the elements must be handled,
834                /// before the array the raw view points to realizes its ownership.
835                pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
836                    where F: FnMut($($p::Item,)* ) -> R
837                {
838                    // Get the last producer; and make a Partial that aliases its data pointer
839                    let (.., ref output) = &self.parts;
840
841                    // debug assert that the output is contiguous in the memory layout we need
842                    if cfg!(debug_assertions) {
843                        let out_layout = output.layout();
844                        assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
845                        assert!(
846                            (self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
847                            (self.layout_tendency >= 0 && out_layout.tendency() >= 0),
848                            "layout tendency violation for self layout {:?}, output layout {:?},\
849                            output shape {:?}",
850                            self.layout, out_layout, output.raw_dim());
851                    }
852
853                    let mut partial = Partial::new(output.as_ptr());
854
855                    // Apply the mapping function on this zip
856                    // if we panic with unwinding; Partial will drop the written elements.
857                    let partial_len = &mut partial.len;
858                    self.for_each(move |$($p,)* output_elem: *mut R| {
859                        output_elem.write(f($($p),*));
860                        if std::mem::needs_drop::<R>() {
861                            *partial_len += 1;
862                        }
863                    });
864
865                    partial
866                }
867            }
868        );
869
870        impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
871            where D: Dimension,
872                  $($p: NdProducer<Dim=D> ,)*
873        {
874            fn can_split(&self) -> bool { self.size() > 1 }
875
876            fn split_preference(&self) -> (Axis, usize) {
877                // Always split in a way that preserves layout (if any)
878                let axis = self.max_stride_axis();
879                let index = self.len_of(axis) / 2;
880                (axis, index)
881            }
882        }
883
884        impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
885            where D: Dimension,
886                  $($p: NdProducer<Dim=D> ,)*
887        {
888            fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
889                let (p1, p2) = self.parts.split_at(axis, index);
890                let (d1, d2) = self.dimension.split_at(axis, index);
891                (Zip {
892                    dimension: d1,
893                    layout: self.layout,
894                    parts: p1,
895                    layout_tendency: self.layout_tendency,
896                },
897                Zip {
898                    dimension: d2,
899                    layout: self.layout,
900                    parts: p2,
901                    layout_tendency: self.layout_tendency,
902                })
903            }
904
905        }
906
907        )+
908    };
909}
910
911map_impl! {
912    [true P1],
913    [true P1 P2],
914    [true P1 P2 P3],
915    [true P1 P2 P3 P4],
916    [true P1 P2 P3 P4 P5],
917    [false P1 P2 P3 P4 P5 P6],
918}
919
920/// Value controlling the execution of `.fold_while` on `Zip`.
921#[derive(Debug, Copy, Clone)]
922pub enum FoldWhile<T>
923{
924    /// Continue folding with this value
925    Continue(T),
926    /// Fold is complete and will return this value
927    Done(T),
928}
929
930impl<T> FoldWhile<T>
931{
932    /// Return the inner value
933    pub fn into_inner(self) -> T
934    {
935        match self {
936            FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
937        }
938    }
939
940    /// Return true if it is `Done`, false if `Continue`
941    pub fn is_done(&self) -> bool
942    {
943        match *self {
944            FoldWhile::Continue(_) => false,
945            FoldWhile::Done(_) => true,
946        }
947    }
948}