diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index d116d2f407..8c94411001 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -80,6 +80,7 @@ object CometExecRule { classOf[GenerateExec] -> CometExplodeExec, classOf[HashAggregateExec] -> CometHashAggregateExec, classOf[ObjectHashAggregateExec] -> CometObjectHashAggregateExec, + classOf[SortAggregateExec] -> CometSortAggregateExec, classOf[BroadcastHashJoinExec] -> CometBroadcastHashJoinExec, classOf[ShuffledHashJoinExec] -> CometHashJoinExec, classOf[SortMergeJoinExec] -> CometSortMergeJoinExec, @@ -149,8 +150,7 @@ case class CometExecRule(session: SparkSession) * Comet columnar shuffle. */ private def revertRedundantColumnarShuffle(plan: SparkPlan): SparkPlan = { - def isAggregate(p: SparkPlan): Boolean = - p.isInstanceOf[HashAggregateExec] || p.isInstanceOf[ObjectHashAggregateExec] + def isAggregate(p: SparkPlan): Boolean = p.isInstanceOf[BaseAggregateExec] def isRedundantShuffle(child: SparkPlan): Boolean = child match { case s: CometShuffleExchangeExec => @@ -858,9 +858,13 @@ case class CometExecRule(session: SparkSession) val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]] if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false - // ObjectHashAggregate has an extra shuffle-enabled guard in its convert method + // ObjectHashAggregate / SortAggregate carry TypedImperativeAggregate functions whose + // intermediate buffer formats differ between Spark and Comet, so the Partial->Final pair + // must travel via Comet shuffle. agg match { - case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false + case _: ObjectHashAggregateExec | _: SortAggregateExec + if !isCometShuffleEnabled(agg.conf) => + return false case _ => } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8cbf7c9189..1b1fd6744f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec} -import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec} +import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -1405,6 +1405,58 @@ case class CometUnionExec( trait CometBaseAggregate { + /** + * Shared support-level check used by every aggregate serde: honor the unit-test knobs that + * selectively disable Comet conversion for partial or final aggregates. + */ + protected def baseAggregateSupportLevel(op: BaseAggregateExec): SupportLevel = { + if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { + return Unsupported(Some("Partial aggregates disabled via test config")) + } + if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && + op.aggregateExpressions.exists(_.mode == Final)) { + return Unsupported(Some("Final aggregates disabled via test config")) + } + Compatible() + } + + /** + * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet), + * the Spark-side output declares buffer columns as BinaryType (Spark serializes state to + * binary). However, the native Comet aggregate produces the actual state type (e.g., + * ArrayType(elementType) for CollectSet). This corrects the output schema to match the native + * state types so the shuffle exchange schema is consistent with the actual data. + * + * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a + * case branch here mapping it to the native state type. + */ + protected def adjustOutputForNativeState(op: BaseAggregateExec): Seq[Attribute] = { + val modes = op.aggregateExpressions.map(_.mode).distinct + if (modes != Seq(Partial)) { + return op.output + } + + val numGrouping = op.groupingExpressions.length + val output = op.output.toArray + + var bufferIdx = numGrouping + for (aggExpr <- op.aggregateExpressions) { + val aggFunc = aggExpr.aggregateFunction + val bufferAttrs = aggFunc.aggBufferAttributes + aggFunc match { + case cs: CollectSet => + val elementType = cs.children.head.dataType + val nativeStateType = ArrayType(elementType, containsNull = true) + output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) + case _ => + } + bufferIdx += bufferAttrs.length + } + + output.toSeq + } + def doConvert( aggregate: BaseAggregateExec, builder: Operator.Builder, @@ -1625,8 +1677,9 @@ trait CometBaseAggregate { } /** - * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with - * partial or partial-merge mode, it will return None. + * Find the first Comet partial aggregate in the plan. If it reaches a Spark BaseAggregateExec + * (HashAggregate / ObjectHashAggregate / SortAggregate) with partial or partial-merge mode, it + * will return None. */ private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = { def isPartialOrMerge(mode: AggregateMode): Boolean = @@ -1636,10 +1689,7 @@ trait CometBaseAggregate { case agg: CometHashAggregateExec if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => Some(agg) - case agg: HashAggregateExec - if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => - None - case agg: ObjectHashAggregateExec + case agg: BaseAggregateExec if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) => None case a: AQEShuffleReadExec => findCometPartialAgg(a.child) @@ -1656,19 +1706,8 @@ object CometHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) - override def getSupportLevel(op: HashAggregateExec): SupportLevel = { - // some unit tests need to disable partial or final hash aggregate support to test that - // CometExecRule does not allow mixed Spark/Comet aggregates - if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && - op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { - return Unsupported(Some("Partial aggregates disabled via test config")) - } - if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && - op.aggregateExpressions.exists(_.mode == Final)) { - return Unsupported(Some("Final aggregates disabled via test config")) - } - Compatible() - } + override def getSupportLevel(op: HashAggregateExec): SupportLevel = + baseAggregateSupportLevel(op) override def convert( aggregate: HashAggregateExec, @@ -1698,19 +1737,8 @@ object CometObjectHashAggregateExec override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_AGGREGATE_ENABLED) - override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = { - // Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution - // unit tests can selectively disable partial or final ObjectHashAggregateExec conversion. - if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) && - op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) { - return Unsupported(Some("Partial aggregates disabled via test config")) - } - if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) && - op.aggregateExpressions.exists(_.mode == Final)) { - return Unsupported(Some("Final aggregates disabled via test config")) - } - Compatible() - } + override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = + baseAggregateSupportLevel(op) override def convert( aggregate: ObjectHashAggregateExec, @@ -1739,42 +1767,49 @@ object CometObjectHashAggregateExec op.child, SerializedPlan(None)) } +} - /** - * For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet), - * the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to - * binary). However, the native Comet aggregate produces the actual state type (e.g., - * ArrayType(elementType) for CollectSet). This method corrects the output schema to match the - * native state types so the shuffle exchange schema is consistent with the actual data. - * - * NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a - * case branch here mapping it to the native state type. - */ - private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = { - // This adjustment only applies to pure-Partial aggregates (checked below). - val modes = op.aggregateExpressions.map(_.mode).distinct - if (modes != Seq(Partial)) { - return op.output - } +object CometSortAggregateExec + extends CometOperatorSerde[SortAggregateExec] + with CometBaseAggregate { - val numGrouping = op.groupingExpressions.length - val output = op.output.toArray + override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( + CometConf.COMET_EXEC_AGGREGATE_ENABLED) - var bufferIdx = numGrouping - for (aggExpr <- op.aggregateExpressions) { - val aggFunc = aggExpr.aggregateFunction - val bufferAttrs = aggFunc.aggBufferAttributes - aggFunc match { - case cs: CollectSet => - val elementType = cs.children.head.dataType - val nativeStateType = ArrayType(elementType, containsNull = true) - output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType) - case _ => - } - bufferIdx += bufferAttrs.length + override def getSupportLevel(op: SortAggregateExec): SupportLevel = + baseAggregateSupportLevel(op) + + override def convert( + aggregate: SortAggregateExec, + builder: Operator.Builder, + childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + + // SortAggregate is planned for TypedImperativeAggregate functions whose intermediate + // buffer formats differ between Spark and Comet (same risk as ObjectHashAggregate). + // Require Comet shuffle so a Partial->Final pair never spans the JVM/native boundary. + if (!isCometShuffleEnabled(aggregate.conf)) { + return None } - output.toSeq + doConvert(aggregate, builder, childOp: _*) + } + + override def createExec(nativeOp: Operator, op: SortAggregateExec): CometNativeExec = { + // Reuse CometHashAggregateExec as the wrapper. The native AggregateExec auto-detects + // Sorted input mode from the child's output ordering and produces output sorted by the + // grouping keys; CometExec.outputOrdering defaults to originalPlan.outputOrdering, which + // is SortAggregateExec's grouping-key ordering, so downstream operators that elided a + // sort against it still see a satisfying ordering without a dedicated wrapper class. + CometHashAggregateExec( + nativeOp, + op, + adjustOutputForNativeState(op), + op.groupingExpressions, + op.aggregateExpressions, + op.resultExpressions, + op.child.output, + op.child, + SerializedPlan(None)) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index cd0beb56cc..1ae80523b0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.optimizer.EliminateSorts import org.apache.spark.sql.comet.CometHashAggregateExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.execution.aggregate.SortAggregateExec import org.apache.spark.sql.functions.{avg, col, count_distinct, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataTypes, StructField, StructType} @@ -2109,4 +2110,30 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("SortAggregate with collect_set is converted to native") { + // useObjectHashAggregateExec=false forces Spark to plan SortAggregateExec for + // TypedImperativeAggregate functions like collect_set. Comet converts those just like + // ObjectHashAggregateExec via the shared CometBaseAggregate path. + withSQLConf( + "spark.sql.execution.useObjectHashAggregateExec" -> "false", + CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + withTempView("tbl") { + Seq((1, "a"), (2, "a"), (1, "a"), (3, "b"), (4, "b"), (4, "b")) + .toDF("v", "g") + .createOrReplaceTempView("tbl") + val query = "SELECT g, sort_array(collect_set(v)) FROM tbl GROUP BY g ORDER BY g" + // Spark must actually plan a SortAggregateExec for this query; otherwise the test + // would pass without exercising the new code path. + withSQLConf(CometConf.COMET_ENABLED.key -> "false") { + val plan = stripAQEPlan(sql(query).queryExecution.executedPlan) + assert( + plan.find(_.isInstanceOf[SortAggregateExec]).isDefined, + s"Expected SortAggregateExec in Spark-only plan but got:\n$plan") + } + checkSparkAnswerAndOperator(sql(query)) + } + } + } + }