Thanks to visit codestin.com
Credit goes to docs.rs

arrow2/compute/arithmetics/decimal/
add.rs

1//! Defines the addition arithmetic kernels for [`PrimitiveArray`] representing decimals.
2use crate::{
3    array::PrimitiveArray,
4    compute::{
5        arithmetics::{ArrayAdd, ArrayCheckedAdd, ArraySaturatingAdd},
6        arity::{binary, binary_checked},
7        utils::{check_same_len, combine_validities},
8    },
9};
10use crate::{
11    datatypes::DataType,
12    error::{Error, Result},
13};
14
15use super::{adjusted_precision_scale, get_parameters, max_value, number_digits};
16
17/// Adds two decimal [`PrimitiveArray`] with the same precision and scale.
18/// # Error
19/// Errors if the precision and scale are different.
20/// # Panic
21/// This function panics iff the added numbers result in a number larger than
22/// the possible number for the precision.
23///
24/// # Examples
25/// ```
26/// use arrow2::compute::arithmetics::decimal::add;
27/// use arrow2::array::PrimitiveArray;
28/// use arrow2::datatypes::DataType;
29///
30/// let a = PrimitiveArray::from([Some(1i128), Some(1i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
31/// let b = PrimitiveArray::from([Some(1i128), Some(2i128), None, Some(2i128)]).to(DataType::Decimal(5, 2));
32///
33/// let result = add(&a, &b);
34/// let expected = PrimitiveArray::from([Some(2i128), Some(3i128), None, Some(4i128)]).to(DataType::Decimal(5, 2));
35///
36/// assert_eq!(result, expected);
37/// ```
38pub fn add(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
39    let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
40
41    let max = max_value(precision);
42    let op = move |a, b| {
43        let res: i128 = a + b;
44
45        assert!(
46            res.abs() <= max,
47            "Overflow in addition presented for precision {precision}"
48        );
49
50        res
51    };
52
53    binary(lhs, rhs, lhs.data_type().clone(), op)
54}
55
56/// Saturated addition of two decimal primitive arrays with the same precision
57/// and scale. If the precision and scale is different, then an
58/// InvalidArgumentError is returned. If the result from the sum is larger than
59/// the possible number with the selected precision then the resulted number in
60/// the arrow array is the maximum number for the selected precision.
61///
62/// # Examples
63/// ```
64/// use arrow2::compute::arithmetics::decimal::saturating_add;
65/// use arrow2::array::PrimitiveArray;
66/// use arrow2::datatypes::DataType;
67///
68/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
69/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
70///
71/// let result = saturating_add(&a, &b);
72/// let expected = PrimitiveArray::from([Some(99999i128), Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
73///
74/// assert_eq!(result, expected);
75/// ```
76pub fn saturating_add(
77    lhs: &PrimitiveArray<i128>,
78    rhs: &PrimitiveArray<i128>,
79) -> PrimitiveArray<i128> {
80    let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
81
82    let max = max_value(precision);
83    let op = move |a, b| {
84        let res: i128 = a + b;
85
86        if res.abs() > max {
87            if res > 0 {
88                max
89            } else {
90                -max
91            }
92        } else {
93            res
94        }
95    };
96
97    binary(lhs, rhs, lhs.data_type().clone(), op)
98}
99
100/// Checked addition of two decimal primitive arrays with the same precision
101/// and scale. If the precision and scale is different, then an
102/// InvalidArgumentError is returned. If the result from the sum is larger than
103/// the possible number with the selected precision (overflowing), then the
104/// validity for that index is changed to None
105///
106/// # Examples
107/// ```
108/// use arrow2::compute::arithmetics::decimal::checked_add;
109/// use arrow2::array::PrimitiveArray;
110/// use arrow2::datatypes::DataType;
111///
112/// let a = PrimitiveArray::from([Some(99000i128), Some(11100i128), None, Some(22200i128)]).to(DataType::Decimal(5, 2));
113/// let b = PrimitiveArray::from([Some(01000i128), Some(22200i128), None, Some(11100i128)]).to(DataType::Decimal(5, 2));
114///
115/// let result = checked_add(&a, &b);
116/// let expected = PrimitiveArray::from([None, Some(33300i128), None, Some(33300i128)]).to(DataType::Decimal(5, 2));
117///
118/// assert_eq!(result, expected);
119/// ```
120pub fn checked_add(lhs: &PrimitiveArray<i128>, rhs: &PrimitiveArray<i128>) -> PrimitiveArray<i128> {
121    let (precision, _) = get_parameters(lhs.data_type(), rhs.data_type()).unwrap();
122
123    let max = max_value(precision);
124    let op = move |a, b| {
125        let result: i128 = a + b;
126
127        if result.abs() > max {
128            None
129        } else {
130            Some(result)
131        }
132    };
133
134    binary_checked(lhs, rhs, lhs.data_type().clone(), op)
135}
136
137// Implementation of ArrayAdd trait for PrimitiveArrays
138impl ArrayAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
139    fn add(&self, rhs: &PrimitiveArray<i128>) -> Self {
140        add(self, rhs)
141    }
142}
143
144// Implementation of ArrayCheckedAdd trait for PrimitiveArrays
145impl ArrayCheckedAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
146    fn checked_add(&self, rhs: &PrimitiveArray<i128>) -> Self {
147        checked_add(self, rhs)
148    }
149}
150
151// Implementation of ArraySaturatingAdd trait for PrimitiveArrays
152impl ArraySaturatingAdd<PrimitiveArray<i128>> for PrimitiveArray<i128> {
153    fn saturating_add(&self, rhs: &PrimitiveArray<i128>) -> Self {
154        saturating_add(self, rhs)
155    }
156}
157
158/// Adaptive addition of two decimal primitive arrays with different precision
159/// and scale. If the precision and scale is different, then the smallest scale
160/// and precision is adjusted to the largest precision and scale. If during the
161/// addition one of the results is larger than the max possible value, the
162/// result precision is changed to the precision of the max value
163///
164/// ```nocode
165/// 11111.11   -> 7, 2
166/// 11111.111  -> 8, 3
167/// ------------------
168/// 22222.221  -> 8, 3
169/// ```
170/// # Examples
171/// ```
172/// use arrow2::compute::arithmetics::decimal::adaptive_add;
173/// use arrow2::array::PrimitiveArray;
174/// use arrow2::datatypes::DataType;
175///
176/// let a = PrimitiveArray::from([Some(11111_11i128)]).to(DataType::Decimal(7, 2));
177/// let b = PrimitiveArray::from([Some(11111_111i128)]).to(DataType::Decimal(8, 3));
178/// let result = adaptive_add(&a, &b).unwrap();
179/// let expected = PrimitiveArray::from([Some(22222_221i128)]).to(DataType::Decimal(8, 3));
180///
181/// assert_eq!(result, expected);
182/// ```
183pub fn adaptive_add(
184    lhs: &PrimitiveArray<i128>,
185    rhs: &PrimitiveArray<i128>,
186) -> Result<PrimitiveArray<i128>> {
187    check_same_len(lhs, rhs)?;
188
189    let (lhs_p, lhs_s, rhs_p, rhs_s) =
190        if let (DataType::Decimal(lhs_p, lhs_s), DataType::Decimal(rhs_p, rhs_s)) =
191            (lhs.data_type(), rhs.data_type())
192        {
193            (*lhs_p, *lhs_s, *rhs_p, *rhs_s)
194        } else {
195            return Err(Error::InvalidArgumentError(
196                "Incorrect data type for the array".to_string(),
197            ));
198        };
199
200    // The resulting precision is mutable because it could change while
201    // looping through the iterator
202    let (mut res_p, res_s, diff) = adjusted_precision_scale(lhs_p, lhs_s, rhs_p, rhs_s);
203
204    let shift = 10i128.pow(diff as u32);
205    let mut max = max_value(res_p);
206
207    let values = lhs
208        .values()
209        .iter()
210        .zip(rhs.values().iter())
211        .map(|(l, r)| {
212            // Based on the array's scales one of the arguments in the sum has to be shifted
213            // to the left to match the final scale
214            let res = if lhs_s > rhs_s {
215                l + r * shift
216            } else {
217                l * shift + r
218            };
219
220            // The precision of the resulting array will change if one of the
221            // sums during the iteration produces a value bigger than the
222            // possible value for the initial precision
223
224            //  99.9999 -> 6, 4
225            //  00.0001 -> 6, 4
226            // -----------------
227            // 100.0000 -> 7, 4
228            if res.abs() > max {
229                res_p = number_digits(res);
230                max = max_value(res_p);
231            }
232            res
233        })
234        .collect::<Vec<_>>();
235
236    let validity = combine_validities(lhs.validity(), rhs.validity());
237
238    Ok(PrimitiveArray::<i128>::new(
239        DataType::Decimal(res_p, res_s),
240        values.into(),
241        validity,
242    ))
243}