From 0e629a6e3ab64ada5a4a31001b41b62cbcacb1b4 Mon Sep 17 00:00:00 2001 From: comphead Date: Thu, 16 Apr 2026 09:35:33 -0700 Subject: [PATCH] feat: support `sort_array` --- docs/spark_expressions_support.md | 2 +- .../apache/comet/serde/QueryPlanSerde.scala | 3 +- .../scala/org/apache/comet/serde/arrays.scala | 30 ++++- .../expressions/array/sort_array.sql | 103 ++++++++++++++++++ 4 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 spark/src/test/resources/sql-tests/expressions/array/sort_array.sql diff --git a/docs/spark_expressions_support.md b/docs/spark_expressions_support.md index 420778375a..9d9e8f7017 100644 --- a/docs/spark_expressions_support.md +++ b/docs/spark_expressions_support.md @@ -105,7 +105,7 @@ - [ ] sequence - [ ] shuffle - [ ] slice -- [ ] sort_array +- [x] sort_array ### bitwise_funcs diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index b8278cce90..f490d82849 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -66,7 +66,8 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[ElementAt] -> CometElementAt, classOf[Flatten] -> CometFlatten, classOf[GetArrayItem] -> CometGetArrayItem, - classOf[Size] -> CometSize) + classOf[Size] -> CometSize, + classOf[SortArray] -> CometSortArray) private val conditionalExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(classOf[CaseWhen] -> CometCaseWhen, classOf[If] -> CometIf) diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index f107d5b309..f1908f351c 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -21,7 +21,7 @@ package org.apache.comet.serde import scala.annotation.tailrec -import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size} +import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayContains, ArrayDistinct, ArrayExcept, ArrayFilter, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayMax, ArrayMin, ArrayRemove, ArrayRepeat, ArraysOverlap, ArrayUnion, Attribute, CreateArray, ElementAt, Expression, Flatten, GetArrayItem, IsNotNull, Literal, Reverse, Size, SortArray} import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -178,6 +178,34 @@ object CometArrayMin extends CometExpressionSerde[ArrayMin] { } } +object CometSortArray extends CometExpressionSerde[SortArray] { + + override def convert( + expr: SortArray, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + val ascending = expr.ascendingOrder match { + case Literal(v: Boolean, BooleanType) => v + case _ => + withInfo(expr, "ascending order must be a boolean literal") + return None + } + + val arrayExprProto = exprToProto(expr.base, inputs, binding) + val (desc, nullsFirst) = if (ascending) { + ("ASC", "NULLS FIRST") + } else { + ("DESC", "NULLS LAST") + } + val descExprProto = exprToProtoInternal(Literal(desc), inputs, binding) + val nullsFirstExprProto = exprToProtoInternal(Literal(nullsFirst), inputs, binding) + + val sortArrayScalarExpr = + scalarFunctionExprToProto("array_sort", arrayExprProto, descExprProto, nullsFirstExprProto) + optExprWithInfo(sortArrayScalarExpr, expr, expr.children: _*) + } +} + object CometArraysOverlap extends CometExpressionSerde[ArraysOverlap] { override def getSupportLevel(expr: ArraysOverlap): SupportLevel = diff --git a/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql new file mode 100644 index 0000000000..24e82a3a3d --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/array/sort_array.sql @@ -0,0 +1,103 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ===== INT arrays ===== + +statement +CREATE TABLE test_sort_array_int(arr array) USING parquet + +statement +INSERT INTO test_sort_array_int VALUES (array(3, 1, 2)), (array()), (NULL), (array(NULL, 3, 1, NULL, 2)), (array(-2147483648, 2147483647, 0)) + +-- ascending (default) +query +SELECT sort_array(arr) FROM test_sort_array_int + +-- descending +query +SELECT sort_array(arr, false) FROM test_sort_array_int + +-- literal arguments +query +SELECT sort_array(array(3, 1, 2)) + +query +SELECT sort_array(array(3, 1, 2), false) + +-- NULL array +query +SELECT sort_array(cast(NULL as array)) + +-- empty array +query +SELECT sort_array(array()) + +-- array with NULLs +query +SELECT sort_array(array(NULL, 3, NULL, 1, 2)) + +query +SELECT sort_array(array(NULL, 3, NULL, 1, 2), false) + +-- ===== STRING arrays ===== + +statement +CREATE TABLE test_sort_array_string(arr array) USING parquet + +statement +INSERT INTO test_sort_array_string VALUES (array('banana', 'apple', 'cherry')), (array()), (NULL), (array(NULL, 'b', NULL, 'a')) + +query +SELECT sort_array(arr) FROM test_sort_array_string + +query +SELECT sort_array(arr, false) FROM test_sort_array_string + +-- ===== DOUBLE arrays (NaN/Infinity handling) ===== + +statement +CREATE TABLE test_sort_array_double(arr array) USING parquet + +statement +INSERT INTO test_sort_array_double VALUES (array(3.0, 1.0, 2.0)), (NULL), (array(NULL, 2.0, NULL, 1.0)), (array(CAST('NaN' AS DOUBLE), 1.0, 2.0)), (array(CAST('Infinity' AS DOUBLE), CAST('-Infinity' AS DOUBLE), 0.0)) + +query +SELECT sort_array(arr) FROM test_sort_array_double + +query spark_answer_only +SELECT sort_array(arr, false) FROM test_sort_array_double + +-- ===== BIGINT arrays ===== + +query +SELECT sort_array(array(CAST(3 AS BIGINT), CAST(1 AS BIGINT), CAST(2 AS BIGINT))) + +query +SELECT sort_array(array(CAST(9223372036854775807 AS BIGINT), CAST(-9223372036854775808 AS BIGINT), CAST(0 AS BIGINT))) + +-- ===== BOOLEAN arrays ===== + +query +SELECT sort_array(array(true, false, true, false)) + +query +SELECT sort_array(array(true, false, true, false), false) + +-- ===== DECIMAL arrays ===== + +query +SELECT sort_array(array(CAST(3.14 AS DECIMAL(10,2)), CAST(1.41 AS DECIMAL(10,2)), CAST(2.72 AS DECIMAL(10,2))))