Thanks to visit codestin.com
Credit goes to docs.rs

arrayfire/core/
index.rs

1use super::array::Array;
2use super::defines::AfError;
3use super::error::HANDLE_ERROR;
4use super::seq::Seq;
5use super::util::{af_array, af_index_t, dim_t, HasAfEnum, IndexableType};
6
7use libc::{c_double, c_int, c_uint};
8use std::default::Default;
9use std::marker::PhantomData;
10use std::mem;
11
12extern "C" {
13    fn af_create_indexers(indexers: *mut af_index_t) -> c_int;
14    fn af_set_array_indexer(indexer: af_index_t, idx: af_array, dim: dim_t) -> c_int;
15    fn af_set_seq_indexer(
16        indexer: af_index_t,
17        idx: *const SeqInternal,
18        dim: dim_t,
19        is_batch: bool,
20    ) -> c_int;
21    fn af_release_indexers(indexers: af_index_t) -> c_int;
22
23    fn af_index(
24        out: *mut af_array,
25        input: af_array,
26        ndims: c_uint,
27        index: *const SeqInternal,
28    ) -> c_int;
29    fn af_lookup(out: *mut af_array, arr: af_array, indices: af_array, dim: c_uint) -> c_int;
30    fn af_assign_seq(
31        out: *mut af_array,
32        lhs: af_array,
33        ndims: c_uint,
34        indices: *const SeqInternal,
35        rhs: af_array,
36    ) -> c_int;
37    fn af_index_gen(
38        out: *mut af_array,
39        input: af_array,
40        ndims: dim_t,
41        indices: af_index_t,
42    ) -> c_int;
43    fn af_assign_gen(
44        out: *mut af_array,
45        lhs: af_array,
46        ndims: dim_t,
47        indices: af_index_t,
48        rhs: af_array,
49    ) -> c_int;
50}
51
52/// Struct to manage an array of resources of type `af_indexer_t`(ArrayFire C struct)
53///
54/// ## Sharing Across Threads
55///
56/// While sharing an Indexer object with other threads, just move it across threads. At the
57/// moment, one cannot share borrowed references across threads.
58///
59/// # Examples
60///
61/// Given below are examples illustrating correct and incorrect usage of Indexer struct.
62///
63/// <h3> Correct Usage </h3>
64///
65/// ```rust
66/// use arrayfire::{Array, Dim4, randu, index_gen, Indexer};
67///
68/// // Always be aware of the fact that, the `Seq` or `Array` objects
69/// // that we intend to use for indexing via `Indexer` have to outlive
70/// // the `Indexer` object created in this context.
71///
72/// let dims    = Dim4::new(&[1, 3, 1, 1]);
73/// let indices = [1u8, 0, 1];
74/// let idx     = Array::new(&indices, dims);
75/// let values  = [2.0f32, 5.0, 6.0];
76/// let arr     = Array::new(&values, dims);
77///
78/// let mut idxr = Indexer::default();
79///
80/// // `idx` is created much before idxr, thus will
81/// // stay in scope at least as long as idxr
82/// idxr.set_index(&idx, 0, None);
83///
84/// index_gen(&arr, idxr);
85/// ```
86///
87/// <h3> Incorrect Usage </h3>
88///
89/// ```rust,ignore
90/// // Say, you create an Array on the fly and try
91/// // to call set_index, it will throw the given below
92/// // error or something similar to that
93/// idxr.set_index(&Array::new(&[1, 0, 1], dims), 0, None);
94/// ```
95///
96/// ```text
97/// error: borrowed value does not live long enough
98///   --> <anon>:16:55
99///   |
100///16 | idxr.set_index(&Array::new(&[1, 0, 1], dims), 0, None);
101///   |                 ----------------------------          ^ temporary value dropped here while still borrowed
102///   |                 |
103///   |                 temporary value created here
104///...
105///19 | }
106///   | - temporary value needs to live until here
107///   |
108///   = note: consider using a `let` binding to increase its lifetime
109/// ```
110pub struct Indexer<'object> {
111    handle: af_index_t,
112    count: usize,
113    marker: PhantomData<&'object ()>,
114}
115
116unsafe impl<'object> Send for Indexer<'object> {}
117
118/// Trait bound indicating indexability
119///
120/// Any object to be able to be passed on to [Indexer::set_index()](./struct.Indexer.html#method.set_index) method  should implement this trait with appropriate implementation of `set` method.
121pub trait Indexable {
122    /// Set indexing object for a given dimension
123    ///
124    /// `is_batch` parameter is not used in most cases as it has been provided in
125    /// ArrayFire C-API to enable GFOR construct in ArrayFire C++ API. This type
126    /// of construct/idea is not exposed in rust wrapper yet. So, the user would
127    /// just need to pass `None` to this parameter while calling this function.
128    /// Since we can't have default default values and we wanted to keep this
129    /// parameter for future use cases, we just made it an `std::Option`.
130    ///
131    /// # Parameters
132    ///
133    /// - `idxr` is mutable reference to [Indexer](./struct.Indexer.html) object which will
134    ///   be modified to set `self` indexable along `dim` dimension.
135    /// - `dim` is the dimension along which `self` indexable will be used for indexing.
136    /// - `is_batch` is only used if `self` is [Seq](./struct.Seq.html) to indicate if indexing
137    ///   along `dim` is a batched operation.
138    fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>);
139}
140
141/// Enables [Array](./struct.Array.html) to be used to index another Array
142///
143/// This is used in functions [index_gen](./fn.index_gen.html) and
144/// [assign_gen](./fn.assign_gen.html)
145impl<T> Indexable for Array<T>
146where
147    T: HasAfEnum + IndexableType,
148{
149    fn set(&self, idxr: &mut Indexer, dim: u32, _is_batch: Option<bool>) {
150        unsafe {
151            let err_val = af_set_array_indexer(idxr.get(), self.get(), dim as dim_t);
152            HANDLE_ERROR(AfError::from(err_val));
153        }
154    }
155}
156
157/// Enables [Seq](./struct.Seq.html) to be used to index another Array
158///
159/// This is used in functions [index_gen](./fn.index_gen.html) and
160/// [assign_gen](./fn.assign_gen.html)
161impl<T> Indexable for Seq<T>
162where
163    c_double: From<T>,
164    T: Copy + IndexableType,
165{
166    fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>) {
167        unsafe {
168            let err_val = af_set_seq_indexer(
169                idxr.get(),
170                &SeqInternal::from_seq(self) as *const SeqInternal,
171                dim as dim_t,
172                match is_batch {
173                    Some(value) => value,
174                    None => false,
175                },
176            );
177            HANDLE_ERROR(AfError::from(err_val));
178        }
179    }
180}
181
182impl<'object> Default for Indexer<'object> {
183    fn default() -> Self {
184        unsafe {
185            let mut temp: af_index_t = std::ptr::null_mut();
186            let err_val = af_create_indexers(&mut temp as *mut af_index_t);
187            HANDLE_ERROR(AfError::from(err_val));
188            Self {
189                handle: temp,
190                count: 0,
191                marker: PhantomData,
192            }
193        }
194    }
195}
196
197impl<'object> Indexer<'object> {
198    /// Create a new Indexer object and set the dimension specific index objects later
199    #[deprecated(since = "3.7.0", note = "Use Indexer::default() instead")]
200    pub fn new() -> Self {
201        unsafe {
202            let mut temp: af_index_t = std::ptr::null_mut();
203            let err_val = af_create_indexers(&mut temp as *mut af_index_t);
204            HANDLE_ERROR(AfError::from(err_val));
205            Self {
206                handle: temp,
207                count: 0,
208                marker: PhantomData,
209            }
210        }
211    }
212
213    /// Set either [Array](./struct.Array.html) or [Seq](./struct.Seq.html) to index an Array along `idx` dimension
214    pub fn set_index<'s, T>(&'s mut self, idx: &'object T, dim: u32, is_batch: Option<bool>)
215    where
216        T: Indexable + 'object,
217    {
218        idx.set(self, dim, is_batch);
219        self.count += 1;
220    }
221
222    /// Get number of indexing objects set
223    pub fn len(&self) -> usize {
224        self.count
225    }
226
227    /// Check if any indexing objects are set
228    pub fn is_empty(&self) -> bool {
229        self.count == 0
230    }
231
232    /// Get native(ArrayFire) resource handle
233    pub unsafe fn get(&self) -> af_index_t {
234        self.handle
235    }
236}
237
238impl<'object> Drop for Indexer<'object> {
239    fn drop(&mut self) {
240        unsafe {
241            let ret_val = af_release_indexers(self.handle as af_index_t);
242            match ret_val {
243                0 => (),
244                _ => panic!("Failed to release indexers resource: {}", ret_val),
245            }
246        }
247    }
248}
249
250/// Indexes the `input` Array using `seqs` Sequences
251///
252/// # Examples
253///
254/// ```rust
255/// use arrayfire::{Dim4, Seq, index, randu, print};
256/// let dims = Dim4::new(&[5, 5, 1, 1]);
257/// let a = randu::<f32>(dims);
258/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
259/// let sub  = index(&a, seqs);
260/// println!("a(seq(1, 3, 1), span)");
261/// print(&sub);
262/// ```
263pub fn index<IO, T>(input: &Array<IO>, seqs: &[Seq<T>]) -> Array<IO>
264where
265    c_double: From<T>,
266    IO: HasAfEnum,
267    T: Copy + HasAfEnum + IndexableType,
268{
269    let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
270    unsafe {
271        let mut temp: af_array = std::ptr::null_mut();
272        let err_val = af_index(
273            &mut temp as *mut af_array,
274            input.get(),
275            seqs.len() as u32,
276            seqs.as_ptr() as *const SeqInternal,
277        );
278        HANDLE_ERROR(AfError::from(err_val));
279        temp.into()
280    }
281}
282
283/// Extract `row_num` row from `input` Array
284///
285/// # Examples
286///
287/// ```rust
288/// use arrayfire::{Dim4, randu, row, print};
289/// let dims = Dim4::new(&[5, 5, 1, 1]);
290/// let a = randu::<f32>(dims);
291/// println!("Grab last row of the random matrix");
292/// print(&a);
293/// print(&row(&a, 4));
294/// ```
295pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
296where
297    T: HasAfEnum,
298{
299    index(
300        input,
301        &[
302            Seq::new(row_num as f64, row_num as f64, 1.0),
303            Seq::default(),
304        ],
305    )
306}
307
308/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
309pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: i64)
310where
311    T: HasAfEnum,
312{
313    let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
314    if inout.dims().ndims() > 1 {
315        seqs.push(Seq::default());
316    }
317    assign_seq(inout, &seqs, new_row)
318}
319
320/// Get an Array with all rows from `first` to `last` in the `input` Array
321pub fn rows<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
322where
323    T: HasAfEnum,
324{
325    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
326    index(
327        input,
328        &[Seq::new(first as f64, last as f64, step), Seq::default()],
329    )
330}
331
332/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
333pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: i64, last: i64)
334where
335    T: HasAfEnum,
336{
337    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
338    let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
339    assign_seq(inout, &seqs, new_rows)
340}
341
342/// Extract `col_num` col from `input` Array
343///
344/// # Examples
345///
346/// ```rust
347/// use arrayfire::{Dim4, randu, col, print};
348/// let dims = Dim4::new(&[5, 5, 1, 1]);
349/// let a = randu::<f32>(dims);
350/// print(&a);
351/// println!("Grab last col of the random matrix");
352/// print(&col(&a, 4));
353/// ```
354pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
355where
356    T: HasAfEnum,
357{
358    index(
359        input,
360        &[
361            Seq::default(),
362            Seq::new(col_num as f64, col_num as f64, 1.0),
363        ],
364    )
365}
366
367/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
368pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
369where
370    T: HasAfEnum,
371{
372    let seqs = [
373        Seq::default(),
374        Seq::new(col_num as f64, col_num as f64, 1.0),
375    ];
376    assign_seq(inout, &seqs, new_col)
377}
378
379/// Get all cols from `first` to `last` in the `input` Array
380pub fn cols<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
381where
382    T: HasAfEnum,
383{
384    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
385    index(
386        input,
387        &[Seq::default(), Seq::new(first as f64, last as f64, step)],
388    )
389}
390
391/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
392pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: i64, last: i64)
393where
394    T: HasAfEnum,
395{
396    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
397    let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
398    assign_seq(inout, &seqs, new_cols)
399}
400
401/// Get `slice_num`^th slice from `input` Array
402///
403/// Slices indicate that the indexing is along 3rd dimension
404pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
405where
406    T: HasAfEnum,
407{
408    let seqs = [
409        Seq::default(),
410        Seq::default(),
411        Seq::new(slice_num as f64, slice_num as f64, 1.0),
412    ];
413    index(input, &seqs)
414}
415
416/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
417///
418/// Slices indicate that the indexing is along 3rd dimension
419pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
420where
421    T: HasAfEnum,
422{
423    let seqs = [
424        Seq::default(),
425        Seq::default(),
426        Seq::new(slice_num as f64, slice_num as f64, 1.0),
427    ];
428    assign_seq(inout, &seqs, new_slice)
429}
430
431/// Get slices from `first` to `last` in `input` Array
432///
433/// Slices indicate that the indexing is along 3rd dimension
434pub fn slices<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
435where
436    T: HasAfEnum,
437{
438    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
439    let seqs = [
440        Seq::default(),
441        Seq::default(),
442        Seq::new(first as f64, last as f64, step),
443    ];
444    index(input, &seqs)
445}
446
447/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
448///
449/// Slices indicate that the indexing is along 3rd dimension
450pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: i64, last: i64)
451where
452    T: HasAfEnum,
453{
454    let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
455    let seqs = [
456        Seq::default(),
457        Seq::default(),
458        Seq::new(first as f64, last as f64, step),
459    ];
460    assign_seq(inout, &seqs, new_slices)
461}
462
463/// Lookup(hash) an Array using another Array
464///
465/// Given a dimension `seq_dim`, `indices` are lookedup in `input` and returned as a new
466/// Array if found
467pub fn lookup<T, I>(input: &Array<T>, indices: &Array<I>, seq_dim: i32) -> Array<T>
468where
469    T: HasAfEnum,
470    I: HasAfEnum + IndexableType,
471{
472    unsafe {
473        let mut temp: af_array = std::ptr::null_mut();
474        let err_val = af_lookup(
475            &mut temp as *mut af_array,
476            input.get() as af_array,
477            indices.get() as af_array,
478            seq_dim as c_uint,
479        );
480        HANDLE_ERROR(AfError::from(err_val));
481        temp.into()
482    }
483}
484
485/// Assign(copy) content of an Array to another Array indexed by Sequences
486///
487/// Assign `rhs` to `lhs` after indexing `lhs`
488///
489/// # Examples
490///
491/// ```rust
492/// use arrayfire::{constant, Dim4, Seq, assign_seq, print};
493/// let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
494/// print(&a);
495/// // 2.0 2.0 2.0
496/// // 2.0 2.0 2.0
497/// // 2.0 2.0 2.0
498/// // 2.0 2.0 2.0
499/// // 2.0 2.0 2.0
500///
501/// let b    = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
502/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
503/// assign_seq(&mut a, seqs, &b);
504///
505/// print(&a);
506/// // 2.0 2.0 2.0
507/// // 1.0 1.0 1.0
508/// // 1.0 1.0 1.0
509/// // 1.0 1.0 1.0
510/// // 2.0 2.0 2.0
511/// ```
512pub fn assign_seq<T, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
513where
514    c_double: From<T>,
515    I: HasAfEnum,
516    T: Copy + IndexableType,
517{
518    let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
519    unsafe {
520        let mut temp: af_array = std::ptr::null_mut();
521        let err_val = af_assign_seq(
522            &mut temp as *mut af_array,
523            lhs.get() as af_array,
524            seqs.len() as c_uint,
525            seqs.as_ptr() as *const SeqInternal,
526            rhs.get() as af_array,
527        );
528        HANDLE_ERROR(AfError::from(err_val));
529
530        let modified = temp.into();
531        let _old_arr = mem::replace(lhs, modified);
532    }
533}
534
535/// Index an Array using any combination of Array's and Sequence's
536///
537/// # Examples
538///
539/// ```rust
540/// use arrayfire::{Array, Dim4, Seq, print, randu, index_gen, Indexer};
541/// let values: [f32; 3] = [1.0, 2.0, 3.0];
542/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
543/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
544/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
545/// // [5 3 1 1]
546/// //     0.0000     0.2190     0.3835
547/// //     0.1315     0.0470     0.5194
548/// //     0.7556     0.6789     0.8310
549/// //     0.4587     0.6793     0.0346
550/// //     0.5328     0.9347     0.0535
551///
552///
553/// let mut idxrs = Indexer::default();
554/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
555/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
556///
557/// let sub2 = index_gen(&a, idxrs);
558/// println!("a(indices, seq(0, 2, 1))"); print(&sub2);
559/// // [3 3 1 1]
560/// //     0.1315     0.0470     0.5194
561/// //     0.7556     0.6789     0.8310
562/// //     0.4587     0.6793     0.0346
563/// ```
564pub fn index_gen<T>(input: &Array<T>, indices: Indexer) -> Array<T>
565where
566    T: HasAfEnum,
567{
568    unsafe {
569        let mut temp: af_array = std::ptr::null_mut();
570        let err_val = af_index_gen(
571            &mut temp as *mut af_array,
572            input.get() as af_array,
573            indices.len() as dim_t,
574            indices.get() as af_index_t,
575        );
576        HANDLE_ERROR(AfError::from(err_val));
577        temp.into()
578    }
579}
580
581/// Assign an Array to another after indexing it using any combination of Array's and Sequence's
582///
583/// # Examples
584///
585/// ```rust
586/// use arrayfire::{Array, Dim4, Seq, print, randu, constant, Indexer, assign_gen};
587/// let values: [f32; 3] = [1.0, 2.0, 3.0];
588/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
589/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
590/// let mut a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
591/// // [5 3 1 1]
592/// //     0.0000     0.2190     0.3835
593/// //     0.1315     0.0470     0.5194
594/// //     0.7556     0.6789     0.8310
595/// //     0.4587     0.6793     0.0346
596/// //     0.5328     0.9347     0.0535
597///
598/// let b    = constant(2.0 as f32, Dim4::new(&[3, 3, 1, 1]));
599///
600/// let mut idxrs = Indexer::default();
601/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
602/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
603///
604/// assign_gen(&mut a, &idxrs, &b);
605/// println!("a(indices, seq(0, 2, 1))"); print(&a);
606/// // [5 3 1 1]
607/// //     0.0000     0.2190     0.3835
608/// //     2.0000     2.0000     2.0000
609/// //     2.0000     2.0000     2.0000
610/// //     2.0000     2.0000     2.0000
611/// //     0.5328     0.9347     0.0535
612/// ```
613pub fn assign_gen<T>(lhs: &mut Array<T>, indices: &Indexer, rhs: &Array<T>)
614where
615    T: HasAfEnum,
616{
617    unsafe {
618        let mut temp: af_array = std::ptr::null_mut();
619        let err_val = af_assign_gen(
620            &mut temp as *mut af_array,
621            lhs.get() as af_array,
622            indices.len() as dim_t,
623            indices.get() as af_index_t,
624            rhs.get() as af_array,
625        );
626        HANDLE_ERROR(AfError::from(err_val));
627
628        let modified = temp.into();
629        let _old_arr = mem::replace(lhs, modified);
630    }
631}
632
633#[repr(C)]
634struct SeqInternal {
635    begin: c_double,
636    end: c_double,
637    step: c_double,
638}
639
640impl SeqInternal {
641    fn from_seq<T>(s: &Seq<T>) -> Self
642    where
643        c_double: From<T>,
644        T: Copy + IndexableType,
645    {
646        Self {
647            begin: From::from(s.begin()),
648            end: From::from(s.end()),
649            step: From::from(s.step()),
650        }
651    }
652}
653
654#[cfg(test)]
655mod tests {
656    use super::super::array::Array;
657    use super::super::data::constant;
658    use super::super::device::set_device;
659    use super::super::dim4::Dim4;
660    use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
661    use super::super::index::{cols, rows};
662    use super::super::random::randu;
663    use super::super::seq::Seq;
664
665    use crate::{dim4, seq, view};
666
667    #[test]
668    fn non_macro_seq_index() {
669        set_device(0);
670        // ANCHOR: non_macro_seq_index
671        let dims = Dim4::new(&[5, 5, 1, 1]);
672        let a = randu::<f32>(dims);
673        //af_print!("a", a);
674        //a
675        //[5 5 1 1]
676        //    0.3990     0.5160     0.8831     0.9107     0.6688
677        //    0.6720     0.3932     0.0621     0.9159     0.8434
678        //    0.5339     0.2706     0.7089     0.0231     0.1328
679        //    0.1386     0.9455     0.9434     0.2330     0.2657
680        //    0.7353     0.1587     0.1227     0.2220     0.2299
681
682        // Index array using sequences
683        let seqs = &[Seq::new(1u32, 3, 1), Seq::default()];
684        let _sub = index(&a, seqs);
685        //af_print!("a(seq(1,3,1), span)", sub);
686        // [3 5 1 1]
687        //     0.6720     0.3932     0.0621     0.9159     0.8434
688        //     0.5339     0.2706     0.7089     0.0231     0.1328
689        //     0.1386     0.9455     0.9434     0.2330     0.2657
690        // ANCHOR_END: non_macro_seq_index
691    }
692
693    #[test]
694    fn seq_index() {
695        set_device(0);
696        // ANCHOR: seq_index
697        let dims = dim4!(5, 5, 1, 1);
698        let a = randu::<f32>(dims);
699        let first3 = seq!(1:3:1);
700        let allindim2 = seq!();
701        let _sub = view!(a[first3, allindim2]);
702        // ANCHOR_END: seq_index
703    }
704
705    #[test]
706    fn non_macro_seq_assign() {
707        set_device(0);
708        // ANCHOR: non_macro_seq_assign
709        let mut a = constant(2.0 as f32, dim4!(5, 3));
710        //print(&a);
711        // 2.0 2.0 2.0
712        // 2.0 2.0 2.0
713        // 2.0 2.0 2.0
714        // 2.0 2.0 2.0
715        // 2.0 2.0 2.0
716
717        let b = constant(1.0 as f32, dim4!(3, 3));
718        let seqs = [seq!(1:3:1), seq!()];
719        assign_seq(&mut a, &seqs, &b);
720        //print(&a);
721        // 2.0 2.0 2.0
722        // 1.0 1.0 1.0
723        // 1.0 1.0 1.0
724        // 1.0 1.0 1.0
725        // 2.0 2.0 2.0
726        // ANCHOR_END: non_macro_seq_assign
727    }
728
729    #[test]
730    fn non_macro_seq_array_index() {
731        set_device(0);
732        // ANCHOR: non_macro_seq_array_index
733        let values: [f32; 3] = [1.0, 2.0, 3.0];
734        let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
735        let seq4gen = Seq::new(0.0, 2.0, 1.0);
736        let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
737        // [5 3 1 1]
738        //     0.0000     0.2190     0.3835
739        //     0.1315     0.0470     0.5194
740        //     0.7556     0.6789     0.8310
741        //     0.4587     0.6793     0.0346
742        //     0.5328     0.9347     0.0535
743
744        let mut idxrs = Indexer::default();
745        idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
746        idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation
747
748        let _sub2 = index_gen(&a, idxrs);
749        //println!("a(indices, seq(0, 2, 1))"); print(&sub2);
750        // [3 3 1 1]
751        //     0.1315     0.0470     0.5194
752        //     0.7556     0.6789     0.8310
753        //     0.4587     0.6793     0.0346
754        // ANCHOR_END: non_macro_seq_array_index
755    }
756
757    #[test]
758    fn seq_array_index() {
759        set_device(0);
760        // ANCHOR: seq_array_index
761        let values: [f32; 3] = [1.0, 2.0, 3.0];
762        let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
763        let seq4gen = seq!(0:2:1);
764        let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
765        let _sub2 = view!(a[indices, seq4gen]);
766        // ANCHOR_END: seq_array_index
767    }
768
769    #[test]
770    fn non_macro_seq_array_assign() {
771        set_device(0);
772        // ANCHOR: non_macro_seq_array_assign
773        let values: [f32; 3] = [1.0, 2.0, 3.0];
774        let indices = Array::new(&values, dim4!(3, 1, 1, 1));
775        let seq4gen = seq!(0:2:1);
776        let mut a = randu::<f32>(dim4!(5, 3, 1, 1));
777        // [5 3 1 1]
778        //     0.0000     0.2190     0.3835
779        //     0.1315     0.0470     0.5194
780        //     0.7556     0.6789     0.8310
781        //     0.4587     0.6793     0.0346
782        //     0.5328     0.9347     0.0535
783
784        let b = constant(2.0 as f32, dim4!(3, 3, 1, 1));
785
786        let mut idxrs = Indexer::default();
787        idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
788        idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation
789
790        let _sub2 = assign_gen(&mut a, &idxrs, &b);
791        //println!("a(indices, seq(0, 2, 1))"); print(&sub2);
792        // [5 3 1 1]
793        //     0.0000     0.2190     0.3835
794        //     2.0000     2.0000     2.0000
795        //     2.0000     2.0000     2.0000
796        //     2.0000     2.0000     2.0000
797        //     0.5328     0.9347     0.0535
798        // ANCHOR_END: non_macro_seq_array_assign
799    }
800
801    #[test]
802    fn setrow() {
803        set_device(0);
804        // ANCHOR: setrow
805        let a = randu::<f32>(dim4!(5, 5, 1, 1));
806        //print(&a);
807        // [5 5 1 1]
808        //     0.6010     0.5497     0.1583     0.3636     0.6755
809        //     0.0278     0.2864     0.3712     0.4165     0.6105
810        //     0.9806     0.3410     0.3543     0.5814     0.5232
811        //     0.2126     0.7509     0.6450     0.8962     0.5567
812        //     0.0655     0.4105     0.9675     0.3712     0.7896
813        let _r = row(&a, 4);
814        // [1 5 1 1]
815        //     0.0655     0.4105     0.9675     0.3712     0.7896
816        let _c = col(&a, 4);
817        // [5 1 1 1]
818        //     0.6755
819        //     0.6105
820        //     0.5232
821        //     0.5567
822        //     0.7896
823        // ANCHOR_END: setrow
824    }
825
826    #[test]
827    fn get_row() {
828        set_device(0);
829        // ANCHOR: get_row
830        let a = randu::<f32>(dim4!(5, 5));
831        // [5 5 1 1]
832        //     0.6010     0.5497     0.1583     0.3636     0.6755
833        //     0.0278     0.2864     0.3712     0.4165     0.6105
834        //     0.9806     0.3410     0.3543     0.5814     0.5232
835        //     0.2126     0.7509     0.6450     0.8962     0.5567
836        //     0.0655     0.4105     0.9675     0.3712     0.7896
837        let _r = row(&a, -1);
838        // [1 5 1 1]
839        //     0.0655     0.4105     0.9675     0.3712     0.7896
840        let _c = col(&a, -1);
841        // [5 1 1 1]
842        //     0.6755
843        //     0.6105
844        //     0.5232
845        //     0.5567
846        //     0.7896
847        // ANCHOR_END: get_row
848    }
849
850    #[test]
851    fn get_rows() {
852        set_device(0);
853        // ANCHOR: get_rows
854        let a = randu::<f32>(dim4!(5, 5));
855        // [5 5 1 1]
856        //     0.6010     0.5497     0.1583     0.3636     0.6755
857        //     0.0278     0.2864     0.3712     0.4165     0.6105
858        //     0.9806     0.3410     0.3543     0.5814     0.5232
859        //     0.2126     0.7509     0.6450     0.8962     0.5567
860        //     0.0655     0.4105     0.9675     0.3712     0.7896
861        let _r = rows(&a, -1, -2);
862        // [2 5 1 1]
863        //     0.2126     0.7509     0.6450     0.8962     0.5567
864        //     0.0655     0.4105     0.9675     0.3712     0.7896
865        let _c = cols(&a, -1, -3);
866        // [5 3 1 1]
867        //     0.1583     0.3636     0.6755
868        //     0.3712     0.4165     0.6105
869        //     0.3543     0.5814     0.5232
870        //     0.6450     0.8962     0.5567
871        //     0.9675     0.3712     0.7896
872        // ANCHOR_END: get_rows
873    }
874}