diff --git a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala index b691039f19..e4c405c003 100644 --- a/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala +++ b/spark/src/test/scala/org/apache/comet/parquet/CometParquetWriterSuite.scala @@ -447,17 +447,7 @@ class CometParquetWriterSuite extends CometTestBase { } } - private def writeWithCometNativeWriteExec( - inputPath: String, - outputPath: String, - num_partitions: Option[Int] = None): Option[SparkPlan] = { - val df = spark.read.parquet(inputPath) - - val plan = captureWritePlan( - path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path), - outputPath) - - // Count CometNativeWriteExec instances in the plan + private def assertHasCometNativeWriteExec(plan: SparkPlan): Unit = { var nativeWriteCount = 0 plan.foreach { case _: CometNativeWriteExec => @@ -474,6 +464,19 @@ class CometParquetWriterSuite extends CometTestBase { assert( nativeWriteCount == 1, s"Expected exactly one CometNativeWriteExec in the plan, but found $nativeWriteCount:\n${plan.treeString}") + } + + private def writeWithCometNativeWriteExec( + inputPath: String, + outputPath: String, + num_partitions: Option[Int] = None): Option[SparkPlan] = { + val df = spark.read.parquet(inputPath) + + val plan = captureWritePlan( + path => num_partitions.fold(df)(n => df.repartition(n)).write.parquet(path), + outputPath) + + assertHasCometNativeWriteExec(plan) Some(plan) } @@ -524,7 +527,10 @@ class CometParquetWriterSuite extends CometTestBase { SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Halifax") { val parquetDf = spark.read.parquet(inputPath) - parquetDf.write.parquet(outputPath) + + // Capture plan and verify CometNativeWriteExec is used + val plan = captureWritePlan(path => parquetDf.write.parquet(path), outputPath) + assertHasCometNativeWriteExec(plan) } // Verify round-trip: read with Spark and Comet, compare results