Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 112 additions & 19 deletions spark/src/main/scala/org/apache/comet/serde/arithmetic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.types.{ByteType, DataType, DecimalType, DoubleType,

import org.apache.comet.CometSparkSessionExtensions.withFallbackReason
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType, serializeDataType}
import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProtoInternal, flattenAssociative, optExprWithFallbackReason, scalarFunctionExprToProtoWithReturnType, serializeDataType}
import org.apache.comet.shims.CometEvalModeUtil

trait MathBase {
Expand Down Expand Up @@ -83,6 +83,65 @@ trait MathBase {
false
}

/**
* True when an `Add` / `Multiply` chain of `dataType` in `evalMode` can be rebalanced without
* changing results. Only integral types in LEGACY (wrapping, modular) eval mode are exactly
* associative, so re-grouping the chain is a no-op on the value. Floating point is not
* associative (rounding differs by grouping -- Spark's own `ReorderAssociativeOperator`
* excludes it). ANSI / TRY make integer overflow observable (throw / null), and the grouping
* changes which intermediate overflows, so those are excluded too. Decimal is excluded because
* intermediate precision grows per operation.
*/
def isAssociativeAndRebalanceable(dataType: DataType, evalMode: EvalMode.Value): Boolean =
evalMode == EvalMode.LEGACY && (dataType match {
case _: ByteType | _: ShortType | _: IntegerType | _: LongType => true
case _ => false
})

/**
* Like [[QueryPlanSerde.createBalancedBinaryExpr]] but for `MathExpr`-shaped associative
* operators (`Add`, `Multiply`): each combined inner node carries the chain's `evalMode` and
* `returnType`. Rebalances a flattened chain into an `O(log n)`-depth tree so deep
* `a + b + ...` chains serialize to a shallow proto instead of a left-deep one that overflows
* protobuf's recursion limit when the plan is re-parsed. Only safe for exactly-associative
* chains -- callers gate via [[isAssociativeAndRebalanceable]]. The flattened leaves all share
* the chain's type (Spark coerces operands to it, with casts acting as flatten boundaries), so
* a single `returnType` / `evalMode` is correct for every inner node.
*/
def createBalancedMathExpr(
expr: Expression,
operands: Seq[Expression],
inputs: Seq[Attribute],
binding: Boolean,
dataType: DataType,
evalMode: EvalMode.Value,
f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder)
: Option[ExprOuterClass.Expr] = {
val protos = operands.map(exprToProtoInternal(_, inputs, binding))
if (protos.exists(_.isEmpty)) {
withFallbackReason(expr, operands: _*)
None
} else {
val returnType = serializeDataType(dataType)
val evalModeProto = evalModeToProto(CometEvalModeUtil.fromSparkEvalMode(evalMode))
val leaves = protos.map(_.get).toIndexedSeq
def build(slice: IndexedSeq[ExprOuterClass.Expr]): ExprOuterClass.Expr = {
if (slice.length == 1) slice.head
else {
val mid = slice.length / 2
val mathBuilder = ExprOuterClass.MathExpr
.newBuilder()
.setLeft(build(slice.slice(0, mid)))
.setRight(build(slice.slice(mid, slice.length)))
.setEvalMode(evalModeProto)
returnType.foreach(mathBuilder.setReturnType)
f(ExprOuterClass.Expr.newBuilder(), mathBuilder.build()).build()
}
}
Some(build(leaves))
}
}

}

object CometAdd extends CometExpressionSerde[Add] with MathBase {
Expand All @@ -95,15 +154,32 @@ object CometAdd extends CometExpressionSerde[Add] with MathBase {
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
createMathExpression(
expr,
expr.left,
expr.right,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setAdd(mathExpr))
if (isAssociativeAndRebalanceable(expr.dataType, expr.evalMode)) {
// Rebalance deep `a + b + ...` chains (integral + LEGACY = exactly associative) so the
// proto stays shallow and doesn't overflow protobuf's recursion limit when re-parsed.
val operands = flattenAssociative(
expr,
{ case _: Add => true; case _ => false },
{ case a: Add => (a.left, a.right) })
createBalancedMathExpr(
expr,
operands,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setAdd(mathExpr))
} else {
createMathExpression(
expr,
expr.left,
expr.right,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setAdd(mathExpr))
}
}
}

