diff --git a/.github/workflows/spark_sql_test.yml b/.github/workflows/spark_sql_test.yml index d143ef83a0..3d7aa2e2f9 100644 --- a/.github/workflows/spark_sql_test.yml +++ b/.github/workflows/spark_sql_test.yml @@ -59,6 +59,10 @@ jobs: - {name: "sql_hive-1", args1: "", args2: "hive/testOnly * -- -l org.apache.spark.tags.ExtendedHiveTest -l org.apache.spark.tags.SlowHiveTest"} - {name: "sql_hive-2", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.ExtendedHiveTest"} - {name: "sql_hive-3", args1: "", args2: "hive/testOnly * -- -n org.apache.spark.tags.SlowHiveTest"} + # Skip sql_hive-1 for Spark 4.0 due to https://github.com/apache/datafusion-comet/issues/2946 + exclude: + - spark-version: {short: '4.0', full: '4.0.1', java: 17} + module: {name: "sql_hive-1", args1: "", args2: "hive/testOnly * -- -l org.apache.spark.tags.ExtendedHiveTest -l org.apache.spark.tags.SlowHiveTest"} fail-fast: false name: spark-sql-${{ matrix.module.name }}/${{ matrix.os }}/spark-${{ matrix.spark-version.full }}/java-${{ matrix.spark-version.java }} runs-on: ${{ matrix.os }} diff --git a/docs/source/user-guide/latest/configs.md b/docs/source/user-guide/latest/configs.md index 13a9c752e3..1a273ad033 100644 --- a/docs/source/user-guide/latest/configs.md +++ b/docs/source/user-guide/latest/configs.md @@ -307,6 +307,7 @@ These settings can be used to determine which parts of the plan are accelerated | `spark.comet.expression.Signum.enabled` | Enable Comet acceleration for `Signum` | true | | `spark.comet.expression.Sin.enabled` | Enable Comet acceleration for `Sin` | true | | `spark.comet.expression.Sinh.enabled` | Enable Comet acceleration for `Sinh` | true | +| `spark.comet.expression.Size.enabled` | Enable Comet acceleration for `Size` | true | | `spark.comet.expression.SortOrder.enabled` | Enable Comet acceleration for `SortOrder` | true | | `spark.comet.expression.SparkPartitionID.enabled` | Enable Comet acceleration for `SparkPartitionID` | true | | `spark.comet.expression.Sqrt.enabled` | Enable Comet acceleration for `Sqrt` | true | diff --git a/native/spark-expr/src/array_funcs/mod.rs b/native/spark-expr/src/array_funcs/mod.rs index cdfb1f4db4..063dd7a5aa 100644 --- a/native/spark-expr/src/array_funcs/mod.rs +++ b/native/spark-expr/src/array_funcs/mod.rs @@ -19,8 +19,10 @@ mod array_insert; mod array_repeat; mod get_array_struct_fields; mod list_extract; +mod size; pub use array_insert::ArrayInsert; pub use array_repeat::spark_array_repeat; pub use get_array_struct_fields::GetArrayStructFields; pub use list_extract::ListExtract; +pub use size::{spark_size, SparkSizeFunc}; diff --git a/native/spark-expr/src/array_funcs/size.rs b/native/spark-expr/src/array_funcs/size.rs new file mode 100644 index 0000000000..9777553341 --- /dev/null +++ b/native/spark-expr/src/array_funcs/size.rs @@ -0,0 +1,419 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Array, ArrayRef, Int32Array}; +use arrow::datatypes::{DataType, Field}; +use datafusion::common::{exec_err, DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion::logical_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +/// Spark size() function that returns the size of arrays or maps. +/// Returns -1 for null inputs (Spark behavior differs from standard SQL). +pub fn spark_size(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return exec_err!("size function takes exactly one argument"); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let result = spark_size_array(array)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + let result = spark_size_scalar(scalar)?; + Ok(ColumnarValue::Scalar(result)) + } + } +} + +#[derive(Debug, Hash, Eq, PartialEq)] +pub struct SparkSizeFunc { + signature: Signature, +} + +impl Default for SparkSizeFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkSizeFunc { + pub fn new() -> Self { + use DataType::*; + Self { + signature: Signature::uniform( + 1, + vec![ + List(Arc::new(Field::new("item", Null, true))), + LargeList(Arc::new(Field::new("item", Null, true))), + FixedSizeList(Arc::new(Field::new("item", Null, true)), -1), + Map(Arc::new(Field::new("entries", Null, true)), false), + ], + Volatility::Immutable, + ), + } + } +} + +impl ScalarUDFImpl for SparkSizeFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "size" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DataFusionResult { + Ok(DataType::Int32) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DataFusionResult { + spark_size(&args.args) + } +} + +fn spark_size_array(array: &ArrayRef) -> Result { + let mut builder = Int32Array::builder(array.len()); + + match array.data_type() { + DataType::List(_) => { + let list_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected ListArray".to_string()))?; + + for i in 0..list_array.len() { + if list_array.is_null(i) { + builder.append_value(-1); // Spark behavior: return -1 for null + } else { + let list_len = list_array.value(i).len() as i32; + builder.append_value(list_len); + } + } + } + DataType::LargeList(_) => { + let list_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected LargeListArray".to_string()))?; + + for i in 0..list_array.len() { + if list_array.is_null(i) { + builder.append_value(-1); // Spark behavior: return -1 for null + } else { + let list_len = list_array.value(i).len() as i32; + builder.append_value(list_len); + } + } + } + DataType::FixedSizeList(_, size) => { + let fixed_list_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal("Expected FixedSizeListArray".to_string()) + })?; + + for i in 0..fixed_list_array.len() { + if fixed_list_array.is_null(i) { + builder.append_value(-1); // Spark behavior: return -1 for null + } else { + builder.append_value(*size); + } + } + } + DataType::Map(_, _) => { + let map_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| DataFusionError::Internal("Expected MapArray".to_string()))?; + + for i in 0..map_array.len() { + if map_array.is_null(i) { + builder.append_value(-1); // Spark behavior: return -1 for null + } else { + let map_len = map_array.value_length(i); + builder.append_value(map_len); + } + } + } + _ => { + return exec_err!( + "size function only supports arrays and maps, got: {:?}", + array.data_type() + ); + } + } + + Ok(Arc::new(builder.finish())) +} + +fn spark_size_scalar(scalar: &ScalarValue) -> Result { + match scalar { + ScalarValue::List(array) => { + // ScalarValue::List contains a ListArray with exactly one row. + // We need the length of that row's contents, not the row count. + if array.is_null(0) { + Ok(ScalarValue::Int32(Some(-1))) // Spark behavior: return -1 for null + } else { + let len = array.value(0).len() as i32; + Ok(ScalarValue::Int32(Some(len))) + } + } + ScalarValue::LargeList(array) => { + if array.is_null(0) { + Ok(ScalarValue::Int32(Some(-1))) + } else { + let len = array.value(0).len() as i32; + Ok(ScalarValue::Int32(Some(len))) + } + } + ScalarValue::FixedSizeList(array) => { + if array.is_null(0) { + Ok(ScalarValue::Int32(Some(-1))) + } else { + let len = array.value(0).len() as i32; + Ok(ScalarValue::Int32(Some(len))) + } + } + ScalarValue::Null => { + Ok(ScalarValue::Int32(Some(-1))) // Spark behavior: return -1 for null + } + _ => { + exec_err!( + "size function only supports arrays and maps, got: {:?}", + scalar + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Int32Array, ListArray, NullBufferBuilder}; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + #[test] + fn test_spark_size_array() { + // Create test data: [[1, 2, 3], [4, 5], null, []] + let value_data = Int32Array::from(vec![1, 2, 3, 4, 5]); + let value_offsets = arrow::buffer::OffsetBuffer::new(vec![0, 3, 5, 5, 5].into()); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + let mut null_buffer = NullBufferBuilder::new(4); + null_buffer.append(true); // [1, 2, 3] - not null + null_buffer.append(true); // [4, 5] - not null + null_buffer.append(false); // null + null_buffer.append(true); // [] - not null but empty + + let list_array = ListArray::try_new( + field, + value_offsets, + Arc::new(value_data), + null_buffer.finish(), + ) + .unwrap(); + + let array_ref: ArrayRef = Arc::new(list_array); + let result = spark_size_array(&array_ref).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + // Expected: [3, 2, -1, 0] + assert_eq!(result.value(0), 3); // [1, 2, 3] has 3 elements + assert_eq!(result.value(1), 2); // [4, 5] has 2 elements + assert_eq!(result.value(2), -1); // null returns -1 + assert_eq!(result.value(3), 0); // [] has 0 elements + } + + #[test] + fn test_spark_size_scalar() { + // Test non-null list with 3 elements + let values = Int32Array::from(vec![1, 2, 3]); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = arrow::buffer::OffsetBuffer::new(vec![0, 3].into()); + let list_array = ListArray::try_new(field, offsets, Arc::new(values), None).unwrap(); + let scalar = ScalarValue::List(Arc::new(list_array)); + let result = spark_size_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Int32(Some(3))); // The array [1,2,3] has 3 elements + + // Test empty list + let empty_values = Int32Array::from(vec![] as Vec); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + let offsets = arrow::buffer::OffsetBuffer::new(vec![0, 0].into()); + let empty_list_array = + ListArray::try_new(field, offsets, Arc::new(empty_values), None).unwrap(); + let scalar = ScalarValue::List(Arc::new(empty_list_array)); + let result = spark_size_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Int32(Some(0))); // Empty array has 0 elements + + // Test null handling + let scalar = ScalarValue::Null; + let result = spark_size_scalar(&scalar).unwrap(); + assert_eq!(result, ScalarValue::Int32(Some(-1))); + } + + // TODO: Add map array test once Arrow MapArray API constraints are resolved + // Currently MapArray doesn't allow nulls in entries which makes testing complex + // The core size() implementation supports maps correctly + #[ignore] + #[test] + fn test_spark_size_map_array() { + use arrow::array::{MapArray, StringArray}; + + // Create a simpler test with maps: + // [{"key1": "value1", "key2": "value2"}, {"key3": "value3"}, {}, null] + + // Create keys array for all entries (no nulls) + let keys = StringArray::from(vec!["key1", "key2", "key3"]); + + // Create values array for all entries (no nulls) + let values = StringArray::from(vec!["value1", "value2", "value3"]); + + // Create entry offsets: [0, 2, 3, 3] representing: + // - Map 1: entries 0-1 (2 key-value pairs) + // - Map 2: entries 2-2 (1 key-value pair) + // - Map 3: entries 3-2 (0 key-value pairs, empty map) + // - Map 4: null (handled by null buffer) + let entry_offsets = arrow::buffer::OffsetBuffer::new(vec![0, 2, 3, 3, 3].into()); + + let key_field = Arc::new(Field::new("key", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("value", DataType::Utf8, false)); // Make values non-nullable too + + // Create the entries struct array + let entries = arrow::array::StructArray::new( + arrow::datatypes::Fields::from(vec![key_field, value_field]), + vec![Arc::new(keys), Arc::new(values)], + None, // No nulls in the entries struct array itself + ); + + // Create null buffer for the map array (fourth map is null) + let mut null_buffer = NullBufferBuilder::new(4); + null_buffer.append(true); // Map with 2 entries - not null + null_buffer.append(true); // Map with 1 entry - not null + null_buffer.append(true); // Empty map - not null + null_buffer.append(false); // null map + + let map_data_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(arrow::datatypes::Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), // Make values non-nullable too + ])), + false, + )), + false, // keys are not sorted + ); + + let map_field = Arc::new(Field::new("map", map_data_type, true)); + + let map_array = MapArray::new( + map_field, + entry_offsets, + entries, + null_buffer.finish(), + false, // keys are not sorted + ); + + let array_ref: ArrayRef = Arc::new(map_array); + let result = spark_size_array(&array_ref).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + // Expected: [2, 1, 0, -1] + assert_eq!(result.value(0), 2); // Map with 2 key-value pairs + assert_eq!(result.value(1), 1); // Map with 1 key-value pair + assert_eq!(result.value(2), 0); // empty map has 0 pairs + assert_eq!(result.value(3), -1); // null map returns -1 + } + + #[test] + fn test_spark_size_fixed_size_list_array() { + use arrow::array::FixedSizeListArray; + + // Create test data: fixed-size arrays of size 3 + // [[1, 2, 3], [4, 5, 6], null] + let values = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 0, 0, 0]); // Last 3 values are for the null entry + let list_size = 3; + + let mut null_buffer = NullBufferBuilder::new(3); + null_buffer.append(true); // [1, 2, 3] - not null + null_buffer.append(true); // [4, 5, 6] - not null + null_buffer.append(false); // null + + let list_field = Arc::new(Field::new("item", DataType::Int32, true)); + + let fixed_list_array = FixedSizeListArray::new( + list_field, + list_size, + Arc::new(values), + null_buffer.finish(), + ); + + let array_ref: ArrayRef = Arc::new(fixed_list_array); + let result = spark_size_array(&array_ref).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + // Expected: [3, 3, -1] + assert_eq!(result.value(0), 3); // Fixed-size list always has size 3 + assert_eq!(result.value(1), 3); // Fixed-size list always has size 3 + assert_eq!(result.value(2), -1); // null returns -1 + } + + #[test] + fn test_spark_size_large_list_array() { + use arrow::array::LargeListArray; + + // Create test data: [[1, 2, 3, 4], [5], null, []] + let value_data = Int32Array::from(vec![1, 2, 3, 4, 5]); + let value_offsets = arrow::buffer::OffsetBuffer::new(vec![0i64, 4, 5, 5, 5].into()); + let field = Arc::new(Field::new("item", DataType::Int32, true)); + + let mut null_buffer = NullBufferBuilder::new(4); + null_buffer.append(true); // [1, 2, 3, 4] - not null + null_buffer.append(true); // [5] - not null + null_buffer.append(false); // null + null_buffer.append(true); // [] - not null but empty + + let large_list_array = LargeListArray::try_new( + field, + value_offsets, + Arc::new(value_data), + null_buffer.finish(), + ) + .unwrap(); + + let array_ref: ArrayRef = Arc::new(large_list_array); + let result = spark_size_array(&array_ref).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + + // Expected: [4, 1, -1, 0] + assert_eq!(result.value(0), 4); // [1, 2, 3, 4] has 4 elements + assert_eq!(result.value(1), 1); // [5] has 1 element + assert_eq!(result.value(2), -1); // null returns -1 + assert_eq!(result.value(3), 0); // [] has 0 elements + } +} diff --git a/native/spark-expr/src/comet_scalar_funcs.rs b/native/spark-expr/src/comet_scalar_funcs.rs index 56b12a9e48..8384a4646a 100644 --- a/native/spark-expr/src/comet_scalar_funcs.rs +++ b/native/spark-expr/src/comet_scalar_funcs.rs @@ -22,7 +22,7 @@ use crate::math_funcs::modulo_expr::spark_modulo; use crate::{ spark_array_repeat, spark_ceil, spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_isnan, spark_lpad, spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, - spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, + spark_unhex, spark_unscaled_value, EvalMode, SparkBitwiseCount, SparkDateTrunc, SparkSizeFunc, SparkStringSpace, }; use arrow::datatypes::DataType; @@ -194,6 +194,7 @@ fn all_scalar_functions() -> Vec> { Arc::new(ScalarUDF::new_from_impl(SparkBitwiseCount::default())), Arc::new(ScalarUDF::new_from_impl(SparkDateTrunc::default())), Arc::new(ScalarUDF::new_from_impl(SparkStringSpace::default())), + Arc::new(ScalarUDF::new_from_impl(SparkSizeFunc::default())), ] } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 83917d33fc..e50b1d80e6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -63,7 +63,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[CreateArray] -> CometCreateArray, classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, - classOf[GetArrayItem] -> CometGetArrayItem) + classOf[GetArrayItem] -> CometGetArrayItem, + classOf[Size] -> CometSize) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f59234d402..5d989b4a35 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -543,6 +543,31 @@ object CometArrayFilter extends CometExpressionSerde[ArrayFilter] { } } +object CometSize extends CometExpressionSerde[Size] { + + override def getSupportLevel(expr: Size): SupportLevel = { + // TODO respect spark.sql.legacy.sizeOfNull + expr.child.dataType match { + case _: ArrayType => Compatible() + case _: MapType => Unsupported(Some("size does not support map inputs")) + case other => + // this should be unreachable because Spark only supports map and array inputs + Unsupported(Some(s"Unsupported child data type: $other")) + } + + } + + override def convert( + expr: Size, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val arrayExprProto = exprToProto(expr.child, inputs, binding) + + val sizeScalarExpr = scalarFunctionExprToProto("size", arrayExprProto) + optExprWithInfo(sizeScalarExpr, expr) + } +} + trait ArraysBase { def isTypeSupported(dt: DataType): Boolean = { diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 4d06baaa8d..9f908e741e 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -829,4 +829,46 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp } } } + + test("size with array input") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Test size function with arrays built from columns (ensures native execution) + checkSparkAnswerAndOperator( + sql("SELECT size(array(_2, _3, _4)) from t1 where _2 is not null order by _2, _3, _4")) + checkSparkAnswerAndOperator( + sql("SELECT size(array(_1)) from t1 where _1 is not null order by _1")) + checkSparkAnswerAndOperator( + sql("SELECT size(array(_2, _3)) from t1 where _2 is null order by _2, _3")) + + // Test with conditional arrays (forces runtime evaluation) + checkSparkAnswerAndOperator(sql( + "SELECT size(case when _2 > 0 then array(_2, _3, _4) else array(_2) end) from t1 order by _2, _3, _4")) + checkSparkAnswerAndOperator(sql( + "SELECT size(case when _1 then array(_8, _9) else array(_8, _9, _10) end) from t1 order by _1, _8, _9, _10")) + + // Test empty arrays using conditional logic to avoid constant folding + checkSparkAnswerAndOperator(sql( + "SELECT size(case when _2 < 0 then array(_2, _3) else array() end) from t1 order by _2, _3")) + + // Test null arrays using conditional logic + checkSparkAnswerAndOperator(sql( + "SELECT size(case when _2 is null then cast(null as array) else array(_2) end) from t1 order by _2")) + + // Test with different data types using column references + checkSparkAnswerAndOperator( + sql("SELECT size(array(_8, _9, _10)) from t1 where _8 is not null order by _8, _9, _10") + ) // string arrays + checkSparkAnswerAndOperator( + sql( + "SELECT size(array(_2, _3, _4, _5, _6)) from t1 where _2 is not null order by _2, _3, _4, _5, _6" + ) + ) // int arrays + } + } + } } diff --git a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala index 88c13391a6..9276a20348 100644 --- a/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometMapExpressionSuite.scala @@ -125,4 +125,36 @@ class CometMapExpressionSuite extends CometTestBase { } } + test("fallback for size with map input") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Use column references in maps to avoid constant folding + checkSparkAnswerAndFallbackReason( + sql("SELECT size(case when _2 < 0 then map(_8, _9) else map() end) from t1"), + "size does not support map inputs") + } + } + } + + // fails with "map is not supported" + ignore("size with map input") { + withTempDir { dir => + withTempView("t1") { + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllPrimitiveTypes(path, dictionaryEnabled = true, 100) + spark.read.parquet(path.toString).createOrReplaceTempView("t1") + + // Use column references in maps to avoid constant folding + checkSparkAnswerAndOperator( + sql("SELECT size(map(_8, _9, _10, _11)) from t1 where _8 is not null")) + checkSparkAnswerAndOperator( + sql("SELECT size(case when _2 < 0 then map(_8, _9) else map() end) from t1")) + } + } + } + }