diff --git a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala index 7c92b07bca..bf0ac324cf 100644 --- a/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala +++ b/spark/src/main/scala/org/apache/comet/rules/EliminateRedundantTransitions.scala @@ -22,7 +22,7 @@ package org.apache.comet.rules import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.sideBySide -import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometPlan, CometSparkToColumnarExec} +import org.apache.spark.sql.comet.{CometCollectLimitExec, CometColumnarToRowExec, CometNativeWriteExec, CometPlan, CometSparkToColumnarExec} import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, CometShuffleExchangeExec} import org.apache.spark.sql.execution.{ColumnarToRowExec, RowToColumnarExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.QueryStageExec @@ -80,6 +80,10 @@ case class EliminateRedundantTransitions(session: SparkSession) extends Rule[Spa // and CometSparkToColumnarExec sparkToColumnar.child } + // Remove unnecessary transition for native writes + // Write should be final operation in the plan + case ColumnarToRowExec(nativeWrite: CometNativeWriteExec) => + nativeWrite case c @ ColumnarToRowExec(child) if hasCometNativeChild(child) => val op = CometColumnarToRowExec(child) if (c.logicalLink.isEmpty) {