From 768b3e90f261c7aea58bdb98dc698b90deeeae34 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 14 Dec 2025 16:24:01 +0400 Subject: [PATCH 1/4] impl map_from_entries --- native/core/src/execution/jni_api.rs | 2 + .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/maps.scala | 29 +++++++++++- .../comet/CometMapExpressionSuite.scala | 45 +++++++++++++++++++ 4 files changed, 77 insertions(+), 2 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index a24d993059..4f53cea3e6 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,6 +46,7 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; +use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; @@ -337,6 +338,7 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); + session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 54df2f1688..a99cf3824b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -125,7 +125,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays) + classOf[MapFromArrays] -> CometMapFromArrays, + classOf[MapFromEntries] -> CometMapFromEntries) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 2e217f6af0..498aa3594c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -19,9 +19,12 @@ package org.apache.comet.serde +import scala.annotation.tailrec + import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, MapType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType} +import org.apache.comet.serde.CometArrayReverse.containsBinary import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -89,3 +92,27 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } + +object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { + val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" + val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" + + private def containsBinary(dataType: DataType): Boolean = { + dataType match { + case BinaryType => true + case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) + case ArrayType(elementType, _) => containsBinary(elementType) + case _ => false + } + } + + override def getSupportLevel(expr: MapFromEntries): SupportLevel = { + if (containsBinary(expr.dataType.keyType)) { + return Incompatible(Some(keyUnsupportedReason)) + } + if (containsBinary(expr.dataType.valueType)) { + return Incompatible(Some(valueUnsupportedReason)) + } + Compatible(None) + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 88c13391a6..01b9744ed6 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,7 +25,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.BinaryType +import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -125,4 +127,47 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("map_from_entries") { + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + val filename = path.toString + val random = new Random(42) + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val schemaGenOptions = + SchemaGenOptions( + generateArray = true, + generateStruct = true, + primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) + val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) + ParquetGenerator.makeParquetFile( + random, + spark, + filename, + 100, + schemaGenOptions, + dataGenOptions) + } + val df = spark.read.parquet(filename) + df.createOrReplaceTempView("t1") + for (field <- df.schema.fieldNames) { + checkSparkAnswerAndOperator( + spark.sql(s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1")) + } + } + } + + test("map_from_entries - fallback for binary type") { + val table = "t2" + withTable(table) { + sql( + s"create table $table using parquet as select cast(array() as array) as c1 from range(10)") + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(c1, 0))) from $table"), + CometMapFromEntries.keyUnsupportedReason) + checkSparkAnswerAndFallbackReason( + sql(s"select map_from_entries(array(struct(0, c1))) from $table"), + CometMapFromEntries.valueUnsupportedReason) + } + } + } From c68c3428676b5d991e7ba9e13464bf2ce1ec84e8 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Tue, 16 Dec 2025 16:10:43 +0400 Subject: [PATCH 2/4] Revert "impl map_from_entries" This reverts commit 768b3e90f261c7aea58bdb98dc698b90deeeae34. --- native/core/src/execution/jni_api.rs | 2 - .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/maps.scala | 29 +----------- .../comet/CometMapExpressionSuite.scala | 45 ------------------- 4 files changed, 2 insertions(+), 77 deletions(-) diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 4f53cea3e6..a24d993059 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -46,7 +46,6 @@ use datafusion_spark::function::datetime::date_add::SparkDateAdd; use datafusion_spark::function::datetime::date_sub::SparkDateSub; use datafusion_spark::function::hash::sha1::SparkSha1; use datafusion_spark::function::hash::sha2::SparkSha2; -use datafusion_spark::function::map::map_from_entries::MapFromEntries; use datafusion_spark::function::math::expm1::SparkExpm1; use datafusion_spark::function::string::char::CharFunc; use datafusion_spark::function::string::concat::SparkConcat; @@ -338,7 +337,6 @@ fn register_datafusion_spark_function(session_ctx: &SessionContext) { session_ctx.register_udf(ScalarUDF::new_from_impl(SparkSha1::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkConcat::default())); session_ctx.register_udf(ScalarUDF::new_from_impl(SparkBitwiseNot::default())); - session_ctx.register_udf(ScalarUDF::new_from_impl(MapFromEntries::default())); } /// Prepares arrow arrays for output. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index a99cf3824b..54df2f1688 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -125,8 +125,7 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[MapKeys] -> CometMapKeys, classOf[MapEntries] -> CometMapEntries, classOf[MapValues] -> CometMapValues, - classOf[MapFromArrays] -> CometMapFromArrays, - classOf[MapFromEntries] -> CometMapFromEntries) + classOf[MapFromArrays] -> CometMapFromArrays) private val structExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map( classOf[CreateNamedStruct] -> CometCreateNamedStruct, diff --git a/spark/src/main/scala/org/apache/comet/serde/maps.scala b/spark/src/main/scala/org/apache/comet/serde/maps.scala index 498aa3594c..2e217f6af0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/maps.scala +++ b/spark/src/main/scala/org/apache/comet/serde/maps.scala @@ -19,12 +19,9 @@ package org.apache.comet.serde -import scala.annotation.tailrec - import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, MapType, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType} -import org.apache.comet.serde.CometArrayReverse.containsBinary import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, optExprWithInfo, scalarFunctionExprToProto, scalarFunctionExprToProtoWithReturnType} object CometMapKeys extends CometExpressionSerde[MapKeys] { @@ -92,27 +89,3 @@ object CometMapFromArrays extends CometExpressionSerde[MapFromArrays] { optExprWithInfo(mapFromArraysExpr, expr, expr.children: _*) } } - -object CometMapFromEntries extends CometScalarFunction[MapFromEntries]("map_from_entries") { - val keyUnsupportedReason = "Using BinaryType as Map keys is not allowed in map_from_entries" - val valueUnsupportedReason = "Using BinaryType as Map values is not allowed in map_from_entries" - - private def containsBinary(dataType: DataType): Boolean = { - dataType match { - case BinaryType => true - case StructType(fields) => fields.exists(field => containsBinary(field.dataType)) - case ArrayType(elementType, _) => containsBinary(elementType) - case _ => false - } - } - - override def getSupportLevel(expr: MapFromEntries): SupportLevel = { - if (containsBinary(expr.dataType.keyType)) { - return Incompatible(Some(keyUnsupportedReason)) - } - if (containsBinary(expr.dataType.valueType)) { - return Incompatible(Some(valueUnsupportedReason)) - } - Compatible(None) - } -} diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 01b9744ed6..88c13391a6 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -25,9 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.BinaryType -import org.apache.comet.serde.CometMapFromEntries import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} class CometMapExpressionSuite extends CometTestBase { @@ -127,47 +125,4 @@ class CometMapExpressionSuite extends CometTestBase { } } - test("map_from_entries") { - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - val filename = path.toString - val random = new Random(42) - withSQLConf(CometConf.COMET_ENABLED.key -> "false") { - val schemaGenOptions = - SchemaGenOptions( - generateArray = true, - generateStruct = true, - primitiveTypes = SchemaGenOptions.defaultPrimitiveTypes.filterNot(_ == BinaryType)) - val dataGenOptions = DataGenOptions(allowNull = false, generateNegativeZero = false) - ParquetGenerator.makeParquetFile( - random, - spark, - filename, - 100, - schemaGenOptions, - dataGenOptions) - } - val df = spark.read.parquet(filename) - df.createOrReplaceTempView("t1") - for (field <- df.schema.fieldNames) { - checkSparkAnswerAndOperator( - spark.sql(s"SELECT map_from_entries(array(struct($field as a, $field as b))) FROM t1")) - } - } - } - - test("map_from_entries - fallback for binary type") { - val table = "t2" - withTable(table) { - sql( - s"create table $table using parquet as select cast(array() as array) as c1 from range(10)") - checkSparkAnswerAndFallbackReason( - sql(s"select map_from_entries(array(struct(c1, 0))) from $table"), - CometMapFromEntries.keyUnsupportedReason) - checkSparkAnswerAndFallbackReason( - sql(s"select map_from_entries(array(struct(0, c1))) from $table"), - CometMapFromEntries.valueUnsupportedReason) - } - } - } From 843063241c5beca92b53ea296c7066e11d2948e7 Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 5 Apr 2026 21:19:13 +0400 Subject: [PATCH 3/4] WIP --- .../src/array_funcs/array_distinct.rs | 209 ++++++++++++++++++ native/spark-expr/src/array_funcs/mod.rs | 1 + 2 files changed, 210 insertions(+) create mode 100644 native/spark-expr/src/array_funcs/array_distinct.rs diff --git a/native/spark-expr/src/array_funcs/array_distinct.rs b/native/spark-expr/src/array_funcs/array_distinct.rs new file mode 100644 index 0000000000..81a277798b --- /dev/null +++ b/native/spark-expr/src/array_funcs/array_distinct.rs @@ -0,0 +1,209 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{as_large_list_array, as_list_array, new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait, UInt32Array, UInt64Array}; +use arrow::datatypes::{DataType, Field, FieldRef}; +use datafusion::common::exec_err; +use datafusion::common::utils::take_function_args; +use datafusion::common::Result; +use datafusion::functions::utils::make_scalar_function; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility +}; +use std::any::Any; +use std::collections::HashSet; +use std::sync::Arc; +use arrow::buffer::OffsetBuffer; +use arrow::compute::take; +use arrow::row::{Row, RowConverter, SortField}; + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkArrayDistinct { + signature: Signature, +} + +impl Default for SparkArrayDistinct { + fn default() -> Self { + Self::new() + } +} + +impl SparkArrayDistinct { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![ + List(Arc::new(Field::new("item", Null, true))), + LargeList(Arc::new(Field::new("item", Null, true))), + FixedSizeList(Arc::new(Field::new("item", Null, true)), -1), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkArrayDistinct { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_distinct" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_distinct_inner, vec![])(&args.args) + } +} + +fn array_distinct_inner(args: &[ArrayRef]) -> Result { + let [array] = take_function_args("array_distinct", args)?; + match array.data_type() { + DataType::List(field) => { + let array = as_list_array(array); + general_array_distinct(array, field) + } + DataType::LargeList(field) => { + let array = as_large_list_array(array); + general_array_distinct(array, field) + } + _ => { + exec_err!("array_distinct function only support arrays, got: {:?}", array.data_type()) + } + } +} + +fn general_array_distinct( + array: &GenericListArray, + field: &FieldRef, +) -> Result { + if array.is_empty() { + return Ok(Arc::new(array.clone()) as ArrayRef); + } + let value_offsets = array.value_offsets(); + let dt = array.value_type(); + let mut offsets = Vec::with_capacity(array.len() + 1); + offsets.push(OffsetSize::usize_as(0)); + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + + let first_offset = value_offsets[0].as_usize(); + let visible_len = value_offsets[array.len()].as_usize() - first_offset; + let rows = + converter.convert_columns(&[array.values().slice(first_offset, visible_len)])?; + + let mut indices: Vec = Vec::with_capacity(rows.num_rows()); + let mut seen: HashSet> = HashSet::new(); + + for i in 0..array.len() { + let last_offset = *offsets.last().unwrap(); + + if array.is_null(i) { + offsets.push(last_offset); + continue; + } + + let start = value_offsets[i].as_usize() - first_offset; + let end = value_offsets[i + 1].as_usize() - first_offset; + + seen.clear(); + seen.reserve(end - start); + + let mut seen_null = false; + let mut distinct_count: usize = 0; + + for idx in start..end { + let abs_idx = idx + first_offset; + + if array.values().is_null(abs_idx) { + if !seen_null { + seen_null = true; + indices.push(abs_idx); + distinct_count += 1; + } + } else { + let row = rows.row(idx); + if seen.insert(row) { + indices.push(abs_idx); + distinct_count += 1; + } + } + } + + offsets.push(last_offset + OffsetSize::usize_as(distinct_count)); + } + + let final_values = if indices.is_empty() { + new_empty_array(&dt) + } else if OffsetSize::IS_LARGE { + let indices = + UInt64Array::from(indices.into_iter().map(|i| i as u64).collect::>()); + take(array.values().as_ref(), &indices, None)? + } else { + let indices = + UInt32Array::from(indices.into_iter().map(|i| i as u32).collect::>()); + take(array.values().as_ref(), &indices, None)? + }; + + Ok(Arc::new(GenericListArray::::try_new( + Arc::clone(field), + OffsetBuffer::new(offsets.into()), + final_values, + array.nulls().cloned(), + )?)) +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + use arrow::array::{ArrayRef, Int32Array, ListArray, NullBufferBuilder}; + use arrow::datatypes::{DataType, Field}; + use crate::array_funcs::array_distinct::array_distinct_inner; + + #[test] + fn test_spark_distinct() { + let values = Int32Array::from(vec![4, 1, 2, 1, 3, 4, 5, 6, 0, 0, 0]); + let value_offsets = arrow::buffer::OffsetBuffer::new(vec![0, 10].into()); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let mut null_buffer = NullBufferBuilder::new(1); + null_buffer.append(true); + + let list_array = ListArray::try_new( + field, + value_offsets, + Arc::new(values), + null_buffer.finish(), + ).unwrap(); + + let array_ref: ArrayRef = Arc::new(list_array); + let result = array_distinct_inner(&[array_ref]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + assert_eq!(result.value(0), 4); + } +} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 3ef50a252f..980d9dc5df 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -19,6 +19,7 @@ mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; +mod array_distinct; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; From 17cbf4fcc5e016643bdf1f086edab33545d734be Mon Sep 17 00:00:00 2001 From: Kazantsev Maksim Date: Sun, 12 Apr 2026 21:30:33 +0400 Subject: [PATCH 4/4] Feat: support function compatible array_distinct --- .../src/array_funcs/array_distinct.rs | 79 +++++-------------- native/spark-expr/src/array_funcs/mod.rs | 3 +- native/spark-expr/src/comet_scalar_funcs.rs | 5 +- .../scala/org/apache/comet/serde/arrays.scala | 17 +--- .../expressions/array/array_distinct.sql | 43 +++++----- 5 files changed, 49 insertions(+), 98 deletions(-) diff --git a/native/spark-expr/src/array_funcs/array_distinct.rs b/native/spark-expr/src/array_funcs/array_distinct.rs index 81a277798b..4f2e358cd8 100644 --- a/native/spark-expr/src/array_funcs/array_distinct.rs +++ b/native/spark-expr/src/array_funcs/array_distinct.rs @@ -15,21 +15,21 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{as_large_list_array, as_list_array, new_empty_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait, UInt32Array, UInt64Array}; +use arrow::array::{ + as_large_list_array, as_list_array, Array, ArrayRef, GenericListArray, OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion::common::exec_err; +use arrow::row::{Row, RowConverter, SortField}; use datafusion::common::utils::take_function_args; use datafusion::common::Result; +use datafusion::common::{exec_err, HashSet}; use datafusion::functions::utils::make_scalar_function; use datafusion::logical_expr::{ - ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use std::any::Any; -use std::collections::HashSet; use std::sync::Arc; -use arrow::buffer::OffsetBuffer; -use arrow::compute::take; -use arrow::row::{Row, RowConverter, SortField}; #[derive(Debug, Hash, Eq, PartialEq)] pub struct SparkArrayDistinct { @@ -51,7 +51,6 @@ impl SparkArrayDistinct { vec![ List(Arc::new(Field::new("item", Null, true))), LargeList(Arc::new(Field::new("item", Null, true))), - FixedSizeList(Arc::new(Field::new("item", Null, true)), -1), ], Volatility::Immutable, ), @@ -93,7 +92,10 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result { general_array_distinct(array, field) } _ => { - exec_err!("array_distinct function only support arrays, got: {:?}", array.data_type()) + exec_err!( + "array_distinct function only support arrays, got: {:?}", + array.data_type() + ) } } } @@ -105,7 +107,9 @@ fn general_array_distinct( if array.is_empty() { return Ok(Arc::new(array.clone()) as ArrayRef); } + let value_offsets = array.value_offsets(); + let original_data = array.values().to_data(); let dt = array.value_type(); let mut offsets = Vec::with_capacity(array.len() + 1); offsets.push(OffsetSize::usize_as(0)); @@ -114,11 +118,9 @@ fn general_array_distinct( let first_offset = value_offsets[0].as_usize(); let visible_len = value_offsets[array.len()].as_usize() - first_offset; - let rows = - converter.convert_columns(&[array.values().slice(first_offset, visible_len)])?; + let rows = converter.convert_columns(&[array.values().slice(first_offset, visible_len)])?; - let mut indices: Vec = Vec::with_capacity(rows.num_rows()); - let mut seen: HashSet> = HashSet::new(); + let mut mutable = arrow::array::MutableArrayData::new(vec![&original_data], false, visible_len); for i in 0..array.len() { let last_offset = *offsets.last().unwrap(); @@ -130,10 +132,9 @@ fn general_array_distinct( let start = value_offsets[i].as_usize() - first_offset; let end = value_offsets[i + 1].as_usize() - first_offset; + let array_len = end - start; - seen.clear(); - seen.reserve(end - start); - + let mut seen: HashSet> = HashSet::with_capacity(array_len); let mut seen_null = false; let mut distinct_count: usize = 0; @@ -143,13 +144,13 @@ fn general_array_distinct( if array.values().is_null(abs_idx) { if !seen_null { seen_null = true; - indices.push(abs_idx); + mutable.extend(0, abs_idx, abs_idx + 1); distinct_count += 1; } } else { let row = rows.row(idx); if seen.insert(row) { - indices.push(abs_idx); + mutable.extend(0, abs_idx, abs_idx + 1); distinct_count += 1; } } @@ -158,17 +159,7 @@ fn general_array_distinct( offsets.push(last_offset + OffsetSize::usize_as(distinct_count)); } - let final_values = if indices.is_empty() { - new_empty_array(&dt) - } else if OffsetSize::IS_LARGE { - let indices = - UInt64Array::from(indices.into_iter().map(|i| i as u64).collect::>()); - take(array.values().as_ref(), &indices, None)? - } else { - let indices = - UInt32Array::from(indices.into_iter().map(|i| i as u32).collect::>()); - take(array.values().as_ref(), &indices, None)? - }; + let final_values = arrow::array::make_array(mutable.freeze()); Ok(Arc::new(GenericListArray::::try_new( Arc::clone(field), @@ -177,33 +168,3 @@ fn general_array_distinct( array.nulls().cloned(), )?)) } - -#[cfg(test)] -mod tests { - use std::sync::Arc; - use arrow::array::{ArrayRef, Int32Array, ListArray, NullBufferBuilder}; - use arrow::datatypes::{DataType, Field}; - use crate::array_funcs::array_distinct::array_distinct_inner; - - #[test] - fn test_spark_distinct() { - let values = Int32Array::from(vec![4, 1, 2, 1, 3, 4, 5, 6, 0, 0, 0]); - let value_offsets = arrow::buffer::OffsetBuffer::new(vec![0, 10].into()); - let field = Arc::new(Field::new("item", DataType::Int32, true)); - let mut null_buffer = NullBufferBuilder::new(1); - null_buffer.append(true); - - let list_array = ListArray::try_new( - field, - value_offsets, - Arc::new(values), - null_buffer.finish(), - ).unwrap(); - - let array_ref: ArrayRef = Arc::new(list_array); - let result = array_distinct_inner(&[array_ref]).unwrap(); - let result = result.as_any().downcast_ref::().unwrap(); - - assert_eq!(result.value(0), 4); - } -} diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index 36f838da3d..849968cdd9 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -16,13 +16,14 @@ // under the License. mod array_compact; +mod array_distinct; mod array_insert; mod get_array_struct_fields; mod list_extract; mod size; -mod array_distinct; pub use array_compact::SparkArrayCompact; +pub use array_distinct::SparkArrayDistinct; pub use array_insert::ArrayInsert; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index aae5a09095..05ee01dac2 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -23,8 +23,8 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex, - spark_unscaled_value, EvalMode, SparkArrayCompact, SparkContains, SparkDateDiff, - SparkDateTrunc, SparkMakeDate, SparkSizeFunc, + spark_unscaled_value, EvalMode, SparkArrayCompact, SparkArrayDistinct, SparkContains, + SparkDateDiff, SparkDateTrunc, SparkMakeDate, SparkSizeFunc, }; use arrow::datatypes::DataType; use datafusion::common::{DataFusionError, Result as DataFusionResult}; @@ -206,6 +206,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkMakeDate::default())), Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), + Arc::new(ScalarUDF::new_from_impl(SparkArrayDistinct::default())), ] } diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f107d5b309..c23ee67be0 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -119,22 +119,7 @@ object CometArrayContains extends CometExpressionSerde[ArrayContains] { } } -object CometArrayDistinct extends CometExpressionSerde[ArrayDistinct] { - - override def getSupportLevel(expr: ArrayDistinct): SupportLevel = - Incompatible(Some("Output elements are sorted rather than preserving insertion order")) - - override def convert( - expr: ArrayDistinct, - inputs: Seq[Attribute], - binding: Boolean): Option[ExprOuterClass.Expr] = { - val arrayExprProto = exprToProto(expr.children.head, inputs, binding) - - val arrayDistinctScalarExpr = - scalarFunctionExprToProto("array_distinct", arrayExprProto) - optExprWithInfo(arrayDistinctScalarExpr, expr) - } -} +object CometArrayDistinct extends CometScalarFunction[ArrayDistinct]("array_distinct") object CometArrayIntersect extends CometExpressionSerde[ArrayIntersect] { diff --git a/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql b/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql index f9d63df075..77b8feef61 100644 --- a/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql +++ b/spark/src/test/resources/sql-tests/expressions/array/array_distinct.sql @@ -34,23 +34,26 @@ INSERT INTO test_array_distinct_int VALUES (array(0, -1, -1, 0, 1)) -- column argument -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_int -- literal arguments -query spark_answer_only +query SELECT array_distinct(array(1, 2, 2, 3, 3)) +query +SELECT array_distinct(array(3, 2, 2, 1, 1)) + -- all NULLs -query spark_answer_only +query SELECT array_distinct(array(CAST(NULL AS INT), CAST(NULL AS INT))) -- NULL input -query spark_answer_only +query SELECT array_distinct(CAST(NULL AS array)) -- boundary values -query spark_answer_only +query SELECT array_distinct(array(-2147483648, 2147483647, -2147483648, 2147483647, 0)) -- ===== LONG arrays ===== @@ -65,11 +68,11 @@ INSERT INTO test_array_distinct_long VALUES (array(NULL, 1, NULL, 2)), (array(-9223372036854775808, 9223372036854775807, -9223372036854775808)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_long -- boundary values -query spark_answer_only +query SELECT array_distinct(array(CAST(-9223372036854775808 AS BIGINT), CAST(9223372036854775807 AS BIGINT), CAST(-9223372036854775808 AS BIGINT))) -- ===== STRING arrays ===== @@ -86,11 +89,11 @@ INSERT INTO test_array_distinct_string VALUES (array('', '', NULL, '')), (array('hello', 'world', 'hello')) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_string -- empty string and NULL distinction -query spark_answer_only +query SELECT array_distinct(array('', NULL, '', NULL, 'a')) -- ===== BOOLEAN arrays ===== @@ -105,7 +108,7 @@ INSERT INTO test_array_distinct_bool VALUES (NULL), (array(NULL, true, NULL, false)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_bool -- ===== DOUBLE arrays ===== @@ -119,23 +122,23 @@ INSERT INTO test_array_distinct_double VALUES (NULL), (array(NULL, 1.0, NULL, 2.0)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_double -- NaN deduplication -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS DOUBLE), CAST('NaN' AS DOUBLE), 1.0, 1.0)) -- NaN with NULL -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS DOUBLE), NULL, CAST('NaN' AS DOUBLE), NULL, 1.0)) -- Infinity -query spark_answer_only +query SELECT array_distinct(array(CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), CAST('Infinity' AS DOUBLE), 0.0)) -- negative zero -query spark_answer_only +query SELECT array_distinct(array(0.0, -0.0, 1.0)) -- ===== FLOAT arrays ===== @@ -149,11 +152,11 @@ INSERT INTO test_array_distinct_float VALUES (NULL), (array(CAST(NULL AS FLOAT), CAST(1.0 AS FLOAT), CAST(NULL AS FLOAT))) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_float -- Float NaN deduplication -query spark_answer_only +query SELECT array_distinct(array(CAST('NaN' AS FLOAT), CAST('NaN' AS FLOAT), CAST(1.0 AS FLOAT))) -- ===== DECIMAL arrays ===== @@ -167,13 +170,13 @@ INSERT INTO test_array_distinct_decimal VALUES (NULL), (array(NULL, 1.10, NULL, 1.10)) -query spark_answer_only +query SELECT array_distinct(arr) FROM test_array_distinct_decimal -- ===== Nested array (array of arrays) ===== -query spark_answer_only +query SELECT array_distinct(array(array(1, 2), array(3, 4), array(1, 2), array(3, 4))) -query spark_answer_only +query SELECT array_distinct(array(array(1, 2), CAST(NULL AS array), array(1, 2), CAST(NULL AS array)))