Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 71 additions & 27 deletions datafusion/functions/src/string/repeat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::sync::Arc;
use crate::utils::utf8_to_str_type;
use arrow::array::{
Array, ArrayRef, AsArray, GenericStringArray, GenericStringBuilder, Int64Array,
OffsetSizeTrait, StringArrayType, StringViewArray,
StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder,
};
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};
Expand Down Expand Up @@ -96,6 +96,9 @@ impl ScalarUDFImpl for RepeatFunc {
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if arg_types[0] == Utf8View {
return Ok(Utf8View);
}
utf8_to_str_type(&arg_types[0], "repeat")
}

Expand Down Expand Up @@ -131,13 +134,12 @@ impl ScalarUDFImpl for RepeatFunc {
};

let result = match string_scalar {
ScalarValue::Utf8(Some(s)) | ScalarValue::Utf8View(Some(s)) => {
ScalarValue::Utf8(Some(compute_repeat(
s,
count,
i32::MAX as usize,
)?))
}
ScalarValue::Utf8View(Some(s)) => ScalarValue::Utf8View(Some(
compute_repeat(s, count, i32::MAX as usize)?,
)),
ScalarValue::Utf8(Some(s)) => ScalarValue::Utf8(Some(
compute_repeat(s, count, i32::MAX as usize)?,
)),
ScalarValue::LargeUtf8(Some(s)) => ScalarValue::LargeUtf8(Some(
compute_repeat(s, count, i64::MAX as usize)?,
)),
Expand Down Expand Up @@ -188,26 +190,47 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
match string_array.data_type() {
Utf8View => {
let string_view_array = string_array.as_string_view();
repeat_impl::<i32, &StringViewArray>(
let (_, max_item_capacity) = calculate_capacities(
&string_view_array,
number_array,
i32::MAX as usize,
)?;
let builder = StringViewBuilder::with_capacity(string_array.len());
repeat_impl::<&StringViewArray, StringViewBuilder>(
&string_view_array,
number_array,
max_item_capacity,
builder,
)
}
Utf8 => {
let string_arr = string_array.as_string::<i32>();
repeat_impl::<i32, &GenericStringArray<i32>>(
let (total_capacity, max_item_capacity) =
calculate_capacities(&string_arr, number_array, i32::MAX as usize)?;
let builder = GenericStringBuilder::<i32>::with_capacity(
string_array.len(),
total_capacity,
);
repeat_impl::<&GenericStringArray<i32>, GenericStringBuilder<i32>>(
&string_arr,
number_array,
i32::MAX as usize,
max_item_capacity,
builder,
)
}
LargeUtf8 => {
let string_arr = string_array.as_string::<i64>();
repeat_impl::<i64, &GenericStringArray<i64>>(
let (total_capacity, max_item_capacity) =
calculate_capacities(&string_arr, number_array, i64::MAX as usize)?;
let builder = GenericStringBuilder::<i64>::with_capacity(
string_array.len(),
total_capacity,
);
repeat_impl::<&GenericStringArray<i64>, GenericStringBuilder<i64>>(
&string_arr,
number_array,
i64::MAX as usize,
max_item_capacity,
builder,
)
}
other => exec_err!(
Expand All @@ -217,17 +240,17 @@ fn repeat(string_array: &ArrayRef, count_array: &ArrayRef) -> Result<ArrayRef> {
}
}

fn repeat_impl<'a, T, S>(
fn calculate_capacities<'a, S>(
string_array: &S,
number_array: &Int64Array,
max_str_len: usize,
) -> Result<ArrayRef>
) -> Result<(usize, usize)>
where
T: OffsetSizeTrait,
S: StringArrayType<'a> + 'a,
S: StringArrayType<'a>,
{
let mut total_capacity = 0;
let mut max_item_capacity = 0;

string_array.iter().zip(number_array.iter()).try_for_each(
|(string, number)| -> Result<(), DataFusionError> {
match (string, number) {
Expand All @@ -249,9 +272,19 @@ where
},
)?;

let mut builder =
GenericStringBuilder::<T>::with_capacity(string_array.len(), total_capacity);
Ok((total_capacity, max_item_capacity))
}

fn repeat_impl<'a, S, B>(
string_array: &S,
number_array: &Int64Array,
max_item_capacity: usize,
mut builder: B,
) -> Result<ArrayRef>
where
S: StringArrayType<'a> + 'a,
B: StringLikeArrayBuilder,
{
// Reusable buffer to avoid allocations in string.repeat()
let mut buffer = Vec::<u8>::with_capacity(max_item_capacity);

Expand Down Expand Up @@ -308,8 +341,8 @@ where

#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use arrow::array::{Array, LargeStringArray, StringArray, StringViewArray};
use arrow::datatypes::DataType::{LargeUtf8, Utf8, Utf8View};

use datafusion_common::ScalarValue;
use datafusion_common::{Result, exec_err};
Expand Down Expand Up @@ -362,8 +395,8 @@ mod tests {
],
Ok(Some("PgPgPgPg")),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
Expand All @@ -373,8 +406,19 @@ mod tests {
],
Ok(None),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
vec![
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("Pg")))),
ColumnarValue::Scalar(ScalarValue::Int64(None)),
],
Ok(None),
&str,
LargeUtf8,
LargeStringArray
);
test_function!(
RepeatFunc::new(),
Expand All @@ -384,8 +428,8 @@ mod tests {
],
Ok(None),
&str,
Utf8,
StringArray
Utf8View,
StringViewArray
);
test_function!(
RepeatFunc::new(),
Expand Down
24 changes: 24 additions & 0 deletions datafusion/sqllogictest/test_files/string/string_literal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,35 @@ SELECT repeat('foo', 3)
----
foofoofoo

query T
SELECT repeat(arrow_cast('foo', 'LargeUtf8'), 3)
----
foofoofoo

query T
SELECT repeat(arrow_cast('foo', 'Utf8View'), 3)
----
foofoofoo

query T
SELECT repeat(arrow_cast('foo', 'Dictionary(Int32, Utf8)'), 3)
----
foofoofoo

query T
SELECT arrow_typeof(repeat('foo', 3))
----
Utf8

query T
SELECT arrow_typeof(repeat(arrow_cast('foo', 'LargeUtf8'), 3))
----
LargeUtf8

query T
SELECT arrow_typeof(repeat(arrow_cast('foo', 'Utf8View'), 3))
----
Utf8View

query T
SELECT replace('foobar', 'bar', 'hello')
Expand Down