From a9605dc251776985f1e6a49f12f91b61dbe9daa1 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Tue, 6 Aug 2024 12:01:51 +0200 Subject: [PATCH 01/48] Make iterators covariant in element type The internal Baseiter type underlies most of the ndarray iterators, and it used `*mut A` for element type A. Update it to use `NonNull` which behaves identically except it's guaranteed to be non-null and is covariant w.r.t the parameter A. Add compile test from the issue. Fixes #1290 --- src/impl_owned_array.rs | 4 ++-- src/impl_views/conversions.rs | 8 ++++---- src/iterators/into_iter.rs | 5 ++--- src/iterators/mod.rs | 23 ++++++++++++++--------- tests/iterators.rs | 34 +++++++++++++++++++++++++++++++--- 5 files changed, 53 insertions(+), 21 deletions(-) diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index db176210c..97ed43a47 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -3,7 +3,7 @@ use alloc::vec::Vec; use std::mem; use std::mem::MaybeUninit; -#[allow(unused_imports)] +#[allow(unused_imports)] // Needed for Rust 1.64 use rawpointer::PointerExt; use crate::imp_prelude::*; @@ -907,7 +907,7 @@ where D: Dimension // iter is a raw pointer iterator traversing the array in memory order now with the // sorted axes. - let mut iter = Baseiter::new(self_.ptr.as_ptr(), self_.dim, self_.strides); + let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides); let mut dropped_elements = 0; let mut last_ptr = data_ptr; diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index ef6923a56..1dd7d97f2 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -199,7 +199,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } } @@ -209,7 +209,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } } @@ -220,7 +220,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } #[inline] @@ -262,7 +262,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr.as_ptr(), self.dim, self.strides) } + unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } } #[inline] diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs index fcc2e4b8c..e03c642ba 100644 --- a/src/iterators/into_iter.rs +++ b/src/iterators/into_iter.rs @@ -33,16 +33,15 @@ impl IntoIter where D: Dimension { /// Create a new by-value iterator that consumes `array` - pub(crate) fn new(mut array: Array) -> Self + pub(crate) fn new(array: Array) -> Self { unsafe { let array_head_ptr = array.ptr; - let ptr = array.as_mut_ptr(); let mut array_data = array.data; let data_len = array_data.release_all_elements(); debug_assert!(data_len >= array.dim.size()); let has_unreachable_elements = array.dim.size() != data_len; - let inner = Baseiter::new(ptr, array.dim, array.strides); + let inner = Baseiter::new(array_head_ptr, array.dim, array.strides); IntoIter { array_data, diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index d49ffe2d0..6978117ca 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -19,6 +19,10 @@ use alloc::vec::Vec; use std::iter::FromIterator; use std::marker::PhantomData; use std::ptr; +use std::ptr::NonNull; + +#[allow(unused_imports)] // Needed for Rust 1.64 +use rawpointer::PointerExt; use crate::Ix1; @@ -38,7 +42,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; #[derive(Debug)] pub struct Baseiter { - ptr: *mut A, + ptr: NonNull, dim: D, strides: D, index: Option, @@ -50,7 +54,7 @@ impl Baseiter /// to be correct to avoid performing an unsafe pointer offset while /// iterating. #[inline] - pub unsafe fn new(ptr: *mut A, len: D, stride: D) -> Baseiter + pub unsafe fn new(ptr: NonNull, len: D, stride: D) -> Baseiter { Baseiter { ptr, @@ -74,7 +78,7 @@ impl Iterator for Baseiter }; let offset = D::stride_offset(&index, &self.strides); self.index = self.dim.next_for(index); - unsafe { Some(self.ptr.offset(offset)) } + unsafe { Some(self.ptr.offset(offset).as_ptr()) } } fn size_hint(&self) -> (usize, Option) @@ -99,7 +103,7 @@ impl Iterator for Baseiter let mut i = 0; let i_end = len - elem_index; while i < i_end { - accum = g(accum, row_ptr.offset(i as isize * stride)); + accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr()); i += 1; } } @@ -140,12 +144,12 @@ impl DoubleEndedIterator for Baseiter Some(ix) => ix, }; self.dim[0] -= 1; - let offset = <_>::stride_offset(&self.dim, &self.strides); + let offset = Ix1::stride_offset(&self.dim, &self.strides); if index == self.dim { self.index = None; } - unsafe { Some(self.ptr.offset(offset)) } + unsafe { Some(self.ptr.offset(offset).as_ptr()) } } fn nth_back(&mut self, n: usize) -> Option<*mut A> @@ -154,11 +158,11 @@ impl DoubleEndedIterator for Baseiter let len = self.dim[0] - index[0]; if n < len { self.dim[0] -= n + 1; - let offset = <_>::stride_offset(&self.dim, &self.strides); + let offset = Ix1::stride_offset(&self.dim, &self.strides); if index == self.dim { self.index = None; } - unsafe { Some(self.ptr.offset(offset)) } + unsafe { Some(self.ptr.offset(offset).as_ptr()) } } else { self.index = None; None @@ -178,7 +182,8 @@ impl DoubleEndedIterator for Baseiter accum = g( accum, self.ptr - .offset(Ix1::stride_offset(&self.dim, &self.strides)), + .offset(Ix1::stride_offset(&self.dim, &self.strides)) + .as_ptr(), ); } } diff --git a/tests/iterators.rs b/tests/iterators.rs index 23175fd40..908b64d15 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -1,6 +1,4 @@ -#![allow( - clippy::many_single_char_names, clippy::deref_addrof, clippy::unreadable_literal, clippy::many_single_char_names -)] +#![allow(clippy::deref_addrof, clippy::unreadable_literal)] use ndarray::prelude::*; use ndarray::{arr3, indices, s, Slice, Zip}; @@ -1055,3 +1053,33 @@ impl Drop for DropCount<'_> self.drops.set(self.drops.get() + 1); } } + +#[test] +fn test_impl_iter_compiles() +{ + // Requires that the iterators are covariant in the element type + + // base case: std + fn slice_iter_non_empty_indices<'s, 'a>(array: &'a Vec<&'s str>) -> impl Iterator + 'a + { + array + .iter() + .enumerate() + .filter(|(_index, elem)| !elem.is_empty()) + .map(|(index, _elem)| index) + } + + let _ = slice_iter_non_empty_indices; + + // ndarray case + fn array_iter_non_empty_indices<'s, 'a>(array: &'a Array<&'s str, Ix1>) -> impl Iterator + 'a + { + array + .iter() + .enumerate() + .filter(|(_index, elem)| !elem.is_empty()) + .map(|(index, _elem)| index) + } + + let _ = array_iter_non_empty_indices; +} From 00e15460ad6d99077f882e92857a9b431478244e Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Tue, 6 Aug 2024 12:26:23 +0200 Subject: [PATCH 02/48] Convert Baseiter to use NonNull throughout Complete the transition to using NonNull as the raw pointer type by using it as Baseiter's iterator element type. --- src/data_repr.rs | 5 ---- src/impl_owned_array.rs | 14 +++++----- src/iterators/chunks.rs | 4 +-- src/iterators/into_iter.rs | 4 +-- src/iterators/mod.rs | 55 +++++++++++++++++++++----------------- src/iterators/windows.rs | 2 +- 6 files changed, 43 insertions(+), 41 deletions(-) diff --git a/src/data_repr.rs b/src/data_repr.rs index a24cd7789..4041c192b 100644 --- a/src/data_repr.rs +++ b/src/data_repr.rs @@ -59,11 +59,6 @@ impl OwnedRepr self.ptr.as_ptr() } - pub(crate) fn as_ptr_mut(&self) -> *mut A - { - self.ptr.as_ptr() - } - pub(crate) fn as_nonnull_mut(&mut self) -> NonNull { self.ptr diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index 97ed43a47..44ac12dd4 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -1,5 +1,6 @@ #[cfg(not(feature = "std"))] use alloc::vec::Vec; +use core::ptr::NonNull; use std::mem; use std::mem::MaybeUninit; @@ -435,7 +436,7 @@ where D: Dimension // "deconstruct" self; the owned repr releases ownership of all elements and we // carry on with raw view methods let data_len = self.data.len(); - let data_ptr = self.data.as_nonnull_mut().as_ptr(); + let data_ptr = self.data.as_nonnull_mut(); unsafe { // Safety: self.data releases ownership of the elements. Any panics below this point @@ -866,8 +867,9 @@ where D: Dimension /// /// This is an internal function for use by move_into and IntoIter only, safety invariants may need /// to be upheld across the calls from those implementations. -pub(crate) unsafe fn drop_unreachable_raw(mut self_: RawArrayViewMut, data_ptr: *mut A, data_len: usize) -where D: Dimension +pub(crate) unsafe fn drop_unreachable_raw( + mut self_: RawArrayViewMut, data_ptr: NonNull, data_len: usize, +) where D: Dimension { let self_len = self_.len(); @@ -878,7 +880,7 @@ where D: Dimension } sort_axes_in_default_order(&mut self_); // with uninverted axes this is now the element with lowest address - let array_memory_head_ptr = self_.ptr.as_ptr(); + let array_memory_head_ptr = self_.ptr; let data_end_ptr = data_ptr.add(data_len); debug_assert!(data_ptr <= array_memory_head_ptr); debug_assert!(array_memory_head_ptr <= data_end_ptr); @@ -917,7 +919,7 @@ where D: Dimension // should now be dropped. This interval may be empty, then we just skip this loop. while last_ptr != elem_ptr { debug_assert!(last_ptr < data_end_ptr); - std::ptr::drop_in_place(last_ptr); + std::ptr::drop_in_place(last_ptr.as_mut()); last_ptr = last_ptr.add(1); dropped_elements += 1; } @@ -926,7 +928,7 @@ where D: Dimension } while last_ptr < data_end_ptr { - std::ptr::drop_in_place(last_ptr); + std::ptr::drop_in_place(last_ptr.as_mut()); last_ptr = last_ptr.add(1); dropped_elements += 1; } diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 465428968..9e2f08e1e 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -204,7 +204,7 @@ impl_iterator! { fn item(&mut self, ptr) { unsafe { - ArrayView::new_( + ArrayView::new( ptr, self.chunk.clone(), self.inner_strides.clone()) @@ -226,7 +226,7 @@ impl_iterator! { fn item(&mut self, ptr) { unsafe { - ArrayViewMut::new_( + ArrayViewMut::new( ptr, self.chunk.clone(), self.inner_strides.clone()) diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs index e03c642ba..9374608cb 100644 --- a/src/iterators/into_iter.rs +++ b/src/iterators/into_iter.rs @@ -61,7 +61,7 @@ impl Iterator for IntoIter #[inline] fn next(&mut self) -> Option { - self.inner.next().map(|p| unsafe { p.read() }) + self.inner.next().map(|p| unsafe { p.as_ptr().read() }) } fn size_hint(&self) -> (usize, Option) @@ -91,7 +91,7 @@ where D: Dimension while let Some(_) = self.next() {} unsafe { - let data_ptr = self.array_data.as_ptr_mut(); + let data_ptr = self.array_data.as_nonnull_mut(); let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), self.inner.strides.clone()); debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}", self.data_len, self.inner.dim.size()); diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 6978117ca..e7321d15b 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -38,7 +38,7 @@ use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; /// Base for iterators over all axes. /// -/// Iterator element type is `*mut A`. +/// Iterator element type is `NonNull`. #[derive(Debug)] pub struct Baseiter { @@ -67,10 +67,10 @@ impl Baseiter impl Iterator for Baseiter { - type Item = *mut A; + type Item = NonNull; #[inline] - fn next(&mut self) -> Option<*mut A> + fn next(&mut self) -> Option { let index = match self.index { None => return None, @@ -78,7 +78,7 @@ impl Iterator for Baseiter }; let offset = D::stride_offset(&index, &self.strides); self.index = self.dim.next_for(index); - unsafe { Some(self.ptr.offset(offset).as_ptr()) } + unsafe { Some(self.ptr.offset(offset)) } } fn size_hint(&self) -> (usize, Option) @@ -88,7 +88,7 @@ impl Iterator for Baseiter } fn fold(mut self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, *mut A) -> Acc + where G: FnMut(Acc, Self::Item) -> Acc { let ndim = self.dim.ndim(); debug_assert_ne!(ndim, 0); @@ -103,7 +103,7 @@ impl Iterator for Baseiter let mut i = 0; let i_end = len - elem_index; while i < i_end { - accum = g(accum, row_ptr.offset(i as isize * stride).as_ptr()); + accum = g(accum, row_ptr.offset(i as isize * stride)); i += 1; } } @@ -137,7 +137,7 @@ impl ExactSizeIterator for Baseiter impl DoubleEndedIterator for Baseiter { #[inline] - fn next_back(&mut self) -> Option<*mut A> + fn next_back(&mut self) -> Option { let index = match self.index { None => return None, @@ -149,10 +149,10 @@ impl DoubleEndedIterator for Baseiter self.index = None; } - unsafe { Some(self.ptr.offset(offset).as_ptr()) } + unsafe { Some(self.ptr.offset(offset)) } } - fn nth_back(&mut self, n: usize) -> Option<*mut A> + fn nth_back(&mut self, n: usize) -> Option { let index = self.index?; let len = self.dim[0] - index[0]; @@ -162,7 +162,7 @@ impl DoubleEndedIterator for Baseiter if index == self.dim { self.index = None; } - unsafe { Some(self.ptr.offset(offset).as_ptr()) } + unsafe { Some(self.ptr.offset(offset)) } } else { self.index = None; None @@ -170,7 +170,7 @@ impl DoubleEndedIterator for Baseiter } fn rfold(mut self, init: Acc, mut g: G) -> Acc - where G: FnMut(Acc, *mut A) -> Acc + where G: FnMut(Acc, Self::Item) -> Acc { let mut accum = init; if let Some(index) = self.index { @@ -182,8 +182,7 @@ impl DoubleEndedIterator for Baseiter accum = g( accum, self.ptr - .offset(Ix1::stride_offset(&self.dim, &self.strides)) - .as_ptr(), + .offset(Ix1::stride_offset(&self.dim, &self.strides)), ); } } @@ -231,7 +230,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> #[inline] fn next(&mut self) -> Option<&'a A> { - self.inner.next().map(|p| unsafe { &*p }) + self.inner.next().map(|p| unsafe { p.as_ref() }) } fn size_hint(&self) -> (usize, Option) @@ -242,7 +241,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBase<'a, A, D> fn fold(self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &*ptr)) } + unsafe { self.inner.fold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } @@ -251,13 +250,13 @@ impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> #[inline] fn next_back(&mut self) -> Option<&'a A> { - self.inner.next_back().map(|p| unsafe { &*p }) + self.inner.next_back().map(|p| unsafe { p.as_ref() }) } fn rfold(self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &*ptr)) } + unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, ptr.as_ref())) } } } @@ -651,7 +650,7 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> #[inline] fn next(&mut self) -> Option<&'a mut A> { - self.inner.next().map(|p| unsafe { &mut *p }) + self.inner.next().map(|mut p| unsafe { p.as_mut() }) } fn size_hint(&self) -> (usize, Option) @@ -662,7 +661,10 @@ impl<'a, A, D: Dimension> Iterator for ElementsBaseMut<'a, A, D> fn fold(self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.fold(init, move |acc, ptr| g(acc, &mut *ptr)) } + unsafe { + self.inner + .fold(init, move |acc, mut ptr| g(acc, ptr.as_mut())) + } } } @@ -671,13 +673,16 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> #[inline] fn next_back(&mut self) -> Option<&'a mut A> { - self.inner.next_back().map(|p| unsafe { &mut *p }) + self.inner.next_back().map(|mut p| unsafe { p.as_mut() }) } fn rfold(self, init: Acc, mut g: G) -> Acc where G: FnMut(Acc, Self::Item) -> Acc { - unsafe { self.inner.rfold(init, move |acc, ptr| g(acc, &mut *ptr)) } + unsafe { + self.inner + .rfold(init, move |acc, mut ptr| g(acc, ptr.as_mut())) + } } } @@ -753,7 +758,7 @@ where D: Dimension { self.iter .next() - .map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } fn size_hint(&self) -> (usize, Option) @@ -777,7 +782,7 @@ impl<'a, A> DoubleEndedIterator for LanesIter<'a, A, Ix1> { self.iter .next_back() - .map(|ptr| unsafe { ArrayView::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + .map(|ptr| unsafe { ArrayView::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } } @@ -805,7 +810,7 @@ where D: Dimension { self.iter .next() - .map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } fn size_hint(&self) -> (usize, Option) @@ -829,7 +834,7 @@ impl<'a, A> DoubleEndedIterator for LanesIterMut<'a, A, Ix1> { self.iter .next_back() - .map(|ptr| unsafe { ArrayViewMut::new_(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) + .map(|ptr| unsafe { ArrayViewMut::new(ptr, Ix1(self.inner_len), Ix1(self.inner_stride as Ix)) }) } } diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 453ef5024..1c2ab6a85 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -115,7 +115,7 @@ impl_iterator! { fn item(&mut self, ptr) { unsafe { - ArrayView::new_( + ArrayView::new( ptr, self.window.clone(), self.strides.clone()) From 7e5762b984faa4ecf25520eb92ec4ef53126f73a Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 18:22:19 +0200 Subject: [PATCH 03/48] ci: Run all checks as pull request tests --- .github/workflows/ci.yaml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 88c9a6c2d..5a2807284 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,6 +1,9 @@ on: pull_request: merge_group: + push: + branches: + - master name: Continuous integration @@ -86,7 +89,7 @@ jobs: - run: ./scripts/all-tests.sh "$FEATURES" ${{ matrix.rust }} cross_test: - if: ${{ github.event_name == 'merge_group' }} + #if: ${{ github.event_name == 'merge_group' }} runs-on: ubuntu-latest strategy: matrix: @@ -110,7 +113,7 @@ jobs: - run: ./scripts/cross-tests.sh "docs" ${{ matrix.rust }} ${{ matrix.target }} cargo-careful: - if: ${{ github.event_name == 'merge_group' }} + #if: ${{ github.event_name == 'merge_group' }} runs-on: ubuntu-latest name: cargo-careful steps: @@ -124,7 +127,7 @@ jobs: - run: cargo careful test -Zcareful-sanitizer --features="$FEATURES" docs: - if: ${{ github.event_name == 'merge_group' }} + #if: ${{ github.event_name == 'merge_group' }} runs-on: ubuntu-latest strategy: matrix: From 6a8fb964d9152bfa108a7fba5d41b5a060f8b36c Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 18:24:08 +0200 Subject: [PATCH 04/48] ci: Check for warnings in cargo doc --- .github/workflows/ci.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 5a2807284..f36591741 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -134,12 +134,14 @@ jobs: rust: - stable name: docs/${{ matrix.rust }} + env: + RUSTDOCFLAGS: "-Dwarnings" steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} - - run: cargo doc + - run: cargo doc --no-deps --all-features conclusion: needs: From 2d258bc605cdde8cde516e471a997b5f9c713b7a Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 18:28:49 +0200 Subject: [PATCH 05/48] Fix rustdoc warnings everywhere --- src/dimension/broadcast.rs | 4 ++-- src/doc/crate_feature_flags.rs | 3 +++ src/doc/ndarray_for_numpy_users/mod.rs | 4 +++- src/doc/ndarray_for_numpy_users/rk_step.rs | 1 + src/slice.rs | 3 +++ 5 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/dimension/broadcast.rs b/src/dimension/broadcast.rs index d277cfea2..fb9fc1a0c 100644 --- a/src/dimension/broadcast.rs +++ b/src/dimension/broadcast.rs @@ -41,8 +41,8 @@ pub trait DimMax } /// Dimensions of the same type remain unchanged when co_broadcast. -/// So you can directly use D as the resulting type. -/// (Instead of >::BroadcastOutput) +/// So you can directly use `D` as the resulting type. +/// (Instead of `>::BroadcastOutput`) impl DimMax for D { type Output = D; diff --git a/src/doc/crate_feature_flags.rs b/src/doc/crate_feature_flags.rs index c0fc4c0f5..fc2c2bd49 100644 --- a/src/doc/crate_feature_flags.rs +++ b/src/doc/crate_feature_flags.rs @@ -30,3 +30,6 @@ //! - Enable the ``threading`` feature in the matrixmultiply package //! //! [`parallel`]: crate::parallel + +#[cfg(doc)] +use crate::parallel::par_azip; diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index 5ac15e300..eba96cdd0 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -654,7 +654,7 @@ //! convert `f32` array to `i32` array with ["saturating" conversion][sat_conv]; care needed because it can be a lossy conversion or result in non-finite values! See [the reference for information][as_typecast]. //! //! -//! +//!
//! //! [as_conv]: https://doc.rust-lang.org/rust-by-example/types/cast.html //! [sat_conv]: https://blog.rust-lang.org/2020/07/16/Rust-1.45.0.html#fixing-unsoundness-in-casts @@ -677,6 +677,8 @@ //! [.column()]: ArrayBase::column //! [.column_mut()]: ArrayBase::column_mut //! [concatenate()]: crate::concatenate() +//! [concatenate!]: crate::concatenate! +//! [stack!]: crate::stack! //! [::default()]: ArrayBase::default //! [.diag()]: ArrayBase::diag //! [.dim()]: ArrayBase::dim diff --git a/src/doc/ndarray_for_numpy_users/rk_step.rs b/src/doc/ndarray_for_numpy_users/rk_step.rs index 0448e0705..c882a3d00 100644 --- a/src/doc/ndarray_for_numpy_users/rk_step.rs +++ b/src/doc/ndarray_for_numpy_users/rk_step.rs @@ -169,6 +169,7 @@ //! ``` //! //! [`.scaled_add()`]: crate::ArrayBase::scaled_add +//! [`azip!()`]: crate::azip! //! //! ### SciPy license //! diff --git a/src/slice.rs b/src/slice.rs index 9e6acc449..e6c237a92 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -7,7 +7,10 @@ // except according to those terms. use crate::dimension::slices_intersect; use crate::error::{ErrorKind, ShapeError}; +#[cfg(doc)] +use crate::s; use crate::{ArrayViewMut, DimAdd, Dimension, Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn}; + #[cfg(not(feature = "std"))] use alloc::vec::Vec; use std::convert::TryFrom; From f07b2fe6f6b7c2bdc6a859f57ad1b51d6025eafd Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 08:17:14 +0200 Subject: [PATCH 06/48] Set doc, doctest = false for test crates --- crates/blas-tests/Cargo.toml | 2 ++ crates/numeric-tests/Cargo.toml | 2 ++ crates/serialization-tests/Cargo.toml | 2 ++ 3 files changed, 6 insertions(+) diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 33323ceac..0dbd9fd12 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -7,6 +7,8 @@ edition = "2018" [lib] test = false +doc = false +doctest = false [dependencies] ndarray = { workspace = true, features = ["approx"] } diff --git a/crates/numeric-tests/Cargo.toml b/crates/numeric-tests/Cargo.toml index 09fe14dbb..214612258 100644 --- a/crates/numeric-tests/Cargo.toml +++ b/crates/numeric-tests/Cargo.toml @@ -7,6 +7,8 @@ edition = "2018" [lib] test = false +doc = false +doctest = false [dependencies] ndarray = { workspace = true, features = ["approx"] } diff --git a/crates/serialization-tests/Cargo.toml b/crates/serialization-tests/Cargo.toml index 8e7056b88..be7c4c17b 100644 --- a/crates/serialization-tests/Cargo.toml +++ b/crates/serialization-tests/Cargo.toml @@ -7,6 +7,8 @@ edition = "2018" [lib] test = false +doc = false +doctest = false [dependencies] ndarray = { workspace = true, features = ["serde"] } From 9f1b35dd81cfd5341517259bc3ac65883fadf21a Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH 07/48] blas-tests: Fix to use blas feature Lost in the recent workspace refactor. --- crates/blas-tests/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 0dbd9fd12..91c6daaa6 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -11,7 +11,7 @@ doc = false doctest = false [dependencies] -ndarray = { workspace = true, features = ["approx"] } +ndarray = { workspace = true, features = ["approx", "blas"] } blas-src = { version = "0.10", optional = true } openblas-src = { version = "0.10", optional = true } From 2ca801c1c3c66159acafb41fad9c7ad6fb375402 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH 08/48] ndarray-gen: Add simple internal interface for building matrices --- Cargo.toml | 2 + crates/ndarray-gen/Cargo.toml | 9 +++ crates/ndarray-gen/README.md | 4 + crates/ndarray-gen/src/array_builder.rs | 97 +++++++++++++++++++++++++ crates/ndarray-gen/src/lib.rs | 11 +++ 5 files changed, 123 insertions(+) create mode 100644 crates/ndarray-gen/Cargo.toml create mode 100644 crates/ndarray-gen/README.md create mode 100644 crates/ndarray-gen/src/array_builder.rs create mode 100644 crates/ndarray-gen/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index 4c34a11bc..a19ae00a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -85,6 +85,7 @@ members = [ default-members = [ ".", "ndarray-rand", + "crates/ndarray-gen", "crates/numeric-tests", "crates/serialization-tests", # exclude blas-tests that depends on BLAS install @@ -93,6 +94,7 @@ default-members = [ [workspace.dependencies] ndarray = { version = "0.16", path = "." } ndarray-rand = { path = "ndarray-rand" } +ndarray-gen = { path = "crates/ndarray-gen" } num-integer = { version = "0.1.39", default-features = false } num-traits = { version = "0.2", default-features = false } diff --git a/crates/ndarray-gen/Cargo.toml b/crates/ndarray-gen/Cargo.toml new file mode 100644 index 000000000..f06adc48a --- /dev/null +++ b/crates/ndarray-gen/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "ndarray-gen" +version = "0.1.0" +edition = "2018" +publish = false + +[dependencies] +ndarray = { workspace = true } +num-traits = { workspace = true } diff --git a/crates/ndarray-gen/README.md b/crates/ndarray-gen/README.md new file mode 100644 index 000000000..7dd02320c --- /dev/null +++ b/crates/ndarray-gen/README.md @@ -0,0 +1,4 @@ + +## ndarray-gen + +Array generation functions, used for testing. diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs new file mode 100644 index 000000000..a021e5252 --- /dev/null +++ b/crates/ndarray-gen/src/array_builder.rs @@ -0,0 +1,97 @@ +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use ndarray::Array; +use ndarray::Dimension; +use ndarray::IntoDimension; +use ndarray::Order; + +use num_traits::Num; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub struct ArrayBuilder +{ + dim: D, + memory_order: Order, + generator: ElementGenerator, +} + +/// How to generate elements +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ElementGenerator +{ + Sequential, + Zero, +} + +impl Default for ArrayBuilder +{ + fn default() -> Self + { + Self::new(D::zeros(D::NDIM.unwrap_or(1))) + } +} + +impl ArrayBuilder +where D: Dimension +{ + pub fn new(dim: impl IntoDimension) -> Self + { + ArrayBuilder { + dim: dim.into_dimension(), + memory_order: Order::C, + generator: ElementGenerator::Sequential, + } + } + + pub fn memory_order(mut self, order: Order) -> Self + { + self.memory_order = order; + self + } + + pub fn generator(mut self, generator: ElementGenerator) -> Self + { + self.generator = generator; + self + } + + pub fn build(self) -> Array + where T: Num + Clone + { + let mut current = T::zero(); + let size = self.dim.size(); + let use_zeros = self.generator == ElementGenerator::Zero; + Array::from_iter((0..size).map(|_| { + let ret = current.clone(); + if !use_zeros { + current = ret.clone() + T::one(); + } + ret + })) + .into_shape_with_order((self.dim, self.memory_order)) + .unwrap() + } +} + +#[test] +fn test_order() +{ + let (m, n) = (12, 13); + let c = ArrayBuilder::new((m, n)) + .memory_order(Order::C) + .build::(); + let f = ArrayBuilder::new((m, n)) + .memory_order(Order::F) + .build::(); + + assert_eq!(c.shape(), &[m, n]); + assert_eq!(f.shape(), &[m, n]); + assert_eq!(c.strides(), &[n as isize, 1]); + assert_eq!(f.strides(), &[1, m as isize]); +} diff --git a/crates/ndarray-gen/src/lib.rs b/crates/ndarray-gen/src/lib.rs new file mode 100644 index 000000000..ceecf2fae --- /dev/null +++ b/crates/ndarray-gen/src/lib.rs @@ -0,0 +1,11 @@ +// Copyright 2024 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +/// Build ndarray arrays for test purposes + +pub mod array_builder; From 27e347c010896cb6e6dde5205af2fbec5c756694 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH 09/48] blas: Update layout logic for gemm We compute A B -> C with matrices A, B, C With the blas (cblas) interface it supports matrices that adhere to certain criteria. They should be contiguous on one dimension (stride=1). We glance a little at how numpy does this to try to catch all cases. In short, we accept A, B contiguous on either axis (row or column major). We use the case where C is (weakly) row major, but if it is column major we transpose A, B, C => A^t, B^t, C^t so that we are back to the C row major case. (Weakly = contiguous with stride=1 on that inner dimension, but stride for the other dimension can be larger; to differentiate from strictly whole array contiguous.) Minor change to the gemv function, no functional change, only updating due to the refactoring of blas layout functions. Fixes #1278 --- Cargo.toml | 4 +- crates/blas-tests/Cargo.toml | 2 + crates/blas-tests/tests/oper.rs | 53 +++-- src/linalg/impl_linalg.rs | 358 +++++++++++++++++++++----------- 4 files changed, 278 insertions(+), 139 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index a19ae00a4..ac1960242 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,7 +47,7 @@ rawpointer = { version = "0.2" } defmac = "0.2" quickcheck = { workspace = true } approx = { workspace = true, default-features = true } -itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +itertools = { workspace = true } [features] default = ["std"] @@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"] portable-atomic-critical-section = ["portable-atomic/critical-section"] + [target.'cfg(not(target_has_atomic = "ptr"))'.dependencies] portable-atomic = { version = "1.6.0" } portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] } @@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false } quickcheck = { version = "1.0", default-features = false } rand = { version = "0.8.0", features = ["small_rng"] } rand_distr = { version = "0.4.0" } +itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } [profile.bench] debug = true diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 91c6daaa6..05a656000 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -12,6 +12,7 @@ doctest = false [dependencies] ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } blas-src = { version = "0.10", optional = true } openblas-src = { version = "0.10", optional = true } @@ -23,6 +24,7 @@ defmac = "0.2" approx = { workspace = true } num-traits = { workspace = true } num-complex = { workspace = true } +itertools = { workspace = true } [features] # Just for making an example and to help testing, , multiple different possible diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index 3ed81915e..931aabea9 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -9,10 +9,13 @@ use ndarray::prelude::*; use ndarray::linalg::general_mat_mul; use ndarray::linalg::general_mat_vec_mul; +use ndarray::Order; use ndarray::{Data, Ix, LinalgScalar}; +use ndarray_gen::array_builder::ArrayBuilder; use approx::assert_relative_eq; use defmac::defmac; +use itertools::iproduct; use num_complex::Complex32; use num_complex::Complex64; @@ -243,7 +246,14 @@ fn gen_mat_mul() let sizes = vec![ (4, 4, 4), (8, 8, 8), - (17, 15, 16), + (10, 10, 10), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), (4, 17, 3), (17, 3, 22), (19, 18, 2), @@ -251,24 +261,41 @@ fn gen_mat_mul() (15, 16, 17), (67, 63, 62), ]; - // test different strides - for &s1 in &[1, 2, -1, -2] { - for &s2 in &[1, 2, -1, -2] { - for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for (&s1, &s2) in iproduct!(strides, strides) { + for &(m, k, n) in &sizes { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + let mut answer = c.clone(); { - let a = a.slice(s![..;s1, ..;s2]); - let b = b.slice(s![..;s2, ..;s2]); - let mut cv = c.slice_mut(s![..;s1, ..;s2]); + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } - let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv; + let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv; answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); - general_mat_mul(alpha, &a, &b, beta, &mut cv); + general_mat_mul(alpha, &av, &bv, beta, &mut cv); } assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7); } diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index f3bedae71..e7813455d 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -25,8 +25,6 @@ use num_complex::{Complex32 as c32, Complex64 as c64}; #[cfg(feature = "blas")] use libc::c_int; #[cfg(feature = "blas")] -use std::cmp; -#[cfg(feature = "blas")] use std::mem::swap; #[cfg(feature = "blas")] @@ -388,8 +386,9 @@ fn mat_mul_impl
( { // size cutoff for using BLAS let cut = GEMM_BLAS_CUTOFF; - let ((mut m, a), (_, mut n)) = (lhs.dim(), rhs.dim()); - if !(m > cut || n > cut || a > cut) + let ((mut m, k), (k2, mut n)) = (lhs.dim(), rhs.dim()); + debug_assert_eq!(k, k2); + if !(m > cut || n > cut || k > cut) || !(same_type::() || same_type::() || same_type::() @@ -397,32 +396,74 @@ fn mat_mul_impl( { return mat_mul_general(alpha, lhs, rhs, beta, c); } - { - // Use `c` for c-order and `f` for an f-order matrix - // We can handle c * c, f * f generally and - // c * f and f * c if the `f` matrix is square. - let mut lhs_ = lhs.view(); - let mut rhs_ = rhs.view(); - let mut c_ = c.view_mut(); - let lhs_s0 = lhs_.strides()[0]; - let rhs_s0 = rhs_.strides()[0]; - let both_f = lhs_s0 == 1 && rhs_s0 == 1; - let mut lhs_trans = CblasNoTrans; - let mut rhs_trans = CblasNoTrans; - if both_f { - // A^t B^t = C^t => B A = C - let lhs_t = lhs_.reversed_axes(); - lhs_ = rhs_.reversed_axes(); - rhs_ = lhs_t; - c_ = c_.reversed_axes(); - swap(&mut m, &mut n); - } else if lhs_s0 == 1 && m == a { - lhs_ = lhs_.reversed_axes(); - lhs_trans = CblasTrans; - } else if rhs_s0 == 1 && a == n { - rhs_ = rhs_.reversed_axes(); - rhs_trans = CblasTrans; + + #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block + 'blas_block: loop { + let mut a = lhs.view(); + let mut b = rhs.view(); + let mut c = c.view_mut(); + + let c_layout = get_blas_compatible_layout(&c); + let c_layout_is_c = matches!(c_layout, Some(MemoryOrder::C)); + let c_layout_is_f = matches!(c_layout, Some(MemoryOrder::F)); + + // Compute A B -> C + // we require for BLAS compatibility that: + // A, B are contiguous (stride=1) in their fastest dimension. + // C is c-contiguous in one dimension (stride=1 in Axis(1)) + // + // If C is f-contiguous, use transpose equivalency + // to translate to the C-contiguous case: + // A^t B^t = C^t => B A = C + + let (a_layout, b_layout) = + match (get_blas_compatible_layout(&a), get_blas_compatible_layout(&b)) { + (Some(a_layout), Some(b_layout)) if c_layout_is_c => { + // normal case + (a_layout, b_layout) + }, + (Some(a_layout), Some(b_layout)) if c_layout_is_f => { + // Transpose equivalency + // A^t B^t = C^t => B A = C + // + // A^t becomes the new B + // B^t becomes the new A + let a_t = a.reversed_axes(); + a = b.reversed_axes(); + b = a_t; + c = c.reversed_axes(); + // Assign (n, k, m) -> (m, k, n) effectively + swap(&mut m, &mut n); + + // Continue using the already computed memory layouts + (b_layout.opposite(), a_layout.opposite()) + }, + _otherwise => { + break 'blas_block; + } + }; + + let a_trans; + let b_trans; + let lda; // Stride of a + let ldb; // Stride of b + + if let MemoryOrder::C = a_layout { + lda = blas_stride(&a, 0); + a_trans = CblasNoTrans; + } else { + lda = blas_stride(&a, 1); + a_trans = CblasTrans; + } + + if let MemoryOrder::C = b_layout { + ldb = blas_stride(&b, 0); + b_trans = CblasNoTrans; + } else { + ldb = blas_stride(&b, 1); + b_trans = CblasTrans; } + let ldc = blas_stride(&c, 0); macro_rules! gemm_scalar_cast { (f32, $var:ident) => { @@ -441,44 +482,25 @@ fn mat_mul_impl( macro_rules! gemm { ($ty:tt, $gemm:ident) => { - if blas_row_major_2d::<$ty, _>(&lhs_) - && blas_row_major_2d::<$ty, _>(&rhs_) - && blas_row_major_2d::<$ty, _>(&c_) - { - let (m, k) = match lhs_trans { - CblasNoTrans => lhs_.dim(), - _ => { - let (rows, cols) = lhs_.dim(); - (cols, rows) - } - }; - let n = match rhs_trans { - CblasNoTrans => rhs_.raw_dim()[1], - _ => rhs_.raw_dim()[0], - }; - // adjust strides, these may [1, 1] for column matrices - let lhs_stride = cmp::max(lhs_.strides()[0] as blas_index, k as blas_index); - let rhs_stride = cmp::max(rhs_.strides()[0] as blas_index, n as blas_index); - let c_stride = cmp::max(c_.strides()[0] as blas_index, n as blas_index); - + if same_type::() { // gemm is C ← αA^Op B^Op + βC // Where Op is notrans/trans/conjtrans unsafe { blas_sys::$gemm( CblasRowMajor, - lhs_trans, - rhs_trans, + a_trans, + b_trans, m as blas_index, // m, rows of Op(a) n as blas_index, // n, cols of Op(b) k as blas_index, // k, cols of Op(a) gemm_scalar_cast!($ty, alpha), // alpha - lhs_.ptr.as_ptr() as *const _, // a - lhs_stride, // lda - rhs_.ptr.as_ptr() as *const _, // b - rhs_stride, // ldb + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb gemm_scalar_cast!($ty, beta), // beta - c_.ptr.as_ptr() as *mut _, // c - c_stride, // ldc + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc ); } return; @@ -490,6 +512,7 @@ fn mat_mul_impl( gemm!(c32, cblas_cgemm); gemm!(c64, cblas_zgemm); + break 'blas_block; } mat_mul_general(alpha, lhs, rhs, beta, c) } @@ -693,46 +716,51 @@ unsafe fn general_mat_vec_mul_impl( #[cfg(feature = "blas")] macro_rules! gemv { ($ty:ty, $gemv:ident) => { - if let Some(layout) = blas_layout::<$ty, _>(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { - // Determine stride between rows or columns. Note that the stride is - // adjusted to at least `k` or `m` to handle the case of a matrix with a - // trivial (length 1) dimension, since the stride for the trivial dimension - // may be arbitrary. - let a_trans = CblasNoTrans; - let a_stride = match layout { - CBLAS_LAYOUT::CblasRowMajor => { - a.strides()[0].max(k as isize) as blas_index - } - CBLAS_LAYOUT::CblasColMajor => { - a.strides()[1].max(m as isize) as blas_index - } - }; - - // Low addr in memory pointers required for x, y - let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); - let x_ptr = x.ptr.as_ptr().sub(x_offset); - let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); - let y_ptr = y.ptr.as_ptr().sub(y_offset); - - let x_stride = x.strides()[0] as blas_index; - let y_stride = y.strides()[0] as blas_index; - - blas_sys::$gemv( - layout, - a_trans, - m as blas_index, // m, rows of Op(a) - k as blas_index, // n, cols of Op(a) - cast_as(&alpha), // alpha - a.ptr.as_ptr() as *const _, // a - a_stride, // lda - x_ptr as *const _, // x - x_stride, - cast_as(&beta), // beta - y_ptr as *mut _, // y - y_stride, - ); - return; + if same_type::() { + if let Some(layout) = get_blas_compatible_layout(&a) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + // Determine stride between rows or columns. Note that the stride is + // adjusted to at least `k` or `m` to handle the case of a matrix with a + // trivial (length 1) dimension, since the stride for the trivial dimension + // may be arbitrary. + let a_trans = CblasNoTrans; + + let (a_stride, cblas_layout) = match layout { + MemoryOrder::C => { + (a.strides()[0].max(k as isize) as blas_index, + CBLAS_LAYOUT::CblasRowMajor) + } + MemoryOrder::F => { + (a.strides()[1].max(m as isize) as blas_index, + CBLAS_LAYOUT::CblasColMajor) + } + }; + + // Low addr in memory pointers required for x, y + let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); + let x_ptr = x.ptr.as_ptr().sub(x_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); + let y_ptr = y.ptr.as_ptr().sub(y_offset); + + let x_stride = x.strides()[0] as blas_index; + let y_stride = y.strides()[0] as blas_index; + + blas_sys::$gemv( + cblas_layout, + a_trans, + m as blas_index, // m, rows of Op(a) + k as blas_index, // n, cols of Op(a) + cast_as(&alpha), // alpha + a.ptr.as_ptr() as *const _, // a + a_stride, // lda + x_ptr as *const _, // x + x_stride, + cast_as(&beta), // beta + y_ptr as *mut _, // y + y_stride, + ); + return; + } } } }; @@ -834,6 +862,7 @@ where } #[cfg(feature = "blas")] +#[derive(Copy, Clone)] enum MemoryOrder { C, @@ -841,29 +870,15 @@ enum MemoryOrder } #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, -{ - if !same_type::() { - return false; - } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) -} - -#[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool -where - S: Data, - A: 'static, - S::Elem: 'static, +impl MemoryOrder { - if !same_type::() { - return false; + fn opposite(self) -> Self + { + match self { + MemoryOrder::C => MemoryOrder::F, + MemoryOrder::F => MemoryOrder::C, + } } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(feature = "blas")] @@ -893,20 +908,71 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool true } +/// Get BLAS compatible layout if any (C or F, preferring the former) +#[cfg(feature = "blas")] +fn get_blas_compatible_layout(a: &ArrayBase) -> Option +where S: Data +{ + if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) { + Some(MemoryOrder::C) + } else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) { + Some(MemoryOrder::F) + } else { + None + } +} + +/// `a` should be blas compatible. +/// axis: 0 or 1. +/// +/// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_layout(a: &ArrayBase) -> Option +fn blas_stride(a: &ArrayBase, axis: usize) -> blas_index +where S: Data +{ + debug_assert!(axis <= 1); + let other_axis = 1 - axis; + let len_this = a.shape()[axis]; + let len_other = a.shape()[other_axis]; + let stride = a.strides()[axis]; + + // if current axis has length == 1, then stride does not matter for ndarray + // but for BLAS we need a stride that makes sense, i.e. it's >= the other axis + + // cast: a should already be blas compatible + (if len_this <= 1 { + Ord::max(stride, len_other as isize) + } else { + stride + }) as blas_index +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_row_major_2d(a: &ArrayBase) -> bool where S: Data, A: 'static, S::Elem: 'static, { - if blas_row_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasRowMajor) - } else if blas_column_major_2d::(a) { - Some(CBLAS_LAYOUT::CblasColMajor) - } else { - None + if !same_type::() { + return false; + } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) +} + +#[cfg(test)] +#[cfg(feature = "blas")] +fn blas_column_major_2d(a: &ArrayBase) -> bool +where + S: Data, + A: 'static, + S::Elem: 'static, +{ + if !same_type::() { + return false; } + is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) } #[cfg(test)] @@ -964,4 +1030,46 @@ mod blas_tests assert!(!blas_row_major_2d::(&m)); assert!(blas_column_major_2d::(&m)); } + + #[test] + fn blas_row_major_2d_skip_rows_ok() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![..;2, ..]); + assert!(blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_row_major_2d_skip_columns_fail() + { + let m: Array2 = Array2::zeros((5, 5)); + let mv = m.slice(s![.., ..;2]); + assert!(!blas_row_major_2d::(&mv)); + assert!(!blas_column_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_columns_ok() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![.., ..;2]); + assert!(blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn blas_col_major_2d_skip_rows_fail() + { + let m: Array2 = Array2::zeros((5, 5).f()); + let mv = m.slice(s![..;2, ..]); + assert!(!blas_column_major_2d::(&mv)); + assert!(!blas_row_major_2d::(&mv)); + } + + #[test] + fn test() + { + //WIP test that stride is larger than other dimension + } } From 01bb218ada456c80d937710ce2d3c997db96bb18 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH 10/48] blas: Fix to skip array with too short stride If we have a matrix of dimension say 5 x 5, BLAS requires the leading stride to be >= 5. Smaller cases are possible for read-only array views in ndarray(broadcasting and custom strides). In this case we mark the array as not BLAS compatible --- src/linalg/impl_linalg.rs | 39 ++++++++++++++++++++++++++++++++++----- 1 file changed, 34 insertions(+), 5 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index e7813455d..778fcaabd 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -863,6 +863,7 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] +#[cfg_attr(test, derive(PartialEq, Eq, Debug))] enum MemoryOrder { C, @@ -887,24 +888,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; - let (inner_stride, outer_dim) = match order { - MemoryOrder::C => (s1, n), - MemoryOrder::F => (s0, m), + let (inner_stride, outer_stride, inner_dim, outer_dim) = match order { + MemoryOrder::C => (s1, s0, m, n), + MemoryOrder::F => (s0, s1, n, m), }; + if !(inner_stride == 1 || outer_dim == 1) { return false; } + if s0 < 1 || s1 < 1 { return false; } + if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize) || (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize) { return false; } + + // leading stride must >= the dimension (no broadcasting/aliasing) + if inner_dim > 1 && (outer_stride as usize) < outer_dim { + return false; + } + if m > blas_index::MAX as usize || n > blas_index::MAX as usize { return false; } + true } @@ -1068,8 +1079,26 @@ mod blas_tests } #[test] - fn test() + fn blas_too_short_stride() { - //WIP test that stride is larger than other dimension + // leading stride must be longer than the other dimension + // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS. + + const N: usize = 5; + const MAXSTRIDE: usize = N + 2; + let mut data = [0; MAXSTRIDE * N]; + let mut iter = 0..data.len(); + data.fill_with(|| iter.next().unwrap()); + + for stride in 1..=MAXSTRIDE { + let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap(); + eprintln!("{:?}", m); + + if stride < N { + assert_eq!(get_blas_compatible_layout(&m), None); + } else { + assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C)); + } + } } } From 56cac34de7ec11705b4c76721ab4290e25e601d7 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 11:18:44 +0200 Subject: [PATCH 11/48] ci: Run ndarray tests with feature blas --- scripts/all-tests.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 6f1fdf73a..b8c9b5849 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -19,6 +19,7 @@ cargo test -v --features "$FEATURES" $QC_FEAT cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FEAT --lib --tests # BLAS tests +cargo test -p ndarray --lib -v --features blas cargo test -p blas-tests -v --features blas-tests/openblas-system cargo test -p numeric-tests -v --features numeric-tests/test_blas From e65bd0d43d716d60b69e6942a314cb5a012e322b Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 7 Aug 2024 15:49:07 +0200 Subject: [PATCH 12/48] tests: Refactor to use ArrayBuilder more places --- Cargo.toml | 3 +- crates/blas-tests/tests/oper.rs | 64 ++++++++-------------- crates/ndarray-gen/Cargo.toml | 2 +- crates/ndarray-gen/src/lib.rs | 1 + tests/oper.rs | 95 ++++++++++++++------------------- 5 files changed, 66 insertions(+), 99 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ac1960242..34d298dda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ defmac = "0.2" quickcheck = { workspace = true } approx = { workspace = true, default-features = true } itertools = { workspace = true } +ndarray-gen = { workspace = true } [features] default = ["std"] @@ -93,7 +94,7 @@ default-members = [ ] [workspace.dependencies] -ndarray = { version = "0.16", path = "." } +ndarray = { version = "0.16", path = ".", default-features = false } ndarray-rand = { path = "ndarray-rand" } ndarray-gen = { path = "crates/ndarray-gen" } diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index 931aabea9..9361a59a5 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -18,6 +18,7 @@ use defmac::defmac; use itertools::iproduct; use num_complex::Complex32; use num_complex::Complex64; +use num_traits::Num; #[test] fn mat_vec_product_1d() @@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis() assert_eq!(a.t().dot(&b), ans); } -fn range_mat(m: Ix, n: Ix) -> Array2 +fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 -{ - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } fn range_mat_complex(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() - .map(|&f| Complex32::new(f, 0.)) + ArrayBuilder::new((m, n)).build() } fn range_mat_complex64(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() - .map(|&f| Complex64::new(f, 0.)) + ArrayBuilder::new((m, n)).build() } fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) + ArrayBuilder::new(m).build() } fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } // simple, slow, correct (hopefully) mat mul @@ -163,8 +147,8 @@ where fn mat_mul_order() { let (m, n, k) = (50, 50, 50); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut af = Array::zeros(a.dim().f()); let mut bf = Array::zeros(b.dim().f()); af.assign(&a); @@ -183,7 +167,7 @@ fn mat_mul_order() fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); + let a = range_mat::(m, n); let x1 = 1.; let x = Array::from(vec![x1]); let b0 = x.broadcast((n, k)).unwrap(); @@ -203,8 +187,8 @@ fn mat_mul_broadcast() fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut rev = Array::zeros(b.dim()); let mut rev = rev.slice_mut(s![..;-1, ..]); rev.assign(&b); @@ -233,8 +217,8 @@ fn mat_mut_zero_len() } } }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); } @@ -307,11 +291,11 @@ fn gen_mat_mul() #[test] fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); + let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(1.0, &a, &x, 1.0, &mut y); assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); @@ -393,11 +377,8 @@ fn gen_mat_vec_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); let (m, k) = a.dim(); let b = range1_mat64(k); let mut c = range1_mat64(m); @@ -438,11 +419,8 @@ fn vec_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); let (m, n) = b.dim(); let a = range1_mat64(m); let mut c = range1_mat64(n); diff --git a/crates/ndarray-gen/Cargo.toml b/crates/ndarray-gen/Cargo.toml index f06adc48a..6818e4b65 100644 --- a/crates/ndarray-gen/Cargo.toml +++ b/crates/ndarray-gen/Cargo.toml @@ -5,5 +5,5 @@ edition = "2018" publish = false [dependencies] -ndarray = { workspace = true } +ndarray = { workspace = true, default-features = false } num-traits = { workspace = true } diff --git a/crates/ndarray-gen/src/lib.rs b/crates/ndarray-gen/src/lib.rs index ceecf2fae..7f9ca89fc 100644 --- a/crates/ndarray-gen/src/lib.rs +++ b/crates/ndarray-gen/src/lib.rs @@ -1,3 +1,4 @@ +#![no_std] // Copyright 2024 bluss and ndarray developers. // // Licensed under the Apache License, Version 2.0 Array2 +fn range_mat(m: Ix, n: Ix) -> Array2 { - Array::linspace(0., (m * n) as f32 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() -} - -fn range_mat64(m: Ix, n: Ix) -> Array2 -{ - Array::linspace(0., (m * n) as f64 - 1., m * n) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } #[cfg(feature = "approx")] fn range1_mat64(m: Ix) -> Array1 { - Array::linspace(0., m as f64 - 1., m) + ArrayBuilder::new(m).build() } fn range_i32(m: Ix, n: Ix) -> Array2 { - Array::from_iter(0..(m * n) as i32) - .into_shape_with_order((m, n)) - .unwrap() + ArrayBuilder::new((m, n)).build() } // simple, slow, correct (hopefully) mat mul @@ -332,8 +325,8 @@ where fn mat_mul() { let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -351,8 +344,8 @@ fn mat_mul() assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 5, 11); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -370,8 +363,8 @@ fn mat_mul() assert_eq!(ab, af.dot(&bf)); let (m, n, k) = (10, 8, 1); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut b = b / 4.; { let mut c = b.column_mut(0); @@ -395,8 +388,8 @@ fn mat_mul() fn mat_mul_order() { let (m, n, k) = (8, 8, 8); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut af = Array::zeros(a.dim().f()); let mut bf = Array::zeros(b.dim().f()); af.assign(&a); @@ -415,8 +408,8 @@ fn mat_mul_order() fn mat_mul_shape_mismatch() { let (m, k, k2, n) = (8, 8, 9, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); a.dot(&b); } @@ -426,9 +419,9 @@ fn mat_mul_shape_mismatch() fn mat_mul_shape_mismatch_2() { let (m, k, k2, n) = (8, 8, 8, 8); - let a = range_mat(m, k); - let b = range_mat(k2, n); - let mut c = range_mat(m, n + 1); + let a = range_mat::(m, k); + let b = range_mat::(k2, n); + let mut c = range_mat::(m, n + 1); general_mat_mul(1., &a, &b, 1., &mut c); } @@ -438,7 +431,7 @@ fn mat_mul_shape_mismatch_2() fn mat_mul_broadcast() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); + let a = range_mat::(m, n); let x1 = 1.; let x = Array::from(vec![x1]); let b0 = x.broadcast((n, k)).unwrap(); @@ -458,8 +451,8 @@ fn mat_mul_broadcast() fn mat_mul_rev() { let (m, n, k) = (16, 16, 16); - let a = range_mat(m, n); - let b = range_mat(n, k); + let a = range_mat::(m, n); + let b = range_mat::(n, k); let mut rev = Array::zeros(b.dim()); let mut rev = rev.slice_mut(s![..;-1, ..]); rev.assign(&b); @@ -488,8 +481,8 @@ fn mat_mut_zero_len() } } }); - mat_mul_zero_len!(range_mat); - mat_mul_zero_len!(range_mat64); + mat_mul_zero_len!(range_mat::); + mat_mul_zero_len!(range_mat::); mat_mul_zero_len!(range_i32); } @@ -528,9 +521,9 @@ fn scaled_add_2() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); - let c = range_mat64(n, q); + let c = range_mat::(n, q); { let mut av = a.slice_mut(s![..;s1, ..;s2]); @@ -570,7 +563,7 @@ fn scaled_add_3() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n, q) in &sizes { - let mut a = range_mat64(m, k); + let mut a = range_mat::(m, k); let mut answer = a.clone(); let cdim = if n == 1 { vec![q] } else { vec![n, q] }; let cslice: Vec = if n == 1 { @@ -582,7 +575,7 @@ fn scaled_add_3() ] }; - let c = range_mat64(n, q).into_shape_with_order(cdim).unwrap(); + let c = range_mat::(n, q).into_shape_with_order(cdim).unwrap(); { let mut av = a.slice_mut(s![..;s1, ..;s2]); @@ -619,9 +612,9 @@ fn gen_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k, n) in &sizes { - let a = range_mat64(m, k); - let b = range_mat64(k, n); - let mut c = range_mat64(m, n); + let a = range_mat::(m, k); + let b = range_mat::(k, n); + let mut c = range_mat::(m, n); let mut answer = c.clone(); { @@ -645,11 +638,11 @@ fn gen_mat_mul() #[test] fn gemm_64_1_f() { - let a = range_mat64(64, 64).reversed_axes(); + let a = range_mat::(64, 64).reversed_axes(); let (m, n) = a.dim(); // m x n times n x 1 == m x 1 - let x = range_mat64(n, 1); - let mut y = range_mat64(m, 1); + let x = range_mat::(n, 1); + let mut y = range_mat::(m, 1); let answer = reference_mat_mul(&a, &x) + &y; general_mat_mul(1.0, &a, &x, 1.0, &mut y); approx::assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7); @@ -728,11 +721,8 @@ fn gen_mat_vec_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, k) in &sizes { - for &rev in &[false, true] { - let mut a = range_mat64(m, k); - if rev { - a = a.reversed_axes(); - } + for order in [Order::C, Order::F] { + let a = ArrayBuilder::new((m, k)).memory_order(order).build(); let (m, k) = a.dim(); let b = range1_mat64(k); let mut c = range1_mat64(m); @@ -794,11 +784,8 @@ fn vec_mat_mul() for &s1 in &[1, 2, -1, -2] { for &s2 in &[1, 2, -1, -2] { for &(m, n) in &sizes { - for &rev in &[false, true] { - let mut b = range_mat64(m, n); - if rev { - b = b.reversed_axes(); - } + for order in [Order::C, Order::F] { + let b = ArrayBuilder::new((m, n)).memory_order(order).build(); let (m, n) = b.dim(); let a = range1_mat64(m); let mut c = range1_mat64(n); From b2955cb9ebd647963b285e1a3e31e7bedc8beacc Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 16:05:14 +0200 Subject: [PATCH 13/48] blas: Simplify layout logic for gemm Using cblas we can simplify this further to a more satisfying translation (from ndarray to BLAS), much simpler logic. Avoids creating and handling an extra layer of array views. --- crates/blas-tests/tests/oper.rs | 2 +- src/linalg/impl_linalg.rs | 133 ++++++++++++++++---------------- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index 9361a59a5..f1e1bc42b 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -253,7 +253,7 @@ fn gen_mat_mul() for &(m, k, n) in &sizes { for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); - let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5; let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 778fcaabd..243dc783b 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -24,13 +24,11 @@ use num_complex::{Complex32 as c32, Complex64 as c64}; #[cfg(feature = "blas")] use libc::c_int; -#[cfg(feature = "blas")] -use std::mem::swap; #[cfg(feature = "blas")] use cblas_sys as blas_sys; #[cfg(feature = "blas")] -use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT}; +use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT}; /// len of vector before we use blas #[cfg(feature = "blas")] @@ -377,8 +375,8 @@ use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] fn mat_mul_impl( alpha: A, - lhs: &ArrayView2<'_, A>, - rhs: &ArrayView2<'_, A>, + a: &ArrayView2<'_, A>, + b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, ) where @@ -386,7 +384,7 @@ fn mat_mul_impl( { // size cutoff for using BLAS let cut = GEMM_BLAS_CUTOFF; - let ((mut m, k), (k2, mut n)) = (lhs.dim(), rhs.dim()); + let ((m, k), (k2, n)) = (a.dim(), b.dim()); debug_assert_eq!(k, k2); if !(m > cut || n > cut || k > cut) || !(same_type::() @@ -394,76 +392,48 @@ fn mat_mul_impl( || same_type::() || same_type::()) { - return mat_mul_general(alpha, lhs, rhs, beta, c); + return mat_mul_general(alpha, a, b, beta, c); } #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block 'blas_block: loop { - let mut a = lhs.view(); - let mut b = rhs.view(); - let mut c = c.view_mut(); - - let c_layout = get_blas_compatible_layout(&c); - let c_layout_is_c = matches!(c_layout, Some(MemoryOrder::C)); - let c_layout_is_f = matches!(c_layout, Some(MemoryOrder::F)); - // Compute A B -> C - // we require for BLAS compatibility that: - // A, B are contiguous (stride=1) in their fastest dimension. - // C is c-contiguous in one dimension (stride=1 in Axis(1)) + // We require for BLAS compatibility that: + // A, B, C are contiguous (stride=1) in their fastest dimension, + // but it can be either first or second axis (either rowmajor/"c" or colmajor/"f"). // - // If C is f-contiguous, use transpose equivalency - // to translate to the C-contiguous case: - // A^t B^t = C^t => B A = C - - let (a_layout, b_layout) = - match (get_blas_compatible_layout(&a), get_blas_compatible_layout(&b)) { - (Some(a_layout), Some(b_layout)) if c_layout_is_c => { - // normal case - (a_layout, b_layout) + // The "normal case" is CblasRowMajor for cblas. + // Select CblasRowMajor, CblasColMajor to fit C's memory order. + // + // Apply transpose to A, B as needed if they differ from the normal case. + // If C is CblasColMajor then transpose both A, B (again!) + + let (a_layout, a_axis, b_layout, b_axis, c_layout) = + match (get_blas_compatible_layout(a), + get_blas_compatible_layout(b), + get_blas_compatible_layout(c)) + { + (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => { + (a_layout, a_layout.lead_axis(), + b_layout, b_layout.lead_axis(), c_layout) }, - (Some(a_layout), Some(b_layout)) if c_layout_is_f => { - // Transpose equivalency - // A^t B^t = C^t => B A = C - // - // A^t becomes the new B - // B^t becomes the new A - let a_t = a.reversed_axes(); - a = b.reversed_axes(); - b = a_t; - c = c.reversed_axes(); - // Assign (n, k, m) -> (m, k, n) effectively - swap(&mut m, &mut n); - - // Continue using the already computed memory layouts - (b_layout.opposite(), a_layout.opposite()) + (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => { + // CblasColMajor is the "other case" + // Mark a, b as having layouts opposite of what they were detected as, which + // ends up with the correct transpose setting w.r.t col major + (a_layout.opposite(), a_layout.lead_axis(), + b_layout.opposite(), b_layout.lead_axis(), c_layout) }, - _otherwise => { - break 'blas_block; - } + _ => break 'blas_block, }; - let a_trans; - let b_trans; - let lda; // Stride of a - let ldb; // Stride of b + let a_trans = a_layout.to_cblas_transpose(); + let lda = blas_stride(&a, a_axis); - if let MemoryOrder::C = a_layout { - lda = blas_stride(&a, 0); - a_trans = CblasNoTrans; - } else { - lda = blas_stride(&a, 1); - a_trans = CblasTrans; - } + let b_trans = b_layout.to_cblas_transpose(); + let ldb = blas_stride(&b, b_axis); - if let MemoryOrder::C = b_layout { - ldb = blas_stride(&b, 0); - b_trans = CblasNoTrans; - } else { - ldb = blas_stride(&b, 1); - b_trans = CblasTrans; - } - let ldc = blas_stride(&c, 0); + let ldc = blas_stride(&c, c_layout.lead_axis()); macro_rules! gemm_scalar_cast { (f32, $var:ident) => { @@ -487,7 +457,7 @@ fn mat_mul_impl( // Where Op is notrans/trans/conjtrans unsafe { blas_sys::$gemm( - CblasRowMajor, + c_layout.to_cblas_layout(), a_trans, b_trans, m as blas_index, // m, rows of Op(a) @@ -507,14 +477,15 @@ fn mat_mul_impl( } }; } + gemm!(f32, cblas_sgemm); gemm!(f64, cblas_dgemm); - gemm!(c32, cblas_cgemm); gemm!(c64, cblas_zgemm); + break 'blas_block; } - mat_mul_general(alpha, lhs, rhs, beta, c) + mat_mul_general(alpha, a, b, beta, c) } /// C ← α A B + β C @@ -873,6 +844,18 @@ enum MemoryOrder #[cfg(feature = "blas")] impl MemoryOrder { + #[inline] + /// Axis of leading stride (opposite of contiguous axis) + fn lead_axis(self) -> usize + { + match self { + MemoryOrder::C => 0, + MemoryOrder::F => 1, + } + } + + /// Get opposite memory order + #[inline] fn opposite(self) -> Self { match self { @@ -880,6 +863,22 @@ impl MemoryOrder MemoryOrder::F => MemoryOrder::C, } } + + fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE + { + match self { + MemoryOrder::C => CblasNoTrans, + MemoryOrder::F => CblasTrans, + } + } + + fn to_cblas_layout(self) -> CBLAS_LAYOUT + { + match self { + MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor, + MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor, + } + } } #[cfg(feature = "blas")] From 844cfcb601f61b7fea35c0611149f85d3aef32f9 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Thu, 8 Aug 2024 19:32:59 +0200 Subject: [PATCH 14/48] blas: Test that matrix multiply calls BLAS Add a crate with a mock blas implementation, so that we can assert that cblas_sgemm etc are called (depending on memory layout). --- Cargo.toml | 5 +- crates/blas-mock-tests/Cargo.toml | 18 ++++ crates/blas-mock-tests/src/lib.rs | 100 +++++++++++++++++++++++ crates/blas-mock-tests/tests/use-blas.rs | 88 ++++++++++++++++++++ scripts/all-tests.sh | 1 + 5 files changed, 210 insertions(+), 2 deletions(-) create mode 100644 crates/blas-mock-tests/Cargo.toml create mode 100644 crates/blas-mock-tests/src/lib.rs create mode 100644 crates/blas-mock-tests/tests/use-blas.rs diff --git a/Cargo.toml b/Cargo.toml index 34d298dda..50faacf19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ approx = { workspace = true, optional = true } rayon = { version = "1.10.0", optional = true } # Use via the `blas` crate feature -cblas-sys = { version = "0.1.4", optional = true, default-features = false } +cblas-sys = { workspace = true, optional = true } libc = { version = "0.2.82", optional = true } matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm"] } @@ -90,7 +90,7 @@ default-members = [ "crates/ndarray-gen", "crates/numeric-tests", "crates/serialization-tests", - # exclude blas-tests that depends on BLAS install + # exclude blas-tests and blas-mock-tests that activate "blas" feature ] [workspace.dependencies] @@ -106,6 +106,7 @@ quickcheck = { version = "1.0", default-features = false } rand = { version = "0.8.0", features = ["small_rng"] } rand_distr = { version = "0.4.0" } itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } +cblas-sys = { version = "0.1.4", default-features = false } [profile.bench] debug = true diff --git a/crates/blas-mock-tests/Cargo.toml b/crates/blas-mock-tests/Cargo.toml new file mode 100644 index 000000000..a12b78580 --- /dev/null +++ b/crates/blas-mock-tests/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "blas-mock-tests" +version = "0.1.0" +edition = "2018" +publish = false + +[lib] +test = false +doc = false +doctest = false + +[dependencies] +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } +cblas-sys = { workspace = true } + +[dev-dependencies] +itertools = { workspace = true } diff --git a/crates/blas-mock-tests/src/lib.rs b/crates/blas-mock-tests/src/lib.rs new file mode 100644 index 000000000..11fc5975e --- /dev/null +++ b/crates/blas-mock-tests/src/lib.rs @@ -0,0 +1,100 @@ +//! Mock interfaces to BLAS + +use core::cell::RefCell; +use core::ffi::{c_double, c_float, c_int}; +use std::thread_local; + +use cblas_sys::{c_double_complex, c_float_complex, CBLAS_LAYOUT, CBLAS_TRANSPOSE}; + +thread_local! { + /// This counter is incremented every time a gemm function is called + pub static CALL_COUNT: RefCell = RefCell::new(0); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_sgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_float, + a: *const c_float, + lda: c_int, + b: *const c_float, + ldb: c_int, + beta: c_float, + c: *mut c_float, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_dgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: c_double, + a: *const c_double, + lda: c_int, + b: *const c_double, + ldb: c_int, + beta: c_double, + c: *mut c_double, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_cgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_float_complex, + a: *const c_float_complex, + lda: c_int, + b: *const c_float_complex, + ldb: c_int, + beta: *const c_float_complex, + c: *mut c_float_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} + +#[rustfmt::skip] +#[no_mangle] +#[allow(unused)] +pub unsafe extern "C" fn cblas_zgemm( + layout: CBLAS_LAYOUT, + transa: CBLAS_TRANSPOSE, + transb: CBLAS_TRANSPOSE, + m: c_int, + n: c_int, + k: c_int, + alpha: *const c_double_complex, + a: *const c_double_complex, + lda: c_int, + b: *const c_double_complex, + ldb: c_int, + beta: *const c_double_complex, + c: *mut c_double_complex, + ldc: c_int +) { + CALL_COUNT.with(|ctx| *ctx.borrow_mut() += 1); +} diff --git a/crates/blas-mock-tests/tests/use-blas.rs b/crates/blas-mock-tests/tests/use-blas.rs new file mode 100644 index 000000000..217508af6 --- /dev/null +++ b/crates/blas-mock-tests/tests/use-blas.rs @@ -0,0 +1,88 @@ +extern crate ndarray; + +use ndarray::prelude::*; + +use blas_mock_tests::CALL_COUNT; +use ndarray::linalg::general_mat_mul; +use ndarray::Order; +use ndarray_gen::array_builder::ArrayBuilder; + +use itertools::iproduct; + +#[test] +fn test_gen_mat_mul_uses_blas() +{ + let alpha = 1.0; + let beta = 0.0; + + let sizes = vec![ + (8, 8, 8), + (10, 10, 10), + (8, 8, 1), + (1, 10, 10), + (10, 1, 10), + (10, 10, 1), + (1, 10, 1), + (10, 1, 1), + (1, 1, 10), + (4, 17, 3), + (17, 3, 22), + (19, 18, 2), + (16, 17, 15), + (15, 16, 17), + (67, 63, 62), + ]; + let strides = &[1, 2, -1, -2]; + let cf_order = [Order::C, Order::F]; + + // test different strides and memory orders + for &(m, k, n) in &sizes { + for (&s1, &s2) in iproduct!(strides, strides) { + for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { + println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); + + let a = ArrayBuilder::new((m, k)).memory_order(ord1).build(); + let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); + let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); + + { + let av; + let bv; + let mut cv; + + if s1 != 1 || s2 != 1 { + av = a.slice(s![..;s1, ..;s2]); + bv = b.slice(s![..;s2, ..;s2]); + cv = c.slice_mut(s![..;s1, ..;s2]); + } else { + // different stride cases for slicing versus not sliced (for axes of + // len=1); so test not sliced here. + av = a.view(); + bv = b.view(); + cv = c.view_mut(); + } + + let pre_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + general_mat_mul(alpha, &av, &bv, beta, &mut cv); + let after_count = CALL_COUNT.with(|ctx| *ctx.borrow()); + let ncalls = after_count - pre_count; + debug_assert!(ncalls <= 1); + + let always_uses_blas = s1 == 1 && s2 == 1; + + if always_uses_blas { + assert_eq!(ncalls, 1, "Contiguous arrays should use blas, orders={:?}", (ord1, ord2, ord3)); + } + + let should_use_blas = av.strides().iter().all(|&s| s > 0) + && bv.strides().iter().all(|&s| s > 0) + && cv.strides().iter().all(|&s| s > 0) + && av.strides().iter().any(|&s| s == 1) + && bv.strides().iter().any(|&s| s == 1) + && cv.strides().iter().any(|&s| s == 1); + assert_eq!(should_use_blas, ncalls > 0); + } + } + } + } +} diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index b8c9b5849..4ececbcbd 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -20,6 +20,7 @@ cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FE # BLAS tests cargo test -p ndarray --lib -v --features blas +cargo test -p blas-mock-tests -v cargo test -p blas-tests -v --features blas-tests/openblas-system cargo test -p numeric-tests -v --features numeric-tests/test_blas From 700b4ddaae1b97551781f7fb8924cef2f2eb50db Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 15/48] scripts: Fix off by one in makechangelog Use git's tformat to get a newline on the last entry, and then we include the last commit hash in the listing too. --- scripts/makechangelog.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/makechangelog.sh b/scripts/makechangelog.sh index 8bb6f2c2f..535280804 100755 --- a/scripts/makechangelog.sh +++ b/scripts/makechangelog.sh @@ -8,7 +8,7 @@ # Will produce some duplicates for PRs integrated using rebase, # but those will not occur with current merge queue. -git log --first-parent --pretty="format:%H" "$@" | while read commit_sha +git log --first-parent --pretty="tformat:%H" "$@" | while IFS= read -r commit_sha do gh api "/repos/:owner/:repo/commits/${commit_sha}/pulls" \ -q ".[] | \"- \(.title) by [@\(.user.login)](\(.user.html_url)) [#\(.number)](\(.html_url))\"" From 05789c1e6d2b7a61c24e8cbfff593a88aa688869 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 16/48] Use resolver=2 in the workspace It separates feature selection for different dependency classes (dev, build, regular), which makes sense even if it seems to have no impact at the moment. --- Cargo.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Cargo.toml b/Cargo.toml index 50faacf19..77f8a01b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ keywords = ["array", "data-structure", "multidimensional", "matrix", "blas"] categories = ["data-structures", "science"] exclude = ["docgen/images/*"] +resolver = "2" [lib] name = "ndarray" From 7226d3983f1bce5ad2149c315b96089d07ef467f Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 17/48] blas: Run blas-mock-tests in cross compiler tests --- crates/blas-mock-tests/Cargo.toml | 4 ++-- scripts/cross-tests.sh | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/blas-mock-tests/Cargo.toml b/crates/blas-mock-tests/Cargo.toml index a12b78580..39ef9cf99 100644 --- a/crates/blas-mock-tests/Cargo.toml +++ b/crates/blas-mock-tests/Cargo.toml @@ -10,9 +10,9 @@ doc = false doctest = false [dependencies] -ndarray = { workspace = true, features = ["approx", "blas"] } -ndarray-gen = { workspace = true } cblas-sys = { workspace = true } [dev-dependencies] +ndarray = { workspace = true, features = ["approx", "blas"] } +ndarray-gen = { workspace = true } itertools = { workspace = true } diff --git a/scripts/cross-tests.sh b/scripts/cross-tests.sh index 683a901d8..80b37c339 100755 --- a/scripts/cross-tests.sh +++ b/scripts/cross-tests.sh @@ -11,3 +11,4 @@ QC_FEAT=--features=ndarray-rand/quickcheck cross build -v --features="$FEATURES" $QC_FEAT --target=$TARGET cross test -v --no-fail-fast --features="$FEATURES" $QC_FEAT --target=$TARGET +cross test -v -p blas-mock-tests From 453eae38a4ca63b1cb6c6d3a435a10e69b97974c Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 18/48] blas: Refactor and simplify gemm call further Further clarify transpose logic by putting it into BlasOrder methods. --- src/linalg/impl_linalg.rs | 124 ++++++++++++++++++-------------------- 1 file changed, 57 insertions(+), 67 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 243dc783b..965cefc4d 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -28,7 +28,7 @@ use libc::c_int; #[cfg(feature = "blas")] use cblas_sys as blas_sys; #[cfg(feature = "blas")] -use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT}; +use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE}; /// len of vector before we use blas #[cfg(feature = "blas")] @@ -400,40 +400,33 @@ fn mat_mul_impl( // Compute A B -> C // We require for BLAS compatibility that: // A, B, C are contiguous (stride=1) in their fastest dimension, - // but it can be either first or second axis (either rowmajor/"c" or colmajor/"f"). + // but they can be either row major/"c" or col major/"f". // // The "normal case" is CblasRowMajor for cblas. - // Select CblasRowMajor, CblasColMajor to fit C's memory order. + // Select CblasRowMajor / CblasColMajor to fit C's memory order. // - // Apply transpose to A, B as needed if they differ from the normal case. + // Apply transpose to A, B as needed if they differ from the row major case. // If C is CblasColMajor then transpose both A, B (again!) - let (a_layout, a_axis, b_layout, b_axis, c_layout) = - match (get_blas_compatible_layout(a), - get_blas_compatible_layout(b), - get_blas_compatible_layout(c)) + let (a_layout, b_layout, c_layout) = + if let (Some(a_layout), Some(b_layout), Some(c_layout)) = + (get_blas_compatible_layout(a), + get_blas_compatible_layout(b), + get_blas_compatible_layout(c)) { - (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => { - (a_layout, a_layout.lead_axis(), - b_layout, b_layout.lead_axis(), c_layout) - }, - (Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => { - // CblasColMajor is the "other case" - // Mark a, b as having layouts opposite of what they were detected as, which - // ends up with the correct transpose setting w.r.t col major - (a_layout.opposite(), a_layout.lead_axis(), - b_layout.opposite(), b_layout.lead_axis(), c_layout) - }, - _ => break 'blas_block, + (a_layout, b_layout, c_layout) + } else { + break 'blas_block; }; - let a_trans = a_layout.to_cblas_transpose(); - let lda = blas_stride(&a, a_axis); + let cblas_layout = c_layout.to_cblas_layout(); + let a_trans = a_layout.to_cblas_transpose_for(cblas_layout); + let lda = blas_stride(&a, a_layout); - let b_trans = b_layout.to_cblas_transpose(); - let ldb = blas_stride(&b, b_axis); + let b_trans = b_layout.to_cblas_transpose_for(cblas_layout); + let ldb = blas_stride(&b, b_layout); - let ldc = blas_stride(&c, c_layout.lead_axis()); + let ldc = blas_stride(&c, c_layout); macro_rules! gemm_scalar_cast { (f32, $var:ident) => { @@ -457,7 +450,7 @@ fn mat_mul_impl( // Where Op is notrans/trans/conjtrans unsafe { blas_sys::$gemm( - c_layout.to_cblas_layout(), + cblas_layout, a_trans, b_trans, m as blas_index, // m, rows of Op(a) @@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl( // may be arbitrary. let a_trans = CblasNoTrans; - let (a_stride, cblas_layout) = match layout { - MemoryOrder::C => { - (a.strides()[0].max(k as isize) as blas_index, - CBLAS_LAYOUT::CblasRowMajor) - } - MemoryOrder::F => { - (a.strides()[1].max(m as isize) as blas_index, - CBLAS_LAYOUT::CblasColMajor) - } - }; + let a_stride = blas_stride(&a, layout); + let cblas_layout = layout.to_cblas_layout(); // Low addr in memory pointers required for x, y let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); @@ -835,61 +820,66 @@ where #[cfg(feature = "blas")] #[derive(Copy, Clone)] #[cfg_attr(test, derive(PartialEq, Eq, Debug))] -enum MemoryOrder +enum BlasOrder { C, F, } #[cfg(feature = "blas")] -impl MemoryOrder +impl BlasOrder { - #[inline] - /// Axis of leading stride (opposite of contiguous axis) - fn lead_axis(self) -> usize + fn transpose(self) -> Self { match self { - MemoryOrder::C => 0, - MemoryOrder::F => 1, + Self::C => Self::F, + Self::F => Self::C, } } - /// Get opposite memory order #[inline] - fn opposite(self) -> Self + /// Axis of leading stride (opposite of contiguous axis) + fn get_blas_lead_axis(self) -> usize { match self { - MemoryOrder::C => MemoryOrder::F, - MemoryOrder::F => MemoryOrder::C, + Self::C => 0, + Self::F => 1, } } - fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE + fn to_cblas_layout(self) -> CBLAS_LAYOUT { match self { - MemoryOrder::C => CblasNoTrans, - MemoryOrder::F => CblasTrans, + Self::C => CBLAS_LAYOUT::CblasRowMajor, + Self::F => CBLAS_LAYOUT::CblasColMajor, } } - fn to_cblas_layout(self) -> CBLAS_LAYOUT + /// When using cblas_sgemm (etc) with C matrix using `for_layout`, + /// how should this `self` matrix be transposed + fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE { - match self { - MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor, - MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor, + let effective_order = match for_layout { + CBLAS_LAYOUT::CblasRowMajor => self, + CBLAS_LAYOUT::CblasColMajor => self.transpose(), + }; + + match effective_order { + Self::C => CblasNoTrans, + Self::F => CblasTrans, } } } #[cfg(feature = "blas")] -fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool +fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool { let (m, n) = dim.into_pattern(); let s0 = stride[0] as isize; let s1 = stride[1] as isize; let (inner_stride, outer_stride, inner_dim, outer_dim) = match order { - MemoryOrder::C => (s1, s0, m, n), - MemoryOrder::F => (s0, s1, n, m), + BlasOrder::C => (s1, s0, m, n), + BlasOrder::F => (s0, s1, n, m), }; if !(inner_stride == 1 || outer_dim == 1) { @@ -920,13 +910,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] -fn get_blas_compatible_layout(a: &ArrayBase) -> Option +fn get_blas_compatible_layout(a: &ArrayBase) -> Option where S: Data { - if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) { - Some(MemoryOrder::C) - } else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) { - Some(MemoryOrder::F) + if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { + Some(BlasOrder::C) + } else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) { + Some(BlasOrder::F) } else { None } @@ -937,10 +927,10 @@ where S: Data /// /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_stride(a: &ArrayBase, axis: usize) -> blas_index +fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index where S: Data { - debug_assert!(axis <= 1); + let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; let len_this = a.shape()[axis]; let len_other = a.shape()[other_axis]; @@ -968,7 +958,7 @@ where if !same_type::() { return false; } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) + is_blas_2d(&a.dim, &a.strides, BlasOrder::C) } #[cfg(test)] @@ -982,7 +972,7 @@ where if !same_type::() { return false; } - is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) + is_blas_2d(&a.dim, &a.strides, BlasOrder::F) } #[cfg(test)] @@ -1096,7 +1086,7 @@ mod blas_tests if stride < N { assert_eq!(get_blas_compatible_layout(&m), None); } else { - assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C)); + assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C)); } } } From 0153a37c5a832a99d694131464803617540c841d Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 19/48] blas: Simplify control flow in matrix multiply --- src/linalg/impl_linalg.rs | 148 ++++++++++++++++---------------------- 1 file changed, 62 insertions(+), 86 deletions(-) diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 965cefc4d..7472d8292 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -371,32 +371,15 @@ where #[cfg(not(feature = "blas"))] use self::mat_mul_general as mat_mul_impl; -#[rustfmt::skip] #[cfg(feature = "blas")] -fn mat_mul_impl( - alpha: A, - a: &ArrayView2<'_, A>, - b: &ArrayView2<'_, A>, - beta: A, - c: &mut ArrayViewMut2<'_, A>, -) where - A: LinalgScalar, +fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) +where A: LinalgScalar { - // size cutoff for using BLAS - let cut = GEMM_BLAS_CUTOFF; let ((m, k), (k2, n)) = (a.dim(), b.dim()); debug_assert_eq!(k, k2); - if !(m > cut || n > cut || k > cut) - || !(same_type::() - || same_type::() - || same_type::() - || same_type::()) + if (m > GEMM_BLAS_CUTOFF || n > GEMM_BLAS_CUTOFF || k > GEMM_BLAS_CUTOFF) + && (same_type::() || same_type::() || same_type::() || same_type::()) { - return mat_mul_general(alpha, a, b, beta, c); - } - - #[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block - 'blas_block: loop { // Compute A B -> C // We require for BLAS compatibility that: // A, B, C are contiguous (stride=1) in their fastest dimension, @@ -408,75 +391,68 @@ fn mat_mul_impl( // Apply transpose to A, B as needed if they differ from the row major case. // If C is CblasColMajor then transpose both A, B (again!) - let (a_layout, b_layout, c_layout) = - if let (Some(a_layout), Some(b_layout), Some(c_layout)) = - (get_blas_compatible_layout(a), - get_blas_compatible_layout(b), - get_blas_compatible_layout(c)) - { - (a_layout, b_layout, c_layout) - } else { - break 'blas_block; - }; - - let cblas_layout = c_layout.to_cblas_layout(); - let a_trans = a_layout.to_cblas_transpose_for(cblas_layout); - let lda = blas_stride(&a, a_layout); - - let b_trans = b_layout.to_cblas_transpose_for(cblas_layout); - let ldb = blas_stride(&b, b_layout); - - let ldc = blas_stride(&c, c_layout); - - macro_rules! gemm_scalar_cast { - (f32, $var:ident) => { - cast_as(&$var) - }; - (f64, $var:ident) => { - cast_as(&$var) - }; - (c32, $var:ident) => { - &$var as *const A as *const _ - }; - (c64, $var:ident) => { - &$var as *const A as *const _ - }; - } + if let (Some(a_layout), Some(b_layout), Some(c_layout)) = + (get_blas_compatible_layout(a), get_blas_compatible_layout(b), get_blas_compatible_layout(c)) + { + let cblas_layout = c_layout.to_cblas_layout(); + let a_trans = a_layout.to_cblas_transpose_for(cblas_layout); + let lda = blas_stride(&a, a_layout); + + let b_trans = b_layout.to_cblas_transpose_for(cblas_layout); + let ldb = blas_stride(&b, b_layout); + + let ldc = blas_stride(&c, c_layout); + + macro_rules! gemm_scalar_cast { + (f32, $var:ident) => { + cast_as(&$var) + }; + (f64, $var:ident) => { + cast_as(&$var) + }; + (c32, $var:ident) => { + &$var as *const A as *const _ + }; + (c64, $var:ident) => { + &$var as *const A as *const _ + }; + } - macro_rules! gemm { - ($ty:tt, $gemm:ident) => { - if same_type::() { - // gemm is C ← αA^Op B^Op + βC - // Where Op is notrans/trans/conjtrans - unsafe { - blas_sys::$gemm( - cblas_layout, - a_trans, - b_trans, - m as blas_index, // m, rows of Op(a) - n as blas_index, // n, cols of Op(b) - k as blas_index, // k, cols of Op(a) - gemm_scalar_cast!($ty, alpha), // alpha - a.ptr.as_ptr() as *const _, // a - lda, // lda - b.ptr.as_ptr() as *const _, // b - ldb, // ldb - gemm_scalar_cast!($ty, beta), // beta - c.ptr.as_ptr() as *mut _, // c - ldc, // ldc - ); + macro_rules! gemm { + ($ty:tt, $gemm:ident) => { + if same_type::() { + // gemm is C ← αA^Op B^Op + βC + // Where Op is notrans/trans/conjtrans + unsafe { + blas_sys::$gemm( + cblas_layout, + a_trans, + b_trans, + m as blas_index, // m, rows of Op(a) + n as blas_index, // n, cols of Op(b) + k as blas_index, // k, cols of Op(a) + gemm_scalar_cast!($ty, alpha), // alpha + a.ptr.as_ptr() as *const _, // a + lda, // lda + b.ptr.as_ptr() as *const _, // b + ldb, // ldb + gemm_scalar_cast!($ty, beta), // beta + c.ptr.as_ptr() as *mut _, // c + ldc, // ldc + ); + } + return; } - return; - } - }; - } + }; + } - gemm!(f32, cblas_sgemm); - gemm!(f64, cblas_dgemm); - gemm!(c32, cblas_cgemm); - gemm!(c64, cblas_zgemm); + gemm!(f32, cblas_sgemm); + gemm!(f64, cblas_dgemm); + gemm!(c32, cblas_cgemm); + gemm!(c64, cblas_zgemm); - break 'blas_block; + unreachable!() // we checked above that A is one of f32, f64, c32, c64 + } } mat_mul_general(alpha, a, b, beta, c) } From 876ad012f048bd2063746499102d045e11fdb8e5 Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Fri, 9 Aug 2024 21:46:25 +0200 Subject: [PATCH 20/48] blas: test with more than one pattern in data Implement a checkerboard pattern in input data just to test with some another kind of input. --- crates/blas-tests/tests/oper.rs | 16 ++++++++++------ crates/ndarray-gen/src/array_builder.rs | 17 ++++++++--------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index f1e1bc42b..a9dca7e83 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -12,6 +12,7 @@ use ndarray::linalg::general_mat_vec_mul; use ndarray::Order; use ndarray::{Data, Ix, LinalgScalar}; use ndarray_gen::array_builder::ArrayBuilder; +use ndarray_gen::array_builder::ElementGenerator; use approx::assert_relative_eq; use defmac::defmac; @@ -230,7 +231,6 @@ fn gen_mat_mul() let sizes = vec![ (4, 4, 4), (8, 8, 8), - (10, 10, 10), (8, 8, 1), (1, 10, 10), (10, 1, 10), @@ -241,19 +241,23 @@ fn gen_mat_mul() (4, 17, 3), (17, 3, 22), (19, 18, 2), - (16, 17, 15), (15, 16, 17), - (67, 63, 62), + (67, 50, 62), ]; let strides = &[1, 2, -1, -2]; let cf_order = [Order::C, Order::F]; + let generator = [ElementGenerator::Sequential, ElementGenerator::Checkerboard]; // test different strides and memory orders - for (&s1, &s2) in iproduct!(strides, strides) { + for (&s1, &s2, &gen) in iproduct!(strides, strides, &generator) { for &(m, k, n) in &sizes { for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) { - println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3); - let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5; + println!("Case s1={}, s2={}, gen={:?}, orders={:?}, {:?}, {:?}", s1, s2, gen, ord1, ord2, ord3); + let a = ArrayBuilder::new((m, k)) + .memory_order(ord1) + .generator(gen) + .build() + * 0.5; let b = ArrayBuilder::new((k, n)).memory_order(ord2).build(); let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build(); diff --git a/crates/ndarray-gen/src/array_builder.rs b/crates/ndarray-gen/src/array_builder.rs index a021e5252..9351aadc5 100644 --- a/crates/ndarray-gen/src/array_builder.rs +++ b/crates/ndarray-gen/src/array_builder.rs @@ -26,6 +26,7 @@ pub struct ArrayBuilder pub enum ElementGenerator { Sequential, + Checkerboard, Zero, } @@ -64,16 +65,14 @@ where D: Dimension pub fn build(self) -> Array where T: Num + Clone { - let mut current = T::zero(); + let zero = T::zero(); let size = self.dim.size(); - let use_zeros = self.generator == ElementGenerator::Zero; - Array::from_iter((0..size).map(|_| { - let ret = current.clone(); - if !use_zeros { - current = ret.clone() + T::one(); - } - ret - })) + (match self.generator { + ElementGenerator::Sequential => + Array::from_iter(core::iter::successors(Some(zero), |elt| Some(elt.clone() + T::one())).take(size)), + ElementGenerator::Checkerboard => Array::from_iter([T::one(), zero].iter().cycle().take(size).cloned()), + ElementGenerator::Zero => Array::zeros(size), + }) .into_shape_with_order((self.dim, self.memory_order)) .unwrap() } From 1df6c32d73d148df1e9477e0f1d7a45dad4c4de8 Mon Sep 17 00:00:00 2001 From: akern40 Date: Sun, 11 Aug 2024 11:15:36 -0400 Subject: [PATCH 21/48] Fix infinite recursion, overflow, and off-by-one error in triu/tril (#1418) * Fixes infinite recursion and off-by-one error * Avoids overflow using saturating arithmetic * Removes unused import * Fixes bug for isize::MAX for triu * Fix formatting * Uses broadcast indices to remove D::Smaller: Copy trait bound --- src/tri.rs | 183 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 129 insertions(+), 54 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 4eab9e105..b7d297fcc 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -6,18 +6,25 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use core::cmp::{max, min}; +use core::cmp::min; use num_traits::Zero; -use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip}; +use crate::{ + dimension::{is_layout_c, is_layout_f}, + Array, + ArrayBase, + Axis, + Data, + Dimension, + Zip, +}; impl ArrayBase where S: Data, D: Dimension, A: Clone + Zero, - D::Smaller: Copy, { /// Upper triangular of an array. /// @@ -30,38 +37,56 @@ where /// ``` /// use ndarray::array; /// - /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; - /// let res = arr.triu(0); - /// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.triu(0), + /// array![ + /// [1, 2, 3], + /// [0, 5, 6], + /// [0, 0, 9] + /// ] + /// ); /// ``` pub fn triu(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } - match is_layout_f(&self.dim, &self.strides) { - true => { - let n = self.ndim(); - let mut x = self.view(); - x.swap_axes(n - 2, n - 1); - let mut tril = x.tril(-k); - tril.swap_axes(n - 2, n - 1); - - tril - } - false => { - let mut res = Array::zeros(self.raw_dim()); - Zip::indexed(self.rows()) - .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); - let lower = max(row_num as isize + k, 0); - dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); - }); - - res - } + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.tril(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; } + + let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(n - 1)); + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) + .and(res.rows_mut()) + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { + let mut lower = match k >= 0 { + true => row_num.saturating_add(k as usize), // Avoid overflow + false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 + }; + lower = min(lower, ncols); + dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + }); + + res } /// Lower triangular of an array. @@ -75,45 +100,65 @@ where /// ``` /// use ndarray::array; /// - /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; - /// let res = arr.tril(0); - /// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.tril(0), + /// array![ + /// [1, 0, 0], + /// [4, 5, 0], + /// [7, 8, 9] + /// ] + /// ); /// ``` pub fn tril(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } - match is_layout_f(&self.dim, &self.strides) { - true => { - let n = self.ndim(); - let mut x = self.view(); - x.swap_axes(n - 2, n - 1); - let mut tril = x.triu(-k); - tril.swap_axes(n - 2, n - 1); - - tril - } - false => { - let mut res = Array::zeros(self.raw_dim()); - let ncols = self.len_of(Axis(self.ndim() - 1)) as isize; - Zip::indexed(self.rows()) - .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); - let upper = min(row_num as isize + k, ncols) + 1; - dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); - }); - - res - } + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.triu(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; } + + let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(n - 1)); + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) + .and(res.rows_mut()) + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { + // let row_num = i.into_dimension().last_elem(); + let mut upper = match k >= 0 { + true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow + false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow + }; + upper = min(upper, ncols); + dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + }); + + res } } #[cfg(test)] mod tests { + use core::isize; + use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; use alloc::vec; @@ -188,6 +233,19 @@ mod tests assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); } + #[test] + fn test_2d_single() + { + let x = array![[1]]; + + assert_eq!(x.triu(0), array![[1]]); + assert_eq!(x.tril(0), array![[1]]); + assert_eq!(x.triu(1), array![[0]]); + assert_eq!(x.tril(1), array![[1]]); + assert_eq!(x.triu(-1), array![[1]]); + assert_eq!(x.tril(-1), array![[0]]); + } + #[test] fn test_3d() { @@ -285,8 +343,25 @@ mod tests let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]); + let x = array![[1, 2], [3, 4], [5, 6]]; let res = x.triu(0); assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]); + + let res = x.tril(0); + assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]); + } + + #[test] + fn test_odd_k() + { + let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + let z = Array2::zeros([3, 3]); + assert_eq!(x.triu(isize::MIN), x); + assert_eq!(x.tril(isize::MIN), z); + assert_eq!(x.triu(isize::MAX), z); + assert_eq!(x.tril(isize::MAX), x); } } From 6f77377d7d508550bf516e54c142cee3ab243aeb Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 14 Aug 2024 19:29:05 +0200 Subject: [PATCH 22/48] 0.16.1 --- Cargo.toml | 2 +- RELEASES.md | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 77f8a01b7..5c7217025 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "ndarray" -version = "0.16.0" +version = "0.16.1" edition = "2018" rust-version = "1.64" authors = [ diff --git a/RELEASES.md b/RELEASES.md index 04c7c9250..8b4786666 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -1,3 +1,12 @@ +Version 0.16.1 (2024-08-14) +=========================== + +- Refactor and simplify BLAS gemm call further by [@bluss](https://github.com/bluss) [#1421](https://github.com/rust-ndarray/ndarray/pull/1421) +- Fix infinite recursion and off-by-one error in triu/tril by [@akern40](https://github.com/akern40) [#1418](https://github.com/rust-ndarray/ndarray/pull/1418) +- Fix using BLAS for all compatible cases of memory layout by [@bluss](https://github.com/bluss) [#1419](https://github.com/rust-ndarray/ndarray/pull/1419) +- Use PR check instead of Merge Queue, and check rustdoc by [@bluss](https://github.com/bluss) [#1420](https://github.com/rust-ndarray/ndarray/pull/1420) +- Make iterators covariant in element type by [@bluss](https://github.com/bluss) [#1417](https://github.com/rust-ndarray/ndarray/pull/1417) + Version 0.16.0 (2024-08-03) =========================== From 1304f9d1bfa26f9c85da2c2ac192435fa9fed16b Mon Sep 17 00:00:00 2001 From: Ulrik Sverdrup Date: Wed, 21 Aug 2024 17:36:56 +0200 Subject: [PATCH 23/48] Fix uniqueness in last_mut() Last mut did not ensure the array was unique before calling uget_mut. The required properties were protected by a debug assertion, but a clear bug in release mode. Adding tests that would have caught this. --- src/impl_methods.rs | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 589a5c83c..f506204b6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -406,6 +406,7 @@ where if self.is_empty() { None } else { + self.ensure_unique(); let mut index = self.raw_dim(); for ax in 0..index.ndim() { index[ax] -= 1; @@ -3081,6 +3082,7 @@ mod tests { use super::*; use crate::arr3; + use defmac::defmac; #[test] fn test_flatten() @@ -3107,4 +3109,45 @@ mod tests let flattened = array.into_flat(); assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); } + + #[test] + fn test_first_last() + { + let first = 2; + let last = 3; + + defmac!(assert_first mut array => { + assert_eq!(array.first().copied(), Some(first)); + assert_eq!(array.first_mut().copied(), Some(first)); + }); + defmac!(assert_last mut array => { + assert_eq!(array.last().copied(), Some(last)); + assert_eq!(array.last_mut().copied(), Some(last)); + }); + + let base = Array::from_vec(vec![first, last]); + let a = base.clone(); + assert_first!(a); + + let a = base.clone(); + assert_last!(a); + + let a = CowArray::from(base.view()); + assert_first!(a); + let a = CowArray::from(base.view()); + assert_last!(a); + + let a = CowArray::from(base.clone()); + assert_first!(a); + let a = CowArray::from(base.clone()); + assert_last!(a); + + let a = ArcArray::from(base.clone()); + let _a2 = a.clone(); + assert_last!(a); + + let a = ArcArray::from(base.clone()); + let _a2 = a.clone(); + assert_first!(a); + } } From 7843a3bc3b1f00c4804346c6b637754c26bac1cf Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Wed, 21 Aug 2024 23:12:19 -0400 Subject: [PATCH 24/48] Adds vscode editor settings to gitignore Necessary for allowing format-on-save to use nightly for this repo only. --- .gitignore | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.gitignore b/.gitignore index 1e7caa9ea..dd9ffd9fe 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,6 @@ +# Rust items Cargo.lock target/ + +# Editor settings +.vscode From 5dc62e60f4477592f22bca4e9cfc166602a91051 Mon Sep 17 00:00:00 2001 From: benliepert Date: Fri, 6 Sep 2024 08:45:37 -0400 Subject: [PATCH 25/48] Tweak documentation for into_raw_vec_and_offset (#1432) --- src/impl_owned_array.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index 44ac12dd4..bb970f876 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -79,7 +79,7 @@ where D: Dimension /// Return a vector of the elements in the array, in the way they are /// stored internally, and the index in the vector corresponding to the - /// logically first element of the array (or 0 if the array is empty). + /// logically first element of the array (or None if the array is empty). /// /// If the array is in standard memory layout, the logical element order /// of the array (`.iter()` order) and of the returned vector will be the same. From c7ebd35b79e82074977da2c67c2a25cba8ae6bcc Mon Sep 17 00:00:00 2001 From: Philip Trauth Date: Mon, 16 Sep 2024 16:52:45 +0200 Subject: [PATCH 26/48] Removed typo Just removed double and --- src/impl_methods.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index f506204b6..4a00ea000 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -1030,7 +1030,7 @@ where } /// Along `axis`, select arbitrary subviews corresponding to `indices` - /// and and copy them into a new array. + /// and copy them into a new array. /// /// **Panics** if `axis` or an element of `indices` is out of bounds. /// From fce60345ddbf576d6e5cd8c4b0c8104d00c9e326 Mon Sep 17 00:00:00 2001 From: Johann Carl Meyer <32302462+johann-cm@users.noreply.github.com> Date: Thu, 26 Sep 2024 15:10:33 +0200 Subject: [PATCH 27/48] Add `diff` method as an equivalent to `numpy.diff` (#1437) * implement forward finite differneces on arrays * implement tests for the method * remove some heap allocations --- src/numeric/impl_numeric.rs | 58 ++++++++++++++++++++++++++++++- tests/numeric.rs | 68 +++++++++++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 1 deletion(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 7306fc727..6c67b9135 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -10,10 +10,11 @@ use num_traits::Float; use num_traits::One; use num_traits::{FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul}; +use std::ops::{Add, Div, Mul, Sub}; use crate::imp_prelude::*; use crate::numeric_util; +use crate::Slice; /// # Numerical Methods for Arrays impl ArrayBase @@ -437,4 +438,59 @@ where { self.var_axis(axis, ddof).mapv_into(|x| x.sqrt()) } + + /// Calculates the (forward) finite differences of order `n`, along the `axis`. + /// For the 1D-case, `n==1`, this means: `diff[i] == arr[i+1] - arr[i]` + /// + /// For `n>=2`, the process is iterated: + /// ``` + /// use ndarray::{array, Axis}; + /// let arr = array![1.0, 2.0, 5.0]; + /// assert_eq!(arr.diff(2, Axis(0)), arr.diff(1, Axis(0)).diff(1, Axis(0))) + /// ``` + /// **Panics** if `axis` is out of bounds + /// + /// **Panics** if `n` is too big / the array is to short: + /// ```should_panic + /// use ndarray::{array, Axis}; + /// array![1.0, 2.0, 3.0].diff(10, Axis(0)); + /// ``` + pub fn diff(&self, n: usize, axis: Axis) -> Array + where A: Sub + Zero + Clone + { + if n == 0 { + return self.to_owned(); + } + assert!(axis.0 < self.ndim(), "The array has only ndim {}, but `axis` {:?} is given.", self.ndim(), axis); + assert!( + n < self.shape()[axis.0], + "The array must have length at least `n+1`=={} in the direction of `axis`. It has length {}", + n + 1, + self.shape()[axis.0] + ); + + let mut inp = self.to_owned(); + let mut out = Array::zeros({ + let mut inp_dim = self.raw_dim(); + // inp_dim[axis.0] >= 1 as per the 2nd assertion. + inp_dim[axis.0] -= 1; + inp_dim + }); + for _ in 0..n { + let head = inp.slice_axis(axis, Slice::from(..-1)); + let tail = inp.slice_axis(axis, Slice::from(1..)); + + azip!((o in &mut out, h in head, t in tail) *o = t.clone() - h.clone()); + + // feed the output as the input to the next iteration + std::mem::swap(&mut inp, &mut out); + + // adjust the new output array width along `axis`. + // Current situation: width of `inp`: k, `out`: k+1 + // needed width: `inp`: k, `out`: k-1 + // slice is possible, since k >= 1. + out.slice_axis_inplace(axis, Slice::from(..-2)); + } + inp + } } diff --git a/tests/numeric.rs b/tests/numeric.rs index f6de146c9..2395366b0 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -336,3 +336,71 @@ fn std_axis_empty_axis() assert_eq!(v.shape(), &[2]); v.mapv(|x| assert!(x.is_nan())); } + +#[test] +fn diff_1d_order1() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + let expected = array![1.0, 2.0, 3.0]; + assert_eq!(data.diff(1, Axis(0)), expected); +} + +#[test] +fn diff_1d_order2() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + assert_eq!( + data.diff(2, Axis(0)), + data.diff(1, Axis(0)).diff(1, Axis(0)) + ); +} + +#[test] +fn diff_1d_order3() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + assert_eq!( + data.diff(3, Axis(0)), + data.diff(1, Axis(0)).diff(1, Axis(0)).diff(1, Axis(0)) + ); +} + +#[test] +fn diff_2d_order1_ax0() +{ + let data = array![ + [1.0, 2.0, 4.0, 7.0], + [1.0, 3.0, 6.0, 6.0], + [1.5, 3.5, 5.5, 5.5] + ]; + let expected = array![[0.0, 1.0, 2.0, -1.0], [0.5, 0.5, -0.5, -0.5]]; + assert_eq!(data.diff(1, Axis(0)), expected); +} + +#[test] +fn diff_2d_order1_ax1() +{ + let data = array![ + [1.0, 2.0, 4.0, 7.0], + [1.0, 3.0, 6.0, 6.0], + [1.5, 3.5, 5.5, 5.5] + ]; + let expected = array![[1.0, 2.0, 3.0], [2.0, 3.0, 0.0], [2.0, 2.0, 0.0]]; + assert_eq!(data.diff(1, Axis(1)), expected); +} + +#[test] +#[should_panic] +fn diff_panic_n_too_big() +{ + let data = array![1.0, 2.0, 4.0, 7.0]; + data.diff(10, Axis(0)); +} + +#[test] +#[should_panic] +fn diff_panic_axis_out_of_bounds() +{ + let data = array![1, 2, 4, 7]; + data.diff(1, Axis(2)); +} From f1153bfc31639398c2e013bd6c6fb885eab83d1c Mon Sep 17 00:00:00 2001 From: XXMA16 Date: Sun, 6 Oct 2024 21:47:23 +0300 Subject: [PATCH 28/48] Ignore Jetbrains IDE config folder Signed-off-by: XXMA16 --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index dd9ffd9fe..e9b5ca25b 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ target/ # Editor settings .vscode +.idea From 492b2742073cf531635d701ced4e01a827038f69 Mon Sep 17 00:00:00 2001 From: akern40 Date: Thu, 24 Oct 2024 08:26:53 -0400 Subject: [PATCH 29/48] Fixes no_std + approx combination (#1448) Fixes no_std + approx combination These two features can coexist; fixing them included: - Slightly altering tests to avoid `std` fns - Adding `feature = "std"` on some "approx" tests - Adding a line to the test script to catch this in the future --- crates/serialization-tests/tests/serialize.rs | 12 ++++++------ scripts/all-tests.sh | 2 ++ tests/azip.rs | 2 +- tests/numeric.rs | 2 ++ 4 files changed, 11 insertions(+), 7 deletions(-) diff --git a/crates/serialization-tests/tests/serialize.rs b/crates/serialization-tests/tests/serialize.rs index 6e6fb4d64..478eb20ef 100644 --- a/crates/serialization-tests/tests/serialize.rs +++ b/crates/serialization-tests/tests/serialize.rs @@ -45,13 +45,13 @@ fn serial_many_dim_serde() { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32) + let mut a = ArcArray::from_iter(0..32) .into_shape_with_order((2, 2, 2, 4)) .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); let serial = serde_json::to_string(&a).unwrap(); println!("Encode {:?} => {:?}", a, serial); - let res = serde_json::from_str::>(&serial); + let res = serde_json::from_str::>(&serial); println!("{:?}", res); assert_eq!(a, res.unwrap()); } @@ -160,7 +160,7 @@ fn serial_many_dim_serde_msgpack() { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32) + let mut a = ArcArray::from_iter(0..32) .into_shape_with_order((2, 2, 2, 4)) .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); @@ -171,7 +171,7 @@ fn serial_many_dim_serde_msgpack() .unwrap(); let mut deserializer = rmp_serde::Deserializer::new(&buf[..]); - let a_de: ArcArray = serde::Deserialize::deserialize(&mut deserializer).unwrap(); + let a_de: ArcArray = serde::Deserialize::deserialize(&mut deserializer).unwrap(); assert_eq!(a, a_de); } @@ -215,14 +215,14 @@ fn serial_many_dim_ron() { // Test a sliced array. - let mut a = ArcArray::linspace(0., 31., 32) + let mut a = ArcArray::from_iter(0..32) .into_shape_with_order((2, 2, 2, 4)) .unwrap(); a.slice_collapse(s![..;-1, .., .., ..2]); let a_s = ron_serialize(&a).unwrap(); - let a_de: ArcArray = ron_deserialize(&a_s).unwrap(); + let a_de: ArcArray = ron_deserialize(&a_s).unwrap(); assert_eq!(a, a_de); } diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index 4ececbcbd..b9af6b65a 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -13,6 +13,8 @@ cargo build -v --no-default-features # ndarray with no features cargo test -p ndarray -v --no-default-features +# ndarray with no_std-compatible features +cargo test -p ndarray -v --no-default-features --features approx # all with features cargo test -v --features "$FEATURES" $QC_FEAT # all with features and release (ignore test crates which is already optimized) diff --git a/tests/azip.rs b/tests/azip.rs index a4bb6ffac..d1ab5ba2a 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -232,7 +232,7 @@ fn test_azip3_slices() *a += b / 10.; *c = a.sin(); }); - let res = Array::linspace(0., 3.1, 32).mapv_into(f32::sin); + let res = Array::from_iter(0..32).mapv(|x| f32::sin(x as f32 / 10.)); assert_abs_diff_eq!(res, ArrayView::from(&c), epsilon = 1e-4); } diff --git a/tests/numeric.rs b/tests/numeric.rs index 2395366b0..839aba58e 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -163,6 +163,7 @@ fn std_empty_arr() #[test] #[cfg(feature = "approx")] +#[cfg(feature = "std")] fn var_axis() { use ndarray::{aview0, aview2}; @@ -222,6 +223,7 @@ fn var_axis() #[test] #[cfg(feature = "approx")] +#[cfg(feature = "std")] fn std_axis() { use ndarray::aview2; From fd3ce5d5ee8f2774d445d9b7071821a8bc8e30f2 Mon Sep 17 00:00:00 2001 From: akern40 Date: Tue, 29 Oct 2024 22:51:52 -0400 Subject: [PATCH 30/48] Adds Miri to CI/CD (#1446) This is carefully constructed to allow Miri to test most of `ndarray` without slowing down CI/CD very badly; as a result, it skips a number of slow tests. See #1446 for a list. It also excludes `blas` because Miri cannot call `cblas_gemm`, and it excludes `rayon` because it considers the still-running thread pool to be a leak. `rayon` can be re-added when rust-lang/miri#1371 is resolved. --- .github/workflows/ci.yaml | 12 ++++++++++++ ndarray-rand/tests/tests.rs | 2 ++ scripts/miri-tests.sh | 18 ++++++++++++++++++ src/dimension/mod.rs | 1 + tests/array.rs | 1 + tests/azip.rs | 2 +- tests/dimension.rs | 1 + tests/iterators.rs | 1 + tests/oper.rs | 5 +++++ 9 files changed, 42 insertions(+), 1 deletion(-) create mode 100755 scripts/miri-tests.sh diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index f36591741..0f517fac9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -88,6 +88,17 @@ jobs: run: sudo apt-get install libopenblas-dev gfortran - run: ./scripts/all-tests.sh "$FEATURES" ${{ matrix.rust }} + miri: + runs-on: ubuntu-latest + name: miri + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + with: + components: miri + - uses: Swatinem/rust-cache@v2 + - run: ./scripts/miri-tests.sh + cross_test: #if: ${{ github.event_name == 'merge_group' }} runs-on: ubuntu-latest @@ -149,6 +160,7 @@ jobs: - format # should format be required? - nostd - tests + - miri - cross_test - cargo-careful - docs diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index 2db040310..e39347c0c 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -57,6 +57,7 @@ fn oversampling_without_replacement_should_panic() } quickcheck! { + #[cfg_attr(miri, ignore)] // Takes an insufferably long time fn oversampling_with_replacement_is_fine(m: u8, n: u8) -> TestResult { let (m, n) = (m as usize, n as usize); let a = Array::random((m, n), Uniform::new(0., 2.)); @@ -86,6 +87,7 @@ quickcheck! { #[cfg(feature = "quickcheck")] quickcheck! { + #[cfg_attr(miri, ignore)] // This takes *forever* with Miri fn sampling_behaves_as_expected(m: u8, n: u8, strategy: SamplingStrategy) -> TestResult { let (m, n) = (m as usize, n as usize); let a = Array::random((m, n), Uniform::new(0., 2.)); diff --git a/scripts/miri-tests.sh b/scripts/miri-tests.sh new file mode 100755 index 000000000..0100f3e6a --- /dev/null +++ b/scripts/miri-tests.sh @@ -0,0 +1,18 @@ +#!/bin/sh + +set -x +set -e + +# We rely on layout-dependent casts, which should be covered with #[repr(transparent)] +# This should catch if we missed that +RUSTFLAGS="-Zrandomize-layout" + +# Miri reports a stacked borrow violation deep within rayon, in a crate called crossbeam-epoch +# The crate has a PR to fix this: https://github.com/crossbeam-rs/crossbeam/pull/871 +# but using Miri's tree borrow mode may resolve it for now. +# Disabled until we can figure out a different rayon issue: https://github.com/rust-lang/miri/issues/1371 +# MIRIFLAGS="-Zmiri-tree-borrows" + +# General tests +# Note that we exclude blas feature because Miri can't do cblas_gemm +cargo miri test -v -p ndarray -p ndarray-rand --features approx,serde diff --git a/src/dimension/mod.rs b/src/dimension/mod.rs index 601f0dc43..eb07252b2 100644 --- a/src/dimension/mod.rs +++ b/src/dimension/mod.rs @@ -1020,6 +1020,7 @@ mod test } quickcheck! { + #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines // FIXME: This test is extremely slow, even with i16 values, investigate fn arith_seq_intersect_correct( first1: i8, len1: i8, step1: i8, diff --git a/tests/array.rs b/tests/array.rs index 696904dab..ac38fdd03 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2629,6 +2629,7 @@ mod array_cow_tests }); } + #[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] fn test_clone_from() { diff --git a/tests/azip.rs b/tests/azip.rs index d1ab5ba2a..96be9d913 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -216,7 +216,7 @@ fn test_azip2_sum() } #[test] -#[cfg(feature = "approx")] +#[cfg(all(feature = "approx", feature = "std"))] fn test_azip3_slices() { use approx::assert_abs_diff_eq; diff --git a/tests/dimension.rs b/tests/dimension.rs index 6a9207e4c..fe53d96b3 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -323,6 +323,7 @@ fn test_array_view() } #[test] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[cfg(feature = "std")] #[allow(clippy::cognitive_complexity)] fn test_all_ndindex() diff --git a/tests/iterators.rs b/tests/iterators.rs index 908b64d15..bdfd3ee50 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -971,6 +971,7 @@ fn test_into_iter_2d() assert_eq!(v, [1, 3, 2, 4]); } +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] fn test_into_iter_sliced() { diff --git a/tests/oper.rs b/tests/oper.rs index 5e3e669d0..401913e2b 100644 --- a/tests/oper.rs +++ b/tests/oper.rs @@ -502,6 +502,7 @@ fn scaled_add() } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] fn scaled_add_2() { @@ -540,6 +541,7 @@ fn scaled_add_2() } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] fn scaled_add_3() { @@ -592,6 +594,7 @@ fn scaled_add_3() } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] #[test] fn gen_mat_mul() { @@ -681,6 +684,7 @@ fn gen_mat_mul_i32() #[cfg(feature = "approx")] #[test] +#[cfg_attr(miri, ignore)] // Takes too long fn gen_mat_vec_mul() { use approx::assert_relative_eq; @@ -746,6 +750,7 @@ fn gen_mat_vec_mul() } #[cfg(feature = "approx")] +#[cfg_attr(miri, ignore)] // Very slow on CI/CD machines #[test] fn vec_mat_mul() { From 9c703ac8a7f86dce8b0b5949731b2bf364230851 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Sat, 30 Nov 2024 12:57:46 -0500 Subject: [PATCH 31/48] Fixing lifetime elisions and minor clippy complaints --- crates/ndarray-gen/src/lib.rs | 1 - examples/bounds_check_elim.rs | 2 +- ndarray-rand/tests/tests.rs | 2 +- src/argument_traits.rs | 8 +++---- src/array_serde.rs | 4 ++-- src/arraytraits.rs | 4 ++-- src/data_traits.rs | 24 ++++++++++---------- src/dimension/axes.rs | 4 ++-- src/dimension/ndindex.rs | 6 ++--- src/impl_cow.rs | 2 +- src/impl_views/constructors.rs | 4 ++-- src/impl_views/indexing.rs | 2 +- src/impl_views/splitting.rs | 2 +- src/iterators/mod.rs | 40 +++++++++++++++++----------------- src/lib.rs | 2 +- src/split_at.rs | 2 +- tests/azip.rs | 12 +++++----- 17 files changed, 60 insertions(+), 61 deletions(-) diff --git a/crates/ndarray-gen/src/lib.rs b/crates/ndarray-gen/src/lib.rs index 7f9ca89fc..09440e68d 100644 --- a/crates/ndarray-gen/src/lib.rs +++ b/crates/ndarray-gen/src/lib.rs @@ -8,5 +8,4 @@ // except according to those terms. /// Build ndarray arrays for test purposes - pub mod array_builder; diff --git a/examples/bounds_check_elim.rs b/examples/bounds_check_elim.rs index e6b57c719..f1a91cca0 100644 --- a/examples/bounds_check_elim.rs +++ b/examples/bounds_check_elim.rs @@ -57,7 +57,7 @@ pub fn test1d_single_mut(a: &mut Array1, i: usize) -> f64 #[no_mangle] pub fn test1d_len_of(a: &Array1) -> f64 { - let a = &*a; + let a = a; let mut sum = 0.; for i in 0..a.len_of(Axis(0)) { sum += a[i]; diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index e39347c0c..d38e8636e 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -122,7 +122,7 @@ fn sampling_works(a: &Array2, strategy: SamplingStrategy, axis: Axis, n_sam let samples = a.sample_axis(axis, n_samples, strategy); samples .axis_iter(axis) - .all(|lane| is_subset(&a, &lane, axis)) + .all(|lane| is_subset(a, &lane, axis)) } // Check if, when sliced along `axis`, there is at least one lane in `a` equal to `b` diff --git a/src/argument_traits.rs b/src/argument_traits.rs index de8ac7f99..c4e85186a 100644 --- a/src/argument_traits.rs +++ b/src/argument_traits.rs @@ -11,7 +11,7 @@ pub trait AssignElem } /// Assignable element, simply `*self = input`. -impl<'a, T> AssignElem for &'a mut T +impl AssignElem for &mut T { fn assign_elem(self, input: T) { @@ -20,7 +20,7 @@ impl<'a, T> AssignElem for &'a mut T } /// Assignable element, simply `self.set(input)`. -impl<'a, T> AssignElem for &'a Cell +impl AssignElem for &Cell { fn assign_elem(self, input: T) { @@ -29,7 +29,7 @@ impl<'a, T> AssignElem for &'a Cell } /// Assignable element, simply `self.set(input)`. -impl<'a, T> AssignElem for &'a MathCell +impl AssignElem for &MathCell { fn assign_elem(self, input: T) { @@ -39,7 +39,7 @@ impl<'a, T> AssignElem for &'a MathCell /// Assignable element, the item in the MaybeUninit is overwritten (prior value, if any, is not /// read or dropped). -impl<'a, T> AssignElem for &'a mut MaybeUninit +impl AssignElem for &mut MaybeUninit { fn assign_elem(self, input: T) { diff --git a/src/array_serde.rs b/src/array_serde.rs index 31b613d4c..50d9c2905 100644 --- a/src/array_serde.rs +++ b/src/array_serde.rs @@ -98,7 +98,7 @@ where // private iterator wrapper struct Sequence<'a, A, D>(Iter<'a, A, D>); -impl<'a, A, D> Serialize for Sequence<'a, A, D> +impl Serialize for Sequence<'_, A, D> where A: Serialize, D: Dimension + Serialize, @@ -162,7 +162,7 @@ impl<'de> Deserialize<'de> for ArrayField { struct ArrayFieldVisitor; - impl<'de> Visitor<'de> for ArrayFieldVisitor + impl Visitor<'_> for ArrayFieldVisitor { type Value = ArrayField; diff --git a/src/arraytraits.rs b/src/arraytraits.rs index e68b5d56a..d7a00fcfe 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -128,7 +128,7 @@ where /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. #[allow(clippy::unconditional_recursion)] // false positive -impl<'a, A, B, S, S2, D> PartialEq<&'a ArrayBase> for ArrayBase +impl PartialEq<&ArrayBase> for ArrayBase where A: PartialEq, S: Data, @@ -144,7 +144,7 @@ where /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. #[allow(clippy::unconditional_recursion)] // false positive -impl<'a, A, B, S, S2, D> PartialEq> for &'a ArrayBase +impl PartialEq> for &ArrayBase where A: PartialEq, S: Data, diff --git a/src/data_traits.rs b/src/data_traits.rs index f43bfb4ef..fc2fe4bfa 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -407,7 +407,7 @@ where A: Clone } } -unsafe impl<'a, A> RawData for ViewRepr<&'a A> +unsafe impl RawData for ViewRepr<&A> { type Elem = A; @@ -420,7 +420,7 @@ unsafe impl<'a, A> RawData for ViewRepr<&'a A> private_impl! {} } -unsafe impl<'a, A> Data for ViewRepr<&'a A> +unsafe impl Data for ViewRepr<&A> { fn into_owned(self_: ArrayBase) -> Array where @@ -437,7 +437,7 @@ unsafe impl<'a, A> Data for ViewRepr<&'a A> } } -unsafe impl<'a, A> RawDataClone for ViewRepr<&'a A> +unsafe impl RawDataClone for ViewRepr<&A> { unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) { @@ -445,7 +445,7 @@ unsafe impl<'a, A> RawDataClone for ViewRepr<&'a A> } } -unsafe impl<'a, A> RawData for ViewRepr<&'a mut A> +unsafe impl RawData for ViewRepr<&mut A> { type Elem = A; @@ -458,7 +458,7 @@ unsafe impl<'a, A> RawData for ViewRepr<&'a mut A> private_impl! {} } -unsafe impl<'a, A> RawDataMut for ViewRepr<&'a mut A> +unsafe impl RawDataMut for ViewRepr<&mut A> { #[inline] fn try_ensure_unique(_: &mut ArrayBase) @@ -475,7 +475,7 @@ unsafe impl<'a, A> RawDataMut for ViewRepr<&'a mut A> } } -unsafe impl<'a, A> Data for ViewRepr<&'a mut A> +unsafe impl Data for ViewRepr<&mut A> { fn into_owned(self_: ArrayBase) -> Array where @@ -492,7 +492,7 @@ unsafe impl<'a, A> Data for ViewRepr<&'a mut A> } } -unsafe impl<'a, A> DataMut for ViewRepr<&'a mut A> {} +unsafe impl DataMut for ViewRepr<&mut A> {} /// Array representation trait. /// @@ -533,7 +533,7 @@ pub unsafe trait DataOwned: Data pub unsafe trait DataShared: Clone + Data + RawDataClone {} unsafe impl DataShared for OwnedArcRepr {} -unsafe impl<'a, A> DataShared for ViewRepr<&'a A> {} +unsafe impl DataShared for ViewRepr<&A> {} unsafe impl DataOwned for OwnedRepr { @@ -571,7 +571,7 @@ unsafe impl DataOwned for OwnedArcRepr } } -unsafe impl<'a, A> RawData for CowRepr<'a, A> +unsafe impl RawData for CowRepr<'_, A> { type Elem = A; @@ -587,7 +587,7 @@ unsafe impl<'a, A> RawData for CowRepr<'a, A> private_impl! {} } -unsafe impl<'a, A> RawDataMut for CowRepr<'a, A> +unsafe impl RawDataMut for CowRepr<'_, A> where A: Clone { #[inline] @@ -615,7 +615,7 @@ where A: Clone } } -unsafe impl<'a, A> RawDataClone for CowRepr<'a, A> +unsafe impl RawDataClone for CowRepr<'_, A> where A: Clone { unsafe fn clone_with_ptr(&self, ptr: NonNull) -> (Self, NonNull) @@ -681,7 +681,7 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> } } -unsafe impl<'a, A> DataMut for CowRepr<'a, A> where A: Clone {} +unsafe impl DataMut for CowRepr<'_, A> where A: Clone {} unsafe impl<'a, A> DataOwned for CowRepr<'a, A> { diff --git a/src/dimension/axes.rs b/src/dimension/axes.rs index 45b7a75f0..c7aaff149 100644 --- a/src/dimension/axes.rs +++ b/src/dimension/axes.rs @@ -60,7 +60,7 @@ pub struct AxisDescription copy_and_clone!(AxisDescription); copy_and_clone!(['a, D] Axes<'a, D>); -impl<'a, D> Iterator for Axes<'a, D> +impl Iterator for Axes<'_, D> where D: Dimension { /// Description of the axis, its length and its stride. @@ -99,7 +99,7 @@ where D: Dimension } } -impl<'a, D> DoubleEndedIterator for Axes<'a, D> +impl DoubleEndedIterator for Axes<'_, D> where D: Dimension { fn next_back(&mut self) -> Option diff --git a/src/dimension/ndindex.rs b/src/dimension/ndindex.rs index 7bc2c54ef..ca2a3ea69 100644 --- a/src/dimension/ndindex.rs +++ b/src/dimension/ndindex.rs @@ -255,7 +255,7 @@ unsafe impl NdIndex for [Ix; N] } } -impl<'a> IntoDimension for &'a [Ix] +impl IntoDimension for &[Ix] { type Dim = IxDyn; fn into_dimension(self) -> Self::Dim @@ -264,7 +264,7 @@ impl<'a> IntoDimension for &'a [Ix] } } -unsafe impl<'a> NdIndex for &'a IxDyn +unsafe impl NdIndex for &IxDyn { fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { @@ -276,7 +276,7 @@ unsafe impl<'a> NdIndex for &'a IxDyn } } -unsafe impl<'a> NdIndex for &'a [Ix] +unsafe impl NdIndex for &[Ix] { fn index_checked(&self, dim: &IxDyn, strides: &IxDyn) -> Option { diff --git a/src/impl_cow.rs b/src/impl_cow.rs index f064ce7bd..4843e305b 100644 --- a/src/impl_cow.rs +++ b/src/impl_cow.rs @@ -11,7 +11,7 @@ use crate::imp_prelude::*; /// Methods specific to `CowArray`. /// /// ***See also all methods for [`ArrayBase`]*** -impl<'a, A, D> CowArray<'a, A, D> +impl CowArray<'_, A, D> where D: Dimension { /// Returns `true` iff the array is the view (borrowed) variant. diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index 15f2b9b6b..d0089057d 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -230,7 +230,7 @@ where D: Dimension } /// Private array view methods -impl<'a, A, D> ArrayView<'a, A, D> +impl ArrayView<'_, A, D> where D: Dimension { /// Create a new `ArrayView` @@ -254,7 +254,7 @@ where D: Dimension } } -impl<'a, A, D> ArrayViewMut<'a, A, D> +impl ArrayViewMut<'_, A, D> where D: Dimension { /// Create a new `ArrayView` diff --git a/src/impl_views/indexing.rs b/src/impl_views/indexing.rs index 2b72c2142..827313478 100644 --- a/src/impl_views/indexing.rs +++ b/src/impl_views/indexing.rs @@ -100,7 +100,7 @@ pub trait IndexLonger unsafe fn uget(self, index: I) -> Self::Output; } -impl<'a, 'b, I, A, D> IndexLonger for &'b ArrayView<'a, A, D> +impl<'a, I, A, D> IndexLonger for &ArrayView<'a, A, D> where I: NdIndex, D: Dimension, diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 6d6ea275b..58d0a7556 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -11,7 +11,7 @@ use crate::slice::MultiSliceArg; use num_complex::Complex; /// Methods for read-only array views. -impl<'a, A, D> ArrayView<'a, A, D> +impl ArrayView<'_, A, D> where D: Dimension { /// Split the array view along `axis` and return one view strictly before the diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index e7321d15b..01fff14f5 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -260,7 +260,7 @@ impl<'a, A> DoubleEndedIterator for ElementsBase<'a, A, Ix1> } } -impl<'a, A, D> ExactSizeIterator for ElementsBase<'a, A, D> +impl ExactSizeIterator for ElementsBase<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -503,7 +503,7 @@ impl<'a, A> DoubleEndedIterator for Iter<'a, A, Ix1> } } -impl<'a, A, D> ExactSizeIterator for Iter<'a, A, D> +impl ExactSizeIterator for Iter<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -534,7 +534,7 @@ impl<'a, A, D: Dimension> Iterator for IndexedIter<'a, A, D> } } -impl<'a, A, D> ExactSizeIterator for IndexedIter<'a, A, D> +impl ExactSizeIterator for IndexedIter<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -635,7 +635,7 @@ impl<'a, A> DoubleEndedIterator for IterMut<'a, A, Ix1> } } -impl<'a, A, D> ExactSizeIterator for IterMut<'a, A, D> +impl ExactSizeIterator for IterMut<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -686,7 +686,7 @@ impl<'a, A> DoubleEndedIterator for ElementsBaseMut<'a, A, Ix1> } } -impl<'a, A, D> ExactSizeIterator for ElementsBaseMut<'a, A, D> +impl ExactSizeIterator for ElementsBaseMut<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -717,7 +717,7 @@ impl<'a, A, D: Dimension> Iterator for IndexedIterMut<'a, A, D> } } -impl<'a, A, D> ExactSizeIterator for IndexedIterMut<'a, A, D> +impl ExactSizeIterator for IndexedIterMut<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -767,7 +767,7 @@ where D: Dimension } } -impl<'a, A, D> ExactSizeIterator for LanesIter<'a, A, D> +impl ExactSizeIterator for LanesIter<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -776,7 +776,7 @@ where D: Dimension } } -impl<'a, A> DoubleEndedIterator for LanesIter<'a, A, Ix1> +impl DoubleEndedIterator for LanesIter<'_, A, Ix1> { fn next_back(&mut self) -> Option { @@ -819,7 +819,7 @@ where D: Dimension } } -impl<'a, A, D> ExactSizeIterator for LanesIterMut<'a, A, D> +impl ExactSizeIterator for LanesIterMut<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -828,7 +828,7 @@ where D: Dimension } } -impl<'a, A> DoubleEndedIterator for LanesIterMut<'a, A, Ix1> +impl DoubleEndedIterator for LanesIterMut<'_, A, Ix1> { fn next_back(&mut self) -> Option { @@ -1079,7 +1079,7 @@ where D: Dimension } } -impl<'a, A, D> DoubleEndedIterator for AxisIter<'a, A, D> +impl DoubleEndedIterator for AxisIter<'_, A, D> where D: Dimension { fn next_back(&mut self) -> Option @@ -1088,7 +1088,7 @@ where D: Dimension } } -impl<'a, A, D> ExactSizeIterator for AxisIter<'a, A, D> +impl ExactSizeIterator for AxisIter<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -1169,7 +1169,7 @@ where D: Dimension } } -impl<'a, A, D> DoubleEndedIterator for AxisIterMut<'a, A, D> +impl DoubleEndedIterator for AxisIterMut<'_, A, D> where D: Dimension { fn next_back(&mut self) -> Option @@ -1178,7 +1178,7 @@ where D: Dimension } } -impl<'a, A, D> ExactSizeIterator for AxisIterMut<'a, A, D> +impl ExactSizeIterator for AxisIterMut<'_, A, D> where D: Dimension { fn len(&self) -> usize @@ -1187,7 +1187,7 @@ where D: Dimension } } -impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> +impl NdProducer for AxisIter<'_, A, D> { type Item = ::Item; type Dim = Ix1; @@ -1246,7 +1246,7 @@ impl<'a, A, D: Dimension> NdProducer for AxisIter<'a, A, D> private_impl! {} } -impl<'a, A, D: Dimension> NdProducer for AxisIterMut<'a, A, D> +impl NdProducer for AxisIterMut<'_, A, D> { type Item = ::Item; type Dim = Ix1; @@ -1555,12 +1555,12 @@ unsafe impl TrustedIterator for Linspace {} unsafe impl TrustedIterator for Geomspace {} #[cfg(feature = "std")] unsafe impl TrustedIterator for Logspace {} -unsafe impl<'a, A, D> TrustedIterator for Iter<'a, A, D> {} -unsafe impl<'a, A, D> TrustedIterator for IterMut<'a, A, D> {} +unsafe impl TrustedIterator for Iter<'_, A, D> {} +unsafe impl TrustedIterator for IterMut<'_, A, D> {} unsafe impl TrustedIterator for std::iter::Cloned where I: TrustedIterator {} unsafe impl TrustedIterator for std::iter::Map where I: TrustedIterator {} -unsafe impl<'a, A> TrustedIterator for slice::Iter<'a, A> {} -unsafe impl<'a, A> TrustedIterator for slice::IterMut<'a, A> {} +unsafe impl TrustedIterator for slice::Iter<'_, A> {} +unsafe impl TrustedIterator for slice::IterMut<'_, A> {} unsafe impl TrustedIterator for ::std::ops::Range {} // FIXME: These indices iter are dubious -- size needs to be checked up front. unsafe impl TrustedIterator for IndicesIter where D: Dimension {} diff --git a/src/lib.rs b/src/lib.rs index f52f25e5e..b163f16a5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1498,7 +1498,7 @@ pub enum CowRepr<'a, A> Owned(OwnedRepr), } -impl<'a, A> CowRepr<'a, A> +impl CowRepr<'_, A> { /// Returns `true` iff the data is the `View` variant. pub fn is_view(&self) -> bool diff --git a/src/split_at.rs b/src/split_at.rs index 4af1403c0..5dee44b63 100644 --- a/src/split_at.rs +++ b/src/split_at.rs @@ -35,7 +35,7 @@ where D: Dimension } } -impl<'a, A, D> SplitAt for ArrayViewMut<'a, A, D> +impl SplitAt for ArrayViewMut<'_, A, D> where D: Dimension { fn split_at(self, axis: Axis, index: usize) -> (Self, Self) diff --git a/tests/azip.rs b/tests/azip.rs index 96be9d913..9d8bebab7 100644 --- a/tests/azip.rs +++ b/tests/azip.rs @@ -118,7 +118,7 @@ fn test_zip_collect_drop() struct Recorddrop<'a>((usize, usize), &'a RefCell>); - impl<'a> Drop for Recorddrop<'a> + impl Drop for Recorddrop<'_> { fn drop(&mut self) { @@ -470,9 +470,9 @@ fn test_zip_all() let b = Array::::ones(62); let mut c = Array::::ones(62); c[5] = 0.0; - assert_eq!(true, Zip::from(&a).and(&b).all(|&x, &y| x + y == 1.0)); - assert_eq!(false, Zip::from(&a).and(&b).all(|&x, &y| x == y)); - assert_eq!(false, Zip::from(&a).and(&c).all(|&x, &y| x + y == 1.0)); + assert!(Zip::from(&a).and(&b).all(|&x, &y| x + y == 1.0)); + assert!(!Zip::from(&a).and(&b).all(|&x, &y| x == y)); + assert!(!Zip::from(&a).and(&c).all(|&x, &y| x + y == 1.0)); } #[test] @@ -480,6 +480,6 @@ fn test_zip_all_empty_array() { let a = Array::::zeros(0); let b = Array::::ones(0); - assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| true)); - assert_eq!(true, Zip::from(&a).and(&b).all(|&_x, &_y| false)); + assert!(Zip::from(&a).and(&b).all(|&_x, &_y| true)); + assert!(Zip::from(&a).and(&b).all(|&_x, &_y| false)); } From 4e61c87a7dcefac784deae749e0dd982800a9379 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Sat, 30 Nov 2024 12:58:36 -0500 Subject: [PATCH 32/48] Changing CI to account for BLAS requiring MSRV > 1.64 --- .github/workflows/ci.yaml | 14 ++++++++++++++ README.rst | 8 ++++++++ scripts/all-tests.sh | 5 +++-- scripts/blas-integ-tests.sh | 11 +++++++++++ 4 files changed, 36 insertions(+), 2 deletions(-) create mode 100755 scripts/blas-integ-tests.sh diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 0f517fac9..c910b32e0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -88,6 +88,20 @@ jobs: run: sudo apt-get install libopenblas-dev gfortran - run: ./scripts/all-tests.sh "$FEATURES" ${{ matrix.rust }} + blas-msrv: + runs-on: ubuntu-latest + name: blas-msrv + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@master + with: + toolchain: 1.67.0 # BLAS MSRV + - uses: rui314/setup-mold@v1 + - uses: Swatinem/rust-cache@v2 + - name: Install openblas + run: sudo apt-get install libopenblas-dev gfortran + - run: ./scripts/blas-integ-tests.sh "$FEATURES" 1.67.0 + miri: runs-on: ubuntu-latest name: miri diff --git a/README.rst b/README.rst index abac4c18e..ef6577f13 100644 --- a/README.rst +++ b/README.rst @@ -156,6 +156,14 @@ there is no tight coupling to the ``blas-src`` version, so version selection is 0.13 0.2.0 0.6.0 =========== ============ ================ ============== +------------ +BLAS on MSRV +------------ + +Although ``ndarray`` currently maintains an MSRV of 1.64.0, this is separate from the MSRV (either stated or real) of the various BLAS providers. +As of the time of writing, ``openblas`` currently supports MSRV of 1.67.0. +So, while ``ndarray`` and ``openblas-src`` are compatible, they can only work together with toolchains 1.67.0 or above. + Recent Changes -------------- diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index b9af6b65a..e98b90df1 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -23,8 +23,9 @@ cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FE # BLAS tests cargo test -p ndarray --lib -v --features blas cargo test -p blas-mock-tests -v -cargo test -p blas-tests -v --features blas-tests/openblas-system -cargo test -p numeric-tests -v --features numeric-tests/test_blas +if [ "$CHANNEL" != "1.64.0" ]; then + ./scripts/blas-integ-tests.sh "$FEATURES" $CHANNEL +fi # Examples cargo test --examples diff --git a/scripts/blas-integ-tests.sh b/scripts/blas-integ-tests.sh new file mode 100755 index 000000000..5192d67e3 --- /dev/null +++ b/scripts/blas-integ-tests.sh @@ -0,0 +1,11 @@ +#!/bin/sh + +set -x +set -e + +FEATURES=$1 +CHANNEL=$2 + +# BLAS tests +cargo test -p blas-tests -v --features blas-tests/openblas-system +cargo test -p numeric-tests -v --features numeric-tests/test_blas From d5f32ec06e27d8705dcf5da4e6596abf51188909 Mon Sep 17 00:00:00 2001 From: akern40 Date: Thu, 19 Dec 2024 21:28:07 -0800 Subject: [PATCH 33/48] Pin openblas to >=0.10.11 in order to fix blas-compatible MSRV to 0.71.1 (#1465) --- .github/workflows/ci.yaml | 8 ++++++-- README.rst | 4 ++-- crates/blas-tests/Cargo.toml | 2 +- crates/numeric-tests/Cargo.toml | 2 +- scripts/blas-integ-tests.sh | 3 +-- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c910b32e0..ae74aeb45 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -12,6 +12,8 @@ env: HOST: x86_64-unknown-linux-gnu FEATURES: "test docs" RUSTFLAGS: "-D warnings" + MSRV: 1.64.0 + BLAS_MSRV: 1.71.1 jobs: clippy: @@ -95,12 +97,14 @@ jobs: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: 1.67.0 # BLAS MSRV + toolchain: 1.71.1 # BLAS MSRV - uses: rui314/setup-mold@v1 - uses: Swatinem/rust-cache@v2 - name: Install openblas run: sudo apt-get install libopenblas-dev gfortran - - run: ./scripts/blas-integ-tests.sh "$FEATURES" 1.67.0 + - run: cargo tree -p blas-tests -i openblas-src -F blas-tests/openblas-system + - run: cargo tree -p blas-tests -i openblas-build -F blas-tests/openblas-system + - run: ./scripts/blas-integ-tests.sh $BLAS_MSRV miri: runs-on: ubuntu-latest diff --git a/README.rst b/README.rst index ef6577f13..49558b1c1 100644 --- a/README.rst +++ b/README.rst @@ -161,8 +161,8 @@ BLAS on MSRV ------------ Although ``ndarray`` currently maintains an MSRV of 1.64.0, this is separate from the MSRV (either stated or real) of the various BLAS providers. -As of the time of writing, ``openblas`` currently supports MSRV of 1.67.0. -So, while ``ndarray`` and ``openblas-src`` are compatible, they can only work together with toolchains 1.67.0 or above. +As of the time of writing, ``openblas`` currently supports MSRV of 1.71.1. +So, while ``ndarray`` and ``openblas-src`` are compatible, they can only work together with toolchains 1.71.1 or above. Recent Changes -------------- diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index 05a656000..ff556873a 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -15,7 +15,7 @@ ndarray = { workspace = true, features = ["approx", "blas"] } ndarray-gen = { workspace = true } blas-src = { version = "0.10", optional = true } -openblas-src = { version = "0.10", optional = true } +openblas-src = { version = ">=0.10.11", optional = true } netlib-src = { version = "0.8", optional = true } blis-src = { version = "0.2", features = ["system"], optional = true } diff --git a/crates/numeric-tests/Cargo.toml b/crates/numeric-tests/Cargo.toml index 214612258..93a182e66 100644 --- a/crates/numeric-tests/Cargo.toml +++ b/crates/numeric-tests/Cargo.toml @@ -19,7 +19,7 @@ rand = { workspace = true } rand_distr = { workspace = true } blas-src = { optional = true, version = "0.10", default-features = false, features = ["openblas"] } -openblas-src = { optional = true, version = "0.10", default-features = false, features = ["cblas", "system"] } +openblas-src = { optional = true, version = ">=0.10.11", default-features = false, features = ["cblas", "system"] } [dev-dependencies] num-traits = { workspace = true } diff --git a/scripts/blas-integ-tests.sh b/scripts/blas-integ-tests.sh index 5192d67e3..fec938b83 100755 --- a/scripts/blas-integ-tests.sh +++ b/scripts/blas-integ-tests.sh @@ -3,8 +3,7 @@ set -x set -e -FEATURES=$1 -CHANNEL=$2 +CHANNEL=$1 # BLAS tests cargo test -p blas-tests -v --features blas-tests/openblas-system From c7391e99073b40daef5de0563cbcc32aea0facb0 Mon Sep 17 00:00:00 2001 From: akern40 Date: Sun, 2 Feb 2025 23:54:45 -0500 Subject: [PATCH 34/48] Simplify features and make documentation call out feature gates (#1479) * Makes use of the nightly `doc_cfg` feature to automatically mark feature-gated items as requiring that feature. This is possible thanks to the fact that docs.rs runs on nightly. While this may not be stabilized (and therefore may eventually reverse), I think it's extremely useful to users and only requires small additional configurations that would be easy to remove in the future. * Adds appropriate arguments to CI/CD and removes serde-1, test, and docs features * Fixes clippy complaining about `return None` instead of question marks --- .github/workflows/ci.yaml | 10 +++++----- Cargo.toml | 13 ++++--------- src/array_approx.rs | 6 ++---- src/arraytraits.rs | 1 + src/error.rs | 1 + src/impl_constructors.rs | 4 ++++ src/impl_methods.rs | 5 +---- src/iterators/mod.rs | 5 +---- src/lib.rs | 8 +++++++- src/linalg_traits.rs | 5 +++++ src/numeric/impl_float_maths.rs | 1 + src/numeric/impl_numeric.rs | 4 ++++ src/parallel/impl_par_methods.rs | 2 -- src/partial.rs | 3 +++ 14 files changed, 39 insertions(+), 29 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ae74aeb45..1a1ee6415 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -10,7 +10,7 @@ name: Continuous integration env: CARGO_TERM_COLOR: always HOST: x86_64-unknown-linux-gnu - FEATURES: "test docs" + FEATURES: "approx,serde,rayon" RUSTFLAGS: "-D warnings" MSRV: 1.64.0 BLAS_MSRV: 1.71.1 @@ -30,7 +30,7 @@ jobs: toolchain: ${{ matrix.rust }} components: clippy - uses: Swatinem/rust-cache@v2 - - run: cargo clippy --features docs + - run: cargo clippy --features approx,serde,rayon format: runs-on: ubuntu-latest @@ -139,7 +139,7 @@ jobs: - uses: Swatinem/rust-cache@v2 - name: Install cross run: cargo install cross - - run: ./scripts/cross-tests.sh "docs" ${{ matrix.rust }} ${{ matrix.target }} + - run: ./scripts/cross-tests.sh "approx,serde,rayon" ${{ matrix.rust }} ${{ matrix.target }} cargo-careful: #if: ${{ github.event_name == 'merge_group' }} @@ -161,10 +161,10 @@ jobs: strategy: matrix: rust: - - stable + - nightly # This is what docs.rs runs on, and is needed for the feature flags name: docs/${{ matrix.rust }} env: - RUSTDOCFLAGS: "-Dwarnings" + RUSTDOCFLAGS: "-Dwarnings --cfg docsrs" steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master diff --git a/Cargo.toml b/Cargo.toml index 5c7217025..3d1c1dde6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,14 +59,6 @@ default = ["std"] blas = ["dep:cblas-sys", "dep:libc"] serde = ["dep:serde"] -# Old name for the serde feature -serde-1 = ["dep:serde"] - -# These features are used for testing -test = [] - -# This feature is used for docs -docs = ["approx", "serde", "rayon"] std = ["num-traits/std", "matrixmultiply/std"] rayon = ["dep:rayon", "std"] @@ -121,5 +113,8 @@ opt-level = 2 no-dev-version = true tag-name = "{{version}}" +# Config specific to docs.rs [package.metadata.docs.rs] -features = ["docs"] +features = ["approx", "serde", "rayon"] +# Define the configuration attribute `docsrs` +rustdoc-args = ["--cfg", "docsrs"] diff --git a/src/array_approx.rs b/src/array_approx.rs index 493864c7e..c6fd174d1 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -1,4 +1,5 @@ #[cfg(feature = "approx")] +#[cfg_attr(docsrs, doc(cfg(feature = "approx")))] mod approx_methods { use crate::imp_prelude::*; @@ -10,8 +11,6 @@ mod approx_methods { /// A test for equality that uses the elementwise absolute difference to compute the /// approximate equality of two arrays. - /// - /// **Requires crate feature `"approx"`** pub fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool where A: ::approx::AbsDiffEq, @@ -23,8 +22,6 @@ mod approx_methods /// A test for equality that uses an elementwise relative comparison if the values are far /// apart; and the absolute difference otherwise. - /// - /// **Requires crate feature `"approx"`** pub fn relative_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool where A: ::approx::RelativeEq, @@ -192,4 +189,5 @@ macro_rules! impl_approx_traits { } #[cfg(feature = "approx")] +#[cfg_attr(docsrs, doc(cfg(feature = "approx")))] impl_approx_traits!(approx, "**Requires crate feature `\"approx\"`.**"); diff --git a/src/arraytraits.rs b/src/arraytraits.rs index d7a00fcfe..62f95df4a 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -316,6 +316,7 @@ where } #[cfg(feature = "serde")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] // Use version number so we can add a packed format later. pub const ARRAY_FORMAT_VERSION: u8 = 1u8; diff --git a/src/error.rs b/src/error.rs index eb7395ad8..e19c32075 100644 --- a/src/error.rs +++ b/src/error.rs @@ -81,6 +81,7 @@ impl PartialEq for ShapeError } #[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] impl Error for ShapeError {} impl fmt::Display for ShapeError diff --git a/src/impl_constructors.rs b/src/impl_constructors.rs index 260937a90..c1e5b1b8b 100644 --- a/src/impl_constructors.rs +++ b/src/impl_constructors.rs @@ -99,6 +99,7 @@ where S: DataOwned /// assert!(array == arr1(&[0.0, 0.25, 0.5, 0.75, 1.0])) /// ``` #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn linspace(start: A, end: A, n: usize) -> Self where A: Float { @@ -117,6 +118,7 @@ where S: DataOwned /// assert!(array == arr1(&[0., 1., 2., 3., 4.])) /// ``` #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn range(start: A, end: A, step: A) -> Self where A: Float { @@ -145,6 +147,7 @@ where S: DataOwned /// # } /// ``` #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn logspace(base: A, start: A, end: A, n: usize) -> Self where A: Float { @@ -179,6 +182,7 @@ where S: DataOwned /// # example().unwrap(); /// ``` #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn geomspace(start: A, end: A, n: usize) -> Option where A: Float { diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 4a00ea000..3da63b936 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -2299,10 +2299,7 @@ where let dim = dim.into_dimension(); // Note: zero strides are safe precisely because we return an read-only view - let broadcast_strides = match upcast(&dim, &self.dim, &self.strides) { - Some(st) => st, - None => return None, - }; + let broadcast_strides = upcast(&dim, &self.dim, &self.strides)?; unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) } } diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 01fff14f5..e0da8f6c9 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -139,10 +139,7 @@ impl DoubleEndedIterator for Baseiter #[inline] fn next_back(&mut self) -> Option { - let index = match self.index { - None => return None, - Some(ix) => ix, - }; + let index = self.index?; self.dim[0] -= 1; let offset = Ix1::stride_offset(&self.dim, &self.strides); if index == self.dim { diff --git a/src/lib.rs b/src/lib.rs index b163f16a5..f0b64028f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ #![doc(test(attr(allow(unused_variables))))] #![doc(test(attr(allow(deprecated))))] #![cfg_attr(not(feature = "std"), no_std)] +// Enable the doc_cfg nightly feature for including feature gate flags in the documentation +#![cfg_attr(docsrs, feature(doc_cfg))] //! The `ndarray` crate provides an *n*-dimensional container for general elements //! and for numerics. @@ -120,7 +122,7 @@ extern crate std; #[cfg(feature = "blas")] extern crate cblas_sys; -#[cfg(feature = "docs")] +#[cfg(docsrs)] pub mod doc; #[cfg(target_has_atomic = "ptr")] @@ -148,6 +150,7 @@ use crate::iterators::{ElementsBase, ElementsBaseMut, Iter, IterMut}; pub use crate::arraytraits::AsArray; pub use crate::linalg_traits::LinalgScalar; #[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub use crate::linalg_traits::NdFloat; pub use crate::stacking::{concatenate, stack}; @@ -189,9 +192,11 @@ mod layout; mod linalg_traits; mod linspace; #[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub use crate::linspace::{linspace, range, Linspace}; mod logspace; #[cfg(feature = "std")] +#[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub use crate::logspace::{logspace, Logspace}; mod math_cell; mod numeric_util; @@ -1587,6 +1592,7 @@ where // parallel methods #[cfg(feature = "rayon")] +#[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] pub mod parallel; mod impl_1d; diff --git a/src/linalg_traits.rs b/src/linalg_traits.rs index 65d264c40..ec1aebbe7 100644 --- a/src/linalg_traits.rs +++ b/src/linalg_traits.rs @@ -39,7 +39,10 @@ impl LinalgScalar for T where T: 'static + Copy + Zero + One + Add ArrayBase where A: 'static + Float, diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 6c67b9135..a8a008395 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -140,6 +140,7 @@ where /// ``` #[track_caller] #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn var(&self, ddof: A) -> A where A: Float + FromPrimitive { @@ -205,6 +206,7 @@ where /// ``` #[track_caller] #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn std(&self, ddof: A) -> A where A: Float + FromPrimitive { @@ -361,6 +363,7 @@ where /// ``` #[track_caller] #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn var_axis(&self, axis: Axis, ddof: A) -> Array where A: Float + FromPrimitive, @@ -431,6 +434,7 @@ where /// ``` #[track_caller] #[cfg(feature = "std")] + #[cfg_attr(docsrs, doc(cfg(feature = "std")))] pub fn std_axis(&self, axis: Axis, ddof: A) -> Array where A: Float + FromPrimitive, diff --git a/src/parallel/impl_par_methods.rs b/src/parallel/impl_par_methods.rs index c6af4e8f3..7f01ea32f 100644 --- a/src/parallel/impl_par_methods.rs +++ b/src/parallel/impl_par_methods.rs @@ -8,8 +8,6 @@ use crate::parallel::prelude::*; use crate::partial::Partial; /// # Parallel methods -/// -/// These methods require crate feature `rayon`. impl ArrayBase where S: DataMut, diff --git a/src/partial.rs b/src/partial.rs index 99aba75a8..4509e77dc 100644 --- a/src/partial.rs +++ b/src/partial.rs @@ -37,6 +37,7 @@ impl Partial } #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] pub(crate) fn stub() -> Self { Self { @@ -46,6 +47,7 @@ impl Partial } #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] pub(crate) fn is_stub(&self) -> bool { self.ptr.is_null() @@ -60,6 +62,7 @@ impl Partial } #[cfg(feature = "rayon")] + #[cfg_attr(docsrs, doc(cfg(feature = "rayon")))] /// Merge if they are in order (left to right) and contiguous. /// Skips merge if T does not need drop. pub(crate) fn try_merge(mut left: Self, right: Self) -> Self From 41bace11a7ad456911ce62bb7b012d6332ec0af1 Mon Sep 17 00:00:00 2001 From: akern40 Date: Fri, 28 Feb 2025 19:22:52 -0500 Subject: [PATCH 35/48] Uses a simple fix to enable arraybase to be covariant. (#1480) See rust-lang/rust#115799 and rust-lang/rust#57440 for more details. --- src/lib.rs | 6 +++--- tests/variance.rs | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) create mode 100644 tests/variance.rs diff --git a/src/lib.rs b/src/lib.rs index f0b64028f..9ba3b6728 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1282,15 +1282,15 @@ pub type Ixs = isize; // may change in the future. // // [`.offset()`]: https://doc.rust-lang.org/stable/std/primitive.pointer.html#method.offset-1 -pub struct ArrayBase -where S: RawData +pub struct ArrayBase::Elem> +where S: RawData { /// Data buffer / ownership information. (If owned, contains the data /// buffer; if borrowed, contains the lifetime and mutability.) data: S, /// A non-null pointer into the buffer held by `data`; may point anywhere /// in its range. If `S: Data`, this pointer must be aligned. - ptr: std::ptr::NonNull, + ptr: std::ptr::NonNull, /// The lengths of the axes. dim: D, /// The element count stride per axis. To be parsed as `isize`. diff --git a/tests/variance.rs b/tests/variance.rs new file mode 100644 index 000000000..e72805ff7 --- /dev/null +++ b/tests/variance.rs @@ -0,0 +1,14 @@ +use ndarray::{Array1, ArrayView1}; + +fn arrayview_covariant<'a: 'b, 'b>(x: ArrayView1<'a, f64>) -> ArrayView1<'b, f64> +{ + x +} + +#[test] +fn test_covariance() +{ + let x = Array1::zeros(2); + let shorter_view = arrayview_covariant(x.view()); + assert_eq!(shorter_view[0], 0.0); +} From ee6c45e33796f8e792ab7763484b068c35538bce Mon Sep 17 00:00:00 2001 From: akern40 Date: Sun, 16 Mar 2025 18:07:40 -0400 Subject: [PATCH 36/48] Tries to stabilize MSRV CI/CD. (#1485) * Tries to stabilize MSRV CI/CD. With the new MSRV-aware resolver available, I am trying to stabilize our CI/CD, which frequently breaks due to dependency updates. This takes three steps to do so: 1. Add Cargo.lock, so that builds on the CI/CD are deterministic 2. Add a regular (weekly) job that checks against the latest dependencies across both stable and MSRV versions of rustc 3. Simplify the current CI/CD and revert to a single MSRV, rather than using a BLAS-specific MSRV. --- .github/workflows/ci.yaml | 20 +- .github/workflows/latest-deps.yaml | 66 ++ .gitignore | 1 - Cargo.lock | 1321 ++++++++++++++++++++++++++++ scripts/all-tests.sh | 2 +- scripts/blas-integ-tests.sh | 2 - 6 files changed, 1405 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/latest-deps.yaml create mode 100644 Cargo.lock diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1a1ee6415..6ebdc8432 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,6 +16,18 @@ env: BLAS_MSRV: 1.71.1 jobs: + pass-msrv: + runs-on: ubuntu-latest + name: Pass MSRV values to other jobs + outputs: + MSRV: ${{ env.MSRV }} + BLAS_MSRV: ${{ env.BLAS_MSRV }} + steps: + - name: Pass MSRV + run: | + echo "MSRV=${{ env.MSRV }}" >> $GITHUB_OUTPUT + echo "BLAS_MSRV=${{ env.BLAS_MSRV }}" >> $GITHUB_OUTPUT + clippy: runs-on: ubuntu-latest strategy: @@ -70,13 +82,14 @@ jobs: tests: runs-on: ubuntu-latest + needs: pass-msrv strategy: matrix: rust: - stable - beta - nightly - - 1.64.0 # MSRV + - ${{ needs.pass-msrv.outputs.MSRV }} name: tests/${{ matrix.rust }} steps: @@ -89,15 +102,16 @@ jobs: - name: Install openblas run: sudo apt-get install libopenblas-dev gfortran - run: ./scripts/all-tests.sh "$FEATURES" ${{ matrix.rust }} - + blas-msrv: runs-on: ubuntu-latest name: blas-msrv + needs: pass-msrv steps: - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: - toolchain: 1.71.1 # BLAS MSRV + toolchain: ${{ needs.pass-msrv.outputs.BLAS_MSRV }} - uses: rui314/setup-mold@v1 - uses: Swatinem/rust-cache@v2 - name: Install openblas diff --git a/.github/workflows/latest-deps.yaml b/.github/workflows/latest-deps.yaml new file mode 100644 index 000000000..f2f3d8486 --- /dev/null +++ b/.github/workflows/latest-deps.yaml @@ -0,0 +1,66 @@ +name: Check Latest Dependencies +on: + schedule: + # Chosen so that it runs right before the international date line experiences the weekend. + # Since we're open source, that means globally we should be aware of it right when we have the most + # time to fix it. + # + # Sorry if this ruins your weekend, future maintainer... + - cron: '0 12 * * FRI' + workflow_dispatch: # For running manually + +env: + CARGO_TERM_COLOR: always + HOST: x86_64-unknown-linux-gnu + FEATURES: "approx,serde,rayon" + RUSTFLAGS: "-D warnings" + MSRV: 1.64.0 + BLAS_MSRV: 1.71.0 + +jobs: + latest_deps_stable: + runs-on: ubuntu-latest + name: Check Latest Dependencies on Stable + steps: + - name: Check Out Repo + uses: actions/checkout@v4 + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - name: Setup Mold Linker + uses: rui314/setup-mold@v1 + - name: Setup Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install openblas + run: sudo apt-get install libopenblas-dev gfortran + - name: Ensure latest dependencies + run: cargo update + - name: Run Tests + run: ./scripts/all-tests.sh "$FEATURES" stable + + latest_deps_msrv: + runs-on: ubuntu-latest + name: Check Latest Dependencies on MSRV (${{ env.MSRV }}) + steps: + - name: Check Out Repo + uses: actions/checkout@v4 + - name: Install Stable Rust for Update + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + - name: Setup Mold Linker + uses: rui314/setup-mold@v1 + - name: Setup Rust Cache + uses: Swatinem/rust-cache@v2 + - name: Install openblas + run: sudo apt-get install libopenblas-dev gfortran + - name: Ensure latest dependencies + # The difference is here between this and `latest_deps_stable` + run: CARGO_RESOLVER_INCOMPATIBLE_RUST_VERSIONS="fallback" cargo update + - name: Install MSRV Rust for Test + uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ env.MSRV }} + - name: Run Tests + run: ./scripts/all-tests.sh "$FEATURES" $MSRV diff --git a/.gitignore b/.gitignore index e9b5ca25b..ef4ee42f5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,4 @@ # Rust items -Cargo.lock target/ # Editor settings diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 000000000..472b171f8 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,1321 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" + +[[package]] +name = "anyhow" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfed56ad506cb2c684a14971b8861fdc3baaaae314b9e5f9bb532cbe3ba7a4f" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" +dependencies = [ + "serde", +] + +[[package]] +name = "blas-mock-tests" +version = "0.1.0" +dependencies = [ + "cblas-sys", + "itertools", + "ndarray", + "ndarray-gen", +] + +[[package]] +name = "blas-src" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95e83dc868db96e69795c0213143095f03de9dd3252f205d4ac716e4076a7e0" +dependencies = [ + "blis-src", + "netlib-src", + "openblas-src", +] + +[[package]] +name = "blas-tests" +version = "0.1.0" +dependencies = [ + "approx", + "blas-src", + "blis-src", + "defmac", + "itertools", + "ndarray", + "ndarray-gen", + "netlib-src", + "num-complex", + "num-traits", + "openblas-src", +] + +[[package]] +name = "blis-src" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc119b6761ce8b063102502af49043051f81a9bdf242ae06d12e9ea0d92b727a" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + +[[package]] +name = "cc" +version = "1.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "critical-section" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "790eea4361631c5e7d22598ecd5723ff611904e3344ce8720784c93e3d83d40b" + +[[package]] +name = "crossbeam-channel" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "defmac" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5592fca31e96d8a748d03080b58be78c5383617aa4bd89e69f30607d8769891" + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "errno" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "filetime" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.59.0", +] + +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.11.0+wasi-snapshot-preview1", +] + +[[package]] +name = "getrandom" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +dependencies = [ + "cfg-if", + "libc", + "wasi 0.13.3+wasi-0.2.2", + "windows-targets 0.52.6", +] + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "idna" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "279259b0ac81c89d11c290495fdcfa96ea3643b7df311c138b6fe8ca5237f0f8" +dependencies = [ + "idna_mapping", + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "idna_mapping" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5422cc5bc64289a77dbb45e970b86b5e9a04cb500abc7240505aedc1bf40f38" +dependencies = [ + "unicode-joining-type", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + +[[package]] +name = "libc" +version = "0.2.171" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" + +[[package]] +name = "libm" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" + +[[package]] +name = "libredox" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" +dependencies = [ + "bitflags", + "libc", + "redox_syscall", +] + +[[package]] +name = "linux-raw-sys" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" + +[[package]] +name = "log" +version = "0.4.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "num_cpus", + "once_cell", + "rawpointer", + "thread-tree", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "miniz_oxide" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" +dependencies = [ + "adler2", +] + +[[package]] +name = "native-tls" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dab59f8e050d5df8e4dd87d9206fb6f65a483e20ac9fda365ade4fab353196c" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "ndarray" +version = "0.16.1" +dependencies = [ + "approx", + "cblas-sys", + "defmac", + "itertools", + "libc", + "matrixmultiply", + "ndarray-gen", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "quickcheck", + "rawpointer", + "rayon", + "serde", +] + +[[package]] +name = "ndarray-gen" +version = "0.1.0" +dependencies = [ + "ndarray", + "num-traits", +] + +[[package]] +name = "ndarray-rand" +version = "0.15.0" +dependencies = [ + "ndarray", + "quickcheck", + "rand", + "rand_distr", + "rand_isaac", +] + +[[package]] +name = "netlib-src" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39f41f36bb4d46906d5a72da5b73a804d9de1a7282eb7c89617201acda7b8212" +dependencies = [ + "cmake", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "numeric-tests" +version = "0.1.0" +dependencies = [ + "approx", + "blas-src", + "ndarray", + "ndarray-rand", + "num-complex", + "num-traits", + "openblas-src", + "rand", + "rand_distr", +] + +[[package]] +name = "once_cell" +version = "1.20.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" + +[[package]] +name = "openblas-build" +version = "0.10.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ca8f8c64eb5b43f5538059ccbc71391420bba14d987d7e8ab99ed62ed33e26b" +dependencies = [ + "anyhow", + "cc", + "flate2", + "native-tls", + "tar", + "thiserror 2.0.12", + "ureq", +] + +[[package]] +name = "openblas-src" +version = "0.10.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "252f22774417be65f908a20f7721a97e33a253acad4f28370408b7f1baea0629" +dependencies = [ + "dirs", + "openblas-build", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "openssl" +version = "0.10.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e14130c6a98cd258fdcb0fb6d744152343ff729cbfcb28c656a9d12b999fbcd" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "percent-encoding" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" + +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + +[[package]] +name = "portable-atomic" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +dependencies = [ + "critical-section", +] + +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.94" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quickcheck" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" +dependencies = [ + "rand", +] + +[[package]] +name = "quote" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.15", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rand_isaac" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28" +dependencies = [ + "rand_core", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.15", + "libredox", + "thiserror 1.0.69", +] + +[[package]] +name = "rmp" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f55e5fa1446c4d5dd1f5daeed2a4fe193071771a2636274d0d7a3b082aa7ad6" +dependencies = [ + "byteorder", + "num-traits", +] + +[[package]] +name = "rmp-serde" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ce7d70c926fe472aed493b902010bccc17fa9f7284145cb8772fd22fdb052d8" +dependencies = [ + "byteorder", + "rmp", + "serde", +] + +[[package]] +name = "ron" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" +dependencies = [ + "base64 0.21.7", + "bitflags", + "serde", + "serde_derive", +] + +[[package]] +name = "rustix" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7178faa4b75a30e269c71e61c353ce2748cf3d76f0c44c393f4e60abf49b825" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustls-native-certs" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5bfb394eeed242e909609f56089eecfe5fda225042e8b171791b9c95f5931e5" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dce314e5fee3f39953d46bb63bb8a46d40c2f8fb7cc5a3b6cab2bde9721d6e50" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "rustls-pki-types" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c" + +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "schannel" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f29ebaa345f945cec9fbbc532eb307f0fdad8161f281b6369539c8d84876b3d" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.140" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "serialization-tests" +version = "0.1.0" +dependencies = [ + "ndarray", + "rmp", + "rmp-serde", + "ron", + "serde", + "serde_json", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" + +[[package]] +name = "syn" +version = "2.0.100" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488960f40a3fd53d72c2a29a58722561dee8afdd175bd88e3db4677d7b2ba600" +dependencies = [ + "fastrand", + "getrandom 0.3.1", + "once_cell", + "rustix", + "windows-sys 0.59.0", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +dependencies = [ + "thiserror-impl 2.0.12", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread-tree" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffbd370cb847953a25954d9f63e14824a36113f8c72eecf6eccef5dc4b45d630" +dependencies = [ + "crossbeam-channel", +] + +[[package]] +name = "tinyvec" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09b3661f17e86524eccd4371ab0429194e0d7c008abb45f7a7495b1719463c71" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" + +[[package]] +name = "unicode-ident" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" + +[[package]] +name = "unicode-joining-type" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22f8cb47ccb8bc750808755af3071da4a10dcd147b68fc874b7ae4b12543f6f5" + +[[package]] +name = "unicode-normalization" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5033c97c4262335cded6d6fc3e5c18ab755e1a3dc96376350f3d8e9f009ad956" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "ureq" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b74fc6b57825be3373f7054754755f03ac3a8f5d70015ccad699ba2029956f4a" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls-native-certs", + "url", +] + +[[package]] +name = "url" +version = "2.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasi" +version = "0.13.3+wasi-0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +dependencies = [ + "wit-bindgen-rt", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen-rt" +version = "0.33.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +dependencies = [ + "bitflags", +] + +[[package]] +name = "xattr" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d65cbf2f12c15564212d48f4e3dfb87923d25d611f2aed18f4cb23f0413d89e" +dependencies = [ + "libc", + "rustix", +] + +[[package]] +name = "zerocopy" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd97444d05a4328b90e75e503a34bad781f14e28a823ad3557f0750df1ebcbc6" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6352c01d0edd5db859a63e2605f4ea3183ddbd15e2c4a9e7d32184df75e4f154" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/scripts/all-tests.sh b/scripts/all-tests.sh index e98b90df1..612a7d758 100755 --- a/scripts/all-tests.sh +++ b/scripts/all-tests.sh @@ -23,7 +23,7 @@ cargo test -v -p ndarray -p ndarray-rand --release --features "$FEATURES" $QC_FE # BLAS tests cargo test -p ndarray --lib -v --features blas cargo test -p blas-mock-tests -v -if [ "$CHANNEL" != "1.64.0" ]; then +if [[ -z "${MSRV}" ]] && [ "$CHANNEL" != "$MSRV" ]; then ./scripts/blas-integ-tests.sh "$FEATURES" $CHANNEL fi diff --git a/scripts/blas-integ-tests.sh b/scripts/blas-integ-tests.sh index fec938b83..3d769e0af 100755 --- a/scripts/blas-integ-tests.sh +++ b/scripts/blas-integ-tests.sh @@ -3,8 +3,6 @@ set -x set -e -CHANNEL=$1 - # BLAS tests cargo test -p blas-tests -v --features blas-tests/openblas-system cargo test -p numeric-tests -v --features numeric-tests/test_blas From 9fc4110b572cbf7e80c50e4498f30d39b73e7970 Mon Sep 17 00:00:00 2001 From: akern40 Date: Sun, 16 Mar 2025 19:54:21 -0400 Subject: [PATCH 37/48] Bump rand to 0.9.0 and rand_distr to 0.5.0 (#1486) --- Cargo.toml | 4 ++-- crates/numeric-tests/tests/accuracy.rs | 28 +++++++++++++------------- ndarray-rand/Cargo.toml | 2 +- ndarray-rand/benches/bench.rs | 2 +- ndarray-rand/src/lib.rs | 19 ++++++++--------- ndarray-rand/tests/tests.rs | 24 +++++++++++----------- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3d1c1dde6..98326a598 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -96,8 +96,8 @@ num-traits = { version = "0.2", default-features = false } num-complex = { version = "0.4", default-features = false } approx = { version = "0.5", default-features = false } quickcheck = { version = "1.0", default-features = false } -rand = { version = "0.8.0", features = ["small_rng"] } -rand_distr = { version = "0.4.0" } +rand = { version = "0.9.0", features = ["small_rng"] } +rand_distr = { version = "0.5.0" } itertools = { version = "0.13.0", default-features = false, features = ["use_std"] } cblas-sys = { version = "0.1.4", default-features = false } diff --git a/crates/numeric-tests/tests/accuracy.rs b/crates/numeric-tests/tests/accuracy.rs index c594f020d..db10d57cd 100644 --- a/crates/numeric-tests/tests/accuracy.rs +++ b/crates/numeric-tests/tests/accuracy.rs @@ -86,7 +86,7 @@ where #[test] fn accurate_eye_f32() { - let rng = &mut SmallRng::from_entropy(); + let rng = &mut SmallRng::from_os_rng(); for i in 0..20 { let eye = Array::eye(i); for j in 0..20 { @@ -99,8 +99,8 @@ fn accurate_eye_f32() } // pick a few random sizes for _ in 0..10 { - let i = rng.gen_range(15..512); - let j = rng.gen_range(15..512); + let i = rng.random_range(15..512); + let j = rng.random_range(15..512); println!("Testing size {} by {}", i, j); let a = gen::(Ix2(i, j), rng); let eye = Array::eye(i); @@ -114,7 +114,7 @@ fn accurate_eye_f32() #[test] fn accurate_eye_f64() { - let rng = &mut SmallRng::from_entropy(); + let rng = &mut SmallRng::from_os_rng(); let abs_tol = 1e-15; for i in 0..20 { let eye = Array::eye(i); @@ -128,8 +128,8 @@ fn accurate_eye_f64() } // pick a few random sizes for _ in 0..10 { - let i = rng.gen_range(15..512); - let j = rng.gen_range(15..512); + let i = rng.random_range(15..512); + let j = rng.random_range(15..512); println!("Testing size {} by {}", i, j); let a = gen::(Ix2(i, j), rng); let eye = Array::eye(i); @@ -172,9 +172,9 @@ fn random_matrix_mul( ) -> (Array2, Array2) where A: LinalgScalar { - let m = rng.gen_range(15..128); - let k = rng.gen_range(15..128); - let n = rng.gen_range(15..512); + let m = rng.random_range(15..128); + let k = rng.random_range(15..128); + let n = rng.random_range(15..512); let a = generator(Ix2(m, k), rng); let b = generator(Ix2(n, k), rng); let c = if use_general { @@ -209,7 +209,7 @@ where A: fmt::Debug, { // pick a few random sizes - let mut rng = SmallRng::from_entropy(); + let mut rng = SmallRng::from_os_rng(); for i in 0..20 { let (c, reference) = random_matrix_mul(&mut rng, i > 10, use_general, gen::); @@ -241,7 +241,7 @@ where A: fmt::Debug, { // pick a few random sizes - let mut rng = SmallRng::from_entropy(); + let mut rng = SmallRng::from_os_rng(); for i in 0..20 { let (c, reference) = random_matrix_mul(&mut rng, i > 10, true, gen_complex::); @@ -259,10 +259,10 @@ where fn accurate_mul_with_column_f64() { // pick a few random sizes - let rng = &mut SmallRng::from_entropy(); + let rng = &mut SmallRng::from_os_rng(); for i in 0..10 { - let m = rng.gen_range(1..128); - let k = rng.gen_range(1..350); + let m = rng.random_range(1..128); + let k = rng.random_range(1..350); let a = gen::(Ix2(m, k), rng); let b_owner = gen::(Ix2(k, k), rng); let b_row_col; diff --git a/ndarray-rand/Cargo.toml b/ndarray-rand/Cargo.toml index b58e752a5..72b959020 100644 --- a/ndarray-rand/Cargo.toml +++ b/ndarray-rand/Cargo.toml @@ -21,7 +21,7 @@ rand_distr = { workspace = true } quickcheck = { workspace = true, optional = true } [dev-dependencies] -rand_isaac = "0.3.0" +rand_isaac = "0.4.0" quickcheck = { workspace = true } [package.metadata.release] diff --git a/ndarray-rand/benches/bench.rs b/ndarray-rand/benches/bench.rs index 0e5eb2ff7..364eca9f4 100644 --- a/ndarray-rand/benches/bench.rs +++ b/ndarray-rand/benches/bench.rs @@ -13,7 +13,7 @@ use test::Bencher; fn uniform_f32(b: &mut Bencher) { let m = 100; - b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.))); + b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.).unwrap())); } #[bench] diff --git a/ndarray-rand/src/lib.rs b/ndarray-rand/src/lib.rs index 6671ab334..795e246d4 100644 --- a/ndarray-rand/src/lib.rs +++ b/ndarray-rand/src/lib.rs @@ -29,10 +29,10 @@ //! that the items are not compatible (e.g. that a type doesn't implement a //! necessary trait). -use crate::rand::distributions::{Distribution, Uniform}; +use crate::rand::distr::{Distribution, Uniform}; use crate::rand::rngs::SmallRng; use crate::rand::seq::index; -use crate::rand::{thread_rng, Rng, SeedableRng}; +use crate::rand::{rng, Rng, SeedableRng}; use ndarray::{Array, Axis, RemoveAxis, ShapeBuilder}; use ndarray::{ArrayBase, Data, DataOwned, Dimension, RawData}; @@ -71,8 +71,8 @@ where /// Create an array with shape `dim` with elements drawn from /// `distribution` using the default RNG. /// - /// ***Panics*** if creation of the RNG fails or if the number of elements - /// overflows usize. + /// ***Panics*** if creation of the RNG fails, the number of elements + /// overflows usize, or the axis has zero length. /// /// ``` /// use ndarray::Array; @@ -80,7 +80,7 @@ where /// use ndarray_rand::rand_distr::Uniform; /// /// # fn main() { - /// let a = Array::random((2, 5), Uniform::new(0., 10.)); + /// let a = Array::random((2, 5), Uniform::new(0., 10.).unwrap()); /// println!("{:8.4}", a); /// // Example Output: /// // [[ 8.6900, 6.9824, 3.8922, 6.5861, 2.4890], @@ -95,7 +95,8 @@ where /// Create an array with shape `dim` with elements drawn from /// `distribution`, using a specific Rng `rng`. /// - /// ***Panics*** if the number of elements overflows usize. + /// ***Panics*** if the number of elements overflows usize + /// or the axis has zero length. /// /// ``` /// use ndarray::Array; @@ -110,7 +111,7 @@ where /// let mut rng = Isaac64Rng::seed_from_u64(seed); /// /// // Generate a random array using `rng` - /// let a = Array::random_using((2, 5), Uniform::new(0., 10.), &mut rng); + /// let a = Array::random_using((2, 5), Uniform::new(0., 10.).unwrap(), &mut rng); /// println!("{:8.4}", a); /// // Example Output: /// // [[ 8.6900, 6.9824, 3.8922, 6.5861, 2.4890], @@ -270,7 +271,7 @@ where { let indices: Vec<_> = match strategy { SamplingStrategy::WithReplacement => { - let distribution = Uniform::from(0..self.len_of(axis)); + let distribution = Uniform::new(0, self.len_of(axis)).unwrap(); (0..n_samples).map(|_| distribution.sample(rng)).collect() } SamplingStrategy::WithoutReplacement => index::sample(rng, self.len_of(axis), n_samples).into_vec(), @@ -308,5 +309,5 @@ impl Arbitrary for SamplingStrategy fn get_rng() -> SmallRng { - SmallRng::from_rng(thread_rng()).expect("create SmallRng from thread_rng failed") + SmallRng::from_rng(&mut rng()) } diff --git a/ndarray-rand/tests/tests.rs b/ndarray-rand/tests/tests.rs index d38e8636e..5d322551a 100644 --- a/ndarray-rand/tests/tests.rs +++ b/ndarray-rand/tests/tests.rs @@ -1,6 +1,6 @@ use ndarray::{Array, Array2, ArrayView1, Axis}; #[cfg(feature = "quickcheck")] -use ndarray_rand::rand::{distributions::Distribution, thread_rng}; +use ndarray_rand::rand::{distr::Distribution, rng}; use ndarray::ShapeBuilder; use ndarray_rand::rand_distr::Uniform; @@ -13,7 +13,7 @@ fn test_dim() let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { - let a = Array::random((m, n), Uniform::new(0., 2.)); + let a = Array::random((m, n), Uniform::new(0., 2.).unwrap()); assert_eq!(a.shape(), &[m, n]); assert!(a.iter().all(|x| *x < 2.)); assert!(a.iter().all(|x| *x >= 0.)); @@ -28,7 +28,7 @@ fn test_dim_f() let (mm, nn) = (5, 5); for m in 0..mm { for n in 0..nn { - let a = Array::random((m, n).f(), Uniform::new(0., 2.)); + let a = Array::random((m, n).f(), Uniform::new(0., 2.).unwrap()); assert_eq!(a.shape(), &[m, n]); assert!(a.iter().all(|x| *x < 2.)); assert!(a.iter().all(|x| *x >= 0.)); @@ -41,7 +41,7 @@ fn test_dim_f() fn sample_axis_on_view() { let m = 5; - let a = Array::random((m, 4), Uniform::new(0., 2.)); + let a = Array::random((m, 4), Uniform::new(0., 2.).unwrap()); let _samples = a .view() .sample_axis(Axis(0), m, SamplingStrategy::WithoutReplacement); @@ -52,7 +52,7 @@ fn sample_axis_on_view() fn oversampling_without_replacement_should_panic() { let m = 5; - let a = Array::random((m, 4), Uniform::new(0., 2.)); + let a = Array::random((m, 4), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), m + 1, SamplingStrategy::WithoutReplacement); } @@ -60,7 +60,7 @@ quickcheck! { #[cfg_attr(miri, ignore)] // Takes an insufferably long time fn oversampling_with_replacement_is_fine(m: u8, n: u8) -> TestResult { let (m, n) = (m as usize, n as usize); - let a = Array::random((m, n), Uniform::new(0., 2.)); + let a = Array::random((m, n), Uniform::new(0., 2.).unwrap()); // Higher than the length of both axes let n_samples = m + n + 1; @@ -90,12 +90,12 @@ quickcheck! { #[cfg_attr(miri, ignore)] // This takes *forever* with Miri fn sampling_behaves_as_expected(m: u8, n: u8, strategy: SamplingStrategy) -> TestResult { let (m, n) = (m as usize, n as usize); - let a = Array::random((m, n), Uniform::new(0., 2.)); - let mut rng = &mut thread_rng(); + let a = Array::random((m, n), Uniform::new(0., 2.).unwrap()); + let mut rng = &mut rng(); // We don't want to deal with sampling from 0-length axes in this test if m != 0 { - let n_row_samples = Uniform::from(1..m+1).sample(&mut rng); + let n_row_samples = Uniform::new(1, m+1).unwrap().sample(&mut rng); if !sampling_works(&a, strategy.clone(), Axis(0), n_row_samples) { return TestResult::failed(); } @@ -105,7 +105,7 @@ quickcheck! { // We don't want to deal with sampling from 0-length axes in this test if n != 0 { - let n_col_samples = Uniform::from(1..n+1).sample(&mut rng); + let n_col_samples = Uniform::new(1, n+1).unwrap().sample(&mut rng); if !sampling_works(&a, strategy, Axis(1), n_col_samples) { return TestResult::failed(); } @@ -136,7 +136,7 @@ fn is_subset(a: &Array2, b: &ArrayView1, axis: Axis) -> bool fn sampling_without_replacement_from_a_zero_length_axis_should_panic() { let n = 5; - let a = Array::random((0, n), Uniform::new(0., 2.)); + let a = Array::random((0, n), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithoutReplacement); } @@ -145,6 +145,6 @@ fn sampling_without_replacement_from_a_zero_length_axis_should_panic() fn sampling_with_replacement_from_a_zero_length_axis_should_panic() { let n = 5; - let a = Array::random((0, n), Uniform::new(0., 2.)); + let a = Array::random((0, n), Uniform::new(0., 2.).unwrap()); let _samples = a.sample_axis(Axis(0), 1, SamplingStrategy::WithReplacement); } From 7ebe4c81fa11398714eb217dca60889dbd7b9003 Mon Sep 17 00:00:00 2001 From: akern40 Date: Mon, 17 Mar 2025 20:44:44 -0400 Subject: [PATCH 38/48] Forgot the Cargo.lock for #1476 --- Cargo.lock | 46 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 472b171f8..d0530aff0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -484,7 +484,7 @@ version = "0.15.0" dependencies = [ "ndarray", "quickcheck", - "rand", + "rand 0.9.0", "rand_distr", "rand_isaac", ] @@ -547,7 +547,7 @@ dependencies = [ "num-complex", "num-traits", "openblas-src", - "rand", + "rand 0.9.0", "rand_distr", ] @@ -688,7 +688,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "588f6378e4dd99458b60ec275b4477add41ce4fa9f64dcba6f15adccb19b50d6" dependencies = [ - "rand", + "rand 0.8.5", ] [[package]] @@ -706,19 +706,28 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ - "libc", + "rand_core 0.6.4", +] + +[[package]] +name = "rand" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3779b94aeb87e8bd4e834cee3650289ee9e0d5677f976ecdb6d219e5f4f6cd94" +dependencies = [ "rand_chacha", - "rand_core", + "rand_core 0.9.3", + "zerocopy", ] [[package]] name = "rand_chacha" -version = "0.3.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", ] [[package]] @@ -730,23 +739,32 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.1", +] + [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand", + "rand 0.9.0", ] [[package]] name = "rand_isaac" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fac4373cd91b4f55722c553fb0f286edbb81ef3ff6eec7b99d1898a4110a0b28" +checksum = "3382fc9f0aad4f2e2a56b53d9133c8c810b4dbf21e7e370e24346161a5b2c7bd" dependencies = [ - "rand_core", + "rand_core 0.9.3", ] [[package]] From 1866e91fd1b91e68a4b384c31960974cd728985a Mon Sep 17 00:00:00 2001 From: HuiSeomKim <126950833+NewBornRustacean@users.noreply.github.com> Date: Tue, 18 Mar 2025 09:46:47 +0900 Subject: [PATCH 39/48] Add dot product support for ArrayD (#1483) Uses the implementations from `Array1` and `Array2`, tests found in `crates/blas-tests/tests/dyn.rs`. --- crates/blas-tests/tests/dyn.rs | 80 ++++++++++++++++++++++++++++++++++ src/linalg/impl_linalg.rs | 66 +++++++++++++++++++++++++++- 2 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 crates/blas-tests/tests/dyn.rs diff --git a/crates/blas-tests/tests/dyn.rs b/crates/blas-tests/tests/dyn.rs new file mode 100644 index 000000000..6c0fd975e --- /dev/null +++ b/crates/blas-tests/tests/dyn.rs @@ -0,0 +1,80 @@ +extern crate blas_src; +use ndarray::{linalg::Dot, Array1, Array2, ArrayD, Ix1, Ix2}; + +#[test] +fn test_arrayd_dot_2d() +{ + let mat1 = ArrayD::from_shape_vec(vec![3, 2], vec![3.0; 6]).unwrap(); + let mat2 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + + let result = mat1.dot(&mat2); + + // Verify the result is correct + assert_eq!(result.ndim(), 2); + assert_eq!(result.shape(), &[3, 3]); + + // Compare with Array2 implementation + let mat1_2d = Array2::from_shape_vec((3, 2), vec![3.0; 6]).unwrap(); + let mat2_2d = Array2::from_shape_vec((2, 3), vec![1.0; 6]).unwrap(); + let expected = mat1_2d.dot(&mat2_2d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} + +#[test] +fn test_arrayd_dot_1d() +{ + // Test 1D array dot product + let vec1 = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).unwrap(); + let vec2 = ArrayD::from_shape_vec(vec![3], vec![4.0, 5.0, 6.0]).unwrap(); + + let result = vec1.dot(&vec2); + + // Verify scalar result + assert_eq!(result.ndim(), 0); + assert_eq!(result.shape(), &[]); + assert_eq!(result[[]], 32.0); // 1*4 + 2*5 + 3*6 +} + +#[test] +#[should_panic(expected = "Dot product for ArrayD is only supported for 1D and 2D arrays")] +fn test_arrayd_dot_3d() +{ + // Test that 3D arrays are not supported + let arr1 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![2, 2, 2], vec![1.0; 8]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +#[should_panic(expected = "ndarray: inputs 2 × 3 and 4 × 5 are not compatible for matrix multiplication")] +fn test_arrayd_dot_incompatible_dims() +{ + // Test arrays with incompatible dimensions + let arr1 = ArrayD::from_shape_vec(vec![2, 3], vec![1.0; 6]).unwrap(); + let arr2 = ArrayD::from_shape_vec(vec![4, 5], vec![1.0; 20]).unwrap(); + + let _result = arr1.dot(&arr2); // Should panic +} + +#[test] +fn test_arrayd_dot_matrix_vector() +{ + // Test matrix-vector multiplication + let mat = ArrayD::from_shape_vec(vec![3, 2], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).unwrap(); + + let result = mat.dot(&vec); + + // Verify result + assert_eq!(result.ndim(), 1); + assert_eq!(result.shape(), &[3]); + + // Compare with Array2 implementation + let mat_2d = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let vec_1d = Array1::from_vec(vec![1.0, 2.0]); + let expected = mat_2d.dot(&vec_1d); + + assert_eq!(result.into_dimensionality::().unwrap(), expected); +} diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 7472d8292..e05740378 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -14,8 +14,11 @@ use crate::numeric_util; use crate::{LinalgScalar, Zip}; +#[cfg(not(feature = "std"))] +use alloc::vec; #[cfg(not(feature = "std"))] use alloc::vec::Vec; + use std::any::TypeId; use std::mem::MaybeUninit; @@ -353,7 +356,7 @@ where /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// - /// **Panics** if broadcasting isn’t possible. + /// **Panics** if broadcasting isn't possible. #[track_caller] pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) where @@ -1067,3 +1070,64 @@ mod blas_tests } } } + +/// Dot product for dynamic-dimensional arrays (`ArrayD`). +/// +/// For one-dimensional arrays, computes the vector dot product, which is the sum +/// of the elementwise products (no conjugation of complex operands). +/// Both arrays must have the same length. +/// +/// For two-dimensional arrays, performs matrix multiplication. The array shapes +/// must be compatible in the following ways: +/// - If `self` is *M* × *N*, then `rhs` must be *N* × *K* for matrix-matrix multiplication +/// - If `self` is *M* × *N* and `rhs` is *N*, returns a vector of length *M* +/// - If `self` is *M* and `rhs` is *M* × *N*, returns a vector of length *N* +/// - If both arrays are one-dimensional of length *N*, returns a scalar +/// +/// **Panics** if: +/// - The arrays have dimensions other than 1 or 2 +/// - The array shapes are incompatible for the operation +/// - For vector dot product: the vectors have different lengths +/// +impl Dot> for ArrayBase +where + S: Data, + S2: Data, + A: LinalgScalar, +{ + type Output = Array; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + match (self.ndim(), rhs.ndim()) { + (1, 1) => { + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + ArrayD::from_elem(vec![], result) + } + (2, 2) => { + // Matrix-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (2, 1) => { + // Matrix-vector multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + (1, 2) => { + // Vector-matrix multiplication + let a = self.view().into_dimensionality::().unwrap(); + let b = rhs.view().into_dimensionality::().unwrap(); + let result = a.dot(&b); + result.into_dimensionality::().unwrap() + } + _ => panic!("Dot product for ArrayD is only supported for 1D and 2D arrays"), + } + } +} From 5a25737c88ac432bbab1bf60319ba0afdf372e20 Mon Sep 17 00:00:00 2001 From: akern40 Date: Mon, 17 Mar 2025 22:37:40 -0400 Subject: [PATCH 40/48] Adds an array reference type (#1440) See #1440 for more information, especially [this comment](https://github.com/rust-ndarray/ndarray/pull/1440#issuecomment-2412404938). --- .gitignore | 3 + crates/blas-tests/tests/oper.rs | 2 +- examples/axis_ops.rs | 2 +- examples/convo.rs | 2 +- examples/functions_and_traits.rs | 178 +++++++ src/alias_asref.rs | 359 ++++++++++++++ src/aliases.rs | 36 +- src/array_approx.rs | 95 +++- src/arrayformat.rs | 83 +++- src/arraytraits.rs | 151 +++++- src/data_traits.rs | 33 +- src/doc/ndarray_for_numpy_users/mod.rs | 97 ++-- src/doc/ndarray_for_numpy_users/rk_step.rs | 4 +- src/free_functions.rs | 26 +- src/impl_1d.rs | 8 +- src/impl_2d.rs | 83 +++- src/impl_clone.rs | 17 +- src/impl_cow.rs | 8 +- src/impl_dyn.rs | 52 +- src/impl_internal_constructors.rs | 18 +- src/impl_methods.rs | 548 +++++++++++++-------- src/impl_ops.rs | 177 +++++++ src/impl_owned_array.rs | 22 +- src/impl_raw_views.rs | 55 ++- src/impl_ref_types.rs | 370 ++++++++++++++ src/impl_special_element_types.rs | 5 +- src/impl_views/constructors.rs | 2 +- src/impl_views/conversions.rs | 20 +- src/impl_views/indexing.rs | 20 +- src/impl_views/splitting.rs | 2 +- src/iterators/chunks.rs | 20 +- src/iterators/into_iter.rs | 6 +- src/iterators/lanes.rs | 4 +- src/iterators/mod.rs | 30 +- src/iterators/windows.rs | 6 +- src/layout/mod.rs | 2 +- src/lib.rs | 294 +++++++++-- src/linalg/impl_linalg.rs | 205 ++++---- src/math_cell.rs | 2 +- src/numeric/impl_float_maths.rs | 6 +- src/numeric/impl_numeric.rs | 6 +- src/parallel/impl_par_methods.rs | 5 +- src/parallel/mod.rs | 4 +- src/prelude.rs | 17 +- src/shape_builder.rs | 2 +- src/slice.rs | 12 +- src/tri.rs | 24 +- src/zip/mod.rs | 13 +- src/zip/ndproducer.rs | 61 ++- tests/array.rs | 8 + 50 files changed, 2554 insertions(+), 651 deletions(-) create mode 100644 examples/functions_and_traits.rs create mode 100644 src/alias_asref.rs create mode 100644 src/impl_ref_types.rs diff --git a/.gitignore b/.gitignore index ef4ee42f5..c1885550c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,6 @@ target/ # Editor settings .vscode .idea + +# Apple details +**/.DS_Store diff --git a/crates/blas-tests/tests/oper.rs b/crates/blas-tests/tests/oper.rs index a9dca7e83..f604ae091 100644 --- a/crates/blas-tests/tests/oper.rs +++ b/crates/blas-tests/tests/oper.rs @@ -280,7 +280,7 @@ fn gen_mat_mul() cv = c.view_mut(); } - let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv; + let answer_part: Array = alpha * reference_mat_mul(&av, &bv) + beta * &cv; answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part); general_mat_mul(alpha, &av, &bv, beta, &mut cv); diff --git a/examples/axis_ops.rs b/examples/axis_ops.rs index 3a54a52fb..7f80a637f 100644 --- a/examples/axis_ops.rs +++ b/examples/axis_ops.rs @@ -13,7 +13,7 @@ use ndarray::prelude::*; /// it corresponds to their order in memory. /// /// Errors if array has a 0-stride axis -fn regularize(a: &mut Array) -> Result<(), &'static str> +fn regularize(a: &mut ArrayRef) -> Result<(), &'static str> where D: Dimension, A: ::std::fmt::Debug, diff --git a/examples/convo.rs b/examples/convo.rs index a59795e12..79e8ab6b6 100644 --- a/examples/convo.rs +++ b/examples/convo.rs @@ -14,7 +14,7 @@ type Kernel3x3 = [[A; 3]; 3]; #[inline(never)] #[cfg(feature = "std")] -fn conv_3x3(a: &ArrayView2<'_, F>, out: &mut ArrayViewMut2<'_, F>, kernel: &Kernel3x3) +fn conv_3x3(a: &ArrayRef2, out: &mut ArrayRef2, kernel: &Kernel3x3) where F: Float { let (n, m) = a.dim(); diff --git a/examples/functions_and_traits.rs b/examples/functions_and_traits.rs new file mode 100644 index 000000000..dc8f73da4 --- /dev/null +++ b/examples/functions_and_traits.rs @@ -0,0 +1,178 @@ +//! Examples of how to write functions and traits that operate on `ndarray` types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], the owner of the layout that describes an array in memory; +//! this includes [`ndarray::Array`], [`ndarray::ArcArray`], [`ndarray::ArrayView`], +//! [`ndarray::RawArrayView`], and other variants. +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array. +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array. +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind. +//! +//! Below, we illustrate how to write functions and traits for most variants of these types. + +use ndarray::{ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; + +/// Take an array with the most basic requirements. +/// +/// This function takes its data as owning. It is very rare that a user will need to specifically +/// take a reference to an `ArrayBase`, rather than to one of the other four types. +#[rustfmt::skip] +fn takes_base_raw(arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsRef` + takes_rawref(arr.as_ref()); // Caller uses `.as_ref` + takes_rawref_asref(&arr); // Implementor uses `.as_ref` + takes_layout(arr.as_ref()); // Caller uses `.as_ref` + takes_layout_asref(&arr); // Implementor uses `.as_ref` + + arr +} + +/// Similar to above, but allow us to read the underlying data. +#[rustfmt::skip] +fn takes_base_raw_mut(mut arr: ArrayBase) -> ArrayBase +{ + // These skip from a possibly-raw array to `RawRef` and `LayoutRef`, and so must go through `AsMut` + takes_rawref_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_rawref_asmut(&mut arr); // Implementor uses `.as_mut` + takes_layout_mut(arr.as_mut()); // Caller uses `.as_mut` + takes_layout_asmut(&mut arr); // Implementor uses `.as_mut` + + arr +} + +/// Now take an array whose data is safe to read. +#[allow(dead_code)] +fn takes_base(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw(arr); + + // No need for AsRef, since data is safe + takes_arrref(&arr); + takes_rawref(&arr); + takes_rawref_asref(&arr); + takes_layout(&arr); + takes_layout_asref(&arr); + + arr +} + +/// Now, an array whose data is safe to read and that we can mutate. +/// +/// Notice that we include now a trait bound on `D: Dimension`; this is necessary in order +/// for the `ArrayBase` to dereference to an `ArrayRef` (or to any of the other types). +#[allow(dead_code)] +fn takes_base_mut(mut arr: ArrayBase) -> ArrayBase +{ + // Raw call + arr = takes_base_raw_mut(arr); + + // No need for AsMut, since data is safe + takes_arrref_mut(&mut arr); + takes_rawref_mut(&mut arr); + takes_rawref_asmut(&mut arr); + takes_layout_mut(&mut arr); + takes_layout_asmut(&mut arr); + + arr +} + +/// Now for new stuff: we want to read (but not alter) any array whose data is safe to read. +/// +/// This is probably the most common functionality that one would want to write. +/// As we'll see below, calling this function is very simple for `ArrayBase`. +fn takes_arrref(arr: &ArrayRef) +{ + // No need for AsRef, since data is safe + takes_rawref(arr); + takes_rawref_asref(arr); + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Now we want any array whose data is safe to mutate. +/// +/// **Importantly**, any array passed to this function is guaranteed to uniquely point to its data. +/// As a result, passing a shared array to this function will **silently** un-share the array. +#[allow(dead_code)] +fn takes_arrref_mut(arr: &mut ArrayRef) +{ + // Immutable call + takes_arrref(arr); + + // No need for AsMut, since data is safe + takes_rawref_mut(arr); + takes_rawref_asmut(arr); + takes_layout_mut(arr); + takes_rawref_asmut(arr); +} + +/// Now, we no longer care about whether we can safely read data. +/// +/// This is probably the rarest type to deal with, since `LayoutRef` can access and modify an array's +/// shape and strides, and even do in-place slicing. As a result, `RawRef` is only for functionality +/// that requires unsafe data access, something that `LayoutRef` can't do. +/// +/// Writing functions and traits that deal with `RawRef`s and `LayoutRef`s can be done two ways: +/// 1. Directly on the types; calling these functions on arrays whose data are not known to be safe +/// to dereference (i.e., raw array views or `ArrayBase`) must explicitly call `.as_ref()`. +/// 2. Via a generic with `: AsRef>`; doing this will allow direct calling for all `ArrayBase` and +/// `ArrayRef` instances. +/// We'll demonstrate #1 here for both immutable and mutable references, then #2 directly below. +#[allow(dead_code)] +fn takes_rawref(arr: &RawRef) +{ + takes_layout(arr); + takes_layout_asref(arr); +} + +/// Mutable, directly take `RawRef` +#[allow(dead_code)] +fn takes_rawref_mut(arr: &mut RawRef) +{ + takes_layout(arr); + takes_layout_asmut(arr); +} + +/// Immutable, take a generic that implements `AsRef` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asref(_arr: &T) +where T: AsRef> +{ + takes_layout(_arr.as_ref()); + takes_layout_asref(_arr.as_ref()); +} + +/// Mutable, take a generic that implements `AsMut` to `RawRef` +#[allow(dead_code)] +fn takes_rawref_asmut(_arr: &mut T) +where T: AsMut> +{ + takes_layout_mut(_arr.as_mut()); + takes_layout_asmut(_arr.as_mut()); +} + +/// Finally, there's `LayoutRef`: this type provides read and write access to an array's *structure*, but not its *data*. +/// +/// Practically, this means that functions that only read/modify an array's shape or strides, +/// such as checking dimensionality or slicing, should take `LayoutRef`. +/// +/// Like `RawRef`, functions can be written either directly on `LayoutRef` or as generics with `: AsRef>>`. +#[allow(dead_code)] +fn takes_layout(_arr: &LayoutRef) {} + +/// Mutable, directly take `LayoutRef` +#[allow(dead_code)] +fn takes_layout_mut(_arr: &mut LayoutRef) {} + +/// Immutable, take a generic that implements `AsRef` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asref>, A, D>(_arr: &T) {} + +/// Mutable, take a generic that implements `AsMut` to `LayoutRef` +#[allow(dead_code)] +fn takes_layout_asmut>, A, D>(_arr: &mut T) {} + +fn main() {} diff --git a/src/alias_asref.rs b/src/alias_asref.rs new file mode 100644 index 000000000..ab78af605 --- /dev/null +++ b/src/alias_asref.rs @@ -0,0 +1,359 @@ +use crate::{ + iter::Axes, + ArrayBase, + Axis, + AxisDescription, + Dimension, + NdIndex, + RawArrayView, + RawData, + RawDataMut, + Slice, + SliceArg, +}; + +/// Functions coming from RawRef +impl, D: Dimension> ArrayBase +{ + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view(); + /// let p = a.get_ptr((0, 1)).unwrap(); + /// + /// assert_eq!(unsafe { *p }, 2.); + /// ``` + pub fn get_ptr(&self, index: I) -> Option<*const A> + where I: NdIndex + { + self.as_raw_ref().get_ptr(index) + } + + /// Return a raw pointer to the element at `index`, or return `None` + /// if the index is out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let mut a = arr2(&[[1., 2.], [3., 4.]]); + /// + /// let v = a.raw_view_mut(); + /// let p = a.get_mut_ptr((0, 1)).unwrap(); + /// + /// unsafe { + /// *p = 5.; + /// } + /// + /// assert_eq!(a.get((0, 1)), Some(&5.)); + /// ``` + pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> + where + S: RawDataMut, + I: NdIndex, + { + self.as_raw_ref_mut().get_mut_ptr(index) + } + + /// Return a pointer to the first element in the array. + /// + /// Raw access to array elements needs to follow the strided indexing + /// scheme: an element at multi-index *I* in an array with strides *S* is + /// located at offset + /// + /// *Σ0 ≤ k < d Ik × Sk* + /// + /// where *d* is `self.ndim()`. + #[inline(always)] + pub fn as_ptr(&self) -> *const A + { + self.as_raw_ref().as_ptr() + } + + /// Return a raw view of the array. + #[inline] + pub fn raw_view(&self) -> RawArrayView + { + self.as_raw_ref().raw_view() + } +} + +/// Functions coming from LayoutRef +impl ArrayBase +{ + /// Slice the array in place without changing the number of dimensions. + /// + /// In particular, if an axis is sliced with an index, the axis is + /// collapsed, as in [`.collapse_axis()`], rather than removed, as in + /// [`.slice_move()`] or [`.index_axis_move()`]. + /// + /// [`.collapse_axis()`]: Self::collapse_axis + /// [`.slice_move()`]: Self::slice_move + /// [`.index_axis_move()`]: Self::index_axis_move + /// + /// See [*Slicing*](#slicing) for full documentation. + /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). + /// + /// **Panics** in the following cases: + /// + /// - if an index is out of bounds + /// - if a step size is zero + /// - if [`NewAxis`](`crate::SliceInfoElem::NewAxis`) is in `info`, e.g. if `NewAxis` was + /// used in the [`s!`] macro + /// - if `D` is `IxDyn` and `info` does not match the number of array axes + #[track_caller] + pub fn slice_collapse(&mut self, info: I) + where I: SliceArg + { + self.as_layout_ref_mut().slice_collapse(info); + } + + /// Slice the array in place along the specified axis. + /// + /// **Panics** if an index is out of bounds or step size is zero.
+ /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn slice_axis_inplace(&mut self, axis: Axis, indices: Slice) + { + self.as_layout_ref_mut().slice_axis_inplace(axis, indices); + } + + /// Slice the array in place, with a closure specifying the slice for each + /// axis. + /// + /// This is especially useful for code which is generic over the + /// dimensionality of the array. + /// + /// **Panics** if an index is out of bounds or step size is zero. + #[track_caller] + pub fn slice_each_axis_inplace(&mut self, f: F) + where F: FnMut(AxisDescription) -> Slice + { + self.as_layout_ref_mut().slice_each_axis_inplace(f); + } + + /// Selects `index` along the axis, collapsing the axis into length one. + /// + /// **Panics** if `axis` or `index` is out of bounds. + #[track_caller] + pub fn collapse_axis(&mut self, axis: Axis, index: usize) + { + self.as_layout_ref_mut().collapse_axis(axis, index); + } + + /// Return `true` if the array data is laid out in contiguous “C order” in + /// memory (where the last index is the most rapidly varying). + /// + /// Return `false` otherwise, i.e. the array is possibly not + /// contiguous in memory, it has custom strides, etc. + pub fn is_standard_layout(&self) -> bool + { + self.as_layout_ref().is_standard_layout() + } + + /// Return true if the array is known to be contiguous. + pub(crate) fn is_contiguous(&self) -> bool + { + self.as_layout_ref().is_contiguous() + } + + /// Return an iterator over the length and stride of each axis. + pub fn axes(&self) -> Axes<'_, D> + { + self.as_layout_ref().axes() + } + + /* + /// Return the axis with the least stride (by absolute value) + pub fn min_stride_axis(&self) -> Axis { + self.dim.min_stride_axis(&self.strides) + } + */ + + /// Return the axis with the greatest stride (by absolute value), + /// preferring axes with len > 1. + pub fn max_stride_axis(&self) -> Axis + { + self.as_layout_ref().max_stride_axis() + } + + /// Reverse the stride of `axis`. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn invert_axis(&mut self, axis: Axis) + { + self.as_layout_ref_mut().invert_axis(axis); + } + + /// Swap axes `ax` and `bx`. + /// + /// This does not move any data, it just adjusts the array’s dimensions + /// and strides. + /// + /// **Panics** if the axes are out of bounds. + /// + /// ``` + /// use ndarray::arr2; + /// + /// let mut a = arr2(&[[1., 2., 3.]]); + /// a.swap_axes(0, 1); + /// assert!( + /// a == arr2(&[[1.], [2.], [3.]]) + /// ); + /// ``` + #[track_caller] + pub fn swap_axes(&mut self, ax: usize, bx: usize) + { + self.as_layout_ref_mut().swap_axes(ax, bx); + } + + /// If possible, merge in the axis `take` to `into`. + /// + /// Returns `true` iff the axes are now merged. + /// + /// This method merges the axes if movement along the two original axes + /// (moving fastest along the `into` axis) can be equivalently represented + /// as movement along one (merged) axis. Merging the axes preserves this + /// order in the merged axis. If `take` and `into` are the same axis, then + /// the axis is "merged" if its length is ≤ 1. + /// + /// If the return value is `true`, then the following hold: + /// + /// * The new length of the `into` axis is the product of the original + /// lengths of the two axes. + /// + /// * The new length of the `take` axis is 0 if the product of the original + /// lengths of the two axes is 0, and 1 otherwise. + /// + /// If the return value is `false`, then merging is not possible, and the + /// original shape and strides have been preserved. + /// + /// Note that the ordering constraint means that if it's possible to merge + /// `take` into `into`, it's usually not possible to merge `into` into + /// `take`, and vice versa. + /// + /// ``` + /// use ndarray::Array3; + /// use ndarray::Axis; + /// + /// let mut a = Array3::::zeros((2, 3, 4)); + /// assert!(a.merge_axes(Axis(1), Axis(2))); + /// assert_eq!(a.shape(), &[2, 1, 12]); + /// ``` + /// + /// ***Panics*** if an axis is out of bounds. + #[track_caller] + pub fn merge_axes(&mut self, take: Axis, into: Axis) -> bool + { + self.as_layout_ref_mut().merge_axes(take, into) + } + + /// Return the total number of elements in the array. + pub fn len(&self) -> usize + { + self.as_layout_ref().len() + } + + /// Return the length of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn len_of(&self, axis: Axis) -> usize + { + self.as_layout_ref().len_of(axis) + } + + /// Return whether the array has any elements + pub fn is_empty(&self) -> bool + { + self.as_layout_ref().is_empty() + } + + /// Return the number of dimensions (axes) in the array + pub fn ndim(&self) -> usize + { + self.as_layout_ref().ndim() + } + + /// Return the shape of the array in its “pattern” form, + /// an integer in the one-dimensional case, tuple in the n-dimensional cases + /// and so on. + pub fn dim(&self) -> D::Pattern + { + self.as_layout_ref().dim() + } + + /// Return the shape of the array as it's stored in the array. + /// + /// This is primarily useful for passing to other `ArrayBase` + /// functions, such as when creating another array of the same + /// shape and dimensionality. + /// + /// ``` + /// use ndarray::Array; + /// + /// let a = Array::from_elem((2, 3), 5.); + /// + /// // Create an array of zeros that's the same shape and dimensionality as `a`. + /// let b = Array::::zeros(a.raw_dim()); + /// ``` + pub fn raw_dim(&self) -> D + { + self.as_layout_ref().raw_dim() + } + + /// Return the shape of the array as a slice. + /// + /// Note that you probably don't want to use this to create an array of the + /// same shape as another array because creating an array with e.g. + /// [`Array::zeros()`](ArrayBase::zeros) using a shape of type `&[usize]` + /// results in a dynamic-dimensional array. If you want to create an array + /// that has the same shape and dimensionality as another array, use + /// [`.raw_dim()`](ArrayBase::raw_dim) instead: + /// + /// ```rust + /// use ndarray::{Array, Array2}; + /// + /// let a = Array2::::zeros((3, 4)); + /// let shape = a.shape(); + /// assert_eq!(shape, &[3, 4]); + /// + /// // Since `a.shape()` returned `&[usize]`, we get an `ArrayD` instance: + /// let b = Array::zeros(shape); + /// assert_eq!(a.clone().into_dyn(), b); + /// + /// // To get the same dimension type, use `.raw_dim()` instead: + /// let c = Array::zeros(a.raw_dim()); + /// assert_eq!(a, c); + /// ``` + pub fn shape(&self) -> &[usize] + { + self.as_layout_ref().shape() + } + + /// Return the strides of the array as a slice. + pub fn strides(&self) -> &[isize] + { + self.as_layout_ref().strides() + } + + /// Return the stride of `axis`. + /// + /// The axis should be in the range `Axis(` 0 .. *n* `)` where *n* is the + /// number of dimensions (axes) of the array. + /// + /// ***Panics*** if the axis is out of bounds. + #[track_caller] + pub fn stride_of(&self, axis: Axis) -> isize + { + self.as_layout_ref().stride_of(axis) + } +} diff --git a/src/aliases.rs b/src/aliases.rs index 5df0c95ec..7f897304b 100644 --- a/src/aliases.rs +++ b/src/aliases.rs @@ -2,7 +2,7 @@ //! use crate::dimension::Dim; -use crate::{ArcArray, Array, ArrayView, ArrayViewMut, Ix, IxDynImpl}; +use crate::{ArcArray, Array, ArrayRef, ArrayView, ArrayViewMut, Ix, IxDynImpl, LayoutRef}; /// Create a zero-dimensional index #[allow(non_snake_case)] @@ -123,6 +123,40 @@ pub type Array6
= Array; /// dynamic-dimensional array pub type ArrayD = Array; +/// zero-dimensional array reference +pub type ArrayRef0 = ArrayRef; +/// one-dimensional array reference +pub type ArrayRef1 = ArrayRef; +/// two-dimensional array reference +pub type ArrayRef2 = ArrayRef; +/// three-dimensional array reference +pub type ArrayRef3 = ArrayRef; +/// four-dimensional array reference +pub type ArrayRef4 = ArrayRef; +/// five-dimensional array reference +pub type ArrayRef5 = ArrayRef; +/// six-dimensional array reference +pub type ArrayRef6 = ArrayRef; +/// dynamic-dimensional array reference +pub type ArrayRefD = ArrayRef; + +/// zero-dimensional layout reference +pub type LayoutRef0 = LayoutRef; +/// one-dimensional layout reference +pub type LayoutRef1 = LayoutRef; +/// two-dimensional layout reference +pub type LayoutRef2 = LayoutRef; +/// three-dimensional layout reference +pub type LayoutRef3 = LayoutRef; +/// four-dimensional layout reference +pub type LayoutRef4 = LayoutRef; +/// five-dimensional layout reference +pub type LayoutRef5 = LayoutRef; +/// six-dimensional layout reference +pub type LayoutRef6 = LayoutRef; +/// dynamic-dimensional layout reference +pub type LayoutRefD = LayoutRef; + /// zero-dimensional array view pub type ArrayView0<'a, A> = ArrayView<'a, A, Ix0>; /// one-dimensional array view diff --git a/src/array_approx.rs b/src/array_approx.rs index c6fd174d1..958f6f6ba 100644 --- a/src/array_approx.rs +++ b/src/array_approx.rs @@ -4,29 +4,24 @@ mod approx_methods { use crate::imp_prelude::*; - impl ArrayBase - where - S: Data, - D: Dimension, + impl ArrayRef { /// A test for equality that uses the elementwise absolute difference to compute the /// approximate equality of two arrays. - pub fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool + pub fn abs_diff_eq(&self, other: &ArrayRef, epsilon: A::Epsilon) -> bool where - A: ::approx::AbsDiffEq, + A: ::approx::AbsDiffEq, A::Epsilon: Clone, - S2: Data, { >::abs_diff_eq(self, other, epsilon) } /// A test for equality that uses an elementwise relative comparison if the values are far /// apart; and the absolute difference otherwise. - pub fn relative_eq(&self, other: &ArrayBase, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool + pub fn relative_eq(&self, other: &ArrayRef, epsilon: A::Epsilon, max_relative: A::Epsilon) -> bool where - A: ::approx::RelativeEq, + A: ::approx::RelativeEq, A::Epsilon: Clone, - S2: Data, { >::relative_eq(self, other, epsilon, max_relative) } @@ -41,12 +36,10 @@ macro_rules! impl_approx_traits { use $approx::{AbsDiffEq, RelativeEq, UlpsEq}; #[doc = $doc] - impl AbsDiffEq> for ArrayBase + impl AbsDiffEq> for ArrayRef where A: AbsDiffEq, A::Epsilon: Clone, - S: Data, - S2: Data, D: Dimension, { type Epsilon = A::Epsilon; @@ -55,7 +48,7 @@ macro_rules! impl_approx_traits { A::default_epsilon() } - fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + fn abs_diff_eq(&self, other: &ArrayRef, epsilon: A::Epsilon) -> bool { if self.shape() != other.shape() { return false; } @@ -67,13 +60,31 @@ macro_rules! impl_approx_traits { } #[doc = $doc] - impl RelativeEq> for ArrayBase + impl AbsDiffEq> for ArrayBase where - A: RelativeEq, + A: AbsDiffEq, A::Epsilon: Clone, S: Data, S2: Data, D: Dimension, + { + type Epsilon = A::Epsilon; + + fn default_epsilon() -> A::Epsilon { + A::default_epsilon() + } + + fn abs_diff_eq(&self, other: &ArrayBase, epsilon: A::Epsilon) -> bool { + (**self).abs_diff_eq(other, epsilon) + } + } + + #[doc = $doc] + impl RelativeEq> for ArrayRef + where + A: RelativeEq, + A::Epsilon: Clone, + D: Dimension, { fn default_max_relative() -> A::Epsilon { A::default_max_relative() @@ -81,7 +92,7 @@ macro_rules! impl_approx_traits { fn relative_eq( &self, - other: &ArrayBase, + other: &ArrayRef, epsilon: A::Epsilon, max_relative: A::Epsilon, ) -> bool { @@ -96,13 +107,34 @@ macro_rules! impl_approx_traits { } #[doc = $doc] - impl UlpsEq> for ArrayBase + impl RelativeEq> for ArrayBase where - A: UlpsEq, + A: RelativeEq, A::Epsilon: Clone, S: Data, S2: Data, D: Dimension, + { + fn default_max_relative() -> A::Epsilon { + A::default_max_relative() + } + + fn relative_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_relative: A::Epsilon, + ) -> bool { + (**self).relative_eq(other, epsilon, max_relative) + } + } + + #[doc = $doc] + impl UlpsEq> for ArrayRef + where + A: UlpsEq, + A::Epsilon: Clone, + D: Dimension, { fn default_max_ulps() -> u32 { A::default_max_ulps() @@ -110,7 +142,7 @@ macro_rules! impl_approx_traits { fn ulps_eq( &self, - other: &ArrayBase, + other: &ArrayRef, epsilon: A::Epsilon, max_ulps: u32, ) -> bool { @@ -124,6 +156,29 @@ macro_rules! impl_approx_traits { } } + #[doc = $doc] + impl UlpsEq> for ArrayBase + where + A: UlpsEq, + A::Epsilon: Clone, + S: Data, + S2: Data, + D: Dimension, + { + fn default_max_ulps() -> u32 { + A::default_max_ulps() + } + + fn ulps_eq( + &self, + other: &ArrayBase, + epsilon: A::Epsilon, + max_ulps: u32, + ) -> bool { + (**self).ulps_eq(other, epsilon, max_ulps) + } + } + #[cfg(test)] mod tests { use crate::prelude::*; diff --git a/src/arrayformat.rs b/src/arrayformat.rs index 1a3b714c3..7e5e1b1c9 100644 --- a/src/arrayformat.rs +++ b/src/arrayformat.rs @@ -6,7 +6,10 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. use super::{ArrayBase, ArrayView, Axis, Data, Dimension, NdProducer}; -use crate::aliases::{Ix1, IxDyn}; +use crate::{ + aliases::{Ix1, IxDyn}, + ArrayRef, +}; use alloc::format; use std::fmt; @@ -112,13 +115,12 @@ fn format_with_overflow( Ok(()) } -fn format_array( - array: &ArrayBase, f: &mut fmt::Formatter<'_>, format: F, fmt_opt: &FormatOptions, +fn format_array( + array: &ArrayRef, f: &mut fmt::Formatter<'_>, format: F, fmt_opt: &FormatOptions, ) -> fmt::Result where F: FnMut(&A, &mut fmt::Formatter<'_>) -> fmt::Result + Clone, D: Dimension, - S: Data, { // Cast into a dynamically dimensioned view // This is required to be able to use `index_axis` for the recursive case @@ -174,6 +176,18 @@ where /// The array is shown in multiline style. impl fmt::Display for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `Display` and apply the formatting parameters +/// used to each element. +/// +/// The array is shown in multiline style. +impl fmt::Display for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -188,6 +202,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::Debug for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `Debug` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::Debug for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -216,6 +242,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::LowerExp for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array reference using `LowerExp` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::LowerExp for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -230,6 +268,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::UpperExp for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `UpperExp` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::UpperExp for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -237,12 +287,25 @@ where S: Data format_array(self, f, <_>::fmt, &fmt_opt) } } + /// Format the array using `LowerHex` and apply the formatting parameters used /// to each element. /// /// The array is shown in multiline style. impl fmt::LowerHex for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `LowerHex` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::LowerHex for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -257,6 +320,18 @@ where S: Data /// The array is shown in multiline style. impl fmt::Binary for ArrayBase where S: Data +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result + { + (**self).fmt(f) + } +} + +/// Format the array using `Binary` and apply the formatting parameters used +/// to each element. +/// +/// The array is shown in multiline style. +impl fmt::Binary for ArrayRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/src/arraytraits.rs b/src/arraytraits.rs index 62f95df4a..5068cd6c2 100644 --- a/src/arraytraits.rs +++ b/src/arraytraits.rs @@ -19,6 +19,7 @@ use std::{iter::FromIterator, slice}; use crate::imp_prelude::*; use crate::Arc; +use crate::LayoutRef; use crate::{ dimension, iter::{Iter, IterMut}, @@ -37,11 +38,10 @@ pub(crate) fn array_out_of_bounds() -> ! } #[inline(always)] -pub fn debug_bounds_check(_a: &ArrayBase, _index: &I) +pub fn debug_bounds_check(_a: &LayoutRef, _index: &I) where D: Dimension, I: NdIndex, - S: Data, { debug_bounds_check!(_a, *_index); } @@ -49,15 +49,15 @@ where /// Access the element at **index**. /// /// **Panics** if index is out of bounds. -impl Index for ArrayBase +impl Index for ArrayRef where D: Dimension, I: NdIndex, - S: Data, { - type Output = S::Elem; + type Output = A; + #[inline] - fn index(&self, index: I) -> &S::Elem + fn index(&self, index: I) -> &Self::Output { debug_bounds_check!(self, index); unsafe { @@ -73,14 +73,13 @@ where /// Access the element at **index** mutably. /// /// **Panics** if index is out of bounds. -impl IndexMut for ArrayBase +impl IndexMut for ArrayRef where D: Dimension, I: NdIndex, - S: DataMut, { #[inline] - fn index_mut(&mut self, index: I) -> &mut S::Elem + fn index_mut(&mut self, index: I) -> &mut A { debug_bounds_check!(self, index); unsafe { @@ -93,16 +92,48 @@ where } } +/// Access the element at **index**. +/// +/// **Panics** if index is out of bounds. +impl Index for ArrayBase +where + D: Dimension, + I: NdIndex, + S: Data, +{ + type Output = S::Elem; + + #[inline] + fn index(&self, index: I) -> &S::Elem + { + Index::index(&**self, index) + } +} + +/// Access the element at **index** mutably. +/// +/// **Panics** if index is out of bounds. +impl IndexMut for ArrayBase +where + D: Dimension, + I: NdIndex, + S: DataMut, +{ + #[inline] + fn index_mut(&mut self, index: I) -> &mut S::Elem + { + IndexMut::index_mut(&mut (**self), index) + } +} + /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. -impl PartialEq> for ArrayBase +impl PartialEq> for ArrayRef where A: PartialEq, - S: Data, - S2: Data, D: Dimension, { - fn eq(&self, rhs: &ArrayBase) -> bool + fn eq(&self, rhs: &ArrayRef) -> bool { if self.shape() != rhs.shape() { return false; @@ -125,6 +156,54 @@ where } } +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +impl PartialEq<&ArrayRef> for ArrayRef +where + A: PartialEq, + D: Dimension, +{ + fn eq(&self, rhs: &&ArrayRef) -> bool + { + *self == **rhs + } +} + +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +impl PartialEq> for &ArrayRef +where + A: PartialEq, + D: Dimension, +{ + fn eq(&self, rhs: &ArrayRef) -> bool + { + **self == *rhs + } +} + +impl Eq for ArrayRef +where + D: Dimension, + A: Eq, +{ +} + +/// Return `true` if the array shapes and all elements of `self` and +/// `rhs` are equal. Return `false` otherwise. +impl PartialEq> for ArrayBase +where + A: PartialEq, + S: Data, + S2: Data, + D: Dimension, +{ + fn eq(&self, rhs: &ArrayBase) -> bool + { + PartialEq::eq(&**self, &**rhs) + } +} + /// Return `true` if the array shapes and all elements of `self` and /// `rhs` are equal. Return `false` otherwise. #[allow(clippy::unconditional_recursion)] // false positive @@ -216,6 +295,32 @@ where S: DataOwned } } +impl<'a, A, D> IntoIterator for &'a ArrayRef +where D: Dimension +{ + type Item = &'a A; + + type IntoIter = Iter<'a, A, D>; + + fn into_iter(self) -> Self::IntoIter + { + self.iter() + } +} + +impl<'a, A, D> IntoIterator for &'a mut ArrayRef +where D: Dimension +{ + type Item = &'a mut A; + + type IntoIter = IterMut<'a, A, D>; + + fn into_iter(self) -> Self::IntoIter + { + self.iter_mut() + } +} + impl<'a, S, D> IntoIterator for &'a ArrayBase where D: Dimension, @@ -268,11 +373,10 @@ where D: Dimension } } -impl hash::Hash for ArrayBase +impl hash::Hash for ArrayRef where D: Dimension, - S: Data, - S::Elem: hash::Hash, + A: hash::Hash, { // Note: elements are hashed in the logical order fn hash(&self, state: &mut H) @@ -294,6 +398,19 @@ where } } +impl hash::Hash for ArrayBase +where + D: Dimension, + S: Data, + S::Elem: hash::Hash, +{ + // Note: elements are hashed in the logical order + fn hash(&self, state: &mut H) + { + (**self).hash(state) + } +} + // NOTE: ArrayBase keeps an internal raw pointer that always // points into the storage. This is Sync & Send as long as we // follow the usual inherited mutability rules, as we do with @@ -464,7 +581,7 @@ where D: Dimension { let data = OwnedArcRepr(Arc::new(arr.data)); // safe because: equivalent unmoved data, ptr and dims remain valid - unsafe { ArrayBase::from_data_ptr(data, arr.ptr).with_strides_dim(arr.strides, arr.dim) } + unsafe { ArrayBase::from_data_ptr(data, arr.layout.ptr).with_strides_dim(arr.layout.strides, arr.layout.dim) } } } diff --git a/src/data_traits.rs b/src/data_traits.rs index fc2fe4bfa..4266e4017 100644 --- a/src/data_traits.rs +++ b/src/data_traits.rs @@ -23,7 +23,7 @@ use std::mem::MaybeUninit; use std::mem::{self, size_of}; use std::ptr::NonNull; -use crate::{ArcArray, Array, ArrayBase, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; +use crate::{ArcArray, Array, ArrayBase, ArrayRef, CowRepr, Dimension, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; /// Array representation trait. /// @@ -251,7 +251,7 @@ where A: Clone if Arc::get_mut(&mut self_.data.0).is_some() { return; } - if self_.dim.size() <= self_.data.0.len() / 2 { + if self_.layout.dim.size() <= self_.data.0.len() / 2 { // Clone only the visible elements if the current view is less than // half of backing data. *self_ = self_.to_owned().into_shared(); @@ -260,13 +260,13 @@ where A: Clone let rcvec = &mut self_.data.0; let a_size = mem::size_of::() as isize; let our_off = if a_size != 0 { - (self_.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size + (self_.layout.ptr.as_ptr() as isize - rcvec.as_ptr() as isize) / a_size } else { 0 }; let rvec = Arc::make_mut(rcvec); unsafe { - self_.ptr = rvec.as_nonnull_mut().offset(our_off); + self_.layout.ptr = rvec.as_nonnull_mut().offset(our_off); } } @@ -286,7 +286,9 @@ unsafe impl Data for OwnedArcRepr Self::ensure_unique(&mut self_); let data = Arc::try_unwrap(self_.data.0).ok().unwrap(); // safe because data is equivalent - unsafe { ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) } + unsafe { + ArrayBase::from_data_ptr(data, self_.layout.ptr).with_strides_dim(self_.layout.strides, self_.layout.dim) + } } fn try_into_owned_nocopy(self_: ArrayBase) -> Result, ArrayBase> @@ -295,13 +297,14 @@ unsafe impl Data for OwnedArcRepr match Arc::try_unwrap(self_.data.0) { Ok(owned_data) => unsafe { // Safe because the data is equivalent. - Ok(ArrayBase::from_data_ptr(owned_data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + Ok(ArrayBase::from_data_ptr(owned_data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, Err(arc_data) => unsafe { // Safe because the data is equivalent; we're just // reconstructing `self_`. - Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.ptr) - .with_strides_dim(self_.strides, self_.dim)) + Err(ArrayBase::from_data_ptr(OwnedArcRepr(arc_data), self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, } } @@ -598,11 +601,11 @@ where A: Clone { match array.data { CowRepr::View(_) => { - let owned = array.to_owned(); + let owned = ArrayRef::to_owned(array); array.data = CowRepr::Owned(owned.data); - array.ptr = owned.ptr; - array.dim = owned.dim; - array.strides = owned.strides; + array.layout.ptr = owned.layout.ptr; + array.layout.dim = owned.layout.dim; + array.layout.strides = owned.layout.strides; } CowRepr::Owned(_) => {} } @@ -663,7 +666,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> CowRepr::View(_) => self_.to_owned(), CowRepr::Owned(data) => unsafe { // safe because the data is equivalent so ptr, dims remain valid - ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim) + ArrayBase::from_data_ptr(data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim) }, } } @@ -675,7 +679,8 @@ unsafe impl<'a, A> Data for CowRepr<'a, A> CowRepr::View(_) => Err(self_), CowRepr::Owned(data) => unsafe { // safe because the data is equivalent so ptr, dims remain valid - Ok(ArrayBase::from_data_ptr(data, self_.ptr).with_strides_dim(self_.strides, self_.dim)) + Ok(ArrayBase::from_data_ptr(data, self_.layout.ptr) + .with_strides_dim(self_.layout.strides, self_.layout.dim)) }, } } diff --git a/src/doc/ndarray_for_numpy_users/mod.rs b/src/doc/ndarray_for_numpy_users/mod.rs index eba96cdd0..bb6b7ae83 100644 --- a/src/doc/ndarray_for_numpy_users/mod.rs +++ b/src/doc/ndarray_for_numpy_users/mod.rs @@ -322,7 +322,7 @@ //! //! //! -//! [`mat1.dot(&mat2)`][matrix-* dot] +//! [`mat1.dot(&mat2)`][dot-2-2] //! //! //! @@ -336,7 +336,7 @@ //! //! //! -//! [`mat.dot(&vec)`][matrix-* dot] +//! [`mat.dot(&vec)`][dot-2-1] //! //! //! @@ -350,7 +350,7 @@ //! //! //! -//! [`vec.dot(&mat)`][vec-* dot] +//! [`vec.dot(&mat)`][dot-1-2] //! //! //! @@ -364,7 +364,7 @@ //! //! //! -//! [`vec1.dot(&vec2)`][vec-* dot] +//! [`vec1.dot(&vec2)`][dot-1-1] //! //! //! @@ -670,22 +670,22 @@ //! `a[:,4]` | [`a.column(4)`][.column()] or [`a.column_mut(4)`][.column_mut()] | view (or mutable view) of column 4 in a 2-D array //! `a.shape[0] == a.shape[1]` | [`a.is_square()`][.is_square()] | check if the array is square //! -//! [.abs_diff_eq()]: ArrayBase#impl-AbsDiffEq> -//! [.assign()]: ArrayBase::assign -//! [.axis_iter()]: ArrayBase::axis_iter -//! [.ncols()]: ArrayBase::ncols -//! [.column()]: ArrayBase::column -//! [.column_mut()]: ArrayBase::column_mut +//! [.abs_diff_eq()]: ArrayRef#impl-AbsDiffEq%3CArrayRef%3CB,+D%3E%3E +//! [.assign()]: ArrayRef::assign +//! [.axis_iter()]: ArrayRef::axis_iter +//! [.ncols()]: LayoutRef::ncols +//! [.column()]: ArrayRef::column +//! [.column_mut()]: ArrayRef::column_mut //! [concatenate()]: crate::concatenate() //! [concatenate!]: crate::concatenate! //! [stack!]: crate::stack! //! [::default()]: ArrayBase::default -//! [.diag()]: ArrayBase::diag -//! [.dim()]: ArrayBase::dim +//! [.diag()]: ArrayRef::diag +//! [.dim()]: LayoutRef::dim //! [::eye()]: ArrayBase::eye -//! [.fill()]: ArrayBase::fill -//! [.fold()]: ArrayBase::fold -//! [.fold_axis()]: ArrayBase::fold_axis +//! [.fill()]: ArrayRef::fill +//! [.fold()]: ArrayRef::fold +//! [.fold_axis()]: ArrayRef::fold_axis //! [::from_elem()]: ArrayBase::from_elem //! [::from_iter()]: ArrayBase::from_iter //! [::from_diag()]: ArrayBase::from_diag @@ -694,48 +694,51 @@ //! [::from_shape_vec_unchecked()]: ArrayBase::from_shape_vec_unchecked //! [::from_vec()]: ArrayBase::from_vec //! [.index()]: ArrayBase#impl-Index -//! [.indexed_iter()]: ArrayBase::indexed_iter +//! [.indexed_iter()]: ArrayRef::indexed_iter //! [.insert_axis()]: ArrayBase::insert_axis -//! [.is_empty()]: ArrayBase::is_empty -//! [.is_square()]: ArrayBase::is_square -//! [.iter()]: ArrayBase::iter -//! [.len()]: ArrayBase::len -//! [.len_of()]: ArrayBase::len_of +//! [.is_empty()]: LayoutRef::is_empty +//! [.is_square()]: LayoutRef::is_square +//! [.iter()]: ArrayRef::iter +//! [.len()]: LayoutRef::len +//! [.len_of()]: LayoutRef::len_of //! [::linspace()]: ArrayBase::linspace //! [::logspace()]: ArrayBase::logspace //! [::geomspace()]: ArrayBase::geomspace -//! [.map()]: ArrayBase::map -//! [.map_axis()]: ArrayBase::map_axis -//! [.map_inplace()]: ArrayBase::map_inplace -//! [.mapv()]: ArrayBase::mapv -//! [.mapv_inplace()]: ArrayBase::mapv_inplace +//! [.map()]: ArrayRef::map +//! [.map_axis()]: ArrayRef::map_axis +//! [.map_inplace()]: ArrayRef::map_inplace +//! [.mapv()]: ArrayRef::mapv +//! [.mapv_inplace()]: ArrayRef::mapv_inplace //! [.mapv_into()]: ArrayBase::mapv_into -//! [matrix-* dot]: ArrayBase::dot-1 -//! [.mean()]: ArrayBase::mean -//! [.mean_axis()]: ArrayBase::mean_axis -//! [.ndim()]: ArrayBase::ndim +//! [dot-2-2]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-1-1]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-1-2]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [dot-2-1]: ArrayRef#impl-Dot>>-for-ArrayRef> +//! [.mean()]: ArrayRef::mean +//! [.mean_axis()]: ArrayRef::mean_axis +//! [.ndim()]: LayoutRef::ndim //! [::ones()]: ArrayBase::ones -//! [.outer_iter()]: ArrayBase::outer_iter +//! [.outer_iter()]: ArrayRef::outer_iter //! [::range()]: ArrayBase::range -//! [.raw_dim()]: ArrayBase::raw_dim +//! [.raw_dim()]: LayoutRef::raw_dim //! [.reversed_axes()]: ArrayBase::reversed_axes -//! [.row()]: ArrayBase::row -//! [.row_mut()]: ArrayBase::row_mut -//! [.nrows()]: ArrayBase::nrows -//! [.sum()]: ArrayBase::sum -//! [.slice()]: ArrayBase::slice -//! [.slice_axis()]: ArrayBase::slice_axis -//! [.slice_collapse()]: ArrayBase::slice_collapse +//! [.row()]: ArrayRef::row +//! [.row_mut()]: ArrayRef::row_mut +//! [.nrows()]: LayoutRef::nrows +//! [.sum()]: ArrayRef::sum +//! [.slice()]: ArrayRef::slice +//! [.slice_axis()]: ArrayRef::slice_axis +//! [.slice_collapse()]: LayoutRef::slice_collapse //! [.slice_move()]: ArrayBase::slice_move -//! [.slice_mut()]: ArrayBase::slice_mut -//! [.shape()]: ArrayBase::shape +//! [.slice_mut()]: ArrayRef::slice_mut +//! [.shape()]: LayoutRef::shape //! [stack()]: crate::stack() -//! [.strides()]: ArrayBase::strides -//! [.index_axis()]: ArrayBase::index_axis -//! [.sum_axis()]: ArrayBase::sum_axis -//! [.t()]: ArrayBase::t -//! [vec-* dot]: ArrayBase::dot -//! [.for_each()]: ArrayBase::for_each +//! [.strides()]: LayoutRef::strides +//! [.index_axis()]: ArrayRef::index_axis +//! [.sum_axis()]: ArrayRef::sum_axis +//! [.t()]: ArrayRef::t +//! [vec-* dot]: ArrayRef::dot +//! [.for_each()]: ArrayRef::for_each //! [::zeros()]: ArrayBase::zeros //! [`Zip`]: crate::Zip diff --git a/src/doc/ndarray_for_numpy_users/rk_step.rs b/src/doc/ndarray_for_numpy_users/rk_step.rs index c882a3d00..820d6cdfb 100644 --- a/src/doc/ndarray_for_numpy_users/rk_step.rs +++ b/src/doc/ndarray_for_numpy_users/rk_step.rs @@ -122,7 +122,7 @@ //! //! * Use [`c.mul_add(h, t)`](f64::mul_add) instead of `t + c * h`. This is //! faster and reduces the floating-point error. It might also be beneficial -//! to use [`.scaled_add()`] or a combination of +//! to use [`.scaled_add()`](crate::ArrayRef::scaled_add) or a combination of //! [`azip!()`] and [`.mul_add()`](f64::mul_add) on the arrays in //! some places, but that's not demonstrated in the example below. //! @@ -168,7 +168,7 @@ //! # fn main() { let _ = rk_step::, ArrayViewMut1<'_, f64>)>; } //! ``` //! -//! [`.scaled_add()`]: crate::ArrayBase::scaled_add +//! [`.scaled_add()`]: crate::ArrayRef::scaled_add //! [`azip!()`]: crate::azip! //! //! ### SciPy license diff --git a/src/free_functions.rs b/src/free_functions.rs index 5659d7024..c1889cec8 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -14,8 +14,8 @@ use std::compile_error; use std::mem::{forget, size_of}; use std::ptr::NonNull; -use crate::imp_prelude::*; use crate::{dimension, ArcArray1, ArcArray2}; +use crate::{imp_prelude::*, LayoutRef}; /// Create an **[`Array`]** with one, two, three, four, five, or six dimensions. /// @@ -106,10 +106,12 @@ pub const fn aview0(x: &A) -> ArrayView0<'_, A> { ArrayBase { data: ViewRepr::new(), - // Safe because references are always non-null. - ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) }, - dim: Ix0(), - strides: Ix0(), + layout: LayoutRef { + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(x as *const A as *mut A) }, + dim: Ix0(), + strides: Ix0(), + }, } } @@ -144,10 +146,12 @@ pub const fn aview1(xs: &[A]) -> ArrayView1<'_, A> } ArrayBase { data: ViewRepr::new(), - // Safe because references are always non-null. - ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }, - dim: Ix1(xs.len()), - strides: Ix1(1), + layout: LayoutRef { + // Safe because references are always non-null. + ptr: unsafe { NonNull::new_unchecked(xs.as_ptr() as *mut A) }, + dim: Ix1(xs.len()), + strides: Ix1(1), + }, } } @@ -200,9 +204,7 @@ pub const fn aview2(xs: &[[A; N]]) -> ArrayView2<'_, A> }; ArrayBase { data: ViewRepr::new(), - ptr, - dim, - strides, + layout: LayoutRef { ptr, dim, strides }, } } diff --git a/src/impl_1d.rs b/src/impl_1d.rs index e49fdd731..bd34ba2ca 100644 --- a/src/impl_1d.rs +++ b/src/impl_1d.rs @@ -15,14 +15,11 @@ use crate::imp_prelude::*; use crate::low_level_util::AbortIfPanic; /// # Methods For 1-D Arrays -impl ArrayBase -where S: RawData +impl ArrayRef { /// Return an vector with the elements of the one-dimensional array. pub fn to_vec(&self) -> Vec - where - A: Clone, - S: Data, + where A: Clone { if let Some(slc) = self.as_slice() { slc.to_vec() @@ -34,7 +31,6 @@ where S: RawData /// Rotate the elements of the array by 1 element towards the front; /// the former first element becomes the last. pub(crate) fn rotate1_front(&mut self) - where S: DataMut { // use swapping to keep all elements initialized (as required by owned storage) let mut lane_iter = self.iter_mut(); diff --git a/src/impl_2d.rs b/src/impl_2d.rs index c2e9725ac..b6379e67b 100644 --- a/src/impl_2d.rs +++ b/src/impl_2d.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods For 2-D Arrays -impl ArrayBase -where S: RawData +impl ArrayRef { /// Return an array view of row `index`. /// @@ -24,7 +23,6 @@ where S: RawData /// ``` #[track_caller] pub fn row(&self, index: Ix) -> ArrayView1<'_, A> - where S: Data { self.index_axis(Axis(0), index) } @@ -41,11 +39,13 @@ where S: RawData /// ``` #[track_caller] pub fn row_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where S: DataMut { self.index_axis_mut(Axis(0), index) } +} +impl LayoutRef +{ /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. /// /// ``` @@ -67,7 +67,10 @@ where S: RawData { self.len_of(Axis(0)) } +} +impl ArrayRef +{ /// Return an array view of column `index`. /// /// **Panics** if `index` is out of bounds. @@ -79,7 +82,6 @@ where S: RawData /// ``` #[track_caller] pub fn column(&self, index: Ix) -> ArrayView1<'_, A> - where S: Data { self.index_axis(Axis(1), index) } @@ -96,11 +98,13 @@ where S: RawData /// ``` #[track_caller] pub fn column_mut(&mut self, index: Ix) -> ArrayViewMut1<'_, A> - where S: DataMut { self.index_axis_mut(Axis(1), index) } +} +impl LayoutRef +{ /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. /// /// ``` @@ -144,3 +148,70 @@ where S: RawData m == n } } + +impl ArrayBase +{ + /// Return the number of rows (length of `Axis(0)`) in the two-dimensional array. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.nrows(), 3); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(m, array.nrows()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(m, array.len_of(Axis(0))); + /// ``` + pub fn nrows(&self) -> usize + { + self.as_layout_ref().nrows() + } + + /// Return the number of columns (length of `Axis(1)`) in the two-dimensional array. + /// + /// ``` + /// use ndarray::{array, Axis}; + /// + /// let array = array![[1., 2.], + /// [3., 4.], + /// [5., 6.]]; + /// assert_eq!(array.ncols(), 2); + /// + /// // equivalent ways of getting the dimensions + /// // get nrows, ncols by using dim: + /// let (m, n) = array.dim(); + /// assert_eq!(n, array.ncols()); + /// // get length of any particular axis with .len_of() + /// assert_eq!(n, array.len_of(Axis(1))); + /// ``` + pub fn ncols(&self) -> usize + { + self.as_layout_ref().ncols() + } + + /// Return true if the array is square, false otherwise. + /// + /// # Examples + /// Square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2.], [3., 4.]]; + /// assert!(array.is_square()); + /// ``` + /// Not square: + /// ``` + /// use ndarray::array; + /// let array = array![[1., 2., 5.], [3., 4., 6.]]; + /// assert!(!array.is_square()); + /// ``` + pub fn is_square(&self) -> bool + { + self.as_layout_ref().is_square() + } +} diff --git a/src/impl_clone.rs b/src/impl_clone.rs index d65f6c338..402437941 100644 --- a/src/impl_clone.rs +++ b/src/impl_clone.rs @@ -7,6 +7,7 @@ // except according to those terms. use crate::imp_prelude::*; +use crate::LayoutRef; use crate::RawDataClone; impl Clone for ArrayBase @@ -15,12 +16,14 @@ impl Clone for ArrayBase { // safe because `clone_with_ptr` promises to provide equivalent data and ptr unsafe { - let (data, ptr) = self.data.clone_with_ptr(self.ptr); + let (data, ptr) = self.data.clone_with_ptr(self.layout.ptr); ArrayBase { data, - ptr, - dim: self.dim.clone(), - strides: self.strides.clone(), + layout: LayoutRef { + ptr, + dim: self.layout.dim.clone(), + strides: self.layout.strides.clone(), + }, } } } @@ -31,9 +34,9 @@ impl Clone for ArrayBase fn clone_from(&mut self, other: &Self) { unsafe { - self.ptr = self.data.clone_from_with_ptr(&other.data, other.ptr); - self.dim.clone_from(&other.dim); - self.strides.clone_from(&other.strides); + self.layout.ptr = self.data.clone_from_with_ptr(&other.data, other.layout.ptr); + self.layout.dim.clone_from(&other.layout.dim); + self.layout.strides.clone_from(&other.layout.strides); } } } diff --git a/src/impl_cow.rs b/src/impl_cow.rs index 4843e305b..0ecc3c44b 100644 --- a/src/impl_cow.rs +++ b/src/impl_cow.rs @@ -33,7 +33,10 @@ where D: Dimension fn from(view: ArrayView<'a, A, D>) -> CowArray<'a, A, D> { // safe because equivalent data - unsafe { ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr).with_strides_dim(view.strides, view.dim) } + unsafe { + ArrayBase::from_data_ptr(CowRepr::View(view.data), view.ptr) + .with_strides_dim(view.layout.strides, view.layout.dim) + } } } @@ -44,7 +47,8 @@ where D: Dimension { // safe because equivalent data unsafe { - ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.ptr).with_strides_dim(array.strides, array.dim) + ArrayBase::from_data_ptr(CowRepr::Owned(array.data), array.layout.ptr) + .with_strides_dim(array.layout.strides, array.layout.dim) } } } diff --git a/src/impl_dyn.rs b/src/impl_dyn.rs index b86c5dd69..409fe991a 100644 --- a/src/impl_dyn.rs +++ b/src/impl_dyn.rs @@ -10,8 +10,7 @@ use crate::imp_prelude::*; /// # Methods for Dynamic-Dimensional Arrays -impl ArrayBase -where S: Data +impl LayoutRef { /// Insert new array axis of length 1 at `axis`, modifying the shape and /// strides in-place. @@ -58,7 +57,56 @@ where S: Data self.dim = self.dim.remove_axis(axis); self.strides = self.strides.remove_axis(axis); } +} + +impl ArrayBase +{ + /// Insert new array axis of length 1 at `axis`, modifying the shape and + /// strides in-place. + /// + /// **Panics** if the axis is out of bounds. + /// + /// ``` + /// use ndarray::{Axis, arr2, arr3}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 3]); + /// + /// a.insert_axis_inplace(Axis(1)); + /// assert_eq!(a, arr3(&[[[1, 2, 3]], [[4, 5, 6]]]).into_dyn()); + /// assert_eq!(a.shape(), &[2, 1, 3]); + /// ``` + #[track_caller] + pub fn insert_axis_inplace(&mut self, axis: Axis) + { + self.as_mut().insert_axis_inplace(axis) + } + /// Collapses the array to `index` along the axis and removes the axis, + /// modifying the shape and strides in-place. + /// + /// **Panics** if `axis` or `index` is out of bounds. + /// + /// ``` + /// use ndarray::{Axis, arr1, arr2}; + /// + /// let mut a = arr2(&[[1, 2, 3], [4, 5, 6]]).into_dyn(); + /// assert_eq!(a.shape(), &[2, 3]); + /// + /// a.index_axis_inplace(Axis(1), 1); + /// assert_eq!(a, arr1(&[2, 5]).into_dyn()); + /// assert_eq!(a.shape(), &[2]); + /// ``` + #[track_caller] + pub fn index_axis_inplace(&mut self, axis: Axis, index: usize) + { + self.as_mut().index_axis_inplace(axis, index) + } +} + +impl ArrayBase +where S: Data +{ /// Remove axes of length 1 and return the modified array. /// /// If the array has more the one dimension, the result array will always diff --git a/src/impl_internal_constructors.rs b/src/impl_internal_constructors.rs index adb4cbd35..7f95339d5 100644 --- a/src/impl_internal_constructors.rs +++ b/src/impl_internal_constructors.rs @@ -8,7 +8,7 @@ use std::ptr::NonNull; -use crate::imp_prelude::*; +use crate::{imp_prelude::*, LayoutRef}; // internal "builder-like" methods impl ArrayBase @@ -27,9 +27,11 @@ where S: RawData { let array = ArrayBase { data, - ptr, - dim: Ix1(0), - strides: Ix1(1), + layout: LayoutRef { + ptr, + dim: Ix1(0), + strides: Ix1(1), + }, }; debug_assert!(array.pointer_is_inbounds()); array @@ -58,9 +60,11 @@ where debug_assert_eq!(strides.ndim(), dim.ndim()); ArrayBase { data: self.data, - ptr: self.ptr, - dim, - strides, + layout: LayoutRef { + ptr: self.layout.ptr, + dim, + strides, + }, } } } diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 3da63b936..d2f04ef1f 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -38,7 +38,10 @@ use crate::math_cell::MathCell; use crate::order::Order; use crate::shape_builder::ShapeArg; use crate::zip::{IntoNdProducer, Zip}; +use crate::ArrayRef; use crate::AxisDescription; +use crate::LayoutRef; +use crate::RawRef; use crate::{arraytraits, DimMax}; use crate::iter::{ @@ -62,10 +65,7 @@ use crate::stacking::concatenate; use crate::{NdIndex, Slice, SliceInfoElem}; /// # Methods For All Array Types -impl ArrayBase -where - S: RawData, - D: Dimension, +impl LayoutRef { /// Return the total number of elements in the array. pub fn len(&self) -> usize @@ -173,20 +173,20 @@ where // strides are reinterpreted as isize self.strides[axis.index()] as isize } +} +impl ArrayRef +{ /// Return a read-only view of the array pub fn view(&self) -> ArrayView<'_, A, D> - where S: Data { - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } /// Return a read-write view of the array pub fn view_mut(&mut self) -> ArrayViewMut<'_, A, D> - where S: DataMut { - self.ensure_unique(); unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } } @@ -198,7 +198,6 @@ where /// The view acts "as if" the elements are temporarily in cells, and elements /// can be changed through shared references using the regular cell methods. pub fn cell_view(&mut self) -> ArrayView<'_, MathCell, D> - where S: DataMut { self.view_mut().into_cell_view() } @@ -234,9 +233,7 @@ where /// # assert_eq!(arr, owned); /// ``` pub fn to_owned(&self) -> Array - where - A: Clone, - S: Data, + where A: Clone { if let Some(slc) = self.as_slice_memory_order() { unsafe { Array::from_shape_vec_unchecked(self.dim.clone().strides(self.strides.clone()), slc.to_vec()) } @@ -244,6 +241,50 @@ where self.map(A::clone) } } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ + /// Return an uniquely owned copy of the array. + /// + /// If the input array is contiguous, then the output array will have the same + /// memory layout. Otherwise, the layout of the output array is unspecified. + /// If you need a particular layout, you can allocate a new array with the + /// desired memory layout and [`.assign()`](ArrayRef::assign) the data. + /// Alternatively, you can collectan iterator, like this for a result in + /// standard layout: + /// + /// ``` + /// # use ndarray::prelude::*; + /// # let arr = Array::from_shape_vec((2, 2).f(), vec![1, 2, 3, 4]).unwrap(); + /// # let owned = { + /// Array::from_shape_vec(arr.raw_dim(), arr.iter().cloned().collect()).unwrap() + /// # }; + /// # assert!(owned.is_standard_layout()); + /// # assert_eq!(arr, owned); + /// ``` + /// + /// or this for a result in column-major (Fortran) layout: + /// + /// ``` + /// # use ndarray::prelude::*; + /// # let arr = Array::from_shape_vec((2, 2), vec![1, 2, 3, 4]).unwrap(); + /// # let owned = { + /// Array::from_shape_vec(arr.raw_dim().f(), arr.t().iter().cloned().collect()).unwrap() + /// # }; + /// # assert!(owned.t().is_standard_layout()); + /// # assert_eq!(arr, owned); + /// ``` + pub fn to_owned(&self) -> Array + where + A: Clone, + S: Data, + { + (**self).to_owned() + } /// Return a shared ownership (copy on write) array, cloning the array /// elements if necessary. @@ -305,7 +346,10 @@ where { S::into_shared(self) } +} +impl ArrayRef +{ /// Returns a reference to the first element of the array, or `None` if it /// is empty. /// @@ -322,7 +366,6 @@ where /// assert_eq!(b.first(), None); /// ``` pub fn first(&self) -> Option<&A> - where S: Data { if self.is_empty() { None @@ -347,7 +390,6 @@ where /// assert_eq!(b.first_mut(), None); /// ``` pub fn first_mut(&mut self) -> Option<&mut A> - where S: DataMut { if self.is_empty() { None @@ -372,7 +414,6 @@ where /// assert_eq!(b.last(), None); /// ``` pub fn last(&self) -> Option<&A> - where S: Data { if self.is_empty() { None @@ -401,12 +442,10 @@ where /// assert_eq!(b.last_mut(), None); /// ``` pub fn last_mut(&mut self) -> Option<&mut A> - where S: DataMut { if self.is_empty() { None } else { - self.ensure_unique(); let mut index = self.raw_dim(); for ax in 0..index.ndim() { index[ax] -= 1; @@ -422,9 +461,8 @@ where /// /// Iterator element type is `&A`. pub fn iter(&self) -> Iter<'_, A, D> - where S: Data { - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); self.view().into_iter_() } @@ -435,7 +473,6 @@ where /// /// Iterator element type is `&mut A`. pub fn iter_mut(&mut self) -> IterMut<'_, A, D> - where S: DataMut { self.view_mut().into_iter_() } @@ -449,7 +486,6 @@ where /// /// See also [`Zip::indexed`] pub fn indexed_iter(&self) -> IndexedIter<'_, A, D> - where S: Data { IndexedIter::new(self.view().into_elements_base()) } @@ -461,7 +497,6 @@ where /// /// Iterator element type is `(D::Pattern, &mut A)`. pub fn indexed_iter_mut(&mut self) -> IndexedIterMut<'_, A, D> - where S: DataMut { IndexedIterMut::new(self.view_mut().into_elements_base()) } @@ -475,9 +510,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice(&self, info: I) -> ArrayView<'_, A, I::OutDim> - where - I: SliceArg, - S: Data, + where I: SliceArg { self.view().slice_move(info) } @@ -491,9 +524,7 @@ where /// (**Panics** if `D` is `IxDyn` and `info` does not match the number of array axes.) #[track_caller] pub fn slice_mut(&mut self, info: I) -> ArrayViewMut<'_, A, I::OutDim> - where - I: SliceArg, - S: DataMut, + where I: SliceArg { self.view_mut().slice_move(info) } @@ -523,13 +554,17 @@ where /// ``` #[track_caller] pub fn multi_slice_mut<'a, M>(&'a mut self, info: M) -> M::Output - where - M: MultiSliceArg<'a, A, D>, - S: DataMut, + where M: MultiSliceArg<'a, A, D> { info.multi_slice_move(self.view_mut()) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Slice the array, possibly changing the number of dimensions. /// /// See [*Slicing*](#slicing) for full documentation. @@ -557,8 +592,8 @@ where // Slice the axis in-place to update the `dim`, `strides`, and `ptr`. self.slice_axis_inplace(Axis(old_axis), Slice { start, end, step }); // Copy the sliced dim and stride to corresponding axis. - new_dim[new_axis] = self.dim[old_axis]; - new_strides[new_axis] = self.strides[old_axis]; + new_dim[new_axis] = self.layout.dim[old_axis]; + new_strides[new_axis] = self.layout.strides[old_axis]; old_axis += 1; new_axis += 1; } @@ -585,16 +620,19 @@ where // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(new_strides, new_dim) } } +} +impl LayoutRef +{ /// Slice the array in place without changing the number of dimensions. /// /// In particular, if an axis is sliced with an index, the axis is /// collapsed, as in [`.collapse_axis()`], rather than removed, as in /// [`.slice_move()`] or [`.index_axis_move()`]. /// - /// [`.collapse_axis()`]: Self::collapse_axis - /// [`.slice_move()`]: Self::slice_move - /// [`.index_axis_move()`]: Self::index_axis_move + /// [`.collapse_axis()`]: LayoutRef::collapse_axis + /// [`.slice_move()`]: ArrayBase::slice_move + /// [`.index_axis_move()`]: ArrayBase::index_axis_move /// /// See [*Slicing*](#slicing) for full documentation. /// See also [`s!`], [`SliceArg`], and [`SliceInfo`](crate::SliceInfo). @@ -630,7 +668,10 @@ where }); debug_assert_eq!(axis, self.ndim()); } +} +impl ArrayRef +{ /// Return a view of the array, sliced along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -638,7 +679,6 @@ where #[track_caller] #[must_use = "slice_axis returns an array view with the sliced result"] pub fn slice_axis(&self, axis: Axis, indices: Slice) -> ArrayView<'_, A, D> - where S: Data { let mut view = self.view(); view.slice_axis_inplace(axis, indices); @@ -652,13 +692,15 @@ where #[track_caller] #[must_use = "slice_axis_mut returns an array view with the sliced result"] pub fn slice_axis_mut(&mut self, axis: Axis, indices: Slice) -> ArrayViewMut<'_, A, D> - where S: DataMut { let mut view_mut = self.view_mut(); view_mut.slice_axis_inplace(axis, indices); view_mut } +} +impl LayoutRef +{ /// Slice the array in place along the specified axis. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -671,9 +713,15 @@ where unsafe { self.ptr = self.ptr.offset(offset); } - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Slice the array in place along the specified axis, then return the sliced array. /// /// **Panics** if an index is out of bounds or step size is zero.
@@ -684,7 +732,10 @@ where self.slice_axis_inplace(axis, indices); self } +} +impl ArrayRef +{ /// Return a view of a slice of the array, with a closure specifying the /// slice for each axis. /// @@ -694,9 +745,7 @@ where /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis(&self, f: F) -> ArrayView<'_, A, D> - where - F: FnMut(AxisDescription) -> Slice, - S: Data, + where F: FnMut(AxisDescription) -> Slice { let mut view = self.view(); view.slice_each_axis_inplace(f); @@ -712,15 +761,16 @@ where /// **Panics** if an index is out of bounds or step size is zero. #[track_caller] pub fn slice_each_axis_mut(&mut self, f: F) -> ArrayViewMut<'_, A, D> - where - F: FnMut(AxisDescription) -> Slice, - S: DataMut, + where F: FnMut(AxisDescription) -> Slice { let mut view = self.view_mut(); view.slice_each_axis_inplace(f); view } +} +impl LayoutRef +{ /// Slice the array in place, with a closure specifying the slice for each /// axis. /// @@ -743,7 +793,10 @@ where ) } } +} +impl ArrayRef +{ /// Return a reference to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -763,13 +816,14 @@ where /// ); /// ``` pub fn get(&self, index: I) -> Option<&A> - where - S: Data, - I: NdIndex, + where I: NdIndex { unsafe { self.get_ptr(index).map(|ptr| &*ptr) } } +} +impl RawRef +{ /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -791,17 +845,21 @@ where .index_checked(&self.dim, &self.strides) .map(move |offset| unsafe { ptr.as_ptr().offset(offset) as *const _ }) } +} +impl ArrayRef +{ /// Return a mutable reference to the element at `index`, or return `None` /// if the index is out of bounds. pub fn get_mut(&mut self, index: I) -> Option<&mut A> - where - S: DataMut, - I: NdIndex, + where I: NdIndex { unsafe { self.get_mut_ptr(index).map(|ptr| &mut *ptr) } } +} +impl RawRef +{ /// Return a raw pointer to the element at `index`, or return `None` /// if the index is out of bounds. /// @@ -820,9 +878,7 @@ where /// assert_eq!(a.get((0, 1)), Some(&5.)); /// ``` pub fn get_mut_ptr(&mut self, index: I) -> Option<*mut A> - where - S: RawDataMut, - I: NdIndex, + where I: NdIndex { // const and mut are separate to enforce &mutness as well as the // extra code in as_mut_ptr @@ -831,7 +887,10 @@ where .index_checked(&self.dim, &self.strides) .map(move |offset| unsafe { ptr.offset(offset) }) } +} +impl ArrayRef +{ /// Perform *unchecked* array indexing. /// /// Return a reference to the element at `index`. @@ -843,9 +902,7 @@ where /// The caller must ensure that the index is in-bounds. #[inline] pub unsafe fn uget(&self, index: I) -> &A - where - S: Data, - I: NdIndex, + where I: NdIndex { arraytraits::debug_bounds_check(self, &index); let off = index.index_unchecked(&self.strides); @@ -868,11 +925,9 @@ where /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) #[inline] pub unsafe fn uget_mut(&mut self, index: I) -> &mut A - where - S: DataMut, - I: NdIndex, + where I: NdIndex { - debug_assert!(self.data.is_unique()); + // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index); let off = index.index_unchecked(&self.strides); &mut *self.ptr.as_ptr().offset(off) @@ -885,9 +940,7 @@ where /// ***Panics*** if an index is out of bounds. #[track_caller] pub fn swap(&mut self, index1: I, index2: I) - where - S: DataMut, - I: NdIndex, + where I: NdIndex { let ptr = self.as_mut_ptr(); let offset1 = index1.index_checked(&self.dim, &self.strides); @@ -918,11 +971,9 @@ where /// 2. the data is uniquely held by the array. (This property is guaranteed /// for `Array` and `ArrayViewMut`, but not for `ArcArray` or `CowArray`.) pub unsafe fn uswap(&mut self, index1: I, index2: I) - where - S: DataMut, - I: NdIndex, + where I: NdIndex { - debug_assert!(self.data.is_unique()); + // debug_assert!(self.data.is_unique()); arraytraits::debug_bounds_check(self, &index1); arraytraits::debug_bounds_check(self, &index2); let off1 = index1.index_unchecked(&self.strides); @@ -933,7 +984,6 @@ where // `get` for zero-dimensional arrays // panics if dimension is not zero. otherwise an element is always present. fn get_0d(&self) -> &A - where S: Data { assert!(self.ndim() == 0); unsafe { &*self.as_ptr() } @@ -962,9 +1012,7 @@ where /// ``` #[track_caller] pub fn index_axis(&self, axis: Axis, index: usize) -> ArrayView<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { self.view().index_axis_move(axis, index) } @@ -995,16 +1043,20 @@ where /// ``` #[track_caller] pub fn index_axis_mut(&mut self, axis: Axis, index: usize) -> ArrayViewMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { self.view_mut().index_axis_move(axis, index) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Collapses the array to `index` along the axis and removes the axis. /// - /// See [`.index_axis()`](Self::index_axis) and [*Subviews*](#subviews) for full documentation. + /// See [`.index_axis()`](ArrayRef::index_axis) and [*Subviews*](#subviews) for full documentation. /// /// **Panics** if `axis` or `index` is out of bounds. #[track_caller] @@ -1012,12 +1064,15 @@ where where D: RemoveAxis { self.collapse_axis(axis, index); - let dim = self.dim.remove_axis(axis); - let strides = self.strides.remove_axis(axis); + let dim = self.layout.dim.remove_axis(axis); + let strides = self.layout.strides.remove_axis(axis); // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(strides, dim) } } +} +impl LayoutRef +{ /// Selects `index` along the axis, collapsing the axis into length one. /// /// **Panics** if `axis` or `index` is out of bounds. @@ -1026,9 +1081,12 @@ where { let offset = dimension::do_collapse_axis(&mut self.dim, &self.strides, axis.index(), index); self.ptr = unsafe { self.ptr.offset(offset) }; - debug_assert!(self.pointer_is_inbounds()); + // debug_assert!(self.pointer_is_inbounds()); } +} +impl ArrayRef +{ /// Along `axis`, select arbitrary subviews corresponding to `indices` /// and copy them into a new array. /// @@ -1054,7 +1112,6 @@ where pub fn select(&self, axis: Axis, indices: &[Ix]) -> Array where A: Clone, - S: Data, D: RemoveAxis, { if self.ndim() == 1 { @@ -1116,7 +1173,6 @@ where /// } /// ``` pub fn rows(&self) -> Lanes<'_, A, D::Smaller> - where S: Data { let mut n = self.ndim(); if n == 0 { @@ -1130,7 +1186,6 @@ where /// /// Iterator element is `ArrayView1
` (1D read-write array view). pub fn rows_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { let mut n = self.ndim(); if n == 0 { @@ -1166,7 +1221,6 @@ where /// } /// ``` pub fn columns(&self) -> Lanes<'_, A, D::Smaller> - where S: Data { Lanes::new(self.view(), Axis(0)) } @@ -1176,7 +1230,6 @@ where /// /// Iterator element is `ArrayView1` (1D read-write array view). pub fn columns_mut(&mut self) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { LanesMut::new(self.view_mut(), Axis(0)) } @@ -1210,7 +1263,6 @@ where /// assert_eq!(inner2.into_iter().next().unwrap(), aview1(&[0, 1, 2])); /// ``` pub fn lanes(&self, axis: Axis) -> Lanes<'_, A, D::Smaller> - where S: Data { Lanes::new(self.view(), axis) } @@ -1220,7 +1272,6 @@ where /// /// Iterator element is `ArrayViewMut1` (1D read-write array view). pub fn lanes_mut(&mut self, axis: Axis) -> LanesMut<'_, A, D::Smaller> - where S: DataMut { LanesMut::new(self.view_mut(), axis) } @@ -1233,9 +1284,7 @@ where /// Iterator element is `ArrayView` (read-only array view). #[allow(deprecated)] pub fn outer_iter(&self) -> AxisIter<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { self.view().into_outer_iter() } @@ -1248,9 +1297,7 @@ where /// Iterator element is `ArrayViewMut` (read-write array view). #[allow(deprecated)] pub fn outer_iter_mut(&mut self) -> AxisIterMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { self.view_mut().into_outer_iter() } @@ -1272,9 +1319,7 @@ where /// #[track_caller] pub fn axis_iter(&self, axis: Axis) -> AxisIter<'_, A, D::Smaller> - where - S: Data, - D: RemoveAxis, + where D: RemoveAxis { AxisIter::new(self.view(), axis) } @@ -1288,9 +1333,7 @@ where /// **Panics** if `axis` is out of bounds. #[track_caller] pub fn axis_iter_mut(&mut self, axis: Axis) -> AxisIterMut<'_, A, D::Smaller> - where - S: DataMut, - D: RemoveAxis, + where D: RemoveAxis { AxisIterMut::new(self.view_mut(), axis) } @@ -1323,7 +1366,6 @@ where /// ``` #[track_caller] pub fn axis_chunks_iter(&self, axis: Axis, size: usize) -> AxisChunksIter<'_, A, D> - where S: Data { AxisChunksIter::new(self.view(), axis, size) } @@ -1336,7 +1378,6 @@ where /// **Panics** if `axis` is out of bounds or if `size` is zero. #[track_caller] pub fn axis_chunks_iter_mut(&mut self, axis: Axis, size: usize) -> AxisChunksIterMut<'_, A, D> - where S: DataMut { AxisChunksIterMut::new(self.view_mut(), axis, size) } @@ -1354,9 +1395,7 @@ where /// number of array axes.) #[track_caller] pub fn exact_chunks(&self, chunk_size: E) -> ExactChunks<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { ExactChunks::new(self.view(), chunk_size) } @@ -1395,9 +1434,7 @@ where /// ``` #[track_caller] pub fn exact_chunks_mut(&mut self, chunk_size: E) -> ExactChunksMut<'_, A, D> - where - E: IntoDimension, - S: DataMut, + where E: IntoDimension { ExactChunksMut::new(self.view_mut(), chunk_size) } @@ -1410,9 +1447,7 @@ where /// This is essentially equivalent to [`.windows_with_stride()`] with unit stride. #[track_caller] pub fn windows(&self, window_size: E) -> Windows<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { Windows::new(self.view(), window_size) } @@ -1433,13 +1468,13 @@ where /// `window_size`. /// /// Note that passing a stride of only ones is similar to - /// calling [`ArrayBase::windows()`]. + /// calling [`ArrayRef::windows()`]. /// /// **Panics** if any dimension of `window_size` or `stride` is zero.
/// (**Panics** if `D` is `IxDyn` and `window_size` or `stride` does not match the /// number of array axes.) /// - /// This is the same illustration found in [`ArrayBase::windows()`], + /// This is the same illustration found in [`ArrayRef::windows()`], /// 2×2 windows in a 3×4 array, but now with a (1, 2) stride: /// /// ```text @@ -1463,9 +1498,7 @@ where /// ``` #[track_caller] pub fn windows_with_stride(&self, window_size: E, stride: E) -> Windows<'_, A, D> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { Windows::new_with_stride(self.view(), window_size, stride) } @@ -1492,7 +1525,6 @@ where /// } /// ``` pub fn axis_windows(&self, axis: Axis, window_size: usize) -> AxisWindows<'_, A, D> - where S: Data { let axis_index = axis.index(); @@ -1510,31 +1542,35 @@ where AxisWindows::new(self.view(), axis, window_size) } - // Return (length, stride) for diagonal - fn diag_params(&self) -> (Ix, Ixs) - { - /* empty shape has len 1 */ - let len = self.dim.slice().iter().cloned().min().unwrap_or(1); - let stride = self.strides().iter().sum(); - (len, stride) - } - /// Return a view of the diagonal elements of the array. /// /// The diagonal is simply the sequence indexed by *(0, 0, .., 0)*, /// *(1, 1, ..., 1)* etc as long as all axes have elements. pub fn diag(&self) -> ArrayView1<'_, A> - where S: Data { self.view().into_diag() } /// Return a read-write view over the diagonal elements of the array. pub fn diag_mut(&mut self) -> ArrayViewMut1<'_, A> - where S: DataMut { self.view_mut().into_diag() } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ + // Return (length, stride) for diagonal + fn diag_params(&self) -> (Ix, Ixs) + { + /* empty shape has len 1 */ + let len = self.layout.dim.slice().iter().cloned().min().unwrap_or(1); + let stride = self.strides().iter().sum(); + (len, stride) + } /// Return the diagonal as a one-dimensional array. pub fn into_diag(self) -> ArrayBase @@ -1560,14 +1596,17 @@ where /// Make the array unshared. /// /// This method is mostly only useful with unsafe code. - fn ensure_unique(&mut self) + pub(crate) fn ensure_unique(&mut self) where S: DataMut { debug_assert!(self.pointer_is_inbounds()); S::ensure_unique(self); debug_assert!(self.pointer_is_inbounds()); } +} +impl LayoutRef +{ /// Return `true` if the array data is laid out in contiguous “C order” in /// memory (where the last index is the most rapidly varying). /// @@ -1583,7 +1622,10 @@ where { D::is_contiguous(&self.dim, &self.strides) } +} +impl ArrayRef +{ /// Return a standard-layout array containing the data, cloning if /// necessary. /// @@ -1607,9 +1649,7 @@ where /// assert!(cow_owned.is_standard_layout()); /// ``` pub fn as_standard_layout(&self) -> CowArray<'_, A, D> - where - S: Data, - A: Clone, + where A: Clone { if self.is_standard_layout() { CowArray::from(self.view()) @@ -1625,7 +1665,10 @@ where } } } +} +impl RawRef +{ /// Return a pointer to the first element in the array. /// /// Raw access to array elements needs to follow the strided indexing @@ -1641,6 +1684,19 @@ where self.ptr.as_ptr() as *const A } + /// Return a mutable pointer to the first element in the array reference. + #[inline(always)] + pub fn as_mut_ptr(&mut self) -> *mut A + { + self.ptr.as_ptr() + } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Return a mutable pointer to the first element in the array. /// /// This method attempts to unshare the data. If `S: DataMut`, then the @@ -1656,9 +1712,12 @@ where where S: RawDataMut { self.try_ensure_unique(); // for ArcArray - self.ptr.as_ptr() + self.layout.ptr.as_ptr() } +} +impl RawRef +{ /// Return a raw view of the array. #[inline] pub fn raw_view(&self) -> RawArrayView @@ -1666,6 +1725,19 @@ where unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) } } + /// Return a raw mutable view of the array. + #[inline] + pub fn raw_view_mut(&mut self) -> RawArrayViewMut + { + unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + } +} + +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Return a raw mutable view of the array. /// /// This method attempts to unshare the data. If `S: DataMut`, then the @@ -1675,7 +1747,7 @@ where where S: RawDataMut { self.try_ensure_unique(); // for ArcArray - unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + unsafe { RawArrayViewMut::new(self.layout.ptr, self.layout.dim.clone(), self.layout.strides.clone()) } } /// Return a raw mutable view of the array. @@ -1688,13 +1760,54 @@ where RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) } + /// Return the array’s data as a slice, if it is contiguous and in standard order. + /// Return `None` otherwise. + pub fn as_slice_mut(&mut self) -> Option<&mut [A]> + where S: DataMut + { + if self.is_standard_layout() { + self.ensure_unique(); + unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } + } else { + None + } + } + + /// Return the array’s data as a slice if it is contiguous, + /// return `None` otherwise. + /// + /// In the contiguous case, in order to return a unique reference, this + /// method unshares the data if necessary, but it preserves the existing + /// strides. + pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> + where S: DataMut + { + self.try_as_slice_memory_order_mut().ok() + } + + /// Return the array’s data as a slice if it is contiguous, otherwise + /// return `self` in the `Err` variant. + pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> + where S: DataMut + { + if self.is_contiguous() { + self.ensure_unique(); + let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); + unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } + } else { + Err(self) + } + } +} + +impl ArrayRef +{ /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// /// If this function returns `Some(_)`, then the element order in the slice /// corresponds to the logical order of the array’s elements. pub fn as_slice(&self) -> Option<&[A]> - where S: Data { if self.is_standard_layout() { unsafe { Some(slice::from_raw_parts(self.ptr.as_ptr(), self.len())) } @@ -1706,10 +1819,8 @@ where /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. pub fn as_slice_mut(&mut self) -> Option<&mut [A]> - where S: DataMut { if self.is_standard_layout() { - self.ensure_unique(); unsafe { Some(slice::from_raw_parts_mut(self.ptr.as_ptr(), self.len())) } } else { None @@ -1722,7 +1833,6 @@ where /// If this function returns `Some(_)`, then the elements in the slice /// have whatever order the elements have in memory. pub fn as_slice_memory_order(&self) -> Option<&[A]> - where S: Data { if self.is_contiguous() { let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); @@ -1739,7 +1849,6 @@ where /// method unshares the data if necessary, but it preserves the existing /// strides. pub fn as_slice_memory_order_mut(&mut self) -> Option<&mut [A]> - where S: DataMut { self.try_as_slice_memory_order_mut().ok() } @@ -1747,10 +1856,8 @@ where /// Return the array’s data as a slice if it is contiguous, otherwise /// return `self` in the `Err` variant. pub(crate) fn try_as_slice_memory_order_mut(&mut self) -> Result<&mut [A], &mut Self> - where S: DataMut { if self.is_contiguous() { - self.ensure_unique(); let offset = offset_from_low_addr_ptr_to_logical_ptr(&self.dim, &self.strides); unsafe { Ok(slice::from_raw_parts_mut(self.ptr.sub(offset).as_ptr(), self.len())) } } else { @@ -1817,7 +1924,6 @@ where where E: ShapeArg, A: Clone, - S: Data, { let (shape, order) = new_shape.into_shape_and_order(); self.to_shape_order(shape, order.unwrap_or(Order::RowMajor)) @@ -1827,7 +1933,6 @@ where where E: Dimension, A: Clone, - S: Data, { let len = self.dim.size(); if size_of_shape_checked(&shape) != Ok(len) { @@ -1861,7 +1966,13 @@ where Ok(CowArray::from(Array::from_shape_trusted_iter_unchecked(shape, view.into_iter(), A::clone))) } } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Transform the array into `shape`; any shape with the same number of /// elements is accepted, but the source array must be contiguous. /// @@ -1918,8 +2029,8 @@ where where E: Dimension { let shape = shape.into_dimension(); - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { - return Err(error::incompatible_shapes(&self.dim, &shape)); + if size_of_shape_checked(&shape) != Ok(self.layout.dim.size()) { + return Err(error::incompatible_shapes(&self.layout.dim, &shape)); } // Check if contiguous, then we can change shape @@ -1963,8 +2074,8 @@ where where E: IntoDimension { let shape = shape.into_dimension(); - if size_of_shape_checked(&shape) != Ok(self.dim.size()) { - return Err(error::incompatible_shapes(&self.dim, &shape)); + if size_of_shape_checked(&shape) != Ok(self.layout.dim.size()) { + return Err(error::incompatible_shapes(&self.layout.dim, &shape)); } // Check if contiguous, if not => copy all, else just adapt strides unsafe { @@ -2092,7 +2203,10 @@ where unsafe { ArrayBase::from_shape_vec_unchecked(shape, v) } } } +} +impl ArrayRef +{ /// Flatten the array to a one-dimensional array. /// /// The array is returned as a `CowArray`; a view if possible, otherwise an owned array. @@ -2105,9 +2219,7 @@ where /// assert_eq!(flattened, arr1(&[1, 2, 3, 4, 5, 6, 7, 8])); /// ``` pub fn flatten(&self) -> CowArray<'_, A, Ix1> - where - A: Clone, - S: Data, + where A: Clone { self.flatten_with_order(Order::RowMajor) } @@ -2128,13 +2240,17 @@ where /// assert_eq!(flattened, arr1(&[1, 3, 5, 7, 2, 4, 6, 8])); /// ``` pub fn flatten_with_order(&self, order: Order) -> CowArray<'_, A, Ix1> - where - A: Clone, - S: Data, + where A: Clone { self.to_shape((self.len(), order)).unwrap() } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Flatten the array to a one-dimensional array, consuming the array. /// /// If possible, no copy is made, and the new array use the same memory as the original array. @@ -2169,7 +2285,8 @@ where { // safe because new dims equivalent unsafe { - ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(self.strides.into_dyn(), self.dim.into_dyn()) + ArrayBase::from_data_ptr(self.data, self.layout.ptr) + .with_strides_dim(self.layout.strides.into_dyn(), self.layout.dim.into_dyn()) } } @@ -2195,14 +2312,14 @@ where unsafe { if D::NDIM == D2::NDIM { // safe because D == D2 - let dim = unlimited_transmute::(self.dim); - let strides = unlimited_transmute::(self.strides); - return Ok(ArrayBase::from_data_ptr(self.data, self.ptr).with_strides_dim(strides, dim)); + let dim = unlimited_transmute::(self.layout.dim); + let strides = unlimited_transmute::(self.layout.strides); + return Ok(ArrayBase::from_data_ptr(self.data, self.layout.ptr).with_strides_dim(strides, dim)); } else if D::NDIM.is_none() || D2::NDIM.is_none() { // one is dynamic dim // safe because dim, strides are equivalent under a different type - if let Some(dim) = D2::from_dimension(&self.dim) { - if let Some(strides) = D2::from_dimension(&self.strides) { + if let Some(dim) = D2::from_dimension(&self.layout.dim) { + if let Some(strides) = D2::from_dimension(&self.layout.strides) { return Ok(self.with_strides_dim(strides, dim)); } } @@ -2210,7 +2327,10 @@ where } Err(ShapeError::from_kind(ErrorKind::IncompatibleShape)) } +} +impl ArrayRef +{ /// Act like a larger size and/or shape array by *broadcasting* /// into a larger shape, if possible. /// @@ -2241,9 +2361,7 @@ where /// ); /// ``` pub fn broadcast(&self, dim: E) -> Option> - where - E: IntoDimension, - S: Data, + where E: IntoDimension { /// Return new stride when trying to grow `from` into shape `to` /// @@ -2308,12 +2426,10 @@ where /// /// Return `ShapeError` if their shapes can not be broadcast together. #[allow(clippy::type_complexity)] - pub(crate) fn broadcast_with<'a, 'b, B, S2, E>( - &'a self, other: &'b ArrayBase, + pub(crate) fn broadcast_with<'a, 'b, B, E>( + &'a self, other: &'b ArrayRef, ) -> Result<(ArrayView<'a, A, DimMaxOf>, ArrayView<'b, B, DimMaxOf>), ShapeError> where - S: Data, - S2: Data, D: Dimension + DimMax, E: Dimension, { @@ -2339,7 +2455,10 @@ where }; Ok((view1, view2)) } +} +impl LayoutRef +{ /// Swap axes `ax` and `bx`. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2362,7 +2481,13 @@ where self.dim.slice_mut().swap(ax, bx); self.strides.slice_mut().swap(ax, bx); } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Permute the axes. /// /// This does not move any data, it just adjusts the array’s dimensions @@ -2402,8 +2527,8 @@ where let mut new_dim = usage_counts; // reuse to avoid an allocation let mut new_strides = D::zeros(self.ndim()); { - let dim = self.dim.slice(); - let strides = self.strides.slice(); + let dim = self.layout.dim.slice(); + let strides = self.layout.strides.slice(); for (new_axis, &axis) in axes.slice().iter().enumerate() { new_dim[new_axis] = dim[axis]; new_strides[new_axis] = strides[axis]; @@ -2419,22 +2544,27 @@ where /// while retaining the same data. pub fn reversed_axes(mut self) -> ArrayBase { - self.dim.slice_mut().reverse(); - self.strides.slice_mut().reverse(); + self.layout.dim.slice_mut().reverse(); + self.layout.strides.slice_mut().reverse(); self } +} +impl ArrayRef +{ /// Return a transposed view of the array. /// /// This is a shorthand for `self.view().reversed_axes()`. /// /// See also the more general methods `.reversed_axes()` and `.swap_axes()`. pub fn t(&self) -> ArrayView<'_, A, D> - where S: Data { self.view().reversed_axes() } +} +impl LayoutRef +{ /// Return an iterator over the length and stride of each axis. pub fn axes(&self) -> Axes<'_, D> { @@ -2511,7 +2641,13 @@ where { merge_axes(&mut self.dim, &mut self.strides, take, into) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Insert new array axis at `axis` and return the result. /// /// ``` @@ -2539,8 +2675,8 @@ where assert!(axis.index() <= self.ndim()); // safe because a new axis of length one does not affect memory layout unsafe { - let strides = self.strides.insert_axis(axis); - let dim = self.dim.insert_axis(axis); + let strides = self.layout.strides.insert_axis(axis); + let dim = self.layout.dim.insert_axis(axis); self.with_strides_dim(strides, dim) } } @@ -2562,18 +2698,18 @@ where { self.data._is_pointer_inbounds(self.as_ptr()) } +} +impl ArrayRef +{ /// Perform an elementwise assigment to `self` from `rhs`. /// /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. /// /// **Panics** if broadcasting isn’t possible. #[track_caller] - pub fn assign(&mut self, rhs: &ArrayBase) - where - S: DataMut, - A: Clone, - S2: Data, + pub fn assign(&mut self, rhs: &ArrayRef) + where A: Clone { self.zip_mut_with(rhs, |x, y| x.clone_from(y)); } @@ -2587,7 +2723,6 @@ where #[track_caller] pub fn assign_to

(&self, to: P) where - S: Data, P: IntoNdProducer, P::Item: AssignElem, A: Clone, @@ -2597,17 +2732,13 @@ where /// Perform an elementwise assigment to `self` from element `x`. pub fn fill(&mut self, x: A) - where - S: DataMut, - A: Clone, + where A: Clone { self.map_inplace(move |elt| elt.clone_from(&x)); } - pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayBase, mut f: F) + pub(crate) fn zip_mut_with_same_shape(&mut self, rhs: &ArrayRef, mut f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2630,10 +2761,8 @@ where // zip two arrays where they have different layout or strides #[inline(always)] - fn zip_mut_with_by_rows(&mut self, rhs: &ArrayBase, mut f: F) + fn zip_mut_with_by_rows(&mut self, rhs: &ArrayRef, mut f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2649,9 +2778,7 @@ where } fn zip_mut_with_elem(&mut self, rhs_elem: &B, mut f: F) - where - S: DataMut, - F: FnMut(&mut A, &B), + where F: FnMut(&mut A, &B) { self.map_inplace(move |elt| f(elt, rhs_elem)); } @@ -2664,10 +2791,8 @@ where /// **Panics** if broadcasting isn’t possible. #[track_caller] #[inline] - pub fn zip_mut_with(&mut self, rhs: &ArrayBase, f: F) + pub fn zip_mut_with(&mut self, rhs: &ArrayRef, f: F) where - S: DataMut, - S2: Data, E: Dimension, F: FnMut(&mut A, &B), { @@ -2690,13 +2815,12 @@ where where F: FnMut(B, &'a A) -> B, A: 'a, - S: Data, { if let Some(slc) = self.as_slice_memory_order() { slc.iter().fold(init, f) } else { let mut v = self.view(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + move_min_stride_axis_to_last(&mut v.layout.dim, &mut v.layout.strides); v.into_elements_base().fold(init, f) } } @@ -2723,7 +2847,6 @@ where where F: FnMut(&'a A) -> B, A: 'a, - S: Data, { unsafe { if let Some(slc) = self.as_slice_memory_order() { @@ -2748,7 +2871,6 @@ where where F: FnMut(&'a mut A) -> B, A: 'a, - S: DataMut, { let dim = self.dim.clone(); if self.is_contiguous() { @@ -2781,11 +2903,16 @@ where where F: FnMut(A) -> B, A: Clone, - S: Data, { self.map(move |x| f(x.clone())) } +} +impl ArrayBase +where + S: RawData, + D: Dimension, +{ /// Call `f` by **v**alue on each element, update the array with the new values /// and return it. /// @@ -2813,7 +2940,7 @@ where /// Elements are visited in arbitrary order. /// /// [`mapv_into`]: ArrayBase::mapv_into - /// [`mapv`]: ArrayBase::mapv + /// [`mapv`]: ArrayRef::mapv pub fn mapv_into_any(self, mut f: F) -> Array where S: DataMut, @@ -2841,13 +2968,15 @@ where self.mapv(f) } } +} +impl ArrayRef +{ /// Modify the array in place by calling `f` by mutable reference on each element. /// /// Elements are visited in arbitrary order. pub fn map_inplace<'a, F>(&'a mut self, f: F) where - S: DataMut, A: 'a, F: FnMut(&'a mut A), { @@ -2855,7 +2984,7 @@ where Ok(slc) => slc.iter_mut().for_each(f), Err(arr) => { let mut v = arr.view_mut(); - move_min_stride_axis_to_last(&mut v.dim, &mut v.strides); + move_min_stride_axis_to_last(&mut v.layout.dim, &mut v.layout.strides); v.into_elements_base().for_each(f); } } @@ -2884,7 +3013,6 @@ where /// ``` pub fn mapv_inplace(&mut self, mut f: F) where - S: DataMut, F: FnMut(A) -> A, A: Clone, { @@ -2898,7 +3026,6 @@ where where F: FnMut(&'a A), A: 'a, - S: Data, { self.fold((), move |(), elt| f(elt)) } @@ -2917,7 +3044,6 @@ where D: RemoveAxis, F: FnMut(&B, &A) -> B, B: Clone, - S: Data, { let mut res = Array::from_elem(self.raw_dim().remove_axis(axis), init); for subview in self.axis_iter(axis) { @@ -2940,7 +3066,6 @@ where D: RemoveAxis, F: FnMut(ArrayView1<'a, A>) -> B, A: 'a, - S: Data, { if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); @@ -2966,7 +3091,6 @@ where D: RemoveAxis, F: FnMut(ArrayViewMut1<'a, A>) -> B, A: 'a, - S: DataMut, { if self.len_of(axis) == 0 { let new_dim = self.dim.remove_axis(axis); @@ -2975,7 +3099,13 @@ where Zip::from(self.lanes_mut(axis)).map_collect(mapping) } } +} +impl ArrayBase +where + S: DataOwned + DataMut, + D: Dimension, +{ /// Remove the `index`th elements along `axis` and shift down elements from higher indexes. /// /// Note that this "removes" the elements by swapping them around to the end of the axis and @@ -2988,7 +3118,6 @@ where /// ***Panics*** if `axis` is out of bounds
/// ***Panics*** if not `index < self.len_of(axis)`. pub fn remove_index(&mut self, axis: Axis, index: usize) - where S: DataOwned + DataMut { assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index()); @@ -2998,7 +3127,10 @@ where // then slice the axis in place to cut out the removed final element self.slice_axis_inplace(axis, Slice::new(0, Some(-1), 1)); } +} +impl ArrayRef +{ /// Iterates over pairs of consecutive elements along the axis. /// /// The first argument to the closure is an element, and the second @@ -3028,9 +3160,7 @@ where /// ); /// ``` pub fn accumulate_axis_inplace(&mut self, axis: Axis, mut f: F) - where - F: FnMut(&A, &mut A), - S: DataMut, + where F: FnMut(&A, &mut A) { if self.len_of(axis) <= 1 { return; diff --git a/src/impl_ops.rs b/src/impl_ops.rs index 46ea18a7c..53f49cc43 100644 --- a/src/impl_ops.rs +++ b/src/impl_ops.rs @@ -72,6 +72,7 @@ where E: Dimension, { type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: ArrayBase) -> Self::Output { @@ -100,8 +101,37 @@ where E: Dimension, { type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: &ArrayBase) -> Self::Output + { + self.$mth(&**rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between `self` and reference `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S, D, E> $trt<&'a ArrayRef> for ArrayBase +where + A: Clone + $trt, + B: Clone, + S: DataOwned + DataMut, + D: Dimension + DimMax, + E: Dimension, +{ + type Output = ArrayBase>::Output>; + + #[track_caller] + fn $mth(self, rhs: &ArrayRef) -> Self::Output { if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let mut out = self.into_dimensionality::<>::Output>().unwrap(); @@ -141,6 +171,36 @@ where E: Dimension + DimMax, { type Output = ArrayBase>::Output>; + + #[track_caller] + fn $mth(self, rhs: ArrayBase) -> Self::Output + where + { + (&**self).$mth(rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between reference `self` and `rhs`, +/// and return the result. +/// +/// `rhs` must be an `Array` or `ArcArray`. +/// +/// If their shapes disagree, `self` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, S2, D, E> $trt> for &'a ArrayRef +where + A: Clone + $trt, + B: Clone, + S2: DataOwned + DataMut, + D: Dimension, + E: Dimension + DimMax, +{ + type Output = ArrayBase>::Output>; + #[track_caller] fn $mth(self, rhs: ArrayBase) -> Self::Output where @@ -181,8 +241,33 @@ where E: Dimension, { type Output = Array>::Output>; + #[track_caller] fn $mth(self, rhs: &'a ArrayBase) -> Self::Output { + (&**self).$mth(&**rhs) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between references `self` and `rhs`, +/// and return the result as a new `Array`. +/// +/// If their shapes disagree, `self` and `rhs` is broadcast to their broadcast shape, +/// cloning the data if needed. +/// +/// **Panics** if broadcasting isn’t possible. +impl<'a, A, B, D, E> $trt<&'a ArrayRef> for &'a ArrayRef +where + A: Clone + $trt, + B: Clone, + D: Dimension + DimMax, + E: Dimension, +{ + type Output = Array>::Output>; + + #[track_caller] + fn $mth(self, rhs: &'a ArrayRef) -> Self::Output { let (lhs, rhs) = if self.ndim() == rhs.ndim() && self.shape() == rhs.shape() { let lhs = self.view().into_dimensionality::<>::Output>().unwrap(); let rhs = rhs.view().into_dimensionality::<>::Output>().unwrap(); @@ -226,6 +311,23 @@ impl<'a, A, S, D, B> $trt for &'a ArrayBase B: ScalarOperand, { type Output = Array; + + fn $mth(self, x: B) -> Self::Output { + (&**self).$mth(x) + } +} + +/// Perform elementwise +#[doc=$doc] +/// between the reference `self` and the scalar `x`, +/// and return the result as a new `Array`. +impl<'a, A, D, B> $trt for &'a ArrayRef + where A: Clone + $trt, + D: Dimension, + B: ScalarOperand, +{ + type Output = Array; + fn $mth(self, x: B) -> Self::Output { self.map(move |elt| elt.clone() $operator x.clone()) } @@ -277,7 +379,21 @@ impl<'a, S, D> $trt<&'a ArrayBase> for $scalar D: Dimension, { type Output = Array<$scalar, D>; + fn $mth(self, rhs: &ArrayBase) -> Self::Output { + self.$mth(&**rhs) + } +} + +// Perform elementwise +// between the scalar `self` and array `rhs`, +// and return the result as a new `Array`. +impl<'a, D> $trt<&'a ArrayRef<$scalar, D>> for $scalar + where D: Dimension +{ + type Output = Array<$scalar, D>; + + fn $mth(self, rhs: &ArrayRef<$scalar, D>) -> Self::Output { if_commutative!($commutative { rhs.$mth(self) } or { @@ -381,6 +497,7 @@ mod arithmetic_ops D: Dimension, { type Output = Self; + /// Perform an elementwise negation of `self` and return the result. fn neg(mut self) -> Self { @@ -398,6 +515,22 @@ mod arithmetic_ops D: Dimension, { type Output = Array; + + /// Perform an elementwise negation of reference `self` and return the + /// result as a new `Array`. + fn neg(self) -> Array + { + (&**self).neg() + } + } + + impl<'a, A, D> Neg for &'a ArrayRef + where + &'a A: 'a + Neg, + D: Dimension, + { + type Output = Array; + /// Perform an elementwise negation of reference `self` and return the /// result as a new `Array`. fn neg(self) -> Array @@ -413,6 +546,7 @@ mod arithmetic_ops D: Dimension, { type Output = Self; + /// Perform an elementwise unary not of `self` and return the result. fn not(mut self) -> Self { @@ -430,6 +564,22 @@ mod arithmetic_ops D: Dimension, { type Output = Array; + + /// Perform an elementwise unary not of reference `self` and return the + /// result as a new `Array`. + fn not(self) -> Array + { + (&**self).not() + } + } + + impl<'a, A, D> Not for &'a ArrayRef + where + &'a A: 'a + Not, + D: Dimension, + { + type Output = Array; + /// Perform an elementwise unary not of reference `self` and return the /// result as a new `Array`. fn not(self) -> Array @@ -462,6 +612,22 @@ mod assign_ops { #[track_caller] fn $method(&mut self, rhs: &ArrayBase) { + (**self).$method(&**rhs) + } + } + + #[doc=$doc] + /// If their shapes disagree, `rhs` is broadcast to the shape of `self`. + /// + /// **Panics** if broadcasting isn’t possible. + impl<'a, A, D, E> $trt<&'a ArrayRef> for ArrayRef + where + A: Clone + $trt
, + D: Dimension, + E: Dimension, + { + #[track_caller] + fn $method(&mut self, rhs: &ArrayRef) { self.zip_mut_with(rhs, |x, y| { x.$method(y.clone()); }); @@ -474,6 +640,17 @@ mod assign_ops A: ScalarOperand + $trt, S: DataMut, D: Dimension, + { + fn $method(&mut self, rhs: A) { + (**self).$method(rhs) + } + } + + #[doc=$doc] + impl $trt for ArrayRef + where + A: ScalarOperand + $trt, + D: Dimension, { fn $method(&mut self, rhs: A) { self.map_inplace(move |elt| { diff --git a/src/impl_owned_array.rs b/src/impl_owned_array.rs index bb970f876..023e9ebb4 100644 --- a/src/impl_owned_array.rs +++ b/src/impl_owned_array.rs @@ -849,7 +849,7 @@ where D: Dimension 0 }; debug_assert!(data_to_array_offset >= 0); - self.ptr = self + self.layout.ptr = self .data .reserve(len_to_append) .offset(data_to_array_offset); @@ -880,7 +880,7 @@ pub(crate) unsafe fn drop_unreachable_raw( } sort_axes_in_default_order(&mut self_); // with uninverted axes this is now the element with lowest address - let array_memory_head_ptr = self_.ptr; + let array_memory_head_ptr = self_.layout.ptr; let data_end_ptr = data_ptr.add(data_len); debug_assert!(data_ptr <= array_memory_head_ptr); debug_assert!(array_memory_head_ptr <= data_end_ptr); @@ -897,19 +897,19 @@ pub(crate) unsafe fn drop_unreachable_raw( // As an optimization, the innermost axis is removed if it has stride 1, because // we then have a long stretch of contiguous elements we can skip as one. let inner_lane_len; - if self_.ndim() > 1 && self_.strides.last_elem() == 1 { - self_.dim.slice_mut().rotate_right(1); - self_.strides.slice_mut().rotate_right(1); - inner_lane_len = self_.dim[0]; - self_.dim[0] = 1; - self_.strides[0] = 1; + if self_.ndim() > 1 && self_.layout.strides.last_elem() == 1 { + self_.layout.dim.slice_mut().rotate_right(1); + self_.layout.strides.slice_mut().rotate_right(1); + inner_lane_len = self_.layout.dim[0]; + self_.layout.dim[0] = 1; + self_.layout.strides[0] = 1; } else { inner_lane_len = 1; } // iter is a raw pointer iterator traversing the array in memory order now with the // sorted axes. - let mut iter = Baseiter::new(self_.ptr, self_.dim, self_.strides); + let mut iter = Baseiter::new(self_.layout.ptr, self_.layout.dim, self_.layout.strides); let mut dropped_elements = 0; let mut last_ptr = data_ptr; @@ -948,7 +948,7 @@ where if a.ndim() <= 1 { return; } - sort_axes1_impl(&mut a.dim, &mut a.strides); + sort_axes1_impl(&mut a.layout.dim, &mut a.layout.strides); } fn sort_axes1_impl(adim: &mut D, astrides: &mut D) @@ -988,7 +988,7 @@ where if a.ndim() <= 1 { return; } - sort_axes2_impl(&mut a.dim, &mut a.strides, &mut b.dim, &mut b.strides); + sort_axes2_impl(&mut a.layout.dim, &mut a.layout.strides, &mut b.layout.dim, &mut b.layout.strides); } fn sort_axes2_impl(adim: &mut D, astrides: &mut D, bdim: &mut D, bstrides: &mut D) diff --git a/src/impl_raw_views.rs b/src/impl_raw_views.rs index 5132b1158..5bb2a0e42 100644 --- a/src/impl_raw_views.rs +++ b/src/impl_raw_views.rs @@ -98,10 +98,10 @@ where D: Dimension pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayView::new(self.ptr, self.dim, self.strides) + ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -113,23 +113,23 @@ where D: Dimension pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { assert!(index <= self.len_of(axis)); - let left_ptr = self.ptr.as_ptr(); + let left_ptr = self.layout.ptr.as_ptr(); let right_ptr = if index == self.len_of(axis) { - self.ptr.as_ptr() + self.layout.ptr.as_ptr() } else { - let offset = stride_offset(index, self.strides.axis(axis)); + let offset = stride_offset(index, self.layout.strides.axis(axis)); // The `.offset()` is safe due to the guarantees of `RawData`. - unsafe { self.ptr.as_ptr().offset(offset) } + unsafe { self.layout.ptr.as_ptr().offset(offset) } }; - let mut dim_left = self.dim.clone(); + let mut dim_left = self.layout.dim.clone(); dim_left.set_axis(axis, index); - let left = unsafe { Self::new_(left_ptr, dim_left, self.strides.clone()) }; + let left = unsafe { Self::new_(left_ptr, dim_left, self.layout.strides.clone()) }; - let mut dim_right = self.dim; + let mut dim_right = self.layout.dim; let right_len = dim_right.axis(axis) - index; dim_right.set_axis(axis, right_len); - let right = unsafe { Self::new_(right_ptr, dim_right, self.strides) }; + let right = unsafe { Self::new_(right_ptr, dim_right, self.layout.strides) }; (left, right) } @@ -152,8 +152,8 @@ where D: Dimension mem::size_of::(), "size mismatch in raw view cast" ); - let ptr = self.ptr.cast::(); - unsafe { RawArrayView::new(ptr, self.dim, self.strides) } + let ptr = self.layout.ptr.cast::(); + unsafe { RawArrayView::new(ptr, self.layout.dim, self.layout.strides) } } } @@ -172,11 +172,11 @@ where D: Dimension ); assert_eq!(mem::align_of::>(), mem::align_of::()); - let dim = self.dim.clone(); + let dim = self.layout.dim.clone(); // Double the strides. In the zero-sized element case and for axes of // length <= 1, we leave the strides as-is to avoid possible overflow. - let mut strides = self.strides.clone(); + let mut strides = self.layout.strides.clone(); if mem::size_of::() != 0 { for ax in 0..strides.ndim() { if dim[ax] > 1 { @@ -185,7 +185,7 @@ where D: Dimension } } - let ptr_re: *mut T = self.ptr.as_ptr().cast(); + let ptr_re: *mut T = self.layout.ptr.as_ptr().cast(); let ptr_im: *mut T = if self.is_empty() { // In the empty case, we can just reuse the existing pointer since // it won't be dereferenced anyway. It is not safe to offset by @@ -308,7 +308,7 @@ where D: Dimension #[inline] pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Converts to a read-only view of the array. @@ -323,10 +323,10 @@ where D: Dimension pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayView::new(self.ptr, self.dim, self.strides) + ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Converts to a mutable view of the array. @@ -341,10 +341,10 @@ where D: Dimension pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> { debug_assert!( - is_aligned(self.ptr.as_ptr()), + is_aligned(self.layout.ptr.as_ptr()), "The pointer must be aligned." ); - ArrayViewMut::new(self.ptr, self.dim, self.strides) + ArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } /// Split the array view along `axis` and return one array pointer strictly @@ -356,7 +356,12 @@ where D: Dimension pub fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) { let (left, right) = self.into_raw_view().split_at(axis, index); - unsafe { (Self::new(left.ptr, left.dim, left.strides), Self::new(right.ptr, right.dim, right.strides)) } + unsafe { + ( + Self::new(left.layout.ptr, left.layout.dim, left.layout.strides), + Self::new(right.layout.ptr, right.layout.dim, right.layout.strides), + ) + } } /// Cast the raw pointer of the raw array view to a different type @@ -377,8 +382,8 @@ where D: Dimension mem::size_of::(), "size mismatch in raw view cast" ); - let ptr = self.ptr.cast::(); - unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) } + let ptr = self.layout.ptr.cast::(); + unsafe { RawArrayViewMut::new(ptr, self.layout.dim, self.layout.strides) } } } @@ -392,8 +397,8 @@ where D: Dimension let Complex { re, im } = self.into_raw_view().split_complex(); unsafe { Complex { - re: RawArrayViewMut::new(re.ptr, re.dim, re.strides), - im: RawArrayViewMut::new(im.ptr, im.dim, im.strides), + re: RawArrayViewMut::new(re.layout.ptr, re.layout.dim, re.layout.strides), + im: RawArrayViewMut::new(im.layout.ptr, im.layout.dim, im.layout.strides), } } } diff --git a/src/impl_ref_types.rs b/src/impl_ref_types.rs new file mode 100644 index 000000000..d93a996bf --- /dev/null +++ b/src/impl_ref_types.rs @@ -0,0 +1,370 @@ +//! Implementations that connect arrays to their reference types. +//! +//! `ndarray` has four kinds of array types that users may interact with: +//! 1. [`ArrayBase`], which represents arrays which own their layout (shape and strides) +//! 2. [`ArrayRef`], which represents a read-safe, uniquely-owned look at an array +//! 3. [`RawRef`], which represents a read-unsafe, possibly-shared look at an array +//! 4. [`LayoutRef`], which represents a look at an array's underlying structure, +//! but does not allow data reading of any kind +//! +//! These types are connected through a number of `Deref` and `AsRef` implementations. +//! 1. `ArrayBase` dereferences to `ArrayRef` when `S: Data` +//! 2. `ArrayBase` mutably dereferences to `ArrayRef` when `S: DataMut`, and ensures uniqueness +//! 3. `ArrayRef` mutably dereferences to `RawRef` +//! 4. `RawRef` mutably dereferences to `LayoutRef` +//! This chain works very well for arrays whose data is safe to read and is uniquely held. +//! Because raw views do not meet `S: Data`, they cannot dereference to `ArrayRef`; furthermore, +//! technical limitations of Rust's compiler means that `ArrayBase` cannot have multiple `Deref` implementations. +//! In addition, shared-data arrays do not want to go down the `Deref` path to get to methods on `RawRef` +//! or `LayoutRef`, since that would unecessarily ensure their uniqueness. +//! +//! To mitigate these problems, `ndarray` also provides `AsRef` and `AsMut` implementations as follows: +//! 1. `ArrayBase` implements `AsRef` to `RawRef` and `LayoutRef` when `S: RawData` +//! 2. `ArrayBase` implements `AsMut` to `RawRef` when `S: RawDataMut` +//! 3. `ArrayBase` implements `AsRef` and `AsMut` to `LayoutRef` unconditionally +//! 4. `ArrayRef` implements `AsRef` and `AsMut` to `RawRef` and `LayoutRef` unconditionally +//! 5. `RawRef` implements `AsRef` and `AsMut` to `LayoutRef` +//! 6. `RawRef` and `LayoutRef` implement `AsRef` and `AsMut` to themselves +//! +//! This allows users to write a single method or trait implementation that takes `T: AsRef>` +//! or `T: AsRef>` and have that functionality work on any of the relevant array types. + +use alloc::borrow::ToOwned; +use core::{ + borrow::{Borrow, BorrowMut}, + ops::{Deref, DerefMut}, +}; + +use crate::{Array, ArrayBase, ArrayRef, Data, DataMut, Dimension, LayoutRef, RawData, RawDataMut, RawRef}; + +// D1: &ArrayBase -> &ArrayRef when data is safe to read +impl Deref for ArrayBase +where S: Data +{ + type Target = ArrayRef; + + fn deref(&self) -> &Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(&self.layout as *const LayoutRef).cast::>() } + } +} + +// D2: &mut ArrayBase -> &mut ArrayRef when data is safe to read; ensure uniqueness +impl DerefMut for ArrayBase +where + S: DataMut, + D: Dimension, +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + self.ensure_unique(); + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(&mut self.layout as *mut LayoutRef).cast::>() } + } +} + +// D3: &ArrayRef -> &RawRef +impl Deref for ArrayRef +{ + type Target = RawRef; + + fn deref(&self) -> &Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(self as *const ArrayRef).cast::>() } + } +} + +// D4: &mut ArrayRef -> &mut RawRef +impl DerefMut for ArrayRef +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(self as *mut ArrayRef).cast::>() } + } +} + +// D5: &RawRef -> &LayoutRef +impl Deref for RawRef +{ + type Target = LayoutRef; + + fn deref(&self) -> &Self::Target + { + &self.0 + } +} + +// D5: &mut RawRef -> &mut LayoutRef +impl DerefMut for RawRef +{ + fn deref_mut(&mut self) -> &mut Self::Target + { + &mut self.0 + } +} + +// A1: &ArrayBase -AR-> &RawRef +impl AsRef> for ArrayBase +where S: RawData +{ + fn as_ref(&self) -> &RawRef + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &*(&self.layout as *const LayoutRef).cast::>() } + } +} + +// A2: &mut ArrayBase -AM-> &mut RawRef +impl AsMut> for ArrayBase +where S: RawDataMut +{ + fn as_mut(&mut self) -> &mut RawRef + { + // SAFETY: + // - The pointer is aligned because neither type uses repr(align) + // - It is "dereferencable" because it comes from a reference + // - For the same reason, it is initialized + // - The cast is valid because ArrayRef uses #[repr(transparent)] + unsafe { &mut *(&mut self.layout as *mut LayoutRef).cast::>() } + } +} + +// A3: &ArrayBase -AR-> &LayoutRef +impl AsRef> for ArrayBase +where S: RawData +{ + fn as_ref(&self) -> &LayoutRef + { + &self.layout + } +} + +// A3: &mut ArrayBase -AM-> &mut LayoutRef +impl AsMut> for ArrayBase +where S: RawData +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + &mut self.layout + } +} + +// A4: &ArrayRef -AR-> &RawRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &RawRef + { + self + } +} + +// A4: &mut ArrayRef -AM-> &mut RawRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut RawRef + { + self + } +} + +// A4: &ArrayRef -AR-> &LayoutRef +impl AsRef> for ArrayRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A4: &mut ArrayRef -AM-> &mut LayoutRef +impl AsMut> for ArrayRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +// A5: &RawRef -AR-> &LayoutRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A5: &mut RawRef -AM-> &mut LayoutRef +impl AsMut> for RawRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +// A6: &RawRef -AR-> &RawRef +impl AsRef> for RawRef +{ + fn as_ref(&self) -> &RawRef + { + self + } +} + +// A6: &mut RawRef -AM-> &mut RawRef +impl AsMut> for RawRef +{ + fn as_mut(&mut self) -> &mut RawRef + { + self + } +} + +// A6: &LayoutRef -AR-> &LayoutRef +impl AsRef> for LayoutRef +{ + fn as_ref(&self) -> &LayoutRef + { + self + } +} + +// A6: &mut LayoutRef -AR-> &mut LayoutRef +impl AsMut> for LayoutRef +{ + fn as_mut(&mut self) -> &mut LayoutRef + { + self + } +} + +/// # Safety +/// +/// Usually the pointer would be bad to just clone, as we'd have aliasing +/// and completely separated references to the same data. However, it is +/// impossible to read the data behind the pointer from a LayoutRef (this +/// is a safety invariant that *must* be maintained), and therefore we can +/// Clone and Copy as desired. +impl Clone for LayoutRef +{ + fn clone(&self) -> Self + { + Self { + dim: self.dim.clone(), + strides: self.strides.clone(), + ptr: self.ptr, + } + } +} + +impl Copy for LayoutRef {} + +impl Borrow> for ArrayBase +where S: RawData +{ + fn borrow(&self) -> &RawRef + { + self.as_ref() + } +} + +impl BorrowMut> for ArrayBase +where S: RawDataMut +{ + fn borrow_mut(&mut self) -> &mut RawRef + { + self.as_mut() + } +} + +impl Borrow> for ArrayBase +where S: Data +{ + fn borrow(&self) -> &ArrayRef + { + self + } +} + +impl BorrowMut> for ArrayBase +where + S: DataMut, + D: Dimension, +{ + fn borrow_mut(&mut self) -> &mut ArrayRef + { + self + } +} + +impl ToOwned for ArrayRef +where + A: Clone, + D: Dimension, +{ + type Owned = Array; + + fn to_owned(&self) -> Self::Owned + { + self.to_owned() + } + + fn clone_into(&self, target: &mut Array) + { + target.zip_mut_with(self, |tgt, src| tgt.clone_from(src)); + } +} + +/// Shortcuts for the various as_ref calls +impl ArrayBase +where S: RawData +{ + /// Cheaply convert a reference to the array to an &LayoutRef + pub fn as_layout_ref(&self) -> &LayoutRef + { + self.as_ref() + } + + /// Cheaply and mutably convert a reference to the array to an &LayoutRef + pub fn as_layout_ref_mut(&mut self) -> &mut LayoutRef + { + self.as_mut() + } + + /// Cheaply convert a reference to the array to an &RawRef + pub fn as_raw_ref(&self) -> &RawRef + { + self.as_ref() + } + + /// Cheaply and mutably convert a reference to the array to an &RawRef + pub fn as_raw_ref_mut(&mut self) -> &mut RawRef + where S: RawDataMut + { + self.as_mut() + } +} diff --git a/src/impl_special_element_types.rs b/src/impl_special_element_types.rs index e430b20bc..42b524bc2 100644 --- a/src/impl_special_element_types.rs +++ b/src/impl_special_element_types.rs @@ -9,6 +9,7 @@ use std::mem::MaybeUninit; use crate::imp_prelude::*; +use crate::LayoutRef; use crate::RawDataSubst; /// Methods specific to arrays with `MaybeUninit` elements. @@ -35,9 +36,7 @@ where { let ArrayBase { data, - ptr, - dim, - strides, + layout: LayoutRef { ptr, dim, strides }, } = self; // "transmute" from storage of MaybeUninit to storage of A diff --git a/src/impl_views/constructors.rs b/src/impl_views/constructors.rs index d0089057d..29b7c13d7 100644 --- a/src/impl_views/constructors.rs +++ b/src/impl_views/constructors.rs @@ -225,7 +225,7 @@ where D: Dimension pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D> where 'a: 'b { - unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } diff --git a/src/impl_views/conversions.rs b/src/impl_views/conversions.rs index 1dd7d97f2..efd876f7a 100644 --- a/src/impl_views/conversions.rs +++ b/src/impl_views/conversions.rs @@ -29,13 +29,13 @@ where D: Dimension pub fn reborrow<'b>(self) -> ArrayView<'b, A, D> where 'a: 'b { - unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Return the array’s data as a slice, if it is contiguous and in standard order. /// Return `None` otherwise. /// - /// Note that while the method is similar to [`ArrayBase::as_slice()`], this method transfers + /// Note that while the method is similar to [`ArrayRef::as_slice()`], this method transfers /// the view's lifetime to the slice, so it is a bit more powerful. pub fn to_slice(&self) -> Option<&'a [A]> { @@ -50,7 +50,7 @@ where D: Dimension /// Return `None` otherwise. /// /// Note that while the method is similar to - /// [`ArrayBase::as_slice_memory_order()`], this method transfers the view's + /// [`ArrayRef::as_slice_memory_order()`], this method transfers the view's /// lifetime to the slice, so it is a bit more powerful. pub fn to_slice_memory_order(&self) -> Option<&'a [A]> { @@ -66,7 +66,7 @@ where D: Dimension #[inline] pub(crate) fn into_raw_view(self) -> RawArrayView { - unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -199,7 +199,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -209,7 +209,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } } @@ -220,7 +220,7 @@ where D: Dimension #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] @@ -250,19 +250,19 @@ where D: Dimension // Convert into a read-only view pub(crate) fn into_view(self) -> ArrayView<'a, A, D> { - unsafe { ArrayView::new(self.ptr, self.dim, self.strides) } + unsafe { ArrayView::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } /// Converts to a mutable raw array view. pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut { - unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) } + unsafe { RawArrayViewMut::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] pub(crate) fn into_base_iter(self) -> Baseiter { - unsafe { Baseiter::new(self.ptr, self.dim, self.strides) } + unsafe { Baseiter::new(self.layout.ptr, self.layout.dim, self.layout.strides) } } #[inline] diff --git a/src/impl_views/indexing.rs b/src/impl_views/indexing.rs index 827313478..2879e7416 100644 --- a/src/impl_views/indexing.rs +++ b/src/impl_views/indexing.rs @@ -60,7 +60,7 @@ pub trait IndexLonger /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get + /// [1]: ArrayRef::get /// /// **Panics** if index is out of bounds. #[track_caller] @@ -68,15 +68,15 @@ pub trait IndexLonger /// Get a reference of a element through the view. /// - /// This method is like `ArrayBase::get` but with a longer lifetime (matching + /// This method is like `ArrayRef::get` but with a longer lifetime (matching /// the array view); which we can only do for the array view and not in the /// `Index` trait. /// /// See also [the `get` method][1] (and [`get_mut`][2]) which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get - /// [2]: ArrayBase::get_mut + /// [1]: ArrayRef::get + /// [2]: ArrayRef::get_mut /// /// **Panics** if index is out of bounds. #[track_caller] @@ -90,7 +90,7 @@ pub trait IndexLonger /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget + /// [1]: ArrayRef::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. /// @@ -116,7 +116,7 @@ where /// See also [the `get` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get + /// [1]: ArrayRef::get /// /// **Panics** if index is out of bounds. #[track_caller] @@ -139,7 +139,7 @@ where /// See also [the `uget` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget + /// [1]: ArrayRef::uget /// /// **Note:** only unchecked for non-debug builds of ndarray. unsafe fn uget(self, index: I) -> &'a A @@ -165,7 +165,7 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get_mut + /// [1]: ArrayRef::get_mut /// /// **Panics** if index is out of bounds. #[track_caller] @@ -186,7 +186,7 @@ where /// See also [the `get_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::get_mut + /// [1]: ArrayRef::get_mut /// fn get(mut self, index: I) -> Option<&'a mut A> { @@ -205,7 +205,7 @@ where /// See also [the `uget_mut` method][1] which works for all arrays and array /// views. /// - /// [1]: ArrayBase::uget_mut + /// [1]: ArrayRef::uget_mut /// /// **Note:** only unchecked for non-debug builds of ndarray. unsafe fn uget(mut self, index: I) -> &'a mut A diff --git a/src/impl_views/splitting.rs b/src/impl_views/splitting.rs index 58d0a7556..42b12b159 100644 --- a/src/impl_views/splitting.rs +++ b/src/impl_views/splitting.rs @@ -157,7 +157,7 @@ where D: Dimension /// [`MultiSliceArg`], [`s!`], [`SliceArg`](crate::SliceArg), and /// [`SliceInfo`](crate::SliceInfo). /// - /// [`.multi_slice_mut()`]: ArrayBase::multi_slice_mut + /// [`.multi_slice_mut()`]: ArrayRef::multi_slice_mut /// /// **Panics** if any of the following occur: /// diff --git a/src/iterators/chunks.rs b/src/iterators/chunks.rs index 9e2f08e1e..4dd99f002 100644 --- a/src/iterators/chunks.rs +++ b/src/iterators/chunks.rs @@ -27,7 +27,7 @@ impl_ndproducer! { /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more +/// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. //#[derive(Debug)] pub struct ExactChunks<'a, A, D> @@ -59,10 +59,10 @@ impl<'a, A, D: Dimension> ExactChunks<'a, A, D> a.shape() ); for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; + a.layout.dim[i] /= chunk[i]; } - let inner_strides = a.strides.clone(); - a.strides *= &chunk; + let inner_strides = a.layout.strides.clone(); + a.layout.strides *= &chunk; ExactChunks { base: a, @@ -93,7 +93,7 @@ where /// Exact chunks iterator. /// -/// See [`.exact_chunks()`](ArrayBase::exact_chunks) for more +/// See [`.exact_chunks()`](crate::ArrayRef::exact_chunks) for more /// information. pub struct ExactChunksIter<'a, A, D> { @@ -126,7 +126,7 @@ impl_ndproducer! { /// Exact chunks producer and iterable. /// -/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) +/// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. //#[derive(Debug)] pub struct ExactChunksMut<'a, A, D> @@ -158,10 +158,10 @@ impl<'a, A, D: Dimension> ExactChunksMut<'a, A, D> a.shape() ); for i in 0..a.ndim() { - a.dim[i] /= chunk[i]; + a.layout.dim[i] /= chunk[i]; } - let inner_strides = a.strides.clone(); - a.strides *= &chunk; + let inner_strides = a.layout.strides.clone(); + a.layout.strides *= &chunk; ExactChunksMut { base: a, @@ -237,7 +237,7 @@ impl_iterator! { /// Exact chunks iterator. /// -/// See [`.exact_chunks_mut()`](ArrayBase::exact_chunks_mut) +/// See [`.exact_chunks_mut()`](crate::ArrayRef::exact_chunks_mut) /// for more information. pub struct ExactChunksIterMut<'a, A, D> { diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs index 9374608cb..b51315a0f 100644 --- a/src/iterators/into_iter.rs +++ b/src/iterators/into_iter.rs @@ -39,9 +39,9 @@ where D: Dimension let array_head_ptr = array.ptr; let mut array_data = array.data; let data_len = array_data.release_all_elements(); - debug_assert!(data_len >= array.dim.size()); - let has_unreachable_elements = array.dim.size() != data_len; - let inner = Baseiter::new(array_head_ptr, array.dim, array.strides); + debug_assert!(data_len >= array.layout.dim.size()); + let has_unreachable_elements = array.layout.dim.size() != data_len; + let inner = Baseiter::new(array_head_ptr, array.layout.dim, array.layout.strides); IntoIter { array_data, diff --git a/src/iterators/lanes.rs b/src/iterators/lanes.rs index 11c83d002..0f9678872 100644 --- a/src/iterators/lanes.rs +++ b/src/iterators/lanes.rs @@ -23,7 +23,7 @@ impl_ndproducer! { } } -/// See [`.lanes()`](ArrayBase::lanes) +/// See [`.lanes()`](crate::ArrayRef::lanes) /// for more information. pub struct Lanes<'a, A, D> { @@ -92,7 +92,7 @@ where D: Dimension } } -/// See [`.lanes_mut()`](ArrayBase::lanes_mut) +/// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. pub struct LanesMut<'a, A, D> { diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index e0da8f6c9..55a9920a8 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -333,7 +333,7 @@ pub enum ElementsRepr /// /// Iterator element type is `&'a A`. /// -/// See [`.iter()`](ArrayBase::iter) for more information. +/// See [`.iter()`](crate::ArrayRef::iter) for more information. #[derive(Debug)] pub struct Iter<'a, A, D> { @@ -352,7 +352,7 @@ pub struct ElementsBase<'a, A, D> /// /// Iterator element type is `&'a mut A`. /// -/// See [`.iter_mut()`](ArrayBase::iter_mut) for more information. +/// See [`.iter_mut()`](crate::ArrayRef::iter_mut) for more information. #[derive(Debug)] pub struct IterMut<'a, A, D> { @@ -382,12 +382,12 @@ impl<'a, A, D: Dimension> ElementsBaseMut<'a, A, D> /// An iterator over the indexes and elements of an array. /// -/// See [`.indexed_iter()`](ArrayBase::indexed_iter) for more information. +/// See [`.indexed_iter()`](crate::ArrayRef::indexed_iter) for more information. #[derive(Clone)] pub struct IndexedIter<'a, A, D>(ElementsBase<'a, A, D>); /// An iterator over the indexes and elements of an array (mutable). /// -/// See [`.indexed_iter_mut()`](ArrayBase::indexed_iter_mut) for more information. +/// See [`.indexed_iter_mut()`](crate::ArrayRef::indexed_iter_mut) for more information. pub struct IndexedIterMut<'a, A, D>(ElementsBaseMut<'a, A, D>); impl<'a, A, D> IndexedIter<'a, A, D> @@ -726,7 +726,7 @@ where D: Dimension /// An iterator that traverses over all axes but one, and yields a view for /// each lane along that axis. /// -/// See [`.lanes()`](ArrayBase::lanes) for more information. +/// See [`.lanes()`](crate::ArrayRef::lanes) for more information. pub struct LanesIter<'a, A, D> { inner_len: Ix, @@ -789,7 +789,7 @@ impl DoubleEndedIterator for LanesIter<'_, A, Ix1> /// An iterator that traverses over all dimensions but the innermost, /// and yields each inner row (mutable). /// -/// See [`.lanes_mut()`](ArrayBase::lanes_mut) +/// See [`.lanes_mut()`](crate::ArrayRef::lanes_mut) /// for more information. pub struct LanesIterMut<'a, A, D> { @@ -1004,8 +1004,8 @@ where D: Dimension /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.outer_iter()`](ArrayBase::outer_iter) -/// or [`.axis_iter()`](ArrayBase::axis_iter) +/// See [`.outer_iter()`](crate::ArrayRef::outer_iter) +/// or [`.axis_iter()`](crate::ArrayRef::axis_iter) /// for more information. #[derive(Debug)] pub struct AxisIter<'a, A, D> @@ -1105,8 +1105,8 @@ where D: Dimension /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.outer_iter_mut()`](ArrayBase::outer_iter_mut) -/// or [`.axis_iter_mut()`](ArrayBase::axis_iter_mut) +/// See [`.outer_iter_mut()`](crate::ArrayRef::outer_iter_mut) +/// or [`.axis_iter_mut()`](crate::ArrayRef::axis_iter_mut) /// for more information. pub struct AxisIterMut<'a, A, D> { @@ -1311,7 +1311,7 @@ impl NdProducer for AxisIterMut<'_, A, D> /// /// Iterator element type is `ArrayView<'a, A, D>`. /// -/// See [`.axis_chunks_iter()`](ArrayBase::axis_chunks_iter) for more information. +/// See [`.axis_chunks_iter()`](crate::ArrayRef::axis_chunks_iter) for more information. pub struct AxisChunksIter<'a, A, D> { iter: AxisIterCore, @@ -1369,7 +1369,7 @@ fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: u let mut inner_dim = v.dim.clone(); inner_dim[axis] = size; - let mut partial_chunk_dim = v.dim; + let mut partial_chunk_dim = v.layout.dim; partial_chunk_dim[axis] = chunk_remainder; let partial_chunk_index = n_whole_chunks; @@ -1378,8 +1378,8 @@ fn chunk_iter_parts(v: ArrayView<'_, A, D>, axis: Axis, size: u end: iter_len, stride, inner_dim, - inner_strides: v.strides, - ptr: v.ptr.as_ptr(), + inner_strides: v.layout.strides, + ptr: v.layout.ptr.as_ptr(), }; (iter, partial_chunk_index, partial_chunk_dim) @@ -1493,7 +1493,7 @@ macro_rules! chunk_iter_impl { /// /// Iterator element type is `ArrayViewMut<'a, A, D>`. /// -/// See [`.axis_chunks_iter_mut()`](ArrayBase::axis_chunks_iter_mut) +/// See [`.axis_chunks_iter_mut()`](crate::ArrayRef::axis_chunks_iter_mut) /// for more information. pub struct AxisChunksIterMut<'a, A, D> { diff --git a/src/iterators/windows.rs b/src/iterators/windows.rs index 1c2ab6a85..afdaaa895 100644 --- a/src/iterators/windows.rs +++ b/src/iterators/windows.rs @@ -9,7 +9,7 @@ use crate::Slice; /// Window producer and iterable /// -/// See [`.windows()`](ArrayBase::windows) for more +/// See [`.windows()`](crate::ArrayRef::windows) for more /// information. pub struct Windows<'a, A, D> { @@ -91,7 +91,7 @@ where /// Window iterator. /// -/// See [`.windows()`](ArrayBase::windows) for more +/// See [`.windows()`](crate::ArrayRef::windows) for more /// information. pub struct WindowsIter<'a, A, D> { @@ -129,7 +129,7 @@ send_sync_read_only!(WindowsIter); /// Window producer and iterable /// -/// See [`.axis_windows()`](ArrayBase::axis_windows) for more +/// See [`.axis_windows()`](crate::ArrayRef::axis_windows) for more /// information. pub struct AxisWindows<'a, A, D> { diff --git a/src/layout/mod.rs b/src/layout/mod.rs index 026688d63..36853848e 100644 --- a/src/layout/mod.rs +++ b/src/layout/mod.rs @@ -1,6 +1,6 @@ mod layoutfmt; -// Layout it a bitset used for internal layout description of +// Layout is a bitset used for internal layout description of // arrays, producers and sets of producers. // The type is public but users don't interact with it. #[doc(hidden)] diff --git a/src/lib.rs b/src/lib.rs index 9ba3b6728..77bdc9313 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,12 +31,17 @@ //! dimensions, then an element in the array is accessed by using that many indices. //! Each dimension is also called an *axis*. //! +//! To get started, functionality is provided in the following core types: //! - **[`ArrayBase`]**: //! The *n*-dimensional array type itself.
//! It is used to implement both the owned arrays and the views; see its docs //! for an overview of all array features.
//! - The main specific array type is **[`Array`]**, which owns //! its elements. +//! - A reference type, **[`ArrayRef`]**, that contains most of the functionality +//! for reading and writing to arrays. +//! - A reference type, **[`LayoutRef`]**, that contains most of the functionality +//! for reading and writing to array layouts: their shape and strides. //! //! ## Highlights //! @@ -62,8 +67,8 @@ //! - Performance: //! + Prefer higher order methods and arithmetic operations on arrays first, //! then iteration, and as a last priority using indexed algorithms. -//! + The higher order functions like [`.map()`](ArrayBase::map), -//! [`.map_inplace()`](ArrayBase::map_inplace), [`.zip_mut_with()`](ArrayBase::zip_mut_with), +//! + The higher order functions like [`.map()`](ArrayRef::map), +//! [`.map_inplace()`](ArrayRef::map_inplace), [`.zip_mut_with()`](ArrayRef::zip_mut_with), //! [`Zip`] and [`azip!()`](azip) are the most efficient ways //! to perform single traversal and lock step traversal respectively. //! + Performance of an operation depends on the memory layout of the array @@ -163,6 +168,7 @@ pub use crate::shape_builder::{Shape, ShapeArg, ShapeBuilder, StrideShape}; mod macro_utils; #[macro_use] mod private; +mod impl_ref_types; mod aliases; #[macro_use] mod itertools; @@ -312,7 +318,7 @@ pub type Ixs = isize; /// data (shared ownership). /// Sharing requires that it uses copy-on-write for mutable operations. /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](Self::view_mut) or [`get_mut()`](Self::get_mut), +/// [`view_mut()`](ArrayRef::view_mut) or [`get_mut()`](ArrayRef::get_mut), /// will break sharing and require a clone of the data (if it is not uniquely held). /// /// ## `CowArray` @@ -336,9 +342,9 @@ pub type Ixs = isize; /// Please see the documentation for the respective array view for an overview /// of methods specific to array views: [`ArrayView`], [`ArrayViewMut`]. /// -/// A view is created from an array using [`.view()`](ArrayBase::view), -/// [`.view_mut()`](ArrayBase::view_mut), using -/// slicing ([`.slice()`](ArrayBase::slice), [`.slice_mut()`](ArrayBase::slice_mut)) or from one of +/// A view is created from an array using [`.view()`](ArrayRef::view), +/// [`.view_mut()`](ArrayRef::view_mut), using +/// slicing ([`.slice()`](ArrayRef::slice), [`.slice_mut()`](ArrayRef::slice_mut)) or from one of /// the many iterators that yield array views. /// /// You can also create an array view from a regular slice of data not @@ -480,12 +486,12 @@ pub type Ixs = isize; /// [`.columns()`][gc], [`.columns_mut()`][gcm], /// [`.lanes(axis)`][l], [`.lanes_mut(axis)`][lm]. /// -/// [gr]: Self::rows -/// [grm]: Self::rows_mut -/// [gc]: Self::columns -/// [gcm]: Self::columns_mut -/// [l]: Self::lanes -/// [lm]: Self::lanes_mut +/// [gr]: ArrayRef::rows +/// [grm]: ArrayRef::rows_mut +/// [gc]: ArrayRef::columns +/// [gcm]: ArrayRef::columns_mut +/// [l]: ArrayRef::lanes +/// [lm]: ArrayRef::lanes_mut /// /// Yes, for 2D arrays `.rows()` and `.outer_iter()` have about the same /// effect: @@ -511,10 +517,10 @@ pub type Ixs = isize; /// [`.slice_collapse()`] panics on `NewAxis` elements and behaves like /// [`.collapse_axis()`] by preserving the number of dimensions. /// -/// [`.slice()`]: Self::slice -/// [`.slice_mut()`]: Self::slice_mut +/// [`.slice()`]: ArrayRef::slice +/// [`.slice_mut()`]: ArrayRef::slice_mut /// [`.slice_move()`]: Self::slice_move -/// [`.slice_collapse()`]: Self::slice_collapse +/// [`.slice_collapse()`]: LayoutRef::slice_collapse /// /// When slicing arrays with generic dimensionality, creating an instance of /// [`SliceInfo`] to pass to the multi-axis slicing methods like [`.slice()`] @@ -523,17 +529,17 @@ pub type Ixs = isize; /// or to create a view and then slice individual axes of the view using /// methods such as [`.slice_axis_inplace()`] and [`.collapse_axis()`]. /// -/// [`.slice_each_axis()`]: Self::slice_each_axis -/// [`.slice_each_axis_mut()`]: Self::slice_each_axis_mut +/// [`.slice_each_axis()`]: ArrayRef::slice_each_axis +/// [`.slice_each_axis_mut()`]: ArrayRef::slice_each_axis_mut /// [`.slice_each_axis_inplace()`]: Self::slice_each_axis_inplace /// [`.slice_axis_inplace()`]: Self::slice_axis_inplace -/// [`.collapse_axis()`]: Self::collapse_axis +/// [`.collapse_axis()`]: LayoutRef::collapse_axis /// /// It's possible to take multiple simultaneous *mutable* slices with /// [`.multi_slice_mut()`] or (for [`ArrayViewMut`] only) /// [`.multi_slice_move()`]. /// -/// [`.multi_slice_mut()`]: Self::multi_slice_mut +/// [`.multi_slice_mut()`]: ArrayRef::multi_slice_mut /// [`.multi_slice_move()`]: ArrayViewMut#method.multi_slice_move /// /// ``` @@ -632,16 +638,16 @@ pub type Ixs = isize; /// Methods for selecting an individual subview take two arguments: `axis` and /// `index`. /// -/// [`.axis_iter()`]: Self::axis_iter -/// [`.axis_iter_mut()`]: Self::axis_iter_mut -/// [`.fold_axis()`]: Self::fold_axis -/// [`.index_axis()`]: Self::index_axis -/// [`.index_axis_inplace()`]: Self::index_axis_inplace -/// [`.index_axis_mut()`]: Self::index_axis_mut +/// [`.axis_iter()`]: ArrayRef::axis_iter +/// [`.axis_iter_mut()`]: ArrayRef::axis_iter_mut +/// [`.fold_axis()`]: ArrayRef::fold_axis +/// [`.index_axis()`]: ArrayRef::index_axis +/// [`.index_axis_inplace()`]: LayoutRef::index_axis_inplace +/// [`.index_axis_mut()`]: ArrayRef::index_axis_mut /// [`.index_axis_move()`]: Self::index_axis_move -/// [`.collapse_axis()`]: Self::collapse_axis -/// [`.outer_iter()`]: Self::outer_iter -/// [`.outer_iter_mut()`]: Self::outer_iter_mut +/// [`.collapse_axis()`]: LayoutRef::collapse_axis +/// [`.outer_iter()`]: ArrayRef::outer_iter +/// [`.outer_iter_mut()`]: ArrayRef::outer_iter_mut /// /// ``` /// @@ -747,7 +753,7 @@ pub type Ixs = isize; /// Arrays support limited *broadcasting*, where arithmetic operations with /// array operands of different sizes can be carried out by repeating the /// elements of the smaller dimension array. See -/// [`.broadcast()`](Self::broadcast) for a more detailed +/// [`.broadcast()`](ArrayRef::broadcast) for a more detailed /// description. /// /// ``` @@ -1048,9 +1054,9 @@ pub type Ixs = isize; /// `&[A]` | `ArrayView` | [`::from_shape()`](ArrayView#method.from_shape) /// `&mut [A]` | `ArrayViewMut1
` | [`::from()`](ArrayViewMut#method.from) /// `&mut [A]` | `ArrayViewMut` | [`::from_shape()`](ArrayViewMut#method.from_shape) -/// `&ArrayBase` | `Vec` | [`.to_vec()`](Self::to_vec) +/// `&ArrayBase` | `Vec` | [`.to_vec()`](ArrayRef::to_vec) /// `Array` | `Vec` | [`.into_raw_vec()`](Array#method.into_raw_vec)[1](#into_raw_vec) -/// `&ArrayBase` | `&[A]` | [`.as_slice()`](Self::as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](Self::as_slice_memory_order)[3](#req_contig) +/// `&ArrayBase` | `&[A]` | [`.as_slice()`](ArrayRef::as_slice)[2](#req_contig_std), [`.as_slice_memory_order()`](ArrayRef::as_slice_memory_order)[3](#req_contig) /// `&mut ArrayBase` | `&mut [A]` | [`.as_slice_mut()`](Self::as_slice_mut)[2](#req_contig_std), [`.as_slice_memory_order_mut()`](Self::as_slice_memory_order_mut)[3](#req_contig) /// `ArrayView` | `&[A]` | [`.to_slice()`](ArrayView#method.to_slice)[2](#req_contig_std) /// `ArrayViewMut` | `&mut [A]` | [`.into_slice()`](ArrayViewMut#method.into_slice)[2](#req_contig_std) @@ -1074,9 +1080,9 @@ pub type Ixs = isize; /// [.into_owned()]: Self::into_owned /// [.into_shared()]: Self::into_shared /// [.to_owned()]: Self::to_owned -/// [.map()]: Self::map -/// [.view()]: Self::view -/// [.view_mut()]: Self::view_mut +/// [.map()]: ArrayRef::map +/// [.view()]: ArrayRef::view +/// [.view_mut()]: ArrayRef::view_mut /// /// ### Conversions from Nested `Vec`s/`Array`s /// @@ -1277,6 +1283,9 @@ pub type Ixs = isize; // implementation since `ArrayBase` doesn't implement `Drop` and `&mut // ArrayBase` is `!UnwindSafe`, but the implementation must not call // methods/functions on the array while it violates the constraints. +// Critically, this includes calling `DerefMut`; as a result, methods/functions +// that temporarily violate these must not rely on the `DerefMut` implementation +// for access to the underlying `ptr`, `strides`, or `dim`. // // Users of the `ndarray` crate cannot rely on these constraints because they // may change in the future. @@ -1288,6 +1297,112 @@ where S: RawData /// Data buffer / ownership information. (If owned, contains the data /// buffer; if borrowed, contains the lifetime and mutability.) data: S, + /// The dimension, strides, and pointer to inside of `data` + layout: LayoutRef, +} + +/// A reference to the layout of an *n*-dimensional array. +/// +/// This type can be used to read and write to the layout of an array; +/// that is to say, its shape and strides. It does not provide any read +/// or write access to the array's underlying data. It is generic on two +/// types: `D`, its dimensionality, and `A`, the element type of its data. +/// +/// ## Example +/// Say we wanted to write a function that provides the aspect ratio +/// of any 2D array: the ratio of its width (number of columns) to its +/// height (number of rows). We would write that as follows: +/// ```rust +/// use ndarray::{LayoutRef2, array}; +/// +/// fn aspect_ratio(layout: &T) -> (usize, usize) +/// where T: AsRef> +/// { +/// let layout = layout.as_ref(); +/// (layout.ncols(), layout.nrows()) +/// } +/// +/// let arr = array![[1, 2], [3, 4]]; +/// assert_eq!(aspect_ratio(&arr), (2, 2)); +/// ``` +/// Similarly, new traits that provide functions that only depend on +/// or alter the layout of an array should do so via a blanket +/// implementation. Lets write a trait that both provides the aspect ratio +/// and lets users cut down arrays to a desired aspect ratio. +/// For simplicity, we'll panic if the user provides an aspect ratio +/// where either element is larger than the array's size. +/// ```rust +/// use ndarray::{LayoutRef2, array, s}; +/// +/// trait Ratioable { +/// fn aspect_ratio(&self) -> (usize, usize) +/// where Self: AsRef>; +/// +/// fn cut_to_ratio(&mut self, ratio: (usize, usize)) +/// where Self: AsMut>; +/// } +/// +/// impl Ratioable for T +/// where T: AsRef> + AsMut> +/// { +/// fn aspect_ratio(&self) -> (usize, usize) +/// { +/// let layout = self.as_ref(); +/// (layout.ncols(), layout.nrows()) +/// } +/// +/// fn cut_to_ratio(&mut self, ratio: (usize, usize)) +/// { +/// let layout = self.as_mut(); +/// layout.slice_collapse(s![..ratio.1, ..ratio.0]); +/// } +/// } +/// +/// let mut arr = array![[1, 2, 3], [4, 5, 6]]; +/// assert_eq!(arr.aspect_ratio(), (3, 2)); +/// arr.cut_to_ratio((2, 2)); +/// assert_eq!(arr, array![[1, 2], [4, 5]]); +/// ``` +/// Continue reading for why we use `AsRef` instead of taking `&LayoutRef` directly. +/// +/// ## Writing Functions +/// Writing functions that accept `LayoutRef` is not as simple as taking +/// a `&LayoutRef` argument, as the above examples show. This is because +/// `LayoutRef` can be obtained either cheaply or expensively, depending +/// on the method used. `LayoutRef` can be obtained from all kinds of arrays +/// -- [owned](Array), [shared](ArcArray), [viewed](ArrayView), [referenced](ArrayRef), +/// and [raw referenced](RawRef) -- via `.as_ref()`. Critically, this way of +/// obtaining a `LayoutRef` is cheap, as it does not guarantee that the +/// underlying data is uniquely held. +/// +/// However, `LayoutRef`s can be obtained a second way: they sit at the bottom +/// of a "deref chain" going from shared arrays, through `ArrayRef`, through +/// `RawRef`, and finally to `LayoutRef`. As a result, `LayoutRef`s can also +/// be obtained via auto-dereferencing. When requesting a mutable reference -- +/// `&mut LayoutRef` -- the `deref_mut` to `ArrayRef` triggers a (possibly +/// expensive) guarantee that the data is uniquely held (see [`ArrayRef`] +/// for more information). +/// +/// To help users avoid this error cost, functions that operate on `LayoutRef`s +/// should take their parameters as a generic type `T: AsRef>`, +/// as the above examples show. This aids the caller in two ways: they can pass +/// their arrays by reference (`&arr`) instead of explicitly calling `as_ref`, +/// and they will avoid paying a performance penalty for mutating the shape. +// +// # Safety for Implementors +// +// Despite carrying around a `ptr`, maintainers of `LayoutRef` +// must *guarantee* that the pointer is *never* dereferenced. +// No read access can be used when handling a `LayoutRef`, and +// the `ptr` can *never* be exposed to the user. +// +// The reason the pointer is included here is because some methods +// which alter the layout / shape / strides of an array must also +// alter the offset of the pointer. This is allowed, as it does not +// cause a pointer deref. +#[derive(Debug)] +pub struct LayoutRef +{ /// A non-null pointer into the buffer held by `data`; may point anywhere /// in its range. If `S: Data`, this pointer must be aligned. ptr: std::ptr::NonNull, @@ -1297,6 +1412,91 @@ where S: RawData strides: D, } +/// A reference to an *n*-dimensional array whose data is safe to read and write. +/// +/// This type's relationship to [`ArrayBase`] can be thought of a bit like the +/// relationship between [`Vec`] and [`std::slice`]: it represents a look into the +/// array, and is the [`Deref`](std::ops::Deref) target for owned, shared, and viewed +/// arrays. Most functionality is implemented on `ArrayRef`, and most functions +/// should take `&ArrayRef` instead of `&ArrayBase`. +/// +/// ## Relationship to Views +/// `ArrayRef` and [`ArrayView`] are very similar types: they both represent a +/// "look" into an array. There is one key difference: views have their own +/// shape and strides, while `ArrayRef` just points to the shape and strides of +/// whatever array it came from. +/// +/// As an example, let's write a function that takes an array, trims it +/// down to a square in-place, and then returns the sum: +/// ```rust +/// use std::cmp; +/// use std::ops::Add; +/// +/// use ndarray::{ArrayRef2, array, s}; +/// use num_traits::Zero; +/// +/// fn square_and_sum(arr: &mut ArrayRef2) -> A +/// where A: Clone + Add + Zero +/// { +/// let side_len = cmp::min(arr.nrows(), arr.ncols()); +/// arr.slice_collapse(s![..side_len, ..side_len]); +/// arr.sum() +/// } +/// +/// let mut arr = array![ +/// [ 1, 2, 3], +/// [ 4, 5, 6], +/// [ 7, 8, 9], +/// [10, 11, 12] +/// ]; +/// // Take a view of the array, excluding the first column +/// let mut view = arr.slice_mut(s![.., 1..]); +/// let sum_view = square_and_sum(&mut view); +/// assert_eq!(sum_view, 16); +/// assert_eq!(view.ncols(), 2usize); // The view has changed shape... +/// assert_eq!(view.nrows(), 2usize); +/// assert_eq!(arr.ncols(), 3usize); // ... but the original array has not +/// assert_eq!(arr.nrows(), 4usize); +/// +/// let sum_all = square_and_sum(&mut arr); +/// assert_eq!(sum_all, 45); +/// assert_eq!(arr.ncols(), 3usize); // Now the original array has changed shape +/// assert_eq!(arr.nrows(), 3usize); // because we passed it directly to the function +/// ``` +/// Critically, we can call the same function on both the view and the array itself. +/// We can see that, because the view has its own shape and strides, "squaring" it does +/// not affect the shape of the original array. Those only change when we pass the array +/// itself into the function. +/// +/// Also notice that the output of `slice_mut` is a *view*, not an `ArrayRef`. +/// This is where the analogy to `Vec`/`slice` breaks down a bit: due to limitations of +/// the Rust language, `ArrayRef` *cannot* have a different shape / stride from the +/// array from which it is dereferenced. So slicing still produces an `ArrayView`, +/// not an `ArrayRef`. +/// +/// ## Uniqueness +/// `ndarray` has copy-on-write shared data; see [`ArcArray`], for example. +/// When a copy-on-write array is passed to a function that takes `ArrayRef` as mutable +/// (i.e., `&mut ArrayRef`, like above), that array will be un-shared when it is dereferenced +/// into `ArrayRef`. In other words, having a `&mut ArrayRef` guarantees that the underlying +/// data is un-shared and safe to write to. +#[repr(transparent)] +pub struct ArrayRef(LayoutRef); + +/// A reference to an *n*-dimensional array whose data is not safe to read or write. +/// +/// This type is similar to [`ArrayRef`] but does not guarantee that its data is safe +/// to read or write; i.e., the underlying data may come from a shared array or be otherwise +/// unsafe to dereference. This type should be used sparingly and with extreme caution; +/// most of its methods either provide pointers or return [`RawArrayView`], both of +/// which tend to be full of unsafety. +/// +/// For the few times when this type is appropriate, it has the same `AsRef` semantics +/// as [`LayoutRef`]; see [its documentation on writing functions](LayoutRef#writing-functions) +/// for information on how to properly handle functionality on this type. +#[repr(transparent)] +pub struct RawRef(LayoutRef); + /// An array where the data has shared ownership and is copy on write. /// /// The `ArcArray` is parameterized by `A` for the element type and `D` for @@ -1305,8 +1505,8 @@ where S: RawData /// It can act as both an owner as the data as well as a shared reference (view /// like). /// Calling a method for mutating elements on `ArcArray`, for example -/// [`view_mut()`](ArrayBase::view_mut) or -/// [`get_mut()`](ArrayBase::get_mut), will break sharing and +/// [`view_mut()`](ArrayRef::view_mut) or +/// [`get_mut()`](ArrayRef::get_mut), will break sharing and /// require a clone of the data (if it is not uniquely held). /// /// `ArcArray` uses atomic reference counting like `Arc`, so it is `Send` and @@ -1533,14 +1733,12 @@ mod impl_internal_constructors; mod impl_constructors; mod impl_methods; +mod alias_asref; mod impl_owned_array; mod impl_special_element_types; /// Private Methods -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef { #[inline] fn broadcast_unwrap(&self, dim: E) -> ArrayView<'_, A, E> @@ -1553,11 +1751,7 @@ where D: Dimension, E: Dimension, { - panic!( - "ndarray: could not broadcast array from shape: {:?} to: {:?}", - from.slice(), - to.slice() - ) + panic!("ndarray: could not broadcast array from shape: {:?} to: {:?}", from.slice(), to.slice()) } match self.broadcast(dim.clone()) { @@ -1579,12 +1773,18 @@ where strides.slice_mut().copy_from_slice(self.strides.slice()); unsafe { ArrayView::new(ptr, dim, strides) } } +} +impl ArrayBase +where + S: Data, + D: Dimension, +{ /// Remove array axis `axis` and return the result. fn try_remove_axis(self, axis: Axis) -> ArrayBase { - let d = self.dim.try_remove_axis(axis); - let s = self.strides.try_remove_axis(axis); + let d = self.layout.dim.try_remove_axis(axis); + let s = self.layout.strides.try_remove_axis(axis); // safe because new dimension, strides allow access to a subset of old data unsafe { self.with_strides_dim(s, d) } } diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index e05740378..0f28cac1d 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -11,6 +11,8 @@ use crate::imp_prelude::*; #[cfg(feature = "blas")] use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr; use crate::numeric_util; +use crate::ArrayRef1; +use crate::ArrayRef2; use crate::{LinalgScalar, Zip}; @@ -43,8 +45,7 @@ const GEMM_BLAS_CUTOFF: usize = 7; #[allow(non_camel_case_types)] type blas_index = c_int; // blas index type -impl ArrayBase -where S: Data +impl ArrayRef { /// Perform dot product or matrix multiplication of arrays `self` and `rhs`. /// @@ -70,10 +71,8 @@ where S: Data Dot::dot(self, rhs) } - fn dot_generic(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_generic(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { debug_assert_eq!(self.len(), rhs.len()); assert!(self.len() == rhs.len()); @@ -92,19 +91,15 @@ where S: Data } #[cfg(not(feature = "blas"))] - fn dot_impl(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_impl(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { self.dot_generic(rhs) } #[cfg(feature = "blas")] - fn dot_impl(&self, rhs: &ArrayBase) -> A - where - S2: Data, - A: LinalgScalar, + fn dot_impl(&self, rhs: &ArrayRef) -> A + where A: LinalgScalar { // Use only if the vector is large enough to be worth it if self.len() >= DOT_BLAS_CUTOFF { @@ -168,14 +163,64 @@ pub trait Dot /// /// For two-dimensional arrays: a rectangular array. type Output; + fn dot(&self, rhs: &Rhs) -> Self::Output; } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +macro_rules! impl_dots { + ( + $shape1:ty, + $shape2:ty + ) => { + impl Dot> for ArrayBase + where + S: Data, + S2: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + Dot::dot(&**self, &**rhs) + } + } + + impl Dot> for ArrayBase + where + S: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayRef) -> Self::Output + { + (**self).dot(rhs) + } + } + + impl Dot> for ArrayRef + where + S: Data, + A: LinalgScalar, + { + type Output = as Dot>>::Output; + + fn dot(&self, rhs: &ArrayBase) -> Self::Output + { + self.dot(&**rhs) + } + } + }; +} + +impl_dots!(Ix1, Ix1); +impl_dots!(Ix1, Ix2); +impl_dots!(Ix2, Ix1); +impl_dots!(Ix2, Ix2); + +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = A; @@ -188,17 +233,14 @@ where /// *Note:* If enabled, uses blas `dot` for elements of `f32, f64` when memory /// layout allows. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> A + fn dot(&self, rhs: &ArrayRef) -> A { self.dot_impl(rhs) } } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array; @@ -212,14 +254,13 @@ where /// /// **Panics** if shapes are incompatible. #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array + fn dot(&self, rhs: &ArrayRef) -> Array { - rhs.t().dot(self) + (*rhs.t()).dot(self) } } -impl ArrayBase -where S: Data +impl ArrayRef { /// Perform matrix multiplication of rectangular arrays `self` and `rhs`. /// @@ -258,14 +299,12 @@ where S: Data } } -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array2; - fn dot(&self, b: &ArrayBase) -> Array2 + + fn dot(&self, b: &ArrayRef) -> Array2 { let a = self.view(); let b = b.view(); @@ -321,15 +360,13 @@ fn general_dot_shape_error(m: usize, k: usize, k2: usize, n: usize, c1: usize, c /// Return a result array with shape *M*. /// /// **Panics** if shapes are incompatible. -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array; + #[track_caller] - fn dot(&self, rhs: &ArrayBase) -> Array + fn dot(&self, rhs: &ArrayRef) -> Array { let ((m, a), n) = (self.dim(), rhs.dim()); if a != n { @@ -345,10 +382,8 @@ where } } -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef +where D: Dimension { /// Perform the operation `self += alpha * rhs` efficiently, where /// `alpha` is a scalar and `rhs` is another array. This operation is @@ -358,10 +393,8 @@ where /// /// **Panics** if broadcasting isn't possible. #[track_caller] - pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayBase) + pub fn scaled_add(&mut self, alpha: A, rhs: &ArrayRef) where - S: DataMut, - S2: Data, A: LinalgScalar, E: Dimension, { @@ -369,13 +402,13 @@ where } } -// mat_mul_impl uses ArrayView arguments to send all array kinds into +// mat_mul_impl uses ArrayRef arguments to send all array kinds into // the same instantiated implementation. #[cfg(not(feature = "blas"))] use self::mat_mul_general as mat_mul_impl; #[cfg(feature = "blas")] -fn mat_mul_impl(alpha: A, a: &ArrayView2<'_, A>, b: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>) +fn mat_mul_impl(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) where A: LinalgScalar { let ((m, k), (k2, n)) = (a.dim(), b.dim()); @@ -461,9 +494,8 @@ where A: LinalgScalar } /// C ← α A B + β C -fn mat_mul_general( - alpha: A, lhs: &ArrayView2<'_, A>, rhs: &ArrayView2<'_, A>, beta: A, c: &mut ArrayViewMut2<'_, A>, -) where A: LinalgScalar +fn mat_mul_general(alpha: A, lhs: &ArrayRef2, rhs: &ArrayRef2, beta: A, c: &mut ArrayRef2) +where A: LinalgScalar { let ((m, k), (_, n)) = (lhs.dim(), rhs.dim()); @@ -595,13 +627,8 @@ fn mat_mul_general( /// layout allows. The default matrixmultiply backend is otherwise used for /// `f32, f64` for all memory layouts. #[track_caller] -pub fn general_mat_mul( - alpha: A, a: &ArrayBase, b: &ArrayBase, beta: A, c: &mut ArrayBase, -) where - S1: Data, - S2: Data, - S3: DataMut, - A: LinalgScalar, +pub fn general_mat_mul(alpha: A, a: &ArrayRef2, b: &ArrayRef2, beta: A, c: &mut ArrayRef2) +where A: LinalgScalar { let ((m, k), (k2, n)) = (a.dim(), b.dim()); let (m2, n2) = c.dim(); @@ -624,13 +651,8 @@ pub fn general_mat_mul( /// layout allows. #[track_caller] #[allow(clippy::collapsible_if)] -pub fn general_mat_vec_mul( - alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: &mut ArrayBase, -) where - S1: Data, - S2: Data, - S3: DataMut, - A: LinalgScalar, +pub fn general_mat_vec_mul(alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: &mut ArrayRef1) +where A: LinalgScalar { unsafe { general_mat_vec_mul_impl(alpha, a, x, beta, y.raw_view_mut()) } } @@ -644,12 +666,9 @@ pub fn general_mat_vec_mul( /// The caller must ensure that the raw view is valid for writing. /// the destination may be uninitialized iff beta is zero. #[allow(clippy::collapsible_else_if)] -unsafe fn general_mat_vec_mul_impl( - alpha: A, a: &ArrayBase, x: &ArrayBase, beta: A, y: RawArrayViewMut, -) where - S1: Data, - S2: Data, - A: LinalgScalar, +unsafe fn general_mat_vec_mul_impl( + alpha: A, a: &ArrayRef2, x: &ArrayRef1, beta: A, y: RawArrayViewMut, +) where A: LinalgScalar { let ((m, k), k2) = (a.dim(), x.dim()); let m2 = y.dim(); @@ -661,7 +680,7 @@ unsafe fn general_mat_vec_mul_impl( ($ty:ty, $gemv:ident) => { if same_type::() { if let Some(layout) = get_blas_compatible_layout(&a) { - if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y) { + if blas_compat_1d::<$ty, _>(&x) && blas_compat_1d::<$ty, _>(&y.as_ref()) { // Determine stride between rows or columns. Note that the stride is // adjusted to at least `k` or `m` to handle the case of a matrix with a // trivial (length 1) dimension, since the stride for the trivial dimension @@ -674,8 +693,8 @@ unsafe fn general_mat_vec_mul_impl( // Low addr in memory pointers required for x, y let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides); let x_ptr = x.ptr.as_ptr().sub(x_offset); - let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.dim, &y.strides); - let y_ptr = y.ptr.as_ptr().sub(y_offset); + let y_offset = offset_from_low_addr_ptr_to_logical_ptr(&y.layout.dim, &y.layout.strides); + let y_ptr = y.layout.ptr.as_ptr().sub(y_offset); let x_stride = x.strides()[0] as blas_index; let y_stride = y.strides()[0] as blas_index; @@ -724,11 +743,8 @@ unsafe fn general_mat_vec_mul_impl( /// /// The kronecker product of a LxN matrix A and a MxR matrix B is a (L*M)x(N*R) /// matrix K formed by the block multiplication A_ij * B. -pub fn kron(a: &ArrayBase, b: &ArrayBase) -> Array -where - S1: Data, - S2: Data, - A: LinalgScalar, +pub fn kron(a: &ArrayRef2, b: &ArrayRef2) -> Array +where A: LinalgScalar { let dimar = a.shape()[0]; let dimac = a.shape()[1]; @@ -777,13 +793,12 @@ fn complex_array(z: Complex) -> [A; 2] } #[cfg(feature = "blas")] -fn blas_compat_1d(a: &ArrayBase) -> bool +fn blas_compat_1d(a: &RawRef) -> bool where - S: RawData, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } if a.len() > blas_index::MAX as usize { @@ -889,8 +904,7 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool /// Get BLAS compatible layout if any (C or F, preferring the former) #[cfg(feature = "blas")] -fn get_blas_compatible_layout(a: &ArrayBase) -> Option -where S: Data +fn get_blas_compatible_layout(a: &ArrayRef) -> Option { if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) { Some(BlasOrder::C) @@ -906,8 +920,7 @@ where S: Data /// /// Return leading stride (lda, ldb, ldc) of array #[cfg(feature = "blas")] -fn blas_stride(a: &ArrayBase, order: BlasOrder) -> blas_index -where S: Data +fn blas_stride(a: &ArrayRef, order: BlasOrder) -> blas_index { let axis = order.get_blas_lead_axis(); let other_axis = 1 - axis; @@ -928,13 +941,12 @@ where S: Data #[cfg(test)] #[cfg(feature = "blas")] -fn blas_row_major_2d(a: &ArrayBase) -> bool +fn blas_row_major_2d(a: &ArrayRef2) -> bool where - S: Data, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } is_blas_2d(&a.dim, &a.strides, BlasOrder::C) @@ -942,13 +954,12 @@ where #[cfg(test)] #[cfg(feature = "blas")] -fn blas_column_major_2d(a: &ArrayBase) -> bool +fn blas_column_major_2d(a: &ArrayRef2) -> bool where - S: Data, A: 'static, - S::Elem: 'static, + B: 'static, { - if !same_type::() { + if !same_type::() { return false; } is_blas_2d(&a.dim, &a.strides, BlasOrder::F) diff --git a/src/math_cell.rs b/src/math_cell.rs index 6ed1ed71f..629e5575d 100644 --- a/src/math_cell.rs +++ b/src/math_cell.rs @@ -7,7 +7,7 @@ use std::ops::{Deref, DerefMut}; /// A transparent wrapper of [`Cell`](std::cell::Cell) which is identical in every way, except /// it will implement arithmetic operators as well. /// -/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayBase::cell_view). +/// The purpose of `MathCell` is to be used from [.cell_view()](crate::ArrayRef::cell_view). /// The `MathCell` derefs to `Cell`, so all the cell's methods are available. #[repr(transparent)] #[derive(Default)] diff --git a/src/numeric/impl_float_maths.rs b/src/numeric/impl_float_maths.rs index 7a88364e3..7012a8b93 100644 --- a/src/numeric/impl_float_maths.rs +++ b/src/numeric/impl_float_maths.rs @@ -55,10 +55,9 @@ macro_rules! binary_ops { /// Element-wise math functions for any array type that contains float number. #[cfg(feature = "std")] #[cfg_attr(docsrs, doc(cfg(feature = "std")))] -impl ArrayBase +impl ArrayRef where A: 'static + Float, - S: Data, D: Dimension, { boolean_ops! { @@ -144,10 +143,9 @@ where } } -impl ArrayBase +impl ArrayRef where A: 'static + PartialOrd + Clone, - S: Data, D: Dimension, { /// Limit the values for each element, similar to NumPy's `clip` function. diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index a8a008395..27c5687ee 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -17,10 +17,8 @@ use crate::numeric_util; use crate::Slice; /// # Numerical Methods for Arrays -impl ArrayBase -where - S: Data, - D: Dimension, +impl ArrayRef +where D: Dimension { /// Return the sum of all elements in the array. /// diff --git a/src/parallel/impl_par_methods.rs b/src/parallel/impl_par_methods.rs index 7f01ea32f..189436c3d 100644 --- a/src/parallel/impl_par_methods.rs +++ b/src/parallel/impl_par_methods.rs @@ -1,5 +1,5 @@ use crate::AssignElem; -use crate::{Array, ArrayBase, DataMut, Dimension, IntoNdProducer, NdProducer, Zip}; +use crate::{Array, ArrayRef, Dimension, IntoNdProducer, NdProducer, Zip}; use super::send_producer::SendProducer; use crate::parallel::par::ParallelSplits; @@ -8,9 +8,8 @@ use crate::parallel::prelude::*; use crate::partial::Partial; /// # Parallel methods -impl ArrayBase +impl ArrayRef where - S: DataMut, D: Dimension, A: Send + Sync, { diff --git a/src/parallel/mod.rs b/src/parallel/mod.rs index 0c84baa91..2eef69307 100644 --- a/src/parallel/mod.rs +++ b/src/parallel/mod.rs @@ -19,8 +19,8 @@ //! //! The following other parallelized methods exist: //! -//! - [`ArrayBase::par_map_inplace()`] -//! - [`ArrayBase::par_mapv_inplace()`] +//! - [`ArrayRef::par_map_inplace()`](crate::ArrayRef::par_map_inplace) +//! - [`ArrayRef::par_mapv_inplace()`](crate::ArrayRef::par_mapv_inplace) //! - [`Zip::par_for_each()`] (all arities) //! - [`Zip::par_map_collect()`] (all arities) //! - [`Zip::par_map_assign_into()`] (all arities) diff --git a/src/prelude.rs b/src/prelude.rs index acf39da1a..072eb4825 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -18,11 +18,26 @@ //! ``` #[doc(no_inline)] -pub use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, CowArray, RawArrayView, RawArrayViewMut}; +pub use crate::{ + ArcArray, + Array, + ArrayBase, + ArrayRef, + ArrayView, + ArrayViewMut, + CowArray, + LayoutRef, + RawArrayView, + RawArrayViewMut, + RawRef, +}; #[doc(no_inline)] pub use crate::{Axis, Dim, Dimension}; +#[doc(no_inline)] +pub use crate::{ArrayRef0, ArrayRef1, ArrayRef2, ArrayRef3, ArrayRef4, ArrayRef5, ArrayRef6, ArrayRefD}; + #[doc(no_inline)] pub use crate::{Array0, Array1, Array2, Array3, Array4, Array5, Array6, ArrayD}; diff --git a/src/shape_builder.rs b/src/shape_builder.rs index cd790a25f..b9a4b0ab6 100644 --- a/src/shape_builder.rs +++ b/src/shape_builder.rs @@ -210,7 +210,7 @@ where D: Dimension /// This is an argument conversion trait that is used to accept an array shape and /// (optionally) an ordering argument. /// -/// See for example [`.to_shape()`](crate::ArrayBase::to_shape). +/// See for example [`.to_shape()`](crate::ArrayRef::to_shape). pub trait ShapeArg { type Dim: Dimension; diff --git a/src/slice.rs b/src/slice.rs index e6c237a92..e2ce1e727 100644 --- a/src/slice.rs +++ b/src/slice.rs @@ -430,7 +430,7 @@ unsafe impl SliceArg for [SliceInfoElem] /// `SliceInfo` instance can still be used to slice an array with dimension /// `IxDyn` as long as the number of axes matches. /// -/// [`.slice()`]: crate::ArrayBase::slice +/// [`.slice()`]: crate::ArrayRef::slice #[derive(Debug)] pub struct SliceInfo { @@ -521,7 +521,7 @@ where } /// Returns the number of dimensions of the input array for - /// [`.slice()`](crate::ArrayBase::slice). + /// [`.slice()`](crate::ArrayRef::slice). /// /// If `Din` is a fixed-size dimension type, then this is equivalent to /// `Din::NDIM.unwrap()`. Otherwise, the value is calculated by iterating @@ -536,7 +536,7 @@ where } /// Returns the number of dimensions after calling - /// [`.slice()`](crate::ArrayBase::slice) (including taking + /// [`.slice()`](crate::ArrayRef::slice) (including taking /// subviews). /// /// If `Dout` is a fixed-size dimension type, then this is equivalent to @@ -755,10 +755,10 @@ impl_slicenextdim!((), NewAxis, Ix0, Ix1); /// panic. Without the `NewAxis`, i.e. `s![0..4;2, 6, 1..5]`, /// [`.slice_collapse()`] would result in an array of shape `[2, 1, 4]`. /// -/// [`.slice()`]: crate::ArrayBase::slice -/// [`.slice_mut()`]: crate::ArrayBase::slice_mut +/// [`.slice()`]: crate::ArrayRef::slice +/// [`.slice_mut()`]: crate::ArrayRef::slice_mut /// [`.slice_move()`]: crate::ArrayBase::slice_move -/// [`.slice_collapse()`]: crate::ArrayBase::slice_collapse +/// [`.slice_collapse()`]: crate::LayoutRef::slice_collapse /// /// See also [*Slicing*](crate::ArrayBase#slicing). /// diff --git a/src/tri.rs b/src/tri.rs index b7d297fcc..6e3b90b5b 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -13,16 +13,14 @@ use num_traits::Zero; use crate::{ dimension::{is_layout_c, is_layout_f}, Array, - ArrayBase, + ArrayRef, Axis, - Data, Dimension, Zip, }; -impl ArrayBase +impl ArrayRef where - S: Data, D: Dimension, A: Clone + Zero, { @@ -32,7 +30,7 @@ where /// For arrays with `ndim` exceeding 2, `triu` will apply to the final two axes. /// For 0D and 1D arrays, `triu` will return an unchanged clone. /// - /// See also [`ArrayBase::tril`] + /// See also [`ArrayRef::tril`] /// /// ``` /// use ndarray::array; @@ -83,7 +81,9 @@ where false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 }; lower = min(lower, ncols); - dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + (*dst) + .slice_mut(s![lower..]) + .assign(&(*src).slice(s![lower..])); }); res @@ -95,7 +95,7 @@ where /// For arrays with `ndim` exceeding 2, `tril` will apply to the final two axes. /// For 0D and 1D arrays, `tril` will return an unchanged clone. /// - /// See also [`ArrayBase::triu`] + /// See also [`ArrayRef::triu`] /// /// ``` /// use ndarray::array; @@ -127,10 +127,10 @@ where if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { let mut x = self.view(); x.swap_axes(n - 2, n - 1); - let mut tril = x.triu(-k); - tril.swap_axes(n - 2, n - 1); + let mut triu = x.triu(-k); + triu.swap_axes(n - 2, n - 1); - return tril; + return triu; } let mut res = Array::zeros(self.raw_dim()); @@ -147,7 +147,9 @@ where false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow }; upper = min(upper, ncols); - dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + (*dst) + .slice_mut(s![..upper]) + .assign(&(*src).slice(s![..upper])); }); res diff --git a/src/zip/mod.rs b/src/zip/mod.rs index b58752f66..640a74d1b 100644 --- a/src/zip/mod.rs +++ b/src/zip/mod.rs @@ -76,10 +76,8 @@ fn array_layout(dim: &D, strides: &D) -> Layout } } -impl ArrayBase -where - S: RawData, - D: Dimension, +impl LayoutRef +where D: Dimension { pub(crate) fn layout_impl(&self) -> Layout { @@ -96,8 +94,8 @@ where fn broadcast_unwrap(self, shape: E) -> Self::Output { #[allow(clippy::needless_borrow)] - let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension()); - unsafe { ArrayView::new(res.ptr, res.dim, res.strides) } + let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension()); + unsafe { ArrayView::new(res.layout.ptr, res.layout.dim, res.layout.strides) } } private_impl! {} } @@ -762,7 +760,8 @@ macro_rules! map_impl { pub(crate) fn map_collect_owned(self, f: impl FnMut($($p::Item,)* ) -> R) -> ArrayBase - where S: DataOwned + where + S: DataOwned, { // safe because: all elements are written before the array is completed diff --git a/src/zip/ndproducer.rs b/src/zip/ndproducer.rs index 1d1b3391b..82f3f43a7 100644 --- a/src/zip/ndproducer.rs +++ b/src/zip/ndproducer.rs @@ -1,4 +1,5 @@ use crate::imp_prelude::*; +use crate::ArrayRef; use crate::Layout; use crate::NdIndex; #[cfg(not(feature = "std"))] @@ -156,6 +157,34 @@ where } } +/// An array reference is an n-dimensional producer of element references +/// (like ArrayView). +impl<'a, A: 'a, D> IntoNdProducer for &'a ArrayRef +where D: Dimension +{ + type Item = &'a A; + type Dim = D; + type Output = ArrayView<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view() + } +} + +/// A mutable array reference is an n-dimensional producer of mutable element +/// references (like ArrayViewMut). +impl<'a, A: 'a, D> IntoNdProducer for &'a mut ArrayRef +where D: Dimension +{ + type Item = &'a mut A; + type Dim = D; + type Output = ArrayViewMut<'a, A, D>; + fn into_producer(self) -> Self::Output + { + self.view_mut() + } +} + /// A slice is a one-dimensional producer impl<'a, A: 'a> IntoNdProducer for &'a [A] { @@ -239,7 +268,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn raw_dim(&self) -> Self::Dim { - self.raw_dim() + (***self).raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -249,7 +278,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ + (**self).as_ptr() as _ } fn layout(&self) -> Layout @@ -269,7 +298,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayView<'a, A, D> fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) + (**self).stride_of(axis) } #[inline(always)] @@ -295,7 +324,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn raw_dim(&self) -> Self::Dim { - self.raw_dim() + (***self).raw_dim() } fn equal_dim(&self, dim: &Self::Dim) -> bool @@ -305,7 +334,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn as_ptr(&self) -> *mut A { - self.as_ptr() as _ + (**self).as_ptr() as _ } fn layout(&self) -> Layout @@ -325,7 +354,7 @@ impl<'a, A, D: Dimension> NdProducer for ArrayViewMut<'a, A, D> fn stride_of(&self, axis: Axis) -> isize { - self.stride_of(axis) + (**self).stride_of(axis) } #[inline(always)] @@ -356,17 +385,17 @@ impl NdProducer for RawArrayView fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) + self.layout.dim.equal(dim) } fn as_ptr(&self) -> *const A { - self.as_ptr() + self.as_ptr() as _ } fn layout(&self) -> Layout { - self.layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *const A) -> *const A @@ -376,7 +405,10 @@ impl NdProducer for RawArrayView unsafe fn uget_ptr(&self, i: &Self::Dim) -> *const A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + self.layout + .ptr + .as_ptr() + .offset(i.index_unchecked(&self.layout.strides)) } fn stride_of(&self, axis: Axis) -> isize @@ -412,7 +444,7 @@ impl NdProducer for RawArrayViewMut fn equal_dim(&self, dim: &Self::Dim) -> bool { - self.dim.equal(dim) + self.layout.dim.equal(dim) } fn as_ptr(&self) -> *mut A @@ -422,7 +454,7 @@ impl NdProducer for RawArrayViewMut fn layout(&self) -> Layout { - self.layout_impl() + AsRef::>::as_ref(self).layout_impl() } unsafe fn as_ref(&self, ptr: *mut A) -> *mut A @@ -432,7 +464,10 @@ impl NdProducer for RawArrayViewMut unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A { - self.ptr.as_ptr().offset(i.index_unchecked(&self.strides)) + self.layout + .ptr + .as_ptr() + .offset(i.index_unchecked(&self.layout.strides)) } fn stride_of(&self, axis: Axis) -> isize diff --git a/tests/array.rs b/tests/array.rs index ac38fdd03..f1426625c 100644 --- a/tests/array.rs +++ b/tests/array.rs @@ -2820,3 +2820,11 @@ fn test_split_complex_invert_axis() assert_eq!(cmplx.re, a.mapv(|z| z.re)); assert_eq!(cmplx.im, a.mapv(|z| z.im)); } + +#[test] +fn test_slice_assign() +{ + let mut a = array![0, 1, 2, 3, 4]; + *a.slice_mut(s![1..3]) += 1; + assert_eq!(a, array![0, 2, 3, 3, 4]); +} From a0f0317315f25326f12a7419a4ee8f5756a9a76a Mon Sep 17 00:00:00 2001 From: akern40 Date: Mon, 17 Mar 2025 23:18:44 -0400 Subject: [PATCH 41/48] Fixes CI for checking against latest dependencies (#1490) Also runs latest-deps CI when someone changes the latest-deps.yaml configuration --- .github/workflows/ci.yaml | 2 ++ .github/workflows/latest-deps.yaml | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 6ebdc8432..13fb9d0d6 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -1,5 +1,7 @@ on: pull_request: + paths-ignore: + - '.github/workflows/latest-deps.yaml' merge_group: push: branches: diff --git a/.github/workflows/latest-deps.yaml b/.github/workflows/latest-deps.yaml index f2f3d8486..3b28169ec 100644 --- a/.github/workflows/latest-deps.yaml +++ b/.github/workflows/latest-deps.yaml @@ -8,6 +8,9 @@ on: # Sorry if this ruins your weekend, future maintainer... - cron: '0 12 * * FRI' workflow_dispatch: # For running manually + pull_request: + paths: + - '.github/workflows/latest-deps.yaml' env: CARGO_TERM_COLOR: always @@ -41,7 +44,7 @@ jobs: latest_deps_msrv: runs-on: ubuntu-latest - name: Check Latest Dependencies on MSRV (${{ env.MSRV }}) + name: Check Latest Dependencies on MSRV steps: - name: Check Out Repo uses: actions/checkout@v4 From 549249217eda3b75c2ba6ac8ad566aecabd34091 Mon Sep 17 00:00:00 2001 From: akern40 Date: Mon, 17 Mar 2025 23:22:17 -0400 Subject: [PATCH 42/48] Meshgrid implementation (#1477) Adds a Numpy-equivalent `meshgrid` function --- src/free_functions.rs | 408 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) diff --git a/src/free_functions.rs b/src/free_functions.rs index c1889cec8..a2ad6137c 100644 --- a/src/free_functions.rs +++ b/src/free_functions.rs @@ -9,6 +9,7 @@ use alloc::vec; #[cfg(not(feature = "std"))] use alloc::vec::Vec; +use meshgrid_impl::Meshgrid; #[allow(unused_imports)] use std::compile_error; use std::mem::{forget, size_of}; @@ -45,6 +46,8 @@ use crate::{imp_prelude::*, LayoutRef}; /// /// This macro uses `vec![]`, and has the same ownership semantics; /// elements are moved into the resulting `Array`. +/// If running with `no_std`, this may require that you `use alloc::vec` +/// before being able to use the `array!` macro. /// /// Use `array![...].into_shared()` to create an `ArcArray`. /// @@ -336,3 +339,408 @@ pub fn rcarr3(xs: &[[[A; M]; N]]) -> A { arr3(xs).into_shared() } + +/// The indexing order for [`meshgrid`]; see there for more details. +/// +/// Controls whether the first argument to `meshgrid` will fill the rows or columns of the outputs. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum MeshIndex +{ + /// Cartesian indexing. + /// + /// The first argument of `meshgrid` will repeat over the columns of the output. + /// + /// Note: this is the default in `numpy`. + XY, + /// Matrix indexing. + /// + /// The first argument of `meshgrid` will repeat over the rows of the output. + IJ, +} + +mod meshgrid_impl +{ + use super::MeshIndex; + use crate::extension::nonnull::nonnull_debug_checked_from_ptr; + use crate::{ + ArrayBase, + ArrayRef1, + ArrayView, + ArrayView2, + ArrayView3, + ArrayView4, + ArrayView5, + ArrayView6, + Axis, + Data, + Dim, + IntoDimension, + Ix1, + LayoutRef1, + }; + + /// Construct the correct strides for the `idx`-th entry into meshgrid + fn construct_strides( + arr: &LayoutRef1, idx: usize, indexing: MeshIndex, + ) -> <[usize; N] as IntoDimension>::Dim + where [usize; N]: IntoDimension + { + let mut ret = [0; N]; + if idx < 2 && indexing == MeshIndex::XY { + ret[1 - idx] = arr.stride_of(Axis(0)) as usize; + } else { + ret[idx] = arr.stride_of(Axis(0)) as usize; + } + Dim(ret) + } + + /// Construct the correct shape for the `idx`-th entry into meshgrid + fn construct_shape( + arrays: [&LayoutRef1; N], indexing: MeshIndex, + ) -> <[usize; N] as IntoDimension>::Dim + where [usize; N]: IntoDimension + { + let mut ret = arrays.map(|a| a.len()); + if indexing == MeshIndex::XY { + ret.swap(0, 1); + } + Dim(ret) + } + + /// A trait to encapsulate static dispatch for [`meshgrid`](super::meshgrid); see there for more details. + /// + /// The inputs should always be some sort of 1D array. + /// The outputs should always be ND arrays where N is the number of inputs. + /// + /// Where possible, this trait tries to return array views rather than allocating additional memory. + pub trait Meshgrid + { + type Output; + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output; + } + + macro_rules! meshgrid_body { + ($count:literal, $indexing:expr, $(($arr:expr, $idx:literal)),+) => { + { + let shape = construct_shape([$($arr),+], $indexing); + ( + $({ + let strides = construct_strides::<_, $count>($arr, $idx, $indexing); + unsafe { ArrayView::new(nonnull_debug_checked_from_ptr($arr.as_ptr() as *mut A), shape, strides) } + }),+ + ) + } + }; + } + + impl<'a, 'b, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1) + { + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(2, indexing, (arrays.0, 0), (arrays.1, 1)) + } + } + + impl<'a, 'b, S1, S2, A: 'b + 'a> Meshgrid for (&'a ArrayBase, &'b ArrayBase) + where + S1: Data, + S2: Data, + { + type Output = (ArrayView2<'a, A>, ArrayView2<'b, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1), indexing) + } + } + + impl<'a, 'b, 'c, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1) + { + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(3, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2)) + } + } + + impl<'a, 'b, 'c, S1, S2, S3, A: 'b + 'a + 'c> Meshgrid + for (&'a ArrayBase, &'b ArrayBase, &'c ArrayBase) + where + S1: Data, + S2: Data, + S3: Data, + { + type Output = (ArrayView3<'a, A>, ArrayView3<'b, A>, ArrayView3<'c, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2), indexing) + } + } + + impl<'a, 'b, 'c, 'd, A> Meshgrid for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1) + { + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(4, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3)) + } + } + + impl<'a, 'b, 'c, 'd, S1, S2, S3, S4, A: 'a + 'b + 'c + 'd> Meshgrid + for (&'a ArrayBase, &'b ArrayBase, &'c ArrayBase, &'d ArrayBase) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + { + type Output = (ArrayView4<'a, A>, ArrayView4<'b, A>, ArrayView4<'c, A>, ArrayView4<'d, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3), indexing) + } + } + + impl<'a, 'b, 'c, 'd, 'e, A> Meshgrid + for (&'a ArrayRef1, &'b ArrayRef1, &'c ArrayRef1, &'d ArrayRef1, &'e ArrayRef1) + { + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(5, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4)) + } + } + + impl<'a, 'b, 'c, 'd, 'e, S1, S2, S3, S4, S5, A: 'a + 'b + 'c + 'd + 'e> Meshgrid + for ( + &'a ArrayBase, + &'b ArrayBase, + &'c ArrayBase, + &'d ArrayBase, + &'e ArrayBase, + ) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + S5: Data, + { + type Output = (ArrayView5<'a, A>, ArrayView5<'b, A>, ArrayView5<'c, A>, ArrayView5<'d, A>, ArrayView5<'e, A>); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4), indexing) + } + } + + impl<'a, 'b, 'c, 'd, 'e, 'f, A> Meshgrid + for ( + &'a ArrayRef1, + &'b ArrayRef1, + &'c ArrayRef1, + &'d ArrayRef1, + &'e ArrayRef1, + &'f ArrayRef1, + ) + { + type Output = ( + ArrayView6<'a, A>, + ArrayView6<'b, A>, + ArrayView6<'c, A>, + ArrayView6<'d, A>, + ArrayView6<'e, A>, + ArrayView6<'f, A>, + ); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + meshgrid_body!(6, indexing, (arrays.0, 0), (arrays.1, 1), (arrays.2, 2), (arrays.3, 3), (arrays.4, 4), (arrays.5, 5)) + } + } + + impl<'a, 'b, 'c, 'd, 'e, 'f, S1, S2, S3, S4, S5, S6, A: 'a + 'b + 'c + 'd + 'e + 'f> Meshgrid + for ( + &'a ArrayBase, + &'b ArrayBase, + &'c ArrayBase, + &'d ArrayBase, + &'e ArrayBase, + &'f ArrayBase, + ) + where + S1: Data, + S2: Data, + S3: Data, + S4: Data, + S5: Data, + S6: Data, + { + type Output = ( + ArrayView6<'a, A>, + ArrayView6<'b, A>, + ArrayView6<'c, A>, + ArrayView6<'d, A>, + ArrayView6<'e, A>, + ArrayView6<'f, A>, + ); + + fn meshgrid(arrays: Self, indexing: MeshIndex) -> Self::Output + { + Meshgrid::meshgrid((&**arrays.0, &**arrays.1, &**arrays.2, &**arrays.3, &**arrays.4, &**arrays.5), indexing) + } + } +} + +/// Create coordinate matrices from coordinate vectors. +/// +/// Given an N-tuple of 1D coordinate vectors, return an N-tuple of ND coordinate arrays. +/// This is particularly useful for computing the outputs of functions with N arguments over +/// regularly spaced grids. +/// +/// The `indexing` argument can be controlled by [`MeshIndex`] to support both Cartesian and +/// matrix indexing. In the two-dimensional case, inputs of length `N` and `M` will create +/// output arrays of size `(M, N)` when using [`MeshIndex::XY`] and size `(N, M)` when using +/// [`MeshIndex::IJ`]. +/// +/// # Example +/// ``` +/// use ndarray::{array, meshgrid, MeshIndex}; +/// +/// let arr1 = array![1, 2]; +/// let arr2 = array![3, 4]; +/// let arr3 = array![5, 6]; +/// +/// // Cartesian indexing +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::XY); +/// assert_eq!(res1, array![ +/// [1, 2], +/// [1, 2], +/// ]); +/// assert_eq!(res2, array![ +/// [3, 3], +/// [4, 4], +/// ]); +/// +/// // Matrix indexing +/// let (res1, res2) = meshgrid((&arr1, &arr2), MeshIndex::IJ); +/// assert_eq!(res1, array![ +/// [1, 1], +/// [2, 2], +/// ]); +/// assert_eq!(res2, array![ +/// [3, 4], +/// [3, 4], +/// ]); +/// +/// let (_, _, res3) = meshgrid((&arr1, &arr2, &arr3), MeshIndex::XY); +/// assert_eq!(res3, array![ +/// [[5, 6], +/// [5, 6]], +/// [[5, 6], +/// [5, 6]], +/// ]); +/// ``` +pub fn meshgrid(arrays: T, indexing: MeshIndex) -> T::Output +{ + Meshgrid::meshgrid(arrays, indexing) +} + +#[cfg(test)] +mod tests +{ + use super::s; + use crate::{meshgrid, Axis, MeshIndex}; + #[cfg(not(feature = "std"))] + use alloc::vec; + + #[test] + fn test_meshgrid2() + { + let x = array![1, 2, 3]; + let y = array![4, 5, 6, 7]; + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]); + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6], [7, 7, 7]]); + + let (xx, yy) = meshgrid((&x, &y), MeshIndex::IJ); + assert_eq!(xx, array![[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3]]); + assert_eq!(yy, array![[4, 5, 6, 7], [4, 5, 6, 7], [4, 5, 6, 7]]); + } + + #[test] + fn test_meshgrid3() + { + let x = array![1, 2, 3]; + let y = array![4, 5, 6, 7]; + let z = array![-1, -2]; + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::XY); + assert_eq!(xx, array![ + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + [[1, 1], [2, 2], [3, 3]], + ]); + assert_eq!(yy, array![ + [[4, 4], [4, 4], [4, 4]], + [[5, 5], [5, 5], [5, 5]], + [[6, 6], [6, 6], [6, 6]], + [[7, 7], [7, 7], [7, 7]], + ]); + assert_eq!(zz, array![ + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2]], + ]); + + let (xx, yy, zz) = meshgrid((&x, &y, &z), MeshIndex::IJ); + assert_eq!(xx, array![ + [[1, 1], [1, 1], [1, 1], [1, 1]], + [[2, 2], [2, 2], [2, 2], [2, 2]], + [[3, 3], [3, 3], [3, 3], [3, 3]], + ]); + assert_eq!(yy, array![ + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + [[4, 4], [5, 5], [6, 6], [7, 7]], + ]); + assert_eq!(zz, array![ + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + [[-1, -2], [-1, -2], [-1, -2], [-1, -2]], + ]); + } + + #[test] + fn test_meshgrid_from_offset() + { + let x = array![1, 2, 3]; + let x = x.slice(s![1..]); + let y = array![4, 5, 6]; + let y = y.slice(s![1..]); + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[2, 3], [2, 3]]); + assert_eq!(yy, array![[5, 5], [6, 6]]); + } + + #[test] + fn test_meshgrid_neg_stride() + { + let x = array![1, 2, 3]; + let x = x.slice(s![..;-1]); + assert!(x.stride_of(Axis(0)) < 0); // Setup for test + let y = array![4, 5, 6]; + let (xx, yy) = meshgrid((&x, &y), MeshIndex::XY); + assert_eq!(xx, array![[3, 2, 1], [3, 2, 1], [3, 2, 1]]); + assert_eq!(yy, array![[4, 4, 4], [5, 5, 5], [6, 6, 6]]); + } +} From 3a4b9c7769d4fc3f7b2aaca46612014982fdbfb6 Mon Sep 17 00:00:00 2001 From: akern40 Date: Tue, 25 Mar 2025 19:13:14 +0000 Subject: [PATCH 43/48] Changes Dot impl to be on ArrayRef (#1494) Also adds an accelerate option to the blas-tests crate --- Cargo.lock | 7 +++++++ crates/blas-tests/Cargo.toml | 1 + src/linalg/impl_linalg.rs | 10 +++------- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d0530aff0..d1a513a74 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "accelerate-src" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "415ed64958754dbe991900f3940677e6a7eefb4d7367afd70d642677b0c7d19d" + [[package]] name = "adler2" version = "2.0.0" @@ -66,6 +72,7 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b95e83dc868db96e69795c0213143095f03de9dd3252f205d4ac716e4076a7e0" dependencies = [ + "accelerate-src", "blis-src", "netlib-src", "openblas-src", diff --git a/crates/blas-tests/Cargo.toml b/crates/blas-tests/Cargo.toml index ff556873a..08acc7fa5 100644 --- a/crates/blas-tests/Cargo.toml +++ b/crates/blas-tests/Cargo.toml @@ -34,3 +34,4 @@ openblas-cache = ["blas-src", "blas-src/openblas", "openblas-src/cache"] netlib = ["blas-src", "blas-src/netlib"] netlib-system = ["blas-src", "blas-src/netlib", "netlib-src/system"] blis-system = ["blas-src", "blas-src/blis", "blis-src/system"] +accelerate = ["blas-src", "blas-src/accelerate"] diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 0f28cac1d..d34fd9156 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1099,16 +1099,12 @@ mod blas_tests /// - The arrays have dimensions other than 1 or 2 /// - The array shapes are incompatible for the operation /// - For vector dot product: the vectors have different lengths -/// -impl Dot> for ArrayBase -where - S: Data, - S2: Data, - A: LinalgScalar, +impl Dot> for ArrayRef +where A: LinalgScalar { type Output = Array; - fn dot(&self, rhs: &ArrayBase) -> Self::Output + fn dot(&self, rhs: &ArrayRef) -> Self::Output { match (self.ndim(), rhs.ndim()) { (1, 1) => { From 2a5cae1dc0f01ec31a336c575434e3a88cdcba25 Mon Sep 17 00:00:00 2001 From: HuiSeomKim <126950833+NewBornRustacean@users.noreply.github.com> Date: Thu, 27 Mar 2025 08:55:03 +0900 Subject: [PATCH 44/48] Add cumprod (#1491) --- src/numeric/impl_numeric.rs | 41 +++++++++++++++++++++- tests/numeric.rs | 70 +++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 1 deletion(-) diff --git a/src/numeric/impl_numeric.rs b/src/numeric/impl_numeric.rs index 27c5687ee..ae82a482a 100644 --- a/src/numeric/impl_numeric.rs +++ b/src/numeric/impl_numeric.rs @@ -10,7 +10,7 @@ use num_traits::Float; use num_traits::One; use num_traits::{FromPrimitive, Zero}; -use std::ops::{Add, Div, Mul, Sub}; +use std::ops::{Add, Div, Mul, MulAssign, Sub}; use crate::imp_prelude::*; use crate::numeric_util; @@ -97,6 +97,45 @@ where D: Dimension sum } + /// Return the cumulative product of elements along a given axis. + /// + /// ``` + /// use ndarray::{arr2, Axis}; + /// + /// let a = arr2(&[[1., 2., 3.], + /// [4., 5., 6.]]); + /// + /// // Cumulative product along rows (axis 0) + /// assert_eq!( + /// a.cumprod(Axis(0)), + /// arr2(&[[1., 2., 3.], + /// [4., 10., 18.]]) + /// ); + /// + /// // Cumulative product along columns (axis 1) + /// assert_eq!( + /// a.cumprod(Axis(1)), + /// arr2(&[[1., 2., 6.], + /// [4., 20., 120.]]) + /// ); + /// ``` + /// + /// **Panics** if `axis` is out of bounds. + #[track_caller] + pub fn cumprod(&self, axis: Axis) -> Array + where + A: Clone + Mul + MulAssign, + D: Dimension + RemoveAxis, + { + if axis.0 >= self.ndim() { + panic!("axis is out of bounds for array of dimension"); + } + + let mut result = self.to_owned(); + result.accumulate_axis_inplace(axis, |prev, curr| *curr *= prev.clone()); + result + } + /// Return variance of elements in the array. /// /// The variance is computed using the [Welford one-pass diff --git a/tests/numeric.rs b/tests/numeric.rs index 839aba58e..7e6964812 100644 --- a/tests/numeric.rs +++ b/tests/numeric.rs @@ -75,6 +75,76 @@ fn sum_mean_prod_empty() assert_eq!(a, None); } +#[test] +fn test_cumprod_1d() +{ + let a = array![1, 2, 3, 4]; + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![1, 2, 6, 24]); +} + +#[test] +fn test_cumprod_2d() +{ + let a = array![[1, 2], [3, 4]]; + + let result_axis0 = a.cumprod(Axis(0)); + assert_eq!(result_axis0, array![[1, 2], [3, 8]]); + + let result_axis1 = a.cumprod(Axis(1)); + assert_eq!(result_axis1, array![[1, 2], [3, 12]]); +} + +#[test] +fn test_cumprod_3d() +{ + let a = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]]; + + let result_axis0 = a.cumprod(Axis(0)); + assert_eq!(result_axis0, array![[[1, 2], [3, 4]], [[5, 12], [21, 32]]]); + + let result_axis1 = a.cumprod(Axis(1)); + assert_eq!(result_axis1, array![[[1, 2], [3, 8]], [[5, 6], [35, 48]]]); + + let result_axis2 = a.cumprod(Axis(2)); + assert_eq!(result_axis2, array![[[1, 2], [3, 12]], [[5, 30], [7, 56]]]); +} + +#[test] +fn test_cumprod_empty() +{ + // For 2D empty array + let b: Array2 = Array2::zeros((0, 0)); + let result_axis0 = b.cumprod(Axis(0)); + assert_eq!(result_axis0, Array2::zeros((0, 0))); + let result_axis1 = b.cumprod(Axis(1)); + assert_eq!(result_axis1, Array2::zeros((0, 0))); +} + +#[test] +fn test_cumprod_1_element() +{ + // For 1D array with one element + let a = array![5]; + let result = a.cumprod(Axis(0)); + assert_eq!(result, array![5]); + + // For 2D array with one element + let b = array![[5]]; + let result_axis0 = b.cumprod(Axis(0)); + let result_axis1 = b.cumprod(Axis(1)); + assert_eq!(result_axis0, array![[5]]); + assert_eq!(result_axis1, array![[5]]); +} + +#[test] +#[should_panic(expected = "axis is out of bounds for array of dimension")] +fn test_cumprod_axis_out_of_bounds() +{ + let a = array![[1, 2], [3, 4]]; + let _result = a.cumprod(Axis(2)); +} + #[test] #[cfg(feature = "std")] fn var() From 4e2a70f186560292bb73e030b67d54bf6faf1a9f Mon Sep 17 00:00:00 2001 From: akern40 Date: Sun, 30 Mar 2025 00:35:18 -0400 Subject: [PATCH 45/48] Allows benchmarks that do not use linspace to run on no_std (#1495) --- benches/bench1.rs | 5 +++++ benches/construct.rs | 2 ++ benches/higher-order.rs | 5 +++++ benches/iter.rs | 6 ++++++ benches/numeric.rs | 1 + src/linalg/impl_linalg.rs | 1 - 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/benches/bench1.rs b/benches/bench1.rs index 33185844a..c07b8e3d9 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -982,6 +982,7 @@ fn dot_extended(bench: &mut test::Bencher) const MEAN_SUM_N: usize = 127; +#[cfg(feature = "std")] fn range_mat(m: Ix, n: Ix) -> Array2 { assert!(m * n != 0); @@ -990,6 +991,7 @@ fn range_mat(m: Ix, n: Ix) -> Array2 .unwrap() } +#[cfg(feature = "std")] #[bench] fn mean_axis0(bench: &mut test::Bencher) { @@ -997,6 +999,7 @@ fn mean_axis0(bench: &mut test::Bencher) bench.iter(|| a.mean_axis(Axis(0))); } +#[cfg(feature = "std")] #[bench] fn mean_axis1(bench: &mut test::Bencher) { @@ -1004,6 +1007,7 @@ fn mean_axis1(bench: &mut test::Bencher) bench.iter(|| a.mean_axis(Axis(1))); } +#[cfg(feature = "std")] #[bench] fn sum_axis0(bench: &mut test::Bencher) { @@ -1011,6 +1015,7 @@ fn sum_axis0(bench: &mut test::Bencher) bench.iter(|| a.sum_axis(Axis(0))); } +#[cfg(feature = "std")] #[bench] fn sum_axis1(bench: &mut test::Bencher) { diff --git a/benches/construct.rs b/benches/construct.rs index 278174388..380d87799 100644 --- a/benches/construct.rs +++ b/benches/construct.rs @@ -19,6 +19,7 @@ fn zeros_f64(bench: &mut Bencher) bench.iter(|| Array::::zeros((128, 128))) } +#[cfg(feature = "std")] #[bench] fn map_regular(bench: &mut test::Bencher) { @@ -28,6 +29,7 @@ fn map_regular(bench: &mut test::Bencher) bench.iter(|| a.map(|&x| 2. * x)); } +#[cfg(feature = "std")] #[bench] fn map_stride(bench: &mut test::Bencher) { diff --git a/benches/higher-order.rs b/benches/higher-order.rs index 9cc3bd961..1b4e8340c 100644 --- a/benches/higher-order.rs +++ b/benches/higher-order.rs @@ -12,6 +12,7 @@ const N: usize = 1024; const X: usize = 64; const Y: usize = 16; +#[cfg(feature = "std")] #[bench] fn map_regular(bench: &mut Bencher) { @@ -26,6 +27,7 @@ pub fn double_array(mut a: ArrayViewMut2<'_, f64>) a *= 2.0; } +#[cfg(feature = "std")] #[bench] fn map_stride_double_f64(bench: &mut Bencher) { @@ -38,6 +40,7 @@ fn map_stride_double_f64(bench: &mut Bencher) }); } +#[cfg(feature = "std")] #[bench] fn map_stride_f64(bench: &mut Bencher) { @@ -48,6 +51,7 @@ fn map_stride_f64(bench: &mut Bencher) bench.iter(|| av.map(|&x| 2. * x)); } +#[cfg(feature = "std")] #[bench] fn map_stride_u32(bench: &mut Bencher) { @@ -59,6 +63,7 @@ fn map_stride_u32(bench: &mut Bencher) bench.iter(|| av.map(|&x| 2 * x)); } +#[cfg(feature = "std")] #[bench] fn fold_axis(bench: &mut Bencher) { diff --git a/benches/iter.rs b/benches/iter.rs index 77f511745..154ee4eaf 100644 --- a/benches/iter.rs +++ b/benches/iter.rs @@ -45,6 +45,7 @@ fn iter_sum_2d_transpose(bench: &mut Bencher) bench.iter(|| a.iter().sum::()); } +#[cfg(feature = "std")] #[bench] fn iter_filter_sum_2d_u32(bench: &mut Bencher) { @@ -55,6 +56,7 @@ fn iter_filter_sum_2d_u32(bench: &mut Bencher) bench.iter(|| b.iter().filter(|&&x| x < 75).sum::()); } +#[cfg(feature = "std")] #[bench] fn iter_filter_sum_2d_f32(bench: &mut Bencher) { @@ -65,6 +67,7 @@ fn iter_filter_sum_2d_f32(bench: &mut Bencher) bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } +#[cfg(feature = "std")] #[bench] fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) { @@ -76,6 +79,7 @@ fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) bench.iter(|| b.iter().filter(|&&x| x < 75).sum::()); } +#[cfg(feature = "std")] #[bench] fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) { @@ -87,6 +91,7 @@ fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::()); } +#[cfg(feature = "std")] #[bench] fn iter_rev_step_by_contiguous(bench: &mut Bencher) { @@ -98,6 +103,7 @@ fn iter_rev_step_by_contiguous(bench: &mut Bencher) }); } +#[cfg(feature = "std")] #[bench] fn iter_rev_step_by_discontiguous(bench: &mut Bencher) { diff --git a/benches/numeric.rs b/benches/numeric.rs index e2ffa1b84..ceb57fbd7 100644 --- a/benches/numeric.rs +++ b/benches/numeric.rs @@ -9,6 +9,7 @@ const N: usize = 1024; const X: usize = 64; const Y: usize = 16; +#[cfg(feature = "std")] #[bench] fn clip(bench: &mut Bencher) { diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index d34fd9156..0bbc0b026 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1071,7 +1071,6 @@ mod blas_tests for stride in 1..=MAXSTRIDE { let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap(); - eprintln!("{:?}", m); if stride < N { assert_eq!(get_blas_compatible_layout(&m), None); From 2324d2a49cb19d848b3aa8629d63e73095d783b1 Mon Sep 17 00:00:00 2001 From: HuiSeomKim <126950833+NewBornRustacean@users.noreply.github.com> Date: Mon, 7 Apr 2025 02:31:05 +0900 Subject: [PATCH 46/48] Add partition(similar to numpy.partition) (#1498) * fn partition --- src/impl_methods.rs | 192 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 192 insertions(+) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index d2f04ef1f..42d843781 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3184,6 +3184,81 @@ impl ArrayRef f(&*prev, &mut *curr) }); } + + /// Return a partitioned copy of the array. + /// + /// Creates a copy of the array and partially sorts it around the k-th element along the given axis. + /// The k-th element will be in its sorted position, with: + /// - All elements smaller than the k-th element to its left + /// - All elements equal or greater than the k-th element to its right + /// - The ordering within each partition is undefined + /// + /// # Parameters + /// + /// * `kth` - Index to partition by. The k-th element will be in its sorted position. + /// * `axis` - Axis along which to partition. + /// + /// # Returns + /// + /// A new array of the same shape and type as the input array, with elements partitioned. + /// + /// # Examples + /// + /// ``` + /// use ndarray::prelude::*; + /// + /// let a = array![7, 1, 5, 2, 6, 0, 3, 4]; + /// let p = a.partition(3, Axis(0)); + /// + /// // The element at position 3 is now 3, with smaller elements to the left + /// // and greater elements to the right + /// assert_eq!(p[3], 3); + /// assert!(p.slice(s![..3]).iter().all(|&x| x <= 3)); + /// assert!(p.slice(s![4..]).iter().all(|&x| x >= 3)); + /// ``` + pub fn partition(&self, kth: usize, axis: Axis) -> Array + where + A: Clone + Ord + num_traits::Zero, + D: Dimension, + { + // Bounds checking + let axis_len = self.len_of(axis); + if kth >= axis_len { + panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len); + } + + let mut result = self.to_owned(); + + // Check if the first lane is contiguous + let is_contiguous = result + .lanes_mut(axis) + .into_iter() + .next() + .unwrap() + .is_contiguous(); + + if is_contiguous { + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + lane.as_slice_mut().unwrap().select_nth_unstable(kth); + }); + } else { + let mut temp_vec = vec![A::zero(); axis_len]; + + Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + Zip::from(&mut temp_vec).and(&lane).for_each(|dest, src| { + *dest = src.clone(); + }); + + temp_vec.select_nth_unstable(kth); + + Zip::from(&mut lane).and(&temp_vec).for_each(|dest, src| { + *dest = src.clone(); + }); + }); + } + + result + } } /// Transmute from A to B. @@ -3277,4 +3352,121 @@ mod tests let _a2 = a.clone(); assert_first!(a); } + + #[test] + fn test_partition_1d() + { + // Test partitioning a 1D array + let array = arr1(&[3, 1, 4, 1, 5, 9, 2, 6]); + let result = array.partition(3, Axis(0)); + // After partitioning, the element at index 3 should be in its final sorted position + assert!(result.slice(s![..3]).iter().all(|&x| x <= result[3])); + assert!(result.slice(s![4..]).iter().all(|&x| x >= result[3])); + } + + #[test] + fn test_partition_2d() + { + // Test partitioning a 2D array along both axes + let array = arr2(&[[3, 1, 4], [1, 5, 9], [2, 6, 5]]); + + // Partition along axis 0 (rows) + let result0 = array.partition(1, Axis(0)); + // After partitioning along axis 0, each column should have its middle element in the correct position + assert!(result0[[0, 0]] <= result0[[1, 0]] && result0[[2, 0]] >= result0[[1, 0]]); + assert!(result0[[0, 1]] <= result0[[1, 1]] && result0[[2, 1]] >= result0[[1, 1]]); + assert!(result0[[0, 2]] <= result0[[1, 2]] && result0[[2, 2]] >= result0[[1, 2]]); + + // Partition along axis 1 (columns) + let result1 = array.partition(1, Axis(1)); + // After partitioning along axis 1, each row should have its middle element in the correct position + assert!(result1[[0, 0]] <= result1[[0, 1]] && result1[[0, 2]] >= result1[[0, 1]]); + assert!(result1[[1, 0]] <= result1[[1, 1]] && result1[[1, 2]] >= result1[[1, 1]]); + assert!(result1[[2, 0]] <= result1[[2, 1]] && result1[[2, 2]] >= result1[[2, 1]]); + } + + #[test] + fn test_partition_3d() + { + // Test partitioning a 3D array + let array = arr3(&[[[3, 1], [4, 1]], [[5, 9], [2, 6]]]); + + // Partition along axis 0 + let result = array.partition(0, Axis(0)); + // After partitioning, each 2x2 slice should have its first element in the correct position + assert!(result[[0, 0, 0]] <= result[[1, 0, 0]]); + assert!(result[[0, 0, 1]] <= result[[1, 0, 1]]); + assert!(result[[0, 1, 0]] <= result[[1, 1, 0]]); + assert!(result[[0, 1, 1]] <= result[[1, 1, 1]]); + } + + #[test] + #[should_panic] + fn test_partition_invalid_kth() + { + let a = array![1, 2, 3, 4]; + // This should panic because kth=4 is out of bounds + let _ = a.partition(4, Axis(0)); + } + + #[test] + #[should_panic] + fn test_partition_invalid_axis() + { + let a = array![1, 2, 3, 4]; + // This should panic because axis=1 is out of bounds for a 1D array + let _ = a.partition(0, Axis(1)); + } + + #[test] + fn test_partition_contiguous_or_not() + { + // Test contiguous case (C-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Partition along axis 0 (contiguous) + let p_axis0 = a.partition(1, Axis(0)); + + // For each column, verify the partitioning: + // - First row should be <= middle row (kth element) + // - Last row should be >= middle row (kth element) + for col in 0..3 { + let kth = p_axis0[[1, col]]; + assert!(p_axis0[[0, col]] <= kth, + "Column {}: First row {} should be <= middle row {}", + col, p_axis0[[0, col]], kth); + assert!(p_axis0[[2, col]] >= kth, + "Column {}: Last row {} should be >= middle row {}", + col, p_axis0[[2, col]], kth); + } + + // Test non-contiguous case (F-order) + let a = array![ + [7, 1, 5], + [2, 6, 0], + [3, 4, 8] + ]; + + // Make array non-contiguous by transposing + let a = a.t().to_owned(); + + // Partition along axis 1 (non-contiguous) + let p_axis1 = a.partition(1, Axis(1)); + + // For each row, verify the partitioning: + // - First column should be <= middle column + // - Last column should be >= middle column + for row in 0..3 { + assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]], + "Row {}: First column {} should be <= middle column {}", + row, p_axis1[[row, 0]], p_axis1[[row, 1]]); + assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]], + "Row {}: Last column {} should be >= middle column {}", + row, p_axis1[[row, 2]], p_axis1[[row, 1]]); + } + } } From da115c919ec76314109ff280bbba9fb9e6056a65 Mon Sep 17 00:00:00 2001 From: akern40 Date: Thu, 10 Apr 2025 00:00:46 -0400 Subject: [PATCH 47/48] Fix partition on empty arrays (#1502) Closes #1501 --- src/impl_methods.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index 42d843781..ea9c9a0d5 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -3229,16 +3229,22 @@ impl ArrayRef let mut result = self.to_owned(); + // Must guarantee that the array isn't empty before checking for contiguity + if result.shape().iter().any(|s| *s == 0) { + return result; + } + // Check if the first lane is contiguous let is_contiguous = result .lanes_mut(axis) .into_iter() .next() + // This unwrap shouldn't cause panics because the array isn't empty .unwrap() .is_contiguous(); if is_contiguous { - Zip::from(result.lanes_mut(axis)).for_each(|mut lane| { + result.lanes_mut(axis).into_iter().for_each(|mut lane| { lane.as_slice_mut().unwrap().select_nth_unstable(kth); }); } else { From 8bd70b0c51e6ff3a6b6c2df94b4bcf9547dc9128 Mon Sep 17 00:00:00 2001 From: HuiSeomKim <126950833+NewBornRustacean@users.noreply.github.com> Date: Sun, 20 Apr 2025 22:24:21 +0900 Subject: [PATCH 48/48] add test case for partition on empty array (#1504) * add test case for empty array * return early when the array has zero lenth dims --------- Co-authored-by: Adam Kern --- src/impl_methods.rs | 141 +++++++++++++++++++++++++++++--------------- 1 file changed, 93 insertions(+), 48 deletions(-) diff --git a/src/impl_methods.rs b/src/impl_methods.rs index ea9c9a0d5..9a1741be6 100644 --- a/src/impl_methods.rs +++ b/src/impl_methods.rs @@ -576,11 +576,7 @@ where pub fn slice_move(mut self, info: I) -> ArrayBase where I: SliceArg { - assert_eq!( - info.in_ndim(), - self.ndim(), - "The input dimension of `info` must match the array to be sliced.", - ); + assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let out_ndim = info.out_ndim(); let mut new_dim = I::OutDim::zeros(out_ndim); let mut new_strides = I::OutDim::zeros(out_ndim); @@ -648,11 +644,7 @@ impl LayoutRef pub fn slice_collapse(&mut self, info: I) where I: SliceArg { - assert_eq!( - info.in_ndim(), - self.ndim(), - "The input dimension of `info` must match the array to be sliced.", - ); + assert_eq!(info.in_ndim(), self.ndim(), "The input dimension of `info` must match the array to be sliced.",); let mut axis = 0; info.as_ref().iter().for_each(|&ax_info| match ax_info { SliceInfoElem::Slice { start, end, step } => { @@ -1120,8 +1112,7 @@ impl ArrayRef // bounds check the indices first if let Some(max_index) = indices.iter().cloned().max() { if max_index >= axis_len { - panic!("ndarray: index {} is out of bounds in array of len {}", - max_index, self.len_of(axis)); + panic!("ndarray: index {} is out of bounds in array of len {}", max_index, self.len_of(axis)); } } // else: indices empty is ok let view = self.view().into_dimensionality::().unwrap(); @@ -1530,10 +1521,7 @@ impl ArrayRef ndassert!( axis_index < self.ndim(), - concat!( - "Window axis {} does not match array dimension {} ", - "(with array of shape {:?})" - ), + concat!("Window axis {} does not match array dimension {} ", "(with array of shape {:?})"), axis_index, self.ndim(), self.shape() @@ -3119,8 +3107,7 @@ where /// ***Panics*** if not `index < self.len_of(axis)`. pub fn remove_index(&mut self, axis: Axis, index: usize) { - assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", - index, axis.index()); + assert!(index < self.len_of(axis), "index {} must be less than length of Axis({})", index, axis.index()); let (_, mut tail) = self.view_mut().split_at(axis, index); // shift elements to the front Zip::from(tail.lanes_mut(axis)).for_each(|mut lane| lane.rotate1_front()); @@ -3193,15 +3180,16 @@ impl ArrayRef /// - All elements equal or greater than the k-th element to its right /// - The ordering within each partition is undefined /// + /// Empty arrays (i.e., those with any zero-length axes) are considered partitioned already, + /// and will be returned unchanged. + /// + /// **Panics** if `k` is out of bounds for a non-zero axis length. + /// /// # Parameters /// /// * `kth` - Index to partition by. The k-th element will be in its sorted position. /// * `axis` - Axis along which to partition. /// - /// # Returns - /// - /// A new array of the same shape and type as the input array, with elements partitioned. - /// /// # Examples /// /// ``` @@ -3221,19 +3209,19 @@ impl ArrayRef A: Clone + Ord + num_traits::Zero, D: Dimension, { - // Bounds checking - let axis_len = self.len_of(axis); - if kth >= axis_len { - panic!("partition index {} is out of bounds for axis of length {}", kth, axis_len); - } - let mut result = self.to_owned(); - // Must guarantee that the array isn't empty before checking for contiguity - if result.shape().iter().any(|s| *s == 0) { + // Return early if the array has zero-length dimensions + if self.shape().iter().any(|s| *s == 0) { return result; } + // Bounds checking. Panics if kth is out of bounds + let axis_len = self.len_of(axis); + if kth >= axis_len { + panic!("Partition index {} is out of bounds for axis {} of length {}", kth, axis.0, axis_len); + } + // Check if the first lane is contiguous let is_contiguous = result .lanes_mut(axis) @@ -3428,11 +3416,7 @@ mod tests fn test_partition_contiguous_or_not() { // Test contiguous case (C-order) - let a = array![ - [7, 1, 5], - [2, 6, 0], - [3, 4, 8] - ]; + let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; // Partition along axis 0 (contiguous) let p_axis0 = a.partition(1, Axis(0)); @@ -3442,20 +3426,24 @@ mod tests // - Last row should be >= middle row (kth element) for col in 0..3 { let kth = p_axis0[[1, col]]; - assert!(p_axis0[[0, col]] <= kth, + assert!( + p_axis0[[0, col]] <= kth, "Column {}: First row {} should be <= middle row {}", - col, p_axis0[[0, col]], kth); - assert!(p_axis0[[2, col]] >= kth, + col, + p_axis0[[0, col]], + kth + ); + assert!( + p_axis0[[2, col]] >= kth, "Column {}: Last row {} should be >= middle row {}", - col, p_axis0[[2, col]], kth); + col, + p_axis0[[2, col]], + kth + ); } // Test non-contiguous case (F-order) - let a = array![ - [7, 1, 5], - [2, 6, 0], - [3, 4, 8] - ]; + let a = array![[7, 1, 5], [2, 6, 0], [3, 4, 8]]; // Make array non-contiguous by transposing let a = a.t().to_owned(); @@ -3467,12 +3455,69 @@ mod tests // - First column should be <= middle column // - Last column should be >= middle column for row in 0..3 { - assert!(p_axis1[[row, 0]] <= p_axis1[[row, 1]], + assert!( + p_axis1[[row, 0]] <= p_axis1[[row, 1]], "Row {}: First column {} should be <= middle column {}", - row, p_axis1[[row, 0]], p_axis1[[row, 1]]); - assert!(p_axis1[[row, 2]] >= p_axis1[[row, 1]], + row, + p_axis1[[row, 0]], + p_axis1[[row, 1]] + ); + assert!( + p_axis1[[row, 2]] >= p_axis1[[row, 1]], "Row {}: Last column {} should be >= middle column {}", - row, p_axis1[[row, 2]], p_axis1[[row, 1]]); + row, + p_axis1[[row, 2]], + p_axis1[[row, 1]] + ); } } + + #[test] + fn test_partition_empty() + { + // Test 1D empty array + let empty1d = Array1::::zeros(0); + let result1d = empty1d.partition(0, Axis(0)); + assert_eq!(result1d.len(), 0); + + // Test 1D empty array with kth out of bounds + let result1d_out_of_bounds = empty1d.partition(1, Axis(0)); + assert_eq!(result1d_out_of_bounds.len(), 0); + + // Test 2D empty array + let empty2d = Array2::::zeros((0, 3)); + let result2d = empty2d.partition(0, Axis(0)); + assert_eq!(result2d.shape(), &[0, 3]); + + // Test 2D empty array with zero columns + let empty2d_cols = Array2::::zeros((2, 0)); + let result2d_cols = empty2d_cols.partition(0, Axis(1)); + assert_eq!(result2d_cols.shape(), &[2, 0]); + + // Test 3D empty array + let empty3d = Array3::::zeros((0, 2, 3)); + let result3d = empty3d.partition(0, Axis(0)); + assert_eq!(result3d.shape(), &[0, 2, 3]); + + // Test 3D empty array with zero in middle dimension + let empty3d_mid = Array3::::zeros((2, 0, 3)); + let result3d_mid = empty3d_mid.partition(0, Axis(1)); + assert_eq!(result3d_mid.shape(), &[2, 0, 3]); + + // Test 4D empty array + let empty4d = Array4::::zeros((0, 2, 3, 4)); + let result4d = empty4d.partition(0, Axis(0)); + assert_eq!(result4d.shape(), &[0, 2, 3, 4]); + + // Test empty array with non-zero dimensions in other axes + let empty_mixed = Array2::::zeros((0, 5)); + let result_mixed = empty_mixed.partition(0, Axis(0)); + assert_eq!(result_mixed.shape(), &[0, 5]); + + // Test empty array with negative strides + let arr = Array2::::zeros((3, 3)); + let empty_slice = arr.slice(s![0..0, ..]); + let result_slice = empty_slice.partition(0, Axis(0)); + assert_eq!(result_slice.shape(), &[0, 3]); + } }