Thanks to visit codestin.com
Credit goes to docs.rs

arrayfire/algorithm/
mod.rs

1use super::core::{
2    af_array, AfError, Array, BinaryOp, Fromf64, HasAfEnum, RealNumber, ReduceByKeyInput, Scanable,
3    HANDLE_ERROR,
4};
5
6use libc::{c_double, c_int, c_uint};
7
8extern "C" {
9    fn af_sum(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
10    fn af_sum_nan(out: *mut af_array, input: af_array, dim: c_int, nanval: c_double) -> c_int;
11    fn af_product(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
12    fn af_product_nan(out: *mut af_array, input: af_array, dim: c_int, val: c_double) -> c_int;
13    fn af_min(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
14    fn af_max(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
15    fn af_all_true(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
16    fn af_any_true(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
17    fn af_count(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
18    fn af_sum_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
19    fn af_sum_nan_all(r: *mut c_double, i: *mut c_double, input: af_array, val: c_double) -> c_int;
20    fn af_product_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
21    fn af_product_nan_all(
22        r: *mut c_double,
23        i: *mut c_double,
24        input: af_array,
25        val: c_double,
26    ) -> c_int;
27    fn af_min_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
28    fn af_max_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
29    fn af_all_true_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
30    fn af_any_true_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
31    fn af_count_all(r: *mut c_double, i: *mut c_double, input: af_array) -> c_int;
32    fn af_imin(out: *mut af_array, idx: *mut af_array, input: af_array, dim: c_int) -> c_int;
33    fn af_imax(out: *mut af_array, idx: *mut af_array, input: af_array, dim: c_int) -> c_int;
34    fn af_imin_all(r: *mut c_double, i: *mut c_double, idx: *mut c_uint, input: af_array) -> c_int;
35    fn af_imax_all(r: *mut c_double, i: *mut c_double, idx: *mut c_uint, input: af_array) -> c_int;
36    fn af_accum(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
37    fn af_where(out: *mut af_array, input: af_array) -> c_int;
38    fn af_diff1(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
39    fn af_diff2(out: *mut af_array, input: af_array, dim: c_int) -> c_int;
40    fn af_sort(out: *mut af_array, input: af_array, dim: c_uint, ascend: bool) -> c_int;
41    fn af_sort_index(
42        o: *mut af_array,
43        i: *mut af_array,
44        inp: af_array,
45        d: c_uint,
46        a: bool,
47    ) -> c_int;
48    fn af_set_unique(out: *mut af_array, input: af_array, is_sorted: bool) -> c_int;
49    fn af_set_union(out: *mut af_array, first: af_array, second: af_array, is_unq: bool) -> c_int;
50    fn af_set_intersect(out: *mut af_array, one: af_array, two: af_array, is_unq: bool) -> c_int;
51
52    fn af_sort_by_key(
53        out_keys: *mut af_array,
54        out_vals: *mut af_array,
55        in_keys: af_array,
56        in_vals: af_array,
57        dim: c_uint,
58        ascend: bool,
59    ) -> c_int;
60
61    fn af_scan(out: *mut af_array, inp: af_array, dim: c_int, op: c_uint, inclusive: bool)
62        -> c_int;
63    fn af_scan_by_key(
64        out: *mut af_array,
65        key: af_array,
66        inp: af_array,
67        dim: c_int,
68        op: c_uint,
69        inclusive: bool,
70    ) -> c_int;
71    fn af_all_true_by_key(
72        keys_out: *mut af_array,
73        vals_out: *mut af_array,
74        keys: af_array,
75        vals: af_array,
76        dim: c_int,
77    ) -> c_int;
78    fn af_any_true_by_key(
79        keys_out: *mut af_array,
80        vals_out: *mut af_array,
81        keys: af_array,
82        vals: af_array,
83        dim: c_int,
84    ) -> c_int;
85    fn af_count_by_key(
86        keys_out: *mut af_array,
87        vals_out: *mut af_array,
88        keys: af_array,
89        vals: af_array,
90        dim: c_int,
91    ) -> c_int;
92    fn af_max_by_key(
93        keys_out: *mut af_array,
94        vals_out: *mut af_array,
95        keys: af_array,
96        vals: af_array,
97        dim: c_int,
98    ) -> c_int;
99    fn af_min_by_key(
100        keys_out: *mut af_array,
101        vals_out: *mut af_array,
102        keys: af_array,
103        vals: af_array,
104        dim: c_int,
105    ) -> c_int;
106    fn af_product_by_key(
107        keys_out: *mut af_array,
108        vals_out: *mut af_array,
109        keys: af_array,
110        vals: af_array,
111        dim: c_int,
112    ) -> c_int;
113    fn af_product_by_key_nan(
114        keys_out: *mut af_array,
115        vals_out: *mut af_array,
116        keys: af_array,
117        vals: af_array,
118        dim: c_int,
119        nan_val: c_double,
120    ) -> c_int;
121    fn af_sum_by_key(
122        keys_out: *mut af_array,
123        vals_out: *mut af_array,
124        keys: af_array,
125        vals: af_array,
126        dim: c_int,
127    ) -> c_int;
128    fn af_sum_by_key_nan(
129        keys_out: *mut af_array,
130        vals_out: *mut af_array,
131        keys: af_array,
132        vals: af_array,
133        dim: c_int,
134        nan_val: c_double,
135    ) -> c_int;
136    fn af_max_ragged(
137        val_out: *mut af_array,
138        idx_out: *mut af_array,
139        input: af_array,
140        ragged_len: af_array,
141        dim: c_int,
142    ) -> c_int;
143}
144
145macro_rules! dim_reduce_func_def {
146    ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
147        #[doc=$doc_str]
148        pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> Array<$out_type>
149        where
150            T: HasAfEnum,
151            $out_type: HasAfEnum,
152        {
153            unsafe {
154                let mut temp: af_array = std::ptr::null_mut();
155                let err_val = $ffi_name(&mut temp as *mut af_array, input.get(), dim);
156                HANDLE_ERROR(AfError::from(err_val));
157                temp.into()
158            }
159        }
160    };
161}
162
163dim_reduce_func_def!(
164    "
165    Sum elements along a given dimension
166
167    # Parameters
168
169    - `input` - Input Array
170    - `dim`   - Dimension along which the input Array will be reduced
171
172    # Return Values
173
174    Result Array after summing all elements along given dimension
175
176    # Examples
177
178    ```rust
179    use arrayfire::{Dim4, print, randu, sum};
180    let dims = Dim4::new(&[5, 3, 1, 1]);
181    let a = randu::<f32>(dims);
182    print(&a);
183    let b = sum(&a, 0);
184    print(&b);
185    let c = sum(&a, 1);
186    print(&c);
187    ```
188    ",
189    sum,
190    af_sum,
191    T::AggregateOutType
192);
193
194dim_reduce_func_def!(
195    "
196    Compute product of elements along a given dimension
197
198    # Parameters
199
200    - `input` - Input Array
201    - `dim`   - Dimension along which the input Array will be reduced
202
203    # Return Values
204
205    Result Array after multiplying all elements along given dimension
206
207    # Examples
208
209    ```rust
210    use arrayfire::{Dim4, print, randu, product};
211    let dims = Dim4::new(&[5, 3, 1, 1]);
212    let a = randu::<f32>(dims);
213    print(&a);
214    let b = product(&a, 0);
215    print(&b);
216    let c = product(&a, 1);
217    print(&c);
218    ```
219    ",
220    product,
221    af_product,
222    T::ProductOutType
223);
224
225dim_reduce_func_def!(
226    "
227    Find minimum among elements of given dimension
228
229    # Parameters
230
231    - `input` - Input Array
232    - `dim`   - Dimension along which the input Array will be reduced
233
234    # Return Values
235
236    Result Array after finding minimum among elements along a given dimension
237
238    # Examples
239
240    ```rust
241    use arrayfire::{Dim4, print, randu, min};
242    let dims = Dim4::new(&[5, 3, 1, 1]);
243    let a = randu::<f32>(dims);
244    print(&a);
245    let b = min(&a, 0);
246    print(&b);
247    let c = min(&a, 1);
248    print(&c);
249    ```
250    ",
251    min,
252    af_min,
253    T::InType
254);
255
256dim_reduce_func_def!(
257    "
258    Find maximum among elements of given dimension
259
260    # Parameters
261
262    - `input` - Input Array
263    - `dim`   - Dimension along which the input Array will be reduced
264
265    # Return Values
266
267    Result Array after finding maximum among elements along a given dimension
268
269    # Examples
270
271    ```rust
272    use arrayfire::{Dim4, print, randu, max};
273    let dims = Dim4::new(&[5, 3, 1, 1]);
274    let a = randu::<f32>(dims);
275    print(&a);
276    let b = max(&a, 0);
277    print(&b);
278    let c = max(&a, 1);
279    print(&c);
280    ```
281    ",
282    max,
283    af_max,
284    T::InType
285);
286
287dim_reduce_func_def!(
288    "
289    Find if all of the values along a given dimension in the Array are true
290
291    # Parameters
292
293    - `input` - Input Array
294    - `dim`   - Dimension along which the predicate is evaluated
295
296    # Return Values
297
298    Result Array that contains the result of `AND` operation of all elements along given dimension
299
300    # Examples
301
302    ```rust
303    use arrayfire::{Dim4, print, randu, all_true};
304    let dims = Dim4::new(&[5, 3, 1, 1]);
305    let a = randu::<f32>(dims);
306    print(&a);
307    let b = all_true(&a, 0);
308    print(&b);
309    let c = all_true(&a, 1);
310    print(&c);
311    ```
312    ",
313    all_true,
314    af_all_true,
315    bool
316);
317
318dim_reduce_func_def!(
319    "
320    Find if any of the values along a given dimension in the Array are true
321
322    # Parameters
323
324    - `input` - Input Array
325    - `dim`   - Dimension along which the predicate is evaluated
326
327    # Return Values
328
329    Result Array that contains the result of `OR` operation of all elements along given dimension
330
331    # Examples
332
333    ```rust
334    use arrayfire::{Dim4, print, randu, any_true};
335    let dims = Dim4::new(&[5, 3, 1, 1]);
336    let a = randu::<f32>(dims);
337    print(&a);
338    let b = any_true(&a, 0);
339    print(&b);
340    let c = any_true(&a, 1);
341    print(&c);
342    ```
343    ",
344    any_true,
345    af_any_true,
346    bool
347);
348
349dim_reduce_func_def!(
350    "
351    Count number of non-zero elements along a given dimension
352
353    # Parameters
354
355    - `input` - Input Array
356    - `dim`   - Dimension along which the non-zero elements are counted
357
358    # Return Values
359
360    Result Array with number of non-zero elements along a given dimension
361
362    # Examples
363
364    ```rust
365    use arrayfire::{Dim4, gt, print, randu, count};
366    let dims = Dim4::new(&[5, 3, 1, 1]);
367    let cnst: f32 = 0.5;
368    let a = gt(&randu::<f32>(dims), &cnst, false);
369    print(&a);
370    let b = count(&a, 0);
371    print(&b);
372    let c = count(&a, 1);
373    print(&c);
374    ```
375    ",
376    count,
377    af_count,
378    u32
379);
380
381dim_reduce_func_def!(
382    "
383    Perform exclusive sum of elements along a given dimension
384
385    # Parameters
386
387    - `input` - Input Array
388    - `dim`   - Dimension along which the exclusive scan operation is carried out
389
390    # Return Values
391
392    Result Array with exclusive sums of input Array elements along a given dimension
393
394    # Examples
395
396    ```rust
397    use arrayfire::{Dim4, print, randu, accum};
398    let dims = Dim4::new(&[5, 3, 1, 1]);
399    let a = randu::<f32>(dims);
400    print(&a);
401    let b = accum(&a, 0);
402    print(&b);
403    let c = accum(&a, 1);
404    print(&c);
405    ```
406    ",
407    accum,
408    af_accum,
409    T::AggregateOutType
410);
411
412dim_reduce_func_def!(
413    "
414    Calculate first order numerical difference along a given dimension
415
416    # Parameters
417
418    - `input` - Input Array
419    - `dim`   - Dimension along which first order difference is calculated
420
421    # Return Values
422
423    Result Array with first order difference values
424
425    # Examples
426
427    ```rust
428    use arrayfire::{Dim4, print, randu, diff1};
429    let dims = Dim4::new(&[5, 3, 1, 1]);
430    let a = randu::<f32>(dims);
431    print(&a);
432    let b = diff1(&a, 0);
433    print(&b);
434    let c = diff1(&a, 1);
435    print(&c);
436    ```
437    ",
438    diff1,
439    af_diff1,
440    T::InType
441);
442
443dim_reduce_func_def!(
444    "
445    Calculate second order numerical difference along a given dimension
446
447    # Parameters
448
449    - `input` - Input Array
450    - `dim`   - Dimension along which second order difference is calculated
451
452    # Return Values
453
454    Result Array with second order difference values
455
456    # Examples
457
458    ```rust
459    use arrayfire::{Dim4, print, randu, diff2};
460    let dims = Dim4::new(&[5, 3, 1, 1]);
461    let a = randu::<f32>(dims);
462    print(&a);
463    let b = diff2(&a, 0);
464    print(&b);
465    let c = diff2(&a, 1);
466    print(&c);
467    ```
468    ",
469    diff2,
470    af_diff2,
471    T::InType
472);
473
474/// Sum along specific dimension using user specified value instead of `NAN` values
475///
476/// Sum values of the `input` Array along `dim` dimension after replacing any `NAN` values in the
477/// Array with the value of the parameter `nanval`.
478///
479/// # Parameters
480///
481/// - `input` is the input Array
482/// - `dim` is reduction dimension
483/// - `nanval` is value with which all the `NAN` values of Array are replaced with
484///
485/// # Return Values
486///
487/// Array that is reduced along given dimension via addition operation
488pub fn sum_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::AggregateOutType>
489where
490    T: HasAfEnum,
491    T::AggregateOutType: HasAfEnum,
492{
493    unsafe {
494        let mut temp: af_array = std::ptr::null_mut();
495        let err_val = af_sum_nan(&mut temp as *mut af_array, input.get(), dim, nanval);
496        HANDLE_ERROR(AfError::from(err_val));
497        temp.into()
498    }
499}
500
501/// Product of elements along specific dimension using user specified value instead of `NAN` values
502///
503/// Compute product of the values of the `input` Array along `dim` dimension after replacing any `NAN` values in the Array with `nanval` value.
504///
505/// # Parameters
506///
507/// - `input` is the input Array
508/// - `dim` is reduction dimension
509/// - `nanval` is value with which all the `NAN` values of Array are replaced with
510///
511/// # Return Values
512///
513/// Array that is reduced along given dimension via multiplication operation
514pub fn product_nan<T>(input: &Array<T>, dim: i32, nanval: f64) -> Array<T::ProductOutType>
515where
516    T: HasAfEnum,
517    T::ProductOutType: HasAfEnum,
518{
519    unsafe {
520        let mut temp: af_array = std::ptr::null_mut();
521        let err_val = af_product_nan(&mut temp as *mut af_array, input.get(), dim, nanval);
522        HANDLE_ERROR(AfError::from(err_val));
523        temp.into()
524    }
525}
526
527macro_rules! all_reduce_func_def {
528    ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
529        #[doc=$doc_str]
530        pub fn $fn_name<T>(input: &Array<T>)
531            -> (
532                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
533                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType
534               )
535        where
536            T: HasAfEnum,
537            <T as HasAfEnum>::$assoc_type: HasAfEnum,
538            <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
539        {
540            let mut real: f64 = 0.0;
541            let mut imag: f64 = 0.0;
542            unsafe {
543                let err_val = $ffi_name(
544                    &mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
545                );
546                HANDLE_ERROR(AfError::from(err_val));
547            }
548            (
549                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
550                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
551            )
552        }
553    };
554}
555
556all_reduce_func_def!(
557    "
558    Sum all values of the Array
559
560    # Parameters
561
562    - `input` is the input Array
563
564    # Return Values
565
566    A tuple containing the summation result.
567
568    Note: For non-complex data type Arrays, second value of tuple is zero.
569
570    # Examples
571
572    ```rust
573    use arrayfire::{Dim4, print, randu, sum_all};
574    let dims = Dim4::new(&[5, 5, 1, 1]);
575    let a = randu::<f32>(dims);
576    print(&a);
577    println!(\"Result : {:?}\", sum_all(&a));
578    ```
579    ",
580    sum_all,
581    af_sum_all,
582    AggregateOutType
583);
584
585all_reduce_func_def!(
586    "
587    Product of all values of the Array
588
589    # Parameters
590
591    - `input` is the input Array
592
593    # Return Values
594
595    A tuple containing the product result.
596
597    Note: For non-complex data type Arrays, second value of tuple is zero.
598
599    # Examples
600
601    ```rust
602    use arrayfire::{Dim4, print, randu, product_all};
603    let dims = Dim4::new(&[5, 5, 1, 1]);
604    let a = randu::<f32>(dims);
605    print(&a);
606    let res = product_all(&a);
607    println!(\"Result : {:?}\", res);
608    ```
609    ",
610    product_all,
611    af_product_all,
612    ProductOutType
613);
614
615all_reduce_func_def!(
616    "
617    Find minimum among all values of the Array
618
619    # Parameters
620
621    - `input` is the input Array
622
623    # Return Values
624
625    A tuple containing the minimum value.
626
627    Note: For non-complex data type Arrays, second value of tuple is zero.
628
629    # Examples
630
631    ```rust
632    use arrayfire::{Dim4, print, randu, min_all};
633    let dims = Dim4::new(&[5, 5, 1, 1]);
634    let a = randu::<f32>(dims);
635    print(&a);
636    println!(\"Result : {:?}\", min_all(&a));
637    ```
638    ",
639    min_all,
640    af_min_all,
641    InType
642);
643
644all_reduce_func_def!(
645    "
646    Find maximum among all values of the Array
647
648    # Parameters
649
650    - `input` is the input Array
651
652    # Return Values
653
654    A tuple containing the maximum value.
655
656    Note: For non-complex data type Arrays, second value of tuple is zero.
657
658    # Examples
659
660    ```rust
661    use arrayfire::{Dim4, print, randu, max_all};
662    let dims = Dim4::new(&[5, 5, 1, 1]);
663    let a = randu::<f32>(dims);
664    print(&a);
665    println!(\"Result : {:?}\", max_all(&a));
666    ```
667    ",
668    max_all,
669    af_max_all,
670    InType
671);
672
673macro_rules! all_reduce_func_def2 {
674    ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type:ty) => {
675        #[doc=$doc_str]
676        pub fn $fn_name<T>(input: &Array<T>) -> ($out_type, $out_type)
677        where
678            T: HasAfEnum,
679            $out_type: HasAfEnum + Fromf64
680        {
681            let mut real: f64 = 0.0;
682            let mut imag: f64 = 0.0;
683            unsafe {
684                let err_val = $ffi_name(
685                    &mut real as *mut c_double, &mut imag as *mut c_double, input.get(),
686                );
687                HANDLE_ERROR(AfError::from(err_val));
688            }
689            (<$out_type>::fromf64(real), <$out_type>::fromf64(imag))
690        }
691    };
692}
693
694all_reduce_func_def2!(
695    "
696    Find if all values of Array are non-zero
697
698    # Parameters
699
700    - `input` is the input Array
701
702    # Return Values
703
704    A tuple containing the result of `AND` operation on all values of Array.
705
706    # Examples
707
708    ```rust
709    use arrayfire::{Dim4, print, randu, all_true_all};
710    let dims = Dim4::new(&[5, 5, 1, 1]);
711    let a = randu::<f32>(dims);
712    print(&a);
713    println!(\"Result : {:?}\", all_true_all(&a));
714    ```
715    ",
716    all_true_all,
717    af_all_true_all,
718    bool
719);
720
721all_reduce_func_def2!(
722    "
723    Find if any value of Array is non-zero
724
725    # Parameters
726
727    - `input` is the input Array
728
729    # Return Values
730
731    A tuple containing the result of `OR` operation on all values of Array.
732
733    # Examples
734
735    ```rust
736    use arrayfire::{Dim4, print, randu, any_true_all};
737    let dims = Dim4::new(&[5, 5, 1, 1]);
738    let a = randu::<f32>(dims);
739    print(&a);
740    println!(\"Result : {:?}\", any_true_all(&a));
741    ```
742    ",
743    any_true_all,
744    af_any_true_all,
745    bool
746);
747
748all_reduce_func_def2!(
749    "
750    Count number of non-zero values in the Array
751
752    # Parameters
753
754    - `input` is the input Array
755
756    # Return Values
757
758    A tuple containing the count of non-zero values in the Array.
759
760    # Examples
761
762    ```rust
763    use arrayfire::{Dim4, print, randu, count_all};
764    let dims = Dim4::new(&[5, 5, 1, 1]);
765    let a = randu::<f32>(dims);
766    print(&a);
767    println!(\"Result : {:?}\", count_all(&a));
768    ```
769    ",
770    count_all,
771    af_count_all,
772    u64
773);
774
775/// Sum all values using user provided value for `NAN`
776///
777/// Sum all the values of the `input` Array after replacing any `NAN` values with `val`.
778///
779/// # Parameters
780///
781/// - `input` is the input Array
782/// - `val` is the val that replaces all `NAN` values of the Array before reduction operation is
783/// performed.
784///
785/// # Return Values
786///
787/// A tuple of summation result.
788///
789/// Note: For non-complex data type Arrays, second value of tuple is zero.
790pub fn sum_nan_all<T>(
791    input: &Array<T>,
792    val: f64,
793) -> (
794    <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
795    <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType,
796)
797where
798    T: HasAfEnum,
799    <T as HasAfEnum>::AggregateOutType: HasAfEnum,
800    <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
801{
802    let mut real: f64 = 0.0;
803    let mut imag: f64 = 0.0;
804    unsafe {
805        let err_val = af_sum_nan_all(
806            &mut real as *mut c_double,
807            &mut imag as *mut c_double,
808            input.get(),
809            val,
810        );
811        HANDLE_ERROR(AfError::from(err_val));
812    }
813    (
814        <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(real),
815        <<T as HasAfEnum>::AggregateOutType as HasAfEnum>::BaseType::fromf64(imag),
816    )
817}
818
819/// Product of all values using user provided value for `NAN`
820///
821/// Compute the product of all the values of the `input` Array after replacing any `NAN` values with `val`
822///
823/// # Parameters
824///
825/// - `input` is the input Array
826/// - `val` is the val that replaces all `NAN` values of the Array before reduction operation is
827/// performed.
828///
829/// # Return Values
830///
831/// A tuple of product result.
832///
833/// Note: For non-complex data type Arrays, second value of tuple is zero.
834pub fn product_nan_all<T>(
835    input: &Array<T>,
836    val: f64,
837) -> (
838    <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
839    <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType,
840)
841where
842    T: HasAfEnum,
843    <T as HasAfEnum>::ProductOutType: HasAfEnum,
844    <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
845{
846    let mut real: f64 = 0.0;
847    let mut imag: f64 = 0.0;
848    unsafe {
849        let err_val = af_product_nan_all(
850            &mut real as *mut c_double,
851            &mut imag as *mut c_double,
852            input.get(),
853            val,
854        );
855        HANDLE_ERROR(AfError::from(err_val));
856    }
857    (
858        <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(real),
859        <<T as HasAfEnum>::ProductOutType as HasAfEnum>::BaseType::fromf64(imag),
860    )
861}
862
863macro_rules! dim_ireduce_func_def {
864    ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ident) => {
865        #[doc=$doc_str]
866        pub fn $fn_name<T>(input: &Array<T>, dim: i32) -> (Array<T::$out_type>, Array<u32>)
867        where
868            T: HasAfEnum,
869            T::$out_type: HasAfEnum,
870        {
871            unsafe {
872            let mut temp: af_array = std::ptr::null_mut();
873            let mut idx: af_array = std::ptr::null_mut();
874                let err_val = $ffi_name(
875                    &mut temp as *mut af_array, &mut idx as *mut af_array, input.get(), dim,
876                );
877                HANDLE_ERROR(AfError::from(err_val));
878            (temp.into(), idx.into())
879            }
880        }
881    };
882}
883
884dim_ireduce_func_def!("
885    Find minimum value along given dimension and their corresponding indices
886
887    # Parameters
888
889    - `input` - Input Array
890    - `dim` - Dimension along which the input Array will be reduced
891
892    # Return Values
893
894    A tuple of Arrays: Array minimum values and Array containing their index along the reduced dimension.
895    ", imin, af_imin, InType);
896
897dim_ireduce_func_def!("
898    Find maximum value along given dimension and their corresponding indices
899
900    # Parameters
901
902    - `input` - Input Array
903    - `dim` - Dimension along which the input Array will be reduced
904
905    # Return Values
906
907    A tuple of Arrays: Array maximum values and Array containing their index along the reduced dimension.
908    ", imax, af_imax, InType);
909
910macro_rules! all_ireduce_func_def {
911    ($doc_str: expr, $fn_name: ident, $ffi_name: ident, $assoc_type:ident) => {
912        #[doc=$doc_str]
913        pub fn $fn_name<T>(input: &Array<T>)
914            -> (
915                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
916                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType,
917                u32
918               )
919        where
920            T: HasAfEnum,
921            <T as HasAfEnum>::$assoc_type: HasAfEnum,
922            <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType: HasAfEnum + Fromf64,
923        {
924            let mut real: f64 = 0.0;
925            let mut imag: f64 = 0.0;
926            let mut temp: u32 = 0;
927            unsafe {
928                let err_val = $ffi_name(
929                    &mut real as *mut c_double, &mut imag as *mut c_double,
930                    &mut temp as *mut c_uint, input.get(),
931                );
932                HANDLE_ERROR(AfError::from(err_val));
933            }
934            (
935                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(real),
936                <<T as HasAfEnum>::$assoc_type as HasAfEnum>::BaseType::fromf64(imag),
937                temp,
938            )
939        }
940    };
941}
942
943all_ireduce_func_def!(
944    "
945    Find minimum and it's index in the whole Array
946
947    # Parameters
948
949    `input` - Input Array
950
951    # Return Values
952
953    A triplet with
954
955      * minimum element of Array in the first component.
956      * second component of value zero if Array is of non-complex type.
957      * index of minimum element in the third component.
958    ",
959    imin_all,
960    af_imin_all,
961    InType
962);
963all_ireduce_func_def!(
964    "
965    Find maximum and it's index in the whole Array
966
967    # Parameters
968
969    `input` - Input Array
970
971    # Return Values
972
973    A triplet with
974
975      - maximum element of Array in the first component.
976      - second component of value zero if Array is of non-complex type.
977      - index of maximum element in the third component.
978    ",
979    imax_all,
980    af_imax_all,
981    InType
982);
983
984/// Locate the indices of non-zero elements.
985///
986/// The locations are provided by flattening the input into a linear array.
987///
988/// # Parameters
989///
990/// - `input` - Input Array
991///
992/// # Return Values
993///
994/// Array of indices where the input Array has non-zero values.
995pub fn locate<T: HasAfEnum>(input: &Array<T>) -> Array<u32> {
996    unsafe {
997        let mut temp: af_array = std::ptr::null_mut();
998        let err_val = af_where(&mut temp as *mut af_array, input.get());
999        HANDLE_ERROR(AfError::from(err_val));
1000        temp.into()
1001    }
1002}
1003
1004/// Sort the values in input Arrays
1005///
1006/// Sort an multidimensional Array along a given dimension
1007///
1008/// # Parameters
1009///
1010/// - `input` - Input Array
1011/// - `dim` - Dimension along which to sort
1012/// - `ascending` - Sorted output will have ascending values if
1013///                 ```True``` and descending order otherwise.
1014///
1015/// # Return Values
1016///
1017/// Sorted Array.
1018pub fn sort<T>(input: &Array<T>, dim: u32, ascending: bool) -> Array<T>
1019where
1020    T: HasAfEnum + RealNumber,
1021{
1022    unsafe {
1023        let mut temp: af_array = std::ptr::null_mut();
1024        let err_val = af_sort(&mut temp as *mut af_array, input.get(), dim, ascending);
1025        HANDLE_ERROR(AfError::from(err_val));
1026        temp.into()
1027    }
1028}
1029
1030/// Sort the values in input Arrays
1031///
1032/// # Parameters
1033///
1034/// - `input` - Input Array
1035/// - `dim` - Dimension along which to sort
1036/// - `ascending` - Sorted output will have ascending values if
1037///                 ```True``` and descending order otherwise.
1038///
1039/// # Return Values
1040///
1041/// A tuple of Arrays.
1042///
1043/// The first Array contains the keys based on sorted values.
1044///
1045/// The second Array contains the original indices of the sorted values.
1046pub fn sort_index<T>(input: &Array<T>, dim: u32, ascending: bool) -> (Array<T>, Array<u32>)
1047where
1048    T: HasAfEnum + RealNumber,
1049{
1050    unsafe {
1051        let mut temp: af_array = std::ptr::null_mut();
1052        let mut idx: af_array = std::ptr::null_mut();
1053        let err_val = af_sort_index(
1054            &mut temp as *mut af_array,
1055            &mut idx as *mut af_array,
1056            input.get(),
1057            dim,
1058            ascending,
1059        );
1060        HANDLE_ERROR(AfError::from(err_val));
1061        (temp.into(), idx.into())
1062    }
1063}
1064
1065/// Sort the values in input Arrays
1066///
1067/// Sort an multidimensional Array based on keys
1068///
1069/// # Parameters
1070///
1071/// - `keys` - Array with key values
1072/// - `vals` - Array with input values
1073/// - `dim` - Dimension along which to sort
1074/// - `ascending` - Sorted output will have ascending values if ```True``` and descending order otherwise.
1075///
1076/// # Return Values
1077///
1078/// A tuple of Arrays.
1079///
1080/// The first Array contains the keys based on sorted values.
1081///
1082/// The second Array contains the sorted values.
1083pub fn sort_by_key<K, V>(
1084    keys: &Array<K>,
1085    vals: &Array<V>,
1086    dim: u32,
1087    ascending: bool,
1088) -> (Array<K>, Array<V>)
1089where
1090    K: HasAfEnum + RealNumber,
1091    V: HasAfEnum,
1092{
1093    unsafe {
1094        let mut temp: af_array = std::ptr::null_mut();
1095        let mut temp2: af_array = std::ptr::null_mut();
1096        let err_val = af_sort_by_key(
1097            &mut temp as *mut af_array,
1098            &mut temp2 as *mut af_array,
1099            keys.get(),
1100            vals.get(),
1101            dim,
1102            ascending,
1103        );
1104        HANDLE_ERROR(AfError::from(err_val));
1105        (temp.into(), temp2.into())
1106    }
1107}
1108
1109/// Find unique values from a Set
1110///
1111/// # Parameters
1112///
1113/// - `input` - Input Array
1114/// - `is_sorted` - is a boolean variable. If ```True``
1115///                 indicates, the `input` Array is sorted.
1116///
1117/// # Return Values
1118///
1119/// An Array of unique values from the input Array.
1120pub fn set_unique<T>(input: &Array<T>, is_sorted: bool) -> Array<T>
1121where
1122    T: HasAfEnum + RealNumber,
1123{
1124    unsafe {
1125        let mut temp: af_array = std::ptr::null_mut();
1126        let err_val = af_set_unique(&mut temp as *mut af_array, input.get(), is_sorted);
1127        HANDLE_ERROR(AfError::from(err_val));
1128        temp.into()
1129    }
1130}
1131
1132/// Find union of two sets
1133///
1134/// # Parameters
1135///
1136/// - `first` is one of the input sets
1137/// - `second` is the other of the input sets
1138/// - `is_unique` is a boolean value indicates if the input sets are unique
1139///
1140/// # Return Values
1141///
1142/// An Array with union of the input sets
1143pub fn set_union<T>(first: &Array<T>, second: &Array<T>, is_unique: bool) -> Array<T>
1144where
1145    T: HasAfEnum + RealNumber,
1146{
1147    unsafe {
1148        let mut temp: af_array = std::ptr::null_mut();
1149        let err_val = af_set_union(
1150            &mut temp as *mut af_array,
1151            first.get(),
1152            second.get(),
1153            is_unique,
1154        );
1155        HANDLE_ERROR(AfError::from(err_val));
1156        temp.into()
1157    }
1158}
1159
1160/// Find intersection of two sets
1161///
1162/// # Parameters
1163///
1164/// - `first` is one of the input sets
1165/// - `second` is the other of the input sets
1166/// - `is_unique` is a boolean value indicates if the input sets are unique
1167///
1168/// # Return Values
1169///
1170/// An Array with intersection of the input sets
1171pub fn set_intersect<T>(first: &Array<T>, second: &Array<T>, is_unique: bool) -> Array<T>
1172where
1173    T: HasAfEnum + RealNumber,
1174{
1175    unsafe {
1176        let mut temp: af_array = std::ptr::null_mut();
1177        let err_val = af_set_intersect(
1178            &mut temp as *mut af_array,
1179            first.get(),
1180            second.get(),
1181            is_unique,
1182        );
1183        HANDLE_ERROR(AfError::from(err_val));
1184        temp.into()
1185    }
1186}
1187
1188/// Generalized scan
1189///
1190/// # Parameters
1191///
1192/// - `input` is the data on which scan is to be performed
1193/// - `dim` is the dimension along which scan operation is to be performed
1194/// - `op` takes value of [BinaryOp](./enum.BinaryOp.html) enum indicating
1195///    the type of scan operation
1196/// - `inclusive` says if inclusive/exclusive scan is to be performed
1197///
1198/// # Return Values
1199///
1200/// Output Array of scanned input
1201pub fn scan<T>(
1202    input: &Array<T>,
1203    dim: i32,
1204    op: BinaryOp,
1205    inclusive: bool,
1206) -> Array<T::AggregateOutType>
1207where
1208    T: HasAfEnum,
1209    T::AggregateOutType: HasAfEnum,
1210{
1211    unsafe {
1212        let mut temp: af_array = std::ptr::null_mut();
1213        let err_val = af_scan(
1214            &mut temp as *mut af_array,
1215            input.get(),
1216            dim,
1217            op as u32,
1218            inclusive,
1219        );
1220        HANDLE_ERROR(AfError::from(err_val));
1221        temp.into()
1222    }
1223}
1224
1225/// Generalized scan by key
1226///
1227/// # Parameters
1228///
1229/// - `key` is the key Array
1230/// - `input` is the data on which scan is to be performed
1231/// - `dim` is the dimension along which scan operation is to be performed
1232/// - `op` takes value of [BinaryOp](./enum.BinaryOp.html) enum indicating
1233///    the type of scan operation
1234/// - `inclusive` says if inclusive/exclusive scan is to be performed
1235///
1236/// # Return Values
1237///
1238/// Output Array of scanned input
1239pub fn scan_by_key<K, V>(
1240    key: &Array<K>,
1241    input: &Array<V>,
1242    dim: i32,
1243    op: BinaryOp,
1244    inclusive: bool,
1245) -> Array<V::AggregateOutType>
1246where
1247    V: HasAfEnum,
1248    V::AggregateOutType: HasAfEnum,
1249    K: HasAfEnum + Scanable,
1250{
1251    unsafe {
1252        let mut temp: af_array = std::ptr::null_mut();
1253        let err_val = af_scan_by_key(
1254            &mut temp as *mut af_array,
1255            key.get(),
1256            input.get(),
1257            dim,
1258            op as u32,
1259            inclusive,
1260        );
1261        HANDLE_ERROR(AfError::from(err_val));
1262        temp.into()
1263    }
1264}
1265
1266macro_rules! dim_reduce_by_key_func_def {
1267    ($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1268        #[doc=$brief_str]
1269        /// # Parameters
1270        ///
1271        /// - `keys` - key Array
1272        /// - `vals` - value Array
1273        /// - `dim`   - Dimension along which the input Array is reduced
1274        ///
1275        /// # Return Values
1276        ///
1277        /// Tuple of Arrays, with output keys and values after reduction
1278        ///
1279        #[doc=$ex_str]
1280        pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1281                                            dim: i32
1282        ) -> (Array<KeyType>, Array<$out_type>)
1283        where
1284            KeyType: ReduceByKeyInput,
1285            ValueType: HasAfEnum,
1286            $out_type: HasAfEnum,
1287        {
1288            unsafe {
1289            let mut out_keys: af_array = std::ptr::null_mut();
1290            let mut out_vals: af_array = std::ptr::null_mut();
1291                let err_val = $ffi_name(
1292                    &mut out_keys as *mut af_array, &mut out_vals as *mut af_array,
1293                    keys.get(), vals.get(), dim,
1294                );
1295                HANDLE_ERROR(AfError::from(err_val));
1296            (out_keys.into(), out_vals.into())
1297            }
1298        }
1299    };
1300}
1301
1302dim_reduce_by_key_func_def!(
1303    "
1304Key based AND of elements along a given dimension
1305
1306All positive non-zero values are considered true, while negative and zero
1307values are considered as false.
1308",
1309    "
1310# Examples
1311```rust
1312use arrayfire::{Dim4, print, randu, all_true_by_key};
1313let dims = Dim4::new(&[5, 3, 1, 1]);
1314let vals = randu::<f32>(dims);
1315let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1316print(&vals);
1317print(&keys);
1318let (out_keys, out_vals) = all_true_by_key(&keys, &vals, 0);
1319print(&out_keys);
1320print(&out_vals);
1321```
1322",
1323    all_true_by_key,
1324    af_all_true_by_key,
1325    ValueType::AggregateOutType
1326);
1327
1328dim_reduce_by_key_func_def!(
1329    "
1330Key based OR of elements along a given dimension
1331
1332All positive non-zero values are considered true, while negative and zero
1333values are considered as false.
1334",
1335    "
1336# Examples
1337```rust
1338use arrayfire::{Dim4, print, randu, any_true_by_key};
1339let dims = Dim4::new(&[5, 3, 1, 1]);
1340let vals = randu::<f32>(dims);
1341let keys = randu::<u32>(Dim4::new(&[5, 1, 1, 1]));
1342print(&vals);
1343print(&keys);
1344let (out_keys, out_vals) = any_true_by_key(&keys, &vals, 0);
1345print(&out_keys);
1346print(&out_vals);
1347```
1348",
1349    any_true_by_key,
1350    af_any_true_by_key,
1351    ValueType::AggregateOutType
1352);
1353
1354dim_reduce_by_key_func_def!(
1355    "Find total count of elements with similar keys along a given dimension",
1356    "",
1357    count_by_key,
1358    af_count_by_key,
1359    ValueType::AggregateOutType
1360);
1361
1362dim_reduce_by_key_func_def!(
1363    "Find maximum among values of similar keys along a given dimension",
1364    "",
1365    max_by_key,
1366    af_max_by_key,
1367    ValueType::AggregateOutType
1368);
1369
1370dim_reduce_by_key_func_def!(
1371    "Find minimum among values of similar keys along a given dimension",
1372    "",
1373    min_by_key,
1374    af_min_by_key,
1375    ValueType::AggregateOutType
1376);
1377
1378dim_reduce_by_key_func_def!(
1379    "Find product of all values with similar keys along a given dimension",
1380    "",
1381    product_by_key,
1382    af_product_by_key,
1383    ValueType::ProductOutType
1384);
1385
1386dim_reduce_by_key_func_def!(
1387    "Find sum of all values with similar keys along a given dimension",
1388    "",
1389    sum_by_key,
1390    af_sum_by_key,
1391    ValueType::AggregateOutType
1392);
1393
1394macro_rules! dim_reduce_by_key_nan_func_def {
1395    ($brief_str: expr, $ex_str: expr, $fn_name: ident, $ffi_name: ident, $out_type: ty) => {
1396        #[doc=$brief_str]
1397        ///
1398        /// This version of sum by key can replaced all NaN values in the input
1399        /// with a user provided value before performing the reduction operation.
1400        /// # Parameters
1401        ///
1402        /// - `keys` - key Array
1403        /// - `vals` - value Array
1404        /// - `dim`   - Dimension along which the input Array is reduced
1405        ///
1406        /// # Return Values
1407        ///
1408        /// Tuple of Arrays, with output keys and values after reduction
1409        ///
1410        #[doc=$ex_str]
1411        pub fn $fn_name<KeyType, ValueType>(keys: &Array<KeyType>, vals: &Array<ValueType>,
1412                                            dim: i32, replace_value: f64
1413        ) -> (Array<KeyType>, Array<$out_type>)
1414        where
1415            KeyType: ReduceByKeyInput,
1416            ValueType: HasAfEnum,
1417            $out_type: HasAfEnum,
1418        {
1419            unsafe {
1420            let mut out_keys: af_array = std::ptr::null_mut();
1421            let mut out_vals: af_array = std::ptr::null_mut();
1422                let err_val = $ffi_name(
1423                    &mut out_keys as *mut af_array,
1424                    &mut out_vals as *mut af_array,
1425                    keys.get(), vals.get(), dim, replace_value,
1426                );
1427                HANDLE_ERROR(AfError::from(err_val));
1428            (out_keys.into(), out_vals.into())
1429            }
1430        }
1431    };
1432}
1433
1434dim_reduce_by_key_nan_func_def!(
1435    "Compute sum of all values with similar keys along a given dimension",
1436    "",
1437    sum_by_key_nan,
1438    af_sum_by_key_nan,
1439    ValueType::AggregateOutType
1440);
1441
1442dim_reduce_by_key_nan_func_def!(
1443    "Compute product of all values with similar keys along a given dimension",
1444    "",
1445    product_by_key_nan,
1446    af_product_by_key_nan,
1447    ValueType::ProductOutType
1448);
1449
1450/// Max reduction along given axis as per ragged lengths provided
1451///
1452/// # Parameters
1453///
1454/// - `input` contains the input values to be reduced
1455/// - `ragged_len` array containing number of elements to use when reducing along `dim`
1456/// - `dim` is the dimension along which the max operation occurs
1457///
1458/// # Return Values
1459///
1460/// Tuple of Arrays:
1461/// - First element: An Array containing the maximum ragged values in `input` along `dim`
1462///                  according to `ragged_len`
1463/// - Second Element: An Array containing the locations of the maximum ragged values in
1464///                   `input` along `dim` according to `ragged_len`
1465///
1466/// # Examples
1467/// ```rust
1468/// use arrayfire::{Array, dim4, print, randu, max_ragged};
1469/// let vals: [f32; 6] = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
1470/// let rlens: [u32; 2] = [9, 2];
1471/// let varr = Array::new(&vals, dim4![3, 2]);
1472/// let rarr = Array::new(&rlens, dim4![1, 2]);
1473/// print(&varr);
1474/// // 1 4
1475/// // 2 5
1476/// // 3 6
1477/// print(&rarr); // numbers of elements to participate in reduction along given axis
1478/// // 9 2
1479/// let (out, idx) = max_ragged(&varr, &rarr, 0);
1480/// print(&out);
1481/// // 3 5
1482/// print(&idx);
1483/// // 2 1 //Since 3 is max element for given length 9 along first column
1484///        //Since 5 is max element for given length 2 along second column
1485/// ```
1486pub fn max_ragged<T>(
1487    input: &Array<T>,
1488    ragged_len: &Array<u32>,
1489    dim: i32,
1490) -> (Array<T::InType>, Array<u32>)
1491where
1492    T: HasAfEnum,
1493    T::InType: HasAfEnum,
1494{
1495    unsafe {
1496        let mut out_vals: af_array = std::ptr::null_mut();
1497        let mut out_idxs: af_array = std::ptr::null_mut();
1498        let err_val = af_max_ragged(
1499            &mut out_vals as *mut af_array,
1500            &mut out_idxs as *mut af_array,
1501            input.get(),
1502            ragged_len.get(),
1503            dim,
1504        );
1505        HANDLE_ERROR(AfError::from(err_val));
1506        (out_vals.into(), out_idxs.into())
1507    }
1508}
1509
1510#[cfg(test)]
1511mod tests {
1512    use super::super::core::c32;
1513    use super::{imax_all, imin_all, product_nan_all, sum_all, sum_nan_all};
1514    use crate::core::set_device;
1515    use crate::randu;
1516
1517    #[test]
1518    fn all_reduce_api() {
1519        set_device(0);
1520        let a = randu!(c32; 10, 10);
1521        println!("Reduction of complex f32 matrix: {:?}", sum_all(&a));
1522
1523        let b = randu!(bool; 10, 10);
1524        println!("reduction of bool matrix: {:?}", sum_all(&b));
1525
1526        println!(
1527            "reduction of complex f32 matrix after replacing nan with {}: {:?}",
1528            1.0,
1529            product_nan_all(&a, 1.0)
1530        );
1531
1532        println!(
1533            "reduction of bool matrix after replacing nan with {}: {:?}",
1534            0.0,
1535            sum_nan_all(&b, 0.0)
1536        );
1537    }
1538
1539    #[test]
1540    fn all_ireduce_api() {
1541        set_device(0);
1542        let a = randu!(c32; 10);
1543        println!("Reduction of complex f32 matrix: {:?}", imin_all(&a));
1544
1545        let b = randu!(u32; 10);
1546        println!("reduction of bool matrix: {:?}", imax_all(&b));
1547    }
1548}