1use super::core::{
2 af_array, dim_t, AfError, Array, CovarianceComputable, HasAfEnum, MedianComputable,
3 RealFloating, RealNumber, TopkFn, VarianceBias, HANDLE_ERROR,
4};
5
6use libc::{c_double, c_int, c_uint};
7
8extern "C" {
9 fn af_mean(out: *mut af_array, arr: af_array, dim: dim_t) -> c_int;
10 fn af_median(out: *mut af_array, arr: af_array, dim: dim_t) -> c_int;
11
12 fn af_mean_weighted(out: *mut af_array, arr: af_array, wts: af_array, dim: dim_t) -> c_int;
13 fn af_var_weighted(out: *mut af_array, arr: af_array, wts: af_array, dim: dim_t) -> c_int;
14
15 fn af_mean_all(real: *mut c_double, imag: *mut c_double, arr: af_array) -> c_int;
16 fn af_median_all(real: *mut c_double, imag: *mut c_double, arr: af_array) -> c_int;
17
18 fn af_mean_all_weighted(
19 real: *mut c_double,
20 imag: *mut c_double,
21 arr: af_array,
22 wts: af_array,
23 ) -> c_int;
24 fn af_var_all_weighted(
25 real: *mut c_double,
26 imag: *mut c_double,
27 arr: af_array,
28 wts: af_array,
29 ) -> c_int;
30
31 fn af_corrcoef(real: *mut c_double, imag: *mut c_double, X: af_array, Y: af_array) -> c_int;
32 fn af_topk(
33 vals: *mut af_array,
34 idxs: *mut af_array,
35 arr: af_array,
36 k: c_int,
37 dim: c_int,
38 order: c_uint,
39 ) -> c_int;
40
41 fn af_meanvar(
42 mean: *mut af_array,
43 var: *mut af_array,
44 input: af_array,
45 weights: af_array,
46 bias: c_uint,
47 dim: dim_t,
48 ) -> c_int;
49 fn af_var_v2(out: *mut af_array, arr: af_array, bias_kind: c_uint, dim: dim_t) -> c_int;
50 fn af_cov_v2(out: *mut af_array, X: af_array, Y: af_array, bias_kind: c_uint) -> c_int;
51 fn af_stdev_v2(out: *mut af_array, arr: af_array, bias_kind: c_uint, dim: dim_t) -> c_int;
52 fn af_var_all_v2(
53 real: *mut c_double,
54 imag: *mut c_double,
55 arr: af_array,
56 bias_kind: c_uint,
57 ) -> c_int;
58 fn af_stdev_all_v2(
59 real: *mut c_double,
60 imag: *mut c_double,
61 arr: af_array,
62 bias_kind: c_uint,
63 ) -> c_int;
64}
65
66pub fn median<T>(input: &Array<T>, dim: i64) -> Array<T>
78where
79 T: HasAfEnum + MedianComputable,
80{
81 unsafe {
82 let mut temp: af_array = std::ptr::null_mut();
83 let err_val = af_median(&mut temp as *mut af_array, input.get(), dim);
84 HANDLE_ERROR(AfError::from(err_val));
85 temp.into()
86 }
87}
88
89macro_rules! stat_func_def {
90 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
91 #[doc=$doc_str]
92 pub fn $fn_name<T>(input: &Array<T>, dim: i64) -> Array<T::MeanOutType>
103 where
104 T: HasAfEnum,
105 T::MeanOutType: HasAfEnum,
106 {
107 unsafe {
108 let mut temp: af_array = std::ptr::null_mut();
109 let err_val = $ffi_fn(&mut temp as *mut af_array, input.get(), dim);
110 HANDLE_ERROR(AfError::from(err_val));
111 temp.into()
112 }
113 }
114 };
115}
116
117stat_func_def!("Mean along specified dimension", mean, af_mean);
118
119macro_rules! stat_wtd_func_def {
120 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
121 #[doc=$doc_str]
122 pub fn $fn_name<T, W>(
134 input: &Array<T>,
135 weights: &Array<W>,
136 dim: i64,
137 ) -> Array<T::MeanOutType>
138 where
139 T: HasAfEnum,
140 T::MeanOutType: HasAfEnum,
141 W: HasAfEnum + RealFloating,
142 {
143 unsafe {
144 let mut temp: af_array = std::ptr::null_mut();
145 let err_val = $ffi_fn(&mut temp as *mut af_array,input.get(), weights.get(), dim);
146 HANDLE_ERROR(AfError::from(err_val));
147 temp.into()
148 }
149 }
150 };
151}
152
153stat_wtd_func_def!(
154 "Weighted mean along specified dimension",
155 mean_weighted,
156 af_mean_weighted
157);
158stat_wtd_func_def!(
159 "Weight variance along specified dimension",
160 var_weighted,
161 af_var_weighted
162);
163
164pub fn var_v2<T>(arr: &Array<T>, bias_kind: VarianceBias, dim: i64) -> Array<T::MeanOutType>
178where
179 T: HasAfEnum,
180 T::MeanOutType: HasAfEnum,
181{
182 unsafe {
183 let mut temp: af_array = std::ptr::null_mut();
184 let err_val = af_var_v2(
185 &mut temp as *mut af_array,
186 arr.get(),
187 bias_kind as c_uint,
188 dim,
189 );
190 HANDLE_ERROR(AfError::from(err_val));
191 temp.into()
192 }
193}
194
195#[deprecated(since = "3.8.0", note = "Please use var_v2 API")]
207pub fn var<T>(arr: &Array<T>, isbiased: bool, dim: i64) -> Array<T::MeanOutType>
208where
209 T: HasAfEnum,
210 T::MeanOutType: HasAfEnum,
211{
212 var_v2(
213 arr,
214 if isbiased {
215 VarianceBias::SAMPLE
216 } else {
217 VarianceBias::POPULATION
218 },
219 dim,
220 )
221}
222
223pub fn cov_v2<T>(x: &Array<T>, y: &Array<T>, bias_kind: VarianceBias) -> Array<T::MeanOutType>
237where
238 T: HasAfEnum + CovarianceComputable,
239 T::MeanOutType: HasAfEnum,
240{
241 unsafe {
242 let mut temp: af_array = std::ptr::null_mut();
243 let err_val = af_cov_v2(
244 &mut temp as *mut af_array,
245 x.get(),
246 y.get(),
247 bias_kind as c_uint,
248 );
249 HANDLE_ERROR(AfError::from(err_val));
250 temp.into()
251 }
252}
253
254#[deprecated(since = "3.8.0", note = "Please use cov_v2 API")]
266pub fn cov<T>(x: &Array<T>, y: &Array<T>, isbiased: bool) -> Array<T::MeanOutType>
267where
268 T: HasAfEnum + CovarianceComputable,
269 T::MeanOutType: HasAfEnum,
270{
271 cov_v2(
272 x,
273 y,
274 if isbiased {
275 VarianceBias::SAMPLE
276 } else {
277 VarianceBias::POPULATION
278 },
279 )
280}
281
282pub fn var_all_v2<T: HasAfEnum>(input: &Array<T>, bias_kind: VarianceBias) -> (f64, f64) {
295 let mut real: f64 = 0.0;
296 let mut imag: f64 = 0.0;
297 unsafe {
298 let err_val = af_var_all_v2(
299 &mut real as *mut c_double,
300 &mut imag as *mut c_double,
301 input.get(),
302 bias_kind as c_uint,
303 );
304 HANDLE_ERROR(AfError::from(err_val));
305 }
306 (real, imag)
307}
308
309#[deprecated(since = "3.8.0", note = "Please use var_all_v2 API")]
320pub fn var_all<T: HasAfEnum>(input: &Array<T>, isbiased: bool) -> (f64, f64) {
321 var_all_v2(
322 input,
323 if isbiased {
324 VarianceBias::SAMPLE
325 } else {
326 VarianceBias::POPULATION
327 },
328 )
329}
330
331macro_rules! stat_all_func_def {
332 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
333 #[doc=$doc_str]
334 pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> (f64, f64) {
343 let mut real: f64 = 0.0;
344 let mut imag: f64 = 0.0;
345 unsafe {
346 let err_val = $ffi_fn(
347 &mut real as *mut c_double,
348 &mut imag as *mut c_double,
349 input.get(),
350 );
351 HANDLE_ERROR(AfError::from(err_val));
352 }
353 (real, imag)
354 }
355 };
356}
357
358stat_all_func_def!("Compute mean of all data", mean_all, af_mean_all);
359
360pub fn median_all<T>(input: &Array<T>) -> (f64, f64)
370where
371 T: HasAfEnum + MedianComputable,
372{
373 let mut real: f64 = 0.0;
374 let mut imag: f64 = 0.0;
375 unsafe {
376 let err_val = af_median_all(
377 &mut real as *mut c_double,
378 &mut imag as *mut c_double,
379 input.get(),
380 );
381 HANDLE_ERROR(AfError::from(err_val));
382 }
383 (real, imag)
384}
385
386macro_rules! stat_wtd_all_func_def {
387 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
388 #[doc=$doc_str]
389 pub fn $fn_name<T, W>(input: &Array<T>, weights: &Array<W>) -> (f64, f64)
399 where
400 T: HasAfEnum,
401 W: HasAfEnum + RealFloating,
402 {
403 let mut real: f64 = 0.0;
404 let mut imag: f64 = 0.0;
405 unsafe {
406 let err_val = $ffi_fn(
407 &mut real as *mut c_double,
408 &mut imag as *mut c_double,
409 input.get(),
410 weights.get(),
411 );
412 HANDLE_ERROR(AfError::from(err_val));
413 }
414 (real, imag)
415 }
416 };
417}
418
419stat_wtd_all_func_def!(
420 "Compute weighted mean of all data",
421 mean_all_weighted,
422 af_mean_all_weighted
423);
424stat_wtd_all_func_def!(
425 "Compute weighted variance of all data",
426 var_all_weighted,
427 af_var_all_weighted
428);
429
430pub fn corrcoef<T>(x: &Array<T>, y: &Array<T>) -> (f64, f64)
440where
441 T: HasAfEnum + RealNumber,
442{
443 let mut real: f64 = 0.0;
444 let mut imag: f64 = 0.0;
445 unsafe {
446 let err_val = af_corrcoef(
447 &mut real as *mut c_double,
448 &mut imag as *mut c_double,
449 x.get(),
450 y.get(),
451 );
452 HANDLE_ERROR(AfError::from(err_val));
453 }
454 (real, imag)
455}
456
457pub fn topk<T>(input: &Array<T>, k: u32, dim: i32, order: TopkFn) -> (Array<T>, Array<u32>)
481where
482 T: HasAfEnum,
483{
484 unsafe {
485 let mut t0: af_array = std::ptr::null_mut();
486 let mut t1: af_array = std::ptr::null_mut();
487 let err_val = af_topk(
488 &mut t0 as *mut af_array,
489 &mut t1 as *mut af_array,
490 input.get(),
491 k as c_int,
492 dim as c_int,
493 order as c_uint,
494 );
495 HANDLE_ERROR(AfError::from(err_val));
496 (t0.into(), t1.into())
497 }
498}
499
500pub fn meanvar<T, W>(
517 input: &Array<T>,
518 weights: &Array<W>,
519 bias: VarianceBias,
520 dim: i64,
521) -> (Array<T::MeanOutType>, Array<T::MeanOutType>)
522where
523 T: HasAfEnum,
524 T::MeanOutType: HasAfEnum,
525 W: HasAfEnum + RealFloating,
526{
527 unsafe {
528 let mut mean: af_array = std::ptr::null_mut();
529 let mut var: af_array = std::ptr::null_mut();
530 let err_val = af_meanvar(
531 &mut mean as *mut af_array,
532 &mut var as *mut af_array,
533 input.get(),
534 weights.get(),
535 bias as c_uint,
536 dim,
537 );
538 HANDLE_ERROR(AfError::from(err_val));
539 (mean.into(), var.into())
540 }
541}
542
543pub fn stdev_v2<T>(input: &Array<T>, bias_kind: VarianceBias, dim: i64) -> Array<T::MeanOutType>
558where
559 T: HasAfEnum,
560 T::MeanOutType: HasAfEnum,
561{
562 unsafe {
563 let mut temp: af_array = std::ptr::null_mut();
564 let err_val = af_stdev_v2(
565 &mut temp as *mut af_array,
566 input.get(),
567 bias_kind as c_uint,
568 dim,
569 );
570 HANDLE_ERROR(AfError::from(err_val));
571 temp.into()
572 }
573}
574
575#[deprecated(since = "3.8.0", note = "Please use stdev_v2 API")]
587pub fn stdev<T>(input: &Array<T>, dim: i64) -> Array<T::MeanOutType>
588where
589 T: HasAfEnum,
590 T::MeanOutType: HasAfEnum,
591{
592 stdev_v2(input, VarianceBias::POPULATION, dim)
593}
594
595pub fn stdev_all_v2<T: HasAfEnum>(input: &Array<T>, bias_kind: VarianceBias) -> (f64, f64) {
608 let mut real: f64 = 0.0;
609 let mut imag: f64 = 0.0;
610 unsafe {
611 let err_val = af_stdev_all_v2(
612 &mut real as *mut c_double,
613 &mut imag as *mut c_double,
614 input.get(),
615 bias_kind as c_uint,
616 );
617 HANDLE_ERROR(AfError::from(err_val));
618 }
619 (real, imag)
620}
621
622pub fn stdev_all<T: HasAfEnum>(input: &Array<T>) -> (f64, f64) {
632 stdev_all_v2(input, VarianceBias::POPULATION)
633}