From ffa4048cc42b2d304b61906f7f8d0fa17f6901c4 Mon Sep 17 00:00:00 2001 From: Scott Schenkein Date: Wed, 3 Jun 2026 21:43:57 -0400 Subject: [PATCH] feat: rebalance associative bitwise / Add / Multiply chains (#4577) Follow-up to #4531 (deep And/Or chains). Protobuf's recursion limit (100) applies to any deeply nested BinaryExpr, so a long left-deep chain of other associative operators overflows the same way when the serialized plan is re-parsed. Extend the rebalancing (flattenAssociative + a balanced O(log n)-depth tree) to: - BitwiseAnd / BitwiseOr / BitwiseXor: always integral and exactly associative, so they reuse the existing createBalancedBinaryExpr directly. - Add / Multiply: gated via isAssociativeAndRebalanceable to integral types in LEGACY (wrapping, modular) eval mode -- the only exactly-associative case. Float isn't associative (Spark's ReorderAssociativeOperator excludes it too); ANSI/TRY make integer overflow position (which the grouping changes) observable; decimal precision grows per op. Those keep the existing left-deep serialization. Add and Multiply emit a MathExpr (eval_mode + return_type) rather than a BinaryExpr, so a new createBalancedMathExpr builds the balanced tree with the chain's uniform type and eval mode at every inner node. Tests mirror #4531: project 200-deep chains and assert Comet runs them natively with results matching Spark (which also verifies the associativity guarantee). Co-Authored-By: Claude Opus 4.8 (1M context) --- .../org/apache/comet/serde/arithmetic.scala | 131 +++++++++++++++--- .../org/apache/comet/serde/bitwise.scala | 30 ++-- .../apache/comet/CometExpressionSuite.scala | 34 +++++ 3 files changed, 167 insertions(+), 28 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala index 58e99f9c79..74eeab99ff 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arithmetic.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType, import org.apache.comet.CometSparkSessionExtensions.withFallbackReason import org.apache.comet.expressions.{CometCast, CometEvalMode} -import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType, serializeDataType} +import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, flattenAssociative, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType, serializeDataType} import org.apache.comet.shims.CometEvalModeUtil trait MathBase { @@ -83,6 +83,65 @@ trait MathBase { false } + /** + * True when an `Add` / `Multiply` chain of `dataType` in `evalMode` can be rebalanced without + * changing results. Only integral types in LEGACY (wrapping, modular) eval mode are exactly + * associative, so re-grouping the chain is a no-op on the value. Floating point is not + * associative (rounding differs by grouping -- Spark's own `ReorderAssociativeOperator` + * excludes it). ANSI / TRY make integer overflow observable (throw / null), and the grouping + * changes which intermediate overflows, so those are excluded too. Decimal is excluded because + * intermediate precision grows per operation. + */ + def isAssociativeAndRebalanceable(dataType: DataType, evalMode: EvalMode.Value): Boolean = + evalMode == EvalMode.LEGACY && (dataType match { + case _: ByteType | _: ShortType | _: IntegerType | _: LongType => true + case _ => false + }) + + /** + * Like [[QueryPlanSerde.createBalancedBinaryExpr]] but for `MathExpr`-shaped associative + * operators (`Add`, `Multiply`): each combined inner node carries the chain's `evalMode` and + * `returnType`. Rebalances a flattened chain into an `O(log n)`-depth tree so deep + * `a + b + ...` chains serialize to a shallow proto instead of a left-deep one that overflows + * protobuf's recursion limit when the plan is re-parsed. Only safe for exactly-associative + * chains -- callers gate via [[isAssociativeAndRebalanceable]]. The flattened leaves all share + * the chain's type (Spark coerces operands to it, with casts acting as flatten boundaries), so + * a single `returnType` / `evalMode` is correct for every inner node. + */ + def createBalancedMathExpr( + expr: Expression, + operands: Seq[Expression], + inputs: Seq[Attribute], + binding: Boolean, + dataType: DataType, + evalMode: EvalMode.Value, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val protos = operands.map(exprToProtoInternal(_, inputs, binding)) + if (protos.exists(_.isEmpty)) { + withFallbackReason(expr, operands: _*) + None + } else { + val returnType = serializeDataType(dataType) + val evalModeProto = evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode)) + val leaves = protos.map(_.get).toIndexedSeq + def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = { + if (slice.length == 1) slice.head + else { + val mid = slice.length / 2 + val mathBuilder = ExprOuterClass.MathExpr + .newBuilder() + .setLeft(build(slice.slice(0, mid))) + .setRight(build(slice.slice(mid, slice.length))) + .setEvalMode(evalModeProto) + returnType.foreach(mathBuilder.setReturnType) + f(ExprOuterClass.Expr.newBuilder(), mathBuilder.build()).build() + } + } + Some(build(leaves)) + } + } + } object CometAdd extends CometExpressionSerde[Add] with MathBase { @@ -95,15 +154,32 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase { withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - createMathExpression( - expr, - expr.left, - expr.right, - inputs, - binding, - expr.dataType, - expr.evalMode, - (builder, mathExpr) => builder.setAdd(mathExpr)) + if (isAssociativeAndRebalanceable(expr.dataType, expr.evalMode)) { + // Rebalance deep `a + b + ...` chains (integral + LEGACY = exactly associative) so the + // proto stays shallow and doesn't overflow protobuf's recursion limit when re-parsed. + val operands = flattenAssociative( + expr, + { case _: Add => true; case _ => false }, + { case a: Add => (a.left, a.right) }) + createBalancedMathExpr( + expr, + operands, + inputs, + binding, + expr.dataType, + expr.evalMode, + (builder, mathExpr) => builder.setAdd(mathExpr)) + } else { + createMathExpression( + expr, + expr.left, + expr.right, + inputs, + binding, + expr.dataType, + expr.evalMode, + (builder, mathExpr) => builder.setAdd(mathExpr)) + } } } @@ -139,15 +215,32 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase { withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}") return None } - createMathExpression( - expr, - expr.left, - expr.right, - inputs, - binding, - expr.dataType, - expr.evalMode, - (builder, mathExpr) => builder.setMultiply(mathExpr)) + if (isAssociativeAndRebalanceable(expr.dataType, expr.evalMode)) { + // Rebalance deep `a * b * ...` chains (integral + LEGACY = exactly associative) so the + // proto stays shallow and doesn't overflow protobuf's recursion limit when re-parsed. + val operands = flattenAssociative( + expr, + { case _: Multiply => true; case _ => false }, + { case m: Multiply => (m.left, m.right) }) + createBalancedMathExpr( + expr, + operands, + inputs, + binding, + expr.dataType, + expr.evalMode, + (builder, mathExpr) => builder.setMultiply(mathExpr)) + } else { + createMathExpression( + expr, + expr.left, + expr.right, + inputs, + binding, + expr.dataType, + expr.evalMode, + (builder, mathExpr) => builder.setMultiply(mathExpr)) + } } } diff --git a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala index 7c05dc4349..db773cd541 100644 --- a/spark/src/main/scala/org/apache/comet/serde/bitwise.scala +++ b/spark/src/main/scala/org/apache/comet/serde/bitwise.scala @@ -29,10 +29,16 @@ object CometBitwiseAnd extends CometExpressionSerde[BitwiseAnd] { expr: BitwiseAnd, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + // Rebalance the (associative, always-integral) chain so deep `a & b & ...` produces a + // shallow proto instead of a left-deep one that overflows protobuf's recursion limit when + // the plan is re-parsed (see createBalancedBinaryExpr). + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: BitwiseAnd => true; case _ => false }, + { case b: BitwiseAnd => (b.left, b.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) @@ -56,10 +62,13 @@ object CometBitwiseOr extends CometExpressionSerde[BitwiseOr] { expr: BitwiseOr, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: BitwiseOr => true; case _ => false }, + { case b: BitwiseOr => (b.left, b.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) @@ -71,10 +80,13 @@ object CometBitwiseXor extends CometExpressionSerde[BitwiseXor] { expr: BitwiseXor, inputs: Seq[Attribute], binding: Boolean): Option[ExprOuterClass.Expr] = { - createBinaryExpr( + val operands = flattenAssociative( expr, - expr.left, - expr.right, + { case _: BitwiseXor => true; case _ => false }, + { case b: BitwiseXor => (b.left, b.right) }) + createBalancedBinaryExpr( + expr, + operands, inputs, binding, (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ec87920c2c..b9d31fa718 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -3162,4 +3162,38 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("deep bitwise And/Or/Xor chains do not overflow the protobuf recursion limit") { + // Same protobuf-recursion-limit concern as the AND/OR case, for the (always-integral, + // exactly associative) bitwise operators: a left-deep chain of N serializes N levels deep + // and overflows the default limit (100) when re-parsed. Comet rebalances them. + val n = 200 + withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") { + // Column-based operands (mix the column with a distinct literal) so the chain isn't + // constant-folded away before serialization. + val terms = (1 to n).map(i => col("_1") + lit(i)) + checkSparkAnswerAndOperator( + spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseAND(b)).as("a"))) + checkSparkAnswerAndOperator( + spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseOR(b)).as("o"))) + checkSparkAnswerAndOperator( + spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseXOR(b)).as("x"))) + } + } + + test("deep integral Add/Multiply chains do not overflow the protobuf recursion limit") { + // Integral + non-ANSI (LEGACY, wrapping) Add/Multiply are exactly associative, so Comet + // rebalances the deep chain. (ANSI/TRY, float, and decimal are intentionally NOT rebalanced + // -- their result or overflow behaviour depends on grouping -- so force non-ANSI here.) + val n = 200 + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") { + val terms = (1 to n).map(i => col("_1") + lit(i)) + checkSparkAnswerAndOperator(spark.table("tbl").select(terms.reduce(_ + _).as("a"))) + // Wrapping (mod 2^32) multiply: the product overflows Int but wraps identically in + // Spark and Comet, so the rebalanced grouping yields the same value. + checkSparkAnswerAndOperator(spark.table("tbl").select(terms.reduce(_ * _).as("m"))) + } + } + } + }