Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, Comet
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat
Expand Down Expand Up @@ -80,6 +80,7 @@ object CometExecRule {
classOf[GenerateExec] -> CometExplodeExec,
classOf[HashAggregateExec] -> CometHashAggregateExec,
classOf[ObjectHashAggregateExec] -> CometObjectHashAggregateExec,
classOf[SortAggregateExec] -> CometSortAggregateExec,
classOf[BroadcastHashJoinExec] -> CometBroadcastHashJoinExec,
classOf[ShuffledHashJoinExec] -> CometHashJoinExec,
classOf[SortMergeJoinExec] -> CometSortMergeJoinExec,
Expand Down Expand Up @@ -149,8 +150,7 @@ case class CometExecRule(session: SparkSession)
* Comet columnar shuffle.
*/
private def revertRedundantColumnarShuffle(plan: SparkPlan): SparkPlan = {
def isAggregate(p: SparkPlan): Boolean =
p.isInstanceOf[HashAggregateExec] || p.isInstanceOf[ObjectHashAggregateExec]
def isAggregate(p: SparkPlan): Boolean = p.isInstanceOf[BaseAggregateExec]

def isRedundantShuffle(child: SparkPlan): Boolean = child match {
case s: CometShuffleExchangeExec =>
Expand Down Expand Up @@ -858,9 +858,13 @@ case class CometExecRule(session: SparkSession)
val serde = handler.get.asInstanceOf[CometOperatorSerde[SparkPlan]]
if (!isOperatorEnabled(serde, agg.asInstanceOf[SparkPlan])) return false

// ObjectHashAggregate has an extra shuffle-enabled guard in its convert method
// ObjectHashAggregate / SortAggregate carry TypedImperativeAggregate functions whose
// intermediate buffer formats differ between Spark and Comet, so the Partial->Final pair
// must travel via Comet shuffle.
agg match {
case _: ObjectHashAggregateExec if !isCometShuffleEnabled(agg.conf) => return false
case _: ObjectHashAggregateExec | _: SortAggregateExec
if !isCometShuffleEnabled(agg.conf) =>
return false
case _ =>
}

Expand Down
163 changes: 99 additions & 64 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.comet.util.Utils
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AQEShuffleReadExec, BroadcastQueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.aggregate.{BaseAggregateExec, HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, HashJoin, ShuffledHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
Expand Down Expand Up @@ -1405,6 +1405,58 @@ case class CometUnionExec(

trait CometBaseAggregate {

/**
* Shared support-level check used by every aggregate serde: honor the unit-test knobs that
* selectively disable Comet conversion for partial or final aggregates.
*/
protected def baseAggregateSupportLevel(op: BaseAggregateExec): SupportLevel = {
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}

/**
* For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet),
* the Spark-side output declares buffer columns as BinaryType (Spark serializes state to
* binary). However, the native Comet aggregate produces the actual state type (e.g.,
* ArrayType(elementType) for CollectSet). This corrects the output schema to match the native
* state types so the shuffle exchange schema is consistent with the actual data.
*
* NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a
* case branch here mapping it to the native state type.
*/
protected def adjustOutputForNativeState(op: BaseAggregateExec): Seq[Attribute] = {
val modes = op.aggregateExpressions.map(_.mode).distinct
if (modes != Seq(Partial)) {
return op.output
}

val numGrouping = op.groupingExpressions.length
val output = op.output.toArray

var bufferIdx = numGrouping
for (aggExpr <- op.aggregateExpressions) {
val aggFunc = aggExpr.aggregateFunction
val bufferAttrs = aggFunc.aggBufferAttributes
aggFunc match {
case cs: CollectSet =>
val elementType = cs.children.head.dataType
val nativeStateType = ArrayType(elementType, containsNull = true)
output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType)
case _ =>
}
bufferIdx += bufferAttrs.length
}

output.toSeq
}

def doConvert(
aggregate: BaseAggregateExec,
builder: Operator.Builder,
Expand Down Expand Up @@ -1625,8 +1677,9 @@ trait CometBaseAggregate {
}

/**
* Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with
* partial or partial-merge mode, it will return None.
* Find the first Comet partial aggregate in the plan. If it reaches a Spark BaseAggregateExec
* (HashAggregate / ObjectHashAggregate / SortAggregate) with partial or partial-merge mode, it
* will return None.
*/
private def findCometPartialAgg(plan: SparkPlan): Option[CometHashAggregateExec] = {
def isPartialOrMerge(mode: AggregateMode): Boolean =
Expand All @@ -1636,10 +1689,7 @@ trait CometBaseAggregate {
case agg: CometHashAggregateExec
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
Some(agg)
case agg: HashAggregateExec
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
None
case agg: ObjectHashAggregateExec
case agg: BaseAggregateExec
if agg.aggregateExpressions.forall(e => isPartialOrMerge(e.mode)) =>
None
case a: AQEShuffleReadExec => findCometPartialAgg(a.child)
Expand All @@ -1656,19 +1706,8 @@ object CometHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: HashAggregateExec): SupportLevel = {
// some unit tests need to disable partial or final hash aggregate support to test that
// CometExecRule does not allow mixed Spark/Comet aggregates
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}
override def getSupportLevel(op: HashAggregateExec): SupportLevel =
baseAggregateSupportLevel(op)

override def convert(
aggregate: HashAggregateExec,
Expand Down Expand Up @@ -1698,19 +1737,8 @@ object CometObjectHashAggregateExec
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel = {
// Mirror the same test-knobs as CometHashAggregateExec so that mixed-execution
// unit tests can selectively disable partial or final ObjectHashAggregateExec conversion.
if (!CometConf.COMET_ENABLE_PARTIAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(expr => expr.mode == Partial || expr.mode == PartialMerge)) {
return Unsupported(Some("Partial aggregates disabled via test config"))
}
if (!CometConf.COMET_ENABLE_FINAL_HASH_AGGREGATE.get(op.conf) &&
op.aggregateExpressions.exists(_.mode == Final)) {
return Unsupported(Some("Final aggregates disabled via test config"))
}
Compatible()
}
override def getSupportLevel(op: ObjectHashAggregateExec): SupportLevel =
baseAggregateSupportLevel(op)

override def convert(
aggregate: ObjectHashAggregateExec,
Expand Down Expand Up @@ -1739,42 +1767,49 @@ object CometObjectHashAggregateExec
op.child,
SerializedPlan(None))
}
}

/**
* For Partial mode aggregates containing TypedImperativeAggregate functions (like CollectSet),
* the Spark-side output declares buffer columns as BinaryType (since Spark serializes state to
* binary). However, the native Comet aggregate produces the actual state type (e.g.,
* ArrayType(elementType) for CollectSet). This method corrects the output schema to match the
* native state types so the shuffle exchange schema is consistent with the actual data.
*
* NOTE: If a new TypedImperativeAggregate function (e.g., CollectList) is added natively, add a
* case branch here mapping it to the native state type.
*/
private def adjustOutputForNativeState(op: ObjectHashAggregateExec): Seq[Attribute] = {
// This adjustment only applies to pure-Partial aggregates (checked below).
val modes = op.aggregateExpressions.map(_.mode).distinct
if (modes != Seq(Partial)) {
return op.output
}
object CometSortAggregateExec
extends CometOperatorSerde[SortAggregateExec]
with CometBaseAggregate {

val numGrouping = op.groupingExpressions.length
val output = op.output.toArray
override def enabledConfig: Option[ConfigEntry[Boolean]] = Some(
CometConf.COMET_EXEC_AGGREGATE_ENABLED)

var bufferIdx = numGrouping
for (aggExpr <- op.aggregateExpressions) {
val aggFunc = aggExpr.aggregateFunction
val bufferAttrs = aggFunc.aggBufferAttributes
aggFunc match {
case cs: CollectSet =>
val elementType = cs.children.head.dataType
val nativeStateType = ArrayType(elementType, containsNull = true)
output(bufferIdx) = output(bufferIdx).withDataType(nativeStateType)
case _ =>
}
bufferIdx += bufferAttrs.length
override def getSupportLevel(op: SortAggregateExec): SupportLevel =
baseAggregateSupportLevel(op)

override def convert(
aggregate: SortAggregateExec,
builder: Operator.Builder,
childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = {

// SortAggregate is planned for TypedImperativeAggregate functions whose intermediate
// buffer formats differ between Spark and Comet (same risk as ObjectHashAggregate).
// Require Comet shuffle so a Partial->Final pair never spans the JVM/native boundary.
if (!isCometShuffleEnabled(aggregate.conf)) {
return None
}

output.toSeq
doConvert(aggregate, builder, childOp: _*)
}

override def createExec(nativeOp: Operator, op: SortAggregateExec): CometNativeExec = {
// Reuse CometHashAggregateExec as the wrapper. The native AggregateExec auto-detects
// Sorted input mode from the child's output ordering and produces output sorted by the
// grouping keys; CometExec.outputOrdering defaults to originalPlan.outputOrdering, which
// is SortAggregateExec's grouping-key ordering, so downstream operators that elided a
// sort against it still see a satisfying ordering without a dedicated wrapper class.
CometHashAggregateExec(
nativeOp,
op,
adjustOutputForNativeState(op),
op.groupingExpressions,
op.aggregateExpressions,
op.resultExpressions,
op.child.output,
op.child,
SerializedPlan(None))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.catalyst.optimizer.EliminateSorts
import org.apache.spark.sql.comet.CometHashAggregateExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.SortAggregateExec
import org.apache.spark.sql.functions.{avg, col, count_distinct, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
Expand Down Expand Up @@ -2109,4 +2110,30 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("SortAggregate with collect_set is converted to native") {
// useObjectHashAggregateExec=false forces Spark to plan SortAggregateExec for
// TypedImperativeAggregate functions like collect_set. Comet converts those just like
// ObjectHashAggregateExec via the shared CometBaseAggregate path.
withSQLConf(
"spark.sql.execution.useObjectHashAggregateExec" -> "false",
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") {
withTempView("tbl") {
Seq((1, "a"), (2, "a"), (1, "a"), (3, "b"), (4, "b"), (4, "b"))
.toDF("v", "g")
.createOrReplaceTempView("tbl")
val query = "SELECT g, sort_array(collect_set(v)) FROM tbl GROUP BY g ORDER BY g"
// Spark must actually plan a SortAggregateExec for this query; otherwise the test
// would pass without exercising the new code path.
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
val plan = stripAQEPlan(sql(query).queryExecution.executedPlan)
assert(
plan.find(_.isInstanceOf[SortAggregateExec]).isDefined,
s"Expected SortAggregateExec in Spark-only plan but got:\n$plan")
}
checkSparkAnswerAndOperator(sql(query))
}
}
}

}