diff --git a/datafusion/functions/src/string/split_part.rs b/datafusion/functions/src/string/split_part.rs index 1994c65bcf326..7e382868c4f23 100644 --- a/datafusion/functions/src/string/split_part.rs +++ b/datafusion/functions/src/string/split_part.rs @@ -15,13 +15,15 @@ // specific language governing permissions and limitations // under the License. +use crate::strings::{ + BulkNullStringArrayBuilder, GenericStringArrayBuilder, StringViewArrayBuilder, +}; use crate::utils::utf8_to_str_type; use arrow::array::{ - Array, ArrayRef, AsArray, ByteView, GenericStringBuilder, Int64Array, - StringArrayType, StringLikeArrayBuilder, StringViewArray, StringViewBuilder, + Array, ArrayRef, AsArray, ByteView, Int64Array, StringArrayType, StringViewArray, make_view, new_null_array, }; -use arrow::buffer::ScalarBuffer; +use arrow::buffer::{NullBuffer, ScalarBuffer}; use arrow::datatypes::DataType; use datafusion_common::ScalarValue; use datafusion_common::cast::as_int64_array; @@ -167,7 +169,7 @@ impl ScalarUDFImpl for SplitPartFunc { let result = match args[0].data_type() { DataType::Utf8View => split_part_for_delimiter_type!( &args[0].as_string_view(), - StringViewBuilder::with_capacity(inferred_length) + StringViewArrayBuilder::with_capacity(inferred_length) ), DataType::Utf8 => { let str_arr = &args[0].as_string::(); @@ -176,7 +178,7 @@ impl ScalarUDFImpl for SplitPartFunc { // pre-allocating the full input data size. split_part_for_delimiter_type!( str_arr, - GenericStringBuilder::::with_capacity( + GenericStringArrayBuilder::::with_capacity( inferred_length, inferred_length, ) @@ -187,7 +189,7 @@ impl ScalarUDFImpl for SplitPartFunc { // Conservative under-estimate; see Utf8 comment above. split_part_for_delimiter_type!( str_arr, - GenericStringBuilder::::with_capacity( + GenericStringArrayBuilder::::with_capacity( inferred_length, inferred_length, ) @@ -293,7 +295,7 @@ fn split_part_scalar( arr, delimiter, position, - GenericStringBuilder::::with_capacity(arr.len(), arr.len()), + GenericStringArrayBuilder::::with_capacity(arr.len(), arr.len()), ) } DataType::LargeUtf8 => { @@ -303,7 +305,7 @@ fn split_part_scalar( arr, delimiter, position, - GenericStringBuilder::::with_capacity(arr.len(), arr.len()), + GenericStringArrayBuilder::::with_capacity(arr.len(), arr.len()), ) } other => exec_err!("Unsupported string type {other:?} for split_part"), @@ -323,7 +325,7 @@ fn split_part_scalar_impl<'a, S, B>( ) -> Result where S: StringArrayType<'a> + Copy, - B: StringLikeArrayBuilder, + B: BulkNullStringArrayBuilder, { if delimiter.is_empty() { // PostgreSQL: empty delimiter treats input as a single field, @@ -367,16 +369,31 @@ where fn map_strings<'a, S, B, F>(string_array: S, mut builder: B, f: F) -> Result where S: StringArrayType<'a> + Copy, - B: StringLikeArrayBuilder, + B: BulkNullStringArrayBuilder, F: Fn(&'a str) -> Option<&'a str>, { - for string in string_array.iter() { - match string { - Some(s) => builder.append_value(f(s).unwrap_or("")), - None => builder.append_null(), + let item_len = string_array.len(); + let nulls = string_array.nulls().cloned(); + + if let Some(ref n) = nulls { + for i in 0..item_len { + if n.is_null(i) { + builder.append_placeholder(); + } else { + // SAFETY: `n.is_null(i)` was false in the branch above. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(f(s).unwrap_or("")); + } + } + } else { + for i in 0..item_len { + // SAFETY: no null buffer means every index is valid. + let s = unsafe { string_array.value_unchecked(i) }; + builder.append_value(f(s).unwrap_or("")); } } - Ok(Arc::new(builder.finish()) as ArrayRef) + + builder.finish(nulls) } /// Finds the `n`th (0-based) split part using a pre-built `memmem::Finder`. @@ -543,58 +560,82 @@ fn split_part_impl<'a, StringArrType, DelimiterArrType, B>( where StringArrType: StringArrayType<'a>, DelimiterArrType: StringArrayType<'a>, - B: StringLikeArrayBuilder, + B: BulkNullStringArrayBuilder, { - for ((string, delimiter), n) in string_array - .iter() - .zip(delimiter_array.iter()) - .zip(n_array.iter()) - { - match (string, delimiter, n) { - (Some(string), Some(delimiter), Some(n)) => { - let result = match n.cmp(&0) { - std::cmp::Ordering::Greater => { - let idx: usize = (n - 1).try_into().map_err(|_| { - exec_datafusion_err!( - "split_part index {n} exceeds maximum supported value" - ) - })?; - if delimiter.is_empty() { - // Match PostgreSQL's behavior: empty delimiter - // treats input as a single field, so only position - // 1 returns data. - (n == 1).then_some(string) - } else { - split_nth(string, delimiter, idx) - } - } - std::cmp::Ordering::Less => { - let idx: usize = - (n.unsigned_abs() - 1).try_into().map_err(|_| { - exec_datafusion_err!( - "split_part index {n} exceeds minimum supported value" - ) - })?; - if delimiter.is_empty() { - // Match PostgreSQL's behavior: empty delimiter - // treats input as a single field, so only position - // -1 returns data. - (n == -1).then_some(string) - } else { - rsplit_nth(string, delimiter, idx) - } - } - std::cmp::Ordering::Equal => { - return exec_err!("field position must not be zero"); - } - }; - builder.append_value(result.unwrap_or("")); + let nulls = NullBuffer::union_many([ + string_array.nulls(), + delimiter_array.nulls(), + n_array.nulls(), + ]); + + if let Some(ref n) = nulls { + for i in 0..string_array.len() { + if n.is_null(i) { + builder.append_placeholder(); + continue; } - _ => builder.append_null(), + + // SAFETY: the union null buffer is valid at `i`, so each input is valid. + let string = unsafe { string_array.value_unchecked(i) }; + let delimiter = unsafe { delimiter_array.value_unchecked(i) }; + let position = unsafe { n_array.value_unchecked(i) }; + append_split_part(string, delimiter, position, &mut builder)?; + } + } else { + for i in 0..string_array.len() { + // SAFETY: no input has a null buffer, so every index is valid. + let string = unsafe { string_array.value_unchecked(i) }; + let delimiter = unsafe { delimiter_array.value_unchecked(i) }; + let position = unsafe { n_array.value_unchecked(i) }; + append_split_part(string, delimiter, position, &mut builder)?; } } - Ok(Arc::new(builder.finish()) as ArrayRef) + builder.finish(nulls) +} + +#[inline] +fn append_split_part( + string: &str, + delimiter: &str, + n: i64, + builder: &mut B, +) -> Result<()> { + let result = match n.cmp(&0) { + std::cmp::Ordering::Greater => { + let idx: usize = (n - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds maximum supported value" + ) + })?; + if delimiter.is_empty() { + // Match PostgreSQL's behavior: empty delimiter treats input + // as a single field, so only position 1 returns data. + (n == 1).then_some(string) + } else { + split_nth(string, delimiter, idx) + } + } + std::cmp::Ordering::Less => { + let idx: usize = (n.unsigned_abs() - 1).try_into().map_err(|_| { + exec_datafusion_err!( + "split_part index {n} exceeds minimum supported value" + ) + })?; + if delimiter.is_empty() { + // Match PostgreSQL's behavior: empty delimiter treats input + // as a single field, so only position -1 returns data. + (n == -1).then_some(string) + } else { + rsplit_nth(string, delimiter, idx) + } + } + std::cmp::Ordering::Equal => { + return exec_err!("field position must not be zero"); + } + }; + builder.append_value(result.unwrap_or("")); + Ok(()) } #[cfg(test)]