Skip to content

Commit 21ebf69

Browse files
committed
feat: Implement spark_translate function to improve performance of translate expression
1 parent dca45ea commit 21ebf69

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ apache-rat-*.jar
1818
venv
1919
dev/release/comet-rm/workdir
2020
spark/benchmarks
21+
.DS_Store

native/spark-expr/src/comet_scalar_funcs.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use crate::hash_funcs::*;
1919
use crate::math_funcs::abs::abs;
2020
use crate::math_funcs::checked_arithmetic::{checked_add, checked_div, checked_mul, checked_sub};
2121
use crate::math_funcs::modulo_expr::spark_modulo;
22+
use crate::string_funcs::spark_translate;
2223
use crate::{
2324
spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor,
2425
spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad,
@@ -181,6 +182,10 @@ pub fn create_comet_physical_fun_with_eval_mode(
181182
let func = Arc::new(abs);
182183
make_comet_scalar_udf!("abs", func, without data_type)
183184
}
185+
"translate" => {
186+
let func = Arc::new(spark_translate);
187+
make_comet_scalar_udf!("translate", func, without data_type)
188+
}
184189
_ => registry.udf(fun_name).map_err(|e| {
185190
DataFusionError::Execution(format!(
186191
"Function {fun_name} not found in the registry: {e}",

native/spark-expr/src/string_funcs/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
mod string_space;
1919
mod substring;
20+
mod translate;
2021

2122
pub use string_space::SparkStringSpace;
2223
pub use substring::SubstringExpr;
24+
pub use translate::spark_translate;
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::builder::GenericStringBuilder;
19+
use arrow::array::cast::as_dictionary_array;
20+
use arrow::array::types::Int32Type;
21+
use arrow::array::{make_array, Array, DictionaryArray, OffsetSizeTrait};
22+
use arrow::datatypes::DataType;
23+
use datafusion::common::{cast::as_generic_string_array, DataFusionError, ScalarValue};
24+
use datafusion::physical_plan::ColumnarValue;
25+
use std::collections::HashMap;
26+
use std::sync::Arc;
27+
28+
pub fn spark_translate(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
29+
match args {
30+
[ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Utf8(Some(from))), ColumnarValue::Scalar(ScalarValue::Utf8(Some(to)))] =>
31+
{
32+
let translation_map = build_translation_map(from, to);
33+
34+
match array.data_type() {
35+
DataType::Utf8 => translate_array_internal::<i32>(array, &translation_map),
36+
DataType::LargeUtf8 => translate_array_internal::<i64>(array, &translation_map),
37+
DataType::Dictionary(_, value_type) => {
38+
let dict = as_dictionary_array::<Int32Type>(array);
39+
let col = if value_type.as_ref() == &DataType::Utf8 {
40+
translate_array_internal::<i32>(dict.values(), &translation_map)?
41+
} else {
42+
translate_array_internal::<i64>(dict.values(), &translation_map)?
43+
};
44+
let values = col.to_array(0)?;
45+
let result = DictionaryArray::try_new(dict.keys().clone(), values)?;
46+
Ok(ColumnarValue::Array(make_array(result.into())))
47+
}
48+
other => Err(DataFusionError::Internal(format!(
49+
"Unsupported data type {other:?} for function translate",
50+
))),
51+
}
52+
}
53+
other => Err(DataFusionError::Internal(format!(
54+
"Unsupported arguments {other:?} for function translate",
55+
))),
56+
}
57+
}
58+
59+
#[derive(Clone, Copy)]
60+
enum TranslateAction {
61+
Replace(char),
62+
Delete,
63+
}
64+
65+
fn build_translation_map(from: &str, to: &str) -> HashMap<char, TranslateAction> {
66+
let from_chars: Vec<char> = from.chars().collect();
67+
let to_chars: Vec<char> = to.chars().collect();
68+
69+
let mut map = HashMap::with_capacity(from_chars.len());
70+
71+
for (i, from_char) in from_chars.into_iter().enumerate() {
72+
// Only insert the first occurrence of each character to match Spark behaviour
73+
if !map.contains_key(&from_char) {
74+
if i < to_chars.len() {
75+
map.insert(from_char, TranslateAction::Replace(to_chars[i]));
76+
} else {
77+
map.insert(from_char, TranslateAction::Delete);
78+
}
79+
}
80+
}
81+
82+
map
83+
}
84+
85+
fn translate_array_internal<T: OffsetSizeTrait>(
86+
array: &Arc<dyn Array>,
87+
translation_map: &HashMap<char, TranslateAction>,
88+
) -> Result<ColumnarValue, DataFusionError> {
89+
let string_array = as_generic_string_array::<T>(array)?;
90+
91+
let estimated_capacity = string_array.len();
92+
let mut builder = GenericStringBuilder::<T>::with_capacity(
93+
estimated_capacity,
94+
string_array.value_data().len(),
95+
);
96+
97+
let mut buffer = String::new();
98+
99+
for string in string_array.iter() {
100+
match string {
101+
Some(s) => {
102+
buffer.clear();
103+
translate_string(&mut buffer, s, translation_map);
104+
builder.append_value(&buffer);
105+
}
106+
None => builder.append_null(),
107+
}
108+
}
109+
110+
Ok(ColumnarValue::Array(Arc::new(builder.finish())))
111+
}
112+
113+
#[inline]
114+
fn translate_string(
115+
buffer: &mut String,
116+
input: &str,
117+
translation_map: &HashMap<char, TranslateAction>,
118+
) {
119+
buffer.reserve(input.len());
120+
121+
for ch in input.chars() {
122+
match translation_map.get(&ch) {
123+
Some(TranslateAction::Replace(replacement)) => buffer.push(*replacement),
124+
Some(TranslateAction::Delete) => {}
125+
None => buffer.push(ch),
126+
}
127+
}
128+
}
129+
130+
#[cfg(test)]
131+
mod tests {
132+
use super::*;
133+
use arrow::array::StringArray;
134+
135+
#[test]
136+
fn test_translate_basic() {
137+
let input = Arc::new(StringArray::from(vec![
138+
Some("Spark SQL"),
139+
Some("hello"),
140+
None,
141+
Some(""),
142+
])) as Arc<dyn Array>;
143+
144+
let result = spark_translate(&[
145+
ColumnarValue::Array(input),
146+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("SL".to_string()))),
147+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("12".to_string()))),
148+
])
149+
.unwrap();
150+
151+
if let ColumnarValue::Array(arr) = result {
152+
let result_array = arr.as_any().downcast_ref::<StringArray>().unwrap();
153+
assert_eq!(result_array.value(0), "1park 1Q2");
154+
assert_eq!(result_array.value(1), "hello");
155+
assert!(result_array.is_null(2));
156+
assert_eq!(result_array.value(3), "");
157+
} else {
158+
panic!("Expected array result");
159+
}
160+
}
161+
162+
#[test]
163+
fn test_translate_with_delete() {
164+
// When `from` is longer than `to`, extra characters in `from` should be deleted
165+
let input = Arc::new(StringArray::from(vec![Some("abcdef")])) as Arc<dyn Array>;
166+
167+
let result = spark_translate(&[
168+
ColumnarValue::Array(input),
169+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))),
170+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("XY".to_string()))),
171+
])
172+
.unwrap();
173+
174+
if let ColumnarValue::Array(arr) = result {
175+
let result_array = arr.as_any().downcast_ref::<StringArray>().unwrap();
176+
// 'a' -> 'X', 'b' -> 'Y', 'c' -> deleted, 'd' -> deleted
177+
assert_eq!(result_array.value(0), "XYef");
178+
} else {
179+
panic!("Expected array result");
180+
}
181+
}
182+
183+
#[test]
184+
fn test_translate_unicode() {
185+
let input = Arc::new(StringArray::from(vec![Some("苹果手机")])) as Arc<dyn Array>;
186+
187+
let result = spark_translate(&[
188+
ColumnarValue::Array(input),
189+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("苹".to_string()))),
190+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("1".to_string()))),
191+
])
192+
.unwrap();
193+
194+
if let ColumnarValue::Array(arr) = result {
195+
let result_array = arr.as_any().downcast_ref::<StringArray>().unwrap();
196+
assert_eq!(result_array.value(0), "1果手机");
197+
} else {
198+
panic!("Expected array result");
199+
}
200+
}
201+
202+
#[test]
203+
fn test_translate_duplicate_from_chars() {
204+
// Only the first occurrence of each character in `from` should be used
205+
let input = Arc::new(StringArray::from(vec![Some("aaa")])) as Arc<dyn Array>;
206+
207+
let result = spark_translate(&[
208+
ColumnarValue::Array(input),
209+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("aaa".to_string()))),
210+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("xyz".to_string()))),
211+
])
212+
.unwrap();
213+
214+
if let ColumnarValue::Array(arr) = result {
215+
let result_array = arr.as_any().downcast_ref::<StringArray>().unwrap();
216+
// All 'a' should map to 'x' (first mapping wins)
217+
assert_eq!(result_array.value(0), "xxx");
218+
} else {
219+
panic!("Expected array result");
220+
}
221+
}
222+
}

0 commit comments

Comments
 (0)