From dee4966725456603abbb993039d2232d527d7994 Mon Sep 17 00:00:00 2001 From: shekharrajak Date: Sat, 27 Dec 2025 23:28:08 +0530 Subject: [PATCH] Add config to disable columnar shuffle for complex types --- .../scala/org/apache/comet/CometConf.scala | 11 ++ .../shuffle/CometShuffleExchangeExec.scala | 46 +++++--- .../exec/CometColumnarShuffleSuite.scala | 102 +++++++++++++----- 3 files changed, 117 insertions(+), 42 deletions(-) diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index cccad53c53..23678e6365 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -376,6 +376,17 @@ object CometConf extends ShimCometConf { .intConf .createWithDefault(1) + val COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.columnar.shuffle.complexTypes.enabled") + .category(CATEGORY_SHUFFLE) + .doc( + "Whether to enable Comet columnar shuffle for complex types (struct, array, map). " + + "When disabled (default), queries with complex types will fall back to Spark shuffle " + + "for better performance. Enable this only if you need columnar shuffle features for " + + "complex types and accept potential performance tradeoffs.") + .booleanConf + .createWithDefault(false) + val COMET_COLUMNAR_SHUFFLE_ASYNC_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.columnar.shuffle.async.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 2e6ab9aff9..964ece5506 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom import com.google.common.base.Objects import org.apache.comet.CometConf -import org.apache.comet.CometConf.{COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} +import org.apache.comet.CometConf.{COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED, COMET_EXEC_SHUFFLE_ENABLED, COMET_SHUFFLE_MODE} import org.apache.comet.CometSparkSessionExtensions.{isCometShuffleManagerEnabled, withInfo} import org.apache.comet.serde.{Compatible, OperatorOuterClass, QueryPlanSerde, SupportLevel, Unsupported} import org.apache.comet.serde.operator.CometSink @@ -403,23 +403,39 @@ object CometShuffleExchangeExec * * Comet columnar shuffle used native code to convert Spark unsafe rows to Arrow batches, see * shuffle/row.rs + * + * Returns None if supported, or Some(reason) if not supported. */ - def supportedSerializableDataType(dt: DataType): Boolean = dt match { + def supportedSerializableDataType(dt: DataType): Option[String] = dt match { case _: BooleanType | _: ByteType | _: ShortType | _: IntegerType | _: LongType | _: FloatType | _: DoubleType | _: StringType | _: BinaryType | _: TimestampType | _: TimestampNTZType | _: DecimalType | _: DateType => - true + None case StructType(fields) => - fields.nonEmpty && fields.forall(f => supportedSerializableDataType(f.dataType)) && - // Java Arrow stream reader cannot work on duplicate field name - fields.map(f => f.name).distinct.length == fields.length && - fields.nonEmpty + if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) { + Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled") + } else if (fields.isEmpty) { + Some("struct type with no fields is not supported") + } else if (fields.map(f => f.name).distinct.length != fields.length) { + // Java Arrow stream reader cannot work on duplicate field name + Some("struct type with duplicate field names is not supported") + } else { + fields.flatMap(f => supportedSerializableDataType(f.dataType)).headOption + } case ArrayType(elementType, _) => - supportedSerializableDataType(elementType) + if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) { + Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled") + } else { + supportedSerializableDataType(elementType) + } case MapType(keyType, valueType, _) => - supportedSerializableDataType(keyType) && supportedSerializableDataType(valueType) + if (!COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.get(s.conf)) { + Some(s"${COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key} is not enabled") + } else { + supportedSerializableDataType(keyType).orElse(supportedSerializableDataType(valueType)) + } case _ => - false + Some(s"unsupported data type: $dt") } if (!isCometShuffleEnabledWithInfo(s)) { @@ -444,9 +460,13 @@ object CometShuffleExchangeExec val inputs = s.child.output for (input <- inputs) { - if (!supportedSerializableDataType(input.dataType)) { - withInfo(s, s"unsupported shuffle data type ${input.dataType} for input $input") - return false + supportedSerializableDataType(input.dataType) match { + case Some(reason) => + withInfo( + s, + s"unsupported data type ${input.dataType} for column ${input.name}: $reason") + return false + case None => // supported } } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala index 70479f0e34..076e9465cd 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometColumnarShuffleSuite.scala @@ -112,12 +112,42 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar checkSparkAnswer(df) } + test("Fallback to Spark for complex types when config is disabled (default)") { + // https://github.com/apache/datafusion-comet/issues/2904 + // By default, complex types should fall back to Spark shuffle for better performance + withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "false") { + // Test struct type + withParquetTable(Seq((1, (0, "1")), (2, (3, "3"))), "tbl") { + val df = sql("SELECT * FROM tbl").repartition(10, $"_1", $"_2") + // Should have 0 Comet shuffle exchanges since complex types are disabled + checkCometExchange(df, 0, false) + checkSparkAnswer(df) + } + + // Test array type + withParquetTable((0 until 10).map(i => (Seq(i, i + 1), i + 1)), "tbl2") { + val df = sql("SELECT * FROM tbl2").repartition(10, $"_1", $"_2") + checkCometExchange(df, 0, false) + checkSparkAnswer(df) + } + + // Test map type + withParquetTable((0 until 10).map(i => (Map(i -> i.toString), i + 1)), "tbl3") { + val df = sql("SELECT * FROM tbl3").repartition(10, $"_1", $"_2") + checkCometExchange(df, 0, false) + checkSparkAnswer(df) + } + } + } + test("columnar shuffle on nested struct including nulls") { // https://github.com/apache/datafusion-comet/issues/1538 assume(CometConf.COMET_NATIVE_SCAN_IMPL.get() != CometConf.SCAN_NATIVE_DATAFUSION) Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable( (0 until 50).map(i => (i, Seq((i + 1, i.toString), null, (i + 3, (i + 3).toString)), i + 1)), @@ -137,7 +167,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar test("columnar shuffle on struct including nulls") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { val data: Seq[(Int, (Int, String))] = Seq((1, (0, "1")), (2, (3, "3")), (3, null)) withParquetTable(data, "tbl") { @@ -158,6 +190,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> execEnabled, CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable((0 until 50).map(i => (Map(Seq(i, i + 1) -> i), i + 1)), "tbl") { @@ -230,6 +263,7 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> execEnabled, CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable( @@ -336,7 +370,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar def columnarShuffleOnMapTest[K: TypeTag](num: Int, keys: Seq[K]): Unit = { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable(genTuples(num, keys), "tbl") { repartitionAndSort(numPartitions) } @@ -451,7 +487,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable( (0 until 50).map(i => ( @@ -483,7 +521,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar Seq("false", "true").foreach { _ => Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable( (0 until 50).map(i => (Seq(Seq(i + 1), Seq(i + 2), Seq(i + 3)), i + 1)), "tbl") { @@ -503,7 +543,9 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar test("columnar shuffle on nested struct") { Seq(10, 201).foreach { numPartitions => Seq("1.0", "10.0").foreach { ratio => - withSQLConf(CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { + withSQLConf( + CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true", + CometConf.COMET_SHUFFLE_PREFER_DICTIONARY_RATIO.key -> ratio) { withParquetTable( (0 until 50).map(i => ((i, 2.toString, (i + 1).toLong, (3.toString, i + 1, (i + 2).toLong)), i + 1)), @@ -871,29 +913,31 @@ abstract class CometColumnarShuffleSuite extends CometTestBase with AdaptiveSpar } test("columnar shuffle on null struct fields") { - withTempDir { dir => - val testData = "{}\n" - val path = Paths.get(dir.toString, "test.json") - Files.write(path, testData.getBytes) - - // Define the nested struct schema - val readSchema = StructType( - Array( - StructField( - "metaData", - StructType( - Array(StructField( - "format", - StructType(Array(StructField("provider", StringType, nullable = true))), - nullable = true))), - nullable = true))) - - // Read JSON with custom schema and repartition, this will repartition rows that contain - // null struct fields. - val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2) - assert(df.count() == 1) - val row = df.collect()(0) - assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null) + withSQLConf(CometConf.COMET_COLUMNAR_SHUFFLE_COMPLEX_TYPES_ENABLED.key -> "true") { + withTempDir { dir => + val testData = "{}\n" + val path = Paths.get(dir.toString, "test.json") + Files.write(path, testData.getBytes) + + // Define the nested struct schema + val readSchema = StructType( + Array( + StructField( + "metaData", + StructType( + Array(StructField( + "format", + StructType(Array(StructField("provider", StringType, nullable = true))), + nullable = true))), + nullable = true))) + + // Read JSON with custom schema and repartition, this will repartition rows that contain + // null struct fields. + val df = spark.read.format("json").schema(readSchema).load(path.toString).repartition(2) + assert(df.count() == 1) + val row = df.collect()(0) + assert(row.getAs[org.apache.spark.sql.Row]("metaData") == null) + } } }