arrow2/compute/arithmetics/decimal/add.rs
1//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals.
2use crate::{
3 array::PrimitiveArray,
4 compute::{
5 arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd},
6 arity::{binary, binary_checked},
7 utils::{check_same_len, combine_validities},
8 },
9};
10use crate::{
11 datatypes::DataType,
12 error::{Error, Result},
13};
14
15use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
16
17/// Adds two decimal [`PrimitiveArray`] with the same precision and scale.
18/// # Error
19/// Errors if the precision and scale are different.
20/// # Panic
21/// This function panics iff the added numbers result in a number larger than
22/// the possible number for the precision.
23///
24/// # Examples
25/// ```
26/// use arrow2::compute::arithmetics::decimal::add;
27/// use arrow2::array::PrimitiveArray;
28/// use arrow2::datatypes::DataType;
29///
30/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
31/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
32///
33/// let result = add(&a, &b);
34/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2));
35///
36/// assert_eq!(result, expected);
37/// ```
38pub fn add(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
39 let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
40
41 let max = max_value(precision);
42 let op = move |a, b| {
43 let res: i128 = a + b;
44
45 assert!(
46 res.abs() <= max,
47 "Overflow in addition presented for precision {precision}"
48 );
49
50 res
51 };
52
53 binary(lhs, rhs, lhs.data_type().clone(), op)
54}
55
56/// Saturated addition of two decimal primitive arrays with the same precision
57/// and scale. If the precision and scale is different, then an
58/// InvalidArgumentError is returned. If the result from the sum is larger than
59/// the possible number with the selected precision then the resulted number in
60/// the arrow array is the maximum number for the selected precision.
61///
62/// # Examples
63/// ```
64/// use arrow2::compute::arithmetics::decimal::saturating_add;
65/// use arrow2::array::PrimitiveArray;
66/// use arrow2::datatypes::DataType;
67///
68/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
69/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
70///
71/// let result = saturating_add(&a, &b);
72/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
73///
74/// assert_eq!(result, expected);
75/// ```
76pub fn saturating_add(
77 lhs: &PrimitiveArray<i128>,
78 rhs: &PrimitiveArray<i128>,
79) -> PrimitiveArray<i128> {
80 let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
81
82 let max = max_value(precision);
83 let op = move |a, b| {
84 let res: i128 = a + b;
85
86 if res.abs() > max {
87 if res > 0 {
88 max
89 } else {
90 -max
91 }
92 } else {
93 res
94 }
95 };
96
97 binary(lhs, rhs, lhs.data_type().clone(), op)
98}
99
100/// Checked addition of two decimal primitive arrays with the same precision
101/// and scale. If the precision and scale is different, then an
102/// InvalidArgumentError is returned. If the result from the sum is larger than
103/// the possible number with the selected precision (overflowing), then the
104/// validity for that index is changed to None
105///
106/// # Examples
107/// ```
108/// use arrow2::compute::arithmetics::decimal::checked_add;
109/// use arrow2::array::PrimitiveArray;
110/// use arrow2::datatypes::DataType;
111///
112/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
113/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
114///
115/// let result = checked_add(&a, &b);
116/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
117///
118/// assert_eq!(result, expected);
119/// ```
120pub fn checked_add(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
121 let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
122
123 let max = max_value(precision);
124 let op = move |a, b| {
125 let result: i128 = a + b;
126
127 if result.abs() > max {
128 None
129 } else {
130 Some(result)
131 }
132 };
133
134 binary_checked(lhs, rhs, lhs.data_type().clone(), op)
135}
136
137// Implementation of ArrayAdd trait for PrimitiveArrays
138impl ArrayAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
139 fn add(&self, rhs: &PrimitiveArray<i128>) -> Self {
140 add(self, rhs)
141 }
142}
143
144// Implementation of ArrayCheckedAdd trait for PrimitiveArrays
145impl ArrayCheckedAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
146 fn checked_add(&self, rhs: &PrimitiveArray<i128>) -> Self {
147 checked_add(self, rhs)
148 }
149}
150
151// Implementation of ArraySaturatingAdd trait for PrimitiveArrays
152impl ArraySaturatingAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
153 fn saturating_add(&self, rhs: &PrimitiveArray<i128>) -> Self {
154 saturating_add(self, rhs)
155 }
156}
157
158/// Adaptive addition of two decimal primitive arrays with different precision
159/// and scale. If the precision and scale is different, then the smallest scale
160/// and precision is adjusted to the largest precision and scale. If during the
161/// addition one of the results is larger than the max possible value, the
162/// result precision is changed to the precision of the max value
163///
164/// ```nocode
165/// 11111.11 -> 7, 2
166/// 11111.111 -> 8, 3
167/// ------------------
168/// 22222.221 -> 8, 3
169/// ```
170/// # Examples
171/// ```
172/// use arrow2::compute::arithmetics::decimal::adaptive_add;
173/// use arrow2::array::PrimitiveArray;
174/// use arrow2::datatypes::DataType;
175///
176/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2));
177/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3));
178/// let result = adaptive_add(&a, &b).unwrap();
179/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3));
180///
181/// assert_eq!(result, expected);
182/// ```
183pub fn adaptive_add(
184 lhs: &PrimitiveArray<i128>,
185 rhs: &PrimitiveArray<i128>,
186) -> Result<PrimitiveArray<i128>> {
187 check_same_len(lhs, rhs)?;
188
189 let (lhs_p, lhs_s, rhs_p, rhs_s) =
190 if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
191 (lhs.data_type(), rhs.data_type())
192 {
193 (*lhs_p, *lhs_s, *rhs_p, *rhs_s)
194 } else {
195 return Err(Error::InvalidArgumentError(
196 "Incorrect data type for the array".to_string(),
197 ));
198 };
199
200 // The resulting precision is mutable because it could change while
201 // looping through the iterator
202 let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s);
203
204 let shift = 10i128.pow(diff as u32);
205 let mut max = max_value(res_p);
206
207 let values = lhs
208 .values()
209 .iter()
210 .zip(rhs.values().iter())
211 .map(|(l, r)| {
212 // Based on the array's scales one of the arguments in the sum has to be shifted
213 // to the left to match the final scale
214 let res = if lhs_s > rhs_s {
215 l + r * shift
216 } else {
217 l * shift + r
218 };
219
220 // The precision of the resulting array will change if one of the
221 // sums during the iteration produces a value bigger than the
222 // possible value for the initial precision
223
224 // 99.9999 -> 6, 4
225 // 00.0001 -> 6, 4
226 // -----------------
227 // 100.0000 -> 7, 4
228 if res.abs() > max {
229 res_p = number_digits(res);
230 max = max_value(res_p);
231 }
232 res
233 })
234 .collect::<Vec<_>>();
235
236 let validity = combine_validities(lhs.validity(), rhs.validity());
237
238 Ok(PrimitiveArray::<i128>::new(
239 DataType::Decimal(res_p, res_s),
240 values.into(),
241 validity,
242 ))
243}