From 27929a30b172afeb6a962defe717a5c2353e76de Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 12:12:22 +0530 Subject: [PATCH 1/2] perf: Optimize contains expression with SIMD-based scalar pattern search (#2972) --- native/Cargo.lock | 2 + native/spark-expr/Cargo.toml | 2 + native/spark-expr/src/comet_scalar_funcs.rs | 5 +- .../spark-expr/src/string_funcs/contains.rs | 282 ++++++++++++++++++ native/spark-expr/src/string_funcs/mod.rs | 2 + .../apache/comet/CometExpressionSuite.scala | 19 +- .../CometStringExpressionBenchmark.scala | 1 + 7 files changed, 310 insertions(+), 3 deletions(-) create mode 100644 native/spark-expr/src/string_funcs/contains.rs diff --git a/native/Cargo.lock b/native/Cargo.lock index bf9a7ea2da..7369a97d6b 100644 --- a/native/Cargo.lock +++ b/native/Cargo.lock @@ -1872,6 +1872,7 @@ name = "datafusion-comet-spark-expr" version = "0.13.0" dependencies = [ "arrow", + "arrow-string", "base64", "chrono", "chrono-tz", @@ -1879,6 +1880,7 @@ dependencies = [ "datafusion", "futures", "hex", + "memchr", "num", "rand 0.9.2", "regex", diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index ea89c43204..a0476b2a32 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -28,9 +28,11 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } +arrow-string = "57.0.0" chrono = { workspace = true } datafusion = { workspace = true } chrono-tz = { workspace = true } +memchr = "2.7" num = { workspace = true } regex = { workspace = true } serde_json = "1.0" diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 8384a4646a..2ff355369e 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,8 +22,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc, - SparkStringSpace, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkContains, SparkDateTrunc, + SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -192,6 +192,7 @@ pub fn create_comet_physical_fun_with_eval_mode( fn all_scalar_functions() -> Vec> { vec![ Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), + Arc::new(ScalarUDF::new_from_impl(SparkContains::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs new file mode 100644 index 0000000000..c4662ba9d3 --- /dev/null +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -0,0 +1,282 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Optimized `contains` string function for Spark compatibility. +//! +//! This implementation is optimized for the common case where the pattern +//! (second argument) is a scalar value. In this case, we use `memchr::memmem::Finder` +//! which is SIMD-optimized and reuses a single finder instance across all rows. +//! +//! The DataFusion built-in `contains` function uses `make_scalar_function` which +//! expands scalar values to arrays, losing the performance benefit of the optimized +//! scalar path in arrow-rs. + +use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; +use arrow::datatypes::DataType; +use arrow_string::like::contains as arrow_contains; +use datafusion::common::{exec_err, Result, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use memchr::memmem::Finder; +use std::any::Any; +use std::sync::Arc; + +/// Spark-optimized contains function. +/// +/// Returns true if the first string argument contains the second string argument. +/// Optimized for the common case where the pattern is a scalar constant. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkContains { + signature: Signature, +} + +impl Default for SparkContains { + fn default() -> Self { + Self::new() + } +} + +impl SparkContains { + pub fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkContains { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "contains" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Boolean) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + if args.args.len() != 2 { + return exec_err!("contains function requires exactly 2 arguments"); + } + spark_contains(&args.args[0], &args.args[1]) + } +} + +/// Execute the contains function with optimized scalar pattern handling. +fn spark_contains(haystack: &ColumnarValue, needle: &ColumnarValue) -> Result { + match (haystack, needle) { + // Case 1: Both are arrays - use arrow's contains directly + (ColumnarValue::Array(haystack_array), ColumnarValue::Array(needle_array)) => { + let result = arrow_contains(haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Case 2: Haystack is array, needle is scalar - OPTIMIZED PATH + // This is the common case in SQL like: WHERE col CONTAINS 'pattern' + (ColumnarValue::Array(haystack_array), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_with_scalar_pattern(haystack_array, needle_scalar)?; + Ok(ColumnarValue::Array(result)) + } + + // Case 3: Haystack is scalar, needle is array - less common + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Array(needle_array)) => { + // Convert scalar to array and use arrow's contains + let haystack_array = haystack_scalar.to_array_of_size(needle_array.len())?; + let result = arrow_contains(&haystack_array, needle_array)?; + Ok(ColumnarValue::Array(Arc::new(result))) + } + + // Case 4: Both are scalars - compute single result + (ColumnarValue::Scalar(haystack_scalar), ColumnarValue::Scalar(needle_scalar)) => { + let result = contains_scalar_scalar(haystack_scalar, needle_scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +/// Optimized contains for array haystack with scalar needle pattern. +/// Uses memchr's SIMD-optimized Finder for efficient repeated searches. +fn contains_with_scalar_pattern( + haystack_array: &ArrayRef, + needle_scalar: &ScalarValue, +) -> Result { + // Handle null needle + if needle_scalar.is_null() { + return Ok(Arc::new(BooleanArray::new_null(haystack_array.len()))); + } + + // Extract the needle string + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + // Create a reusable Finder for efficient SIMD-optimized searching + let finder = Finder::new(needle_str.as_bytes()); + + match haystack_array.data_type() { + DataType::Utf8 => { + let array = haystack_array.as_string::(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + DataType::LargeUtf8 => { + let array = haystack_array.as_string::(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + DataType::Utf8View => { + let array = haystack_array.as_string_view(); + let result: BooleanArray = array + .iter() + .map(|opt_haystack| opt_haystack.map(|h| finder.find(h.as_bytes()).is_some())) + .collect(); + Ok(Arc::new(result)) + } + other => exec_err!( + "contains function requires string type for haystack, got {:?}", + other + ), + } +} + +/// Contains for two scalar values. +fn contains_scalar_scalar( + haystack_scalar: &ScalarValue, + needle_scalar: &ScalarValue, +) -> Result { + // Handle nulls + if haystack_scalar.is_null() || needle_scalar.is_null() { + return Ok(ScalarValue::Boolean(None)); + } + + let haystack_str = match haystack_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for haystack, got {:?}", + haystack_scalar.data_type() + ) + } + }; + + let needle_str = match needle_scalar { + ScalarValue::Utf8(Some(s)) + | ScalarValue::LargeUtf8(Some(s)) + | ScalarValue::Utf8View(Some(s)) => s.as_str(), + _ => { + return exec_err!( + "contains function requires string type for needle, got {:?}", + needle_scalar.data_type() + ) + } + }; + + Ok(ScalarValue::Boolean(Some( + haystack_str.contains(needle_str), + ))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::StringArray; + + #[test] + fn test_contains_array_scalar() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + Some("testing"), + None, + ])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + assert!(bool_array.value(0)); // "hello world" contains "world" + assert!(!bool_array.value(1)); // "foo bar" does not contain "world" + assert!(!bool_array.value(2)); // "testing" does not contain "world" + assert!(bool_array.is_null(3)); // null input => null output + } + + #[test] + fn test_contains_scalar_scalar() { + let haystack = ScalarValue::Utf8(Some("hello world".to_string())); + let needle = ScalarValue::Utf8(Some("world".to_string())); + + let result = contains_scalar_scalar(&haystack, &needle).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(true))); + + let needle_not_found = ScalarValue::Utf8(Some("xyz".to_string())); + let result = contains_scalar_scalar(&haystack, &needle_not_found).unwrap(); + assert_eq!(result, ScalarValue::Boolean(Some(false))); + } + + #[test] + fn test_contains_null_needle() { + let haystack = Arc::new(StringArray::from(vec![ + Some("hello world"), + Some("foo bar"), + ])) as ArrayRef; + let needle = ScalarValue::Utf8(None); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Null needle should produce null results + assert!(bool_array.is_null(0)); + assert!(bool_array.is_null(1)); + } + + #[test] + fn test_contains_empty_needle() { + let haystack = Arc::new(StringArray::from(vec![Some("hello world"), Some("")])) as ArrayRef; + let needle = ScalarValue::Utf8(Some("".to_string())); + + let result = contains_with_scalar_pattern(&haystack, &needle).unwrap(); + let bool_array = result.as_any().downcast_ref::().unwrap(); + + // Empty string is contained in any string + assert!(bool_array.value(0)); + assert!(bool_array.value(1)); + } +} diff --git a/native/spark-expr/src/string_funcs/mod.rs b/native/spark-expr/src/string_funcs/mod.rs index aac8204e29..abdd0cc89b 100644 --- a/native/spark-expr/src/string_funcs/mod.rs +++ b/native/spark-expr/src/string_funcs/mod.rs @@ -15,8 +15,10 @@ // specific language governing permissions and limitations // under the License. +mod contains; mod string_space; mod substring; +pub use contains::SparkContains; pub use string_space::SparkStringSpace; pub use substring::SubstringExpr; diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 0352da7850..93b184ad7f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1107,7 +1107,24 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { // Filter rows that contains 'rose' in 'name' column val queryContains = sql(s"select id from $table where contains (name, 'rose')") - checkAnswer(queryContains, Row(5) :: Nil) + checkSparkAnswerAndOperator(queryContains) + + // Additional test cases for optimized contains implementation + // Test with empty pattern (should match all non-null rows) + val queryEmptyPattern = sql(s"select id from $table where contains (name, '')") + checkSparkAnswerAndOperator(queryEmptyPattern) + + // Test with pattern not found + val queryNotFound = sql(s"select id from $table where contains (name, 'xyz')") + checkSparkAnswerAndOperator(queryNotFound) + + // Test with pattern at start + val queryStart = sql(s"select id from $table where contains (name, 'James')") + checkSparkAnswerAndOperator(queryStart) + + // Test with pattern at end + val queryEnd = sql(s"select id from $table where contains (name, 'Smith')") + checkSparkAnswerAndOperator(queryEnd) } } diff --git a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala index 41eabb8513..c96cd83438 100644 --- a/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala +++ b/spark/src/test/scala/org/apache/spark/sql/benchmark/CometStringExpressionBenchmark.scala @@ -72,6 +72,7 @@ object CometStringExpressionBenchmark extends CometBenchmarkBase { StringExprConfig("initCap", "select initCap(c1) from parquetV1Table"), StringExprConfig("trim", "select trim(c1) from parquetV1Table"), StringExprConfig("concatws", "select concat_ws(' ', c1, c1) from parquetV1Table"), + StringExprConfig("contains", "select contains(c1, '123') from parquetV1Table"), StringExprConfig("length", "select length(c1) from parquetV1Table"), StringExprConfig("repeat", "select repeat(c1, 3) from parquetV1Table"), StringExprConfig("reverse", "select reverse(c1) from parquetV1Table"), From a4492975524ccda0ae1ecc1a87a8f81830ff4272 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 23:47:16 +0530 Subject: [PATCH 2/2] Use arrow re-export instead of direct arrow-string dependency --- native/spark-expr/Cargo.toml | 1 - native/spark-expr/src/string_funcs/contains.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/native/spark-expr/Cargo.toml b/native/spark-expr/Cargo.toml index a0476b2a32..7621c0c974 100644 --- a/native/spark-expr/Cargo.toml +++ b/native/spark-expr/Cargo.toml @@ -28,7 +28,6 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } -arrow-string = "57.0.0" chrono = { workspace = true } datafusion = { workspace = true } chrono-tz = { workspace = true } diff --git a/native/spark-expr/src/string_funcs/contains.rs b/native/spark-expr/src/string_funcs/contains.rs index c4662ba9d3..9925319880 100644 --- a/native/spark-expr/src/string_funcs/contains.rs +++ b/native/spark-expr/src/string_funcs/contains.rs @@ -27,7 +27,7 @@ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray}; use arrow::datatypes::DataType; -use arrow_string::like::contains as arrow_contains; +use arrow::compute::kernels::comparison::contains as arrow_contains; use datafusion::common::{exec_err, Result, ScalarValue}; use datafusion::logical_expr::{ ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,