1use super::core::{
2 af_array, dim_t, AfError, Array, FloatingPoint, HasAfEnum, SparseFormat, HANDLE_ERROR,
3};
4
5use libc::{c_int, c_uint, c_void};
6
7extern "C" {
8 fn af_create_sparse_array(
9 out: *mut af_array,
10 nRows: dim_t,
11 nCols: dim_t,
12 vals: af_array,
13 rowIdx: af_array,
14 colIdx: af_array,
15 stype: c_uint,
16 ) -> c_int;
17
18 fn af_create_sparse_array_from_ptr(
19 out: *mut af_array,
20 nRows: dim_t,
21 nCols: dim_t,
22 nNZ: dim_t,
23 values: *const c_void,
24 rowIdx: *const c_int,
25 colIdx: *const c_int,
26 aftype: c_uint,
27 stype: c_uint,
28 src: c_uint,
29 ) -> c_int;
30
31 fn af_create_sparse_array_from_dense(
32 out: *mut af_array,
33 dense: af_array,
34 stype: c_uint,
35 ) -> c_int;
36
37 fn af_sparse_convert_to(out: *mut af_array, input: af_array, dstStrge: c_uint) -> c_int;
38
39 fn af_sparse_to_dense(out: *mut af_array, sparse: af_array) -> c_int;
40
41 fn af_sparse_get_info(
42 vals: *mut af_array,
43 rIdx: *mut af_array,
44 cIdx: *mut af_array,
45 stype: *mut c_uint,
46 input: af_array,
47 ) -> c_int;
48
49 fn af_sparse_get_values(out: *mut af_array, input: af_array) -> c_int;
50
51 fn af_sparse_get_row_idx(out: *mut af_array, input: af_array) -> c_int;
52
53 fn af_sparse_get_col_idx(out: *mut af_array, input: af_array) -> c_int;
54
55 fn af_sparse_get_nnz(out: *mut dim_t, input: af_array) -> c_int;
56
57 fn af_sparse_get_storage(out: *mut c_uint, input: af_array) -> c_int;
58}
59
60pub fn sparse<T>(
84 rows: u64,
85 cols: u64,
86 values: &Array<T>,
87 row_indices: &Array<i32>,
88 col_indices: &Array<i32>,
89 format: SparseFormat,
90) -> Array<T>
91where
92 T: HasAfEnum + FloatingPoint,
93{
94 unsafe {
95 let mut temp: af_array = std::ptr::null_mut();
96 let err_val = af_create_sparse_array(
97 &mut temp as *mut af_array,
98 rows as dim_t,
99 cols as dim_t,
100 values.get(),
101 row_indices.get(),
102 col_indices.get(),
103 format as c_uint,
104 );
105 HANDLE_ERROR(AfError::from(err_val));
106 temp.into()
107 }
108}
109
110pub fn sparse_from_host<T>(
135 rows: u64,
136 cols: u64,
137 nzz: u64,
138 values: &[T],
139 row_indices: &[i32],
140 col_indices: &[i32],
141 format: SparseFormat,
142) -> Array<T>
143where
144 T: HasAfEnum + FloatingPoint,
145{
146 let aftype = T::get_af_dtype();
147 unsafe {
148 let mut temp: af_array = std::ptr::null_mut();
149 let err_val = af_create_sparse_array_from_ptr(
150 &mut temp as *mut af_array,
151 rows as dim_t,
152 cols as dim_t,
153 nzz as dim_t,
154 values.as_ptr() as *const c_void,
155 row_indices.as_ptr() as *const c_int,
156 col_indices.as_ptr() as *const c_int,
157 aftype as c_uint,
158 format as c_uint,
159 1,
160 );
161 HANDLE_ERROR(AfError::from(err_val));
162 temp.into()
163 }
164}
165
166pub fn sparse_from_dense<T>(dense: &Array<T>, format: SparseFormat) -> Array<T>
177where
178 T: HasAfEnum + FloatingPoint,
179{
180 unsafe {
181 let mut temp: af_array = std::ptr::null_mut();
182 let err_val = af_create_sparse_array_from_dense(
183 &mut temp as *mut af_array,
184 dense.get(),
185 format as c_uint,
186 );
187 HANDLE_ERROR(AfError::from(err_val));
188 temp.into()
189 }
190}
191
192pub fn sparse_convert_to<T>(input: &Array<T>, format: SparseFormat) -> Array<T>
203where
204 T: HasAfEnum + FloatingPoint,
205{
206 unsafe {
207 let mut temp: af_array = std::ptr::null_mut();
208 let err_val =
209 af_sparse_convert_to(&mut temp as *mut af_array, input.get(), format as c_uint);
210 HANDLE_ERROR(AfError::from(err_val));
211 temp.into()
212 }
213}
214
215pub fn sparse_to_dense<T>(input: &Array<T>) -> Array<T>
225where
226 T: HasAfEnum + FloatingPoint,
227{
228 unsafe {
229 let mut temp: af_array = std::ptr::null_mut();
230 let err_val = af_sparse_to_dense(&mut temp as *mut af_array, input.get());
231 HANDLE_ERROR(AfError::from(err_val));
232 temp.into()
233 }
234}
235
236pub fn sparse_get_info<T>(input: &Array<T>) -> (Array<T>, Array<i32>, Array<i32>, SparseFormat)
246where
247 T: HasAfEnum + FloatingPoint,
248{
249 unsafe {
250 let mut val: af_array = std::ptr::null_mut();
251 let mut row: af_array = std::ptr::null_mut();
252 let mut col: af_array = std::ptr::null_mut();
253 let mut stype: u32 = 0;
254 let err_val = af_sparse_get_info(
255 &mut val as *mut af_array,
256 &mut row as *mut af_array,
257 &mut col as *mut af_array,
258 &mut stype as *mut c_uint,
259 input.get(),
260 );
261 HANDLE_ERROR(AfError::from(err_val));
262 (
263 val.into(),
264 row.into(),
265 col.into(),
266 SparseFormat::from(stype),
267 )
268 }
269}
270
271pub fn sparse_get_values<T>(input: &Array<T>) -> Array<T>
281where
282 T: HasAfEnum + FloatingPoint,
283{
284 unsafe {
285 let mut val: af_array = std::ptr::null_mut();
286 let err_val = af_sparse_get_values(&mut val as *mut af_array, input.get());
287 HANDLE_ERROR(AfError::from(err_val));
288 val.into()
289 }
290}
291
292pub fn sparse_get_row_indices<T>(input: &Array<T>) -> Array<i32>
302where
303 T: HasAfEnum + FloatingPoint,
304{
305 unsafe {
306 let mut val: af_array = std::ptr::null_mut();
307 let err_val = af_sparse_get_row_idx(&mut val as *mut af_array, input.get());
308 HANDLE_ERROR(AfError::from(err_val));
309 val.into()
310 }
311}
312
313pub fn sparse_get_col_indices<T>(input: &Array<T>) -> Array<i32>
323where
324 T: HasAfEnum + FloatingPoint,
325{
326 unsafe {
327 let mut val: af_array = std::ptr::null_mut();
328 let err_val = af_sparse_get_col_idx(&mut val as *mut af_array, input.get());
329 HANDLE_ERROR(AfError::from(err_val));
330 val.into()
331 }
332}
333
334pub fn sparse_get_nnz<T: HasAfEnum>(input: &Array<T>) -> i64 {
344 let mut count: i64 = 0;
345 unsafe {
346 let err_val = af_sparse_get_nnz(&mut count as *mut dim_t, input.get());
347 HANDLE_ERROR(AfError::from(err_val));
348 }
349 count
350}
351
352pub fn sparse_get_format<T: HasAfEnum>(input: &Array<T>) -> SparseFormat {
362 let mut stype: u32 = 0;
363 unsafe {
364 let err_val = af_sparse_get_storage(&mut stype as *mut c_uint, input.get());
365 HANDLE_ERROR(AfError::from(err_val));
366 }
367 SparseFormat::from(stype)
368}