diff --git a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala index 0a5a2770b4..fc8f776ca7 100644 --- a/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/CometAggregateExpressionSerde.scala @@ -49,6 +49,20 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] { */ def getSupportLevel(expr: T): SupportLevel = Compatible(None) + /** + * Indicates whether this aggregate function supports "Spark partial / Comet final" mixed + * execution. This requires the intermediate buffer format to be compatible between Spark and + * Comet. + * + * Only aggregates with simple, compatible intermediate buffers should return true. Aggregates + * with complex buffers or those with known incompatibilities (e.g., decimal overflow handling + * differences) should return false. + * + * @return + * true if the aggregate can safely run with Spark partial and Comet final, false otherwise + */ + def supportsSparkPartialCometFinal: Boolean = false + /** * Convert a Spark expression into a protocol buffer representation that can be passed into * native code. diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index e50b1d80e6..f6c12a8a11 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -253,6 +253,23 @@ object QueryPlanSerde extends Logging with CometExprShim { classOf[VariancePop] -> CometVariancePop, classOf[VarianceSamp] -> CometVarianceSamp) + /** + * Checks if the given aggregate function supports "Spark partial / Comet final" mixed + * execution. This is used to determine if Comet can process a final aggregate even when the + * partial aggregate was performed by Spark. + * + * @param fn + * The aggregate function to check + * @return + * true if the aggregate supports mixed execution, false otherwise + */ + def aggSupportsMixedExecution(fn: AggregateFunction): Boolean = { + aggrSerdeMap.get(fn.getClass) match { + case Some(handler) => handler.supportsSparkPartialCometFinal + case None => false + } + } + def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match { case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index a05efaebbc..30a70f4673 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -34,6 +34,9 @@ import org.apache.comet.shims.CometEvalModeUtil object CometMin extends CometAggregateExpressionSerde[Min] { + // Min has a simple intermediate buffer (single value) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Min, @@ -81,6 +84,9 @@ object CometMin extends CometAggregateExpressionSerde[Min] { object CometMax extends CometAggregateExpressionSerde[Max] { + // Max has a simple intermediate buffer (single value) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Max, @@ -127,6 +133,10 @@ object CometMax extends CometAggregateExpressionSerde[Max] { } object CometCount extends CometAggregateExpressionSerde[Count] { + + // Count has a simple intermediate buffer (single Long) compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, expr: Count, @@ -317,6 +327,11 @@ object CometLast extends CometAggregateExpressionSerde[Last] { } object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { + + // BitAnd has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitAnd: BitAndAgg, @@ -351,6 +366,11 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] { } object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { + + // BitOr has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitOr: BitOrAgg, @@ -385,6 +405,11 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] { } object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] { + + // BitXor has a simple intermediate buffer (single integral value) + // compatible between Spark and Comet + override def supportsSparkPartialCometFinal: Boolean = true + override def convert( aggExpr: AggregateExpression, bitXor: BitXorAgg, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 0a435e5b7a..9929024a6d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -56,7 +56,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, with import org.apache.comet.parquet.CometParquetUtils import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported} import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator} -import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType} +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, aggSupportsMixedExecution, exprToProto, supportedSortType} import org.apache.comet.serde.operator.CometSink /** @@ -1067,11 +1067,14 @@ trait CometBaseAggregate { val modes = aggregate.aggregateExpressions.map(_.mode).distinct // In distinct aggregates there can be a combination of modes val multiMode = modes.size > 1 - // For a final mode HashAggregate, we only need to transform the HashAggregate - // if there is Comet partial aggregation. + // For a final mode HashAggregate, check if there is Comet partial aggregation. + // If not, we can still proceed if all aggregates support mixed execution + // (Spark partial / Comet final). See https://github.com/apache/datafusion-comet/issues/2894 val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty + val allSupportMixedExecution = aggregate.aggregateExpressions.forall(expr => + aggSupportsMixedExecution(expr.aggregateFunction)) - if (multiMode || sparkFinalMode) { + if (multiMode || (sparkFinalMode && !allSupportMixedExecution)) { return None } diff --git a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala index cf6f8918f4..a343f7fb93 100644 --- a/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/rules/CometExecRuleSuite.scala @@ -157,7 +157,9 @@ class CometExecRuleSuite extends CometTestBase { } } - test("CometExecRule should not allow Spark partial and Comet final hash aggregate") { + test("CometExecRule should not allow Spark partial and Comet final for unsafe aggregates") { + // https://github.com/apache/datafusion-comet/issues/2894 + // SUM is not safe for mixed execution due to potential overflow handling differences withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data") @@ -173,7 +175,7 @@ class CometExecRuleSuite extends CometTestBase { CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { val transformedPlan = applyCometExecRule(sparkPlan) - // if the partial aggregate cannot be converted to Comet, then neither should be + // SUM is not safe for mixed execution, so both partial and final should fall back assert( countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount) assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0) @@ -181,6 +183,34 @@ class CometExecRuleSuite extends CometTestBase { } } + test("CometExecRule should allow Spark partial and Comet final for safe aggregates") { + // https://github.com/apache/datafusion-comet/issues/2894 + // MIN, MAX, COUNT are safe for mixed execution (simple intermediate buffer) + withTempView("test_data") { + createTestDataFrame.createOrReplaceTempView("test_data") + + val sparkPlan = + createSparkPlan( + spark, + "SELECT MIN(id), MAX(id), COUNT(*) FROM test_data GROUP BY (id % 3)") + + // Count original Spark operators (should be 2: partial + final) + val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec]) + assert(originalHashAggCount == 2) + + withSQLConf( + CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val transformedPlan = applyCometExecRule(sparkPlan) + + // MIN, MAX, COUNT support mixed execution, so final should run in Comet + // Partial stays in Spark (1 HashAggregateExec), final runs in Comet (1 CometHashAggregateExec) + assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1) + assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1) + } + } + } + test("CometExecRule should apply broadcast exchange transformations") { withTempView("test_data") { createTestDataFrame.createOrReplaceTempView("test_data")