1use super::core::{
2 af_array, AfError, Array, BinaryOp, Fromf64, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable,
3 HANDLE_ERROR,
4};
5
6use libc::{c_double, c_int, c_uint};
7
8extern "C" {
9 fn af_sum(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
10 fn af_sum_nan(out: *mut af_array, input: af_array, dim: c_int, nanval: c_double) -> c_int;
11 fn af_product(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
12 fn af_product_nan(out: *mut af_array, input: af_array, dim: c_int, val: c_double) -> c_int;
13 fn af_min(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
14 fn af_max(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
15 fn af_all_true(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
16 fn af_any_true(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
17 fn af_count(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
18 fn af_sum_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
19 fn af_sum_nan_all(r: *mut c_double, i: *mut c_double, input: af_array, val: c_double) -> c_int;
20 fn af_product_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
21 fn af_product_nan_all(
22 r: *mut c_double,
23 i: *mut c_double,
24 input: af_array,
25 val: c_double,
26 ) -> c_int;
27 fn af_min_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
28 fn af_max_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
29 fn af_all_true_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
30 fn af_any_true_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
31 fn af_count_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
32 fn af_imin(out: *mut af_array, idx: *mut af_array, input: af_array, dim: c_int) -> c_int;
33 fn af_imax(out: *mut af_array, idx: *mut af_array, input: af_array, dim: c_int) -> c_int;
34 fn af_imin_all(r: *mut c_double, i: *mut c_double, idx: *mut c_uint, input: af_array) -> c_int;
35 fn af_imax_all(r: *mut c_double, i: *mut c_double, idx: *mut c_uint, input: af_array) -> c_int;
36 fn af_accum(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
37 fn af_where(out: *mut af_array, input: af_array) -> c_int;
38 fn af_diff1(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
39 fn af_diff2(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
40 fn af_sort(out: *mut af_array, input: af_array, dim: c_uint, ascend: bool) -> c_int;
41 fn af_sort_index(
42 o: *mut af_array,
43 i: *mut af_array,
44 inp: af_array,
45 d: c_uint,
46 a: bool,
47 ) -> c_int;
48 fn af_set_unique(out: *mut af_array, input: af_array, is_sorted: bool) -> c_int;
49 fn af_set_union(out: *mut af_array, first: af_array, second: af_array, is_unq: bool) -> c_int;
50 fn af_set_intersect(out: *mut af_array, one: af_array, two: af_array, is_unq: bool) -> c_int;
51
52 fn af_sort_by_key(
53 out_keys: *mut af_array,
54 out_vals: *mut af_array,
55 in_keys: af_array,
56 in_vals: af_array,
57 dim: c_uint,
58 ascend: bool,
59 ) -> c_int;
60
61 fn af_scan(out: *mut af_array, inp: af_array, dim: c_int, op: c_uint, inclusive: bool)
62 -> c_int;
63 fn af_scan_by_key(
64 out: *mut af_array,
65 key: af_array,
66 inp: af_array,
67 dim: c_int,
68 op: c_uint,
69 inclusive: bool,
70 ) -> c_int;
71 fn af_all_true_by_key(
72 keys_out: *mut af_array,
73 vals_out: *mut af_array,
74 keys: af_array,
75 vals: af_array,
76 dim: c_int,
77 ) -> c_int;
78 fn af_any_true_by_key(
79 keys_out: *mut af_array,
80 vals_out: *mut af_array,
81 keys: af_array,
82 vals: af_array,
83 dim: c_int,
84 ) -> c_int;
85 fn af_count_by_key(
86 keys_out: *mut af_array,
87 vals_out: *mut af_array,
88 keys: af_array,
89 vals: af_array,
90 dim: c_int,
91 ) -> c_int;
92 fn af_max_by_key(
93 keys_out: *mut af_array,
94 vals_out: *mut af_array,
95 keys: af_array,
96 vals: af_array,
97 dim: c_int,
98 ) -> c_int;
99 fn af_min_by_key(
100 keys_out: *mut af_array,
101 vals_out: *mut af_array,
102 keys: af_array,
103 vals: af_array,
104 dim: c_int,
105 ) -> c_int;
106 fn af_product_by_key(
107 keys_out: *mut af_array,
108 vals_out: *mut af_array,
109 keys: af_array,
110 vals: af_array,
111 dim: c_int,
112 ) -> c_int;
113 fn af_product_by_key_nan(
114 keys_out: *mut af_array,
115 vals_out: *mut af_array,
116 keys: af_array,
117 vals: af_array,
118 dim: c_int,
119 nan_val: c_double,
120 ) -> c_int;
121 fn af_sum_by_key(
122 keys_out: *mut af_array,
123 vals_out: *mut af_array,
124 keys: af_array,
125 vals: af_array,
126 dim: c_int,
127 ) -> c_int;
128 fn af_sum_by_key_nan(
129 keys_out: *mut af_array,
130 vals_out: *mut af_array,
131 keys: af_array,
132 vals: af_array,
133 dim: c_int,
134 nan_val: c_double,
135 ) -> c_int;
136 fn af_max_ragged(
137 val_out: *mut af_array,
138 idx_out: *mut af_array,
139 input: af_array,
140 ragged_len: af_array,
141 dim: c_int,
142 ) -> c_int;
143}
144
145macro_rules! dim_reduce_func_def {
146 ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
147 #[doc=$doc_str]
148 pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> Array<$out_type>
149 where
150 T: HasAfEnum,
151 $out_type: HasAfEnum,
152 {
153 unsafe {
154 let mut temp: af_array = std::ptr::null_mut();
155 let err_val = $ffi_name(&mut temp as *mut af_array, input.get(), dim);
156 HANDLE_ERROR(AfError::from(err_val));
157 temp.into()
158 }
159 }
160 };
161}
162
163dim_reduce_func_def!(
164 "
165 Sum elements along a given dimension
166
167 # Parameters
168
169 - `input` - Input Array
170 - `dim` - Dimension along which the input Array will be reduced
171
172 # Return Values
173
174 Result Array after summing all elements along given dimension
175
176 # Examples
177
178 ```rust
179 use arrayfire::{Dim4, print, randu, sum};
180 let dims = Dim4::new(&[5, 3, 1, 1]);
181 let a = randu::<f32>(dims);
182 print(&a);
183 let b = sum(&a, 0);
184 print(&b);
185 let c = sum(&a, 1);
186 print(&c);
187 ```
188 ",
189 sum,
190 af_sum,
191 T::AggregateOutType
192);
193
194dim_reduce_func_def!(
195 "
196 Compute product of elements along a given dimension
197
198 # Parameters
199
200 - `input` - Input Array
201 - `dim` - Dimension along which the input Array will be reduced
202
203 # Return Values
204
205 Result Array after multiplying all elements along given dimension
206
207 # Examples
208
209 ```rust
210 use arrayfire::{Dim4, print, randu, product};
211 let dims = Dim4::new(&[5, 3, 1, 1]);
212 let a = randu::<f32>(dims);
213 print(&a);
214 let b = product(&a, 0);
215 print(&b);
216 let c = product(&a, 1);
217 print(&c);
218 ```
219 ",
220 product,
221 af_product,
222 T::ProductOutType
223);
224
225dim_reduce_func_def!(
226 "
227 Find minimum among elements of given dimension
228
229 # Parameters
230
231 - `input` - Input Array
232 - `dim` - Dimension along which the input Array will be reduced
233
234 # Return Values
235
236 Result Array after finding minimum among elements along a given dimension
237
238 # Examples
239
240 ```rust
241 use arrayfire::{Dim4, print, randu, min};
242 let dims = Dim4::new(&[5, 3, 1, 1]);
243 let a = randu::<f32>(dims);
244 print(&a);
245 let b = min(&a, 0);
246 print(&b);
247 let c = min(&a, 1);
248 print(&c);
249 ```
250 ",
251 min,
252 af_min,
253 T::InType
254);
255
256dim_reduce_func_def!(
257 "
258 Find maximum among elements of given dimension
259
260 # Parameters
261
262 - `input` - Input Array
263 - `dim` - Dimension along which the input Array will be reduced
264
265 # Return Values
266
267 Result Array after finding maximum among elements along a given dimension
268
269 # Examples
270
271 ```rust
272 use arrayfire::{Dim4, print, randu, max};
273 let dims = Dim4::new(&[5, 3, 1, 1]);
274 let a = randu::<f32>(dims);
275 print(&a);
276 let b = max(&a, 0);
277 print(&b);
278 let c = max(&a, 1);
279 print(&c);
280 ```
281 ",
282 max,
283 af_max,
284 T::InType
285);
286
287dim_reduce_func_def!(
288 "
289 Find if all of the values along a given dimension in the Array are true
290
291 # Parameters
292
293 - `input` - Input Array
294 - `dim` - Dimension along which the predicate is evaluated
295
296 # Return Values
297
298 Result Array that contains the result of `AND` operation of all elements along given dimension
299
300 # Examples
301
302 ```rust
303 use arrayfire::{Dim4, print, randu, all_true};
304 let dims = Dim4::new(&[5, 3, 1, 1]);
305 let a = randu::<f32>(dims);
306 print(&a);
307 let b = all_true(&a, 0);
308 print(&b);
309 let c = all_true(&a, 1);
310 print(&c);
311 ```
312 ",
313 all_true,
314 af_all_true,
315 bool
316);
317
318dim_reduce_func_def!(
319 "
320 Find if any of the values along a given dimension in the Array are true
321
322 # Parameters
323
324 - `input` - Input Array
325 - `dim` - Dimension along which the predicate is evaluated
326
327 # Return Values
328
329 Result Array that contains the result of `OR` operation of all elements along given dimension
330
331 # Examples
332
333 ```rust
334 use arrayfire::{Dim4, print, randu, any_true};
335 let dims = Dim4::new(&[5, 3, 1, 1]);
336 let a = randu::<f32>(dims);
337 print(&a);
338 let b = any_true(&a, 0);
339 print(&b);
340 let c = any_true(&a, 1);
341 print(&c);
342 ```
343 ",
344 any_true,
345 af_any_true,
346 bool
347);
348
349dim_reduce_func_def!(
350 "
351 Count number of non-zero elements along a given dimension
352
353 # Parameters
354
355 - `input` - Input Array
356 - `dim` - Dimension along which the non-zero elements are counted
357
358 # Return Values
359
360 Result Array with number of non-zero elements along a given dimension
361
362 # Examples
363
364 ```rust
365 use arrayfire::{Dim4, gt, print, randu, count};
366 let dims = Dim4::new(&[5, 3, 1, 1]);
367 let cnst: f32 = 0.5;
368 let a = gt(&randu::<f32>(dims), &cnst, false);
369 print(&a);
370 let b = count(&a, 0);
371 print(&b);
372 let c = count(&a, 1);
373 print(&c);
374 ```
375 ",
376 count,
377 af_count,
378 u32
379);
380
381dim_reduce_func_def!(
382 "
383 Perform exclusive sum of elements along a given dimension
384
385 # Parameters
386
387 - `input` - Input Array
388 - `dim` - Dimension along which the exclusive scan operation is carried out
389
390 # Return Values
391
392 Result Array with exclusive sums of input Array elements along a given dimension
393
394 # Examples
395
396 ```rust
397 use arrayfire::{Dim4, print, randu, accum};
398 let dims = Dim4::new(&[5, 3, 1, 1]);
399 let a = randu::<f32>(dims);
400 print(&a);
401 let b = accum(&a, 0);
402 print(&b);
403 let c = accum(&a, 1);
404 print(&c);
405 ```
406 ",
407 accum,
408 af_accum,
409 T::AggregateOutType
410);
411
412dim_reduce_func_def!(
413 "
414 Calculate first order numerical difference along a given dimension
415
416 # Parameters
417
418 - `input` - Input Array
419 - `dim` - Dimension along which first order difference is calculated
420
421 # Return Values
422
423 Result Array with first order difference values
424
425 # Examples
426
427 ```rust
428 use arrayfire::{Dim4, print, randu, diff1};
429 let dims = Dim4::new(&[5, 3, 1, 1]);
430 let a = randu::<f32>(dims);
431 print(&a);
432 let b = diff1(&a, 0);
433 print(&b);
434 let c = diff1(&a, 1);
435 print(&c);
436 ```
437 ",
438 diff1,
439 af_diff1,
440 T::InType
441);
442
443dim_reduce_func_def!(
444 "
445 Calculate second order numerical difference along a given dimension
446
447 # Parameters
448
449 - `input` - Input Array
450 - `dim` - Dimension along which second order difference is calculated
451
452 # Return Values
453
454 Result Array with second order difference values
455
456 # Examples
457
458 ```rust
459 use arrayfire::{Dim4, print, randu, diff2};
460 let dims = Dim4::new(&[5, 3, 1, 1]);
461 let a = randu::<f32>(dims);
462 print(&a);
463 let b = diff2(&a, 0);
464 print(&b);
465 let c = diff2(&a, 1);
466 print(&c);
467 ```
468 ",
469 diff2,
470 af_diff2,
471 T::InType
472);
473
474pub fn sum_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::AggregateOutType>
489where
490 T: HasAfEnum,
491 T::AggregateOutType: HasAfEnum,
492{
493 unsafe {
494 let mut temp: af_array = std::ptr::null_mut();
495 let err_val = af_sum_nan(&mut temp as *mut af_array, input.get(), dim, nanval);
496 HANDLE_ERROR(AfError::from(err_val));
497 temp.into()
498 }
499}
500
501pub fn product_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::ProductOutType>
515where
516 T: HasAfEnum,
517 T::ProductOutType: HasAfEnum,
518{
519 unsafe {
520 let mut temp: af_array = std::ptr::null_mut();
521 let err_val = af_product_nan(&mut temp as *mut af_array, input.get(), dim, nanval);
522 HANDLE_ERROR(AfError::from(err_val));
523 temp.into()
524 }
525}
526
527macro_rules! all_reduce_func_def {
528 ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
529 #[doc=$doc_str]
530 pub fn $fn_name<T>(input: &Array<T>)
531 -> (
532 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
533 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
534 )
535 where
536 T: HasAfEnum,
537 <T as HasAfEnum>::$assoc_type: HasAfEnum,
538 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
539 {
540 let mut real: f64 = 0.0;
541 let mut imag: f64 = 0.0;
542 unsafe {
543 let err_val = $ffi_name(
544 &mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
545 );
546 HANDLE_ERROR(AfError::from(err_val));
547 }
548 (
549 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
550 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
551 )
552 }
553 };
554}
555
556all_reduce_func_def!(
557 "
558 Sum all values of the Array
559
560 # Parameters
561
562 - `input` is the input Array
563
564 # Return Values
565
566 A tuple containing the summation result.
567
568 Note: For non-complex data type Arrays, second value of tuple is zero.
569
570 # Examples
571
572 ```rust
573 use arrayfire::{Dim4, print, randu, sum_all};
574 let dims = Dim4::new(&[5, 5, 1, 1]);
575 let a = randu::<f32>(dims);
576 print(&a);
577 println!(\"Result : {:?}\", sum_all(&a));
578 ```
579 ",
580 sum_all,
581 af_sum_all,
582 AggregateOutType
583);
584
585all_reduce_func_def!(
586 "
587 Product of all values of the Array
588
589 # Parameters
590
591 - `input` is the input Array
592
593 # Return Values
594
595 A tuple containing the product result.
596
597 Note: For non-complex data type Arrays, second value of tuple is zero.
598
599 # Examples
600
601 ```rust
602 use arrayfire::{Dim4, print, randu, product_all};
603 let dims = Dim4::new(&[5, 5, 1, 1]);
604 let a = randu::<f32>(dims);
605 print(&a);
606 let res = product_all(&a);
607 println!(\"Result : {:?}\", res);
608 ```
609 ",
610 product_all,
611 af_product_all,
612 ProductOutType
613);
614
615all_reduce_func_def!(
616 "
617 Find minimum among all values of the Array
618
619 # Parameters
620
621 - `input` is the input Array
622
623 # Return Values
624
625 A tuple containing the minimum value.
626
627 Note: For non-complex data type Arrays, second value of tuple is zero.
628
629 # Examples
630
631 ```rust
632 use arrayfire::{Dim4, print, randu, min_all};
633 let dims = Dim4::new(&[5, 5, 1, 1]);
634 let a = randu::<f32>(dims);
635 print(&a);
636 println!(\"Result : {:?}\", min_all(&a));
637 ```
638 ",
639 min_all,
640 af_min_all,
641 InType
642);
643
644all_reduce_func_def!(
645 "
646 Find maximum among all values of the Array
647
648 # Parameters
649
650 - `input` is the input Array
651
652 # Return Values
653
654 A tuple containing the maximum value.
655
656 Note: For non-complex data type Arrays, second value of tuple is zero.
657
658 # Examples
659
660 ```rust
661 use arrayfire::{Dim4, print, randu, max_all};
662 let dims = Dim4::new(&[5, 5, 1, 1]);
663 let a = randu::<f32>(dims);
664 print(&a);
665 println!(\"Result : {:?}\", max_all(&a));
666 ```
667 ",
668 max_all,
669 af_max_all,
670 InType
671);
672
673macro_rules! all_reduce_func_def2 {
674 ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
675 #[doc=$doc_str]
676 pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
677 where
678 T: HasAfEnum,
679 $out_type: HasAfEnum + Fromf64
680 {
681 let mut real: f64 = 0.0;
682 let mut imag: f64 = 0.0;
683 unsafe {
684 let err_val = $ffi_name(
685 &mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
686 );
687 HANDLE_ERROR(AfError::from(err_val));
688 }
689 (<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
690 }
691 };
692}
693
694all_reduce_func_def2!(
695 "
696 Find if all values of Array are non-zero
697
698 # Parameters
699
700 - `input` is the input Array
701
702 # Return Values
703
704 A tuple containing the result of `AND` operation on all values of Array.
705
706 # Examples
707
708 ```rust
709 use arrayfire::{Dim4, print, randu, all_true_all};
710 let dims = Dim4::new(&[5, 5, 1, 1]);
711 let a = randu::<f32>(dims);
712 print(&a);
713 println!(\"Result : {:?}\", all_true_all(&a));
714 ```
715 ",
716 all_true_all,
717 af_all_true_all,
718 bool
719);
720
721all_reduce_func_def2!(
722 "
723 Find if any value of Array is non-zero
724
725 # Parameters
726
727 - `input` is the input Array
728
729 # Return Values
730
731 A tuple containing the result of `OR` operation on all values of Array.
732
733 # Examples
734
735 ```rust
736 use arrayfire::{Dim4, print, randu, any_true_all};
737 let dims = Dim4::new(&[5, 5, 1, 1]);
738 let a = randu::<f32>(dims);
739 print(&a);
740 println!(\"Result : {:?}\", any_true_all(&a));
741 ```
742 ",
743 any_true_all,
744 af_any_true_all,
745 bool
746);
747
748all_reduce_func_def2!(
749 "
750 Count number of non-zero values in the Array
751
752 # Parameters
753
754 - `input` is the input Array
755
756 # Return Values
757
758 A tuple containing the count of non-zero values in the Array.
759
760 # Examples
761
762 ```rust
763 use arrayfire::{Dim4, print, randu, count_all};
764 let dims = Dim4::new(&[5, 5, 1, 1]);
765 let a = randu::<f32>(dims);
766 print(&a);
767 println!(\"Result : {:?}\", count_all(&a));
768 ```
769 ",
770 count_all,
771 af_count_all,
772 u64
773);
774
775pub fn sum_nan_all<T>(
791 input: &Array<T>,
792 val: f64,
793) -> (
794 <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
795 <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
796)
797where
798 T: HasAfEnum,
799 <T as HasAfEnum>::AggregateOutType: HasAfEnum,
800 <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
801{
802 let mut real: f64 = 0.0;
803 let mut imag: f64 = 0.0;
804 unsafe {
805 let err_val = af_sum_nan_all(
806 &mut real as *mut c_double,
807 &mut imag as *mut c_double,
808 input.get(),
809 val,
810 );
811 HANDLE_ERROR(AfError::from(err_val));
812 }
813 (
814 <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(real),
815 <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag),
816 )
817}
818
819pub fn product_nan_all<T>(
835 input: &Array<T>,
836 val: f64,
837) -> (
838 <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
839 <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
840)
841where
842 T: HasAfEnum,
843 <T as HasAfEnum>::ProductOutType: HasAfEnum,
844 <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
845{
846 let mut real: f64 = 0.0;
847 let mut imag: f64 = 0.0;
848 unsafe {
849 let err_val = af_product_nan_all(
850 &mut real as *mut c_double,
851 &mut imag as *mut c_double,
852 input.get(),
853 val,
854 );
855 HANDLE_ERROR(AfError::from(err_val));
856 }
857 (
858 <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(real),
859 <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(imag),
860 )
861}
862
863macro_rules! dim_ireduce_func_def {
864 ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ident) => {
865 #[doc=$doc_str]
866 pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> (Array<T::$out_type>, Array<u32>)
867 where
868 T: HasAfEnum,
869 T::$out_type: HasAfEnum,
870 {
871 unsafe {
872 let mut temp: af_array = std::ptr::null_mut();
873 let mut idx: af_array = std::ptr::null_mut();
874 let err_val = $ffi_name(
875 &mut temp as *mut af_array, &mut idx as *mut af_array, input.get(), dim,
876 );
877 HANDLE_ERROR(AfError::from(err_val));
878 (temp.into(), idx.into())
879 }
880 }
881 };
882}
883
884dim_ireduce_func_def!("
885 Find minimum value along given dimension and their corresponding indices
886
887 # Parameters
888
889 - `input` - Input Array
890 - `dim` - Dimension along which the input Array will be reduced
891
892 # Return Values
893
894 A tuple of Arrays: Array minimum values and Array containing their index along the reduced dimension.
895 ", imin, af_imin, InType);
896
897dim_ireduce_func_def!("
898 Find maximum value along given dimension and their corresponding indices
899
900 # Parameters
901
902 - `input` - Input Array
903 - `dim` - Dimension along which the input Array will be reduced
904
905 # Return Values
906
907 A tuple of Arrays: Array maximum values and Array containing their index along the reduced dimension.
908 ", imax, af_imax, InType);
909
910macro_rules! all_ireduce_func_def {
911 ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
912 #[doc=$doc_str]
913 pub fn $fn_name<T>(input: &Array<T>)
914 -> (
915 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
916 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
917 u32
918 )
919 where
920 T: HasAfEnum,
921 <T as HasAfEnum>::$assoc_type: HasAfEnum,
922 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
923 {
924 let mut real: f64 = 0.0;
925 let mut imag: f64 = 0.0;
926 let mut temp: u32 = 0;
927 unsafe {
928 let err_val = $ffi_name(
929 &mut real as *mut c_double, &mut imag as *mut c_double,
930 &mut temp as *mut c_uint, input.get(),
931 );
932 HANDLE_ERROR(AfError::from(err_val));
933 }
934 (
935 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
936 <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
937 temp,
938 )
939 }
940 };
941}
942
943all_ireduce_func_def!(
944 "
945 Find minimum and it's index in the whole Array
946
947 # Parameters
948
949 `input` - Input Array
950
951 # Return Values
952
953 A triplet with
954
955 * minimum element of Array in the first component.
956 * second component of value zero if Array is of non-complex type.
957 * index of minimum element in the third component.
958 ",
959 imin_all,
960 af_imin_all,
961 InType
962);
963all_ireduce_func_def!(
964 "
965 Find maximum and it's index in the whole Array
966
967 # Parameters
968
969 `input` - Input Array
970
971 # Return Values
972
973 A triplet with
974
975 - maximum element of Array in the first component.
976 - second component of value zero if Array is of non-complex type.
977 - index of maximum element in the third component.
978 ",
979 imax_all,
980 af_imax_all,
981 InType
982);
983
984pub fn locate<T: HasAfEnum>(input: &Array<T>) -> Array<u32> {
996 unsafe {
997 let mut temp: af_array = std::ptr::null_mut();
998 let err_val = af_where(&mut temp as *mut af_array, input.get());
999 HANDLE_ERROR(AfError::from(err_val));
1000 temp.into()
1001 }
1002}
1003
1004pub fn sort<T>(input: &Array<T>, dim: u32, ascending: bool) -> Array<T>
1019where
1020 T: HasAfEnum + RealNumber,
1021{
1022 unsafe {
1023 let mut temp: af_array = std::ptr::null_mut();
1024 let err_val = af_sort(&mut temp as *mut af_array, input.get(), dim, ascending);
1025 HANDLE_ERROR(AfError::from(err_val));
1026 temp.into()
1027 }
1028}
1029
1030pub fn sort_index<T>(input: &Array<T>, dim: u32, ascending: bool) -> (Array<T>, Array<u32>)
1047where
1048 T: HasAfEnum + RealNumber,
1049{
1050 unsafe {
1051 let mut temp: af_array = std::ptr::null_mut();
1052 let mut idx: af_array = std::ptr::null_mut();
1053 let err_val = af_sort_index(
1054 &mut temp as *mut af_array,
1055 &mut idx as *mut af_array,
1056 input.get(),
1057 dim,
1058 ascending,
1059 );
1060 HANDLE_ERROR(AfError::from(err_val));
1061 (temp.into(), idx.into())
1062 }
1063}
1064
1065pub fn sort_by_key<K, V>(
1084 keys: &Array<K>,
1085 vals: &Array<V>,
1086 dim: u32,
1087 ascending: bool,
1088) -> (Array<K>, Array<V>)
1089where
1090 K: HasAfEnum + RealNumber,
1091 V: HasAfEnum,
1092{
1093 unsafe {
1094 let mut temp: af_array = std::ptr::null_mut();
1095 let mut temp2: af_array = std::ptr::null_mut();
1096 let err_val = af_sort_by_key(
1097 &mut temp as *mut af_array,
1098 &mut temp2 as *mut af_array,
1099 keys.get(),
1100 vals.get(),
1101 dim,
1102 ascending,
1103 );
1104 HANDLE_ERROR(AfError::from(err_val));
1105 (temp.into(), temp2.into())
1106 }
1107}
1108
1109pub fn set_unique<T>(input: &Array<T>, is_sorted: bool) -> Array<T>
1121where
1122 T: HasAfEnum + RealNumber,
1123{
1124 unsafe {
1125 let mut temp: af_array = std::ptr::null_mut();
1126 let err_val = af_set_unique(&mut temp as *mut af_array, input.get(), is_sorted);
1127 HANDLE_ERROR(AfError::from(err_val));
1128 temp.into()
1129 }
1130}
1131
1132pub fn set_union<T>(first: &Array<T>, second: &Array<T>, is_unique: bool) -> Array<T>
1144where
1145 T: HasAfEnum + RealNumber,
1146{
1147 unsafe {
1148 let mut temp: af_array = std::ptr::null_mut();
1149 let err_val = af_set_union(
1150 &mut temp as *mut af_array,
1151 first.get(),
1152 second.get(),
1153 is_unique,
1154 );
1155 HANDLE_ERROR(AfError::from(err_val));
1156 temp.into()
1157 }
1158}
1159
1160pub fn set_intersect<T>(first: &Array<T>, second: &Array<T>, is_unique: bool) -> Array<T>
1172where
1173 T: HasAfEnum + RealNumber,
1174{
1175 unsafe {
1176 let mut temp: af_array = std::ptr::null_mut();
1177 let err_val = af_set_intersect(
1178 &mut temp as *mut af_array,
1179 first.get(),
1180 second.get(),
1181 is_unique,
1182 );
1183 HANDLE_ERROR(AfError::from(err_val));
1184 temp.into()
1185 }
1186}
1187
1188pub fn scan<T>(
1202 input: &Array<T>,
1203 dim: i32,
1204 op: BinaryOp,
1205 inclusive: bool,
1206) -> Array<T::AggregateOutType>
1207where
1208 T: HasAfEnum,
1209 T::AggregateOutType: HasAfEnum,
1210{
1211 unsafe {
1212 let mut temp: af_array = std::ptr::null_mut();
1213 let err_val = af_scan(
1214 &mut temp as *mut af_array,
1215 input.get(),
1216 dim,
1217 op as u32,
1218 inclusive,
1219 );
1220 HANDLE_ERROR(AfError::from(err_val));
1221 temp.into()
1222 }
1223}
1224
1225pub fn scan_by_key<K, V>(
1240 key: &Array<K>,
1241 input: &Array<V>,
1242 dim: i32,
1243 op: BinaryOp,
1244 inclusive: bool,
1245) -> Array<V::AggregateOutType>
1246where
1247 V: HasAfEnum,
1248 V::AggregateOutType: HasAfEnum,
1249 K: HasAfEnum + Scanable,
1250{
1251 unsafe {
1252 let mut temp: af_array = std::ptr::null_mut();
1253 let err_val = af_scan_by_key(
1254 &mut temp as *mut af_array,
1255 key.get(),
1256 input.get(),
1257 dim,
1258 op as u32,
1259 inclusive,
1260 );
1261 HANDLE_ERROR(AfError::from(err_val));
1262 temp.into()
1263 }
1264}
1265
1266macro_rules! dim_reduce_by_key_func_def {
1267 ($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1268 #[doc=$brief_str]
1269 #[doc=$ex_str]
1280 pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1281 dim: i32
1282 ) -> (Array<KeyType>, Array<$out_type>)
1283 where
1284 KeyType: ReduceByKeyInput,
1285 ValueType: HasAfEnum,
1286 $out_type: HasAfEnum,
1287 {
1288 unsafe {
1289 let mut out_keys: af_array = std::ptr::null_mut();
1290 let mut out_vals: af_array = std::ptr::null_mut();
1291 let err_val = $ffi_name(
1292 &mut out_keys as *mut af_array, &mut out_vals as *mut af_array,
1293 keys.get(), vals.get(), dim,
1294 );
1295 HANDLE_ERROR(AfError::from(err_val));
1296 (out_keys.into(), out_vals.into())
1297 }
1298 }
1299 };
1300}
1301
1302dim_reduce_by_key_func_def!(
1303 "
1304Key based AND of elements along a given dimension
1305
1306All positive non-zero values are considered true, while negative and zero
1307values are considered as false.
1308",
1309 "
1310# Examples
1311```rust
1312use arrayfire::{Dim4, print, randu, all_true_by_key};
1313let dims = Dim4::new(&[5, 3, 1, 1]);
1314let vals = randu::<f32>(dims);
1315let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1316print(&vals);
1317print(&keys);
1318let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1319print(&out_keys);
1320print(&out_vals);
1321```
1322",
1323 all_true_by_key,
1324 af_all_true_by_key,
1325 ValueType::AggregateOutType
1326);
1327
1328dim_reduce_by_key_func_def!(
1329 "
1330Key based OR of elements along a given dimension
1331
1332All positive non-zero values are considered true, while negative and zero
1333values are considered as false.
1334",
1335 "
1336# Examples
1337```rust
1338use arrayfire::{Dim4, print, randu, any_true_by_key};
1339let dims = Dim4::new(&[5, 3, 1, 1]);
1340let vals = randu::<f32>(dims);
1341let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1342print(&vals);
1343print(&keys);
1344let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1345print(&out_keys);
1346print(&out_vals);
1347```
1348",
1349 any_true_by_key,
1350 af_any_true_by_key,
1351 ValueType::AggregateOutType
1352);
1353
1354dim_reduce_by_key_func_def!(
1355 "Find total count of elements with similar keys along a given dimension",
1356 "",
1357 count_by_key,
1358 af_count_by_key,
1359 ValueType::AggregateOutType
1360);
1361
1362dim_reduce_by_key_func_def!(
1363 "Find maximum among values of similar keys along a given dimension",
1364 "",
1365 max_by_key,
1366 af_max_by_key,
1367 ValueType::AggregateOutType
1368);
1369
1370dim_reduce_by_key_func_def!(
1371 "Find minimum among values of similar keys along a given dimension",
1372 "",
1373 min_by_key,
1374 af_min_by_key,
1375 ValueType::AggregateOutType
1376);
1377
1378dim_reduce_by_key_func_def!(
1379 "Find product of all values with similar keys along a given dimension",
1380 "",
1381 product_by_key,
1382 af_product_by_key,
1383 ValueType::ProductOutType
1384);
1385
1386dim_reduce_by_key_func_def!(
1387 "Find sum of all values with similar keys along a given dimension",
1388 "",
1389 sum_by_key,
1390 af_sum_by_key,
1391 ValueType::AggregateOutType
1392);
1393
1394macro_rules! dim_reduce_by_key_nan_func_def {
1395 ($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1396 #[doc=$brief_str]
1397 #[doc=$ex_str]
1411 pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1412 dim: i32, replace_value: f64
1413 ) -> (Array<KeyType>, Array<$out_type>)
1414 where
1415 KeyType: ReduceByKeyInput,
1416 ValueType: HasAfEnum,
1417 $out_type: HasAfEnum,
1418 {
1419 unsafe {
1420 let mut out_keys: af_array = std::ptr::null_mut();
1421 let mut out_vals: af_array = std::ptr::null_mut();
1422 let err_val = $ffi_name(
1423 &mut out_keys as *mut af_array,
1424 &mut out_vals as *mut af_array,
1425 keys.get(), vals.get(), dim, replace_value,
1426 );
1427 HANDLE_ERROR(AfError::from(err_val));
1428 (out_keys.into(), out_vals.into())
1429 }
1430 }
1431 };
1432}
1433
1434dim_reduce_by_key_nan_func_def!(
1435 "Compute sum of all values with similar keys along a given dimension",
1436 "",
1437 sum_by_key_nan,
1438 af_sum_by_key_nan,
1439 ValueType::AggregateOutType
1440);
1441
1442dim_reduce_by_key_nan_func_def!(
1443 "Compute product of all values with similar keys along a given dimension",
1444 "",
1445 product_by_key_nan,
1446 af_product_by_key_nan,
1447 ValueType::ProductOutType
1448);
1449
1450pub fn max_ragged<T>(
1487 input: &Array<T>,
1488 ragged_len: &Array<u32>,
1489 dim: i32,
1490) -> (Array<T::InType>, Array<u32>)
1491where
1492 T: HasAfEnum,
1493 T::InType: HasAfEnum,
1494{
1495 unsafe {
1496 let mut out_vals: af_array = std::ptr::null_mut();
1497 let mut out_idxs: af_array = std::ptr::null_mut();
1498 let err_val = af_max_ragged(
1499 &mut out_vals as *mut af_array,
1500 &mut out_idxs as *mut af_array,
1501 input.get(),
1502 ragged_len.get(),
1503 dim,
1504 );
1505 HANDLE_ERROR(AfError::from(err_val));
1506 (out_vals.into(), out_idxs.into())
1507 }
1508}
1509
1510#[cfg(test)]
1511mod tests {
1512 use super::super::core::c32;
1513 use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all};
1514 use crate::core::set_device;
1515 use crate::randu;
1516
1517 #[test]
1518 fn all_reduce_api() {
1519 set_device(0);
1520 let a = randu!(c32; 10, 10);
1521 println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
1522
1523 let b = randu!(bool; 10, 10);
1524 println!("reduction of bool matrix: {:?}", sum_all(&b));
1525
1526 println!(
1527 "reduction of complex f32 matrix after replacing nan with {}: {:?}",
1528 1.0,
1529 product_nan_all(&a, 1.0)
1530 );
1531
1532 println!(
1533 "reduction of bool matrix after replacing nan with {}: {:?}",
1534 0.0,
1535 sum_nan_all(&b, 0.0)
1536 );
1537 }
1538
1539 #[test]
1540 fn all_ireduce_api() {
1541 set_device(0);
1542 let a = randu!(c32; 10);
1543 println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
1544
1545 let b = randu!(u32; 10);
1546 println!("reduction of bool matrix: {:?}", imax_all(&b));
1547 }
1548}