From 376337058c411eead1dbeaee334ed30086360fd7 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 21 May 2026 16:44:09 -0400 Subject: [PATCH 01/39] enable CometLocalTableScanExec by default --- spark/src/main/scala/org/apache/comet/CometConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index fdd1ae2073..faee23a8eb 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -273,7 +273,7 @@ object CometConf extends ShimCometConf { val COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("takeOrderedAndProject", defaultValue = true) val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = - createExecEnabledConfig("localTableScan", defaultValue = false) + createExecEnabledConfig("localTableScan", defaultValue = true) val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") From 810e5d5c38d106fae4a3bff6563137e3a5fcfd01 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 07:46:49 -0400 Subject: [PATCH 02/39] add NullType to toArrowType --- .../main/scala/org/apache/spark/sql/comet/util/Utils.scala | 1 + .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 783367c054..4605e641f1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -148,6 +148,7 @@ object Utils extends CometTypeShim with Logging { } case TimestampNTZType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE case dt if isTimeType(dt) => new ArrowType.Time(TimeUnit.NANOSECOND, 64) case _ => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 16601d056b..8bf00de20c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3925,6 +3925,13 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec handles NullType column") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val df = spark.sql("SELECT * FROM VALUES ('a', null), ('b', null) AS t(x, y)") + checkSparkAnswer(df) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 From 174c939540a0be46d184f8e8b9d57a52ceae722a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 08:02:09 -0400 Subject: [PATCH 03/39] add NullType to shuffles --- native/shuffle/src/spark_unsafe/row.rs | 15 ++++++++++-- .../shuffle/CometShuffleExchangeExec.scala | 6 ++--- .../exec/CometColumnarShuffleSuite.scala | 23 ++++--------------- .../comet/exec/CometNativeShuffleSuite.scala | 6 +++++ 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/native/shuffle/src/spark_unsafe/row.rs b/native/shuffle/src/spark_unsafe/row.rs index ec0903bc56..6ffe9d0b6e 100644 --- a/native/shuffle/src/spark_unsafe/row.rs +++ b/native/shuffle/src/spark_unsafe/row.rs @@ -28,8 +28,8 @@ use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BinaryDictionaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, - Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, StringDictionaryBuilder, - StructBuilder, TimestampMicrosecondBuilder, + Int64Builder, Int8Builder, ListBuilder, MapBuilder, NullBuilder, StringBuilder, + StringDictionaryBuilder, StructBuilder, TimestampMicrosecondBuilder, }, types::Int32Type, Array, ArrayRef, RecordBatch, RecordBatchOptions, @@ -267,6 +267,10 @@ pub(super) fn append_field( append_field_to_builder!(Date32Builder, |builder: &mut Date32Builder| builder .append_value(row.get_date(idx))); } + DataType::Null => { + let field_builder = get_field_builder!(struct_builder, NullBuilder, idx); + field_builder.append_null(); + } DataType::Timestamp(TimeUnit::Microsecond, _) => { append_field_to_builder!( TimestampMicrosecondBuilder, @@ -1148,6 +1152,12 @@ fn append_columns( .append_value(row.get_date(idx)) ); } + DataType::Null => { + let null_builder = downcast_builder_ref!(NullBuilder, builder); + for _ in row_start..row_end { + null_builder.append_null(); + } + } DataType::Timestamp(TimeUnit::Microsecond, _) => { append_column_to_builder!( TimestampMicrosecondBuilder, @@ -1252,6 +1262,7 @@ fn make_builders( } } DataType::Date32 => Box::new(Date32Builder::with_capacity(row_num)), + DataType::Null => Box::new(NullBuilder::new()), DataType::Timestamp(TimeUnit::Microsecond, _) => { Box::new(TimestampMicrosecondBuilder::with_capacity(row_num).with_data_type(dt.clone())) } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 493c20f8b7..16e7a8b774 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -364,7 +364,7 @@ object CometShuffleExchangeExec def supportedSerializableDataType(dt: DataType): Boolean = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => + _: TimestampNTZType | _: DecimalType | _: DateType | _: NullType => true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) @@ -487,7 +487,7 @@ object CometShuffleExchangeExec def supportedSerializableDataType(dt: DataType): Boolean = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | - _: TimestampNTZType | _: DecimalType | _: DateType => + _: TimestampNTZType | _: DecimalType | _: DateType | _: NullType => true case StructType(fields) => fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) && diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 86c6a6aa4b..70d427972a 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -22,14 +22,13 @@ package org.apache.comet.exec import java.nio.file.{Files, Paths} import scala.reflect.runtime.universe._ -import scala.util.Random import org.scalactic.source.Position import org.scalatest.Tag import org.apache.hadoop.fs.Path import org.apache.spark.{Partitioner, SparkConf} -import org.apache.spark.sql.{CometTestBase, DataFrame, RandomDataGenerator, Row} +import org.apache.spark.sql.{CometTestBase, DataFrame, Row} import org.apache.spark.sql.comet.execution.shuffle.{CometShuffleDependency, CometShuffleExchangeExec, CometShuffleManager} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec, ShuffleQueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec @@ -94,22 +93,10 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar """.stripMargin)) } - test("Fallback to Spark for unsupported input besides ordering") { - val dataGenerator = RandomDataGenerator - .forType( - dataType = NullType, - nullable = true, - new Random(System.nanoTime()), - validJulianDatetime = false) - .get - - val schema = new StructType() - .add("index", IntegerType, nullable = false) - .add("col", NullType, nullable = true) - val rdd = - spark.sparkContext.parallelize((1 to 20).map(i => Row(i, dataGenerator()))) - val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1) - checkSparkAnswer(df) + test("columnar shuffle with NullType passthrough column") { + val df = sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)") + val shuffled = df.repartition(2, $"x") + checkShuffleAnswer(shuffled, 1) } test("columnar shuffle on nested struct including nulls") { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index e0ef1df1f4..60637102f0 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -218,6 +218,12 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper } } + test("native shuffle with NullType passthrough column") { + val df = spark.sql("SELECT x, y FROM VALUES ('a', null), ('b', null), ('c', null) AS t(x, y)") + val shuffled = df.repartition(2, $"x") + checkShuffleAnswer(shuffled, 1) + } + test("fix: Comet native shuffle with binary data") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") From 3790c1022b71cbe47cbde8e4553304373cfde1f4 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 09:12:23 -0400 Subject: [PATCH 04/39] fix windowexec test and nulltype. fix timetype issues --- .../sql/comet/CometLocalTableScanExec.scala | 32 +++++++++++++++++-- .../apache/comet/exec/CometExecSuite.scala | 18 +++++++++++ .../comet/exec/CometWindowExecSuite.scala | 10 +----- 3 files changed, 49 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 622168bcc9..0a836bc389 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.comet +import scala.collection.mutable.ListBuffer + import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -27,11 +29,13 @@ import org.apache.spark.sql.comet.CometLocalTableScanExec.createMetricsIterator import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types.{DataType, NullType} import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects -import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport} +import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink @@ -104,7 +108,7 @@ case class CometLocalTableScanExec( override def hashCode(): Int = Objects.hashCode(originalPlan, originalPlan.schema, output) } -object CometLocalTableScanExec extends CometSink[LocalTableScanExec] { +object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTypeSupport { // uses CometArrowConverters, which re-uses arrays override def isFfiSafe: Boolean = false @@ -112,6 +116,30 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] { override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) + // CometArrowConverters / ArrowWriter support NullType (via Utils.toArrowType + + // NullWriter). Other types not on DataTypeSupport's allow list (e.g. TimeType, + // intervals) lack ArrowWriter coverage and must fall back to Spark. + override def isTypeSupported( + dt: DataType, + name: String, + fallbackReasons: ListBuffer[String]): Boolean = dt match { + case _: NullType => true + case _ => super.isTypeSupported(dt, name, fallbackReasons) + } + + override def convert( + op: LocalTableScanExec, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = { + val fallbackReasons = new ListBuffer[String]() + if (!isSchemaSupported(op.schema, fallbackReasons)) { + withInfo(op, fallbackReasons.mkString("; ")) + None + } else { + super.convert(op, builder, childOp: _*) + } + } + override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 8bf00de20c..c3e903f883 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3932,6 +3932,24 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec handles NullType nested in struct/array/map") { + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + checkSparkAnswer( + spark.sql("SELECT named_struct('a', 1, 'b', null) AS s, array(null, null) AS a, " + + "map('k', null) AS m")) + } + } + + test("CometLocalTableScanExec falls back when schema contains TimeType") { + assume( + org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, + "TimeType requires Spark 4.1+") + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val df = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id") + checkSparkAnswer(df) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala index 544cd91bd2..a9fdc96231 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala @@ -108,15 +108,7 @@ class CometWindowExecSuite extends CometTestBase { val cometShuffles = collect(df2.queryExecution.executedPlan) { case _: CometShuffleExchangeExec => true } - if (shuffleMode == "jvm" || shuffleMode == "auto") { - assert(cometShuffles.length == 1) - } else { - // we fall back to Spark for shuffle because we do not support - // native shuffle with a LocalTableScan input, and we do not fall - // back to Comet columnar shuffle due to - // https://github.com/apache/datafusion-comet/issues/1248 - assert(cometShuffles.isEmpty) - } + assert(cometShuffles.length == 1) } } } From 18cd14b0940eeb9a95e1a22e681b5a1c84f5b650 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 10:03:50 -0400 Subject: [PATCH 05/39] Fix TimeType test. --- .../test/scala/org/apache/comet/exec/CometExecSuite.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index c3e903f883..71a30adecd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3945,8 +3945,12 @@ class CometExecSuite extends CometTestBase { org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, "TimeType requires Spark 4.1+") withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { - val df = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id") - checkSparkAnswer(df) + // Spark 4.1's row encoder cannot serialize TIME columns to the JVM, so we cannot + // collect rows. count() exercises the LocalRelation -> scan path without materializing + // the TIME value, which is sufficient to verify the fallback (without the fallback the + // CometLocalTableScanExec ArrowWriter would crash on TimeType). + val cnt = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id").count() + assert(cnt == 1) } } From fc40d59a81abfc70c91ab1fb5d2453e00fa11e8e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 22 May 2026 16:59:33 -0400 Subject: [PATCH 06/39] fix null value type in map in native shuffle --- native/shuffle/src/spark_unsafe/list.rs | 8 +++++++- .../comet/exec/CometColumnarShuffleSuite.scala | 6 ++++++ .../org/apache/comet/exec/CometExecSuite.scala | 16 +++++++++------- .../comet/exec/CometNativeShuffleSuite.scala | 6 ++++++ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/native/shuffle/src/spark_unsafe/list.rs b/native/shuffle/src/spark_unsafe/list.rs index 3fea3fadeb..14f9feb843 100644 --- a/native/shuffle/src/spark_unsafe/list.rs +++ b/native/shuffle/src/spark_unsafe/list.rs @@ -24,7 +24,7 @@ use arrow::array::{ builder::{ ArrayBuilder, BinaryBuilder, BooleanBuilder, Date32Builder, Decimal128Builder, Float32Builder, Float64Builder, Int16Builder, Int32Builder, Int64Builder, Int8Builder, - ListBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, + ListBuilder, NullBuilder, StringBuilder, StructBuilder, TimestampMicrosecondBuilder, }, MapBuilder, }; @@ -393,6 +393,12 @@ pub fn append_to_builder( let builder = downcast_builder_ref!(Date32Builder, builder); array.append_dates_to_builder::(builder); } + DataType::Null => { + let builder = downcast_builder_ref!(NullBuilder, builder); + for _ in 0..array.get_num_elements() { + builder.append_null(); + } + } DataType::Binary => { add_values!( BinaryBuilder, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 70d427972a..b0be2b90ac 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -99,6 +99,12 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar checkShuffleAnswer(shuffled, 1) } + test("columnar shuffle with Map[_, NullType] column") { + val df = sql("SELECT id, map(id, null) AS m FROM VALUES (1), (2), (3) AS t(id)") + val shuffled = df.repartition(2, $"id") + checkShuffleAnswer(shuffled, 1) + } + test("columnar shuffle on nested struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 71a30adecd..8c8c19bb9c 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3944,13 +3944,15 @@ class CometExecSuite extends CometTestBase { assume( org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, "TimeType requires Spark 4.1+") - withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { - // Spark 4.1's row encoder cannot serialize TIME columns to the JVM, so we cannot - // collect rows. count() exercises the LocalRelation -> scan path without materializing - // the TIME value, which is sufficient to verify the fallback (without the fallback the - // CometLocalTableScanExec ArrowWriter would crash on TimeType). - val cnt = spark.sql("SELECT TIME '12:34:56' AS t, 1 AS id").count() - assert(cnt == 1) + // spark.sql.timeType.enabled defaults to Utils.isTesting; enable explicitly so the + // row encoder accepts TIME (matches Spark's own TimeFunctionsSuiteBase setup). + withSQLConf( + "spark.sql.timeType.enabled" -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + // VALUES folds to a LocalRelation, exercising the CometLocalTableScanExec convert + // path; the TimeType column should drive the schema-level fallback. + val df = spark.sql("SELECT * FROM VALUES (TIME '12:34:56'), (TIME '01:02:03') AS t(c)") + checkSparkAnswer(df) } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala index 60637102f0..b34e75d137 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeShuffleSuite.scala @@ -224,6 +224,12 @@ class CometNativeShuffleSuite extends CometTestBase with AdaptiveSparkPlanHelper checkShuffleAnswer(shuffled, 1) } + test("native shuffle with Map[_, NullType] column") { + val df = spark.sql("SELECT id, map(id, null) AS m FROM VALUES (1), (2), (3) AS t(id)") + val shuffled = df.repartition(2, $"id") + checkShuffleAnswer(shuffled, 1) + } + test("fix: Comet native shuffle with binary data") { withParquetTable((0 until 5).map(i => (i, (i + 1).toLong)), "tbl") { val df = sql("SELECT cast(cast(_1 as STRING) as BINARY) as binary, _2 FROM tbl") From 8c088a732cb9301e3d1832926c0c676094a65ba8 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 26 May 2026 08:29:31 -0400 Subject: [PATCH 07/39] avoid reuse in LocalTableScanExec --- .../sql/comet/CometLocalTableScanExec.scala | 41 +++++++------------ .../arrow/CometArrowConverters.scala | 33 +++++++++++++++ 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 0a836bc389..6ada259e82 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -25,7 +25,6 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.comet.CometLocalTableScanExec.createMetricsIterator import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -68,19 +67,24 @@ case class CometLocalTableScanExec( } override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numInputRows = longMetric("numOutputRows") + val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC to match native side expectations. See CometSparkToColumnarExec. - val timeZoneId = "UTC" - rdd.mapPartitionsInternal { sparkBatches => + val schema = originalPlan.schema + // Native side asserts Timestamp(Microsecond, Some("UTC")). See COMET-2720. + rdd.mapPartitionsInternal { rowIter => val context = TaskContext.get() - val batches = CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - originalPlan.schema, + // Non-Comet JVM consumers (e.g. Iceberg writers) may retain batches across next() + // calls, so each batch must own independent Arrow buffers. + val batches = CometArrowConverters.rowToArrowBatchIterNoReuse( + rowIter, + schema, maxRecordsPerBatch, - timeZoneId, + "UTC", context) - createMetricsIterator(batches, numInputRows) + batches.map { batch => + numOutputRows.add(batch.numRows()) + batch + } } } @@ -110,9 +114,6 @@ case class CometLocalTableScanExec( object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTypeSupport { - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false - override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) @@ -143,18 +144,4 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTy override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) } - - private def createMetricsIterator( - it: Iterator[ColumnarBatch], - numInputRows: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - override def hasNext: Boolean = it.hasNext - - override def next(): ColumnarBatch = { - val batch = it.next() - numInputRows.add(batch.numRows()) - batch - } - } - } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 6d52078181..0e7d0d9a8d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -140,6 +140,39 @@ object CometArrowConverters extends Logging { new RowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) } + /** + * Use when the downstream consumer may retain batches across `next()` calls (e.g. non-Comet JVM + * columnar sinks like Iceberg writers). Each batch owns independent Arrow buffers. + */ + def rowToArrowBatchIterNoReuse( + rowIter: Iterator[InternalRow], + schema: StructType, + maxRecordsPerBatch: Long, + timeZoneId: String, + context: TaskContext): Iterator[ColumnarBatch] = { + val arrowSchema = Utils.toArrowSchema(schema, timeZoneId) + val allocator = + CometArrowAllocator.newChildAllocator("rowToArrowBatchIterNoReuse", 0, Long.MaxValue) + Option(context).foreach(_.addTaskCompletionListener[Unit](_ => allocator.close())) + + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = rowIter.hasNext + + override def next(): ColumnarBatch = { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = ArrowWriter.create(root) + var rowCount = 0L + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) + rowCount += 1 + } + writer.finish() + NativeUtil.rootAsBatch(root) + } + } + } + private[sql] class ColumnBatchToArrowBatchIter( colBatch: ColumnarBatch, schema: StructType, From bd04fb4d3a0c9039d3e020493f5dc450ef1de839 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 12:18:56 -0400 Subject: [PATCH 08/39] Replace Comet's bespoke CometBatchIterator JNI input path with the canonical Arrow C Stream Interface (JVM Data.exportArrayStream <-> native ArrowArrayStreamReader), eliminating the per-batch FFI deep copy and the arrow_ffi_safe flag. --- native/core/src/execution/operators/scan.rs | 185 ++++------------- native/core/src/execution/planner.rs | 51 +++-- native/core/src/execution/utils.rs | 35 +--- ...atch_iterator.rs => arrow_array_stream.rs} | 39 ++-- native/jni-bridge/src/lib.rs | 11 +- native/proto/src/proto/operator.proto | 2 - .../org/apache/comet/CometBatchIterator.java | 93 --------- .../org/apache/comet/CometExecIterator.scala | 15 +- .../operator/CometDataWritingCommand.scala | 1 - .../comet/serde/operator/CometSink.scala | 4 - .../apache/spark/sql/comet/CometExecRDD.scala | 65 +++--- .../spark/sql/comet/CometExecUtils.scala | 15 +- .../sql/comet/CometLocalTableScanExec.scala | 66 +++++-- .../sql/comet/CometNativeWriteExec.scala | 10 +- .../sql/comet/CometSparkToColumnarExec.scala | 138 +++++++------ .../CometTakeOrderedAndProjectExec.scala | 23 ++- .../arrow/ColumnarBatchArrowReader.scala | 84 ++++++++ .../arrow/CometArrowConverters.scala | 186 ++++-------------- .../arrow/CometNativeArrowSource.scala | 184 +++++++++++++++++ .../execution/arrow/RowArrowReader.scala | 69 +++++++ .../arrow/SparkColumnarArrowReader.scala | 97 +++++++++ .../shuffle/CometNativeShuffleWriter.scala | 10 +- .../apache/spark/sql/comet/operators.scala | 114 ++++++++--- .../apache/spark/sql/comet/util/Utils.scala | 10 + .../org/apache/comet/CometNativeSuite.scala | 18 +- .../apache/comet/exec/CometExecSuite.scala | 23 +++ .../exec/CometNativeColumnarToRowSuite.scala | 9 +- 27 files changed, 924 insertions(+), 633 deletions(-) rename native/jni-bridge/src/{batch_iterator.rs => arrow_array_stream.rs} (57%) delete mode 100644 spark/src/main/java/org/apache/comet/CometBatchIterator.java create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index e318d9e66b..bad349e4d7 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,19 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::operators::{copy_array, copy_or_unpack_array, CopyMode}; -use crate::{ - errors::CometError, - execution::{ - operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, utils::SparkArrowConvert, - }, - jvm_bridge::JVMClasses, -}; -use arrow::array::{make_array, ArrayData, ArrayRef, RecordBatch, RecordBatchOptions}; +use crate::execution::operators::{copy_or_unpack_array, CopyMode}; +use crate::{errors::CometError, execution::planner::TEST_EXEC_CONTEXT_ID}; +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::ffi::FFI_ArrowArray; -use arrow::ffi::FFI_ArrowSchema; +use arrow::ffi_stream::ArrowArrayStreamReader; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -40,8 +33,6 @@ use datafusion::{ }; use futures::Stream; use itertools::Itertools; -use jni::objects::{Global, JObject, JValue}; -use std::rc::Rc; use std::{ any::Any, pin::Pin, @@ -49,43 +40,34 @@ use std::{ task::{Context, Poll}, }; -/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file -/// scan or the result of reading a broadcast or shuffle exchange. ScanExec isn't invoked -/// until the data is already available in the JVM. When CometExecIterator invokes -/// Native.executePlan, it passes in the memory addresses of the input batches. +/// `ScanExec` reads batches of data from Spark over the Arrow C Stream Interface. The +/// `input_source` is moved out of the JVM-exported `ArrowArrayStream` at plan-construction time; +/// dropping the reader (when this exec drops) fires the stream's release callback, which closes +/// the JVM-side `ArrowReader` and its `VectorSchemaRoot`. #[derive(Debug, Clone)] pub struct ScanExec { - /// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM - /// environment `JNIEnv` from the execution context. + /// JVM execution-context id used to look up the `JNIEnv` for callbacks. pub exec_context_id: i64, - /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. - pub input_source: Option>>>, - /// A description of the input source for informational purposes + /// The C Stream Interface reader. `None` only in unit tests that seed input via + /// `set_input_batch`. + pub input_source: Option>>, pub input_source_description: String, - /// The data types of columns of the input batch. Converted from Spark schema. pub data_types: Vec, - /// Schema of first batch pub schema: SchemaRef, - /// The input batch of input data. Used to determine the schema of the input data. - /// It is also used in unit test to mock the input data from JVM. + /// Used in unit tests to mock the input batch; otherwise written by `pull_next` on each + /// poll. pub batch: Arc>>, - /// Cache of expensive-to-compute plan properties cache: Arc, - /// Metrics collector metrics: ExecutionPlanMetricsSet, - /// Baseline metrics baseline_metrics: BaselineMetrics, - /// Whether native code can assume ownership of batches that it receives - arrow_ffi_safe: bool, } impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>>>, + input_source: Option>>, input_source_description: &str, data_types: Vec, - arrow_ffi_safe: bool, ) -> Result { let metrics_set = ExecutionPlanMetricsSet::default(); let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); @@ -112,7 +94,6 @@ impl ScanExec { metrics: metrics_set, baseline_metrics, schema, - arrow_ffi_safe, }) } @@ -131,22 +112,18 @@ impl ScanExec { *self.batch.try_lock().unwrap() = Some(input); } - /// Pull next input batch from JVM. + /// Pull next input batch from the upstream `ArrowArrayStreamReader`. pub fn get_next_batch(&mut self) -> Result<(), CometError> { if self.input_source.is_none() { - // This is a unit test. We don't need to call JNI. + // This is a unit test. Input batches are seeded via `set_input_batch`. return Ok(()); } let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut current_batch = self.batch.try_lock().unwrap(); if current_batch.is_none() { - let next_batch = ScanExec::get_next( - self.exec_context_id, - self.input_source.as_ref().unwrap().as_obj(), - self.data_types.len(), - self.arrow_ffi_safe, - )?; + let next_batch = + ScanExec::pull_next(self.exec_context_id, self.input_source.as_ref().unwrap())?; *current_batch = Some(next_batch); } @@ -155,119 +132,35 @@ impl ScanExec { Ok(()) } - /// Invokes JNI call to get next batch. - fn get_next( + /// Pull the next `RecordBatch` from the stream and convert it to an `InputBatch`. Dictionary + /// columns are unpacked because Comet's downstream operators do not handle them. + fn pull_next( exec_context_id: i64, - iter: &JObject, - num_cols: usize, - arrow_ffi_safe: bool, + reader: &Arc>, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { - // This is a unit test. We don't need to call JNI. + // Unit test path; input batches are seeded directly. return Ok(InputBatch::EOF); } - if iter.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null batch iterator object. Plan id: {exec_context_id}" - )))); - } - - JVMClasses::with_env(|env| { - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).has_next() -> i32)? - }; - - if num_rows == -1 { - return Ok(InputBatch::EOF); - } - - // fetch batch data from JVM via FFI - let (num_rows, array_addrs, schema_addrs) = - Self::allocate_and_fetch_batch(env, iter, num_cols)?; - - let mut inputs: Vec = Vec::with_capacity(num_cols); - - // Process each column - for i in 0..num_cols { - let array_ptr = array_addrs[i]; - let schema_ptr = schema_addrs[i]; - let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; - - // TODO: validate array input data - // array_data.validate_full()?; - - let array = make_array(array_data); - - let array = if arrow_ffi_safe { - // ownership of this array has been transferred to native - // but we still need to unpack dictionary arrays - copy_or_unpack_array(&array, &CopyMode::UnpackOrClone)? - } else { - // it is necessary to copy the array because the contents may be - // overwritten on the JVM side in the future - copy_array(&array) - }; - - inputs.push(array); - - // Drop the Arcs to avoid memory leak - unsafe { - Rc::from_raw(array_ptr as *const FFI_ArrowArray); - Rc::from_raw(schema_ptr as *const FFI_ArrowSchema); + let mut reader = reader + .try_lock() + .map_err(|_| CometError::Internal("ArrowArrayStreamReader contended".to_string()))?; + + let next = reader.next(); + match next { + None => Ok(InputBatch::EOF), + Some(Err(e)) => Err(CometError::from(e)), + Some(Ok(record_batch)) => { + let num_rows = record_batch.num_rows(); + let columns = record_batch.columns(); + let mut inputs: Vec = Vec::with_capacity(columns.len()); + for col in columns { + inputs.push(copy_or_unpack_array(col, &CopyMode::UnpackOrClone)?); } + Ok(InputBatch::new(inputs, Some(num_rows))) } - - Ok(InputBatch::new(inputs, Some(num_rows as usize))) - }) - } - - /// Allocates Arrow FFI structures and calls JNI to get the next batch data. - /// Returns the number of rows and the allocated array/schema addresses. - fn allocate_and_fetch_batch( - env: &mut jni::Env, - iter: &JObject, - num_cols: usize, - ) -> Result<(i32, Vec, Vec), CometError> { - let mut array_addrs = Vec::with_capacity(num_cols); - let mut schema_addrs = Vec::with_capacity(num_cols); - - for _ in 0..num_cols { - let arrow_array = Rc::new(FFI_ArrowArray::empty()); - let arrow_schema = Rc::new(FFI_ArrowSchema::empty()); - let (array_ptr, schema_ptr) = ( - Rc::into_raw(arrow_array) as i64, - Rc::into_raw(arrow_schema) as i64, - ); - - array_addrs.push(array_ptr); - schema_addrs.push(schema_ptr); } - - // Prepare the java array parameters - let long_array_addrs = env.new_long_array(num_cols)?; - let long_schema_addrs = env.new_long_array(num_cols)?; - - long_array_addrs.set_region(env, 0, &array_addrs)?; - long_schema_addrs.set_region(env, 0, &schema_addrs)?; - - let array_obj = JObject::from(long_array_addrs); - let schema_obj = JObject::from(long_schema_addrs); - - let array_obj = JValue::Object(array_obj.as_ref()); - let schema_obj = JValue::Object(schema_obj.as_ref()); - - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)? - }; - - // we already checked for end of results on call to has_next() so should always - // have a valid row count when calling next() - assert!(num_rows != -1); - - Ok((num_rows, array_addrs, schema_addrs)) } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 542c3d9536..77b174ea8d 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,6 +21,7 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::errors::CometError; use crate::execution::operators::init_csv_datasource_exec; use crate::execution::operators::IcebergScanExec; use crate::execution::{ @@ -32,6 +33,7 @@ use crate::execution::{ serde::to_arrow_datatype, shuffle::ShuffleWriterExec, }; +use crate::jvm_bridge::{jni_call, JVMClasses}; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; @@ -1447,23 +1449,36 @@ impl PhysicalPlanner { return Err(GeneralError("No input for scan".to_string())); } - // Consumes the first input source for the scan - let input_source = - if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - // For unit test, we will set input batch to scan directly by `set_input_batch`. - None - } else { - Some(inputs.remove(0)) - }; + // Consumes the first input source for the scan. The Java side passes an + // `org.apache.arrow.c.ArrowArrayStream` whose `memoryAddress` points at the C + // struct; native takes ownership via `ArrowArrayStreamReader::from_raw`. + let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID + && inputs.is_empty() + { + // For unit test, we will set input batch to scan directly by `set_input_batch`. + None + } else { + let java_stream = inputs.remove(0); + let address: i64 = JVMClasses::with_env(|env| -> Result { + let addr = unsafe { + jni_call!(env, arrow_array_stream(java_stream.as_obj()).memory_address() -> i64)? + }; + Ok(addr) + })?; + let reader = unsafe { + arrow::ffi_stream::ArrowArrayStreamReader::from_raw( + address as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + ) + } + .map_err(|e| { + GeneralError(format!("Failed to import ArrowArrayStream from JVM: {e}")) + })?; + Some(Arc::new(std::sync::Mutex::new(reader))) + }; // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new( - self.exec_context_id, - input_source, - &scan.source, - data_types, - scan.arrow_ffi_safe, - )?; + let scan = + ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; Ok(( vec![scan.clone()], @@ -3980,7 +3995,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4046,7 +4060,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4256,7 +4269,6 @@ mod tests { op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![create_proto_datatype()], source: "".to_string(), - arrow_ffi_safe: false, })), } } @@ -4299,7 +4311,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4422,7 +4433,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4905,7 +4915,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 2fe6f8758f..6195e3f0ae 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -19,48 +19,15 @@ use crate::execution::operators::ExecutionError; use arrow::{ array::ArrayData, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, }; pub trait SparkArrowConvert { - /// Build Arrow Arrays from C data interface passed from Spark. - /// It accepts a tuple (ArrowArray address, ArrowSchema address). - fn from_spark(addresses: (i64, i64)) -> Result - where - Self: Sized; - /// Move Arrow Arrays to C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>; } impl SparkArrowConvert for ArrayData { - fn from_spark(addresses: (i64, i64)) -> Result { - let (array_ptr, schema_ptr) = addresses; - - let array_ptr = array_ptr as *mut FFI_ArrowArray; - let schema_ptr = schema_ptr as *mut FFI_ArrowSchema; - - if array_ptr.is_null() || schema_ptr.is_null() { - return Err(ExecutionError::ArrowError( - "At least one of passed pointers is null".to_string(), - )); - }; - - // `ArrowArray` will convert raw pointers back to `Arc`. No worries - // about memory leak. - let mut ffi_array = unsafe { - let array_data = std::ptr::replace(array_ptr, FFI_ArrowArray::empty()); - let schema_data = std::ptr::replace(schema_ptr, FFI_ArrowSchema::empty()); - - from_ffi(array_data, &schema_data)? - }; - - // Align imported buffers from Java. - ffi_array.align_buffers(); - - Ok(ffi_array) - } - /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { let array_ptr = array as *mut FFI_ArrowArray; diff --git a/native/jni-bridge/src/batch_iterator.rs b/native/jni-bridge/src/arrow_array_stream.rs similarity index 57% rename from native/jni-bridge/src/batch_iterator.rs rename to native/jni-bridge/src/arrow_array_stream.rs index addda133fa..0b285607ff 100644 --- a/native/jni-bridge/src/batch_iterator.rs +++ b/native/jni-bridge/src/arrow_array_stream.rs @@ -15,45 +15,38 @@ // specific language governing permissions and limitations // under the License. -use jni::signature::Primitive; use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, - signature::ReturnType, + signature::{Primitive, ReturnType}, strings::JNIString, Env, }; -/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. +/// A struct that holds all the JNI methods and fields for JVM `org.apache.arrow.c.ArrowArrayStream` +/// class. `memoryAddress()` is read once per partition so native can take ownership of the +/// underlying C struct via `ArrowArrayStreamReader::from_raw`. #[allow(dead_code)] // we need to keep references to Java items to prevent GC -pub struct CometBatchIterator<'a> { +pub struct ArrowArrayStream<'a> { pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_next: JMethodID, - pub method_next_ret: ReturnType, + pub method_memory_address: JMethodID, + pub method_memory_address_ret: ReturnType, } -impl<'a> CometBatchIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator"; +impl<'a> ArrowArrayStream<'a> { + pub const JVM_CLASS: &'static str = "org/apache/arrow/c/ArrowArrayStream"; - pub fn new(env: &mut Env<'a>) -> JniResult> { + pub fn new(env: &mut Env<'a>) -> JniResult> { let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; - Ok(CometBatchIterator { - class, - method_has_next: env.get_method_id( - JNIString::new(Self::JVM_CLASS), - jni::jni_str!("hasNext"), - jni::jni_sig!("()I"), - )?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_next: env.get_method_id( + Ok(ArrowArrayStream { + method_memory_address: env.get_method_id( JNIString::new(Self::JVM_CLASS), - jni::jni_str!("next"), - jni::jni_sig!("([J[J)I"), + jni::jni_str!("memoryAddress"), + jni::jni_sig!("()J"), )?, - method_next_ret: ReturnType::Primitive(Primitive::Int), + method_memory_address_ret: ReturnType::Primitive(Primitive::Long), + class, }) } } diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index d72323c961..c8bb7cd02d 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -189,13 +189,13 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { mod comet_exec; pub use comet_exec::*; -mod batch_iterator; +mod arrow_array_stream; mod comet_metric_node; mod comet_task_memory_manager; mod comet_udf_bridge; mod shuffle_block_iterator; -use batch_iterator::CometBatchIterator; +use arrow_array_stream::ArrowArrayStream; pub use comet_metric_node::*; pub use comet_task_memory_manager::*; use comet_udf_bridge::CometUdfBridge; @@ -223,8 +223,9 @@ pub struct JVMClasses<'a> { pub comet_metric_node: CometMetricNode<'a>, /// The static CometExec class. Used for getting the subquery result. pub comet_exec: CometExec<'a>, - /// The CometBatchIterator class. Used for iterating over the batches. - pub comet_batch_iterator: CometBatchIterator<'a>, + /// The org.apache.arrow.c.ArrowArrayStream class. Used to get the C struct memory address + /// when importing a JVM-exported batch stream into native code. + pub arrow_array_stream: ArrowArrayStream<'a>, /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to @@ -300,7 +301,7 @@ impl JVMClasses<'_> { throwable_get_cause_method, comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), - comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + arrow_array_stream: ArrowArrayStream::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index ed1684b240..b65a215c78 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,8 +85,6 @@ message Scan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; - // Whether native code can assume ownership of batches that it receives - bool arrow_ffi_safe = 3; } message ShuffleScan { diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java deleted file mode 100644 index 9b48a47c57..0000000000 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import scala.collection.Iterator; - -import org.apache.spark.sql.vectorized.ColumnarBatch; - -import org.apache.comet.vector.NativeUtil; - -/** - * Iterator for fetching batches from JVM to native code. Usually called via JNI from native - * ScanExec. - * - *

Batches are owned by the JVM. Native code can safely access the batch after calling `next` but - * the native code must not retain references to the batch because the next call to `hasNext` will - * signal to the JVM that the batch can be closed. - */ -public class CometBatchIterator { - private final Iterator input; - private final NativeUtil nativeUtil; - private ColumnarBatch previousBatch = null; - private ColumnarBatch currentBatch = null; - - CometBatchIterator(Iterator input, NativeUtil nativeUtil) { - this.input = input; - this.nativeUtil = nativeUtil; - } - - /** - * Fetch the next input batch and allow the previous batch to be closed (this may not happen - * immediately). - * - * @return Number of rows in next batch or -1 if no batches left. - */ - public int hasNext() { - - // release reference to previous batch - previousBatch = null; - - if (currentBatch == null) { - if (input.hasNext()) { - currentBatch = input.next(); - } - } - if (currentBatch == null) { - return -1; - } else { - return currentBatch.numRows(); - } - } - - /** - * Get the next batch of Arrow arrays. - * - * @param arrayAddrs The addresses of the ArrowArray structures. - * @param schemaAddrs The addresses of the ArrowSchema structures. - * @return the number of rows of the current batch. -1 if there is no more batch. - */ - public int next(long[] arrayAddrs, long[] schemaAddrs) { - if (currentBatch == null) { - return -1; - } - - // export the batch using the Arrow C Data Interface - int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); - - // keep a reference to the exported batch so that it doesn't get garbage collected - // while the native code is still processing it - previousBatch = currentBatch; - - currentBatch = null; - - return numRows; - } -} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 6140eca553..d17735a560 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -60,7 +60,7 @@ import org.apache.comet.vector.NativeUtil */ class CometExecIterator( val id: Long, - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode, @@ -79,14 +79,11 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle - // scan indices, CometBatchIterator for regular scan indices. - private val inputIterators: Array[Object] = inputs.zipWithIndex.map { - case (_, idx) if shuffleBlockIterators.contains(idx) => - shuffleBlockIterators(idx).asInstanceOf[Object] - case (iterator, _) => - new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] - }.toArray + // Each input slot is either an org.apache.arrow.c.ArrowArrayStream (consumed natively via + // ArrowArrayStreamReader::from_raw against its memoryAddress) or a CometShuffleBlockIterator + // (consumed via the existing JNI block-iteration protocol). The slot index matches the scan + // input index in the serialized native plan. + private val inputIterators: Array[Object] = inputObjects private val plan = { val conf = SparkEnv.get.conf diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 69b9bd5f85..4a8ae4d2ac 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -96,7 +96,6 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec val scanOp = OperatorOuterClass.Scan .newBuilder() .setSource(cmd.query.nodeName) - .setArrowFfiSafe(false) // Add fields from the query output schema val scanTypes = cmd.query.output.flatMap { attr => diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index 845803d133..b7caeb43c2 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -40,9 +40,6 @@ import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataTy */ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { - /** Whether the data produced by the Comet operator is FFI safe */ - def isFfiSafe: Boolean = true - override def enabledConfig: Option[ConfigEntry[Boolean]] = None override def convert( @@ -65,7 +62,6 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { } else { scanBuilder.setSource(source) } - scanBuilder.setArrowFfiSafe(isFfiSafe) val scanTypes = op.output.flatten { attr => serializeDataType(attr.dataType) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 47eda98a11..4b411d87f7 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.comet +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,7 +28,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometRuntimeException, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -40,23 +41,14 @@ private[spark] class CometExecPartition( extends Partition /** - * Unified RDD for Comet native execution. - * - * Solves the closure capture problem: instead of capturing all partitions' data in the closure - * (which gets serialized to every task), each Partition object carries only its own data. - * - * Handles three cases: - * - With inputs + per-partition data: injects planning data into operator tree - * - With inputs + no per-partition data: just zips inputs (no injection overhead) - * - No inputs: uses numPartitions to create partitions - * - * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in - * CometIcebergNativeScanExec.serializedPartitionData before this RDD is created. It also handles - * ScalarSubquery expressions by registering them with CometScalarSubquery before execution. + * Unified RDD for Comet native execution. Non-shuffle input slots are `RDD[ArrowArrayStream]` + * (consumed natively via the C Stream Interface); shuffle input slots are `CometShuffledBatchRDD` + * (consumed via `CometShuffleBlockIterator`). Slot order matches the scan-input order in the + * serialized native plan. */ private[spark] class CometExecRDD( sc: SparkContext, - var inputRDDs: Seq[RDD[ColumnarBatch]], + var inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], @transient perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], @@ -97,9 +89,31 @@ private[spark] class CometExecRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometExecPartition] - val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => - rdd.iterator(part, context) - } + val shuffleBlockIters = scala.collection.mutable.Map.empty[Int, CometShuffleBlockIterator] + val inputObjects: Array[Object] = inputRDDs + .zip(partition.inputPartitions) + .zipWithIndex + .map { case ((rdd, part), idx) => + if (shuffleScanIndices.contains(idx)) { + rdd match { + case shuffleRDD: CometShuffledBatchRDD => + val it = shuffleRDD.computeAsShuffleBlockIterator(part, context) + shuffleBlockIters(idx) = it + it.asInstanceOf[Object] + case other => + throw new CometRuntimeException( + s"Slot $idx is marked as a shuffle scan but the input RDD is " + + s"${other.getClass.getName}, expected CometShuffledBatchRDD") + } + } else { + val streams = rdd.iterator(part, context).asInstanceOf[Iterator[ArrowArrayStream]] + if (!streams.hasNext) { + throw new CometRuntimeException(s"Empty ArrowArrayStream RDD partition for slot $idx") + } + streams.next().asInstanceOf[Object] + } + } + .toArray // Only inject if we have per-partition planning data val actualPlan = if (commonByKey.nonEmpty) { @@ -111,18 +125,9 @@ private[spark] class CometExecRDD( serializedPlan } - // Create shuffle block iterators for inputs that are CometShuffledBatchRDD - val shuffleBlockIters = shuffleScanIndices.flatMap { idx => - inputRDDs(idx) match { - case rdd: CometShuffledBatchRDD => - Some(idx -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(idx), context)) - case _ => None - } - }.toMap - val it = new CometExecIterator( CometExec.newIterId, - inputs, + inputObjects, numOutputCols, actualPlan, nativeMetrics, @@ -130,7 +135,7 @@ private[spark] class CometExecRDD( partition.index, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleBlockIters) + shuffleBlockIters.toMap) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -169,7 +174,7 @@ object CometExecRDD { // scalastyle:off def apply( sc: SparkContext, - inputRDDs: Seq[RDD[ColumnarBatch]], + inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index a2af60142b..e632190f0a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -25,6 +25,8 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch @@ -56,8 +58,19 @@ object CometExecUtils { // Serialize the plan once before mapping to avoid repeated serialization per partition val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get val serializedPlan = CometExec.serializeNativePlan(limitOp) + val inputSchema = Utils.fromAttributes(outputAttribute) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), outputAttribute.length, serializedPlan, numParts, idx) + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + inputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometExecUtils-getNativeLimit") + CometExec.getCometIterator( + Array(stream.asInstanceOf[Object]), + outputAttribute.length, + serializedPlan, + numParts, + idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 6ada259e82..32b8933872 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -21,11 +21,12 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer -import org.apache.spark.TaskContext +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DataType, NullType} @@ -43,7 +44,8 @@ case class CometLocalTableScanExec( @transient rows: Seq[InternalRow], override val output: Seq[Attribute]) extends CometExec - with LeafExecNode { + with LeafExecNode + with CometNativeArrowSource { override lazy val metrics: Map[String, SQLMetric] = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -66,25 +68,47 @@ case class CometLocalTableScanExec( } } + private def countingRows( + iter: Iterator[InternalRow], + numOutputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numOutputRows.add(1) + row + } + } + override def doExecuteColumnar(): RDD[ColumnarBatch] = { val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - val schema = originalPlan.schema - // Native side asserts Timestamp(Microsecond, Some("UTC")). See COMET-2720. + val sparkSchema = originalPlan.schema + rdd.mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometLocalTableScan", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numOutputRows), + maxRecordsPerBatch)) + } + } + + override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { + val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) + val sparkSchema = originalPlan.schema + val numOutputRows = longMetric("numOutputRows") rdd.mapPartitionsInternal { rowIter => - val context = TaskContext.get() - // Non-Comet JVM consumers (e.g. Iceberg writers) may retain batches across next() - // calls, so each batch must own independent Arrow buffers. - val batches = CometArrowConverters.rowToArrowBatchIterNoReuse( - rowIter, - schema, - maxRecordsPerBatch, - "UTC", - context) - batches.map { batch => - numOutputRows.add(batch.numRows()) - batch - } + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometLocalTableScan", + allocator => + new RowArrowReader( + allocator, + arrowSchema, + countingRows(rowIter, numOutputRows), + maxRecordsPerBatch)) } } @@ -117,9 +141,9 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTy override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) - // CometArrowConverters / ArrowWriter support NullType (via Utils.toArrowType + - // NullWriter). Other types not on DataTypeSupport's allow list (e.g. TimeType, - // intervals) lack ArrowWriter coverage and must fall back to Spark. + // ArrowWriter (used by RowArrowReader) handles NullType via Utils.toArrowType + NullWriter; + // other types off DataTypeSupport's allow list (TimeType, intervals, ...) have no ArrowWriter + // coverage and must fall back to Spark. override def isTypeSupported( dt: DataType, name: String, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 4fb8af39e8..7ba281a666 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -28,6 +28,8 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -210,9 +212,15 @@ case class CometNativeWriteExec( modifiedNativeOp.writeTo(codedOutput) codedOutput.checkNoSpaceLeft() + val arrowStream = CometArrowStream.fromColumnarBatchIter( + iter, + CometUtils.fromAttributes(child.output), + CometArrowStream.NATIVE_TIMEZONE, + "CometNativeWriteExec") + val execIterator = new CometExecIterator( CometExec.newIterId, - Seq(iter), + Array(arrowStream.asInstanceOf[Object]), numOutputCols, planBytes, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index efe6a97d40..00e13bcbde 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -21,13 +21,14 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer -import org.apache.spark.TaskContext +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader, SparkColumnarArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types._ @@ -39,7 +40,8 @@ import org.apache.comet.serde.operator.CometSink case class CometSparkToColumnarExec(child: SparkPlan) extends RowToColumnarTransition - with CometPlan { + with CometPlan + with CometNativeArrowSource { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning @@ -69,72 +71,99 @@ case class CometSparkToColumnarExec(child: SparkPlan) sparkContext, "time converting Spark batches to Arrow batches")) - // The conversion happens in next(), so wrap the call to measure time spent. - private def createTimingIter( + private def countingBatches( iter: Iterator[ColumnarBatch], numInputRows: SQLMetric, - numOutputBatches: SQLMetric, - conversionTime: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { + numOutputBatches: SQLMetric): Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = iter.hasNext + override def next(): ColumnarBatch = { + val batch = iter.next() + numInputRows += batch.numRows() + numOutputBatches += 1 + batch + } + } - override def hasNext: Boolean = { - iter.hasNext - } + private def countingRows( + iter: Iterator[InternalRow], + numInputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numInputRows += 1 + row + } + } + + override def doExecuteColumnar(): RDD[ColumnarBatch] = { + val numInputRows = longMetric("numInputRows") + val numOutputBatches = longMetric("numOutputBatches") + val conversionTime = longMetric("conversionTime") + val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) + val sparkSchema = child.schema - override def next(): ColumnarBatch = { - val startNs = System.nanoTime() - val batch = iter.next() - conversionTime += System.nanoTime() - startNs - numInputRows += batch.numRows() - numOutputBatches += 1 - batch + if (child.supportsColumnar) { + val maxBatchInt = maxRecordsPerBatch.toInt + child.executeColumnar().mapPartitionsInternal { sparkBatches => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometSparkColumnarToColumnar", + new SparkColumnarArrowReader( + _, + arrowSchema, + countingBatches(sparkBatches, numInputRows, numOutputBatches), + maxBatchInt, + ns => conversionTime += ns)) + } + } else { + child.execute().mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.readerBatchIter( + "CometSparkRowToColumnar", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numInputRows), + maxRecordsPerBatch, + ns => conversionTime += ns)) } } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { + override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC for Arrow schema timezone to match the native side, which always - // deserializes Timestamp as Timestamp(Microsecond, Some("UTC")). Spark's internal - // timestamp representation is always UTC microseconds, so the timezone here is - // purely schema metadata. Using session timezone would cause Arrow RowConverter - // schema mismatch errors in non-UTC sessions. See COMET-2720. - val timeZoneId = "UTC" - val schema = child.schema + val sparkSchema = child.schema if (child.supportsColumnar) { - child - .executeColumnar() - .mapPartitionsInternal { sparkBatches => - val arrowBatches = - sparkBatches.flatMap { sparkBatch => - val context = TaskContext.get() - CometArrowConverters.columnarBatchToArrowBatchIter( - sparkBatch, - schema, - maxRecordsPerBatch, - timeZoneId, - context) - } - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + val maxBatchInt = maxRecordsPerBatch.toInt + child.executeColumnar().mapPartitionsInternal { sparkBatches => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometSparkColumnarToColumnar", + allocator => + new SparkColumnarArrowReader( + allocator, + arrowSchema, + countingBatches(sparkBatches, numInputRows, numOutputBatches), + maxBatchInt, + ns => conversionTime += ns)) + } } else { - child - .execute() - .mapPartitionsInternal { sparkBatches => - val context = TaskContext.get() - val arrowBatches = - CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - schema, + child.execute().mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + CometArrowStream.stream( + "CometSparkRowToColumnar", + allocator => + new RowArrowReader( + allocator, + arrowSchema, + countingRows(rowIter, numInputRows), maxRecordsPerBatch, - timeZoneId, - context) - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + ns => conversionTime += ns)) + } } } @@ -145,9 +174,6 @@ case class CometSparkToColumnarExec(child: SparkPlan) object CometSparkToColumnarExec extends CometSink[SparkPlan] with DataTypeSupport { - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false - override def createExec( nativeOp: OperatorOuterClass.Operator, op: SparkPlan): CometNativeExec = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index a66d1b58d6..e9b178dc6b 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -24,7 +24,9 @@ import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -140,8 +142,19 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopK = CometExec.serializeNativePlan(topK) val numOutputCols = child.output.length + val inputSchema = CometUtils.fromAttributes(child.output) childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), numOutputCols, serializedTopK, numParts, idx) + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + inputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometTakeOrderedAndProject-topK") + CometExec.getCometIterator( + Array(stream.asInstanceOf[Object]), + numOutputCols, + serializedTopK, + numParts, + idx) } } @@ -163,9 +176,15 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopKAndProjection = CometExec.serializeNativePlan(topKAndProjection) val finalOutputLength = output.length + val finalInputSchema = CometUtils.fromAttributes(child.output) singlePartitionRDD.mapPartitionsInternal { iter => + val stream = CometArrowStream.fromColumnarBatchIter( + iter, + finalInputSchema, + CometArrowStream.NATIVE_TIMEZONE, + "CometTakeOrderedAndProject-final") val it = CometExec.getCometIterator( - Seq(iter), + Array(stream.asInstanceOf[Object]), finalOutputLength, serializedTopKAndProjection, 1, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala new file mode 100644 index 0000000000..eaacc3968b --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import java.util.{ArrayList => JArrayList} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.CometVector + +/** + * `ArrowReader` over an iterator of Arrow-backed `ColumnarBatch`es. Each `loadNextBatch` unloads + * the source's `FieldVector`s into a transient `ArrowRecordBatch` (retains buffers), loads it + * into this reader's stable VSR via `loadFieldBuffers` (release-and-replace), then closes the + * source batch. The unload/load step decouples this reader's VSR ownership from whatever the + * upstream does with its own buffers. + */ +private[comet] class ColumnarBatchArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch]) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!source.hasNext) { + return false + } + + val src = source.next() + try { + val sourceVectors = new JArrayList[FieldVector](src.numCols()) + var i = 0 + while (i < src.numCols()) { + sourceVectors.add( + src.column(i).asInstanceOf[CometVector].getValueVector.asInstanceOf[FieldVector]) + i += 1 + } + val transient = new VectorSchemaRoot(sourceVectors) + transient.setRowCount(src.numRows()) + + val unloader = new VectorUnloader(transient) + val rb = unloader.getRecordBatch + try { + loadRecordBatch(rb) + } finally { + rb.close() + } + // Note: do not close `transient`. It shares FieldVectors with `src`; closing `src` below + // releases the producer-side refs. Closing `transient` would double-release. + } finally { + src.close() + } + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 0e7d0d9a8d..32441029bb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -22,138 +22,41 @@ package org.apache.spark.sql.comet.execution.arrow import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} -import org.apache.comet.CometArrowAllocator import org.apache.comet.vector.NativeUtil +/** + * Convert Spark `InternalRow` / `ColumnarBatch` streams to a stream of independently-owned Arrow + * `ColumnarBatch`es. Each emitted batch owns a fresh `VectorSchemaRoot` with newly allocated + * buffers; the consumer is responsible for closing the batch. + * + * Buffers are allocated from the caller-provided `BufferAllocator`. The caller owns the + * allocator's lifecycle (typically a child allocator closed at task completion). When emitted + * batches reach `ColumnarBatchArrowReader.loadNextBatch`, ownership of their buffers is + * transferred (via `VectorUnloader` / `loadFieldBuffers`) to the reader's allocator, after which + * the source batch is closed and the producer's allocator returns to zero outstanding bytes. + */ object CometArrowConverters extends Logging { - // This is similar how Spark converts internal row to Arrow format except that it is transforming - // the result batch to Comet's ColumnarBatch instead of serialized bytes. - // There's another big difference that Comet may consume the ColumnarBatch by exporting it to - // the native side. Hence, we need to: - // 1. reset the Arrow writer after the ColumnarBatch is consumed - // 2. close the allocator when the task is finished but not when the iterator is all consumed - // The reason for the second point is that when ColumnarBatch is exported to the native side, the - // exported process increases the reference count of the Arrow vectors. The reference count is - // only decreased when the native plan is done with the vectors, which is usually longer than - // all the ColumnarBatches are consumed. - - abstract private[sql] class ArrowBatchIterBase( - schema: StructType, - timeZoneId: String, - context: TaskContext) - extends Iterator[ColumnarBatch] - with AutoCloseable { - - protected val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) - // Reuse the same root allocator here. - protected val allocator: BufferAllocator = - CometArrowAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) - protected val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, allocator) - protected val arrowWriter: ArrowWriter = ArrowWriter.create(root) - - protected var currentBatch: ColumnarBatch = null - protected var closed: Boolean = false - - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - close(true) - } - } - - override def close(): Unit = { - close(false) - } - - protected def close(closeAllocator: Boolean): Unit = { - try { - if (!closed) { - if (currentBatch != null) { - arrowWriter.reset() - currentBatch.close() - currentBatch = null - } - root.close() - closed = true - } - } finally { - // the allocator shall be closed when the task is finished - if (closeAllocator) { - allocator.close() - } - } - } - - override def next(): ColumnarBatch = { - currentBatch = nextBatch() - currentBatch - } - - protected def nextBatch(): ColumnarBatch - - } - - private[sql] class RowToArrowBatchIter( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { - - override def hasNext: Boolean = rowIter.hasNext || { - close(false) - false - } - - override protected def nextBatch(): ColumnarBatch = { - if (rowIter.hasNext) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() - var rowCount = 0L - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) - rowCount += 1 - } - arrowWriter.finish() - NativeUtil.rootAsBatch(root) - } else { - null - } - } - } - - def rowToArrowBatchIter( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new RowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) - } /** - * Use when the downstream consumer may retain batches across `next()` calls (e.g. non-Comet JVM - * columnar sinks like Iceberg writers). Each batch owns independent Arrow buffers. + * Convert an iterator of Spark `InternalRow`s into an iterator of Arrow `ColumnarBatch`es. + * + * Each call to `next()` allocates a fresh `VectorSchemaRoot`, writes up to `maxRecordsPerBatch` + * rows into it, and emits a `ColumnarBatch` wrapping that root. The consumer must close every + * emitted batch. */ - def rowToArrowBatchIterNoReuse( + def rowToArrowBatchIter( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - val arrowSchema = Utils.toArrowSchema(schema, timeZoneId) - val allocator = - CometArrowAllocator.newChildAllocator("rowToArrowBatchIterNoReuse", 0, Long.MaxValue) - Option(context).foreach(_.addTaskCompletionListener[Unit](_ => allocator.close())) + allocator: BufferAllocator): Iterator[ColumnarBatch] = { + val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) new Iterator[ColumnarBatch] { override def hasNext: Boolean = rowIter.hasNext @@ -173,57 +76,46 @@ object CometArrowConverters extends Logging { } } - private[sql] class ColumnBatchToArrowBatchIter( + /** + * Slice a single Spark `ColumnarBatch` into one or more Arrow `ColumnarBatch`es of at most + * `maxRecordsPerBatch` rows each. Each emitted batch owns a fresh `VectorSchemaRoot`. + */ + def columnarBatchToArrowBatchIter( colBatch: ColumnarBatch, schema: StructType, maxRecordsPerBatch: Int, timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { + allocator: BufferAllocator): Iterator[ColumnarBatch] = { + val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) + val totalRows = colBatch.numRows() - private var rowsProduced: Int = 0 + new Iterator[ColumnarBatch] { + private var rowsProduced: Int = 0 - override def hasNext: Boolean = rowsProduced < colBatch.numRows() || { - close(false) - false - } + override def hasNext: Boolean = rowsProduced < totalRows - override protected def nextBatch(): ColumnarBatch = { - val rowsInBatch = colBatch.numRows() - if (rowsProduced < rowsInBatch) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() + override def next(): ColumnarBatch = { val rowsToProduce = - if (maxRecordsPerBatch <= 0) rowsInBatch - rowsProduced - else Math.min(maxRecordsPerBatch, rowsInBatch - rowsProduced) + if (maxRecordsPerBatch <= 0) totalRows - rowsProduced + else math.min(maxRecordsPerBatch, totalRows - rowsProduced) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = ArrowWriter.create(root) for (columnIndex <- 0 until colBatch.numCols()) { val column = colBatch.column(columnIndex) val columnArray = new ColumnarArray(column, rowsProduced, rowsToProduce) if (column.hasNull) { - arrowWriter.writeCol(columnArray, columnIndex) + writer.writeCol(columnArray, columnIndex) } else { - arrowWriter.writeColNoNull(columnArray, columnIndex) + writer.writeColNoNull(columnArray, columnIndex) } } rowsProduced += rowsToProduce - - arrowWriter.finish() + writer.finish() NativeUtil.rootAsBatch(root) - } else { - null } } } - - def columnarBatchToArrowBatchIter( - colBatch: ColumnarBatch, - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new ColumnBatchToArrowBatchIter(colBatch, schema, maxRecordsPerBatch, timeZoneId, context) - } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala new file mode 100644 index 0000000000..14a7a9ed0c --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.c.{ArrowArrayStream, Data} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.spark.TaskContext +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometArrowAllocator +import org.apache.comet.vector.NativeUtil + +/** + * Marker for Comet operators that can produce Arrow data destined for a Comet native executor + * directly as the C Stream Interface, skipping the intermediate `RDD[ColumnarBatch]` layer. + */ +trait CometNativeArrowSource extends SparkPlan { + def doExecuteAsArrowStream(): RDD[ArrowArrayStream] +} + +object CometArrowStream { + + /** + * Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone; + * Spark's internal timestamp representation is always UTC microseconds anyway, and a non-UTC + * timezone here would only show up as schema metadata that breaks Arrow RowConverter + * validation. See COMET-2720. + */ + val NATIVE_TIMEZONE: String = "UTC" + + /** + * Wrap an `RDD[ColumnarBatch]` whose batches are Arrow-backed into an `RDD[ArrowArrayStream]`. + */ + def wrapColumnarBatchRDD( + rdd: RDD[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): RDD[ArrowArrayStream] = { + // Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema + // inside the per-task body so the closure cleaner doesn't try to ship a Schema across. + rdd.mapPartitionsInternal { batchIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, batchIter)) + } + } + + /** + * Wrap a single per-partition `Iterator[ColumnarBatch]` (Arrow-backed) and return the exported + * `ArrowArrayStream`. For callers outside `CometExecRDD` that hand a JNI input slot directly to + * a `CometExecIterator`. + */ + def fromColumnarBatchIter( + iter: Iterator[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): ArrowArrayStream = { + val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)).next() + } + + /** + * Allocate a child allocator, build a reader, export it as an `ArrowArrayStream`, and register + * task-completion cleanup. Returns a single-element iterator so this composes with + * `RDD.mapPartitionsInternal`. + * + * Close ordering: when native drops its `ArrowArrayStreamReader`, the C release callback fires + * synchronously into `ExportedArrayStreamPrivateData.close` -> `reader.close` -> the VSR's + * buffers are released. The task-completion listener registered here runs strictly after that + * (Spark fires listeners in reverse registration order, and the listener that drops the native + * plan is registered later by `CometExecIterator`), so `allocator.close` finds zero outstanding + * bytes. + */ + def stream( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ArrowArrayStream] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + var reader: ArrowReader = null + var arrowStream: ArrowArrayStream = null + try { + reader = readerFactory(allocator) + arrowStream = ArrowArrayStream.allocateNew(allocator) + Data.exportArrayStream(allocator, reader, arrowStream) + } catch { + case t: Throwable => + // Roll back partial setup before rethrowing -- nothing has been registered with + // TaskContext yet, so without this the allocator (and possibly the reader/stream) leaks. + if (arrowStream != null) { + try arrowStream.close() + catch { case _: Throwable => () } + } + if (reader != null) { + try reader.close() + catch { case _: Throwable => () } + } + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + val streamRef = arrowStream + context.addTaskCompletionListener[Unit] { _ => + streamRef.close() + allocator.close() + } + } + Iterator.single(arrowStream) + } + + /** + * Drive an `ArrowReader` from a per-task body and emit `ColumnarBatch`es wrapping the reader's + * stable VSR. Lifecycle: the supplied factory builds the reader against a fresh child + * allocator; both close at task completion. This is the non-native consumer path + * (`doExecuteColumnar`) -- the native consumer path uses [[stream]] to export instead. + */ + def readerBatchIter( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ColumnarBatch] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + val reader = + try readerFactory(allocator) + catch { + case t: Throwable => + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + context.addTaskCompletionListener[Unit] { _ => + reader.close() + allocator.close() + } + } + new Iterator[ColumnarBatch] { + // Lazily prefetch one batch so `hasNext` can answer without consuming. + private var loaded: Boolean = false + private var hasMore: Boolean = false + + private def ensureLoaded(): Unit = { + if (!loaded) { + hasMore = reader.loadNextBatch() + loaded = true + } + } + + override def hasNext: Boolean = { + ensureLoaded() + hasMore + } + + override def next(): ColumnarBatch = { + ensureLoaded() + if (!hasMore) { + throw new NoSuchElementException("No more batches") + } + loaded = false + NativeUtil.rootAsBatch(reader.getVectorSchemaRoot) + } + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala new file mode 100644 index 0000000000..e1829eb5c5 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.catalyst.InternalRow + +/** + * `ArrowReader` over an iterator of Spark `InternalRow`s, writing up to `maxRecordsPerBatch` rows + * per call into the reader's stable VSR via `ArrowWriter`. + * + * `ArrowWriter.create(root)` calls `vector.allocateNew()`, which releases any prior buffers and + * allocates fresh ones. This is required for FFI safety: previously-exported batches retain their + * buffers via the C release callback, so reusing those buffers in place would corrupt native + * consumers still holding the prior batch. + */ +private[comet] class RowArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + rowIter: Iterator[InternalRow], + maxRecordsPerBatch: Long, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!rowIter.hasNext) { + return false + } + + val startNs = System.nanoTime() + val writer = ArrowWriter.create(getVectorSchemaRoot) + var rowCount = 0L + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) + rowCount += 1 + } + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala new file mode 100644 index 0000000000..0af940fab3 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} + +/** + * `ArrowReader` over an iterator of Spark-side `ColumnarBatch`es (not Arrow-backed). Slices up to + * `maxRecordsPerBatch` rows per `loadNextBatch` from the current Spark batch into the reader's + * stable VSR via `ArrowWriter.writeCol`. Spark's `ColumnVector` implementations aren't Arrow + * buffers, so this reader necessarily copies element values into Arrow format. + */ +private[comet] class SparkColumnarArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch], + maxRecordsPerBatch: Int, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + private var current: ColumnarBatch = _ + private var rowsConsumedInCurrent: Int = 0 + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + private def advanceToNonEmptyBatch(): Boolean = { + while (current == null || rowsConsumedInCurrent >= current.numRows()) { + if (current != null) { + // We don't own Spark ColumnarBatches; just drop the reference. + current = null + rowsConsumedInCurrent = 0 + } + if (!source.hasNext) { + return false + } + current = source.next() + rowsConsumedInCurrent = 0 + } + true + } + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!advanceToNonEmptyBatch()) { + return false + } + + val startNs = System.nanoTime() + val rowsRemaining = current.numRows() - rowsConsumedInCurrent + val rowsToProduce = + if (maxRecordsPerBatch <= 0) rowsRemaining + else math.min(maxRecordsPerBatch, rowsRemaining) + + val writer = ArrowWriter.create(getVectorSchemaRoot) + var col = 0 + while (col < current.numCols()) { + val column = current.column(col) + val columnArray = new ColumnarArray(column, rowsConsumedInCurrent, rowsToProduce) + if (column.hasNull) { + writer.writeCol(columnArray, col) + } else { + writer.writeColNoNull(columnArray, col) + } + col += 1 + } + rowsConsumedInCurrent += rowsToProduce + + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..8a9fd9019c 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.vectorized.ColumnarBatch @@ -96,8 +98,14 @@ class CometNativeShuffleWriter[K, V]( // Getting rid of the fake partitionId val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + val arrowStream = CometArrowStream.fromColumnarBatchIter( + newInputs.asInstanceOf[Iterator[ColumnarBatch]], + CometUtils.fromAttributes(outputAttributes), + CometArrowStream.NATIVE_TIMEZONE, + "CometNativeShuffleWriter") + val cometIter = CometExec.getCometIterator( - Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + Array(arrowStream.asInstanceOf[Object]), outputAttributes.length, nativePlan, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 7d5398ae62..aa4ffad19f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ @@ -311,13 +312,13 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, numParts: Int, partitionIdx: Int): CometExecIterator = { getCometIterator( - inputs, + inputObjects, numOutputCols, nativePlan, CometMetricNode(Map.empty), @@ -332,14 +333,14 @@ object CometExec { * executing the same plan across multiple partitions to avoid serializing the plan repeatedly. */ def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, serializedPlan: Array[Byte], numParts: Int, partitionIdx: Int): CometExecIterator = { new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, serializedPlan, CometMetricNode(Map.empty), @@ -350,7 +351,7 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, nativeMetrics: CometMetricNode, @@ -361,7 +362,7 @@ object CometExec { val bytes = serializeNativePlan(nativePlan) new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, bytes, nativeMetrics, @@ -473,10 +474,11 @@ abstract class CometNativeExec extends CometExec { // Find planning data within this stage (stops at shuffle boundaries). val (commonByKey, perPartitionByKey) = findAllPlanData(this) - // Collect the input ColumnarBatches from the child operators and create a CometExecIterator - // to execute the native plan. + // Collect the input batches from the child operators. Non-shuffle inputs become + // RDD[ArrowArrayStream] (one stream per partition, exported via the C Stream Interface + // for native consumption); shuffle inputs stay as CometShuffledBatchRDD. val sparkPlans = ArrayBuffer.empty[SparkPlan] - val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]] + val inputs = ArrayBuffer.empty[RDD[_]] foreachUntilCometInput(this)(sparkPlans += _) @@ -503,15 +505,85 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") } + def isShuffleScanInput(plan: SparkPlan): Boolean = plan match { + case _: CometShuffleExchangeExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec => + true + case ReusedExchangeExec(_, _: CometShuffleExchangeExec) => true + case _ => false + } + + // The protobuf is the source of truth for whether a slot is a ShuffleScan or a regular + // Scan: `CometExchangeSink.shouldUseShuffleScan` only fires for AQE wrappers + // (`ShuffleQueryStageExec`), so a bare non-AQE `CometShuffleExchangeExec` always serializes + // as a regular Scan regardless of `COMET_SHUFFLE_DIRECT_READ_ENABLED`. Driving the JVM + // dispatch from `shuffleScanIndices` instead of the conf keeps the two aligned. + val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) + + def isBroadcastInput(plan: SparkPlan): Boolean = plan match { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true + case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true + case _ => false + } + + // Unwrap any number of AQE / reuse wrappers to find a CometBroadcastExchangeExec, if + // present. Returns the unwrapped exchange for input wiring -- broadcast partition counts + // are coerced to match firstNonBroadcastPlanNumPartitions, so we always read from the + // underlying exchange directly. + def asBroadcastExchange(plan: SparkPlan): Option[CometBroadcastExchangeExec] = + plan match { + case c: CometBroadcastExchangeExec => Some(c) + case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => Some(c) + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => Some(c) + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + Some(c) + case _ => None + } + + def asArrowStreamRDD(plan: SparkPlan, partitionCount: Int, scanSlot: Int): RDD[_] = + plan match { + case s: CometNativeArrowSource => + s.doExecuteAsArrowStream() + case _ if asBroadcastExchange(plan).isDefined => + val c = asBroadcastExchange(plan).get + CometArrowStream.wrapColumnarBatchRDD( + c.executeColumnar(partitionCount), + c.schema, + CometArrowStream.NATIVE_TIMEZONE, + c.nodeName) + case _ if isShuffleScanInput(plan) && shuffleScanIndices.contains(scanSlot) => + // Direct-read shuffle: `CometShuffledBatchRDD` reaches native via + // CometShuffleBlockIterator. Other shuffle slots fall through and get wrapped. + plan.executeColumnar() + case _ => + CometArrowStream.wrapColumnarBatchRDD( + plan.executeColumnar(), + plan.schema, + CometArrowStream.NATIVE_TIMEZONE, + plan.nodeName) + } + // If the first non broadcast plan is found, we need to adjust the partition number of // the broadcast plans to make sure they have the same partition number as the first non // broadcast plan. + // Walk-order: count how many non-CometNativeExec plans come before the firstNonBroadcast + // plan in `sparkPlans`. That's the slot index it will occupy in `inputs`, and therefore + // the protobuf scan-slot index whose Scan vs ShuffleScan classification governs whether + // it should be wrapped or direct-read. + val firstNonBroadcastSlot = sparkPlans + .take(firstNonBroadcastPlan.get._2) + .count(p => !p.isInstanceOf[CometNativeExec]) + val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = firstNonBroadcastPlan.get._1 match { case plan: CometNativeExec => (null, plan.outputPartitioning.numPartitions) case plan => - val rdd = plan.executeColumnar() + val rdd = asArrowStreamRDD(plan, 0, firstNonBroadcastSlot) (rdd, rdd.getNumPartitions) } @@ -520,24 +592,17 @@ abstract class CometNativeExec extends CometExec { // partition number of Broadcast RDDs to make sure they have the same partition number. sparkPlans.zipWithIndex.foreach { case (plan, idx) => plan match { - case c: CometBroadcastExchangeExec => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec( - _, - ReusedExchangeExec(_, c: CometBroadcastExchangeExec), - _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) case _: CometNativeExec => // no-op case _ if idx == firstNonBroadcastPlan.get._2 => inputs += firstNonBroadcastPlanRDD case _ => - val rdd = plan.executeColumnar() - if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { + // Each plan we add to `inputs` corresponds to the next protobuf scan slot, in + // walk order. `inputs.size` is the slot index this plan will occupy. + val scanSlot = inputs.size + val rdd = asArrowStreamRDD(plan, firstNonBroadcastPlanNumPartitions, scanSlot) + if (!isBroadcastInput(plan) && + rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { throw new CometRuntimeException( s"Partition number mismatch: ${rdd.getNumPartitions} != " + s"$firstNonBroadcastPlanNumPartitions") @@ -551,9 +616,6 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") } - // Detect ShuffleScan indices for direct read in CometExecRDD - val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) - // Unified RDD creation - CometExecRDD handles all cases val subqueries = collectSubqueries(this) val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 4605e641f1..a645c7b17a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -205,6 +205,16 @@ object Utils extends CometTypeShim with Logging { }.asJava) } + /** + * Build a `StructType` from a sequence of Spark `Attribute`s. Avoids + * `StructType.fromAttributes` (removed in Spark 4) and `DataTypeUtils.fromAttributes` (only on + * 4) so the same call works across supported Spark versions. + */ + def fromAttributes( + attributes: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]): StructType = + StructType(attributes.map(a => + org.apache.spark.sql.types.StructField(a.name, a.dataType, a.nullable, a.metadata))) + /** * Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark` * package because `ChunkedByteBufferOutputStream` is spark private class. As it uses Arrow diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index 9c34b3a3ce..e30a1cf6b3 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.PrettyAttribute import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometExec, CometExecUtils} -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch class CometNativeSuite extends CometTestBase { @@ -31,15 +32,16 @@ class CometNativeSuite extends CometTestBase { val rdd = spark.range(0, 1).rdd.map { value => val limitOp = CometExecUtils.getLimitNativePlan(Seq(PrettyAttribute("test", LongType)), 100).get - val cometIter = CometExec.getCometIterator( - Seq(new Iterator[ColumnarBatch] { + val arrowStream = CometArrowStream.fromColumnarBatchIter( + new Iterator[ColumnarBatch] { override def hasNext: Boolean = true override def next(): ColumnarBatch = throw new NullPointerException() - }), - 1, - limitOp, - 1, - 0) + }, + StructType(Seq(StructField("test", LongType, nullable = false))), + CometArrowStream.NATIVE_TIMEZONE, + "test-npe") + val cometIter = + CometExec.getCometIterator(Array(arrowStream.asInstanceOf[Object]), 1, limitOp, 1, 0) try { cometIter.next() } finally { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 8c8c19bb9c..d6075dac55 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3956,6 +3956,29 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec does not leak Arrow buffers (project consumer)") { + // Forces a CometNativeExec consumer over an ArrowArrayStream input. The producer must not + // leak the Arrow buffers it allocates per batch; if it does, the BaseAllocator + // leak detector fires inside the task completion listener. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select($"a" + 1)) + } + } + + test("CometLocalTableScanExec does not leak Arrow buffers (collect_list)") { + // Mirrors DataFrameAggregateSuite "collect functions" which is the test that + // surfaced the leak in CI. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select(collect_list($"a"), collect_list($"b"))) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala index b858fe5c83..a2aac7e6c7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala @@ -492,9 +492,13 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan InternalRow(i, UTF8String.fromString(s"value_$i")) } - // Create batches using rowToArrowBatchIter which handles shading internally + // Each emitted batch needs independent Arrow buffers so the test can hold rows from + // earlier batches while later batches are consumed. CometArrowConverters allocates a + // fresh VSR per batch from the supplied allocator. + val allocator = + org.apache.comet.CometArrowAllocator.newChildAllocator("c2r-test", 0, Long.MaxValue) val batchIter = CometArrowConverters - .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", null) + .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", allocator) val converter = new NativeColumnarToRowConverter(schema, rowsPerBatch) try { @@ -529,6 +533,7 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan "reused UnsafeRow object.") } finally { converter.close() + allocator.close() } } From b6db9969a5e732e309bab0c17a595d534a32dd25 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 14:02:44 -0400 Subject: [PATCH 09/39] Unpack dictionaries. --- .../arrow/ColumnarBatchArrowReader.scala | 34 ++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala index eaacc3968b..2cb8746107 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -23,11 +23,12 @@ import java.util.{ArrayList => JArrayList} import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.dictionary.DictionaryEncoder import org.apache.arrow.vector.ipc.ArrowReader import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.vector.CometVector +import org.apache.comet.vector.{CometDictionaryVector, CometVector} /** * `ArrowReader` over an iterator of Arrow-backed `ColumnarBatch`es. Each `loadNextBatch` unloads @@ -56,12 +57,29 @@ private[comet] class ColumnarBatchArrowReader( } val src = source.next() + var materialized: JArrayList[FieldVector] = null try { val sourceVectors = new JArrayList[FieldVector](src.numCols()) var i = 0 while (i < src.numCols()) { - sourceVectors.add( - src.column(i).asInstanceOf[CometVector].getValueVector.asInstanceOf[FieldVector]) + val col = src.column(i).asInstanceOf[CometVector] + val fv = col match { + case d: CometDictionaryVector => + // Stable VSR was built from the logical (non-dict) schema, so a dict-encoded + // source's indices layout would mismatch the dest buffer count on load. Native + // unpacks downstream anyway via copy_or_unpack_array. + val indices = d.getValueVector + val dictionary = d.provider.lookup(indices.getField.getDictionary.getId) + val plain = DictionaryEncoder + .decode(indices, dictionary, allocator) + .asInstanceOf[FieldVector] + if (materialized == null) materialized = new JArrayList[FieldVector]() + materialized.add(plain) + plain + case _ => + col.getValueVector.asInstanceOf[FieldVector] + } + sourceVectors.add(fv) i += 1 } val transient = new VectorSchemaRoot(sourceVectors) @@ -74,9 +92,17 @@ private[comet] class ColumnarBatchArrowReader( } finally { rb.close() } - // Note: do not close `transient`. It shares FieldVectors with `src`; closing `src` below + // Do not close `transient`. It shares FieldVectors with `src`; closing `src` below // releases the producer-side refs. Closing `transient` would double-release. } finally { + if (materialized != null) { + var j = 0 + while (j < materialized.size()) { + try materialized.get(j).close() + catch { case _: Throwable => () } + j += 1 + } + } src.close() } true From cf7bb6ed003f5e2a8a617a9aab4ec45e249b5dc6 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 14:45:22 -0400 Subject: [PATCH 10/39] Fix shading issue. --- spark/pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/spark/pom.xml b/spark/pom.xml index 6d97ea831f..25c6c34e45 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -476,6 +476,7 @@ under the License. org/apache/arrow/c/jni/PrivateData org/apache/arrow/c/jni/CDataJniException + org/apache/arrow/c/ArrayStreamExporter org/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData From 82c9a1b357da4cc91d65b66f3707b56db104ee74 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 15:25:45 -0400 Subject: [PATCH 11/39] Try again to fix shading issue. --- spark/pom.xml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/spark/pom.xml b/spark/pom.xml index 25c6c34e45..ef1b07abf1 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -469,15 +469,13 @@ under the License. org.apache.arrow ${comet.shade.packageName}.arrow - - org/apache/arrow/c/jni/JniWrapper - org/apache/arrow/c/jni/PrivateData - org/apache/arrow/c/jni/CDataJniException - - org/apache/arrow/c/ArrayStreamExporter - org/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData + + org/apache/arrow/c/** From 6adf124c51dd1a7b7c1d2da82ee97489a4cfb32a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Wed, 27 May 2026 16:36:44 -0400 Subject: [PATCH 12/39] Fix alignment issue for FFI Decimal128 with ArrowArrayStreamReader --- .../operators/aligned_stream_reader.rs | 110 ++++++++++++++++++ native/core/src/execution/operators/mod.rs | 2 + native/core/src/execution/operators/scan.rs | 11 +- native/core/src/execution/planner.rs | 5 +- 4 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 native/core/src/execution/operators/aligned_stream_reader.rs diff --git a/native/core/src/execution/operators/aligned_stream_reader.rs b/native/core/src/execution/operators/aligned_stream_reader.rs new file mode 100644 index 0000000000..c1d615a79f --- /dev/null +++ b/native/core/src/execution/operators/aligned_stream_reader.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{RecordBatch, RecordBatchOptions, StructArray}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::error::ArrowError; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use std::ffi::CStr; +use std::sync::Arc; + +/// C Stream Interface reader that calls [`arrow::array::ArrayData::align_buffers`] on every +/// imported batch before constructing typed arrays. Stock `ArrowArrayStreamReader` panics +/// when a JVM producer hands us a `Decimal128` buffer at an offset that is 8-byte but not +/// 16-byte aligned, which Java's allocator does not guarantee. Track upstream: +/// . +#[derive(Debug)] +pub struct AlignedArrowStreamReader { + stream: FFI_ArrowArrayStream, + schema: SchemaRef, +} + +impl AlignedArrowStreamReader { + /// # Safety + /// `raw` must point at a valid `FFI_ArrowArrayStream` whose ownership is being transferred + /// to this reader. The stream's release callback fires when the reader is dropped. + pub unsafe fn from_raw(raw: *mut FFI_ArrowArrayStream) -> Result { + let mut stream = FFI_ArrowArrayStream::from_raw(raw); + if stream.release.is_none() { + return Err(ArrowError::CDataInterface( + "input stream is already released".to_string(), + )); + } + let schema = read_schema(&mut stream)?; + Ok(Self { stream, schema }) + } + + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn last_error(&mut self) -> Option { + let get = self.stream.get_last_error?; + let ptr = unsafe { get(&mut self.stream) }; + if ptr.is_null() { + return None; + } + Some( + unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned(), + ) + } +} + +impl Iterator for AlignedArrowStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + let mut array = FFI_ArrowArray::empty(); + let ret = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) }; + if ret != 0 { + let msg = self + .last_error() + .unwrap_or_else(|| format!("get_next returned {ret}")); + return Some(Err(ArrowError::CDataInterface(msg))); + } + if array.is_released() { + return None; + } + + let dt = DataType::Struct(self.schema.fields().clone()); + Some( + unsafe { from_ffi_and_data_type(array, dt) }.and_then(|mut data| { + data.align_buffers(); + let len = data.len(); + RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + StructArray::from(data).into_parts().1, + &RecordBatchOptions::new().with_row_count(Some(len)), + ) + }), + ) + } +} + +fn read_schema(stream: &mut FFI_ArrowArrayStream) -> Result { + let mut schema = FFI_ArrowSchema::empty(); + let ret = unsafe { stream.get_schema.unwrap()(stream, &mut schema) }; + if ret != 0 { + return Err(ArrowError::CDataInterface(format!( + "Cannot get schema from input stream. Error code: {ret}" + ))); + } + Ok(Arc::new(Schema::try_from(&schema)?)) +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 4b2c06575d..d68252bd9b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -19,10 +19,12 @@ pub use crate::errors::ExecutionError; +pub use aligned_stream_reader::*; pub use copy::*; pub use iceberg_scan::*; pub use scan::*; +mod aligned_stream_reader; mod copy; mod expand; pub use expand::ExpandExec; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index bad349e4d7..2ef32f6a13 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,12 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::operators::{copy_or_unpack_array, CopyMode}; +use crate::execution::operators::{copy_or_unpack_array, AlignedArrowStreamReader, CopyMode}; use crate::{errors::CometError, execution::planner::TEST_EXEC_CONTEXT_ID}; use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::ffi_stream::ArrowArrayStreamReader; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -50,7 +49,7 @@ pub struct ScanExec { pub exec_context_id: i64, /// The C Stream Interface reader. `None` only in unit tests that seed input via /// `set_input_batch`. - pub input_source: Option>>, + pub input_source: Option>>, pub input_source_description: String, pub data_types: Vec, pub schema: SchemaRef, @@ -65,7 +64,7 @@ pub struct ScanExec { impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>>, + input_source: Option>>, input_source_description: &str, data_types: Vec, ) -> Result { @@ -136,7 +135,7 @@ impl ScanExec { /// columns are unpacked because Comet's downstream operators do not handle them. fn pull_next( exec_context_id: i64, - reader: &Arc>, + reader: &Arc>, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { // Unit test path; input batches are seeded directly. @@ -145,7 +144,7 @@ impl ScanExec { let mut reader = reader .try_lock() - .map_err(|_| CometError::Internal("ArrowArrayStreamReader contended".to_string()))?; + .map_err(|_| CometError::Internal("AlignedArrowStreamReader contended".to_string()))?; let next = reader.next(); match next { diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 77b174ea8d..6213ce6b11 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -23,6 +23,7 @@ pub mod operator_registry; use crate::errors::CometError; use crate::execution::operators::init_csv_datasource_exec; +use crate::execution::operators::AlignedArrowStreamReader; use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::list_positions::ListPositionsExpr, @@ -1451,7 +1452,7 @@ impl PhysicalPlanner { // Consumes the first input source for the scan. The Java side passes an // `org.apache.arrow.c.ArrowArrayStream` whose `memoryAddress` points at the C - // struct; native takes ownership via `ArrowArrayStreamReader::from_raw`. + // struct; native takes ownership via `AlignedArrowStreamReader::from_raw`. let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { @@ -1466,7 +1467,7 @@ impl PhysicalPlanner { Ok(addr) })?; let reader = unsafe { - arrow::ffi_stream::ArrowArrayStreamReader::from_raw( + AlignedArrowStreamReader::from_raw( address as *mut arrow::ffi_stream::FFI_ArrowArrayStream, ) } From 0e080186363ea196b8b1fa654521e4961c29ed47 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 17:36:14 -0400 Subject: [PATCH 13/39] Fix schema mismatch in CometArrowStream. --- .../arrow/CometNativeArrowSource.scala | 78 +++++++++++- .../arrow/CometArrowStreamSuite.scala | 115 ++++++++++++++++++ 2 files changed, 187 insertions(+), 6 deletions(-) create mode 100644 spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala index 14a7a9ed0c..c6fa93cfc1 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.comet.execution.arrow +import scala.jdk.CollectionConverters._ + import org.apache.arrow.c.{ArrowArrayStream, Data} import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.{Field, Schema} import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.SparkPlan @@ -30,7 +34,7 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.CometArrowAllocator -import org.apache.comet.vector.NativeUtil +import org.apache.comet.vector.{CometDictionaryVector, CometVector, NativeUtil} /** * Marker for Comet operators that can produce Arrow data destined for a Comet native executor @@ -40,7 +44,7 @@ trait CometNativeArrowSource extends SparkPlan { def doExecuteAsArrowStream(): RDD[ArrowArrayStream] } -object CometArrowStream { +object CometArrowStream extends Logging { /** * Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone; @@ -61,8 +65,9 @@ object CometArrowStream { // Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema // inside the per-task body so the closure cleaner doesn't try to ship a Schema across. rdd.mapPartitionsInternal { batchIter => - val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) - stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, batchIter)) + val expected = Utils.toArrowSchema(sparkSchema, timeZoneId) + val (arrowSchema, iter) = reconcileStreamSchema(name, expected, batchIter) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)) } } @@ -76,8 +81,69 @@ object CometArrowStream { sparkSchema: StructType, timeZoneId: String, name: String): ArrowArrayStream = { - val arrowSchema = Utils.toArrowSchema(sparkSchema, timeZoneId) - stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)).next() + val expected = Utils.toArrowSchema(sparkSchema, timeZoneId) + val (arrowSchema, reconciled) = reconcileStreamSchema(name, expected, iter) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, reconciled)) + .next() + } + + /** + * Build the stream's advertised Arrow schema from the actual `CometVector` types in the first + * batch, not from `expected` (which derives from the consumer's Spark-declared types). Native + * operators like `ScanExec` already cast their input to the declared scan-input schema in + * `build_record_batch`, so the truthful schema lets that cast actually fire. Advertising + * `expected` instead silently mislabels Int32 buffers as Int64 (and similar) and corrupts on + * import. See PR #4393 width_bucket investigation. + * + * If the first batch's column types differ from `expected` in their `DataType` (timezone-only + * differences on `Timestamp` are ignored), log one warning naming the operator, column, and + * type drift; the cast happens transparently downstream in native. + */ + private[arrow] def reconcileStreamSchema( + name: String, + expected: Schema, + iter: Iterator[ColumnarBatch]): (Schema, Iterator[ColumnarBatch]) = { + val buffered = iter.buffered + if (!buffered.hasNext) { + // Empty partition: keep the consumer-declared schema; consumer can still build its plan. + return (expected, buffered) + } + val first = buffered.head + val expectedFields = expected.getFields + val actualFields = (0 until first.numCols()).map { i => + val col = first.column(i).asInstanceOf[CometVector] + actualFieldOf(col, expectedFields.get(i)) + } + val mismatches = actualFields.zip(expectedFields.asScala).zipWithIndex.collect { + case ((actual, exp), idx) if actual.getType != exp.getType => + s"col[$idx] '${exp.getName}': expected ${exp.getType}, child produced ${actual.getType}" + } + if (mismatches.nonEmpty) { + logWarning( + s"CometArrowStream '$name' input schema mismatch: ${mismatches.mkString("; ")}. " + + "Native ScanExec will cast at the boundary. This usually means a DataFusion-Spark " + + "function declares a different return type than Spark catalyst.") + } + (new Schema(actualFields.asJava), buffered) + } + + /** + * The Arrow field that this column's buffers will look like once unloaded. For a + * `CometDictionaryVector`, [[ColumnarBatchArrowReader]] decodes it via + * `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's + * *value* type, not `Dictionary`. For everything else, use the underlying value + * vector's field. Field name / nullability / metadata come from `expected` so that consumers + * indexing by name keep working. + */ + private def actualFieldOf(col: CometVector, expected: Field): Field = { + val raw = col match { + case d: CometDictionaryVector => + val indices = d.getValueVector + val dict = d.provider.lookup(indices.getField.getDictionary.getId) + dict.getVector.getField + case _ => col.getValueVector.getField + } + new Field(expected.getName, raw.getFieldType, raw.getChildren) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala new file mode 100644 index 0000000000..030f00787c --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.jdk.CollectionConverters._ + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{BigIntVector, IntVector} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.{CometPlainVector, CometVector} + +/** + * Direct tests for [[CometArrowStream.reconcileStreamSchema]]. The end-to-end regression that + * motivated this (Spark Long vs DataFusion Int32 for `width_bucket`) lives in + * `CometMathExpressionSuite`, but that test only catches *one* function-level type drift. This + * suite covers the boundary contract independently of any specific function. + */ +class CometArrowStreamSuite extends AnyFunSuite with Matchers { + + private def expectedSchema(types: (String, ArrowType)*): Schema = { + val fields = types.map { case (name, t) => + new Field(name, new FieldType(true, t, null), java.util.Collections.emptyList[Field]()) + } + new Schema(fields.asJava) + } + + private def batchOf(vectors: CometVector*): ColumnarBatch = { + val numRows = if (vectors.isEmpty) 0 else vectors.head.getValueVector.getValueCount + new ColumnarBatch(vectors.toArray, numRows) + } + + test("reconcileStreamSchema returns expected schema unchanged on empty iterator") { + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + val (returned, iter) = + CometArrowStream.reconcileStreamSchema("test", expected, Iterator.empty) + returned shouldBe expected + iter.hasNext shouldBe false + } + + test("reconcileStreamSchema returns expected schema when types match") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + val v = new BigIntVector("col_0", allocator) + v.allocateNew() + v.setSafe(0, 1L) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + + val (returned, iter) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + returned.getFields.get(0).getType shouldBe new ArrowType.Int(64, true) + iter.hasNext shouldBe true + iter.next() should be theSameInstanceAs batch + + cv.close() + } finally { + allocator.close() + } + } + + test("reconcileStreamSchema rebuilds schema from actual vector types when they differ") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + // Producer produced Int32 (e.g., DataFusion-Spark width_bucket pre-fix), consumer expects + // Int64 (Spark catalyst WidthBucket.dataType = LongType). The truthful schema is Int32 so + // native ScanExec's build_record_batch can cast at the boundary. + val v = new IntVector("col_0", allocator) + v.allocateNew() + v.setSafe(0, 1) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + + val (returned, iter) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + val returnedField = returned.getFields.get(0) + returnedField.getType shouldBe new ArrowType.Int(32, true) + // Names come from `expected` so name-indexed consumers keep working. + returnedField.getName shouldBe "c0" + iter.hasNext shouldBe true + iter.next() should be theSameInstanceAs batch + + cv.close() + } finally { + allocator.close() + } + } +} From a5046e3d0d623c7f3efd768f03a62001f8d07ff9 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 18:29:11 -0400 Subject: [PATCH 14/39] Fix nullability mismatch in CometArrowStreamSuite. --- .../arrow/CometNativeArrowSource.scala | 17 +++++++--- .../arrow/CometArrowStreamSuite.scala | 32 +++++++++++++++++++ 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala index c6fa93cfc1..92ffc61acd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -24,7 +24,7 @@ import scala.jdk.CollectionConverters._ import org.apache.arrow.c.{ArrowArrayStream, Data} import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.ipc.ArrowReader -import org.apache.arrow.vector.types.pojo.{Field, Schema} +import org.apache.arrow.vector.types.pojo.{Field, FieldType, Schema} import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD @@ -132,8 +132,14 @@ object CometArrowStream extends Logging { * `CometDictionaryVector`, [[ColumnarBatchArrowReader]] decodes it via * `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's * *value* type, not `Dictionary`. For everything else, use the underlying value - * vector's field. Field name / nullability / metadata come from `expected` so that consumers - * indexing by name keep working. + * vector's field. + * + * Field name and metadata come from `expected` so that consumers indexing by name keep working. + * Nullability is the union of the two — a CometVector that happens to hold no nulls in this + * batch can still be nullable per Spark's contract (the next batch may have one), and a column + * whose actual buffer carries validity bits must stay nullable even if Spark thought otherwise. + * Taking only `raw.isNullable` here would advertise non-nullable when the next batch does carry + * a null and crash native validation. */ private def actualFieldOf(col: CometVector, expected: Field): Field = { val raw = col match { @@ -143,7 +149,10 @@ object CometArrowStream extends Logging { dict.getVector.getField case _ => col.getValueVector.getField } - new Field(expected.getName, raw.getFieldType, raw.getChildren) + val nullable = expected.isNullable || raw.isNullable + val fieldType = + new FieldType(nullable, raw.getType, raw.getDictionary, expected.getMetadata) + new Field(expected.getName, fieldType, raw.getChildren) } /** diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala index 030f00787c..bf2e190a5b 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala @@ -112,4 +112,36 @@ class CometArrowStreamSuite extends AnyFunSuite with Matchers { allocator.close() } } + + test("reconcileStreamSchema preserves nullability when expected is nullable but actual is not") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + // Spark catalyst declares the column nullable; the first batch happens to come from a + // vector whose Field reports non-nullable. Subsequent batches may carry nulls, so the + // wire schema must stay nullable or native validation rejects the next null with + // "declared as non-nullable but contains null values". + val v = new BigIntVector( + new Field( + "col_0", + new FieldType(false, new ArrowType.Int(64, true), null), + java.util.Collections.emptyList[Field]()), + allocator) + v.allocateNew() + v.setSafe(0, 1L) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) // nullable=true + + val (returned, _) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + val returnedField = returned.getFields.get(0) + returnedField.isNullable shouldBe true + + cv.close() + } finally { + allocator.close() + } + } } From 8f1c35af4ce23c347cff56d800c11d45dacccc4a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 18:30:33 -0400 Subject: [PATCH 15/39] Fix format. --- .../sql/comet/execution/arrow/CometArrowStreamSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala index bf2e190a5b..c423a49d2a 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala @@ -113,7 +113,8 @@ class CometArrowStreamSuite extends AnyFunSuite with Matchers { } } - test("reconcileStreamSchema preserves nullability when expected is nullable but actual is not") { + test( + "reconcileStreamSchema preserves nullability when expected is nullable but actual is not") { val allocator = new RootAllocator(Integer.MAX_VALUE) try { // Spark catalyst declares the column nullable; the first batch happens to come from a From 5c4121526a73be75fe3d02bb710bf1029ed1ce5a Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 20:21:47 -0400 Subject: [PATCH 16/39] Passes CometFuzzTestSuite, CometNativeShuffleSuite, CometExecSuite. --- .../shuffle/CometNativeShuffleWriter.scala | 341 +++++++++--------- .../shuffle/CometShuffleDependency.scala | 13 +- .../shuffle/CometShuffleExchangeExec.scala | 170 ++++++++- .../shuffle/CometShuffleManager.scala | 3 + .../apache/spark/sql/comet/operators.scala | 320 ++++++++-------- 5 files changed, 528 insertions(+), 319 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..68078e8507 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -32,19 +32,34 @@ import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsR import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometScalarSubquery, NativeExecContext, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExecIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} -import org.apache.comet.serde.QueryPlanSerde.serializeDataType /** - * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * A [[ShuffleWriter]] that drives the native shuffle write in a single [[CometExecIterator]] per + * partition. The unified plan it executes has [[OperatorOuterClass.ShuffleWriter]] at the root + * with `childNativeOp` as its only child. Leaf input iterators come from + * [[CometNativeShuffleInputIterator]] (constructed by [[CometNativeShuffleInputRDD.compute]]). + * + * Two flavors of `childNativeOp` are in use: + * - rich Comet native subtree (e.g. HashAgg / Filter / ShuffleScan), supplied by + * [[CometShuffleExchangeExec]] when its child is a + * [[org.apache.spark.sql.comet.CometNativeExec]]. + * - synthetic `Scan("ShuffleWriterInput")` placeholder, supplied by the convenience overload of + * [[CometShuffleExchangeExec.prepareShuffleDependency]] for callers that already hold an + * `RDD[ColumnarBatch]` of native-driven batches (e.g. + * [[org.apache.spark.sql.comet.CometCollectLimitExec]]). + * + * The writer treats both shapes identically. */ class CometNativeShuffleWriter[K, V]( + childNativeOp: Operator, + childMetricNode: CometMetricNode, + nativeContext: NativeExecContext, outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], metrics: Map[String, SQLMetric], @@ -72,8 +87,22 @@ class CometNativeShuffleWriter[K, V]( val tempDataFilePath = Paths.get(tempDataFilename) val tempIndexFilePath = Paths.get(tempIndexFilename) - // Call native shuffle write - val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + // Pull the per-partition leaf iterators and partition index from the iterator handed to us + // by Spark's ShuffleMapTask. CometNativeShuffleInputRDD.compute always returns this exact + // iterator type; no other RDD layers between produce a Product2[Int, ColumnarBatch]. + val shuffleInputIter = inputs.asInstanceOf[CometNativeShuffleInputIterator] + val partitionIdx = shuffleInputIter.partitionIndex + val leafIterators = shuffleInputIter.leafIterators + val shuffleBlockIters = shuffleInputIter.shuffleBlockIterators + + val unifiedPlan = buildUnifiedPlan(tempDataFilename, tempIndexFilename) + val finalNativePlan = if (nativeContext.commonByKey.nonEmpty) { + val partitionDataByKey = + nativeContext.perPartitionByKey.map { case (k, arr) => k -> arr(partitionIdx) } + PlanDataInjector.injectPlanData(unifiedPlan, nativeContext.commonByKey, partitionDataByKey) + } else { + unifiedPlan + } val detailedMetrics = Seq( "elapsed_compute", @@ -82,29 +111,42 @@ class CometNativeShuffleWriter[K, V]( "input_batches", "spill_count", "spilled_bytes") - - // Maps native metrics to SQL metrics val metricsOutputRows = new SQLMetric("outputRows") val metricsWriteTime = new SQLMetric("writeTime") - val nativeSQLMetrics = Map( + val shuffleWriterSQLMetrics = Map( "output_rows" -> metricsOutputRows, "data_size" -> metrics("dataSize"), "write_time" -> metricsWriteTime) ++ metrics.filterKeys(detailedMetrics.contains) - val nativeMetrics = CometMetricNode(nativeSQLMetrics) - // Getting rid of the fake partitionId - val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + // ShuffleWriter metrics live at the root of the metric tree; the child operator's metric + // tree (rich subtree or empty leaf for the Scan placeholder) is attached underneath so the + // SQL UI sees the same per-node breakdown the split-driver flow produced. + val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(childMetricNode)) - val cometIter = CometExec.getCometIterator( - Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + val cometIter = new CometExecIterator( + CometExec.newIterId, + leafIterators, outputAttributes.length, - nativePlan, + PlanDataInjector.serializeOperator(finalNativePlan), nativeMetrics, numParts, - context.partitionId(), - broadcastedHadoopConfForEncryption = None, - encryptedFilePaths = Seq.empty) + partitionIdx, + nativeContext.broadcastedHadoopConfForEncryption, + nativeContext.encryptedFilePaths, + shuffleBlockIters) + + // Register subqueries against the iterator id so native callbacks resolve them to values. + nativeContext.subqueries.foreach { sub => + CometScalarSubquery.setSubquery(cometIter.id, sub) + } + Option(context).foreach { taskCtx => + taskCtx.addTaskCompletionListener[Unit] { _ => + nativeContext.subqueries.foreach { sub => + CometScalarSubquery.removeSubquery(cometIter.id, sub) + } + } + } while (cometIter.hasNext) { cometIter.next() @@ -134,7 +176,7 @@ class CometNativeShuffleWriter[K, V]( // Report spill metrics to Spark's task metrics so they appear in // Spark UI task summaries (not just SQL metrics) - val spilledBytes = nativeSQLMetrics.get("spilled_bytes").map(_.value).getOrElse(0L) + val spilledBytes = shuffleWriterSQLMetrics.get("spilled_bytes").map(_.value).getOrElse(0L) if (spilledBytes > 0) { context.taskMetrics().incMemoryBytesSpilled(spilledBytes) context.taskMetrics().incDiskBytesSpilled(spilledBytes) @@ -162,163 +204,138 @@ class CometNativeShuffleWriter[K, V]( case _ => false } - private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) + /** + * Build the unified `ShuffleWriter(child = childNativeOp)` plan with the partitioning serde, + * compression settings, and output file paths. + */ + private def buildUnifiedPlan(dataFile: String, indexFile: String): Operator = { + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) + + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") + } + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) } - - if (scanTypes.length == outputAttributes.length) { - scanBuilder.addAllFields(scanTypes.asJava) - - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) - - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) + + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) - } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) - - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) + } + } - val partitionExprs = hashPartitioning.expressions + { + val orderingExprs = rangePartitioning.ordering .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - - if (partitionExprs.length != hashPartitioning.expressions.length) { + if (orderingExprs.length != rangePartitioning.ordering.length) { throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } - - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } - } - - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) + .map(_._1) + }) + + { + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - case _: RoundRobinPartitioning => - val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - partitioning.setMaxHashColumns( - CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) + case _: RoundRobinPartitioning => + val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + partitioning.setMaxHashColumns( + CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRoundRobinPartition(partitioning).build()) + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRoundRobinPartition(partitioning).build()) - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") + } - shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) + shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") - } + OperatorOuterClass.Operator + .newBuilder() + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(childNativeOp) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..9ec25c49ec 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -28,11 +28,19 @@ import org.apache.spark.shuffle.ShuffleWriteProcessor import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.{CometMetricNode, NativeExecContext} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. + * + * On the native-shuffle path, also carries the child plan's per-partition execution context, root + * native operator, and metric node so [[CometNativeShuffleWriter]] can drive the unified + * `ShuffleWriter(child = childNativeOp)` plan in a single [[org.apache.comet.CometExecIterator]] + * per partition. These three fields are populated only when `shuffleType == CometNativeShuffle`. */ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @transient private val _rdd: RDD[_ <: Product2[K, V]], @@ -49,7 +57,10 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + val nativeExecContext: Option[NativeExecContext] = None, + val childNativeOp: Option[OperatorOuterClass.Operator] = None, + val childMetricNode: Option[CometMetricNode] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 493c20f8b7..16bd40d402 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Exp import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder, NativeExecContext} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} @@ -101,9 +101,38 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Per-partition execution context for the child native subtree, computed once and shared + * between [[inputRDD]] (which uses it to wire DAGScheduler dependencies) and + * [[shuffleDependency]] (which threads it through to [[CometNativeShuffleWriter]] for + * single-iterator native execution). Only populated when `shuffleType == CometNativeShuffle` + * AND the child is a [[CometNativeExec]] subtree we can inline. When the child is a non-native + * Comet plan (e.g. [[org.apache.spark.sql.comet.CometSparkToColumnarExec]]), this stays `None` + * and the shuffle falls back to the legacy `Scan("ShuffleWriterInput") -> ShuffleWriter` plan + * via the convenience overload of `prepareShuffleDependency`. + */ + @transient private lazy val nativeChildContext: Option[NativeExecContext] = child match { + case nativeChild: CometNativeExec if shuffleType == CometNativeShuffle => + Some(nativeChild.buildNativeContext()) + case _ => None + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + nativeChildContext match { + case Some(ctx) => + // Single-driver path: thin scheduling anchor; CometNativeShuffleWriter drives the + // unified ShuffleWriter + child plan in a single CometExecIterator per partition. + new CometNativeShuffleInputRDD( + sparkContext, + ctx.inputs, + ctx.numPartitions, + ctx.shuffleScanIndices) + case None => + // Child is a Comet plan but not a CometNativeExec subtree (e.g. CometSparkToColumnarExec). + // No native subtree to inline; the writer's plan is `Scan("ShuffleWriterInput") -> + // ShuffleWriter` and JVM batches flow into native through Arrow C Stream Interface. + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec @@ -149,12 +178,36 @@ case class CometShuffleExchangeExec( @transient lazy val shuffleDependency: ShuffleDependency[Int, _, _] = if (shuffleType == CometNativeShuffle) { - val dep = CometShuffleExchangeExec.prepareShuffleDependency( - inputRDD.asInstanceOf[RDD[ColumnarBatch]], - child.output, - outputPartitioning, - serializer, - metrics) + val dep = nativeChildContext match { + case Some(ctx) => + // Single-driver path: child is a CometNativeExec subtree. RangePartitioning needs real + // rows to compute partition bounds; use a regular columnar execution of the child for + // sampling only. The actual shuffle still goes through the single-iterator path. + val nativeChild = child.asInstanceOf[CometNativeExec] + val samplingRDD: Option[RDD[ColumnarBatch]] = outputPartitioning match { + case _: RangePartitioning => Some(child.executeColumnar()) + case _ => None + } + CometShuffleExchangeExec.prepareNativeShuffleDependency( + inputRDD.asInstanceOf[CometNativeShuffleInputRDD], + samplingRDD, + child.output, + outputPartitioning, + serializer, + metrics, + ctx, + nativeChild.nativeOp, + CometMetricNode.fromCometPlan(nativeChild)) + case None => + // Child is a non-native Comet plan; the writer falls back to its Scan-placeholder + // path via the convenience overload of prepareShuffleDependency. + CometShuffleExchangeExec.prepareShuffleDependency( + inputRDD.asInstanceOf[RDD[ColumnarBatch]], + child.output, + outputPartitioning, + serializer, + metrics) + } metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -623,21 +676,107 @@ object CometShuffleExchangeExec } } + /** + * Build a Comet native shuffle dependency around an existing `RDD[ColumnarBatch]` of real + * batches. Used by [[org.apache.spark.sql.comet.CometCollectLimitExec]] and + * [[org.apache.spark.sql.comet.CometTakeOrderedAndProjectExec]] where the input is the result + * of a local-limit / topK transform and there is no separate child native subtree to inline. + * + * Implemented as a thin wrapper around [[prepareNativeShuffleDependency]]: synthesizes a + * `Scan("ShuffleWriterInput")` as the child native op (so the writer's plan is still + * `ShuffleWriter -> Scan`, consuming JVM batches via Arrow C Stream), wraps `rdd` as the single + * leaf input of a thin scheduling RDD, and supplies a minimal [[NativeExecContext]]. Same wire + * shape as before; one writer code path for both this case and the [[CometShuffleExchangeExec]] + * case. + */ def prepareShuffleDependency( rdd: RDD[ColumnarBatch], outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - val numParts = rdd.getNumPartitions + + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatten { attr => + QueryPlanSerde.serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } + scanBuilder.addAllFields(scanTypes.asJava) + val scanOp = OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + + val thinRDD = new CometNativeShuffleInputRDD( + rdd.sparkContext, + Seq(rdd), + rdd.getNumPartitions, + shuffleScanIndices = Set.empty) + + val ctx = NativeExecContext( + inputs = Seq(rdd), + numPartitions = rdd.getNumPartitions, + subqueries = Seq.empty, + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty, + commonByKey = Map.empty, + perPartitionByKey = Map.empty, + shuffleScanIndices = Set.empty, + hasScanInput = false) + + // The Scan placeholder has no per-operator metrics, so the metric tree for the unified plan + // is `shuffleWriterMetrics` at the root with one empty leaf for the Scan child. + prepareNativeShuffleDependency( + thinRDD, + Some(rdd), + outputAttributes, + outputPartitioning, + serializer, + metrics, + ctx, + scanOp, + CometMetricNode(Map.empty)) + } + + /** + * Build a Comet native shuffle dependency for the [[CometShuffleExchangeExec]] case where the + * shuffle is fed by a [[CometNativeExec]] child. The writer drives the unified + * `ShuffleWriter(child = childNativeOp)` plan in a single + * [[org.apache.comet.CometExecIterator]] per partition. The returned dep carries the child's + * per-partition execution context, root native operator, and metric node so + * [[CometNativeShuffleWriter]] can reach them at task time. + * + * @param thinRDD + * scheduling-anchor RDD whose `compute` returns a [[CometNativeShuffleInputIterator]]; + * produces no batches itself. + * @param samplingRDD + * regular columnar execution of the child, only required for [[RangePartitioning]] (sampling + * needs real rows). `None` for hash / single / round-robin. + */ + def prepareNativeShuffleDependency( + thinRDD: CometNativeShuffleInputRDD, + samplingRDD: Option[RDD[ColumnarBatch]], + outputAttributes: Seq[Attribute], + outputPartitioning: Partitioning, + serializer: Serializer, + metrics: Map[String, SQLMetric], + nativeExecContext: NativeExecContext, + childNativeOp: OperatorOuterClass.Operator, + childMetricNode: CometMetricNode): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = thinRDD.getNumPartitions // The code block below is mostly brought over from // ShuffleExchangeExec::prepareShuffleDependency val (partitioner, rangePartitionBounds) = outputPartitioning match { case rangePartitioning: RangePartitioning => + // Sampling needs real rows; use the dedicated samplingRDD (a regular columnar execution + // of the child). The thin RDD itself yields nothing. + val samplingInput = samplingRDD.getOrElse( + throw new IllegalStateException( + "RangePartitioning requires a samplingRDD on the native-shuffle path")) // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner - val rddForSampling = rdd.mapPartitionsInternal { iter => + val rddForSampling = samplingInput.mapPartitionsInternal { iter => val projection = UnsafeProjection.create(rangePartitioning.ordering.map(_.child), outputAttributes) val mutablePair = new MutablePair[InternalRow, Null]() @@ -683,9 +822,7 @@ object CometShuffleExchangeExec } val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rdd.map( - (0, _) - ), // adding fake partitionId that is always 0 because ShuffleDependency requires it + thinRDD, serializer = serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(metrics), shuffleType = CometNativeShuffle, @@ -695,7 +832,10 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) + rangePartitionBounds = rangePartitionBounds, + nativeExecContext = Some(nativeExecContext), + childNativeOp = Some(childNativeOp), + childMetricNode = Some(childMetricNode)) dependency } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index c8f2199d53..16a21bbfd9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -231,6 +231,9 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { case cometShuffleHandle: CometNativeShuffleHandle[K @unchecked, V @unchecked] => val dep = cometShuffleHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]] new CometNativeShuffleWriter( + dep.childNativeOp.get, + dep.childMetricNode.get, + dep.nativeExecContext.get, dep.outputPartitioning.get, dep.outputAttributes, dep.shuffleWriteMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 7d5398ae62..748362cd65 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -381,6 +381,29 @@ object CometExec { } } +/** + * Per-partition execution context for a native subtree rooted at a [[CometNativeExec]] boundary. + * + * Built once on the driver from the SparkPlan tree and consumed by either: + * - [[CometNativeExec.doExecuteColumnar]] to construct a [[CometExecRDD]], or + * - the native-shuffle path, where [[CometNativeShuffleWriter]] drives the same child plan with + * a `ShuffleWriter` operator as its root. + * + * The fields capture everything that depends on tree-walking the SparkPlan and aligning leaf + * input RDDs (broadcast partition counts, plan-data, subqueries, encryption options) so callers + * do not have to re-derive them. + */ +private[comet] case class NativeExecContext( + inputs: Seq[RDD[ColumnarBatch]], + numPartitions: Int, + subqueries: Seq[ScalarSubquery], + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], + encryptedFilePaths: Seq[String], + commonByKey: Map[String, Array[Byte]], + perPartitionByKey: Map[String, Array[Array[Byte]]], + shuffleScanIndices: Set[Int], + hasScanInput: Boolean) + /** * A Comet native physical operator. */ @@ -426,157 +449,30 @@ abstract class CometNativeExec extends CometExec { throw new CometRuntimeException( s"CometNativeExec should not be executed directly without a serialized plan: $this") case Some(serializedPlan) => - val serializedPlanCopy = serializedPlan // TODO: support native metrics for all operators. val nativeMetrics = CometMetricNode.fromCometPlan(this) + val ctx = buildNativeContext() - // Go over all the native scans, in order to see if they need encryption options. - // For each relation in a CometNativeScan generate a hadoopConf, - // for each file path in a relation associate with hadoopConf - // This is done per native plan, so only count scans until a comet input is reached. - val encryptionOptions = - mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] - foreachUntilCometInput(this) { - case scan: CometNativeScanExec => - // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and - // per-relation configs since different tables might have different decryption - // properties. - val hadoopConf = scan.relation.sparkSession.sessionState - .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) - if (encryptionEnabled) { - // hadoopConf isn't serializable, so we have to do a broadcasted config. - val broadcastedConf = - scan.relation.sparkSession.sparkContext - .broadcast(new SerializableConfiguration(hadoopConf)) - - val optsTuple: (Broadcast[SerializableConfiguration], Seq[String]) = - (broadcastedConf, scan.relation.inputFiles.toSeq) - encryptionOptions += optsTuple - } - case _ => // no-op - } - assert( - encryptionOptions.size <= 1, - "We expect one native scan that requires encryption reading in a Comet plan," + - " since we will broadcast one hadoopConf.") - // If this assumption changes in the future, you can look at the commit history of #2447 - // to see how there used to be a map of relations to broadcasted confs in case multiple - // relations in a single plan. The example that came up was UNION. See discussion at: - // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 - val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = - encryptionOptions.headOption match { - case Some((conf, paths)) => (Some(conf), paths) - case None => (None, Seq.empty) - } - - // Find planning data within this stage (stops at shuffle boundaries). - val (commonByKey, perPartitionByKey) = findAllPlanData(this) - - // Collect the input ColumnarBatches from the child operators and create a CometExecIterator - // to execute the native plan. - val sparkPlans = ArrayBuffer.empty[SparkPlan] - val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]] - - foreachUntilCometInput(this)(sparkPlans += _) - - // Find the first non broadcast plan - val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { - case (_: CometBroadcastExchangeExec, _) => false - case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false - case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false - case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false - case _ => true - } - - val containsBroadcastInput = sparkPlans.exists { - case _: CometBroadcastExchangeExec => true - case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true - case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true - case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true - case _ => false - } - - // If the first non broadcast plan is not found, it means all the plans are broadcast plans. - // This is not expected, so throw an exception. - if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) { - throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") - } - - // If the first non broadcast plan is found, we need to adjust the partition number of - // the broadcast plans to make sure they have the same partition number as the first non - // broadcast plan. - val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = - firstNonBroadcastPlan.get._1 match { - case plan: CometNativeExec => - (null, plan.outputPartitioning.numPartitions) - case plan => - val rdd = plan.executeColumnar() - (rdd, rdd.getNumPartitions) - } - - // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with - // same partition number. But for Comet, we need to zip them so we need to adjust the - // partition number of Broadcast RDDs to make sure they have the same partition number. - sparkPlans.zipWithIndex.foreach { case (plan, idx) => - plan match { - case c: CometBroadcastExchangeExec => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec( - _, - ReusedExchangeExec(_, c: CometBroadcastExchangeExec), - _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case _: CometNativeExec => - // no-op - case _ if idx == firstNonBroadcastPlan.get._2 => - inputs += firstNonBroadcastPlanRDD - case _ => - val rdd = plan.executeColumnar() - if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { - throw new CometRuntimeException( - s"Partition number mismatch: ${rdd.getNumPartitions} != " + - s"$firstNonBroadcastPlanNumPartitions") - } else { - inputs += rdd - } - } - } - - if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) { - throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") - } - - // Detect ShuffleScan indices for direct read in CometExecRDD - val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) - - // Unified RDD creation - CometExecRDD handles all cases - val subqueries = collectSubqueries(this) - val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) new CometExecRDD( sparkContext, - inputs.toSeq, - commonByKey, - perPartitionByKey, - serializedPlanCopy, - firstNonBroadcastPlanNumPartitions, + ctx.inputs, + ctx.commonByKey, + ctx.perPartitionByKey, + serializedPlan, + ctx.numPartitions, output.length, nativeMetrics, - subqueries, - broadcastedHadoopConfForEncryption, - encryptedFilePaths, - shuffleScanIndices) { + ctx.subqueries, + ctx.broadcastedHadoopConfForEncryption, + ctx.encryptedFilePaths, + ctx.shuffleScanIndices) { override def compute( split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val res = super.compute(split, context) // Report scan input metrics only when the native plan contains a scan. - if (hasScanInput) { + if (ctx.hasScanInput) { Option(context).foreach(nativeMetrics.reportScanInputMetrics) } @@ -586,6 +482,149 @@ abstract class CometNativeExec extends CometExec { } } + /** + * Walk this CometNativeExec subtree once and gather everything needed to launch native + * execution: leaf input RDDs (with broadcast partition counts aligned to the first + * non-broadcast plan), per-partition planning data, subqueries, encryption options, and + * shuffle-scan indices. See [[NativeExecContext]] for the full set of fields. + * + * Used by [[doExecuteColumnar]] (CometExecRDD path) and by the native-shuffle path + * (CometShuffleExchangeExec) so both observe the same input alignment. + */ + private[comet] def buildNativeContext(): NativeExecContext = { + // Go over all the native scans, in order to see if they need encryption options. + // For each relation in a CometNativeScan generate a hadoopConf, + // for each file path in a relation associate with hadoopConf + // This is done per native plan, so only count scans until a comet input is reached. + val encryptionOptions = + mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] + foreachUntilCometInput(this) { + case scan: CometNativeScanExec => + // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and + // per-relation configs since different tables might have different decryption + // properties. + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) + if (encryptionEnabled) { + // hadoopConf isn't serializable, so we have to do a broadcasted config. + val broadcastedConf = + scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + + val optsTuple: (Broadcast[SerializableConfiguration], Seq[String]) = + (broadcastedConf, scan.relation.inputFiles.toSeq) + encryptionOptions += optsTuple + } + case _ => // no-op + } + assert( + encryptionOptions.size <= 1, + "We expect one native scan that requires encryption reading in a Comet plan," + + " since we will broadcast one hadoopConf.") + // If this assumption changes in the future, you can look at the commit history of #2447 + // to see how there used to be a map of relations to broadcasted confs in case multiple + // relations in a single plan. The example that came up was UNION. See discussion at: + // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 + val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = + encryptionOptions.headOption match { + case Some((conf, paths)) => (Some(conf), paths) + case None => (None, Seq.empty) + } + + // Find planning data within this stage (stops at shuffle boundaries). + val (commonByKey, perPartitionByKey) = findAllPlanData(this) + + // Collect the input ColumnarBatches from the child operators and create a CometExecIterator + // to execute the native plan. + val sparkPlans = ArrayBuffer.empty[SparkPlan] + val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]] + + foreachUntilCometInput(this)(sparkPlans += _) + + // Find the first non broadcast plan + val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { + case (_: CometBroadcastExchangeExec, _) => false + case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false + case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false + case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false + case _ => true + } + + val containsBroadcastInput = sparkPlans.exists { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true + case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true + case _ => false + } + + // If the first non broadcast plan is not found, it means all the plans are broadcast plans. + // This is not expected, so throw an exception. + if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) { + throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") + } + + // If the first non broadcast plan is found, we need to adjust the partition number of + // the broadcast plans to make sure they have the same partition number as the first non + // broadcast plan. + val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = + firstNonBroadcastPlan.get._1 match { + case plan: CometNativeExec => + (null, plan.outputPartitioning.numPartitions) + case plan => + val rdd = plan.executeColumnar() + (rdd, rdd.getNumPartitions) + } + + // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with + // same partition number. But for Comet, we need to zip them so we need to adjust the + // partition number of Broadcast RDDs to make sure they have the same partition number. + sparkPlans.zipWithIndex.foreach { case (plan, idx) => + plan match { + case c: CometBroadcastExchangeExec => + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) + case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) + case _: CometNativeExec => + // no-op + case _ if idx == firstNonBroadcastPlan.get._2 => + inputs += firstNonBroadcastPlanRDD + case _ => + val rdd = plan.executeColumnar() + if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { + throw new CometRuntimeException( + s"Partition number mismatch: ${rdd.getNumPartitions} != " + + s"$firstNonBroadcastPlanNumPartitions") + } else { + inputs += rdd + } + } + } + + if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) { + throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") + } + + NativeExecContext( + inputs = inputs.toSeq, + numPartitions = firstNonBroadcastPlanNumPartitions, + subqueries = collectSubqueries(this), + broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, + encryptedFilePaths = encryptedFilePaths, + commonByKey = commonByKey, + perPartitionByKey = perPartitionByKey, + shuffleScanIndices = findShuffleScanIndices(nativeOp), + hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) + } + /** * Traverse the tree of Comet physical operators until reaching the input sources operators and * apply the given function to each operator. @@ -623,11 +662,10 @@ abstract class CometNativeExec extends CometExec { } /** - * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * Walk the protobuf operator tree depth-first to find which input indices correspond to * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. */ - private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { - val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + private def findShuffleScanIndices(plan: OperatorOuterClass.Operator): Set[Int] = { var scanIndex = 0 val indices = mutable.Set.empty[Int] def walk(op: OperatorOuterClass.Operator): Unit = { From 8c99fc5a9b4f5301994787debedd5c9c6a77e50f Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 20:29:02 -0400 Subject: [PATCH 17/39] Passes CometFuzzTestSuite, CometNativeShuffleSuite, CometExecSuite. --- .../shuffle/CometNativeShuffleInputRDD.scala | 119 ++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala new file mode 100644 index 0000000000..29b9745366 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometShuffleBlockIterator + +/** + * Thin RDD that anchors Spark scheduling for the native-shuffle path. Native execution itself is + * driven by [[CometNativeShuffleWriter]] using the unified `ShuffleWriter(child = childNativeOp)` + * plan. + * + * The RDD's role here is twofold: + * - declare `OneToOneDependency` on each leaf input RDD so the DAGScheduler walks the lineage + * and triggers prior stages, broadcast materialization, etc. + * - construct the per-partition leaf iterators (and shuffle-block iterators where applicable) + * in `compute`, packaged into a [[CometNativeShuffleInputIterator]] that the writer downcasts + * to extract the inputs it needs to feed the native plan. + * + * The iterator returned by `compute` always reports `hasNext = false`. Spark's `ShuffleMapTask` + * will hand it to `writer.write`; the writer ignores it as an iterator and reads its exposed + * fields directly. + */ +private[shuffle] class CometNativeShuffleInputRDD( + sc: SparkContext, + var inputRDDs: Seq[RDD[ColumnarBatch]], + numPartitionsParam: Int, + shuffleScanIndices: Set[Int]) + extends RDD[Product2[Int, ColumnarBatch]]( + sc, + inputRDDs.map(rdd => new OneToOneDependency(rdd))) { + + override protected def getPartitions: Array[Partition] = + (0 until numPartitionsParam).map { i => + // Resolve leaf-RDD partitions on the driver here (where their @transient fields are still + // populated). Stashing them on the partition lets `compute` avoid touching + // `leafRdd.partitions` on the executor, which would otherwise trigger getPartitions and + // hit the @transient-null trap (e.g. CometExecRDD.perPartitionByKey). + val inputParts = inputRDDs.map(_.partitions(i)).toArray + new CometNativeShuffleInputPartition(i, inputParts) + }.toArray + + override def compute( + split: Partition, + context: TaskContext): Iterator[Product2[Int, ColumnarBatch]] = { + val partition = split.asInstanceOf[CometNativeShuffleInputPartition] + val leafIterators = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => + rdd.iterator(part, context) + } + val shuffleBlockIters: Map[Int, CometShuffleBlockIterator] = + shuffleScanIndices.flatMap { si => + inputRDDs(si) match { + case rdd: CometShuffledBatchRDD => + Some(si -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(si), context)) + case _ => None + } + }.toMap + new CometNativeShuffleInputIterator(partition.index, leafIterators, shuffleBlockIters) + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + if (inputRDDs == null || inputRDDs.isEmpty) return Nil + val partition = split.asInstanceOf[CometNativeShuffleInputPartition] + val prefs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => + rdd.preferredLocations(part) + } + val intersection = prefs.reduce((a, b) => a.intersect(b)) + if (intersection.nonEmpty) intersection else prefs.flatten.distinct + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + inputRDDs = null + } +} + +private[shuffle] class CometNativeShuffleInputPartition( + override val index: Int, + val inputPartitions: Array[Partition]) + extends Partition + +/** + * Iterator handed to [[CometNativeShuffleWriter.write]] via Spark's ShuffleMapTask. Reports no + * elements; the writer downcasts and reads `partitionIndex`, `leafIterators`, and + * `shuffleBlockIterators` directly to drive the unified native plan. + */ +private[shuffle] class CometNativeShuffleInputIterator( + val partitionIndex: Int, + val leafIterators: Seq[Iterator[ColumnarBatch]], + val shuffleBlockIterators: Map[Int, CometShuffleBlockIterator]) + extends Iterator[Product2[Int, ColumnarBatch]] { + + override def hasNext: Boolean = false + + override def next(): Product2[Int, ColumnarBatch] = + throw new NoSuchElementException( + "CometNativeShuffleInputIterator does not produce elements; CometNativeShuffleWriter " + + "drives native execution via the iterator's exposed fields.") +} From 443a1c7866ff4dd80664425fecd06f4cbfa211e3 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 20:49:43 -0400 Subject: [PATCH 18/39] Cleanup, update docs. --- .../contributor-guide/native_shuffle.md | 38 ++++-- .../shuffle/CometNativeShuffleInputRDD.scala | 23 +--- .../shuffle/CometNativeShuffleWriter.scala | 72 +++++----- .../shuffle/CometShuffleDependency.scala | 22 ++- .../shuffle/CometShuffleExchangeExec.scala | 62 ++++----- .../shuffle/CometShuffleManager.scala | 4 +- .../apache/spark/sql/comet/operators.scala | 128 ++++++++---------- 7 files changed, 160 insertions(+), 189 deletions(-) diff --git a/docs/source/contributor-guide/native_shuffle.md b/docs/source/contributor-guide/native_shuffle.md index 18e80a90c8..c46b8b45a8 100644 --- a/docs/source/contributor-guide/native_shuffle.md +++ b/docs/source/contributor-guide/native_shuffle.md @@ -69,8 +69,9 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ▼ ┌─────────────────────────────────────────────────────────────────────────────┐ │ CometNativeShuffleWriter │ -│ - Constructs protobuf operator plan │ -│ - Invokes native execution via CometExec.getCometIterator() │ +│ - Builds protobuf operator plan: ShuffleWriter(child = childNativeOp) │ +│ - Reads per-partition leaf iterators from CometNativeShuffleInputIterator │ +│ - Drives one CometExecIterator per partition │ └─────────────────────────────────────────────────────────────────────────────┘ │ ▼ (JNI) @@ -103,13 +104,14 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ### Scala Side -| Class | Location | Description | -| ------------------------------ | ------------------------------------------------ | --------------------------------------------------------------------------------------------- | -| `CometShuffleExchangeExec` | `.../shuffle/CometShuffleExchangeExec.scala` | Physical plan node. Validates types and partitioning, creates `CometShuffleDependency`. | -| `CometNativeShuffleWriter` | `.../shuffle/CometNativeShuffleWriter.scala` | Implements `ShuffleWriter`. Builds protobuf plan and invokes native execution. | -| `CometShuffleDependency` | `.../shuffle/CometShuffleDependency.scala` | Extends `ShuffleDependency`. Holds shuffle type, schema, and range partition bounds. | -| `CometBlockStoreShuffleReader` | `.../shuffle/CometBlockStoreShuffleReader.scala` | Reads shuffle blocks via `ShuffleBlockFetcherIterator`. Decodes Arrow IPC to `ColumnarBatch`. | -| `NativeBatchDecoderIterator` | `.../shuffle/NativeBatchDecoderIterator.scala` | Reads compressed Arrow IPC from input stream. Calls native decode via JNI. | +| Class | Location | Description | +| ------------------------------ | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------- | +| `CometShuffleExchangeExec` | `.../shuffle/CometShuffleExchangeExec.scala` | Physical plan node. Validates types and partitioning, creates `CometShuffleDependency`. | +| `CometNativeShuffleWriter` | `.../shuffle/CometNativeShuffleWriter.scala` | Implements `ShuffleWriter`. Builds the unified `ShuffleWriter(child = childNativeOp)` plan and runs it in one `CometExecIterator` per partition. | +| `CometShuffleDependency` | `.../shuffle/CometShuffleDependency.scala` | Extends `ShuffleDependency`. Holds shuffle type, schema, range partition bounds, and (native shuffle only) a `NativeShuffleSpec`. | +| `CometNativeShuffleInputRDD` | `.../shuffle/CometNativeShuffleInputRDD.scala` | Thin scheduling-anchor RDD on the native-shuffle path. `compute` returns a `CometNativeShuffleInputIterator` carrying per-partition leaf iterators. | +| `CometBlockStoreShuffleReader` | `.../shuffle/CometBlockStoreShuffleReader.scala` | Reads shuffle blocks via `ShuffleBlockFetcherIterator`. Decodes Arrow IPC to `ColumnarBatch`. | +| `NativeBatchDecoderIterator` | `.../shuffle/NativeBatchDecoderIterator.scala` | Reads compressed Arrow IPC from input stream. Calls native decode via JNI. | ### Rust Side @@ -123,11 +125,19 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ### Write Path -1. **Plan construction**: `CometNativeShuffleWriter` builds a protobuf operator plan containing: - - A scan operator reading from the input iterator - - A `ShuffleWriter` operator with partitioning config and compression codec - -2. **Native execution**: `CometExec.getCometIterator()` executes the plan in Rust. +1. **Plan construction**: `CometNativeShuffleWriter` builds a protobuf operator tree with a + `ShuffleWriter` operator at the root and `childNativeOp` as its child. `childNativeOp` takes + one of two shapes: + - The child plan's `nativeOp` directly, when `CometShuffleExchangeExec`'s child is a + `CometNativeExec` subtree. The upstream operators run inside the same `CometExecIterator` + as the writer, with no JVM-to-native batch boundary between them. + - A synthetic `Scan("ShuffleWriterInput")` placeholder, when the dep was built via the + convenience `prepareShuffleDependency(rdd, ...)` overload (used by + `CometCollectLimitExec` and `CometTakeOrderedAndProjectExec`, or when the + exchange's child is a non-native `CometPlan` such as `CometSparkToColumnarExec`). Native + code reads `ColumnarBatch`es from the JVM input iterator via Arrow C Stream Interface. + +2. **Native execution**: A single `CometExecIterator` per partition runs the unified plan. 3. **Partitioning**: `ShuffleWriterExec` receives batches and routes to the appropriate partitioner: - `MultiPartitionShuffleRepartitioner`: For hash/range/round-robin partitioning diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala index 29b9745366..811c5bd3be 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala @@ -26,20 +26,11 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.CometShuffleBlockIterator /** - * Thin RDD that anchors Spark scheduling for the native-shuffle path. Native execution itself is - * driven by [[CometNativeShuffleWriter]] using the unified `ShuffleWriter(child = childNativeOp)` - * plan. - * - * The RDD's role here is twofold: - * - declare `OneToOneDependency` on each leaf input RDD so the DAGScheduler walks the lineage - * and triggers prior stages, broadcast materialization, etc. - * - construct the per-partition leaf iterators (and shuffle-block iterators where applicable) - * in `compute`, packaged into a [[CometNativeShuffleInputIterator]] that the writer downcasts - * to extract the inputs it needs to feed the native plan. - * - * The iterator returned by `compute` always reports `hasNext = false`. Spark's `ShuffleMapTask` - * will hand it to `writer.write`; the writer ignores it as an iterator and reads its exposed - * fields directly. + * Thin scheduling-anchor RDD for the native-shuffle path. Declares `OneToOneDependency` on each + * leaf input RDD (so the DAGScheduler triggers prior stages, broadcasts, etc.) and constructs + * per-partition leaf iterators in `compute`, packaged into a [[CometNativeShuffleInputIterator]]. + * The iterator reports `hasNext = false`; [[CometNativeShuffleWriter]] downcasts it and reads the + * leaf iterators directly to drive the unified `ShuffleWriter(child = childNativeOp)` plan. */ private[shuffle] class CometNativeShuffleInputRDD( sc: SparkContext, @@ -114,6 +105,6 @@ private[shuffle] class CometNativeShuffleInputIterator( override def next(): Product2[Int, ColumnarBatch] = throw new NoSuchElementException( - "CometNativeShuffleInputIterator does not produce elements; CometNativeShuffleWriter " + - "drives native execution via the iterator's exposed fields.") + "CometNativeShuffleInputIterator should never be drained as an iterator. Reaching this " + + "code means a non-Comet ShuffleWriter is consuming the input, which is a bug.") } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 68078e8507..21395c58fd 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -32,7 +32,7 @@ import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsR import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometScalarSubquery, NativeExecContext, PlanDataInjector} +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometScalarSubquery, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.comet.{CometConf, CometExecIterator} @@ -40,26 +40,16 @@ import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, Query import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} /** - * A [[ShuffleWriter]] that drives the native shuffle write in a single [[CometExecIterator]] per - * partition. The unified plan it executes has [[OperatorOuterClass.ShuffleWriter]] at the root - * with `childNativeOp` as its only child. Leaf input iterators come from - * [[CometNativeShuffleInputIterator]] (constructed by [[CometNativeShuffleInputRDD.compute]]). - * - * Two flavors of `childNativeOp` are in use: - * - rich Comet native subtree (e.g. HashAgg / Filter / ShuffleScan), supplied by - * [[CometShuffleExchangeExec]] when its child is a - * [[org.apache.spark.sql.comet.CometNativeExec]]. - * - synthetic `Scan("ShuffleWriterInput")` placeholder, supplied by the convenience overload of - * [[CometShuffleExchangeExec.prepareShuffleDependency]] for callers that already hold an - * `RDD[ColumnarBatch]` of native-driven batches (e.g. - * [[org.apache.spark.sql.comet.CometCollectLimitExec]]). - * - * The writer treats both shapes identically. + * Drives the native shuffle write in a single [[CometExecIterator]] per partition. The plan is + * `ShuffleWriter(child = childNativeOp)`; leaf iterators come from a + * [[CometNativeShuffleInputIterator]]. `childNativeOp` is either a rich Comet native subtree + * (when fed by [[CometShuffleExchangeExec]] with a [[org.apache.spark.sql.comet.CometNativeExec]] + * child) or a synthetic `Scan("ShuffleWriterInput")` placeholder (the + * [[CometShuffleExchangeExec.prepareShuffleDependency]] convenience overload). Same handling + * either way. */ class CometNativeShuffleWriter[K, V]( - childNativeOp: Operator, - childMetricNode: CometMetricNode, - nativeContext: NativeExecContext, + spec: NativeShuffleSpec, outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], metrics: Map[String, SQLMetric], @@ -87,19 +77,28 @@ class CometNativeShuffleWriter[K, V]( val tempDataFilePath = Paths.get(tempDataFilename) val tempIndexFilePath = Paths.get(tempIndexFilename) - // Pull the per-partition leaf iterators and partition index from the iterator handed to us - // by Spark's ShuffleMapTask. CometNativeShuffleInputRDD.compute always returns this exact - // iterator type; no other RDD layers between produce a Product2[Int, ColumnarBatch]. - val shuffleInputIter = inputs.asInstanceOf[CometNativeShuffleInputIterator] + // The dep's _rdd is always a CometNativeShuffleInputRDD on this path. Pattern-match instead + // of asInstanceOf so a future RDD-layering change produces a clear error here rather than a + // bare ClassCastException deeper in the stack. + val shuffleInputIter = inputs match { + case it: CometNativeShuffleInputIterator => it + case other => + throw new IllegalStateException( + "CometNativeShuffleWriter expects its input iterator to be a " + + "CometNativeShuffleInputIterator (produced by CometNativeShuffleInputRDD), got " + + s"${other.getClass.getName}") + } val partitionIdx = shuffleInputIter.partitionIndex val leafIterators = shuffleInputIter.leafIterators val shuffleBlockIters = shuffleInputIter.shuffleBlockIterators val unifiedPlan = buildUnifiedPlan(tempDataFilename, tempIndexFilename) - val finalNativePlan = if (nativeContext.commonByKey.nonEmpty) { - val partitionDataByKey = - nativeContext.perPartitionByKey.map { case (k, arr) => k -> arr(partitionIdx) } - PlanDataInjector.injectPlanData(unifiedPlan, nativeContext.commonByKey, partitionDataByKey) + val ctx = spec.execContext + val finalNativePlan = if (ctx.commonByKey.nonEmpty) { + val partitionDataByKey = ctx.perPartitionByKey.map { case (k, arr) => + k -> arr(partitionIdx) + } + PlanDataInjector.injectPlanData(unifiedPlan, ctx.commonByKey, partitionDataByKey) } else { unifiedPlan } @@ -119,30 +118,29 @@ class CometNativeShuffleWriter[K, V]( "write_time" -> metricsWriteTime) ++ metrics.filterKeys(detailedMetrics.contains) - // ShuffleWriter metrics live at the root of the metric tree; the child operator's metric - // tree (rich subtree or empty leaf for the Scan placeholder) is attached underneath so the - // SQL UI sees the same per-node breakdown the split-driver flow produced. - val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(childMetricNode)) + // ShuffleWriter metrics at the root; child's metric tree underneath so the SQL UI's per-node + // breakdown matches what the split-driver flow showed. + val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(spec.childMetricNode)) val cometIter = new CometExecIterator( CometExec.newIterId, leafIterators, outputAttributes.length, - PlanDataInjector.serializeOperator(finalNativePlan), + CometExec.serializeNativePlan(finalNativePlan), nativeMetrics, numParts, partitionIdx, - nativeContext.broadcastedHadoopConfForEncryption, - nativeContext.encryptedFilePaths, + ctx.broadcastedHadoopConfForEncryption, + ctx.encryptedFilePaths, shuffleBlockIters) // Register subqueries against the iterator id so native callbacks resolve them to values. - nativeContext.subqueries.foreach { sub => + ctx.subqueries.foreach { sub => CometScalarSubquery.setSubquery(cometIter.id, sub) } Option(context).foreach { taskCtx => taskCtx.addTaskCompletionListener[Unit] { _ => - nativeContext.subqueries.foreach { sub => + ctx.subqueries.foreach { sub => CometScalarSubquery.removeSubquery(cometIter.id, sub) } } @@ -334,7 +332,7 @@ class CometNativeShuffleWriter[K, V]( OperatorOuterClass.Operator .newBuilder() .setShuffleWriter(shuffleWriterBuilder) - .addChildren(childNativeOp) + .addChildren(spec.childNativeOp) .build() } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 9ec25c49ec..2a05843007 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -34,13 +34,23 @@ import org.apache.spark.sql.types.StructType import org.apache.comet.serde.OperatorOuterClass +/** + * Bundle of context the native shuffle write path needs at task time. Co-populated for native + * shuffles only; consolidated into a single field on [[CometShuffleDependency]] so it cannot be + * partially set. + */ +case class NativeShuffleSpec( + childNativeOp: OperatorOuterClass.Operator, + childMetricNode: CometMetricNode, + execContext: NativeExecContext) + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. * - * On the native-shuffle path, also carries the child plan's per-partition execution context, root - * native operator, and metric node so [[CometNativeShuffleWriter]] can drive the unified - * `ShuffleWriter(child = childNativeOp)` plan in a single [[org.apache.comet.CometExecIterator]] - * per partition. These three fields are populated only when `shuffleType == CometNativeShuffle`. + * On the native-shuffle path, also carries a [[NativeShuffleSpec]] so + * [[CometNativeShuffleWriter]] can drive the unified `ShuffleWriter(child = childNativeOp)` plan + * in a single [[org.apache.comet.CometExecIterator]] per partition. `nativeShuffleSpec` is + * populated only when `shuffleType == CometNativeShuffle`. */ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @transient private val _rdd: RDD[_ <: Product2[K, V]], @@ -58,9 +68,7 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, val rangePartitionBounds: Option[Seq[InternalRow]] = None, - val nativeExecContext: Option[NativeExecContext] = None, - val childNativeOp: Option[OperatorOuterClass.Operator] = None, - val childMetricNode: Option[CometMetricNode] = None) + val nativeShuffleSpec: Option[NativeShuffleSpec] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 16bd40d402..1470d637d9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -102,14 +102,11 @@ case class CometShuffleExchangeExec( new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) /** - * Per-partition execution context for the child native subtree, computed once and shared - * between [[inputRDD]] (which uses it to wire DAGScheduler dependencies) and - * [[shuffleDependency]] (which threads it through to [[CometNativeShuffleWriter]] for - * single-iterator native execution). Only populated when `shuffleType == CometNativeShuffle` - * AND the child is a [[CometNativeExec]] subtree we can inline. When the child is a non-native - * Comet plan (e.g. [[org.apache.spark.sql.comet.CometSparkToColumnarExec]]), this stays `None` - * and the shuffle falls back to the legacy `Scan("ShuffleWriterInput") -> ShuffleWriter` plan - * via the convenience overload of `prepareShuffleDependency`. + * Single-driver native-shuffle context, computed once and shared between [[inputRDD]] and + * [[shuffleDependency]]. `Some` only when `shuffleType == CometNativeShuffle` AND the child is + * a [[CometNativeExec]] subtree. Otherwise the dep is built via the + * [[CometShuffleExchangeExec.prepareShuffleDependency]] convenience overload (synthetic Scan + * placeholder). */ @transient private lazy val nativeChildContext: Option[NativeExecContext] = child match { case nativeChild: CometNativeExec if shuffleType == CometNativeShuffle => @@ -120,23 +117,19 @@ case class CometShuffleExchangeExec( @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { nativeChildContext match { case Some(ctx) => - // Single-driver path: thin scheduling anchor; CometNativeShuffleWriter drives the - // unified ShuffleWriter + child plan in a single CometExecIterator per partition. new CometNativeShuffleInputRDD( sparkContext, ctx.inputs, ctx.numPartitions, ctx.shuffleScanIndices) case None => - // Child is a Comet plan but not a CometNativeExec subtree (e.g. CometSparkToColumnarExec). - // No native subtree to inline; the writer's plan is `Scan("ShuffleWriterInput") -> - // ShuffleWriter` and JVM batches flow into native through Arrow C Stream Interface. + // Non-native child (e.g. CometSparkToColumnarExec): no subtree to inline. The dep gets + // built via the legacy convenience overload below; we just need a real RDD of batches. child.executeColumnar() } } else if (shuffleType == CometColumnarShuffle) { - // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, - // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec - // to convert columnar batches to rows. + // Row-based shuffle. CometNativeExec.doExecute wraps columnar output with + // ColumnarToRowExec; non-Comet children flow through directly. child.execute() } else { throw new UnsupportedOperationException( @@ -180,12 +173,11 @@ case class CometShuffleExchangeExec( if (shuffleType == CometNativeShuffle) { val dep = nativeChildContext match { case Some(ctx) => - // Single-driver path: child is a CometNativeExec subtree. RangePartitioning needs real - // rows to compute partition bounds; use a regular columnar execution of the child for - // sampling only. The actual shuffle still goes through the single-iterator path. val nativeChild = child.asInstanceOf[CometNativeExec] + // RangePartitioner needs real rows for sampling. Reuse the precomputed context so we + // don't re-walk the SparkPlan tree or re-broadcast the encryption Hadoop conf. val samplingRDD: Option[RDD[ColumnarBatch]] = outputPartitioning match { - case _: RangePartitioning => Some(child.executeColumnar()) + case _: RangePartitioning => Some(nativeChild.executeColumnarWithContext(ctx)) case _ => None } CometShuffleExchangeExec.prepareNativeShuffleDependency( @@ -195,12 +187,11 @@ case class CometShuffleExchangeExec( outputPartitioning, serializer, metrics, - ctx, - nativeChild.nativeOp, - CometMetricNode.fromCometPlan(nativeChild)) + NativeShuffleSpec( + nativeChild.nativeOp, + CometMetricNode.fromCometPlan(nativeChild), + ctx)) case None => - // Child is a non-native Comet plan; the writer falls back to its Scan-placeholder - // path via the convenience overload of prepareShuffleDependency. CometShuffleExchangeExec.prepareShuffleDependency( inputRDD.asInstanceOf[RDD[ColumnarBatch]], child.output, @@ -733,18 +724,16 @@ object CometShuffleExchangeExec outputPartitioning, serializer, metrics, - ctx, - scanOp, - CometMetricNode(Map.empty)) + NativeShuffleSpec(scanOp, CometMetricNode(Map.empty), ctx)) } /** * Build a Comet native shuffle dependency for the [[CometShuffleExchangeExec]] case where the * shuffle is fed by a [[CometNativeExec]] child. The writer drives the unified * `ShuffleWriter(child = childNativeOp)` plan in a single - * [[org.apache.comet.CometExecIterator]] per partition. The returned dep carries the child's - * per-partition execution context, root native operator, and metric node so - * [[CometNativeShuffleWriter]] can reach them at task time. + * [[org.apache.comet.CometExecIterator]] per partition. The returned dep carries the + * [[NativeShuffleSpec]] so [[CometNativeShuffleWriter]] can reach the child's per-partition + * execution context, root native operator, and metric node at task time. * * @param thinRDD * scheduling-anchor RDD whose `compute` returns a [[CometNativeShuffleInputIterator]]; @@ -760,9 +749,7 @@ object CometShuffleExchangeExec outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric], - nativeExecContext: NativeExecContext, - childNativeOp: OperatorOuterClass.Operator, - childMetricNode: CometMetricNode): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + spec: NativeShuffleSpec): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val numParts = thinRDD.getNumPartitions // The code block below is mostly brought over from @@ -821,7 +808,7 @@ object CometShuffleExchangeExec None) } - val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( thinRDD, serializer = serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(metrics), @@ -833,10 +820,7 @@ object CometShuffleExchangeExec shuffleWriteMetrics = metrics, numParts = numParts, rangePartitionBounds = rangePartitionBounds, - nativeExecContext = Some(nativeExecContext), - childNativeOp = Some(childNativeOp), - childMetricNode = Some(childMetricNode)) - dependency + nativeShuffleSpec = Some(spec)) } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index 16a21bbfd9..bd69e91898 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -231,9 +231,7 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { case cometShuffleHandle: CometNativeShuffleHandle[K @unchecked, V @unchecked] => val dep = cometShuffleHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]] new CometNativeShuffleWriter( - dep.childNativeOp.get, - dep.childMetricNode.get, - dep.nativeExecContext.get, + dep.nativeShuffleSpec.get, dep.outputPartitioning.get, dep.outputAttributes, dep.shuffleWriteMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 748362cd65..fa7aa6f1b2 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -383,15 +383,11 @@ object CometExec { /** * Per-partition execution context for a native subtree rooted at a [[CometNativeExec]] boundary. - * - * Built once on the driver from the SparkPlan tree and consumed by either: - * - [[CometNativeExec.doExecuteColumnar]] to construct a [[CometExecRDD]], or - * - the native-shuffle path, where [[CometNativeShuffleWriter]] drives the same child plan with - * a `ShuffleWriter` operator as its root. - * - * The fields capture everything that depends on tree-walking the SparkPlan and aligning leaf - * input RDDs (broadcast partition counts, plan-data, subqueries, encryption options) so callers - * do not have to re-derive them. + * Built once on the driver from the SparkPlan tree, then consumed by either + * [[CometNativeExec.executeColumnarWithContext]] (to build a [[CometExecRDD]]) or the + * native-shuffle path (to drive [[CometNativeShuffleWriter]]). Captures broadcast partition + * alignment, plan-data, subqueries, and encryption options so each consumer doesn't re-walk the + * tree. */ private[comet] case class NativeExecContext( inputs: Seq[RDD[ColumnarBatch]], @@ -402,7 +398,15 @@ private[comet] case class NativeExecContext( commonByKey: Map[String, Array[Byte]], perPartitionByKey: Map[String, Array[Array[Byte]]], shuffleScanIndices: Set[Int], - hasScanInput: Boolean) + hasScanInput: Boolean) { + // Catch shape divergence (e.g. broadcast scans with different partition counts after DPP + // filtering) at construction so consumers don't trip ArrayIndexOutOfBoundsException at + // partition idx access time. + require( + perPartitionByKey.values.forall(_.length == numPartitions), + s"All per-partition arrays must have length $numPartitions, but found: " + + perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}" }.mkString(", ")) +} /** * A Comet native physical operator. @@ -442,90 +446,68 @@ abstract class CometNativeExec extends CometExec { runningSubqueries.clear() } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - serializedPlanOpt.plan match { - case None => - // This is in the middle of a native execution, it should not be executed directly. - throw new CometRuntimeException( - s"CometNativeExec should not be executed directly without a serialized plan: $this") - case Some(serializedPlan) => - // TODO: support native metrics for all operators. - val nativeMetrics = CometMetricNode.fromCometPlan(this) - val ctx = buildNativeContext() - - new CometExecRDD( - sparkContext, - ctx.inputs, - ctx.commonByKey, - ctx.perPartitionByKey, - serializedPlan, - ctx.numPartitions, - output.length, - nativeMetrics, - ctx.subqueries, - ctx.broadcastedHadoopConfForEncryption, - ctx.encryptedFilePaths, - ctx.shuffleScanIndices) { - override def compute( - split: Partition, - context: TaskContext): Iterator[ColumnarBatch] = { - val res = super.compute(split, context) - - // Report scan input metrics only when the native plan contains a scan. - if (ctx.hasScanInput) { - Option(context).foreach(nativeMetrics.reportScanInputMetrics) - } + override def doExecuteColumnar(): RDD[ColumnarBatch] = + executeColumnarWithContext(buildNativeContext()) - res - } + /** + * Build a [[CometExecRDD]] from a precomputed [[NativeExecContext]]. Public so the native + * shuffle path can sample (RangePartitioning) without re-walking the SparkPlan tree and + * re-broadcasting the encryption Hadoop conf. + */ + private[comet] def executeColumnarWithContext(ctx: NativeExecContext): RDD[ColumnarBatch] = { + val serializedPlan = serializedPlanOpt.plan.getOrElse( + throw new CometRuntimeException( + s"CometNativeExec should not be executed directly without a serialized plan: $this")) + val nativeMetrics = CometMetricNode.fromCometPlan(this) + + new CometExecRDD( + sparkContext, + ctx.inputs, + ctx.commonByKey, + ctx.perPartitionByKey, + serializedPlan, + ctx.numPartitions, + output.length, + nativeMetrics, + ctx.subqueries, + ctx.broadcastedHadoopConfForEncryption, + ctx.encryptedFilePaths, + ctx.shuffleScanIndices) { + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + if (ctx.hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) } + res + } } } /** * Walk this CometNativeExec subtree once and gather everything needed to launch native - * execution: leaf input RDDs (with broadcast partition counts aligned to the first - * non-broadcast plan), per-partition planning data, subqueries, encryption options, and - * shuffle-scan indices. See [[NativeExecContext]] for the full set of fields. - * - * Used by [[doExecuteColumnar]] (CometExecRDD path) and by the native-shuffle path - * (CometShuffleExchangeExec) so both observe the same input alignment. + * execution. See [[NativeExecContext]] for the field set. */ private[comet] def buildNativeContext(): NativeExecContext = { - // Go over all the native scans, in order to see if they need encryption options. - // For each relation in a CometNativeScan generate a hadoopConf, - // for each file path in a relation associate with hadoopConf - // This is done per native plan, so only count scans until a comet input is reached. + // Find native scans that need encryption: build a hadoopConf per relation, broadcast it once + // so executors can decrypt on read. Capped at one because we only broadcast one conf per + // CometExecIterator (see #2447 for history of the per-relation map approach). val encryptionOptions = mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] foreachUntilCometInput(this) { case scan: CometNativeScanExec => - // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and - // per-relation configs since different tables might have different decryption - // properties. val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) - if (encryptionEnabled) { - // hadoopConf isn't serializable, so we have to do a broadcasted config. - val broadcastedConf = - scan.relation.sparkSession.sparkContext - .broadcast(new SerializableConfiguration(hadoopConf)) - - val optsTuple: (Broadcast[SerializableConfiguration], Seq[String]) = - (broadcastedConf, scan.relation.inputFiles.toSeq) - encryptionOptions += optsTuple + if (CometParquetUtils.encryptionEnabled(hadoopConf)) { + val broadcastedConf = scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + encryptionOptions += ((broadcastedConf, scan.relation.inputFiles.toSeq)) } - case _ => // no-op + case _ => } assert( encryptionOptions.size <= 1, "We expect one native scan that requires encryption reading in a Comet plan," + " since we will broadcast one hadoopConf.") - // If this assumption changes in the future, you can look at the commit history of #2447 - // to see how there used to be a map of relations to broadcasted confs in case multiple - // relations in a single plan. The example that came up was UNION. See discussion at: - // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = encryptionOptions.headOption match { case Some((conf, paths)) => (Some(conf), paths) From 07e7944faf291edbac10c81aaad2941f33169d41 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 28 May 2026 20:58:19 -0400 Subject: [PATCH 19/39] remove non-ascii --- .../sql/comet/execution/arrow/CometNativeArrowSource.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala index 92ffc61acd..727f761997 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -135,7 +135,7 @@ object CometArrowStream extends Logging { * vector's field. * * Field name and metadata come from `expected` so that consumers indexing by name keep working. - * Nullability is the union of the two — a CometVector that happens to hold no nulls in this + * Nullability is the union of the two: a CometVector that happens to hold no nulls in this * batch can still be nullable per Spark's contract (the next batch may have one), and a column * whose actual buffer carries validity bits must stay nullable even if Spark thought otherwise. * Taking only `raw.isNullable` here would advertise non-nullable when the next batch does carry From cc7c5bedf9665f6671038cdeac273c09ca5c7fa9 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 10:09:16 -0400 Subject: [PATCH 20/39] handle arrow type mismatches on child stream in native shuffle. --- native/core/src/execution/operators/mod.rs | 2 + .../src/execution/operators/schema_align.rs | 266 ++++++++++++++++++ native/core/src/execution/planner.rs | 27 +- native/proto/src/proto/operator.proto | 6 + .../shuffle/CometNativeShuffleWriter.scala | 9 + 5 files changed, 307 insertions(+), 3 deletions(-) create mode 100644 native/core/src/execution/operators/schema_align.rs diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 4b2c06575d..b679d35b19 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -32,6 +32,8 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; +mod schema_align; mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; +pub use schema_align::SchemaAlignExec; pub use shuffle_scan::ShuffleScanExec; diff --git a/native/core/src/execution/operators/schema_align.rs b/native/core/src/execution/operators/schema_align.rs new file mode 100644 index 0000000000..3d207e0202 --- /dev/null +++ b/native/core/src/execution/operators/schema_align.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `SchemaAlignExec` reshapes its child's output so the per-column Arrow type and field-level +//! nullability match what Spark catalyst declared. Used between an inlined native subtree and +//! `ShuffleWriterExec` when the FFI deep-copy + `ScanExec` cast in `build_record_batch` are both +//! gone, so DataFusion / `datafusion-spark` return-type drift would otherwise be written into +//! shuffle blocks. See for the running +//! list of mismatched functions. + +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use datafusion::common::DataFusionError; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::{ + execution::TaskContext, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, + }, +}; +use futures::{Stream, StreamExt}; +use std::{ + any::Any, + collections::HashSet, + pin::Pin, + sync::{Arc, Mutex, OnceLock}, + task::{Context, Poll}, +}; + +/// Process-wide set of `(column, actual, expected)` signatures we have already warned about. +/// Each schema drift produces the same warning on every partition of every query that runs +/// the offending expression; deduping here keeps logs readable while still surfacing each +/// distinct mismatch once. +fn warn_dedup() -> &'static Mutex> { + static SET: OnceLock>> = OnceLock::new(); + SET.get_or_init(|| Mutex::new(HashSet::new())) +} + +#[derive(Debug)] +pub struct SchemaAlignExec { + child: Arc, + target_schema: SchemaRef, + column_actions: Arc>, + cache: Arc, +} + +#[derive(Debug, Clone)] +enum ColumnAction { + /// Pass the input column through unchanged. Any nullability/metadata difference is + /// absorbed when the batch is re-stamped via `RecordBatch::try_new_with_options`. + Passthrough, + /// Cast the input column to the target data_type. + Cast, +} + +impl SchemaAlignExec { + /// Build a SchemaAlignExec that aligns `child`'s output to `expected`. Returns + /// `Ok(child)` unchanged when no per-column reshape is needed; otherwise wraps `child` + /// in a SchemaAlignExec whose target schema preserves `expected`'s data_type and metadata + /// but widens nullability to `actual.nullable || expected.nullable` (matching the + /// reconciliation rule used at the FFI boundary on `main`). + pub fn try_new_or_passthrough( + child: Arc, + expected: &SchemaRef, + ) -> Result, DataFusionError> { + let actual = child.schema(); + if actual.fields().len() != expected.fields().len() { + return Err(DataFusionError::Plan(format!( + "SchemaAlignExec: expected {} fields, child produces {}", + expected.fields().len(), + actual.fields().len() + ))); + } + let mut needs_alignment = false; + let mut actions = Vec::with_capacity(actual.fields().len()); + let mut target_fields = Vec::with_capacity(actual.fields().len()); + for (idx, (actual_field, expected_field)) in actual + .fields() + .iter() + .zip(expected.fields().iter()) + .enumerate() + { + let action = if actual_field.data_type() == expected_field.data_type() { + ColumnAction::Passthrough + } else { + let signature = format!( + "{}|{:?}|{:?}", + expected_field.name(), + actual_field.data_type(), + expected_field.data_type() + ); + if warn_dedup().lock().unwrap().insert(signature) { + log::warn!( + "ShuffleWriter input schema mismatch on col[{idx}] '{}': child produced \ + {:?}, catalyst declared {:?}. Inserting a cast; please file the upstream \ + function bug at https://github.com/apache/datafusion-comet/issues/4515.", + expected_field.name(), + actual_field.data_type(), + expected_field.data_type() + ); + } + ColumnAction::Cast + }; + let target_nullable = actual_field.is_nullable() || expected_field.is_nullable(); + let field_changed = !matches!(action, ColumnAction::Passthrough) + || target_nullable != actual_field.is_nullable() + || expected_field.metadata() != actual_field.metadata() + || expected_field.name() != actual_field.name(); + if field_changed { + needs_alignment = true; + } + target_fields.push( + Field::new( + expected_field.name(), + expected_field.data_type().clone(), + target_nullable, + ) + .with_metadata(expected_field.metadata().clone()), + ); + actions.push(action); + } + if !needs_alignment { + return Ok(child); + } + let target_schema: SchemaRef = Arc::new(Schema::new(target_fields)); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&target_schema)), + child.output_partitioning().clone(), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(Arc::new(Self { + child, + target_schema, + column_actions: Arc::new(actions), + cache, + })) + } +} + +impl DisplayAs for SchemaAlignExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CometSchemaAlignExec") + } + DisplayFormatType::TreeRender => unimplemented!(), + } + } +} + +impl ExecutionPlan for SchemaAlignExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.target_schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + // Rebuild PlanProperties from the new child since `output_partitioning` may differ. + let new_child = Arc::clone(&children[0]); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&self.target_schema)), + new_child.output_partitioning().clone(), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(Arc::new(Self { + child: new_child, + target_schema: Arc::clone(&self.target_schema), + column_actions: Arc::clone(&self.column_actions), + cache, + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion::common::Result { + let child_stream = self.child.execute(partition, context)?; + Ok(Box::pin(SchemaAlignStream { + child_stream, + target_schema: Arc::clone(&self.target_schema), + column_actions: Arc::clone(&self.column_actions), + })) + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn name(&self) -> &str { + "CometSchemaAlignExec" + } +} + +struct SchemaAlignStream { + child_stream: SendableRecordBatchStream, + target_schema: SchemaRef, + column_actions: Arc>, +} + +impl SchemaAlignStream { + fn align(&self, batch: RecordBatch) -> Result { + let mut columns: Vec = Vec::with_capacity(batch.num_columns()); + for (idx, action) in self.column_actions.iter().enumerate() { + let column = batch.column(idx); + let aligned = match action { + ColumnAction::Passthrough => Arc::clone(column), + ColumnAction::Cast => cast_with_options( + column, + self.target_schema.field(idx).data_type(), + &CastOptions::default(), + )?, + }; + columns.push(aligned); + } + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(Arc::clone(&self.target_schema), columns, &options) + .map_err(DataFusionError::from) + } +} + +impl Stream for SchemaAlignStream { + type Item = datafusion::common::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.child_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(self.align(batch))), + other => other, + } + } +} + +impl RecordBatchStream for SchemaAlignStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.target_schema) + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c6160bddd4..dff8275f84 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -26,7 +26,9 @@ use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::list_positions::ListPositionsExpr, expressions::subquery::Subquery, - operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, + operators::{ + ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, SchemaAlignExec, ShuffleScanExec, + }, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, @@ -1515,9 +1517,14 @@ impl PhysicalPlanner { let (scans, shuffle_scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let writer_input = align_shuffle_writer_input( + Arc::clone(&child.native_plan), + &writer.expected_output_schema, + )?; + let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), - child.native_plan.schema(), + writer_input.schema(), )?; let codec = match writer.codec.try_into() { @@ -1535,7 +1542,7 @@ impl PhysicalPlanner { let write_buffer_size = writer.write_buffer_size as usize; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( - Arc::clone(&child.native_plan), + writer_input, partitioning, codec, writer.output_data_file.clone(), @@ -3124,6 +3131,20 @@ fn convert_spark_types_to_arrow_schema( arrow_schema } +/// Wrap `child` in a `SchemaAlignExec` when its output drifts from what Spark catalyst +/// declared. See . +fn align_shuffle_writer_input( + child: Arc, + expected_proto: &[spark_operator::SparkStructField], +) -> Result, ExecutionError> { + if expected_proto.is_empty() { + return Ok(child); + } + let expected = convert_spark_types_to_arrow_schema(expected_proto); + SchemaAlignExec::try_new_or_passthrough(child, &expected) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) +} + /// Converts a protobuf PartitionValue to an iceberg Literal. /// fn partition_value_to_literal( diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 7f50aa928c..da498087b9 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -327,6 +327,12 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // Spark-declared output schema of the writer's child. When the child is an inlined native + // subtree, the native planner casts the child's actual output to this schema before + // serializing to shuffle blocks, since there is no FFI boundary or ScanExec between them + // to absorb DataFusion-vs-Spark type drift. Empty when the child is a placeholder Scan; + // that path already has a cast point upstream. + repeated SparkStructField expected_output_schema = 9; } message ParquetWriter { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 21395c58fd..9fe70f16b5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -34,10 +34,12 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometScalarSubquery, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructField import org.apache.comet.{CometConf, CometExecIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} +import org.apache.comet.serde.operator.schema2Proto /** * Drives the native shuffle write in a single [[CometExecIterator]] per partition. The plan is @@ -329,6 +331,13 @@ class CometNativeShuffleWriter[K, V]( shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) + // Used by the native planner to cast the inlined child's output when DataFusion's + // declared return type drifts from Spark catalyst (see comet#4515). + val expectedFields = outputAttributes + .map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)) + .toArray + schema2Proto(expectedFields).foreach(shuffleWriterBuilder.addExpectedOutputSchema) + OperatorOuterClass.Operator .newBuilder() .setShuffleWriter(shuffleWriterBuilder) From c76a26399357576a3f699dde6eb767c58ea7018f Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 14:08:04 -0400 Subject: [PATCH 21/39] stash --- native/core/src/execution/planner.rs | 30 ++++- native/proto/src/proto/operator.proto | 6 + .../org/apache/comet/CometBatchIterator.java | 16 +++ .../org/apache/comet/CometExecIterator.scala | 21 ++++ .../apache/comet/rules/CometExecRule.scala | 24 ++++ .../apache/comet/rules/CometScanRule.scala | 11 ++ .../comet/serde/operator/CometSink.scala | 74 ++++++++++++ .../org/apache/comet/vector/NativeUtil.scala | 22 ++++ .../shuffle/CometShuffleExchangeExec.scala | 33 ++++++ .../shuffle/CometShuffledRowRDD.scala | 29 ++++- .../apache/spark/sql/comet/operators.scala | 48 +++++++- .../comet/exec/CometAggregateSuite.scala | 93 +++++++++++++++ .../apache/comet/exec/CometExecSuite.scala | 111 ++++++++++++++++++ 13 files changed, 515 insertions(+), 3 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index dff8275f84..bf90eb20cc 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1034,6 +1034,14 @@ impl PhysicalPlanner { )) } OpStruct::HashAgg(agg) => { + // [#4515 instrumentation] Every HashAgg construction with proto sizes. + log::warn!( + "[#4515] OpStruct::HashAgg grouping_exprs.len={} agg_exprs.len={} result_exprs.len={} mode={}", + agg.grouping_exprs.len(), + agg.agg_exprs.len(), + agg.result_exprs.len(), + agg.mode + ); assert_eq!(children.len(), 1); let (scans, shuffle_scans, child) = self.create_plan(&children[0], inputs, partition_count)?; @@ -1173,7 +1181,17 @@ impl PhysicalPlanner { }) .collect(); - if agg.result_exprs.is_empty() { + if !agg.apply_result_projection { + // [#4515 instrumentation] Confirm whether a native HashAggregate is being + // built with empty result_exprs (catalyst-pruned EXISTS / count(*) subquery) + // and what its natural schema would be. + log::warn!( + "[#4515] HashAgg apply_result_projection=false: emitting aggregate natural schema={:?} grouping_exprs.len={} agg_exprs.len={} result_exprs.len={}", + aggregate.schema(), + agg.grouping_exprs.len(), + agg.agg_exprs.len(), + agg.result_exprs.len() + ); Ok(( scans, shuffle_scans, @@ -1443,6 +1461,16 @@ impl PhysicalPlanner { OpStruct::Scan(scan) => { let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); + // [#4515 instrumentation] Log every JVM-bridge Scan's declared column count. + // A 0-column scan paired with a JVM iterator producing batches with columns is + // the AIOOBE-on-exportBatch shape; a non-empty list confirms the inverse. + log::info!( + "[#4515] ScanExec source='{}' declared {} cols: {:?}", + scan.source, + data_types.len(), + data_types + ); + // If it is not test execution context for unit test, we should have at least one // input source if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index da498087b9..921e5bbf35 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -303,6 +303,12 @@ message HashAggregate { // Offset in the child's output where aggregate buffer attributes start. // Used by PartialMerge to locate state fields in the input. int32 initial_input_buffer_offset = 7; + // Whether the native planner should wrap the aggregate in a ProjectionExec using + // `result_exprs`. Disambiguates `result_exprs` empty-because-absent (no projection + // needed) from empty-because-catalyst-pruned-to-zero-cols (project to nothing). + // Without this bit, both wire as `repeated` length 0 and were collapsed to "no + // projection," leaking grouping keys when Spark intent was 0 columns. + bool apply_result_projection = 8; } message Limit { diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index 9b48a47c57..1392f028c0 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -38,6 +38,8 @@ public class CometBatchIterator { private final NativeUtil nativeUtil; private ColumnarBatch previousBatch = null; private ColumnarBatch currentBatch = null; + // [#4515 instrumentation] gate first-batch log per instance + private boolean loggedFirstBatch = false; CometBatchIterator(Iterator input, NativeUtil nativeUtil) { this.input = input; @@ -79,6 +81,20 @@ public int next(long[] arrayAddrs, long[] schemaAddrs) { return -1; } + // [#4515 instrumentation] Log first-batch shape per CometBatchIterator instance. + if (!loggedFirstBatch) { + loggedFirstBatch = true; + org.slf4j.LoggerFactory.getLogger("[#4515]") + .warn( + "CometBatchIterator.next first batch: numCols={} numRows={} arrayAddrs.length={} schemaAddrs.length={} inputCls={} this={}", + currentBatch.numCols(), + currentBatch.numRows(), + arrayAddrs.length, + schemaAddrs.length, + input.getClass().getName(), + System.identityHashCode(this)); + } + // export the batch using the Arrow C Data Interface int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 6140eca553..98a9bc0684 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -92,6 +92,27 @@ class CometExecIterator( val conf = SparkEnv.get.conf val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs + // [#4515 instrumentation] Dump proto plan + per-input iterator class so we can correlate + // a 0-col Scan in the proto with the JVM iterator that will feed it. + { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + val parsed = + scala.util + .Try( + org.apache.comet.serde.OperatorOuterClass.Operator + .parseFrom(protobufQueryPlan) + .toString) + .toOption + .getOrElse("") + val itersDesc = inputIterators.zipWithIndex + .map { case (it, idx) => s" inputIterators[$idx] cls=${it.getClass.getName}" } + .mkString("\n") + log.warn( + s"CometExecIterator constructing plan id=$id partition=$partitionIndex " + + s"numParts=$numParts numOutputCols=$numOutputCols\n$itersDesc\n" + + s" proto plan:\n$parsed") + } + // serialize Comet related Spark configs in protobuf format val protobufSparkConfigs = CometExecIterator.serializeCometSQLConfs() diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index aeb7db40ad..6eb314bbb6 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -166,6 +166,16 @@ case class CometExecRule(session: SparkSession) val reverted = s.originalPlan.withNewChildren(Seq(s.child)).asInstanceOf[ShuffleExchangeExec] reverted.setTagValue(CometExecRule.SKIP_COMET_SHUFFLE_TAG, ()) + // [#4515 instrumentation] Log every revert: this is a primary place where + // vanilla ShuffleExchangeExec instances enter the plan post-Comet-rewrite. + org.slf4j.LoggerFactory + .getLogger("[#4515]") + .warn( + s"revertRedundantColumnarShuffle revert: parentAgg=${op.getClass.getSimpleName} " + + s"(output=${op.output}) cometShuffle.output=${s.output} " + + s"reverted.output=${reverted.output} reverted.identityHash=${System + .identityHashCode(reverted)}\n" + + s" reverted tree:\n${reverted.treeString(verbose = true, addSuffix = false)}") logInfo( "Reverting Comet columnar shuffle to Spark shuffle between " + s"${op.getClass.getSimpleName} and ${s.child.getClass.getSimpleName} " + @@ -546,6 +556,20 @@ case class CometExecRule(session: SparkSession) |${sideBySide(plan.treeString, newPlan.treeString).mkString("\n")} |""".stripMargin) } + // [#4515 instrumentation] Dump the plan in/out of CometExecRule to correlate with the + // 0-col Scan synthesized later. Logs the operator-class diff so we can see what + // CometExecRule did or didn't replace, especially around ShuffleExchangeExec wrappers + // inside subqueries. + if (!newPlan.fastEquals(plan)) { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + log.warn( + "CometExecRule rewrote plan:\n IN classes: " + + plan.collect { case p => p.getClass.getSimpleName }.mkString(",") + + "\n OUT classes: " + + newPlan.collect { case p => p.getClass.getSimpleName }.mkString(",") + + s"\n IN tree:\n${plan.treeString(verbose = true, addSuffix = false)}" + + s"\n OUT tree:\n${newPlan.treeString(verbose = true, addSuffix = false)}") + } newPlan } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 7601fa1c6b..1f7f509db0 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -68,6 +68,17 @@ case class CometScanRule(session: SparkSession) |${sideBySide(plan.treeString, newPlan.treeString).mkString("\n")} |""".stripMargin) } + // [#4515 instrumentation] + if (!newPlan.fastEquals(plan)) { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + log.warn( + "CometScanRule rewrote plan:\n IN classes: " + + plan.collect { case p => p.getClass.getSimpleName }.mkString(",") + + "\n OUT classes: " + + newPlan.collect { case p => p.getClass.getSimpleName }.mkString(",") + + s"\n IN tree:\n${plan.treeString(verbose = true, addSuffix = false)}" + + s"\n OUT tree:\n${newPlan.treeString(verbose = true, addSuffix = false)}") + } newPlan } diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index 845803d133..8a1c576a2d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -49,6 +49,12 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { op: T, builder: Operator.Builder, childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // [#4515 instrumentation] + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + log.warn( + s"CometSink[${this.getClass.getSimpleName}].convert op=${op.getClass.getName} " + + s"simpleString='${op.simpleStringWithNodeId()}' output=${op.output} " + + s"output.size=${op.output.size}") val supportedTypes = op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) @@ -72,6 +78,53 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { } if (scanTypes.length == op.output.length) { + // [#4515 instrumentation] Log when we synthesize a Scan with zero declared columns. + // The runtime JVM iterator may still produce columns (subquery output shrunk by + // catalyst before serialization while the underlying RDD reflects the pre-shrink shape), + // tripping the column-count guard in NativeUtil.exportBatch. + if (scanTypes.isEmpty) { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + // scalastyle:off line.size.limit + val childInfo = op.children.zipWithIndex + .map { case (c, i) => + val canonOut = scala.util + .Try(c.canonicalized.output) + .toOption + .map(_.toString) + .getOrElse("") + s" child[$i] cls=${c.getClass.getName} simpleString='${c.simpleString( + 80)}' output=${c.output} outputSize=${c.output.size} identityHash=${System + .identityHashCode(c)} canonicalized.output=$canonOut" + } + .mkString("\n") + val opCanonOut = scala.util + .Try(op.canonicalized.output) + .toOption + .map(_.toString) + .getOrElse("") + val subqueryInfo = scala.util + .Try(op.subqueries.map(s => s"${s.getClass.getName}(output=${s.output}, prepared=?)")) + .toOption + .getOrElse(Nil) + .mkString("[", ", ", "]") + val callerStack = + new RuntimeException("[#4515] CometSink 0-col Scan caller").getStackTrace + .take(20) + .map(f => s" at ${f}") + .mkString("\n") + log.warn(s"CometSink synthesizing 0-col Scan for op=${op.getClass.getName}\n" + + s" simpleString='${op.simpleStringWithNodeId()}'\n" + + s" op.output=${op.output} op.outputSet=${op.outputSet} op.references=${op.references}\n" + + s" op.canonicalized.output=$opCanonOut\n" + + s" op.subqueries=$subqueryInfo\n" + + s" op identityHash=${System.identityHashCode(op)}\n" + + s" children classes=${op.children.map(_.getClass.getSimpleName).mkString("[", ",", "]")}\n" + + childInfo + "\n" + + s" caller stack:\n$callerStack\n" + + s" op tree:\n${op.treeString(verbose = true, addSuffix = false)}") + // scalastyle:on line.size.limit + } + scanBuilder.addAllFields(scanTypes.asJava) // Sink operators don't have children @@ -94,6 +147,27 @@ object CometExchangeSink extends CometSink[SparkPlan] { op: SparkPlan, builder: Operator.Builder, childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { + // [#4515 instrumentation] + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + val isVanillaSparkExchange = + op.getClass.getName == "org.apache.spark.sql.execution.exchange.ShuffleExchangeExec" + log.warn( + s"CometExchangeSink.convert op=${op.getClass.getName} " + + s"simpleString='${op.simpleStringWithNodeId()}' output=${op.output} " + + s"useShuffleScan=${shouldUseShuffleScan(op)} " + + s"children=${op.children.map(_.getClass.getSimpleName).mkString("[", ",", "]")}") + if (isVanillaSparkExchange) { + val callerStack = + new RuntimeException("[#4515] vanilla ShuffleExchangeExec caller").getStackTrace + .take(20) + .map(f => s" at ${f}") + .mkString("\n") + log.warn( + " vanilla ShuffleExchangeExec being processed by CometExchangeSink:\n" + + s" output=${op.output}\n" + + s" caller stack:\n$callerStack\n" + + s" op tree:\n${op.treeString(verbose = true, addSuffix = false)}") + } if (shouldUseShuffleScan(op)) { convertToShuffleScan(op, builder) } else { diff --git a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 4f027cd9e7..388f5114e6 100644 --- a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -114,6 +114,28 @@ class NativeUtil { batch: ColumnarBatch): Int = { val numRows = mutable.ArrayBuffer.empty[Int] + if (arrayAddrs.length != batch.numCols() || schemaAddrs.length != batch.numCols()) { + val schemaSummary = (0 until batch.numCols()) + .map { i => + val v = batch.column(i) match { + case cv: CometVector => cv.getValueVector + case _ => null + } + if (v != null) s"col[$i]: ${v.getField}" + else s"col[$i]: ${batch.column(i).getClass.getName}" + } + .mkString("; ") + val taskAttempt = Option(org.apache.spark.TaskContext.get()) + .map(c => s"stage=${c.stageId} task=${c.taskAttemptId} partition=${c.partitionId}") + .getOrElse("no-task") + throw new SparkException( + "CometBatchIterator column-count mismatch [#4515 instrumentation]: " + + s"native expected arrayAddrs=${arrayAddrs.length}, schemaAddrs=${schemaAddrs.length}; " + + s"JVM iterator produced batch.numCols=${batch.numCols()} ($taskAttempt). " + + s"Batch schema: $schemaSummary", + new RuntimeException("placeholder for exportBatch column-count mismatch")) + } + (0 until batch.numCols()).foreach { index => batch.column(index) match { case a: CometVector => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 1470d637d9..ba152c2f13 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -243,6 +243,23 @@ case class CometShuffleExchangeExec( * Comet returns RDD[ColumnarBatch] for columnar execution. */ protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { + // [#4515 instrumentation] Track every doExecuteColumnar call: which CometShuffleExchange + // instance, its output, and the call site. Helps confirm whether this PR's changes + // cause an extra EnsureRequirements-inserted vanilla Exchange to wrap us, and what + // RDD is being plumbed where. + { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + val callerStack = new RuntimeException( + "[#4515] CometShuffleExchangeExec.doExecuteColumnar caller").getStackTrace + .take(15) + .map(f => s" at ${f}") + .mkString("\n") + log.warn( + s"CometShuffleExchangeExec.doExecuteColumnar this=${System.identityHashCode(this)} " + + s"shuffleType=$shuffleType outputPartitioning=$outputPartitioning " + + s"output=$output\n" + + s" caller stack:\n$callerStack") + } // Returns the same CometShuffledBatchRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { cachedShuffleRDD = new CometShuffledBatchRDD(shuffleDependency, readMetrics) @@ -687,6 +704,22 @@ object CometShuffleExchangeExec serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + // [#4515 instrumentation] Track every placeholder-Scan ShuffleDependency we build, with + // outputAttributes (drives Scan declared schema) and call site. + { + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + val callerStack = + new RuntimeException("[#4515] prepareShuffleDependency caller").getStackTrace + .take(15) + .map(f => s" at ${f}") + .mkString("\n") + log.warn( + s"prepareShuffleDependency outputAttributes=$outputAttributes " + + s"outputAttributes.size=${outputAttributes.size} " + + s"outputPartitioning=$outputPartitioning rdd.numPartitions=${rdd.getNumPartitions}\n" + + s" caller stack:\n$callerStack") + } + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") val scanTypes = outputAttributes.flatten { attr => QueryPlanSerde.serializeDataType(attr.dataType) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index 7604910b06..e0e5c45b2a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -162,7 +162,34 @@ class CometShuffledBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val reader = createReader(split, context) // TODO: Reads IPC by native code - reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) + val raw = reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) + // [#4515 instrumentation] Peek the first decoded batch to confirm wire schema vs caller + // expectations. Wraps so we don't consume the iterator. + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + new Iterator[ColumnarBatch] { + private var logged = false + override def hasNext: Boolean = raw.hasNext + override def next(): ColumnarBatch = { + val b = raw.next() + if (!logged) { + logged = true + val schemaSummary = (0 until b.numCols()) + .map { i => + val v = b.column(i) match { + case cv: org.apache.comet.vector.CometVector => cv.getValueVector + case _ => null + } + if (v != null) s"col[$i]: ${v.getField}" + else s"col[$i]: ${b.column(i).getClass.getName}" + } + .mkString("; ") + log.warn(s"CometShuffledBatchRDD.compute first decoded batch: numCols=${b.numCols()} " + + s"numRows=${b.numRows()} stage=${context.stageId()} task=${context.taskAttemptId()} " + + s"partition=${split.index} schema=[$schemaSummary]") + } + b + } + } } override def clearDependencies(): Unit = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index fa7aa6f1b2..310767489d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -422,7 +422,15 @@ abstract class CometNativeExec extends CometExec { /** The Comet native operator */ def nativeOp: Operator - override protected def doPrepare(): Unit = prepareSubqueries(this) + override protected def doPrepare(): Unit = { + // [#4515 instrumentation] Track when subqueries are prepared for this CometNativeExec. + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + log.warn( + s"CometNativeExec.doPrepare this=${System.identityHashCode(this)} " + + s"cls=${this.getClass.getName} " + + s"originalPlan.cls=${originalPlan.getClass.getName}") + prepareSubqueries(this) + } override lazy val metrics: Map[String, SQLMetric] = CometMetricNode.baselineMetrics(sparkContext) @@ -563,6 +571,14 @@ abstract class CometNativeExec extends CometExec { // same partition number. But for Comet, we need to zip them so we need to adjust the // partition number of Broadcast RDDs to make sure they have the same partition number. sparkPlans.zipWithIndex.foreach { case (plan, idx) => + // [#4515 instrumentation] Log every JVM-side input plan we wire to native, so we can + // correlate the Scan's declared schema with the runtime plan whose RDD feeds it. + val log = org.slf4j.LoggerFactory.getLogger("[#4515]") + log.warn( + s"buildNativeContext binding input[$idx] cls=${plan.getClass.getName} " + + s"simpleString='${plan.simpleStringWithNodeId()}' output=${plan.output} " + + s"output.size=${plan.output.size} identityHash=${System.identityHashCode(plan)}\n" + + s" subtree:\n${plan.treeString(verbose = true, addSuffix = false)}") plan match { case c: CometBroadcastExchangeExec => inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) @@ -1516,6 +1532,25 @@ trait CometBaseAggregate { // If the aggregateExpressions is empty, we only want to build groupingExpressions, // and skip processing of aggregateExpressions. if (aggregateExpressions.isEmpty) { + // [#4515 instrumentation] Track HashAgg serializations with empty resultExpressions. + // Catalyst prunes resultExpressions for EXISTS / row-existence-only subqueries; the + // native side currently interprets empty result_exprs as "use aggregate natural + // schema", which leaks grouping keys into the output. + org.slf4j.LoggerFactory + .getLogger("[#4515]") + .warn( + "HashAgg empty-aggExprs branch: " + + s"groupingExprs=${groupingExpressions} " + + s"resultExpressions=${resultExpressions} " + + s"resultExpressions.size=${resultExpressions.size} " + + s"aggregate.output=${aggregate.output} " + + s"aggregate.output.size=${aggregate.output.size} " + + s"modes(from aggExprs)=${modes} " + + s"sparkFinalMode=$sparkFinalMode " + + s"requiredChildDistribution=${aggregate.requiredChildDistribution} " + + s"isProjectionToEmpty=${resultExpressions.isEmpty && aggregate.output.isEmpty} " + + s"naturalEqualsIntent=${resultExpressions.map(_.toAttribute) == groupingExpressions + .map(_.toAttribute)}") val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes @@ -1528,6 +1563,14 @@ trait CometBaseAggregate { return None } hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + // Force native to apply the projection when Spark's expected output (which may + // be empty for catalyst-pruned EXISTS / row-existence-only subqueries) differs + // from the aggregate's natural grouping output. Without this, an empty proto + // result_exprs is indistinguishable from "no projection needed" and the natural + // grouping keys leak through. + val naturalEqualsIntent = + resultExpressions.map(_.toAttribute) == groupingExpressions.map(_.toAttribute) + hashAggBuilder.setApplyResultProjection(!naturalEqualsIntent) Some(builder.setHashAgg(hashAggBuilder).build()) } else { // Validate mode combinations. We support: @@ -1608,6 +1651,9 @@ trait CometBaseAggregate { return None } hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) + // Final aggs always need the projection (matches the existing condition under + // which we serialize result_exprs at all). See comet#4515. + hashAggBuilder.setApplyResultProjection(true) } hashAggBuilder.setModeValue(mode.getNumber) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index cd0beb56cc..f098040878 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2109,4 +2109,97 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Regression for comet#4515: a HashAggregateExec whose `resultExpressions` (and therefore + // `output`) catalyst has pruned to empty must still produce 0-col batches at runtime. + // Catalyst prunes `resultExpressions=[]` for plans where the aggregate's column values are + // unused downstream - classically EXISTS subqueries that get rewritten into a literal-`1` + // wrapper. Before the fix, the native HashAggregate emitted its natural output (the + // grouping keys) regardless of the pruned JVM `output`, so any boundary that derived a + // schema from `output` (e.g. a wrapping vanilla Spark `ShuffleExchangeExec.output = + // child.output = []`, or a JVM-bridge Scan synthesized from the same `output`) declared + // 0 columns while the runtime RDD produced 1. The mismatch tripped + // `NativeUtil.exportBatch` with an ArrayIndexOutOfBoundsException on a length-0 + // schemaAddrs[]. + // + // The bug needs a specific catalyst optimizer state: `HashAggregateExec. + // resultExpressions.isEmpty`. Whether the optimizer reaches that state from a given SQL + // depends on Spark version, scan source (LocalTableScan vs Parquet/native), AQE state, + // and which Comet rules already fired - we observed it in the SQL-tests harness running + // `subquery/exists-subquery/exists-orderby-limit.sql` on Spark 4.0.2 with parquet-backed + // temp views, but not via `Dataset.collect` over `LocalTableScan`-backed temp views + // under `CometTestBase`. So the test below writes parquet (matching the harness's scan + // shape), tries several known triggers, runs whichever (if any) produces the bug shape + // under `checkSparkAnswer`, and skips cleanly otherwise. The upstream SQL-tests run + // remains the primary safety net for the harness-only path. + test("HashAggregate with catalyst-pruned resultExpressions returns 0-col output (#4515)") { + withTempDir { dir => + withTempView("emp", "dept") { + // Write parquet so Comet's native scan path (vs LocalTableScan) is the source - + // matches the SQL-tests harness setup that surfaced the bug. + val empPath = new java.io.File(dir, "emp").getAbsolutePath + val deptPath = new java.io.File(dir, "dept").getAbsolutePath + + spark + .sql("""SELECT * FROM VALUES + | (100, 'emp 1', 100.0D, 10), + | (200, 'emp 2', 200.0D, 10), + | (300, 'emp 3', 300.0D, 20), + | (400, 'emp 4', 400.0D, 30), + | (500, 'emp 5', 400.0D, NULL), + | (700, 'emp 7', 400.0D, 100), + | (800, 'emp 8', 150.0D, 70) + |AS t(id, emp_name, salary, dept_id)""".stripMargin) + .write + .parquet(empPath) + spark + .sql("""SELECT * FROM VALUES + | (10, 'CA'), (20, 'NY'), (30, 'TX'), + | (40, 'OR'), (50, 'NJ'), (70, 'FL') + |AS t(dept_id, state)""".stripMargin) + .write + .parquet(deptPath) + + spark.read.parquet(empPath).createOrReplaceTempView("emp") + spark.read.parquet(deptPath).createOrReplaceTempView("dept") + + val candidates = Seq( + // The original failing SQL from the harness - EXISTS with grouped agg + LIMIT/OFFSET. + """SELECT * FROM emp + |WHERE EXISTS ( + | SELECT max(dept.dept_id) FROM dept GROUP BY state LIMIT 1 OFFSET 2)""".stripMargin, + // Inline view + outer constant: ColumnPruning may strip the inner agg's output. + """SELECT 1 FROM ( + | SELECT max(dept_id) FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin, + // Scalar subquery returning a constant. + """SELECT (SELECT 1 FROM dept GROUP BY state LIMIT 1 OFFSET 2)""".stripMargin, + // count(*) over a derived table: outer doesn't reference inner cols. + """SELECT count(*) FROM ( + | SELECT max(dept_id) AS m FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin) + + // Find a candidate whose plan has a HashAggregateExec (Spark or Comet) with empty + // resultExpressions. collectWithSubqueries traverses Subquery nodes too. + val triggering = candidates.find { sql => + val plan = spark.sql(sql).queryExecution.executedPlan + collectWithSubqueries(plan) { + case a: org.apache.spark.sql.execution.aggregate.HashAggregateExec + if a.resultExpressions.isEmpty => + a + case a: CometHashAggregateExec if a.resultExpressions.isEmpty => a + }.nonEmpty + } + + triggering match { + case Some(sql) => checkSparkAnswer(sql) + case None => + cancel( + "No candidate query produced a HashAggregateExec with empty resultExpressions " + + "in this environment. The catalyst-pruned shape that exercises #4515 only " + + "appears under specific optimizer/AQE state we couldn't reproduce here. The " + + "upstream SQL-tests run (subquery/exists-subquery/exists-orderby-limit.sql) " + + "covers this path.") + } + } + } + } + } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 16601d056b..fded42d050 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3959,6 +3959,117 @@ class CometExecSuite extends CometTestBase { } } + // Repro for the Spark 3.5 SQL-tests failure on subquery/exists-subquery/exists-orderby-limit.sql + // (query #19) on branch opt_native_shuffle. Crashes with NativeUtil.exportBatch:132 AIOOBE + // on this branch but passes on main. The SQL file declares three CONFIG_DIM1 combos for + // codegen; the failure trace doesn't say which one fires it, so sweep all three. + test("EXISTS subquery with GROUP BY + LIMIT + OFFSET") { + withTempView("emp", "dept") { + spark + .sql("""SELECT * FROM VALUES + | (100, 'emp 1', date '2005-01-01', 100.0D, 10), + | (200, 'emp 2', date '2003-01-01', 200.0D, 10), + | (300, 'emp 3', date '2002-01-01', 300.0D, 20), + | (400, 'emp 4', date '2005-01-01', 400.0D, 30), + | (500, 'emp 5', date '2001-01-01', 400.0D, NULL), + | (700, 'emp 7', date '2010-01-01', 400.0D, 100), + | (800, 'emp 8', date '2016-01-01', 150.0D, 70) + |AS t(id, emp_name, hiredate, salary, dept_id)""".stripMargin) + .createOrReplaceTempView("emp") + spark + .sql("""SELECT * FROM VALUES + | (10, 'dept 1', 'CA'), + | (20, 'dept 2', 'NY'), + | (30, 'dept 3', 'TX'), + | (40, 'dept 4 - unassigned', 'OR'), + | (50, 'dept 5 - unassigned', 'NJ'), + | (70, 'dept 7', 'FL') + |AS t(dept_id, dept_name, state)""".stripMargin) + .createOrReplaceTempView("dept") + + val configDims = Seq( + Map("spark.sql.codegen.wholeStage" -> "true"), + Map( + "spark.sql.codegen.wholeStage" -> "false", + "spark.sql.codegen.factoryMode" -> "CODEGEN_ONLY"), + Map( + "spark.sql.codegen.wholeStage" -> "false", + "spark.sql.codegen.factoryMode" -> "NO_CODEGEN")) + + // Mirror the SQL test order: queries #1 through #16 from the file run before #17 (the + // one whose CI failure we're chasing). Query #11's subquery is identical except for + // the OFFSET; #13 and #15 share similar shapes. Subquery materialization / AQE plan + // cache state from running them first may alter query #17's executed shape. + val priorQueries = Seq( + // #1 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate""", + // #2 + """SELECT id, hiredate FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate DESC""", + // #3 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 1) ORDER BY hiredate""", + // #4 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 0) ORDER BY hiredate""", + // #5 + """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate""", + // #6 + """SELECT emp_name FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY state ORDER BY state)""", + // #7 + """SELECT count(*) FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY dept_id ORDER BY dept_id)""", + // #8 + """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 1) ORDER BY hiredate""", + // #9 + """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 0) ORDER BY hiredate""", + // #10 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 10 LIMIT 1)""", + // #11 - same subquery as #17 minus the OFFSET + """SELECT * FROM emp WHERE EXISTS (SELECT max(dept.dept_id) FROM dept GROUP BY state LIMIT 1)""", + // #12 + """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 100 LIMIT 1)""", + // #13 + """SELECT * FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) FROM dept WHERE dept.dept_id > 100 GROUP BY state LIMIT 1)""", + // #14 + """SELECT emp_name FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY state ORDER BY state LIMIT 2 OFFSET 1)""", + // #15 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 10 LIMIT 1 OFFSET 2)""", + // #16 + """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > emp.dept_id LIMIT 1)""") + + for (dim <- configDims) { + // SQL-tests harness (dev/diffs/3.5.8.diff, 4.1.1.diff) sets only: + // spark.comet.enabled, spark.comet.exec.enabled, spark.comet.exec.shuffle.enabled, + // spark.comet.parquet.respectFilterPushdown, spark.shuffle.manager, + // spark.comet.memoryOverhead. + // It does NOT enable spark.comet.sparkToColumnar.enabled, but CometTestBase does. + // Force the harness shape so the reproducer matches the failing config. + val harnessConf = dim ++ Map(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "false") + withSQLConf(harnessConf.toSeq: _*) { + // scalastyle:off println + println(s"\n===== config dim: $harnessConf =====") + for ((q, idx) <- priorQueries.zipWithIndex) { + spark.sql(q).collect() + println(s"--- ran prior query #${idx + 1} ---") + } + val sql = """SELECT * + |FROM emp + |WHERE EXISTS (SELECT max(dept.dept_id) + | FROM dept + | GROUP BY state + | LIMIT 1 + | OFFSET 2)""".stripMargin + val df = spark.sql(sql) + println("--- query #17 initial executedPlan ---") + println(df.queryExecution.executedPlan) + val rows = df.collect() + println("--- query #17 final (post-AQE) executedPlan ---") + println(df.queryExecution.executedPlan) + println(s"--- ${rows.length} rows ---") + // scalastyle:on println + checkSparkAnswer(sql) + } + } + } + } + } case class BucketedTableTestSpec( From 1b55b97132f627b4e5257f44ca144f8d70c40494 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 14:21:14 -0400 Subject: [PATCH 22/39] refactor to handle on JVM side. --- native/core/src/execution/planner.rs | 63 ++-------- native/proto/src/proto/operator.proto | 14 +-- .../apache/spark/sql/comet/operators.scala | 113 ++++++++++++------ 3 files changed, 99 insertions(+), 91 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index bf90eb20cc..8abc19c1f5 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1036,10 +1036,9 @@ impl PhysicalPlanner { OpStruct::HashAgg(agg) => { // [#4515 instrumentation] Every HashAgg construction with proto sizes. log::warn!( - "[#4515] OpStruct::HashAgg grouping_exprs.len={} agg_exprs.len={} result_exprs.len={} mode={}", + "[#4515] OpStruct::HashAgg grouping_exprs.len={} agg_exprs.len={} mode={}", agg.grouping_exprs.len(), agg.agg_exprs.len(), - agg.result_exprs.len(), agg.mode ); assert_eq!(children.len(), 1); @@ -1171,55 +1170,19 @@ impl PhysicalPlanner { Arc::clone(&schema), )?, ); - let result_exprs: PhyExprResult = agg - .result_exprs - .iter() - .enumerate() - .map(|(idx, expr)| { - self.create_expr(expr, aggregate.schema()) - .map(|r| (r, format!("col_{idx}"))) - }) - .collect(); - if !agg.apply_result_projection { - // [#4515 instrumentation] Confirm whether a native HashAggregate is being - // built with empty result_exprs (catalyst-pruned EXISTS / count(*) subquery) - // and what its natural schema would be. - log::warn!( - "[#4515] HashAgg apply_result_projection=false: emitting aggregate natural schema={:?} grouping_exprs.len={} agg_exprs.len={} result_exprs.len={}", - aggregate.schema(), - agg.grouping_exprs.len(), - agg.agg_exprs.len(), - agg.result_exprs.len() - ); - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), - )) - } else { - // For final aggregation, DF's hash aggregate exec doesn't support Spark's - // aggregate result expressions like `COUNT(col) + 1`, but instead relying - // on additional `ProjectionExec` to handle the case. Therefore, here we'll - // add a projection node on top of the aggregate node. - // - // Note that `result_exprs` should only be set for final aggregation on the - // Spark side. - let projection = Arc::new(ProjectionExec::try_new( - result_exprs?, - Arc::clone(&aggregate), - )?); - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - projection, - vec![child], - vec![aggregate], - )), - )) - } + // The native HashAggregate emits its natural shape (group keys + agg + // results / state). Any post-aggregate projection Spark catalyst declares + // (`COUNT(col) + 1`, EXISTS-pruned-to-empty output, alias renames, etc.) is + // expressed as an explicit `OpStruct::Projection` op above the aggregate + // by the JVM serializer (see `CometBaseAggregate.doConvert`). Keeping that + // logic on the JVM side means only one place decides plan shape, and the + // native side stays a faithful executor. See comet#4515. + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), + )) } OpStruct::Limit(limit) => { assert_eq!(children.len(), 1); diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 921e5bbf35..aaff3cc4de 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,7 +294,13 @@ message Sort { message HashAggregate { repeated spark.spark_expression.Expr grouping_exprs = 1; repeated spark.spark_expression.AggExpr agg_exprs = 2; - repeated spark.spark_expression.Expr result_exprs = 3; + // Field 3 (`result_exprs`) and field 8 (`apply_result_projection`) were used to apply a + // post-aggregate projection inside the HashAggregate operator. The same effect is now + // expressed by emitting an explicit `Projection` op above the `HashAggregate` from the + // JVM serializer when needed (see `CometBaseAggregate.doConvert`). Reserved to avoid + // accidental reuse at incompatible semantics. + reserved 3, 8; + reserved "result_exprs", "apply_result_projection"; AggregateMode mode = 5; // Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial). // When set, each entry corresponds to agg_exprs at the same index. @@ -303,12 +309,6 @@ message HashAggregate { // Offset in the child's output where aggregate buffer attributes start. // Used by PartialMerge to locate state fields in the input. int32 initial_input_buffer_offset = 7; - // Whether the native planner should wrap the aggregate in a ProjectionExec using - // `result_exprs`. Disambiguates `result_exprs` empty-because-absent (no projection - // needed) from empty-because-catalyst-pruned-to-zero-cols (project to nothing). - // Without this bit, both wire as `repeated` length 0 and were collapsed to "no - // projection," leaking grouping keys when Spark intent was 0 columns. - bool apply_result_projection = 8; } message Limit { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 310767489d..3240adcc01 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1553,25 +1553,48 @@ trait CometBaseAggregate { .map(_.toAttribute)}") val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None + // The native HashAggregate emits its natural shape (the grouping keys, since there + // are no aggregate functions). When Spark catalyst declares a different output - + // either column-renaming via aliases, or an entirely empty output for catalyst-pruned + // EXISTS / row-existence-only subqueries - we wrap the HashAggregate in an explicit + // Projection op so the native side reshapes accordingly. See comet#4515: an empty + // declared output paired with the natural grouping-key output crashed downstream + // boundaries that derived their schema from the declared output. + val naturalOutput = groupingExpressions.map(_.toAttribute) + val needsProjection = resultExpressions.map(_.toAttribute) != naturalOutput + if (needsProjection) { + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + // Build the inner HashAgg op carrying the original child operators from `builder`. + // Use a fresh outer builder for the Projection so it gets a single child (the + // HashAgg op), not the original children appended on top. Both ops share the same + // plan_id so the inner aggregate's native metrics roll up under the same Spark + // operator in the metric tree (otherwise they'd orphan against plan_id=0). + val hashAggOp = OperatorOuterClass.Operator + .newBuilder() + .setPlanId(builder.getPlanId) + .addAllChildren(builder.getChildrenList) + .setHashAgg(hashAggBuilder) + .build() + val projectionBuilder = OperatorOuterClass.Projection.newBuilder() + projectionBuilder.addAllProjectList(resultExprs.map(_.get).asJava) + Some( + OperatorOuterClass.Operator + .newBuilder() + .setPlanId(builder.getPlanId) + .addChildren(hashAggOp) + .setProjection(projectionBuilder) + .build()) + } else { + Some(builder.setHashAgg(hashAggBuilder).build()) } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - // Force native to apply the projection when Spark's expected output (which may - // be empty for catalyst-pruned EXISTS / row-existence-only subqueries) differs - // from the aggregate's natural grouping output. Without this, an empty proto - // result_exprs is indistinguishable from "no projection needed" and the natural - // grouping keys leak through. - val naturalEqualsIntent = - resultExpressions.map(_.toAttribute) == groupingExpressions.map(_.toAttribute) - hashAggBuilder.setApplyResultProjection(!naturalEqualsIntent) - Some(builder.setHashAgg(hashAggBuilder).build()) } else { // Validate mode combinations. We support: // - All Partial @@ -1640,21 +1663,6 @@ trait CometBaseAggregate { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - // Final aggs always need the projection (matches the existing condition under - // which we serialize result_exprs at all). See comet#4515. - hashAggBuilder.setApplyResultProjection(true) - } hashAggBuilder.setModeValue(mode.getNumber) // Send per-expression modes and buffer offset for PartialMerge handling @@ -1673,7 +1681,44 @@ trait CometBaseAggregate { hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset) } - Some(builder.setHashAgg(hashAggBuilder).build()) + // Final aggregations may carry a result projection (e.g. `COUNT(col) + 1`) that + // catalyst encodes via `resultExpressions`. DataFusion's hash aggregate only emits + // its natural shape (group keys + agg results), so we wrap the HashAggregate in + // an explicit Projection op to apply Spark's result expressions. Partial / + // PartialMerge aggregates emit raw state buffers and never need the projection. + // See comet#4515. + if (mode == CometAggregateMode.Final) { + val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes + val resultExprs = resultExpressions.map(exprToProto(_, attributes)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + // Inner HashAgg keeps the original input children from `builder`. Outer + // Projection uses a fresh builder so it has a single child (the HashAgg op). + // Both ops share the same plan_id so the aggregate's native metrics aggregate + // under the same Spark operator (else they'd orphan against plan_id=0). + val hashAggOp = OperatorOuterClass.Operator + .newBuilder() + .setPlanId(builder.getPlanId) + .addAllChildren(builder.getChildrenList) + .setHashAgg(hashAggBuilder) + .build() + val projectionBuilder = OperatorOuterClass.Projection.newBuilder() + projectionBuilder.addAllProjectList(resultExprs.map(_.get).asJava) + Some( + OperatorOuterClass.Operator + .newBuilder() + .setPlanId(builder.getPlanId) + .addChildren(hashAggOp) + .setProjection(projectionBuilder) + .build()) + } else { + Some(builder.setHashAgg(hashAggBuilder).build()) + } } else { val allChildren: Seq[Expression] = groupingExpressions ++ aggregateExpressions ++ aggregateAttributes From d736dd5b7194c4363fb8a62d17e37b56efb2d8a3 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 14:36:21 -0400 Subject: [PATCH 23/39] remove instrumentation. --- native/core/src/execution/planner.rs | 17 --- .../org/apache/comet/CometBatchIterator.java | 16 --- .../org/apache/comet/CometExecIterator.scala | 21 ---- .../apache/comet/rules/CometExecRule.scala | 24 ---- .../apache/comet/rules/CometScanRule.scala | 11 -- .../comet/serde/operator/CometSink.scala | 74 ------------ .../org/apache/comet/vector/NativeUtil.scala | 22 ---- .../shuffle/CometShuffleExchangeExec.scala | 33 ------ .../shuffle/CometShuffledRowRDD.scala | 29 +---- .../apache/spark/sql/comet/operators.scala | 37 +----- .../apache/comet/exec/CometExecSuite.scala | 111 ------------------ 11 files changed, 2 insertions(+), 393 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 8abc19c1f5..18c3acf2dd 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1034,13 +1034,6 @@ impl PhysicalPlanner { )) } OpStruct::HashAgg(agg) => { - // [#4515 instrumentation] Every HashAgg construction with proto sizes. - log::warn!( - "[#4515] OpStruct::HashAgg grouping_exprs.len={} agg_exprs.len={} mode={}", - agg.grouping_exprs.len(), - agg.agg_exprs.len(), - agg.mode - ); assert_eq!(children.len(), 1); let (scans, shuffle_scans, child) = self.create_plan(&children[0], inputs, partition_count)?; @@ -1424,16 +1417,6 @@ impl PhysicalPlanner { OpStruct::Scan(scan) => { let data_types = scan.fields.iter().map(to_arrow_datatype).collect_vec(); - // [#4515 instrumentation] Log every JVM-bridge Scan's declared column count. - // A 0-column scan paired with a JVM iterator producing batches with columns is - // the AIOOBE-on-exportBatch shape; a non-empty list confirms the inverse. - log::info!( - "[#4515] ScanExec source='{}' declared {} cols: {:?}", - scan.source, - data_types.len(), - data_types - ); - // If it is not test execution context for unit test, we should have at least one // input source if self.exec_context_id != TEST_EXEC_CONTEXT_ID && inputs.is_empty() { diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java index 1392f028c0..9b48a47c57 100644 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ b/spark/src/main/java/org/apache/comet/CometBatchIterator.java @@ -38,8 +38,6 @@ public class CometBatchIterator { private final NativeUtil nativeUtil; private ColumnarBatch previousBatch = null; private ColumnarBatch currentBatch = null; - // [#4515 instrumentation] gate first-batch log per instance - private boolean loggedFirstBatch = false; CometBatchIterator(Iterator input, NativeUtil nativeUtil) { this.input = input; @@ -81,20 +79,6 @@ public int next(long[] arrayAddrs, long[] schemaAddrs) { return -1; } - // [#4515 instrumentation] Log first-batch shape per CometBatchIterator instance. - if (!loggedFirstBatch) { - loggedFirstBatch = true; - org.slf4j.LoggerFactory.getLogger("[#4515]") - .warn( - "CometBatchIterator.next first batch: numCols={} numRows={} arrayAddrs.length={} schemaAddrs.length={} inputCls={} this={}", - currentBatch.numCols(), - currentBatch.numRows(), - arrayAddrs.length, - schemaAddrs.length, - input.getClass().getName(), - System.identityHashCode(this)); - } - // export the batch using the Arrow C Data Interface int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 98a9bc0684..6140eca553 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -92,27 +92,6 @@ class CometExecIterator( val conf = SparkEnv.get.conf val localDiskDirs = SparkEnv.get.blockManager.getLocalDiskDirs - // [#4515 instrumentation] Dump proto plan + per-input iterator class so we can correlate - // a 0-col Scan in the proto with the JVM iterator that will feed it. - { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - val parsed = - scala.util - .Try( - org.apache.comet.serde.OperatorOuterClass.Operator - .parseFrom(protobufQueryPlan) - .toString) - .toOption - .getOrElse("") - val itersDesc = inputIterators.zipWithIndex - .map { case (it, idx) => s" inputIterators[$idx] cls=${it.getClass.getName}" } - .mkString("\n") - log.warn( - s"CometExecIterator constructing plan id=$id partition=$partitionIndex " + - s"numParts=$numParts numOutputCols=$numOutputCols\n$itersDesc\n" + - s" proto plan:\n$parsed") - } - // serialize Comet related Spark configs in protobuf format val protobufSparkConfigs = CometExecIterator.serializeCometSQLConfs() diff --git a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala index 6eb314bbb6..aeb7db40ad 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometExecRule.scala @@ -166,16 +166,6 @@ case class CometExecRule(session: SparkSession) val reverted = s.originalPlan.withNewChildren(Seq(s.child)).asInstanceOf[ShuffleExchangeExec] reverted.setTagValue(CometExecRule.SKIP_COMET_SHUFFLE_TAG, ()) - // [#4515 instrumentation] Log every revert: this is a primary place where - // vanilla ShuffleExchangeExec instances enter the plan post-Comet-rewrite. - org.slf4j.LoggerFactory - .getLogger("[#4515]") - .warn( - s"revertRedundantColumnarShuffle revert: parentAgg=${op.getClass.getSimpleName} " + - s"(output=${op.output}) cometShuffle.output=${s.output} " + - s"reverted.output=${reverted.output} reverted.identityHash=${System - .identityHashCode(reverted)}\n" + - s" reverted tree:\n${reverted.treeString(verbose = true, addSuffix = false)}") logInfo( "Reverting Comet columnar shuffle to Spark shuffle between " + s"${op.getClass.getSimpleName} and ${s.child.getClass.getSimpleName} " + @@ -556,20 +546,6 @@ case class CometExecRule(session: SparkSession) |${sideBySide(plan.treeString, newPlan.treeString).mkString("\n")} |""".stripMargin) } - // [#4515 instrumentation] Dump the plan in/out of CometExecRule to correlate with the - // 0-col Scan synthesized later. Logs the operator-class diff so we can see what - // CometExecRule did or didn't replace, especially around ShuffleExchangeExec wrappers - // inside subqueries. - if (!newPlan.fastEquals(plan)) { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - log.warn( - "CometExecRule rewrote plan:\n IN classes: " + - plan.collect { case p => p.getClass.getSimpleName }.mkString(",") + - "\n OUT classes: " + - newPlan.collect { case p => p.getClass.getSimpleName }.mkString(",") + - s"\n IN tree:\n${plan.treeString(verbose = true, addSuffix = false)}" + - s"\n OUT tree:\n${newPlan.treeString(verbose = true, addSuffix = false)}") - } newPlan } diff --git a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala index 1f7f509db0..7601fa1c6b 100644 --- a/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala +++ b/spark/src/main/scala/org/apache/comet/rules/CometScanRule.scala @@ -68,17 +68,6 @@ case class CometScanRule(session: SparkSession) |${sideBySide(plan.treeString, newPlan.treeString).mkString("\n")} |""".stripMargin) } - // [#4515 instrumentation] - if (!newPlan.fastEquals(plan)) { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - log.warn( - "CometScanRule rewrote plan:\n IN classes: " + - plan.collect { case p => p.getClass.getSimpleName }.mkString(",") + - "\n OUT classes: " + - newPlan.collect { case p => p.getClass.getSimpleName }.mkString(",") + - s"\n IN tree:\n${plan.treeString(verbose = true, addSuffix = false)}" + - s"\n OUT tree:\n${newPlan.treeString(verbose = true, addSuffix = false)}") - } newPlan } diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index 8a1c576a2d..845803d133 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -49,12 +49,6 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { op: T, builder: Operator.Builder, childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { - // [#4515 instrumentation] - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - log.warn( - s"CometSink[${this.getClass.getSimpleName}].convert op=${op.getClass.getName} " + - s"simpleString='${op.simpleStringWithNodeId()}' output=${op.output} " + - s"output.size=${op.output.size}") val supportedTypes = op.output.forall(a => supportedDataType(a.dataType, allowComplex = true)) @@ -78,53 +72,6 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { } if (scanTypes.length == op.output.length) { - // [#4515 instrumentation] Log when we synthesize a Scan with zero declared columns. - // The runtime JVM iterator may still produce columns (subquery output shrunk by - // catalyst before serialization while the underlying RDD reflects the pre-shrink shape), - // tripping the column-count guard in NativeUtil.exportBatch. - if (scanTypes.isEmpty) { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - // scalastyle:off line.size.limit - val childInfo = op.children.zipWithIndex - .map { case (c, i) => - val canonOut = scala.util - .Try(c.canonicalized.output) - .toOption - .map(_.toString) - .getOrElse("") - s" child[$i] cls=${c.getClass.getName} simpleString='${c.simpleString( - 80)}' output=${c.output} outputSize=${c.output.size} identityHash=${System - .identityHashCode(c)} canonicalized.output=$canonOut" - } - .mkString("\n") - val opCanonOut = scala.util - .Try(op.canonicalized.output) - .toOption - .map(_.toString) - .getOrElse("") - val subqueryInfo = scala.util - .Try(op.subqueries.map(s => s"${s.getClass.getName}(output=${s.output}, prepared=?)")) - .toOption - .getOrElse(Nil) - .mkString("[", ", ", "]") - val callerStack = - new RuntimeException("[#4515] CometSink 0-col Scan caller").getStackTrace - .take(20) - .map(f => s" at ${f}") - .mkString("\n") - log.warn(s"CometSink synthesizing 0-col Scan for op=${op.getClass.getName}\n" + - s" simpleString='${op.simpleStringWithNodeId()}'\n" + - s" op.output=${op.output} op.outputSet=${op.outputSet} op.references=${op.references}\n" + - s" op.canonicalized.output=$opCanonOut\n" + - s" op.subqueries=$subqueryInfo\n" + - s" op identityHash=${System.identityHashCode(op)}\n" + - s" children classes=${op.children.map(_.getClass.getSimpleName).mkString("[", ",", "]")}\n" + - childInfo + "\n" + - s" caller stack:\n$callerStack\n" + - s" op tree:\n${op.treeString(verbose = true, addSuffix = false)}") - // scalastyle:on line.size.limit - } - scanBuilder.addAllFields(scanTypes.asJava) // Sink operators don't have children @@ -147,27 +94,6 @@ object CometExchangeSink extends CometSink[SparkPlan] { op: SparkPlan, builder: Operator.Builder, childOp: OperatorOuterClass.Operator*): Option[OperatorOuterClass.Operator] = { - // [#4515 instrumentation] - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - val isVanillaSparkExchange = - op.getClass.getName == "org.apache.spark.sql.execution.exchange.ShuffleExchangeExec" - log.warn( - s"CometExchangeSink.convert op=${op.getClass.getName} " + - s"simpleString='${op.simpleStringWithNodeId()}' output=${op.output} " + - s"useShuffleScan=${shouldUseShuffleScan(op)} " + - s"children=${op.children.map(_.getClass.getSimpleName).mkString("[", ",", "]")}") - if (isVanillaSparkExchange) { - val callerStack = - new RuntimeException("[#4515] vanilla ShuffleExchangeExec caller").getStackTrace - .take(20) - .map(f => s" at ${f}") - .mkString("\n") - log.warn( - " vanilla ShuffleExchangeExec being processed by CometExchangeSink:\n" + - s" output=${op.output}\n" + - s" caller stack:\n$callerStack\n" + - s" op tree:\n${op.treeString(verbose = true, addSuffix = false)}") - } if (shouldUseShuffleScan(op)) { convertToShuffleScan(op, builder) } else { diff --git a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala index 388f5114e6..4f027cd9e7 100644 --- a/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala +++ b/spark/src/main/scala/org/apache/comet/vector/NativeUtil.scala @@ -114,28 +114,6 @@ class NativeUtil { batch: ColumnarBatch): Int = { val numRows = mutable.ArrayBuffer.empty[Int] - if (arrayAddrs.length != batch.numCols() || schemaAddrs.length != batch.numCols()) { - val schemaSummary = (0 until batch.numCols()) - .map { i => - val v = batch.column(i) match { - case cv: CometVector => cv.getValueVector - case _ => null - } - if (v != null) s"col[$i]: ${v.getField}" - else s"col[$i]: ${batch.column(i).getClass.getName}" - } - .mkString("; ") - val taskAttempt = Option(org.apache.spark.TaskContext.get()) - .map(c => s"stage=${c.stageId} task=${c.taskAttemptId} partition=${c.partitionId}") - .getOrElse("no-task") - throw new SparkException( - "CometBatchIterator column-count mismatch [#4515 instrumentation]: " + - s"native expected arrayAddrs=${arrayAddrs.length}, schemaAddrs=${schemaAddrs.length}; " + - s"JVM iterator produced batch.numCols=${batch.numCols()} ($taskAttempt). " + - s"Batch schema: $schemaSummary", - new RuntimeException("placeholder for exportBatch column-count mismatch")) - } - (0 until batch.numCols()).foreach { index => batch.column(index) match { case a: CometVector => diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index ba152c2f13..1470d637d9 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -243,23 +243,6 @@ case class CometShuffleExchangeExec( * Comet returns RDD[ColumnarBatch] for columnar execution. */ protected override def doExecuteColumnar(): RDD[ColumnarBatch] = { - // [#4515 instrumentation] Track every doExecuteColumnar call: which CometShuffleExchange - // instance, its output, and the call site. Helps confirm whether this PR's changes - // cause an extra EnsureRequirements-inserted vanilla Exchange to wrap us, and what - // RDD is being plumbed where. - { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - val callerStack = new RuntimeException( - "[#4515] CometShuffleExchangeExec.doExecuteColumnar caller").getStackTrace - .take(15) - .map(f => s" at ${f}") - .mkString("\n") - log.warn( - s"CometShuffleExchangeExec.doExecuteColumnar this=${System.identityHashCode(this)} " + - s"shuffleType=$shuffleType outputPartitioning=$outputPartitioning " + - s"output=$output\n" + - s" caller stack:\n$callerStack") - } // Returns the same CometShuffledBatchRDD if this plan is used by multiple plans. if (cachedShuffleRDD == null) { cachedShuffleRDD = new CometShuffledBatchRDD(shuffleDependency, readMetrics) @@ -704,22 +687,6 @@ object CometShuffleExchangeExec serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - // [#4515 instrumentation] Track every placeholder-Scan ShuffleDependency we build, with - // outputAttributes (drives Scan declared schema) and call site. - { - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - val callerStack = - new RuntimeException("[#4515] prepareShuffleDependency caller").getStackTrace - .take(15) - .map(f => s" at ${f}") - .mkString("\n") - log.warn( - s"prepareShuffleDependency outputAttributes=$outputAttributes " + - s"outputAttributes.size=${outputAttributes.size} " + - s"outputPartitioning=$outputPartitioning rdd.numPartitions=${rdd.getNumPartitions}\n" + - s" caller stack:\n$callerStack") - } - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") val scanTypes = outputAttributes.flatten { attr => QueryPlanSerde.serializeDataType(attr.dataType) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala index e0e5c45b2a..7604910b06 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffledRowRDD.scala @@ -162,34 +162,7 @@ class CometShuffledBatchRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val reader = createReader(split, context) // TODO: Reads IPC by native code - val raw = reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) - // [#4515 instrumentation] Peek the first decoded batch to confirm wire schema vs caller - // expectations. Wraps so we don't consume the iterator. - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - new Iterator[ColumnarBatch] { - private var logged = false - override def hasNext: Boolean = raw.hasNext - override def next(): ColumnarBatch = { - val b = raw.next() - if (!logged) { - logged = true - val schemaSummary = (0 until b.numCols()) - .map { i => - val v = b.column(i) match { - case cv: org.apache.comet.vector.CometVector => cv.getValueVector - case _ => null - } - if (v != null) s"col[$i]: ${v.getField}" - else s"col[$i]: ${b.column(i).getClass.getName}" - } - .mkString("; ") - log.warn(s"CometShuffledBatchRDD.compute first decoded batch: numCols=${b.numCols()} " + - s"numRows=${b.numRows()} stage=${context.stageId()} task=${context.taskAttemptId()} " + - s"partition=${split.index} schema=[$schemaSummary]") - } - b - } - } + reader.read().asInstanceOf[Iterator[Product2[Int, ColumnarBatch]]].map(_._2) } override def clearDependencies(): Unit = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 3240adcc01..7ec8a2e898 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -422,15 +422,7 @@ abstract class CometNativeExec extends CometExec { /** The Comet native operator */ def nativeOp: Operator - override protected def doPrepare(): Unit = { - // [#4515 instrumentation] Track when subqueries are prepared for this CometNativeExec. - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - log.warn( - s"CometNativeExec.doPrepare this=${System.identityHashCode(this)} " + - s"cls=${this.getClass.getName} " + - s"originalPlan.cls=${originalPlan.getClass.getName}") - prepareSubqueries(this) - } + override protected def doPrepare(): Unit = prepareSubqueries(this) override lazy val metrics: Map[String, SQLMetric] = CometMetricNode.baselineMetrics(sparkContext) @@ -571,14 +563,6 @@ abstract class CometNativeExec extends CometExec { // same partition number. But for Comet, we need to zip them so we need to adjust the // partition number of Broadcast RDDs to make sure they have the same partition number. sparkPlans.zipWithIndex.foreach { case (plan, idx) => - // [#4515 instrumentation] Log every JVM-side input plan we wire to native, so we can - // correlate the Scan's declared schema with the runtime plan whose RDD feeds it. - val log = org.slf4j.LoggerFactory.getLogger("[#4515]") - log.warn( - s"buildNativeContext binding input[$idx] cls=${plan.getClass.getName} " + - s"simpleString='${plan.simpleStringWithNodeId()}' output=${plan.output} " + - s"output.size=${plan.output.size} identityHash=${System.identityHashCode(plan)}\n" + - s" subtree:\n${plan.treeString(verbose = true, addSuffix = false)}") plan match { case c: CometBroadcastExchangeExec => inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) @@ -1532,25 +1516,6 @@ trait CometBaseAggregate { // If the aggregateExpressions is empty, we only want to build groupingExpressions, // and skip processing of aggregateExpressions. if (aggregateExpressions.isEmpty) { - // [#4515 instrumentation] Track HashAgg serializations with empty resultExpressions. - // Catalyst prunes resultExpressions for EXISTS / row-existence-only subqueries; the - // native side currently interprets empty result_exprs as "use aggregate natural - // schema", which leaks grouping keys into the output. - org.slf4j.LoggerFactory - .getLogger("[#4515]") - .warn( - "HashAgg empty-aggExprs branch: " + - s"groupingExprs=${groupingExpressions} " + - s"resultExpressions=${resultExpressions} " + - s"resultExpressions.size=${resultExpressions.size} " + - s"aggregate.output=${aggregate.output} " + - s"aggregate.output.size=${aggregate.output.size} " + - s"modes(from aggExprs)=${modes} " + - s"sparkFinalMode=$sparkFinalMode " + - s"requiredChildDistribution=${aggregate.requiredChildDistribution} " + - s"isProjectionToEmpty=${resultExpressions.isEmpty && aggregate.output.isEmpty} " + - s"naturalEqualsIntent=${resultExpressions.map(_.toAttribute) == groupingExpressions - .map(_.toAttribute)}") val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) // The native HashAggregate emits its natural shape (the grouping keys, since there diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index fded42d050..16601d056b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3959,117 +3959,6 @@ class CometExecSuite extends CometTestBase { } } - // Repro for the Spark 3.5 SQL-tests failure on subquery/exists-subquery/exists-orderby-limit.sql - // (query #19) on branch opt_native_shuffle. Crashes with NativeUtil.exportBatch:132 AIOOBE - // on this branch but passes on main. The SQL file declares three CONFIG_DIM1 combos for - // codegen; the failure trace doesn't say which one fires it, so sweep all three. - test("EXISTS subquery with GROUP BY + LIMIT + OFFSET") { - withTempView("emp", "dept") { - spark - .sql("""SELECT * FROM VALUES - | (100, 'emp 1', date '2005-01-01', 100.0D, 10), - | (200, 'emp 2', date '2003-01-01', 200.0D, 10), - | (300, 'emp 3', date '2002-01-01', 300.0D, 20), - | (400, 'emp 4', date '2005-01-01', 400.0D, 30), - | (500, 'emp 5', date '2001-01-01', 400.0D, NULL), - | (700, 'emp 7', date '2010-01-01', 400.0D, 100), - | (800, 'emp 8', date '2016-01-01', 150.0D, 70) - |AS t(id, emp_name, hiredate, salary, dept_id)""".stripMargin) - .createOrReplaceTempView("emp") - spark - .sql("""SELECT * FROM VALUES - | (10, 'dept 1', 'CA'), - | (20, 'dept 2', 'NY'), - | (30, 'dept 3', 'TX'), - | (40, 'dept 4 - unassigned', 'OR'), - | (50, 'dept 5 - unassigned', 'NJ'), - | (70, 'dept 7', 'FL') - |AS t(dept_id, dept_name, state)""".stripMargin) - .createOrReplaceTempView("dept") - - val configDims = Seq( - Map("spark.sql.codegen.wholeStage" -> "true"), - Map( - "spark.sql.codegen.wholeStage" -> "false", - "spark.sql.codegen.factoryMode" -> "CODEGEN_ONLY"), - Map( - "spark.sql.codegen.wholeStage" -> "false", - "spark.sql.codegen.factoryMode" -> "NO_CODEGEN")) - - // Mirror the SQL test order: queries #1 through #16 from the file run before #17 (the - // one whose CI failure we're chasing). Query #11's subquery is identical except for - // the OFFSET; #13 and #15 share similar shapes. Subquery materialization / AQE plan - // cache state from running them first may alter query #17's executed shape. - val priorQueries = Seq( - // #1 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate""", - // #2 - """SELECT id, hiredate FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate DESC""", - // #3 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 1) ORDER BY hiredate""", - // #4 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 0) ORDER BY hiredate""", - // #5 - """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state) ORDER BY hiredate""", - // #6 - """SELECT emp_name FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY state ORDER BY state)""", - // #7 - """SELECT count(*) FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY dept_id ORDER BY dept_id)""", - // #8 - """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 1) ORDER BY hiredate""", - // #9 - """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_id FROM dept WHERE emp.dept_id = dept.dept_id ORDER BY state LIMIT 0) ORDER BY hiredate""", - // #10 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 10 LIMIT 1)""", - // #11 - same subquery as #17 minus the OFFSET - """SELECT * FROM emp WHERE EXISTS (SELECT max(dept.dept_id) FROM dept GROUP BY state LIMIT 1)""", - // #12 - """SELECT * FROM emp WHERE NOT EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 100 LIMIT 1)""", - // #13 - """SELECT * FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) FROM dept WHERE dept.dept_id > 100 GROUP BY state LIMIT 1)""", - // #14 - """SELECT emp_name FROM emp WHERE NOT EXISTS (SELECT max(dept.dept_id) a FROM dept WHERE dept.dept_id = emp.dept_id GROUP BY state ORDER BY state LIMIT 2 OFFSET 1)""", - // #15 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > 10 LIMIT 1 OFFSET 2)""", - // #16 - """SELECT * FROM emp WHERE EXISTS (SELECT dept.dept_name FROM dept WHERE dept.dept_id > emp.dept_id LIMIT 1)""") - - for (dim <- configDims) { - // SQL-tests harness (dev/diffs/3.5.8.diff, 4.1.1.diff) sets only: - // spark.comet.enabled, spark.comet.exec.enabled, spark.comet.exec.shuffle.enabled, - // spark.comet.parquet.respectFilterPushdown, spark.shuffle.manager, - // spark.comet.memoryOverhead. - // It does NOT enable spark.comet.sparkToColumnar.enabled, but CometTestBase does. - // Force the harness shape so the reproducer matches the failing config. - val harnessConf = dim ++ Map(CometConf.COMET_SPARK_TO_ARROW_ENABLED.key -> "false") - withSQLConf(harnessConf.toSeq: _*) { - // scalastyle:off println - println(s"\n===== config dim: $harnessConf =====") - for ((q, idx) <- priorQueries.zipWithIndex) { - spark.sql(q).collect() - println(s"--- ran prior query #${idx + 1} ---") - } - val sql = """SELECT * - |FROM emp - |WHERE EXISTS (SELECT max(dept.dept_id) - | FROM dept - | GROUP BY state - | LIMIT 1 - | OFFSET 2)""".stripMargin - val df = spark.sql(sql) - println("--- query #17 initial executedPlan ---") - println(df.queryExecution.executedPlan) - val rows = df.collect() - println("--- query #17 final (post-AQE) executedPlan ---") - println(df.queryExecution.executedPlan) - println(s"--- ${rows.length} rows ---") - // scalastyle:on println - checkSparkAnswer(sql) - } - } - } - } - } case class BucketedTableTestSpec( From 214a75be2f7def766bc0c98158651c7ff01b5db6 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 15:04:28 -0400 Subject: [PATCH 24/39] cleanup. --- native/core/src/execution/planner.rs | 10 +- native/proto/src/proto/operator.proto | 7 +- .../apache/spark/sql/comet/operators.scala | 134 ++++++++---------- .../comet/exec/CometAggregateSuite.scala | 118 +++++---------- 4 files changed, 97 insertions(+), 172 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 18c3acf2dd..328ca65760 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1164,13 +1164,9 @@ impl PhysicalPlanner { )?, ); - // The native HashAggregate emits its natural shape (group keys + agg - // results / state). Any post-aggregate projection Spark catalyst declares - // (`COUNT(col) + 1`, EXISTS-pruned-to-empty output, alias renames, etc.) is - // expressed as an explicit `OpStruct::Projection` op above the aggregate - // by the JVM serializer (see `CometBaseAggregate.doConvert`). Keeping that - // logic on the JVM side means only one place decides plan shape, and the - // native side stays a faithful executor. See comet#4515. + // HashAggregate emits its natural shape (group keys + agg results); any + // post-aggregate projection is serialized as an explicit `OpStruct::Projection` + // op above by the JVM serializer (see `CometBaseAggregate.doConvert`) Ok(( scans, shuffle_scans, diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index aaff3cc4de..cc56331d74 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -294,11 +294,8 @@ message Sort { message HashAggregate { repeated spark.spark_expression.Expr grouping_exprs = 1; repeated spark.spark_expression.AggExpr agg_exprs = 2; - // Field 3 (`result_exprs`) and field 8 (`apply_result_projection`) were used to apply a - // post-aggregate projection inside the HashAggregate operator. The same effect is now - // expressed by emitting an explicit `Projection` op above the `HashAggregate` from the - // JVM serializer when needed (see `CometBaseAggregate.doConvert`). Reserved to avoid - // accidental reuse at incompatible semantics. + // Was result_exprs / apply_result_projection; now expressed as an explicit Projection + // op above HashAggregate (see CometBaseAggregate.doConvert, comet#4515). reserved 3, 8; reserved "result_exprs", "apply_result_projection"; AggregateMode mode = 5; diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 7ec8a2e898..cbd9dc4cbb 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1518,48 +1518,12 @@ trait CometBaseAggregate { if (aggregateExpressions.isEmpty) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - // The native HashAggregate emits its natural shape (the grouping keys, since there - // are no aggregate functions). When Spark catalyst declares a different output - - // either column-renaming via aliases, or an entirely empty output for catalyst-pruned - // EXISTS / row-existence-only subqueries - we wrap the HashAggregate in an explicit - // Projection op so the native side reshapes accordingly. See comet#4515: an empty - // declared output paired with the natural grouping-key output crashed downstream - // boundaries that derived their schema from the declared output. - val naturalOutput = groupingExpressions.map(_.toAttribute) - val needsProjection = resultExpressions.map(_.toAttribute) != naturalOutput - if (needsProjection) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - // Build the inner HashAgg op carrying the original child operators from `builder`. - // Use a fresh outer builder for the Projection so it gets a single child (the - // HashAgg op), not the original children appended on top. Both ops share the same - // plan_id so the inner aggregate's native metrics roll up under the same Spark - // operator in the metric tree (otherwise they'd orphan against plan_id=0). - val hashAggOp = OperatorOuterClass.Operator - .newBuilder() - .setPlanId(builder.getPlanId) - .addAllChildren(builder.getChildrenList) - .setHashAgg(hashAggBuilder) - .build() - val projectionBuilder = OperatorOuterClass.Projection.newBuilder() - projectionBuilder.addAllProjectList(resultExprs.map(_.get).asJava) - Some( - OperatorOuterClass.Operator - .newBuilder() - .setPlanId(builder.getPlanId) - .addChildren(hashAggOp) - .setProjection(projectionBuilder) - .build()) - } else { - Some(builder.setHashAgg(hashAggBuilder).build()) - } + buildAggOp( + builder, + hashAggBuilder, + groupingExpressions.map(_.toAttribute), + resultExpressions, + aggregate) } else { // Validate mode combinations. We support: // - All Partial @@ -1647,40 +1611,15 @@ trait CometBaseAggregate { } // Final aggregations may carry a result projection (e.g. `COUNT(col) + 1`) that - // catalyst encodes via `resultExpressions`. DataFusion's hash aggregate only emits - // its natural shape (group keys + agg results), so we wrap the HashAggregate in - // an explicit Projection op to apply Spark's result expressions. Partial / - // PartialMerge aggregates emit raw state buffers and never need the projection. - // See comet#4515. + // catalyst encodes via `resultExpressions`. Partial / PartialMerge aggregates emit + // raw state buffers and never need it. See comet#4515. if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withInfo( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - // Inner HashAgg keeps the original input children from `builder`. Outer - // Projection uses a fresh builder so it has a single child (the HashAgg op). - // Both ops share the same plan_id so the aggregate's native metrics aggregate - // under the same Spark operator (else they'd orphan against plan_id=0). - val hashAggOp = OperatorOuterClass.Operator - .newBuilder() - .setPlanId(builder.getPlanId) - .addAllChildren(builder.getChildrenList) - .setHashAgg(hashAggBuilder) - .build() - val projectionBuilder = OperatorOuterClass.Projection.newBuilder() - projectionBuilder.addAllProjectList(resultExprs.map(_.get).asJava) - Some( - OperatorOuterClass.Operator - .newBuilder() - .setPlanId(builder.getPlanId) - .addChildren(hashAggOp) - .setProjection(projectionBuilder) - .build()) + buildAggOp( + builder, + hashAggBuilder, + groupingExpressions.map(_.toAttribute) ++ aggregateAttributes, + resultExpressions, + aggregate) } else { Some(builder.setHashAgg(hashAggBuilder).build()) } @@ -1694,6 +1633,51 @@ trait CometBaseAggregate { } + /** + * Serialize a HashAggregate, wrapping it in an explicit `Projection` op when Spark's declared + * output (`resultExpressions`) differs from the aggregate's natural output. DataFusion's hash + * aggregate emits only its natural shape (group keys + agg results), so any reshape catalyst + * declared - alias renames, `COUNT(col) + 1`, or empty output for catalyst-pruned EXISTS / + * row-existence-only subqueries - is expressed as a separate Projection above the HashAgg. Both + * ops share the caller's `plan_id` so the aggregate's native metrics roll up under the same + * Spark operator. See comet#4515. + */ + private def buildAggOp( + builder: Operator.Builder, + hashAggBuilder: OperatorOuterClass.HashAggregate.Builder, + naturalOutput: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + aggregate: BaseAggregateExec): Option[Operator] = { + if (resultExpressions.map(_.toAttribute) == naturalOutput) { + return Some(builder.setHashAgg(hashAggBuilder).build()) + } + val resultExprs = resultExpressions.map(exprToProto(_, naturalOutput)) + if (resultExprs.exists(_.isEmpty)) { + withInfo( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + val planId = builder.getPlanId + val hashAggOp = OperatorOuterClass.Operator + .newBuilder() + .setPlanId(planId) + .addAllChildren(builder.getChildrenList) + .setHashAgg(hashAggBuilder) + .build() + val projection = OperatorOuterClass.Projection + .newBuilder() + .addAllProjectList(resultExprs.map(_.get).asJava) + Some( + OperatorOuterClass.Operator + .newBuilder() + .setPlanId(planId) + .addChildren(hashAggOp) + .setProjection(projection) + .build()) + } + /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with * partial or partial-merge mode, it will return None. diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index f098040878..3e883d7a32 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2109,95 +2109,43 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // Regression for comet#4515: a HashAggregateExec whose `resultExpressions` (and therefore - // `output`) catalyst has pruned to empty must still produce 0-col batches at runtime. - // Catalyst prunes `resultExpressions=[]` for plans where the aggregate's column values are - // unused downstream - classically EXISTS subqueries that get rewritten into a literal-`1` - // wrapper. Before the fix, the native HashAggregate emitted its natural output (the - // grouping keys) regardless of the pruned JVM `output`, so any boundary that derived a - // schema from `output` (e.g. a wrapping vanilla Spark `ShuffleExchangeExec.output = - // child.output = []`, or a JVM-bridge Scan synthesized from the same `output`) declared - // 0 columns while the runtime RDD produced 1. The mismatch tripped - // `NativeUtil.exportBatch` with an ArrayIndexOutOfBoundsException on a length-0 - // schemaAddrs[]. + // Regression for comet#4515: catalyst prunes `HashAggregateExec.resultExpressions` to + // empty for EXISTS / row-existence-only subqueries. The native HashAggregate's natural + // output (the grouping keys) then disagrees with the pruned JVM `output`, leaking through + // any boundary that derived its schema from `output`. The fix wraps the aggregate in an + // explicit Projection op when natural != declared. // - // The bug needs a specific catalyst optimizer state: `HashAggregateExec. - // resultExpressions.isEmpty`. Whether the optimizer reaches that state from a given SQL - // depends on Spark version, scan source (LocalTableScan vs Parquet/native), AQE state, - // and which Comet rules already fired - we observed it in the SQL-tests harness running - // `subquery/exists-subquery/exists-orderby-limit.sql` on Spark 4.0.2 with parquet-backed - // temp views, but not via `Dataset.collect` over `LocalTableScan`-backed temp views - // under `CometTestBase`. So the test below writes parquet (matching the harness's scan - // shape), tries several known triggers, runs whichever (if any) produces the bug shape - // under `checkSparkAnswer`, and skips cleanly otherwise. The upstream SQL-tests run - // remains the primary safety net for the harness-only path. + // Surfaced upstream in `subquery/exists-subquery/exists-orderby-limit.sql` (query #19, + // an EXISTS over `max(...) GROUP BY state LIMIT 1 OFFSET 2`). The exact `EXISTS-in-WHERE` + // shape doesn't reproduce under CometTestBase's optimizer state, but `count(*)` over the + // same derived aggregate triggers the equivalent ColumnPruning path locally - we assert + // the inner HashAgg's resultExpressions actually got pruned, so a future Spark version + // that breaks the trigger fails the test loudly rather than passing silently. test("HashAggregate with catalyst-pruned resultExpressions returns 0-col output (#4515)") { withTempDir { dir => - withTempView("emp", "dept") { - // Write parquet so Comet's native scan path (vs LocalTableScan) is the source - - // matches the SQL-tests harness setup that surfaced the bug. - val empPath = new java.io.File(dir, "emp").getAbsolutePath - val deptPath = new java.io.File(dir, "dept").getAbsolutePath - - spark - .sql("""SELECT * FROM VALUES - | (100, 'emp 1', 100.0D, 10), - | (200, 'emp 2', 200.0D, 10), - | (300, 'emp 3', 300.0D, 20), - | (400, 'emp 4', 400.0D, 30), - | (500, 'emp 5', 400.0D, NULL), - | (700, 'emp 7', 400.0D, 100), - | (800, 'emp 8', 150.0D, 70) - |AS t(id, emp_name, salary, dept_id)""".stripMargin) - .write - .parquet(empPath) - spark - .sql("""SELECT * FROM VALUES - | (10, 'CA'), (20, 'NY'), (30, 'TX'), - | (40, 'OR'), (50, 'NJ'), (70, 'FL') - |AS t(dept_id, state)""".stripMargin) - .write - .parquet(deptPath) - - spark.read.parquet(empPath).createOrReplaceTempView("emp") - spark.read.parquet(deptPath).createOrReplaceTempView("dept") - - val candidates = Seq( - // The original failing SQL from the harness - EXISTS with grouped agg + LIMIT/OFFSET. - """SELECT * FROM emp - |WHERE EXISTS ( - | SELECT max(dept.dept_id) FROM dept GROUP BY state LIMIT 1 OFFSET 2)""".stripMargin, - // Inline view + outer constant: ColumnPruning may strip the inner agg's output. - """SELECT 1 FROM ( - | SELECT max(dept_id) FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin, - // Scalar subquery returning a constant. - """SELECT (SELECT 1 FROM dept GROUP BY state LIMIT 1 OFFSET 2)""".stripMargin, - // count(*) over a derived table: outer doesn't reference inner cols. + val deptPath = new Path(dir.toURI.toString, "dept") + spark + .sql("""SELECT * FROM VALUES + | (10, 'CA'), (20, 'NY'), (30, 'TX'), + | (40, 'OR'), (50, 'NJ'), (70, 'FL') + |AS t(dept_id, state)""".stripMargin) + .write + .parquet(deptPath.toUri.toString) + withParquetTable(deptPath.toUri.toString, "dept") { + val sql = """SELECT count(*) FROM ( - | SELECT max(dept_id) AS m FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin) - - // Find a candidate whose plan has a HashAggregateExec (Spark or Comet) with empty - // resultExpressions. collectWithSubqueries traverses Subquery nodes too. - val triggering = candidates.find { sql => - val plan = spark.sql(sql).queryExecution.executedPlan - collectWithSubqueries(plan) { - case a: org.apache.spark.sql.execution.aggregate.HashAggregateExec - if a.resultExpressions.isEmpty => - a - case a: CometHashAggregateExec if a.resultExpressions.isEmpty => a - }.nonEmpty - } - - triggering match { - case Some(sql) => checkSparkAnswer(sql) - case None => - cancel( - "No candidate query produced a HashAggregateExec with empty resultExpressions " + - "in this environment. The catalyst-pruned shape that exercises #4515 only " + - "appears under specific optimizer/AQE state we couldn't reproduce here. The " + - "upstream SQL-tests run (subquery/exists-subquery/exists-orderby-limit.sql) " + - "covers this path.") - } + | SELECT max(dept_id) AS m FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin + val plan = spark.sql(sql).queryExecution.executedPlan + val pruned = collectWithSubqueries(plan) { + case a: org.apache.spark.sql.execution.aggregate.HashAggregateExec + if a.resultExpressions.isEmpty => + a + case a: CometHashAggregateExec if a.resultExpressions.isEmpty => a + } + assert( + pruned.nonEmpty, + s"Expected a HashAggregateExec with empty resultExpressions in:\n$plan") + checkSparkAnswerAndOperator(sql) } } } From e8e438f451b52e678ff43f01e259851eef37c232 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 15:12:40 -0400 Subject: [PATCH 25/39] cleanup. --- .../src/main/scala/org/apache/spark/sql/comet/operators.scala | 4 ++-- .../scala/org/apache/comet/exec/CometAggregateSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index cbd9dc4cbb..f419ca1b28 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -1612,7 +1612,7 @@ trait CometBaseAggregate { // Final aggregations may carry a result projection (e.g. `COUNT(col) + 1`) that // catalyst encodes via `resultExpressions`. Partial / PartialMerge aggregates emit - // raw state buffers and never need it. See comet#4515. + // raw state buffers and never need it. if (mode == CometAggregateMode.Final) { buildAggOp( builder, @@ -1640,7 +1640,7 @@ trait CometBaseAggregate { * declared - alias renames, `COUNT(col) + 1`, or empty output for catalyst-pruned EXISTS / * row-existence-only subqueries - is expressed as a separate Projection above the HashAgg. Both * ops share the caller's `plan_id` so the aggregate's native metrics roll up under the same - * Spark operator. See comet#4515. + * Spark operator. */ private def buildAggOp( builder: Operator.Builder, diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 3e883d7a32..ae14c68207 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2109,7 +2109,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // Regression for comet#4515: catalyst prunes `HashAggregateExec.resultExpressions` to + // Regression: Catalyst prunes `HashAggregateExec.resultExpressions` to // empty for EXISTS / row-existence-only subqueries. The native HashAggregate's natural // output (the grouping keys) then disagrees with the pruned JVM `output`, leaking through // any boundary that derived its schema from `output`. The fix wraps the aggregate in an @@ -2121,7 +2121,7 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // same derived aggregate triggers the equivalent ColumnPruning path locally - we assert // the inner HashAgg's resultExpressions actually got pruned, so a future Spark version // that breaks the trigger fails the test loudly rather than passing silently. - test("HashAggregate with catalyst-pruned resultExpressions returns 0-col output (#4515)") { + test("HashAggregate with catalyst-pruned resultExpressions returns 0-col output") { withTempDir { dir => val deptPath = new Path(dir.toURI.toString, "dept") spark From 1913208e7b137a95dd4b6002cf24ed7b21c86037 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 17:05:01 -0400 Subject: [PATCH 26/39] fix aggregation wrapping now that we don't have an extra CometExecIterator. --- native/shuffle/src/shuffle_writer.rs | 34 +++++++++---------- .../comet/exec/CometAggregateSuite.scala | 24 +++++++++++++ 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 8502c79624..b96917550f 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -29,7 +29,7 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::EmptyRecordBatchStream; use datafusion::{ - arrow::{datatypes::SchemaRef, error::ArrowError}, + arrow::datatypes::SchemaRef, error::Result, execution::context::TaskContext, physical_plan::{ @@ -38,7 +38,7 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use std::{ any::Any, fmt, @@ -171,23 +171,23 @@ impl ExecutionPlan for ShuffleWriterExec { let input = self.input.execute(partition, Arc::clone(&context))?; let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0); + // Propagate DataFusionError unchanged: the JNI bridge only downcasts a single + // `DataFusionError::External(SparkError)` layer, so any extra wrap here loses the + // typed exception (e.g. SparkArithmeticException on decimal overflow). Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::once( - external_shuffle( - input, - partition, - self.output_data_file.clone(), - self.output_index_file.clone(), - self.partitioning.clone(), - metrics, - context, - self.codec.clone(), - self.tracing_enabled, - self.write_buffer_size, - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))), - ) + futures::stream::once(external_shuffle( + input, + partition, + self.output_data_file.clone(), + self.output_index_file.clone(), + self.partitioning.clone(), + metrics, + context, + self.codec.clone(), + self.tracing_enabled, + self.write_buffer_size, + )) .try_flatten(), ))) } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index ae14c68207..3079db842f 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1817,6 +1817,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // make sure that the error message throws overflow exception only assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long overflow in ANSI mode") } } else { @@ -1831,6 +1835,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long underflow in ANSI mode") } } else { @@ -1870,6 +1878,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for decimal overflow in ANSI mode") } @@ -1893,6 +1905,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") } @@ -1910,6 +1926,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") } @@ -1951,6 +1971,10 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) + assert( + cometExc.isInstanceOf[ArithmeticException], + "expected ArithmeticException, got " + + s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode") } From 327a6535d53ac07c076e1faa1f3610441d4d134e Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 17:44:23 -0400 Subject: [PATCH 27/39] Remove archeology comments. --- .../core/src/execution/operators/schema_align.rs | 16 +++++++++------- native/core/src/execution/planner.rs | 3 --- .../shuffle/CometShuffleExchangeExec.scala | 7 +++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/native/core/src/execution/operators/schema_align.rs b/native/core/src/execution/operators/schema_align.rs index 3d207e0202..58983d132c 100644 --- a/native/core/src/execution/operators/schema_align.rs +++ b/native/core/src/execution/operators/schema_align.rs @@ -16,11 +16,11 @@ // under the License. //! `SchemaAlignExec` reshapes its child's output so the per-column Arrow type and field-level -//! nullability match what Spark catalyst declared. Used between an inlined native subtree and -//! `ShuffleWriterExec` when the FFI deep-copy + `ScanExec` cast in `build_record_batch` are both -//! gone, so DataFusion / `datafusion-spark` return-type drift would otherwise be written into -//! shuffle blocks. See for the running -//! list of mismatched functions. +//! nullability match what Spark catalyst declared, casting where necessary. Sits between a native +//! subtree and `ShuffleWriterExec` so DataFusion / `datafusion-spark` return-type drift is caught +//! before it reaches shuffle blocks. See +//! for the running list of mismatched +//! functions. use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; @@ -53,6 +53,9 @@ fn warn_dedup() -> &'static Mutex> { SET.get_or_init(|| Mutex::new(HashSet::new())) } +/// Casts each column of `child`'s output to the data_type Spark catalyst declared, widening +/// nullability to `actual.nullable || expected.nullable`. See +/// . #[derive(Debug)] pub struct SchemaAlignExec { child: Arc, @@ -74,8 +77,7 @@ impl SchemaAlignExec { /// Build a SchemaAlignExec that aligns `child`'s output to `expected`. Returns /// `Ok(child)` unchanged when no per-column reshape is needed; otherwise wraps `child` /// in a SchemaAlignExec whose target schema preserves `expected`'s data_type and metadata - /// but widens nullability to `actual.nullable || expected.nullable` (matching the - /// reconciliation rule used at the FFI boundary on `main`). + /// but widens nullability to `actual.nullable || expected.nullable`. pub fn try_new_or_passthrough( child: Arc, expected: &SchemaRef, diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 328ca65760..91d77a35b4 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -1164,9 +1164,6 @@ impl PhysicalPlanner { )?, ); - // HashAggregate emits its natural shape (group keys + agg results); any - // post-aggregate projection is serialized as an explicit `OpStruct::Projection` - // op above by the JVM serializer (see `CometBaseAggregate.doConvert`) Ok(( scans, shuffle_scans, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 861ce6509e..05f3923714 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -124,7 +124,7 @@ case class CometShuffleExchangeExec( ctx.shuffleScanIndices) case None => // Non-native child (e.g. CometSparkToColumnarExec): no subtree to inline. The dep gets - // built via the legacy convenience overload below; we just need a real RDD of batches. + // built via the convenience overload below; we just need a real RDD of batches. child.executeColumnar() } } else if (shuffleType == CometColumnarShuffle) { @@ -676,9 +676,8 @@ object CometShuffleExchangeExec * Implemented as a thin wrapper around [[prepareNativeShuffleDependency]]: synthesizes a * `Scan("ShuffleWriterInput")` as the child native op (so the writer's plan is still * `ShuffleWriter -> Scan`, consuming JVM batches via Arrow C Stream), wraps `rdd` as the single - * leaf input of a thin scheduling RDD, and supplies a minimal [[NativeExecContext]]. Same wire - * shape as before; one writer code path for both this case and the [[CometShuffleExchangeExec]] - * case. + * leaf input of a thin scheduling RDD, and supplies a minimal [[NativeExecContext]]. Lets the + * writer use one code path for both this case and the [[CometShuffleExchangeExec]] case. */ def prepareShuffleDependency( rdd: RDD[ColumnarBatch], From 0b71de43014c683630730d0bf297187d979ba091 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 18:54:22 -0400 Subject: [PATCH 28/39] Undo stricter tests since they're not happy on Spark 3.x. --- .../comet/exec/CometAggregateSuite.scala | 33 +++++-------------- 1 file changed, 9 insertions(+), 24 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index 3079db842f..c441480cec 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1805,6 +1805,15 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) } + // SparkPlan.executeCollect wraps task failures in SparkException("Job aborted..."), so the + // typed exception thrown by Comet (SparkArithmeticException, which extends ArithmeticException) + // ends up in the cause chain rather than at the top. + private def hasArithmeticInChain(t: Throwable): Boolean = + Iterator + .iterate[Throwable](t)(_.getCause) + .takeWhile(_ != null) + .exists(_.isInstanceOf[ArithmeticException]) + test("ANSI support - SUM function") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { @@ -1817,10 +1826,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { // make sure that the error message throws overflow exception only assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long overflow in ANSI mode") } } else { @@ -1835,10 +1840,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long underflow in ANSI mode") } } else { @@ -1878,10 +1879,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for decimal overflow in ANSI mode") } @@ -1905,10 +1902,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long overflow with GROUP BY in ANSI mode") } @@ -1926,10 +1919,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for Long underflow with GROUP BY in ANSI mode") } @@ -1971,10 +1960,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { case (Some(sparkExc), Some(cometExc)) => assert(sparkExc.getMessage.contains("ARITHMETIC_OVERFLOW")) assert(cometExc.getMessage.contains("ARITHMETIC_OVERFLOW")) - assert( - cometExc.isInstanceOf[ArithmeticException], - "expected ArithmeticException, got " + - s"${cometExc.getClass.getName}: ${cometExc.getMessage}") case _ => fail("Exception should be thrown for decimal overflow with GROUP BY in ANSI mode") } From 1ecfd8a2001f9a1c5b6313fcca05745f768fdbfb Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 29 May 2026 18:55:19 -0400 Subject: [PATCH 29/39] Remove unintended change. --- .../org/apache/comet/exec/CometAggregateSuite.scala | 9 --------- 1 file changed, 9 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index c441480cec..ae14c68207 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -1805,15 +1805,6 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { (1 to 50).flatMap(_ => Seq((maxDec38_0, 1))) } - // SparkPlan.executeCollect wraps task failures in SparkException("Job aborted..."), so the - // typed exception thrown by Comet (SparkArithmeticException, which extends ArithmeticException) - // ends up in the cause chain rather than at the top. - private def hasArithmeticInChain(t: Throwable): Boolean = - Iterator - .iterate[Throwable](t)(_.getCause) - .takeWhile(_ != null) - .exists(_.isInstanceOf[ArithmeticException]) - test("ANSI support - SUM function") { Seq(true, false).foreach { ansiEnabled => withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiEnabled.toString) { From 6afbdcebce5f8f499cd1d238c71497dd26daa93b Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 09:15:31 -0400 Subject: [PATCH 30/39] add CometArrowStreamSuite to CI workflows --- .github/workflows/pr_build_linux.yml | 1 + .github/workflows/pr_build_macos.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 422232f546..ca0f5c01f2 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -335,6 +335,7 @@ jobs: org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite org.apache.comet.CometFuzzAggregateSuite + org.apache.spark.sql.comet.execution.arrow.CometArrowStreamSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index d0a03eeb75..f148a44e14 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -177,6 +177,7 @@ jobs: org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite org.apache.comet.CometFuzzAggregateSuite + org.apache.spark.sql.comet.execution.arrow.CometArrowStreamSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite From deb697ad03147ddbf635c3e9c6724c1e6ac759c6 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 09:24:20 -0400 Subject: [PATCH 31/39] fix withInfo use --- .../org/apache/spark/sql/comet/CometLocalTableScanExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 32b8933872..570848b70a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport} -import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink @@ -158,7 +158,7 @@ object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTy childOp: Operator*): Option[Operator] = { val fallbackReasons = new ListBuffer[String]() if (!isSchemaSupported(op.schema, fallbackReasons)) { - withInfo(op, fallbackReasons.mkString("; ")) + withFallbackReason(op, fallbackReasons.mkString("; ")) None } else { super.convert(op, builder, childOp: _*) From 6f913f933a23fef039cd7cdbdd03b480e5343ee0 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 14:11:08 -0400 Subject: [PATCH 32/39] Don't enable LocalTableScan by default, cruft from #4393. --- spark/src/main/scala/org/apache/comet/CometConf.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 82ae6773c7..78ea0f0168 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -273,7 +273,7 @@ object CometConf extends ShimCometConf { val COMET_EXEC_TAKE_ORDERED_AND_PROJECT_ENABLED: ConfigEntry[Boolean] = createExecEnabledConfig("takeOrderedAndProject", defaultValue = true) val COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED: ConfigEntry[Boolean] = - createExecEnabledConfig("localTableScan", defaultValue = true) + createExecEnabledConfig("localTableScan", defaultValue = false) val COMET_NATIVE_COLUMNAR_TO_ROW_ENABLED: ConfigEntry[Boolean] = conf(s"$COMET_EXEC_CONFIG_PREFIX.columnarToRow.native.enabled") From d1623997f72c1a18a1c67f016de57f4d3f52fd1d Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 14:45:39 -0400 Subject: [PATCH 33/39] cleanup --- native/jni-bridge/src/arrow_array_stream.rs | 2 +- .../org/apache/comet/CometExecIterator.scala | 14 ++++++-------- .../org/apache/spark/sql/comet/CometExecRDD.scala | 7 +++++++ .../apache/spark/sql/comet/CometExecUtils.scala | 7 +------ .../spark/sql/comet/CometNativeWriteExec.scala | 11 ++++------- .../comet/CometTakeOrderedAndProjectExec.scala | 15 +++------------ .../arrow/ColumnarBatchArrowReader.scala | 2 +- .../execution/arrow/CometNativeArrowSource.scala | 12 ++++++++++++ .../shuffle/CometNativeShuffleWriter.scala | 10 +++++++--- .../org/apache/spark/sql/comet/operators.scala | 3 +++ .../org/apache/spark/sql/comet/util/Utils.scala | 7 +++---- 11 files changed, 48 insertions(+), 42 deletions(-) diff --git a/native/jni-bridge/src/arrow_array_stream.rs b/native/jni-bridge/src/arrow_array_stream.rs index 0b285607ff..2cfea73688 100644 --- a/native/jni-bridge/src/arrow_array_stream.rs +++ b/native/jni-bridge/src/arrow_array_stream.rs @@ -25,7 +25,7 @@ use jni::{ /// A struct that holds all the JNI methods and fields for JVM `org.apache.arrow.c.ArrowArrayStream` /// class. `memoryAddress()` is read once per partition so native can take ownership of the -/// underlying C struct via `ArrowArrayStreamReader::from_raw`. +/// underlying C struct via `AlignedArrowStreamReader::from_raw`. #[allow(dead_code)] // we need to keep references to Java items to prevent GC pub struct ArrowArrayStream<'a> { pub class: JClass<'a>, diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index d17735a560..c684e17a92 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -47,8 +47,11 @@ import org.apache.comet.vector.NativeUtil * `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is * done). * - * @param inputs - * The input iterators producing sequence of batches of Arrow Arrays. + * @param inputObjects + * Already-built native input slots, in scan-input order. Each slot is either an + * org.apache.arrow.c.ArrowArrayStream (consumed natively via from_raw against its + * memoryAddress) or a CometShuffleBlockIterator (consumed via the JNI block-iteration + * protocol). * @param protobufQueryPlan * The serialized bytes of Spark execution plan. * @param numParts @@ -79,11 +82,6 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - // Each input slot is either an org.apache.arrow.c.ArrowArrayStream (consumed natively via - // ArrowArrayStreamReader::from_raw against its memoryAddress) or a CometShuffleBlockIterator - // (consumed via the existing JNI block-iteration protocol). The slot index matches the scan - // input index in the serialized native plan. - private val inputIterators: Array[Object] = inputObjects private val plan = { val conf = SparkEnv.get.conf @@ -109,7 +107,7 @@ class CometExecIterator( nativeLib.createPlan( id, - inputIterators, + inputObjects, protobufQueryPlan, protobufSparkConfigs, numParts, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 4b411d87f7..7caa375cf5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -45,6 +45,13 @@ private[spark] class CometExecPartition( * (consumed natively via the C Stream Interface); shuffle input slots are `CometShuffledBatchRDD` * (consumed via `CometShuffleBlockIterator`). Slot order matches the scan-input order in the * serialized native plan. + * + * Solves the closure-capture problem: instead of capturing all partitions' data in the closure + * (which gets serialized to every task), each `CometExecPartition` carries only its own data. + * + * Does not handle DPP (InSubqueryExec), which is resolved in + * `CometIcebergNativeScanExec.serializedPartitionData` before this RDD is created. It does handle + * `ScalarSubquery` expressions by registering them with `CometScalarSubquery` before execution. */ private[spark] class CometExecRDD( sc: SparkContext, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index e632190f0a..8145e563f4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -60,13 +60,8 @@ object CometExecUtils { val serializedPlan = CometExec.serializeNativePlan(limitOp) val inputSchema = Utils.fromAttributes(outputAttribute) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - val stream = CometArrowStream.fromColumnarBatchIter( - iter, - inputSchema, - CometArrowStream.NATIVE_TIMEZONE, - "CometExecUtils-getNativeLimit") CometExec.getCometIterator( - Array(stream.asInstanceOf[Object]), + CometArrowStream.inputObjects(iter, inputSchema, "CometExecUtils-getNativeLimit"), outputAttribute.length, serializedPlan, numParts, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 7ba281a666..4cfdd11361 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -212,15 +212,12 @@ case class CometNativeWriteExec( modifiedNativeOp.writeTo(codedOutput) codedOutput.checkNoSpaceLeft() - val arrowStream = CometArrowStream.fromColumnarBatchIter( - iter, - CometUtils.fromAttributes(child.output), - CometArrowStream.NATIVE_TIMEZONE, - "CometNativeWriteExec") - val execIterator = new CometExecIterator( CometExec.newIterId, - Array(arrowStream.asInstanceOf[Object]), + CometArrowStream.inputObjects( + iter, + CometUtils.fromAttributes(child.output), + "CometNativeWriteExec"), numOutputCols, planBytes, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index e9b178dc6b..09dd944d93 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -144,13 +144,8 @@ case class CometTakeOrderedAndProjectExec( val numOutputCols = child.output.length val inputSchema = CometUtils.fromAttributes(child.output) childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - val stream = CometArrowStream.fromColumnarBatchIter( - iter, - inputSchema, - CometArrowStream.NATIVE_TIMEZONE, - "CometTakeOrderedAndProject-topK") CometExec.getCometIterator( - Array(stream.asInstanceOf[Object]), + CometArrowStream.inputObjects(iter, inputSchema, "CometTakeOrderedAndProject-topK"), numOutputCols, serializedTopK, numParts, @@ -178,13 +173,9 @@ case class CometTakeOrderedAndProjectExec( val finalOutputLength = output.length val finalInputSchema = CometUtils.fromAttributes(child.output) singlePartitionRDD.mapPartitionsInternal { iter => - val stream = CometArrowStream.fromColumnarBatchIter( - iter, - finalInputSchema, - CometArrowStream.NATIVE_TIMEZONE, - "CometTakeOrderedAndProject-final") val it = CometExec.getCometIterator( - Array(stream.asInstanceOf[Object]), + CometArrowStream + .inputObjects(iter, finalInputSchema, "CometTakeOrderedAndProject-final"), finalOutputLength, serializedTopKAndProjection, 1, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala index 2cb8746107..6cbba9e06e 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -35,7 +35,7 @@ import org.apache.comet.vector.{CometDictionaryVector, CometVector} * the source's `FieldVector`s into a transient `ArrowRecordBatch` (retains buffers), loads it * into this reader's stable VSR via `loadFieldBuffers` (release-and-replace), then closes the * source batch. The unload/load step decouples this reader's VSR ownership from whatever the - * upstream does with its own buffers. + * source does with its own buffers. */ private[comet] class ColumnarBatchArrowReader( allocator: BufferAllocator, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala index 727f761997..b6d904a189 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -87,6 +87,18 @@ object CometArrowStream extends Logging { .next() } + /** + * Build the `inputObjects` array that `CometExecIterator` / `CometExec.getCometIterator` pass + * to native `createPlan`, for the common case of a single scan input fed by one per-partition + * `Iterator[ColumnarBatch]`. The iterator is exported to one `ArrowArrayStream` (Arrow C + * Stream) and boxed as the lone element, using the native timezone. + */ + def inputObjects( + iter: Iterator[ColumnarBatch], + sparkSchema: StructType, + name: String): Array[Object] = + Array[Object](fromColumnarBatchIter(iter, sparkSchema, NATIVE_TIMEZONE, name)) + /** * Build the stream's advertised Arrow schema from the actual `CometVector` types in the first * batch, not from `expected` (which derives from the consumer's Spark-declared types). Native diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index be38aa0e92..5b549b992a 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -273,14 +273,14 @@ class CometNativeShuffleWriter[K, V]( // DataFusion will deduplicate identical sort expressions in LexOrdering, // so we need to transform boundary rows to match the deduplicated structure val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion } else { seenExprs += sortOrder.child - deduplicationMap += (idx -> true) + deduplicationMap += (idx -> true) // Will be kept by DataFusion } } @@ -296,6 +296,10 @@ class CometNativeShuffleWriter[K, V]( val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + // rangePartitionBounds holds Spark InternalRows of partitioning boundaries: each row is a + // boundary, each entry a value in that row (row-major, not column-major). Convert to + // Literals and keep only the entries whose ordering expression survived deduplication, so + // the boundary shape matches DataFusion's deduplicated LexOrdering. val transformedBoundaryExprs: Seq[Seq[Literal]] = rangePartitionBounds.get.map((row: InternalRow) => { val allLiterals = diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 93e668873b..3122bcfc80 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -496,9 +496,12 @@ abstract class CometNativeExec extends CometExec { mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] foreachUntilCometInput(this) { case scan: CometNativeScanExec => + // Bring in any SQLConf "spark.hadoop.*" configs and the per-relation options, since + // different tables may have different decryption properties. val hadoopConf = scan.relation.sparkSession.sessionState .newHadoopConfWithOptions(scan.relation.options) if (CometParquetUtils.encryptionEnabled(hadoopConf)) { + // hadoopConf isn't serializable, so ship it to executors via a broadcast. val broadcastedConf = scan.relation.sparkSession.sparkContext .broadcast(new SerializableConfiguration(hadoopConf)) encryptionOptions += ((broadcastedConf, scan.relation.inputFiles.toSeq)) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index dc54881681..15e1e2c410 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -36,6 +36,7 @@ import org.apache.arrow.vector.util.VectorSchemaRootAppender import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -210,10 +211,8 @@ object Utils extends CometTypeShim with Logging { * `StructType.fromAttributes` (removed in Spark 4) and `DataTypeUtils.fromAttributes` (only on * 4) so the same call works across supported Spark versions. */ - def fromAttributes( - attributes: Seq[org.apache.spark.sql.catalyst.expressions.Attribute]): StructType = - StructType(attributes.map(a => - org.apache.spark.sql.types.StructField(a.name, a.dataType, a.nullable, a.metadata))) + def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) /** * Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark` From 9f321573f77be559b8610ba0de28177733a12e95 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 15:20:07 -0400 Subject: [PATCH 34/39] cleanup --- .../apache/spark/sql/comet/CometExecRDD.scala | 75 ++++++++++++------- .../sql/comet/CometLocalTableScanExec.scala | 30 +++----- .../sql/comet/CometSparkToColumnarExec.scala | 70 ++++++----------- .../arrow/CometNativeArrowSource.scala | 26 ++++++- .../shuffle/CometNativeShuffleInputRDD.scala | 41 +++++----- .../shuffle/CometNativeShuffleWriter.scala | 13 +--- 6 files changed, 122 insertions(+), 133 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 7caa375cf5..caf639e792 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -96,31 +96,12 @@ private[spark] class CometExecRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometExecPartition] - val shuffleBlockIters = scala.collection.mutable.Map.empty[Int, CometShuffleBlockIterator] - val inputObjects: Array[Object] = inputRDDs - .zip(partition.inputPartitions) - .zipWithIndex - .map { case ((rdd, part), idx) => - if (shuffleScanIndices.contains(idx)) { - rdd match { - case shuffleRDD: CometShuffledBatchRDD => - val it = shuffleRDD.computeAsShuffleBlockIterator(part, context) - shuffleBlockIters(idx) = it - it.asInstanceOf[Object] - case other => - throw new CometRuntimeException( - s"Slot $idx is marked as a shuffle scan but the input RDD is " + - s"${other.getClass.getName}, expected CometShuffledBatchRDD") - } - } else { - val streams = rdd.iterator(part, context).asInstanceOf[Iterator[ArrowArrayStream]] - if (!streams.hasNext) { - throw new CometRuntimeException(s"Empty ArrowArrayStream RDD partition for slot $idx") - } - streams.next().asInstanceOf[Object] - } - } - .toArray + val (inputObjects, shuffleBlockIters) = + CometExecRDD.resolveInputObjects( + inputRDDs, + partition.inputPartitions, + shuffleScanIndices, + context) // Only inject if we have per-partition planning data val actualPlan = if (commonByKey.nonEmpty) { @@ -142,7 +123,7 @@ private[spark] class CometExecRDD( partition.index, broadcastedHadoopConfForEncryption, encryptedFilePaths, - shuffleBlockIters.toMap) + shuffleBlockIters) // Register ScalarSubqueries so native code can look them up subqueries.foreach(sub => CometScalarSubquery.setSubquery(it.id, sub)) @@ -175,6 +156,48 @@ private[spark] class CometExecRDD( object CometExecRDD { + /** + * Resolve the per-partition native input slots for `createPlan`, in scan-input order. A slot is + * either a `CometShuffleBlockIterator` (for slots in `shuffleScanIndices`, fed by a + * `CometShuffledBatchRDD` consumed via the JNI block-iteration protocol) or the single + * `ArrowArrayStream` exported by a non-shuffle `RDD[ArrowArrayStream]`. Returned alongside the + * subset that are shuffle-block iterators, which `CometExecIterator` needs to drive block + * iteration. Shared by [[CometExecRDD.compute]] and the native-shuffle path so both classify + * and resolve slots identically. + */ + def resolveInputObjects( + inputRDDs: Seq[RDD[_]], + inputPartitions: Array[Partition], + shuffleScanIndices: Set[Int], + context: TaskContext): (Array[Object], Map[Int, CometShuffleBlockIterator]) = { + val shuffleBlockIters = scala.collection.mutable.Map.empty[Int, CometShuffleBlockIterator] + val inputObjects: Array[Object] = inputRDDs + .zip(inputPartitions) + .zipWithIndex + .map { case ((rdd, part), idx) => + if (shuffleScanIndices.contains(idx)) { + rdd match { + case shuffleRDD: CometShuffledBatchRDD => + val it = shuffleRDD.computeAsShuffleBlockIterator(part, context) + shuffleBlockIters(idx) = it + it.asInstanceOf[Object] + case other => + throw new CometRuntimeException( + s"Slot $idx is marked as a shuffle scan but the input RDD is " + + s"${other.getClass.getName}, expected CometShuffledBatchRDD") + } + } else { + val streams = rdd.iterator(part, context).asInstanceOf[Iterator[ArrowArrayStream]] + if (!streams.hasNext) { + throw new CometRuntimeException(s"Empty ArrowArrayStream RDD partition for slot $idx") + } + streams.next().asInstanceOf[Object] + } + } + .toArray + (inputObjects, shuffleBlockIters.toMap) + } + /** * Creates an RDD for native execution with optional per-partition planning data. */ diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 570848b70a..9c7f81e2ea 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag -import org.apache.arrow.c.ArrowArrayStream +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} @@ -30,7 +32,6 @@ import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{DataType, NullType} -import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.base.Objects @@ -79,13 +80,17 @@ case class CometLocalTableScanExec( } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { + /** + * Build the per-partition `RowArrowReader`; the trait routes it to the JVM or native consumer. + */ + override protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] = { val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) val sparkSchema = originalPlan.schema rdd.mapPartitionsInternal { rowIter => val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.readerBatchIter( + consume( "CometLocalTableScan", new RowArrowReader( _, @@ -95,23 +100,6 @@ case class CometLocalTableScanExec( } } - override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { - val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - val sparkSchema = originalPlan.schema - val numOutputRows = longMetric("numOutputRows") - rdd.mapPartitionsInternal { rowIter => - val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.stream( - "CometLocalTableScan", - allocator => - new RowArrowReader( - allocator, - arrowSchema, - countingRows(rowIter, numOutputRows), - maxRecordsPerBatch)) - } - } - override protected def stringArgs: Iterator[Any] = { if (rows.isEmpty) { Iterator("", output) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index 00e13bcbde..48be16100f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag -import org.apache.arrow.c.ArrowArrayStream +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -73,13 +75,11 @@ case class CometSparkToColumnarExec(child: SparkPlan) private def countingBatches( iter: Iterator[ColumnarBatch], - numInputRows: SQLMetric, - numOutputBatches: SQLMetric): Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + numInputRows: SQLMetric): Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { override def hasNext: Boolean = iter.hasNext override def next(): ColumnarBatch = { val batch = iter.next() numInputRows += batch.numRows() - numOutputBatches += 1 batch } } @@ -95,74 +95,50 @@ case class CometSparkToColumnarExec(child: SparkPlan) } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { + /** + * Build the per-partition `ArrowReader` (columnar or row, depending on the child); the trait + * routes it to the JVM or native consumer. + * + * `numOutputBatches` is incremented from the reader's per-produced-batch callback rather than + * by counting input batches, so it stays accurate on the native path too (native drives + * `loadNextBatch`) and counts produced Arrow batches, not Spark input batches. + */ + override protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] = { val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) val sparkSchema = child.schema + val onConversionNs: Long => Unit = ns => { + conversionTime += ns + numOutputBatches += 1 + } if (child.supportsColumnar) { val maxBatchInt = maxRecordsPerBatch.toInt child.executeColumnar().mapPartitionsInternal { sparkBatches => val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.readerBatchIter( + consume( "CometSparkColumnarToColumnar", new SparkColumnarArrowReader( _, arrowSchema, - countingBatches(sparkBatches, numInputRows, numOutputBatches), + countingBatches(sparkBatches, numInputRows), maxBatchInt, - ns => conversionTime += ns)) + onConversionNs)) } } else { child.execute().mapPartitionsInternal { rowIter => val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.readerBatchIter( + consume( "CometSparkRowToColumnar", new RowArrowReader( _, arrowSchema, countingRows(rowIter, numInputRows), maxRecordsPerBatch, - ns => conversionTime += ns)) - } - } - } - - override def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = { - val numInputRows = longMetric("numInputRows") - val numOutputBatches = longMetric("numOutputBatches") - val conversionTime = longMetric("conversionTime") - val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - val sparkSchema = child.schema - - if (child.supportsColumnar) { - val maxBatchInt = maxRecordsPerBatch.toInt - child.executeColumnar().mapPartitionsInternal { sparkBatches => - val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.stream( - "CometSparkColumnarToColumnar", - allocator => - new SparkColumnarArrowReader( - allocator, - arrowSchema, - countingBatches(sparkBatches, numInputRows, numOutputBatches), - maxBatchInt, - ns => conversionTime += ns)) - } - } else { - child.execute().mapPartitionsInternal { rowIter => - val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) - CometArrowStream.stream( - "CometSparkRowToColumnar", - allocator => - new RowArrowReader( - allocator, - arrowSchema, - countingRows(rowIter, numInputRows), - maxRecordsPerBatch, - ns => conversionTime += ns)) + onConversionNs)) } } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala index b6d904a189..6415256a39 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.comet.execution.arrow import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag import org.apache.arrow.c.{ArrowArrayStream, Data} import org.apache.arrow.memory.BufferAllocator @@ -37,11 +38,30 @@ import org.apache.comet.CometArrowAllocator import org.apache.comet.vector.{CometDictionaryVector, CometVector, NativeUtil} /** - * Marker for Comet operators that can produce Arrow data destined for a Comet native executor - * directly as the C Stream Interface, skipping the intermediate `RDD[ColumnarBatch]` layer. + * A Comet operator that produces its output as Arrow data, consumable either as JVM + * `ColumnarBatch`es (`doExecuteColumnar`) or, when the consumer is a Comet native executor, + * directly as the Arrow C Stream Interface (`doExecuteAsArrowStream`), skipping the intermediate + * `RDD[ColumnarBatch]` layer. + * + * Implementors supply only [[mapToReaders]] (their source RDD + per-partition `ArrowReader`); the + * two execution paths here differ solely in whether each partition's reader is drained into + * `ColumnarBatch`es or exported as a stream. */ trait CometNativeArrowSource extends SparkPlan { - def doExecuteAsArrowStream(): RDD[ArrowArrayStream] + + /** + * Build this operator's per-partition `ArrowReader` and hand it to `consume`, returning the + * output RDD. `consume` is provided by this trait: `CometArrowStream.readerBatchIter` for the + * JVM columnar path, `CometArrowStream.stream` for the native C Stream path. + */ + protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + mapToReaders(CometArrowStream.readerBatchIter) + + def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = + mapToReaders(CometArrowStream.stream) } object CometArrowStream extends Logging { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala index 44acc6390d..0579e57bce 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala @@ -21,16 +21,18 @@ package org.apache.spark.sql.comet.execution.shuffle import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.CometExecRDD import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.CometShuffleBlockIterator /** * Thin scheduling-anchor RDD for the native-shuffle path. Declares `OneToOneDependency` on each - * leaf input RDD (so the DAGScheduler triggers prior stages, broadcasts, etc.) and constructs - * per-partition leaf iterators in `compute`, packaged into a [[CometNativeShuffleInputIterator]]. - * The iterator reports `hasNext = false`; [[CometNativeShuffleWriter]] downcasts it and reads the - * leaf iterators directly to drive the unified `ShuffleWriter(child = childNativeOp)` plan. + * leaf input RDD (so the DAGScheduler triggers prior stages, broadcasts, etc.) and resolves the + * per-partition native input slots in `compute`, packaged into a + * [[CometNativeShuffleInputIterator]]. The iterator reports `hasNext = false`; + * [[CometNativeShuffleWriter]] downcasts it and reads those slots directly to drive the unified + * `ShuffleWriter(child = childNativeOp)` plan. */ private[shuffle] class CometNativeShuffleInputRDD( sc: SparkContext, @@ -55,23 +57,13 @@ private[shuffle] class CometNativeShuffleInputRDD( split: Partition, context: TaskContext): Iterator[Product2[Int, ColumnarBatch]] = { val partition = split.asInstanceOf[CometNativeShuffleInputPartition] - val shuffleBlockIters: Map[Int, CometShuffleBlockIterator] = - shuffleScanIndices.flatMap { si => - inputRDDs(si) match { - case rdd: CometShuffledBatchRDD => - Some(si -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(si), context)) - case _ => None - } - }.toMap - // Non-shuffle leaves are RDD[ArrowArrayStream] (one stream per partition); pull them lazily in - // the writer. Shuffle leaves reach native via the block iterators above, so leave an empty - // placeholder to keep slot indices aligned (the writer never drains those slots). - val leafIterators: Seq[Iterator[_]] = - inputRDDs.zip(partition.inputPartitions).zipWithIndex.map { case ((rdd, part), idx) => - if (shuffleScanIndices.contains(idx)) Iterator.empty - else rdd.iterator(part, context) - } - new CometNativeShuffleInputIterator(partition.index, leafIterators, shuffleBlockIters) + val (inputObjects, shuffleBlockIters) = + CometExecRDD.resolveInputObjects( + inputRDDs, + partition.inputPartitions, + shuffleScanIndices, + context) + new CometNativeShuffleInputIterator(partition.index, inputObjects, shuffleBlockIters) } override def getPreferredLocations(split: Partition): Seq[String] = { @@ -97,12 +89,13 @@ private[shuffle] class CometNativeShuffleInputPartition( /** * Iterator handed to [[CometNativeShuffleWriter.write]] via Spark's ShuffleMapTask. Reports no - * elements; the writer downcasts and reads `partitionIndex`, `leafIterators`, and - * `shuffleBlockIterators` directly to drive the unified native plan. + * elements; the writer downcasts and reads `partitionIndex`, `inputObjects`, and + * `shuffleBlockIterators` directly to drive the unified native plan. `inputObjects` are the + * already-resolved native input slots (see [[CometExecRDD.resolveInputObjects]]). */ private[shuffle] class CometNativeShuffleInputIterator( val partitionIndex: Int, - val leafIterators: Seq[Iterator[_]], + val inputObjects: Array[Object], val shuffleBlockIterators: Map[Int, CometShuffleBlockIterator]) extends Iterator[Product2[Int, ColumnarBatch]] { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 5b549b992a..5248974c47 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -91,7 +91,7 @@ class CometNativeShuffleWriter[K, V]( s"${other.getClass.getName}") } val partitionIdx = shuffleInputIter.partitionIndex - val leafIterators = shuffleInputIter.leafIterators + val inputObjects = shuffleInputIter.inputObjects val shuffleBlockIters = shuffleInputIter.shuffleBlockIterators val unifiedPlan = buildUnifiedPlan(tempDataFilename, tempIndexFilename) @@ -124,17 +124,6 @@ class CometNativeShuffleWriter[K, V]( // breakdown matches what the split-driver flow showed. val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(spec.childMetricNode)) - // Each leaf input arrives already wrapped as an ArrowArrayStream (one per partition) by the - // arrow C Stream path in CometNativeExec.buildNativeContext; shuffle leaves arrive as a - // CometShuffleBlockIterator. Mirror CometExecRDD.compute: hand native the already-built - // stream/block object per slot. No re-wrapping here. - val inputObjects: Array[Object] = leafIterators.zipWithIndex.map { case (it, idx) => - shuffleBlockIters.get(idx) match { - case Some(blockIter) => blockIter.asInstanceOf[Object] - case None => it.next().asInstanceOf[Object] - } - }.toArray - val cometIter = new CometExecIterator( CometExec.newIterId, inputObjects, From bf63e3dde27a656687c934335b6477e57dce8809 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 15:26:58 -0400 Subject: [PATCH 35/39] cleanup --- .../arrow/CometArrowConverters.scala | 60 +++---------------- 1 file changed, 8 insertions(+), 52 deletions(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 32441029bb..2d4fd71376 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -26,20 +26,19 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.comet.vector.NativeUtil /** - * Convert Spark `InternalRow` / `ColumnarBatch` streams to a stream of independently-owned Arrow - * `ColumnarBatch`es. Each emitted batch owns a fresh `VectorSchemaRoot` with newly allocated - * buffers; the consumer is responsible for closing the batch. + * Convert a stream of Spark `InternalRow`s to a stream of independently-owned Arrow + * `ColumnarBatch`es: each emitted batch owns a fresh `VectorSchemaRoot` with newly allocated + * buffers and the consumer is responsible for closing it. * - * Buffers are allocated from the caller-provided `BufferAllocator`. The caller owns the - * allocator's lifecycle (typically a child allocator closed at task completion). When emitted - * batches reach `ColumnarBatchArrowReader.loadNextBatch`, ownership of their buffers is - * transferred (via `VectorUnloader` / `loadFieldBuffers`) to the reader's allocator, after which - * the source batch is closed and the producer's allocator returns to zero outstanding bytes. + * This differs from [[RowArrowReader]], which reuses one stable `VectorSchemaRoot` + * (release-and-replace) so only one batch is valid at a time. Use this when multiple emitted + * batches must be alive simultaneously (e.g. tests that buffer several batches before consuming). + * Buffers come from the caller-provided `BufferAllocator`, whose lifecycle the caller owns. */ object CometArrowConverters extends Logging { @@ -75,47 +74,4 @@ object CometArrowConverters extends Logging { } } } - - /** - * Slice a single Spark `ColumnarBatch` into one or more Arrow `ColumnarBatch`es of at most - * `maxRecordsPerBatch` rows each. Each emitted batch owns a fresh `VectorSchemaRoot`. - */ - def columnarBatchToArrowBatchIter( - colBatch: ColumnarBatch, - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - allocator: BufferAllocator): Iterator[ColumnarBatch] = { - val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) - val totalRows = colBatch.numRows() - - new Iterator[ColumnarBatch] { - private var rowsProduced: Int = 0 - - override def hasNext: Boolean = rowsProduced < totalRows - - override def next(): ColumnarBatch = { - val rowsToProduce = - if (maxRecordsPerBatch <= 0) totalRows - rowsProduced - else math.min(maxRecordsPerBatch, totalRows - rowsProduced) - - val root = VectorSchemaRoot.create(arrowSchema, allocator) - val writer = ArrowWriter.create(root) - - for (columnIndex <- 0 until colBatch.numCols()) { - val column = colBatch.column(columnIndex) - val columnArray = new ColumnarArray(column, rowsProduced, rowsToProduce) - if (column.hasNull) { - writer.writeCol(columnArray, columnIndex) - } else { - writer.writeColNoNull(columnArray, columnIndex) - } - } - - rowsProduced += rowsToProduce - writer.finish() - NativeUtil.rootAsBatch(root) - } - } - } } From 04c082569ac38f703712902442b952f4ae632273 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 16:16:46 -0400 Subject: [PATCH 36/39] Address PR feedback from #4507. --- .../sql/comet/execution/shuffle/CometShuffleExchangeExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 62691567c5..565183278d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -689,7 +689,7 @@ object CometShuffleExchangeExec metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val scanTypes = outputAttributes.flatten { attr => + val scanTypes = outputAttributes.flatMap { attr => QueryPlanSerde.serializeDataType(attr.dataType) } if (scanTypes.length != outputAttributes.length) { From 6aa6621ed7e780c1ef59b65ac5318d5f7399ab68 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Tue, 2 Jun 2026 19:39:36 -0400 Subject: [PATCH 37/39] Remove inadvertent test change brought over from #4393. --- .../org/apache/comet/exec/CometWindowExecSuite.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala index a9fdc96231..544cd91bd2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala @@ -108,7 +108,15 @@ class CometWindowExecSuite extends CometTestBase { val cometShuffles = collect(df2.queryExecution.executedPlan) { case _: CometShuffleExchangeExec => true } - assert(cometShuffles.length == 1) + if (shuffleMode == "jvm" || shuffleMode == "auto") { + assert(cometShuffles.length == 1) + } else { + // we fall back to Spark for shuffle because we do not support + // native shuffle with a LocalTableScan input, and we do not fall + // back to Comet columnar shuffle due to + // https://github.com/apache/datafusion-comet/issues/1248 + assert(cometShuffles.isEmpty) + } } } } From b4cb826b858053c7ca4a2e35c05c06a62d16b0db Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Thu, 4 Jun 2026 18:01:22 -0400 Subject: [PATCH 38/39] Fix batch size calculation in native shuffle writer, and task metrics on the Scala side. --- .../src/partitioners/multi_partition.rs | 65 +++++++++++++++++-- native/shuffle/src/shuffle_writer.rs | 56 +++++++++++++++- .../shuffle/CometNativeShuffleWriter.scala | 7 ++ .../sql/comet/CometTaskMetricsSuite.scala | 51 +++++++++++++++ 4 files changed, 172 insertions(+), 7 deletions(-) diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 7de9314f54..40f09496c0 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -22,10 +22,10 @@ use crate::partitioners::partitioned_batch_iterator::{ use crate::partitioners::ShufflePartitioner; use crate::writers::{BufBatchWriter, PartitionWriter}; use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayData, ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion::common::utils::proxy::VecAllocExt; -use datafusion::common::DataFusionError; +use datafusion::common::{DataFusionError, HashSet}; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::metrics::Time; @@ -125,6 +125,55 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Start addresses (as `usize`, since raw pointers are not `Send`) of the backing buffers + /// currently pinned by `buffered_batches`, so the spill reservation charges each distinct + /// allocation once rather than once per slice that references it. Cleared whenever the + /// buffered batches drain (spill / shuffle_write). See `count_new_buffers`. + pinned_buffers: HashSet, +} + +/// Sum of the capacities of the backing buffers reachable from `batch` whose start address is +/// not already in `seen` (recursing through child data: dictionary values, list children, and so +/// on). `seen` is kept across every buffered batch, so this returns the bytes a batch newly +/// pins, which is the memory the shuffle writer holds resident by buffering it. +/// +/// Cheaper measures do not match resident memory for the batches this writer sees. A partial +/// `HashAggregate` emits one group-values buffer sliced into batch_size chunks, and every +/// buffered chunk shares that one allocation: +/// +/// * `RecordBatch::get_array_memory_size()` charges a buffer's capacity once per array that +/// references it, counting the shared allocation once per chunk and overstating memory by the +/// chunk count. Reserving against that figure trips the memory limit on nearly every batch +/// and spills spuriously. +/// * the sum of `ArrayData::get_slice_memory_size()` charges only the live rows of each slice, +/// but holding a slice pins its whole backing allocation. The group-values `Vec` rounds +/// capacity up to the next power of two, so that figure undercounts resident memory and lets +/// the writer hold well past its limit before spilling. +/// +/// Counting each distinct allocation once, keyed by start address, is the measure that tracks +/// resident memory regardless of how arrays share or slice their buffers. +fn count_new_buffers(batch: &RecordBatch, seen: &mut HashSet) -> usize { + fn visit(data: &ArrayData, seen: &mut HashSet, total: &mut usize) { + for buffer in data.buffers() { + if seen.insert(buffer.data_ptr().as_ptr() as usize) { + *total += buffer.capacity(); + } + } + if let Some(nulls) = data.nulls() { + let inner = nulls.inner().inner(); + if seen.insert(inner.data_ptr().as_ptr() as usize) { + *total += inner.capacity(); + } + } + for child in data.child_data() { + visit(child, seen, total); + } + } + let mut total = 0; + for column in batch.columns() { + visit(&column.to_data(), seen, &mut total); + } + total } impl MultiPartitionShuffleRepartitioner { @@ -190,6 +239,7 @@ impl MultiPartitionShuffleRepartitioner { reservation, tracing_enabled, write_buffer_size, + pinned_buffers: HashSet::new(), }) } @@ -210,9 +260,6 @@ impl MultiPartitionShuffleRepartitioner { )); } - // Update data size metric - self.metrics.data_size.add(input.get_array_memory_size()); - // NOTE: in shuffle writer exec, the output_rows metrics represents the // number of rows those are written to output data file. self.metrics.baseline.record_output(input.num_rows()); @@ -398,7 +445,11 @@ impl MultiPartitionShuffleRepartitioner { partition_row_indices: &[u32], partition_starts: &[u32], ) -> datafusion::common::Result<()> { - let mut mem_growth: usize = input.get_array_memory_size(); + // Charge both the reservation and the data_size metric for the buffers this batch newly + // pins; `count_new_buffers` dedups buffers shared across already-buffered batches. + let new_buffer_bytes = count_new_buffers(&input, &mut self.pinned_buffers); + self.metrics.data_size.add(new_buffer_bytes); + let mut mem_growth: usize = new_buffer_bytes; let buffered_partition_idx = self.buffered_batches.len() as u32; self.buffered_batches.push(input); @@ -517,6 +568,7 @@ impl MultiPartitionShuffleRepartitioner { } self.reservation.free(); + self.pinned_buffers.clear(); self.metrics.spill_count.add(1); self.metrics.spilled_bytes.add(spilled_bytes); Ok(()) @@ -560,6 +612,7 @@ impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { let start_time = Instant::now(); let mut partitioned_batches = self.partitioned_batches(); + self.pinned_buffers.clear(); let num_output_partitions = self.partition_indices.len(); let mut offsets = vec![0; num_output_partitions + 1]; diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index b96917550f..7b6f4ca7e2 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -267,7 +267,7 @@ async fn external_shuffle( mod test { use super::*; use crate::{read_ipc_compressed, ShuffleBlockWriter}; - use arrow::array::{Array, StringArray, StringBuilder}; + use arrow::array::{Array, Int64Array, StringArray, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; @@ -389,6 +389,60 @@ mod test { repartitioner.insert_batch(batch.clone()).await.unwrap(); } + #[tokio::test] + async fn shuffle_partitioner_charges_shared_buffer_once() { + // `insert_batch` slices a large batch into batch_size chunks that all share one backing + // buffer (the shape a partial HashAggregate's sliced emit hands the writer). The + // reservation and the data_size metric must charge that buffer once, not once per chunk; + // otherwise the chunk count multiplies it and a batch well under the memory limit spills + // spuriously and reports a wildly inflated data_size. + let n = 16_384usize; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let backing = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from_iter_values(0..n as i64))], + ) + .unwrap(); + let buffer_bytes = backing.get_array_memory_size(); + + let memory_limit = 512 * 1024; + let batch_size = 1024; // 16 chunks, all sharing the one backing buffer + let num_partitions = 2; + let runtime_env = create_runtime(memory_limit); + let metrics_set = ExecutionPlanMetricsSet::new(); + let metrics = ShufflePartitionerMetrics::new(&metrics_set, 0); + let data_size = metrics.data_size.clone(); + let spill_count = metrics.spill_count.clone(); + let dir = tempfile::tempdir().unwrap(); + let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new( + 0, + dir.path().join("data.out").to_str().unwrap().to_string(), + dir.path().join("index.out").to_str().unwrap().to_string(), + backing.schema(), + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), + metrics, + runtime_env, + batch_size, + CompressionCodec::Lz4Frame, + false, + 1024 * 1024, + ) + .unwrap(); + + repartitioner.insert_batch(backing).await.unwrap(); + + assert!( + data_size.value() <= 2 * buffer_bytes, + "data_size {} should charge the shared buffer about once (~{buffer_bytes} bytes), not per chunk", + data_size.value() + ); + assert_eq!( + spill_count.value(), + 0, + "one buffer under the memory limit must not spill once per chunk" + ); + } + fn create_runtime(memory_limit: usize) -> Arc { Arc::new( RuntimeEnvBuilder::new() diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index 5248974c47..8cb1eb86ad 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -124,6 +124,13 @@ class CometNativeShuffleWriter[K, V]( // breakdown matches what the split-driver flow showed. val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(spec.childMetricNode)) + // The leaf scans execute inside this writer's single plan rather than a separate native + // stage RDD, so the usual CometExecRDD.compute() bridge (operators.scala) never runs for + // them. Report their bytes/rows to the task's input metrics here instead. + if (ctx.hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) + } + val cometIter = new CometExecIterator( CometExec.newIterId, inputObjects, diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 0187aed8e5..fdcca8d351 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -106,6 +106,57 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("native shuffle reports task input metrics for its scan child") { + // A native shuffle whose child subtree includes a CometNativeScanExec runs that scan inside + // the writer's single plan, so the usual CometExecRDD.compute() input-metric bridge never + // runs. CometNativeShuffleWriter must report the scan's bytes/rows itself; otherwise the + // ShuffleMapTask reports zero input metrics. + withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { + val shuffled = sql("SELECT * FROM tbl").repartition(4, $"_1") + + val cometShuffle = find(shuffled.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + } + assert(cometShuffle.isDefined, "CometShuffleExchangeExec not found in the plan") + assert( + cometShuffle.get.asInstanceOf[CometShuffleExchangeExec].shuffleType == CometNativeShuffle) + assert( + find(shuffled.queryExecution.executedPlan) { + case _: CometNativeScanExec => true + case _ => false + }.isDefined, + "expected a CometNativeScanExec child so the scan is embedded in the writer plan") + + val mapInputBytes = mutable.ArrayBuffer.empty[Long] + val mapInputRecords = mutable.ArrayBuffer.empty[Long] + spark.sparkContext.addSparkListener(new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskType.contains("ShuffleMapTask")) { + val im = taskEnd.taskMetrics.inputMetrics + mapInputBytes.synchronized { mapInputBytes += im.bytesRead } + mapInputRecords.synchronized { mapInputRecords += im.recordsRead } + } + } + }) + + // Avoid receiving earlier taskEnd events + spark.sparkContext.listenerBus.waitUntilEmpty() + + shuffled.collect() + + spark.sparkContext.listenerBus.waitUntilEmpty() + + assert(mapInputRecords.nonEmpty, "no ShuffleMapTask metrics captured") + assert( + mapInputRecords.sum == 10000, + s"recordsRead across map tasks (${mapInputRecords.sum}) should equal the scanned row count") + assert( + mapInputBytes.sum > 0, + s"bytesRead across map tasks (${mapInputBytes.sum}) should be > 0") + } + } + test("native parquet write reports task-level output metrics") { withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") { withTempPath { dir => From e1a94919971991c871927d073be2dca5acdf82d3 Mon Sep 17 00:00:00 2001 From: Matt Butrovich Date: Fri, 5 Jun 2026 09:32:54 -0400 Subject: [PATCH 39/39] Move SchemaAlignExec under shuffle. --- native/core/src/execution/operators/mod.rs | 2 -- native/core/src/execution/planner.rs | 6 ++---- native/shuffle/src/lib.rs | 2 ++ .../operators => shuffle/src}/schema_align.rs | 21 +++++++++++++++---- 4 files changed, 21 insertions(+), 10 deletions(-) rename native/{core/src/execution/operators => shuffle/src}/schema_align.rs (89%) diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index cd7507ebdb..d68252bd9b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -34,8 +34,6 @@ pub use parquet_writer::ParquetWriterExec; mod csv_scan; pub mod projection; mod scan; -mod schema_align; mod shuffle_scan; pub use csv_scan::init_csv_datasource_exec; -pub use schema_align::SchemaAlignExec; pub use shuffle_scan::ShuffleScanExec; diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index f96923bd45..c09ed5a0ef 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -28,13 +28,11 @@ use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::list_positions::ListPositionsExpr, expressions::subquery::Subquery, - operators::{ - ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, SchemaAlignExec, ShuffleScanExec, - }, + operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec, ShuffleScanExec}, planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, - shuffle::ShuffleWriterExec, + shuffle::{SchemaAlignExec, ShuffleWriterExec}, }; use crate::jvm_bridge::{jni_call, JVMClasses}; use arrow::compute::CastOptions; diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs index dd3b900272..2263ae0dac 100644 --- a/native/shuffle/src/lib.rs +++ b/native/shuffle/src/lib.rs @@ -19,6 +19,7 @@ pub(crate) mod comet_partitioning; pub mod ipc; pub(crate) mod metrics; pub(crate) mod partitioners; +mod schema_align; mod shuffle_writer; mod spark_crc32c_hasher; pub mod spark_unsafe; @@ -26,5 +27,6 @@ pub(crate) mod writers; pub use comet_partitioning::CometPartitioning; pub use ipc::read_ipc_compressed; +pub use schema_align::SchemaAlignExec; pub use shuffle_writer::ShuffleWriterExec; pub use writers::{CompressionCodec, ShuffleBlockWriter}; diff --git a/native/core/src/execution/operators/schema_align.rs b/native/shuffle/src/schema_align.rs similarity index 89% rename from native/core/src/execution/operators/schema_align.rs rename to native/shuffle/src/schema_align.rs index 58983d132c..f564afc024 100644 --- a/native/core/src/execution/operators/schema_align.rs +++ b/native/shuffle/src/schema_align.rs @@ -15,10 +15,23 @@ // specific language governing permissions and limitations // under the License. -//! `SchemaAlignExec` reshapes its child's output so the per-column Arrow type and field-level -//! nullability match what Spark catalyst declared, casting where necessary. Sits between a native -//! subtree and `ShuffleWriterExec` so DataFusion / `datafusion-spark` return-type drift is caught -//! before it reaches shuffle blocks. See +//! `SchemaAlignExec` reshapes a shuffle writer's input so each column's Arrow type and field-level +//! nullability match what Spark catalyst declared, casting where necessary. +//! +//! This concern is enclosed by shuffle on purpose: everywhere else in the native runtime, +//! return-type drift from DataFusion / `datafusion-spark` is self-healing. When a native plan's +//! output crosses back to the JVM and feeds another native plan, the consuming `ScanExec` casts +//! every imported column to the catalyst-declared type, so a wrong Arrow type never survives the +//! boundary. Shuffle is the lone exception, on two counts: +//! +//! 1. The writer hash-partitions on these columns, and Spark's hash differs by type (e.g. `Int32` +//! vs `Int64`), so a drifted type would route rows to the wrong partition. A read-side cast +//! cannot undo a wrong partition assignment, so the type must be corrected before partitioning. +//! 2. The shuffle read path (`ShuffleScanExec`) does not cast; it stamps the catalyst schema onto +//! the decoded block and errors on any mismatch. The schema is serialized into the block on +//! write and trusted on read. +//! +//! Both force the alignment to happen on the writer input. See //! for the running list of mismatched //! functions.