arrayfire/core/index.rs
1use super::array::Array;
2use super::defines::AfError;
3use super::error::HANDLE_ERROR;
4use super::seq::Seq;
5use super::util::{af_array, af_index_t, dim_t, HasAfEnum, IndexableType};
6
7use libc::{c_double, c_int, c_uint};
8use std::default::Default;
9use std::marker::PhantomData;
10use std::mem;
11
12extern "C" {
13 fn af_create_indexers(indexers: *mut af_index_t) -> c_int;
14 fn af_set_array_indexer(indexer: af_index_t, idx: af_array, dim: dim_t) -> c_int;
15 fn af_set_seq_indexer(
16 indexer: af_index_t,
17 idx: *const SeqInternal,
18 dim: dim_t,
19 is_batch: bool,
20 ) -> c_int;
21 fn af_release_indexers(indexers: af_index_t) -> c_int;
22
23 fn af_index(
24 out: *mut af_array,
25 input: af_array,
26 ndims: c_uint,
27 index: *const SeqInternal,
28 ) -> c_int;
29 fn af_lookup(out: *mut af_array, arr: af_array, indices: af_array, dim: c_uint) -> c_int;
30 fn af_assign_seq(
31 out: *mut af_array,
32 lhs: af_array,
33 ndims: c_uint,
34 indices: *const SeqInternal,
35 rhs: af_array,
36 ) -> c_int;
37 fn af_index_gen(
38 out: *mut af_array,
39 input: af_array,
40 ndims: dim_t,
41 indices: af_index_t,
42 ) -> c_int;
43 fn af_assign_gen(
44 out: *mut af_array,
45 lhs: af_array,
46 ndims: dim_t,
47 indices: af_index_t,
48 rhs: af_array,
49 ) -> c_int;
50}
51
52/// Struct to manage an array of resources of type `af_indexer_t`(ArrayFire C struct)
53///
54/// ## Sharing Across Threads
55///
56/// While sharing an Indexer object with other threads, just move it across threads. At the
57/// moment, one cannot share borrowed references across threads.
58///
59/// # Examples
60///
61/// Given below are examples illustrating correct and incorrect usage of Indexer struct.
62///
63/// <h3> Correct Usage </h3>
64///
65/// ```rust
66/// use arrayfire::{Array, Dim4, randu, index_gen, Indexer};
67///
68/// // Always be aware of the fact that, the `Seq` or `Array` objects
69/// // that we intend to use for indexing via `Indexer` have to outlive
70/// // the `Indexer` object created in this context.
71///
72/// let dims = Dim4::new(&[1, 3, 1, 1]);
73/// let indices = [1u8, 0, 1];
74/// let idx = Array::new(&indices, dims);
75/// let values = [2.0f32, 5.0, 6.0];
76/// let arr = Array::new(&values, dims);
77///
78/// let mut idxr = Indexer::default();
79///
80/// // `idx` is created much before idxr, thus will
81/// // stay in scope at least as long as idxr
82/// idxr.set_index(&idx, 0, None);
83///
84/// index_gen(&arr, idxr);
85/// ```
86///
87/// <h3> Incorrect Usage </h3>
88///
89/// ```rust,ignore
90/// // Say, you create an Array on the fly and try
91/// // to call set_index, it will throw the given below
92/// // error or something similar to that
93/// idxr.set_index(&Array::new(&[1, 0, 1], dims), 0, None);
94/// ```
95///
96/// ```text
97/// error: borrowed value does not live long enough
98/// --> <anon>:16:55
99/// |
100///16 | idxr.set_index(&Array::new(&[1, 0, 1], dims), 0, None);
101/// | ---------------------------- ^ temporary value dropped here while still borrowed
102/// | |
103/// | temporary value created here
104///...
105///19 | }
106/// | - temporary value needs to live until here
107/// |
108/// = note: consider using a `let` binding to increase its lifetime
109/// ```
110pub struct Indexer<'object> {
111 handle: af_index_t,
112 count: usize,
113 marker: PhantomData<&'object ()>,
114}
115
116unsafe impl<'object> Send for Indexer<'object> {}
117
118/// Trait bound indicating indexability
119///
120/// Any object to be able to be passed on to [Indexer::set_index()](./struct.Indexer.html#method.set_index) method should implement this trait with appropriate implementation of `set` method.
121pub trait Indexable {
122 /// Set indexing object for a given dimension
123 ///
124 /// `is_batch` parameter is not used in most cases as it has been provided in
125 /// ArrayFire C-API to enable GFOR construct in ArrayFire C++ API. This type
126 /// of construct/idea is not exposed in rust wrapper yet. So, the user would
127 /// just need to pass `None` to this parameter while calling this function.
128 /// Since we can't have default default values and we wanted to keep this
129 /// parameter for future use cases, we just made it an `std::Option`.
130 ///
131 /// # Parameters
132 ///
133 /// - `idxr` is mutable reference to [Indexer](./struct.Indexer.html) object which will
134 /// be modified to set `self` indexable along `dim` dimension.
135 /// - `dim` is the dimension along which `self` indexable will be used for indexing.
136 /// - `is_batch` is only used if `self` is [Seq](./struct.Seq.html) to indicate if indexing
137 /// along `dim` is a batched operation.
138 fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>);
139}
140
141/// Enables [Array](./struct.Array.html) to be used to index another Array
142///
143/// This is used in functions [index_gen](./fn.index_gen.html) and
144/// [assign_gen](./fn.assign_gen.html)
145impl<T> Indexable for Array<T>
146where
147 T: HasAfEnum + IndexableType,
148{
149 fn set(&self, idxr: &mut Indexer, dim: u32, _is_batch: Option<bool>) {
150 unsafe {
151 let err_val = af_set_array_indexer(idxr.get(), self.get(), dim as dim_t);
152 HANDLE_ERROR(AfError::from(err_val));
153 }
154 }
155}
156
157/// Enables [Seq](./struct.Seq.html) to be used to index another Array
158///
159/// This is used in functions [index_gen](./fn.index_gen.html) and
160/// [assign_gen](./fn.assign_gen.html)
161impl<T> Indexable for Seq<T>
162where
163 c_double: From<T>,
164 T: Copy + IndexableType,
165{
166 fn set(&self, idxr: &mut Indexer, dim: u32, is_batch: Option<bool>) {
167 unsafe {
168 let err_val = af_set_seq_indexer(
169 idxr.get(),
170 &SeqInternal::from_seq(self) as *const SeqInternal,
171 dim as dim_t,
172 match is_batch {
173 Some(value) => value,
174 None => false,
175 },
176 );
177 HANDLE_ERROR(AfError::from(err_val));
178 }
179 }
180}
181
182impl<'object> Default for Indexer<'object> {
183 fn default() -> Self {
184 unsafe {
185 let mut temp: af_index_t = std::ptr::null_mut();
186 let err_val = af_create_indexers(&mut temp as *mut af_index_t);
187 HANDLE_ERROR(AfError::from(err_val));
188 Self {
189 handle: temp,
190 count: 0,
191 marker: PhantomData,
192 }
193 }
194 }
195}
196
197impl<'object> Indexer<'object> {
198 /// Create a new Indexer object and set the dimension specific index objects later
199 #[deprecated(since = "3.7.0", note = "Use Indexer::default() instead")]
200 pub fn new() -> Self {
201 unsafe {
202 let mut temp: af_index_t = std::ptr::null_mut();
203 let err_val = af_create_indexers(&mut temp as *mut af_index_t);
204 HANDLE_ERROR(AfError::from(err_val));
205 Self {
206 handle: temp,
207 count: 0,
208 marker: PhantomData,
209 }
210 }
211 }
212
213 /// Set either [Array](./struct.Array.html) or [Seq](./struct.Seq.html) to index an Array along `idx` dimension
214 pub fn set_index<'s, T>(&'s mut self, idx: &'object T, dim: u32, is_batch: Option<bool>)
215 where
216 T: Indexable + 'object,
217 {
218 idx.set(self, dim, is_batch);
219 self.count += 1;
220 }
221
222 /// Get number of indexing objects set
223 pub fn len(&self) -> usize {
224 self.count
225 }
226
227 /// Check if any indexing objects are set
228 pub fn is_empty(&self) -> bool {
229 self.count == 0
230 }
231
232 /// Get native(ArrayFire) resource handle
233 pub unsafe fn get(&self) -> af_index_t {
234 self.handle
235 }
236}
237
238impl<'object> Drop for Indexer<'object> {
239 fn drop(&mut self) {
240 unsafe {
241 let ret_val = af_release_indexers(self.handle as af_index_t);
242 match ret_val {
243 0 => (),
244 _ => panic!("Failed to release indexers resource: {}", ret_val),
245 }
246 }
247 }
248}
249
250/// Indexes the `input` Array using `seqs` Sequences
251///
252/// # Examples
253///
254/// ```rust
255/// use arrayfire::{Dim4, Seq, index, randu, print};
256/// let dims = Dim4::new(&[5, 5, 1, 1]);
257/// let a = randu::<f32>(dims);
258/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
259/// let sub = index(&a, seqs);
260/// println!("a(seq(1, 3, 1), span)");
261/// print(&sub);
262/// ```
263pub fn index<IO, T>(input: &Array<IO>, seqs: &[Seq<T>]) -> Array<IO>
264where
265 c_double: From<T>,
266 IO: HasAfEnum,
267 T: Copy + HasAfEnum + IndexableType,
268{
269 let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
270 unsafe {
271 let mut temp: af_array = std::ptr::null_mut();
272 let err_val = af_index(
273 &mut temp as *mut af_array,
274 input.get(),
275 seqs.len() as u32,
276 seqs.as_ptr() as *const SeqInternal,
277 );
278 HANDLE_ERROR(AfError::from(err_val));
279 temp.into()
280 }
281}
282
283/// Extract `row_num` row from `input` Array
284///
285/// # Examples
286///
287/// ```rust
288/// use arrayfire::{Dim4, randu, row, print};
289/// let dims = Dim4::new(&[5, 5, 1, 1]);
290/// let a = randu::<f32>(dims);
291/// println!("Grab last row of the random matrix");
292/// print(&a);
293/// print(&row(&a, 4));
294/// ```
295pub fn row<T>(input: &Array<T>, row_num: i64) -> Array<T>
296where
297 T: HasAfEnum,
298{
299 index(
300 input,
301 &[
302 Seq::new(row_num as f64, row_num as f64, 1.0),
303 Seq::default(),
304 ],
305 )
306}
307
308/// Set `row_num`^th row in `inout` Array to a new Array `new_row`
309pub fn set_row<T>(inout: &mut Array<T>, new_row: &Array<T>, row_num: i64)
310where
311 T: HasAfEnum,
312{
313 let mut seqs = vec![Seq::new(row_num as f64, row_num as f64, 1.0)];
314 if inout.dims().ndims() > 1 {
315 seqs.push(Seq::default());
316 }
317 assign_seq(inout, &seqs, new_row)
318}
319
320/// Get an Array with all rows from `first` to `last` in the `input` Array
321pub fn rows<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
322where
323 T: HasAfEnum,
324{
325 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
326 index(
327 input,
328 &[Seq::new(first as f64, last as f64, step), Seq::default()],
329 )
330}
331
332/// Set rows from `first` to `last` in `inout` Array with rows from Array `new_rows`
333pub fn set_rows<T>(inout: &mut Array<T>, new_rows: &Array<T>, first: i64, last: i64)
334where
335 T: HasAfEnum,
336{
337 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
338 let seqs = [Seq::new(first as f64, last as f64, step), Seq::default()];
339 assign_seq(inout, &seqs, new_rows)
340}
341
342/// Extract `col_num` col from `input` Array
343///
344/// # Examples
345///
346/// ```rust
347/// use arrayfire::{Dim4, randu, col, print};
348/// let dims = Dim4::new(&[5, 5, 1, 1]);
349/// let a = randu::<f32>(dims);
350/// print(&a);
351/// println!("Grab last col of the random matrix");
352/// print(&col(&a, 4));
353/// ```
354pub fn col<T>(input: &Array<T>, col_num: i64) -> Array<T>
355where
356 T: HasAfEnum,
357{
358 index(
359 input,
360 &[
361 Seq::default(),
362 Seq::new(col_num as f64, col_num as f64, 1.0),
363 ],
364 )
365}
366
367/// Set `col_num`^th col in `inout` Array to a new Array `new_col`
368pub fn set_col<T>(inout: &mut Array<T>, new_col: &Array<T>, col_num: i64)
369where
370 T: HasAfEnum,
371{
372 let seqs = [
373 Seq::default(),
374 Seq::new(col_num as f64, col_num as f64, 1.0),
375 ];
376 assign_seq(inout, &seqs, new_col)
377}
378
379/// Get all cols from `first` to `last` in the `input` Array
380pub fn cols<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
381where
382 T: HasAfEnum,
383{
384 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
385 index(
386 input,
387 &[Seq::default(), Seq::new(first as f64, last as f64, step)],
388 )
389}
390
391/// Set cols from `first` to `last` in `inout` Array with cols from Array `new_cols`
392pub fn set_cols<T>(inout: &mut Array<T>, new_cols: &Array<T>, first: i64, last: i64)
393where
394 T: HasAfEnum,
395{
396 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
397 let seqs = [Seq::default(), Seq::new(first as f64, last as f64, step)];
398 assign_seq(inout, &seqs, new_cols)
399}
400
401/// Get `slice_num`^th slice from `input` Array
402///
403/// Slices indicate that the indexing is along 3rd dimension
404pub fn slice<T>(input: &Array<T>, slice_num: i64) -> Array<T>
405where
406 T: HasAfEnum,
407{
408 let seqs = [
409 Seq::default(),
410 Seq::default(),
411 Seq::new(slice_num as f64, slice_num as f64, 1.0),
412 ];
413 index(input, &seqs)
414}
415
416/// Set slice `slice_num` in `inout` Array to a new Array `new_slice`
417///
418/// Slices indicate that the indexing is along 3rd dimension
419pub fn set_slice<T>(inout: &mut Array<T>, new_slice: &Array<T>, slice_num: i64)
420where
421 T: HasAfEnum,
422{
423 let seqs = [
424 Seq::default(),
425 Seq::default(),
426 Seq::new(slice_num as f64, slice_num as f64, 1.0),
427 ];
428 assign_seq(inout, &seqs, new_slice)
429}
430
431/// Get slices from `first` to `last` in `input` Array
432///
433/// Slices indicate that the indexing is along 3rd dimension
434pub fn slices<T>(input: &Array<T>, first: i64, last: i64) -> Array<T>
435where
436 T: HasAfEnum,
437{
438 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
439 let seqs = [
440 Seq::default(),
441 Seq::default(),
442 Seq::new(first as f64, last as f64, step),
443 ];
444 index(input, &seqs)
445}
446
447/// Set `first` to `last` slices of `inout` Array to a new Array `new_slices`
448///
449/// Slices indicate that the indexing is along 3rd dimension
450pub fn set_slices<T>(inout: &mut Array<T>, new_slices: &Array<T>, first: i64, last: i64)
451where
452 T: HasAfEnum,
453{
454 let step: f64 = if first > last && last < 0 { -1.0 } else { 1.0 };
455 let seqs = [
456 Seq::default(),
457 Seq::default(),
458 Seq::new(first as f64, last as f64, step),
459 ];
460 assign_seq(inout, &seqs, new_slices)
461}
462
463/// Lookup(hash) an Array using another Array
464///
465/// Given a dimension `seq_dim`, `indices` are lookedup in `input` and returned as a new
466/// Array if found
467pub fn lookup<T, I>(input: &Array<T>, indices: &Array<I>, seq_dim: i32) -> Array<T>
468where
469 T: HasAfEnum,
470 I: HasAfEnum + IndexableType,
471{
472 unsafe {
473 let mut temp: af_array = std::ptr::null_mut();
474 let err_val = af_lookup(
475 &mut temp as *mut af_array,
476 input.get() as af_array,
477 indices.get() as af_array,
478 seq_dim as c_uint,
479 );
480 HANDLE_ERROR(AfError::from(err_val));
481 temp.into()
482 }
483}
484
485/// Assign(copy) content of an Array to another Array indexed by Sequences
486///
487/// Assign `rhs` to `lhs` after indexing `lhs`
488///
489/// # Examples
490///
491/// ```rust
492/// use arrayfire::{constant, Dim4, Seq, assign_seq, print};
493/// let mut a = constant(2.0 as f32, Dim4::new(&[5, 3, 1, 1]));
494/// print(&a);
495/// // 2.0 2.0 2.0
496/// // 2.0 2.0 2.0
497/// // 2.0 2.0 2.0
498/// // 2.0 2.0 2.0
499/// // 2.0 2.0 2.0
500///
501/// let b = constant(1.0 as f32, Dim4::new(&[3, 3, 1, 1]));
502/// let seqs = &[Seq::new(1.0, 3.0, 1.0), Seq::default()];
503/// assign_seq(&mut a, seqs, &b);
504///
505/// print(&a);
506/// // 2.0 2.0 2.0
507/// // 1.0 1.0 1.0
508/// // 1.0 1.0 1.0
509/// // 1.0 1.0 1.0
510/// // 2.0 2.0 2.0
511/// ```
512pub fn assign_seq<T, I>(lhs: &mut Array<I>, seqs: &[Seq<T>], rhs: &Array<I>)
513where
514 c_double: From<T>,
515 I: HasAfEnum,
516 T: Copy + IndexableType,
517{
518 let seqs: Vec<SeqInternal> = seqs.iter().map(|s| SeqInternal::from_seq(s)).collect();
519 unsafe {
520 let mut temp: af_array = std::ptr::null_mut();
521 let err_val = af_assign_seq(
522 &mut temp as *mut af_array,
523 lhs.get() as af_array,
524 seqs.len() as c_uint,
525 seqs.as_ptr() as *const SeqInternal,
526 rhs.get() as af_array,
527 );
528 HANDLE_ERROR(AfError::from(err_val));
529
530 let modified = temp.into();
531 let _old_arr = mem::replace(lhs, modified);
532 }
533}
534
535/// Index an Array using any combination of Array's and Sequence's
536///
537/// # Examples
538///
539/// ```rust
540/// use arrayfire::{Array, Dim4, Seq, print, randu, index_gen, Indexer};
541/// let values: [f32; 3] = [1.0, 2.0, 3.0];
542/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
543/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
544/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
545/// // [5 3 1 1]
546/// // 0.0000 0.2190 0.3835
547/// // 0.1315 0.0470 0.5194
548/// // 0.7556 0.6789 0.8310
549/// // 0.4587 0.6793 0.0346
550/// // 0.5328 0.9347 0.0535
551///
552///
553/// let mut idxrs = Indexer::default();
554/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
555/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
556///
557/// let sub2 = index_gen(&a, idxrs);
558/// println!("a(indices, seq(0, 2, 1))"); print(&sub2);
559/// // [3 3 1 1]
560/// // 0.1315 0.0470 0.5194
561/// // 0.7556 0.6789 0.8310
562/// // 0.4587 0.6793 0.0346
563/// ```
564pub fn index_gen<T>(input: &Array<T>, indices: Indexer) -> Array<T>
565where
566 T: HasAfEnum,
567{
568 unsafe {
569 let mut temp: af_array = std::ptr::null_mut();
570 let err_val = af_index_gen(
571 &mut temp as *mut af_array,
572 input.get() as af_array,
573 indices.len() as dim_t,
574 indices.get() as af_index_t,
575 );
576 HANDLE_ERROR(AfError::from(err_val));
577 temp.into()
578 }
579}
580
581/// Assign an Array to another after indexing it using any combination of Array's and Sequence's
582///
583/// # Examples
584///
585/// ```rust
586/// use arrayfire::{Array, Dim4, Seq, print, randu, constant, Indexer, assign_gen};
587/// let values: [f32; 3] = [1.0, 2.0, 3.0];
588/// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
589/// let seq4gen = Seq::new(0.0, 2.0, 1.0);
590/// let mut a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
591/// // [5 3 1 1]
592/// // 0.0000 0.2190 0.3835
593/// // 0.1315 0.0470 0.5194
594/// // 0.7556 0.6789 0.8310
595/// // 0.4587 0.6793 0.0346
596/// // 0.5328 0.9347 0.0535
597///
598/// let b = constant(2.0 as f32, Dim4::new(&[3, 3, 1, 1]));
599///
600/// let mut idxrs = Indexer::default();
601/// idxrs.set_index(&indices, 0, None); // 2nd parameter is indexing dimension
602/// idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd parameter indicates batch operation
603///
604/// assign_gen(&mut a, &idxrs, &b);
605/// println!("a(indices, seq(0, 2, 1))"); print(&a);
606/// // [5 3 1 1]
607/// // 0.0000 0.2190 0.3835
608/// // 2.0000 2.0000 2.0000
609/// // 2.0000 2.0000 2.0000
610/// // 2.0000 2.0000 2.0000
611/// // 0.5328 0.9347 0.0535
612/// ```
613pub fn assign_gen<T>(lhs: &mut Array<T>, indices: &Indexer, rhs: &Array<T>)
614where
615 T: HasAfEnum,
616{
617 unsafe {
618 let mut temp: af_array = std::ptr::null_mut();
619 let err_val = af_assign_gen(
620 &mut temp as *mut af_array,
621 lhs.get() as af_array,
622 indices.len() as dim_t,
623 indices.get() as af_index_t,
624 rhs.get() as af_array,
625 );
626 HANDLE_ERROR(AfError::from(err_val));
627
628 let modified = temp.into();
629 let _old_arr = mem::replace(lhs, modified);
630 }
631}
632
633#[repr(C)]
634struct SeqInternal {
635 begin: c_double,
636 end: c_double,
637 step: c_double,
638}
639
640impl SeqInternal {
641 fn from_seq<T>(s: &Seq<T>) -> Self
642 where
643 c_double: From<T>,
644 T: Copy + IndexableType,
645 {
646 Self {
647 begin: From::from(s.begin()),
648 end: From::from(s.end()),
649 step: From::from(s.step()),
650 }
651 }
652}
653
654#[cfg(test)]
655mod tests {
656 use super::super::array::Array;
657 use super::super::data::constant;
658 use super::super::device::set_device;
659 use super::super::dim4::Dim4;
660 use super::super::index::{assign_gen, assign_seq, col, index, index_gen, row, Indexer};
661 use super::super::index::{cols, rows};
662 use super::super::random::randu;
663 use super::super::seq::Seq;
664
665 use crate::{dim4, seq, view};
666
667 #[test]
668 fn non_macro_seq_index() {
669 set_device(0);
670 // ANCHOR: non_macro_seq_index
671 let dims = Dim4::new(&[5, 5, 1, 1]);
672 let a = randu::<f32>(dims);
673 //af_print!("a", a);
674 //a
675 //[5 5 1 1]
676 // 0.3990 0.5160 0.8831 0.9107 0.6688
677 // 0.6720 0.3932 0.0621 0.9159 0.8434
678 // 0.5339 0.2706 0.7089 0.0231 0.1328
679 // 0.1386 0.9455 0.9434 0.2330 0.2657
680 // 0.7353 0.1587 0.1227 0.2220 0.2299
681
682 // Index array using sequences
683 let seqs = &[Seq::new(1u32, 3, 1), Seq::default()];
684 let _sub = index(&a, seqs);
685 //af_print!("a(seq(1,3,1), span)", sub);
686 // [3 5 1 1]
687 // 0.6720 0.3932 0.0621 0.9159 0.8434
688 // 0.5339 0.2706 0.7089 0.0231 0.1328
689 // 0.1386 0.9455 0.9434 0.2330 0.2657
690 // ANCHOR_END: non_macro_seq_index
691 }
692
693 #[test]
694 fn seq_index() {
695 set_device(0);
696 // ANCHOR: seq_index
697 let dims = dim4!(5, 5, 1, 1);
698 let a = randu::<f32>(dims);
699 let first3 = seq!(1:3:1);
700 let allindim2 = seq!();
701 let _sub = view!(a[first3, allindim2]);
702 // ANCHOR_END: seq_index
703 }
704
705 #[test]
706 fn non_macro_seq_assign() {
707 set_device(0);
708 // ANCHOR: non_macro_seq_assign
709 let mut a = constant(2.0 as f32, dim4!(5, 3));
710 //print(&a);
711 // 2.0 2.0 2.0
712 // 2.0 2.0 2.0
713 // 2.0 2.0 2.0
714 // 2.0 2.0 2.0
715 // 2.0 2.0 2.0
716
717 let b = constant(1.0 as f32, dim4!(3, 3));
718 let seqs = [seq!(1:3:1), seq!()];
719 assign_seq(&mut a, &seqs, &b);
720 //print(&a);
721 // 2.0 2.0 2.0
722 // 1.0 1.0 1.0
723 // 1.0 1.0 1.0
724 // 1.0 1.0 1.0
725 // 2.0 2.0 2.0
726 // ANCHOR_END: non_macro_seq_assign
727 }
728
729 #[test]
730 fn non_macro_seq_array_index() {
731 set_device(0);
732 // ANCHOR: non_macro_seq_array_index
733 let values: [f32; 3] = [1.0, 2.0, 3.0];
734 let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
735 let seq4gen = Seq::new(0.0, 2.0, 1.0);
736 let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
737 // [5 3 1 1]
738 // 0.0000 0.2190 0.3835
739 // 0.1315 0.0470 0.5194
740 // 0.7556 0.6789 0.8310
741 // 0.4587 0.6793 0.0346
742 // 0.5328 0.9347 0.0535
743
744 let mut idxrs = Indexer::default();
745 idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
746 idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation
747
748 let _sub2 = index_gen(&a, idxrs);
749 //println!("a(indices, seq(0, 2, 1))"); print(&sub2);
750 // [3 3 1 1]
751 // 0.1315 0.0470 0.5194
752 // 0.7556 0.6789 0.8310
753 // 0.4587 0.6793 0.0346
754 // ANCHOR_END: non_macro_seq_array_index
755 }
756
757 #[test]
758 fn seq_array_index() {
759 set_device(0);
760 // ANCHOR: seq_array_index
761 let values: [f32; 3] = [1.0, 2.0, 3.0];
762 let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
763 let seq4gen = seq!(0:2:1);
764 let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
765 let _sub2 = view!(a[indices, seq4gen]);
766 // ANCHOR_END: seq_array_index
767 }
768
769 #[test]
770 fn non_macro_seq_array_assign() {
771 set_device(0);
772 // ANCHOR: non_macro_seq_array_assign
773 let values: [f32; 3] = [1.0, 2.0, 3.0];
774 let indices = Array::new(&values, dim4!(3, 1, 1, 1));
775 let seq4gen = seq!(0:2:1);
776 let mut a = randu::<f32>(dim4!(5, 3, 1, 1));
777 // [5 3 1 1]
778 // 0.0000 0.2190 0.3835
779 // 0.1315 0.0470 0.5194
780 // 0.7556 0.6789 0.8310
781 // 0.4587 0.6793 0.0346
782 // 0.5328 0.9347 0.0535
783
784 let b = constant(2.0 as f32, dim4!(3, 3, 1, 1));
785
786 let mut idxrs = Indexer::default();
787 idxrs.set_index(&indices, 0, None); // 2nd arg is indexing dimension
788 idxrs.set_index(&seq4gen, 1, Some(false)); // 3rd arg indicates batch operation
789
790 let _sub2 = assign_gen(&mut a, &idxrs, &b);
791 //println!("a(indices, seq(0, 2, 1))"); print(&sub2);
792 // [5 3 1 1]
793 // 0.0000 0.2190 0.3835
794 // 2.0000 2.0000 2.0000
795 // 2.0000 2.0000 2.0000
796 // 2.0000 2.0000 2.0000
797 // 0.5328 0.9347 0.0535
798 // ANCHOR_END: non_macro_seq_array_assign
799 }
800
801 #[test]
802 fn setrow() {
803 set_device(0);
804 // ANCHOR: setrow
805 let a = randu::<f32>(dim4!(5, 5, 1, 1));
806 //print(&a);
807 // [5 5 1 1]
808 // 0.6010 0.5497 0.1583 0.3636 0.6755
809 // 0.0278 0.2864 0.3712 0.4165 0.6105
810 // 0.9806 0.3410 0.3543 0.5814 0.5232
811 // 0.2126 0.7509 0.6450 0.8962 0.5567
812 // 0.0655 0.4105 0.9675 0.3712 0.7896
813 let _r = row(&a, 4);
814 // [1 5 1 1]
815 // 0.0655 0.4105 0.9675 0.3712 0.7896
816 let _c = col(&a, 4);
817 // [5 1 1 1]
818 // 0.6755
819 // 0.6105
820 // 0.5232
821 // 0.5567
822 // 0.7896
823 // ANCHOR_END: setrow
824 }
825
826 #[test]
827 fn get_row() {
828 set_device(0);
829 // ANCHOR: get_row
830 let a = randu::<f32>(dim4!(5, 5));
831 // [5 5 1 1]
832 // 0.6010 0.5497 0.1583 0.3636 0.6755
833 // 0.0278 0.2864 0.3712 0.4165 0.6105
834 // 0.9806 0.3410 0.3543 0.5814 0.5232
835 // 0.2126 0.7509 0.6450 0.8962 0.5567
836 // 0.0655 0.4105 0.9675 0.3712 0.7896
837 let _r = row(&a, -1);
838 // [1 5 1 1]
839 // 0.0655 0.4105 0.9675 0.3712 0.7896
840 let _c = col(&a, -1);
841 // [5 1 1 1]
842 // 0.6755
843 // 0.6105
844 // 0.5232
845 // 0.5567
846 // 0.7896
847 // ANCHOR_END: get_row
848 }
849
850 #[test]
851 fn get_rows() {
852 set_device(0);
853 // ANCHOR: get_rows
854 let a = randu::<f32>(dim4!(5, 5));
855 // [5 5 1 1]
856 // 0.6010 0.5497 0.1583 0.3636 0.6755
857 // 0.0278 0.2864 0.3712 0.4165 0.6105
858 // 0.9806 0.3410 0.3543 0.5814 0.5232
859 // 0.2126 0.7509 0.6450 0.8962 0.5567
860 // 0.0655 0.4105 0.9675 0.3712 0.7896
861 let _r = rows(&a, -1, -2);
862 // [2 5 1 1]
863 // 0.2126 0.7509 0.6450 0.8962 0.5567
864 // 0.0655 0.4105 0.9675 0.3712 0.7896
865 let _c = cols(&a, -1, -3);
866 // [5 3 1 1]
867 // 0.1583 0.3636 0.6755
868 // 0.3712 0.4165 0.6105
869 // 0.3543 0.5814 0.5232
870 // 0.6450 0.8962 0.5567
871 // 0.9675 0.3712 0.7896
872 // ANCHOR_END: get_rows
873 }
874}