diff --git a/native/spark-expr/src/array_funcs/array_distinct.rs b/native/spark-expr/src/array_funcs/array_distinct.rs new file mode 100644 index 0000000000..4f2e358cd8 --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_distinct.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ + as_large_list_array, as_list_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{DataType, Field, FieldRef}; +use arrow::row::{Row, RowConverter, SortField}; +use datafusion::common::utils::take_function_args; +use datafusion::common::Result; +use datafusion::common::{exec_err, HashSet}; +use datafusion::functions::utils::make_scalar_function; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkArrayDistinct { + signature: Signature, +} + +impl Default for SparkArrayDistinct { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayDistinct { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![ + List(Arc::new(Field::new("item", Null, true))), + LargeList(Arc::new(Field::new("item", Null, true))), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkArrayDistinct { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_distinct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_distinct_inner, vec![])(&args.args) + } +} + +fn array_distinct_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_distinct", args)?; + match array.data_type() { + DataType::List(field) => { + let array = as_list_array(array); + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(array); + general_array_distinct(array, field) + } + _ => { + exec_err!( + "array_distinct function only support arrays, got: {:?}", + array.data_type() + ) + } + } +} + +fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + if array.is_empty() { + return Ok(Arc::new(array.clone()) as ArrayRef); + } + + let value_offsets = array.value_offsets(); + let original_data = array.values().to_data(); + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + + let first_offset = value_offsets[0].as_usize(); + let visible_len = value_offsets[array.len()].as_usize() - first_offset; + let rows = converter.convert_columns(&[array.values().slice(first_offset, visible_len)])?; + + let mut mutable = arrow::array::MutableArrayData::new(vec![&original_data], false, visible_len); + + for i in 0..array.len() { + let last_offset = *offsets.last().unwrap(); + + if array.is_null(i) { + offsets.push(last_offset); + continue; + } + + let start = value_offsets[i].as_usize() - first_offset; + let end = value_offsets[i + 1].as_usize() - first_offset; + let array_len = end - start; + + let mut seen: HashSet> = HashSet::with_capacity(array_len); + let mut seen_null = false; + let mut distinct_count: usize = 0; + + for idx in start..end { + let abs_idx = idx + first_offset; + + if array.values().is_null(abs_idx) { + if !seen_null { + seen_null = true; + mutable.extend(0, abs_idx, abs_idx + 1); + distinct_count += 1; + } + } else { + let row = rows.row(idx); + if seen.insert(row) { + mutable.extend(0, abs_idx, abs_idx + 1); + distinct_count += 1; + } + } + } + + offsets.push(last_offset + OffsetSize::usize_as(distinct_count)); + } + + let final_values = arrow::array::make_array(mutable.freeze()); + + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(field), + OffsetBuffer::new(offsets.into()), + final_values, + array.nulls().cloned(), + )?)) +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 2bd1b9631b..849968cdd9 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,12 +16,14 @@ // under the License. mod array_compact; +mod array_distinct; mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; pub use array_compact::SparkArrayCompact; +pub use array_distinct::SparkArrayDistinct; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index aae5a09095..05ee01dac2 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,8 +23,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkArrayCompact, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayDistinct, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -206,6 +206,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), + Arc::new(ScalarUDF::new_from_impl(SparkArrayDistinct::default())), ] } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f107d5b309..c23ee67be0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -119,22 +119,7 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] { } } -object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] { - - override def getSupportLevel(expr: ArrayDistinct): SupportLevel = - Incompatible(Some("Output elements are sorted rather than preserving insertion order")) - - override def convert( - expr: ArrayDistinct, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val arrayExprProto = exprToProto(expr.children.head, inputs, binding) - - val arrayDistinctScalarExpr = - scalarFunctionExprToProto("array_distinct", arrayExprProto) - optExprWithInfo(arrayDistinctScalarExpr, expr) - } -} +object CometArrayDistinct extends CometScalarFunction[ArrayDistinct]("array_distinct") object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql b/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql index f9d63df075..77b8feef61 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql @@ -34,23 +34,26 @@ INSERT INTO test_array_distinct_int VALUES (array(0, -1, -1, 0, 1)) -- column argument -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_int -- literal arguments -query spark_answer_only +query SELECT array_distinct(array(1, 2, 2, 3, 3)) +query +SELECT array_distinct(array(3, 2, 2, 1, 1)) + -- all NULLs -query spark_answer_only +query SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT))) -- NULL input -query spark_answer_only +query SELECT array_distinct(CAST(NULL AS array)) -- boundary values -query spark_answer_only +query SELECT array_distinct(array(-2147483648, 2147483647, -2147483648, 2147483647, 0)) -- ===== LONG arrays ===== @@ -65,11 +68,11 @@ INSERT INTO test_array_distinct_long VALUES (array(NULL, 1, NULL, 2)), (array(-9223372036854775808, 9223372036854775807, -9223372036854775808)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_long -- boundary values -query spark_answer_only +query SELECT array_distinct(array(CAST(-9223372036854775808 AS BIGINT), CAST(9223372036854775807 AS BIGINT), CAST(-9223372036854775808 AS BIGINT))) -- ===== STRING arrays ===== @@ -86,11 +89,11 @@ INSERT INTO test_array_distinct_string VALUES (array('', '', NULL, '')), (array('hello', 'world', 'hello')) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_string -- empty string and NULL distinction -query spark_answer_only +query SELECT array_distinct(array('', NULL, '', NULL, 'a')) -- ===== BOOLEAN arrays ===== @@ -105,7 +108,7 @@ INSERT INTO test_array_distinct_bool VALUES (NULL), (array(NULL, true, NULL, false)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_bool -- ===== DOUBLE arrays ===== @@ -119,23 +122,23 @@ INSERT INTO test_array_distinct_double VALUES (NULL), (array(NULL, 1.0, NULL, 2.0)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_double -- NaN deduplication -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), 1.0, 1.0)) -- NaN with NULL -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS DOUBLE), NULL, CAST('NaN' AS DOUBLE), NULL, 1.0)) -- Infinity -query spark_answer_only +query SELECT array_distinct(array(CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('Infinity' AS DOUBLE), 0.0)) -- negative zero -query spark_answer_only +query SELECT array_distinct(array(0.0, -0.0, 1.0)) -- ===== FLOAT arrays ===== @@ -149,11 +152,11 @@ INSERT INTO test_array_distinct_float VALUES (NULL), (array(CAST(NULL AS FLOAT), CAST(1.0 AS FLOAT), CAST(NULL AS FLOAT))) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_float -- Float NaN deduplication -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS FLOAT), CAST('NaN' AS FLOAT), CAST(1.0 AS FLOAT))) -- ===== DECIMAL arrays ===== @@ -167,13 +170,13 @@ INSERT INTO test_array_distinct_decimal VALUES (NULL), (array(NULL, 1.10, NULL, 1.10)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_decimal -- ===== Nested array (array of arrays) ===== -query spark_answer_only +query SELECT array_distinct(array(array(1, 2), array(3, 4), array(1, 2), array(3, 4))) -query spark_answer_only +query SELECT array_distinct(array(array(1, 2), CAST(NULL AS array), array(1, 2), CAST(NULL AS array)))