Expand Down Expand Up @@ -139,15 +215,32 @@ object CometMultiply extends CometExpressionSerde[Multiply] with MathBase {
withFallbackReason(expr, s"Unsupported datatype ${expr.left.dataType}")
return None
}
createMathExpression(
expr,
expr.left,
expr.right,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setMultiply(mathExpr))
if (isAssociativeAndRebalanceable(expr.dataType, expr.evalMode)) {
// Rebalance deep `a * b * ...` chains (integral + LEGACY = exactly associative) so the
// proto stays shallow and doesn't overflow protobuf's recursion limit when re-parsed.
val operands = flattenAssociative(
expr,
{ case _: Multiply => true; case _ => false },
{ case m: Multiply => (m.left, m.right) })
createBalancedMathExpr(
expr,
operands,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setMultiply(mathExpr))
} else {
createMathExpression(
expr,
expr.left,
expr.right,
inputs,
binding,
expr.dataType,
expr.evalMode,
(builder, mathExpr) => builder.setMultiply(mathExpr))
}
}
}

Expand Down
30 changes: 21 additions & 9 deletions spark/src/main/scala/org/apache/comet/serde/bitwise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@ object CometBitwiseAnd extends CometExpressionSerde[BitwiseAnd] {
expr: BitwiseAnd,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
// Rebalance the (associative, always-integral) chain so deep `a & b & ...` produces a
// shallow proto instead of a left-deep one that overflows protobuf's recursion limit when
// the plan is re-parsed (see createBalancedBinaryExpr).
val operands = flattenAssociative(
expr,
expr.left,
expr.right,
{ case _: BitwiseAnd => true; case _ => false },
{ case b: BitwiseAnd => (b.left, b.right) })
createBalancedBinaryExpr(
expr,
operands,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
Expand All @@ -56,10 +62,13 @@ object CometBitwiseOr extends CometExpressionSerde[BitwiseOr] {
expr: BitwiseOr,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
val operands = flattenAssociative(
expr,
expr.left,
expr.right,
{ case _: BitwiseOr => true; case _ => false },
{ case b: BitwiseOr => (b.left, b.right) })
createBalancedBinaryExpr(
expr,
operands,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
Expand All @@ -71,10 +80,13 @@ object CometBitwiseXor extends CometExpressionSerde[BitwiseXor] {
expr: BitwiseXor,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
createBinaryExpr(
val operands = flattenAssociative(
expr,
expr.left,
expr.right,
{ case _: BitwiseXor => true; case _ => false },
{ case b: BitwiseXor => (b.left, b.right) })
createBalancedBinaryExpr(
expr,
operands,
inputs,
binding,
(builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
Expand Down
34 changes: 34 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3162,4 +3162,38 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("deep bitwise And/Or/Xor chains do not overflow the protobuf recursion limit") {
// Same protobuf-recursion-limit concern as the AND/OR case, for the (always-integral,
// exactly associative) bitwise operators: a left-deep chain of N serializes N levels deep
// and overflows the default limit (100) when re-parsed. Comet rebalances them.
val n = 200
withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") {
// Column-based operands (mix the column with a distinct literal) so the chain isn't
// constant-folded away before serialization.
val terms = (1 to n).map(i => col("_1") + lit(i))
checkSparkAnswerAndOperator(
spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseAND(b)).as("a")))
checkSparkAnswerAndOperator(
spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseOR(b)).as("o")))
checkSparkAnswerAndOperator(
spark.table("tbl").select(terms.reduce((a, b) => a.bitwiseXOR(b)).as("x")))
}
}

test("deep integral Add/Multiply chains do not overflow the protobuf recursion limit") {
// Integral + non-ANSI (LEGACY, wrapping) Add/Multiply are exactly associative, so Comet
// rebalances the deep chain. (ANSI/TRY, float, and decimal are intentionally NOT rebalanced
// -- their result or overflow behaviour depends on grouping -- so force non-ANSI here.)
val n = 200
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
withParquetTable((0 until 100).map(i => (i, i.toLong)), "tbl") {
val terms = (1 to n).map(i => col("_1") + lit(i))
checkSparkAnswerAndOperator(spark.table("tbl").select(terms.reduce(_ + _).as("a")))
// Wrapping (mod 2^32) multiply: the product overflows Int but wraps identically in
// Spark and Comet, so the rebalanced grouping yields the same value.
checkSparkAnswerAndOperator(spark.table("tbl").select(terms.reduce(_ * _).as("m")))
}
}
}

}