Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,20 @@ trait CometAggregateExpressionSerde[T <: AggregateFunction] {
*/
def getSupportLevel(expr: T): SupportLevel = Compatible(None)

/**
* Indicates whether this aggregate function supports "Spark partial / Comet final" mixed
* execution. This requires the intermediate buffer format to be compatible between Spark and
* Comet.
*
* Only aggregates with simple, compatible intermediate buffers should return true. Aggregates
* with complex buffers or those with known incompatibilities (e.g., decimal overflow handling
* differences) should return false.
*
* @return
* true if the aggregate can safely run with Spark partial and Comet final, false otherwise
*/
def supportsSparkPartialCometFinal: Boolean = false

/**
* Convert a Spark expression into a protocol buffer representation that can be passed into
* native code.
Expand Down
17 changes: 17 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,23 @@ object QueryPlanSerde extends Logging with CometExprShim {
classOf[VariancePop] -> CometVariancePop,
classOf[VarianceSamp] -> CometVarianceSamp)

/**
* Checks if the given aggregate function supports "Spark partial / Comet final" mixed
* execution. This is used to determine if Comet can process a final aggregate even when the
* partial aggregate was performed by Spark.
*
* @param fn
* The aggregate function to check
* @return
* true if the aggregate supports mixed execution, false otherwise
*/
def aggSupportsMixedExecution(fn: AggregateFunction): Boolean = {
aggrSerdeMap.get(fn.getClass) match {
case Some(handler) => handler.supportsSparkPartialCometFinal
case None => false
}
}

def supportedDataType(dt: DataType, allowComplex: Boolean = false): Boolean = dt match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType |
_: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType |
Expand Down
25 changes: 25 additions & 0 deletions spark/src/main/scala/org/apache/comet/serde/aggregates.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ import org.apache.comet.shims.CometEvalModeUtil

object CometMin extends CometAggregateExpressionSerde[Min] {

// Min has a simple intermediate buffer (single value) compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Min,
Expand Down Expand Up @@ -81,6 +84,9 @@ object CometMin extends CometAggregateExpressionSerde[Min] {

object CometMax extends CometAggregateExpressionSerde[Max] {

// Max has a simple intermediate buffer (single value) compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Max,
Expand Down Expand Up @@ -127,6 +133,10 @@ object CometMax extends CometAggregateExpressionSerde[Max] {
}

object CometCount extends CometAggregateExpressionSerde[Count] {

// Count has a simple intermediate buffer (single Long) compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
expr: Count,
Expand Down Expand Up @@ -317,6 +327,11 @@ object CometLast extends CometAggregateExpressionSerde[Last] {
}

object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {

// BitAnd has a simple intermediate buffer (single integral value)
// compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitAnd: BitAndAgg,
Expand Down Expand Up @@ -351,6 +366,11 @@ object CometBitAndAgg extends CometAggregateExpressionSerde[BitAndAgg] {
}

object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {

// BitOr has a simple intermediate buffer (single integral value)
// compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitOr: BitOrAgg,
Expand Down Expand Up @@ -385,6 +405,11 @@ object CometBitOrAgg extends CometAggregateExpressionSerde[BitOrAgg] {
}

object CometBitXOrAgg extends CometAggregateExpressionSerde[BitXorAgg] {

// BitXor has a simple intermediate buffer (single integral value)
// compatible between Spark and Comet
override def supportsSparkPartialCometFinal: Boolean = true

override def convert(
aggExpr: AggregateExpression,
bitXor: BitXorAgg,
Expand Down
11 changes: 7 additions & 4 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleEnabled, with
import org.apache.comet.parquet.CometParquetUtils
import org.apache.comet.serde.{CometOperatorSerde, Compatible, Incompatible, OperatorOuterClass, SupportLevel, Unsupported}
import org.apache.comet.serde.OperatorOuterClass.{AggregateMode => CometAggregateMode, Operator}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, supportedSortType}
import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, aggSupportsMixedExecution, exprToProto, supportedSortType}
import org.apache.comet.serde.operator.CometSink

/**
Expand Down Expand Up @@ -1067,11 +1067,14 @@ trait CometBaseAggregate {
val modes = aggregate.aggregateExpressions.map(_.mode).distinct
// In distinct aggregates there can be a combination of modes
val multiMode = modes.size > 1
// For a final mode HashAggregate, we only need to transform the HashAggregate
// if there is Comet partial aggregation.
// For a final mode HashAggregate, check if there is Comet partial aggregation.
// If not, we can still proceed if all aggregates support mixed execution
// (Spark partial / Comet final). See https://github.com/apache/datafusion-comet/issues/2894
val sparkFinalMode = modes.contains(Final) && findCometPartialAgg(aggregate.child).isEmpty
val allSupportMixedExecution = aggregate.aggregateExpressions.forall(expr =>
aggSupportsMixedExecution(expr.aggregateFunction))

if (multiMode || sparkFinalMode) {
if (multiMode || (sparkFinalMode && !allSupportMixedExecution)) {
return None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ class CometExecRuleSuite extends CometTestBase {
}
}

test("CometExecRule should not allow Spark partial and Comet final hash aggregate") {
test("CometExecRule should not allow Spark partial and Comet final for unsafe aggregates") {
// https://github.com/apache/datafusion-comet/issues/2894
// SUM is not safe for mixed execution due to potential overflow handling differences
withTempView("test_data") {
createTestDataFrame.createOrReplaceTempView("test_data")

Expand All @@ -173,14 +175,42 @@ class CometExecRuleSuite extends CometTestBase {
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
val transformedPlan = applyCometExecRule(sparkPlan)

// if the partial aggregate cannot be converted to Comet, then neither should be
// SUM is not safe for mixed execution, so both partial and final should fall back
assert(
countOperators(transformedPlan, classOf[HashAggregateExec]) == originalHashAggCount)
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 0)
}
}
}

test("CometExecRule should allow Spark partial and Comet final for safe aggregates") {
// https://github.com/apache/datafusion-comet/issues/2894
// MIN, MAX, COUNT are safe for mixed execution (simple intermediate buffer)
withTempView("test_data") {
createTestDataFrame.createOrReplaceTempView("test_data")

val sparkPlan =
createSparkPlan(
spark,
"SELECT MIN(id), MAX(id), COUNT(*) FROM test_data GROUP BY (id % 3)")

// Count original Spark operators (should be 2: partial + final)
val originalHashAggCount = countOperators(sparkPlan, classOf[HashAggregateExec])
assert(originalHashAggCount == 2)

withSQLConf(
CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.key -> "false",
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
val transformedPlan = applyCometExecRule(sparkPlan)

// MIN, MAX, COUNT support mixed execution, so final should run in Comet
// Partial stays in Spark (1 HashAggregateExec), final runs in Comet (1 CometHashAggregateExec)
assert(countOperators(transformedPlan, classOf[HashAggregateExec]) == 1)
assert(countOperators(transformedPlan, classOf[CometHashAggregateExec]) == 1)
}
}
}

test("CometExecRule should apply broadcast exchange transformations") {
withTempView("test_data") {
createTestDataFrame.createOrReplaceTempView("test_data")
Expand Down
Loading