arrayfire/ml/mod.rs
1use super::core::{
2 af_array, dim_t, AfError, Array, ConvGradientType, Dim4, HasAfEnum, RealFloating, HANDLE_ERROR,
3};
4
5use libc::{c_int, c_uint};
6
7extern "C" {
8 fn af_convolve2_nn(
9 out: *mut af_array,
10 signal: af_array,
11 filter: af_array,
12 stride_dims: c_uint,
13 strides: *const dim_t,
14 padding_dim: c_uint,
15 paddings: *const dim_t,
16 dilation_dim: c_uint,
17 dilations: *const dim_t,
18 ) -> c_int;
19
20 fn af_convolve2_gradient_nn(
21 out: *mut af_array,
22 incoming_gradient: af_array,
23 original_signal: af_array,
24 original_filter: af_array,
25 convolved_output: af_array,
26 stride_dims: c_uint,
27 strides: *const dim_t,
28 padding_dims: c_uint,
29 paddings: *const dim_t,
30 dilation_dims: c_uint,
31 dilations: *const dim_t,
32 grad_type: c_uint,
33 ) -> c_int;
34}
35
36/// Convolution Integral for two dimensional data
37///
38/// This version of convolution is consistent with the machine learning formulation
39/// that will spatially convolve a filter on 2-dimensions against a signal. Multiple
40/// signals and filters can be batched against each other. Furthermore, the signals
41/// and filters can be multi-dimensional however their dimensions must match. Usually,
42/// this is the forward pass convolution in ML
43///
44/// Example:
45///
46/// Signals with dimensions: d0 x d1 x d2 x Ns
47///
48/// Filters with dimensions: d0 x d1 x d2 x Nf
49///
50/// Resulting Convolution: d0 x d1 x Nf x Ns
51///
52/// # Parameters
53///
54/// - `signal` is the input signal
55/// - `filter` is convolution filter
56/// - `strides` are distance between consecutive elements along each dimension for original convolution
57/// - `padding` specifies padding width along each dimension for original convolution
58/// - `dilation` specifies filter dilation along each dimension for original convolution
59///
60/// # Return Values
61///
62/// Convolved Array
63pub fn convolve2_nn<T>(
64 signal: &Array<T>,
65 filter: &Array<T>,
66 strides: Dim4,
67 padding: Dim4,
68 dilation: Dim4,
69) -> Array<T>
70where
71 T: HasAfEnum + RealFloating,
72{
73 unsafe {
74 let mut temp: af_array = std::ptr::null_mut();
75 let err_val = af_convolve2_nn(
76 &mut temp as *mut af_array,
77 signal.get(),
78 filter.get(),
79 strides.ndims() as c_uint,
80 strides.get().as_ptr() as *const dim_t,
81 padding.ndims() as c_uint,
82 padding.get().as_ptr() as *const dim_t,
83 dilation.ndims() as c_uint,
84 dilation.get().as_ptr() as *const dim_t,
85 );
86 HANDLE_ERROR(AfError::from(err_val));
87 temp.into()
88 }
89}
90
91/// Backward pass gradient of 2D convolution
92///
93/// # Parameters
94///
95/// - `incoming_gradient` gradients to be distributed in backwards pass
96/// - `original_signal` input signal to forward pass of convolution assumed structure of input is ( d0 x d1 x d2 x N )
97/// - `original_filter` input filter to forward pass of convolution assumed structure of input is ( d0 x d1 x d2 x N )
98/// - `convolved_output` output from forward pass of convolution
99/// - `strides` are distance between consecutive elements along each dimension for original convolution
100/// - `padding` specifies padding width along each dimension for original convolution
101/// - `dilation` specifies filter dilation along each dimension for original convolution
102/// - `grad_type` specifies which gradient to return
103///
104/// # Return Values
105///
106/// Gradient Array w.r.t input generated from [convolve2_nn](./fn.convolve2_nn.html)
107#[allow(clippy::too_many_arguments)]
108pub fn convolve2_gradient_nn<T>(
109 incoming_grad: &Array<T>,
110 original_signal: &Array<T>,
111 original_filter: &Array<T>,
112 convolved_output: &Array<T>,
113 strides: Dim4,
114 padding: Dim4,
115 dilation: Dim4,
116 grad_type: ConvGradientType,
117) -> Array<T>
118where
119 T: HasAfEnum + RealFloating,
120{
121 unsafe {
122 let mut temp: af_array = std::ptr::null_mut();
123 let err_val = af_convolve2_gradient_nn(
124 &mut temp as *mut af_array,
125 incoming_grad.get(),
126 original_signal.get(),
127 original_filter.get(),
128 convolved_output.get(),
129 strides.ndims() as c_uint,
130 strides.get().as_ptr() as *const dim_t,
131 padding.ndims() as c_uint,
132 padding.get().as_ptr() as *const dim_t,
133 dilation.ndims() as c_uint,
134 dilation.get().as_ptr() as *const dim_t,
135 grad_type as c_uint,
136 );
137 HANDLE_ERROR(AfError::from(err_val));
138 temp.into()
139 }
140}