diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index c8ac8adac0..a65ded05fb 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -302,6 +302,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -380,6 +381,9 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenSuite + org.apache.comet.CometCodegenSourceSuite + org.apache.comet.CometCodegenHOFSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 3c6953aade..a83d70f380 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -155,6 +155,7 @@ jobs: org.apache.comet.CometFuzzAggregateSuite org.apache.comet.CometFuzzIcebergSuite org.apache.comet.CometFuzzMathSuite + org.apache.comet.CometCodegenFuzzSuite org.apache.comet.DataGeneratorSuite - name: "shuffle" value: | @@ -232,6 +233,9 @@ jobs: org.apache.comet.expressions.conditional.CometIfSuite org.apache.comet.expressions.conditional.CometCoalesceSuite org.apache.comet.expressions.conditional.CometCaseWhenSuite + org.apache.comet.CometCodegenSuite + org.apache.comet.CometCodegenSourceSuite + org.apache.comet.CometCodegenHOFSuite - name: "sql" value: | org.apache.spark.sql.CometToPrettyStringSuite diff --git a/docs/source/user-guide/latest/iceberg.md b/docs/source/user-guide/latest/iceberg.md index 5c63ae9ad6..f22180ec77 100644 --- a/docs/source/user-guide/latest/iceberg.md +++ b/docs/source/user-guide/latest/iceberg.md @@ -146,6 +146,24 @@ The following scenarios will fall back to Spark's native Iceberg reader: - Dynamic Partition Pruning under Adaptive Query Execution (non-AQE DPP is supported); see [#3510](https://github.com/apache/datafusion-comet/issues/3510) +### Iceberg UDFs + +Iceberg ships several `ScalaUDF`s that surface in user queries and maintenance actions: + +- `IcebergSpark.registerBucketUDF` and `registerTruncateUDF` register `bucket(N, col)` and + `truncate(W, col)` for use in `SELECT` / `JOIN` / `WHERE` predicates that align with hidden + partitioning. +- `RewriteDataFiles` with `sort-strategy=zorder` builds a tree of per-type ordered-bytes UDFs + (`INT_ORDERED_BYTES`, `LONG_ORDERED_BYTES`, ..., `INTERLEAVE_BYTES`) over the sort key columns + during compaction. + +By default these UDFs cause the enclosing operator to fall back to Spark, which forces a +columnar-to-row roundtrip and demotes the surrounding shuffle from `CometExchange` to +`CometColumnarExchange`. Enabling the experimental +[Scala UDF and Java UDF Support](scala_java_udfs.md) feature +(`spark.comet.exec.scalaUDF.codegen.enabled=true`) routes these UDFs through native execution so +the project, exchange, and sort operators around them stay on the Comet path end-to-end. + ### Task input metrics The native Iceberg reader populates Spark's task-level `inputMetrics.bytesRead` (visible in the Spark UI Stages tab) using the `bytes_read` counter from iceberg-rust's `ScanMetrics`. This counter includes bytes read from both data files and delete files. diff --git a/docs/source/user-guide/latest/index.rst b/docs/source/user-guide/latest/index.rst index 314a0a51bd..9587b2ee03 100644 --- a/docs/source/user-guide/latest/index.rst +++ b/docs/source/user-guide/latest/index.rst @@ -43,6 +43,7 @@ to read more. Supported Data Types Supported Operators Supported Expressions + ScalaUDF and Java UDF Support Configuration Settings Compatibility Guide Understanding Comet Plans diff --git a/docs/source/user-guide/latest/scala_java_udfs.md b/docs/source/user-guide/latest/scala_java_udfs.md new file mode 100644 index 0000000000..e8163e494c --- /dev/null +++ b/docs/source/user-guide/latest/scala_java_udfs.md @@ -0,0 +1,61 @@ + + +# Scala UDF and Java UDF Support + +Comet executes Spark's Scala and Java [scalar user-defined functions (UDFs)](https://spark.apache.org/docs/latest/sql-ref-functions-udf-scalar.html) on the native Comet path. The presence of a UDF does not force the enclosing operator off the native path; surrounding native operators stay native. + +This page covers Spark's `ScalaUDF` (Scala `udf(...)`, `spark.udf.register(...)` over Scala or Java functional interfaces, and SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`). Other UDF kinds (Python / Pandas, Hive, aggregate) are out of scope and continue to fall back to Spark. + +This feature is experimental and disabled by default. + +## Configuration + +| Key | Default | Description | +| ------------------------------------------- | ------- | ------------------------------------------------------------------------------------------------------------------ | +| `spark.comet.exec.scalaUDF.codegen.enabled` | `false` | When `true`, eligible `ScalaUDF`s run on the Comet path. When `false`, the enclosing operator falls back to Spark. | + +## Supported + +- User functions registered via `udf(...)`, `spark.udf.register(...)` (Scala or Java functional interfaces), or SQL `CREATE FUNCTION ... AS 'com.example.MyUDF'`. +- Scalar input/output types: `Boolean`, `Byte`, `Short`, `Int`, `Long`, `Float`, `Double`, `Decimal`, `String`, `Binary`, `Date`, `Timestamp`, `TimestampNTZ`. +- Complex input/output types with arbitrary nesting: `ArrayType`, `StructType`, `MapType`. +- Composition with other Catalyst expressions inside the argument tree (e.g. `myUdf(upper(s))` runs as one native unit). +- Higher-order functions (`transform`, `filter`, `exists`, `aggregate`, `zip_with`, `map_filter`, `map_zip_with`, etc.) inside the argument tree. + +## Not supported + +- Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, the legacy `UserDefinedAggregateFunction`). +- Table UDFs and generators. +- Python `@udf` and Pandas `@pandas_udf`. +- Hive `GenericUDF` and `SimpleUDF`. +- `CalendarIntervalType`, `NullType`, and `UserDefinedType` arguments and return types. UDT-typed columns fall back to Spark; for native execution, store and read the underlying representation directly (e.g. write MLlib `Vector` outputs as `Struct, values: Array>` rather than `VectorUDT`). +- Trees whose total nested-field count (output plus all input columns the UDF tree references) exceeds `spark.sql.codegen.maxFields` (default 100). Comet refuses these at plan time and the operator falls back to Spark. + +When a UDF is rejected, the reason surfaces through Comet's standard fallback diagnostics; the query still runs on Spark. + +## Behavior + +- Non-deterministic expressions referenced from the argument tree (`rand`, `uuid`, `monotonically_increasing_id`) produce per-partition sequences consistent with Spark. +- `TaskContext.get()` inside the user function returns the driving Spark task's context. +- The user function must be closure-serializable; the same function that works with Spark's executor execution works here. + +## Known limitations + +- Each query containing a ScalaUDF pays a one-time codegen cost on its first batch and reuses the compiled kernel for subsequent batches, matching Spark's whole-stage codegen behavior. Bytecode is deduped JVM-wide via the same `CodeGenerator` cache, so structurally identical queries across a session share the compiled class. diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 75fdb03d0d..2cfdea93be 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -211,9 +211,9 @@ impl PhysicalPlanner { self } - /// Attach the Spark `TaskContext` global reference captured at `createPlan` time. Cloned - /// into every `JvmScalarUdfExpr` the planner builds so the JNI bridge can install it as - /// the thread-local on the Tokio worker driving the UDF. + /// Attach a propagated Spark `TaskContext` global reference. Called by the JNI `executePlan` + /// entry with whatever was captured at `createPlan` time. The planner clones this `Option` + /// into every `JvmScalarUdfExpr` it builds. pub fn with_task_context( mut self, task_context: Option>>>, @@ -742,6 +742,13 @@ impl PhysicalPlanner { to_arrow_datatype(udf.return_type.as_ref().ok_or_else(|| { GeneralError("JvmScalarUdf missing return_type".to_string()) })?); + // Invariant: task_context is propagated for every JvmScalarUdfExpr built during + // normal execution. The TEST_EXEC_CONTEXT_ID path is the only context in which + // task_context may legitimately be None (unit tests, direct native driver runs). + debug_assert!( + self.task_context.is_some() || self.exec_context_id == TEST_EXEC_CONTEXT_ID, + "task_context must be set for non-test execution" + ); Ok(Arc::new(JvmScalarUdfExpr::new( udf.class_name.clone(), args, diff --git a/native/spark-expr/src/jvm_udf/mod.rs b/native/spark-expr/src/jvm_udf/mod.rs index 4ed25de6ee..0e3968e60a 100644 --- a/native/spark-expr/src/jvm_udf/mod.rs +++ b/native/spark-expr/src/jvm_udf/mod.rs @@ -59,6 +59,10 @@ impl JvmScalarUdfExpr { return_nullable: bool, task_context: Option>>>, ) -> Self { + debug_assert!( + !class_name.is_empty(), + "JvmScalarUdfExpr requires a non-empty class name" + ); Self { class_name, args, @@ -120,10 +124,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { } fn evaluate(&self, batch: &RecordBatch) -> DFResult { - // Step 1: evaluate child expressions to get Arrow arrays. Scalar children - // (e.g. literal patterns) are sent as length-1 vectors rather than expanded - // to batch-row count, so the JVM bridge does not pay an O(rows) copy for - // values that never vary across the batch. + // Scalar children (e.g. literal patterns) are sent as length-1 vectors rather than + // expanded to batch-row count, so the JVM bridge does not pay an O(rows) copy for + // values that never vary across the batch. The JVM side gets `numRows` directly via + // the bridge so it doesn't need the scalar to carry batch length. let arrays: Vec = self .args .iter() @@ -133,7 +137,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { }) .collect::>()?; - // Step 2: allocate FFI structs on the Rust heap and collect their raw pointers. // The JVM writes into the out_array/out_schema slots and reads from the in_ slots. let in_ffi_arrays: Vec> = arrays .iter() @@ -157,7 +160,13 @@ impl PhysicalExpr for JvmScalarUdfExpr { .map(|b| b.as_ref() as *const FFI_ArrowSchema as i64) .collect(); - // Allocate output FFI slots. + debug_assert!(!self.class_name.is_empty(), "class_name must not be empty"); + debug_assert_eq!( + in_arr_ptrs.len(), + in_sch_ptrs.len(), + "input array and schema pointer counts must match" + ); + let mut out_array = Box::new(FFI_ArrowArray::empty()); let mut out_schema = Box::new(FFI_ArrowSchema::empty()); let out_arr_ptr = out_array.as_mut() as *mut FFI_ArrowArray as i64; @@ -166,7 +175,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { let class_name = self.class_name.clone(); let n_args = arrays.len(); - // Step 3: attach a JNI env for this thread and call the static bridge method. JVMClasses::with_env(|env| { let bridge = JVMClasses::get().comet_udf_bridge.as_ref().ok_or_else(|| { CometError::from(ExecutionError::GeneralError( @@ -176,12 +184,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { )) })?; - // Build the JVM String for the class name. let jclass_name = env .new_string(&class_name) .map_err(|e| CometError::JNI { source: e })?; - // Build the long[] arrays for input pointers. let in_arr_java = env .new_long_array(n_args) .map_err(|e| CometError::JNI { source: e })?; @@ -196,9 +202,10 @@ impl PhysicalExpr for JvmScalarUdfExpr { .set_region(env, 0, &in_sch_ptrs) .map_err(|e| CometError::JNI { source: e })?; - // Pass a null jobject when no TaskContext was propagated so the bridge's null-guard - // leaves the worker thread's current TaskContext.get() in place. The borrow must - // outlive `call_static_method_unchecked`. + // Resolve the TaskContext reference once before building the arg array so the + // borrow lives until `call_static_method_unchecked` returns. When no TaskContext + // was propagated, pass a null object so the bridge's null-guard leaves the thread- + // local alone. let null_task_context = JObject::null(); let task_context_ref: &JObject = match &self.task_context { Some(gref) => gref.as_obj(), @@ -229,7 +236,6 @@ impl PhysicalExpr for JvmScalarUdfExpr { Ok(()) })?; - // Step 4: import the result from the FFI slots filled by the JVM. // SAFETY: `*out_array` moves the FFI_ArrowArray out of the Box (the heap // allocation is freed by the move), and `from_ffi` wraps it in an Arc that // keeps the JVM-installed release callback alive until the resulting diff --git a/spark/src/main/java/org/apache/comet/codegen/CometBatchKernel.java b/spark/src/main/java/org/apache/comet/codegen/CometBatchKernel.java new file mode 100644 index 0000000000..a515cbe32d --- /dev/null +++ b/spark/src/main/java/org/apache/comet/codegen/CometBatchKernel.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Abstract base extended by the Janino-compiled batch kernel emitted by {@code + * CometBatchKernelCodegen}. The generated subclass extends {@code CometInternalRow} (so Spark's + * {@code BoundReference.genCode} can call {@code this.getUTF8String(ord)} directly) and carries + * typed input fields baked at codegen time, one per input column. Expression evaluation plus Arrow + * read/write fuse into one method per expression tree. + */ +public abstract class CometBatchKernel extends CometInternalRow { + + protected final Object[] references; + + protected CometBatchKernel(Object[] references) { + this.references = references; + } + + /** + * Run partition-dependent initialization. The generated subclass overrides this to execute + * statements collected via {@code CodegenContext.addPartitionInitializationStatement}, e.g. + * reseeding {@code Rand}'s {@code XORShiftRandom} from {@code seed + partitionIndex}. + * Deterministic expressions leave this as a no-op. + * + *

The caller invokes this before the first {@code process} call of each partition. The + * generated subclass is not thread-safe across concurrent {@code process} calls. The dispatcher + * allocates one per partition and serializes calls. + */ + public void init(int partitionIndex) {} + + /** + * Process one batch. + * + * @param inputs Arrow input vectors. Length and concrete classes match the schema the kernel was + * compiled against. + * @param output Arrow output vector. Caller allocates to the expression's {@code dataType}. + * @param numRows number of rows in this batch + */ + public abstract void process(ValueVector[] inputs, FieldVector output, int numRows); +} diff --git a/spark/src/main/scala/org/apache/comet/CometConf.scala b/spark/src/main/scala/org/apache/comet/CometConf.scala index 1a4eb3a8c8..fdd1ae2073 100644 --- a/spark/src/main/scala/org/apache/comet/CometConf.scala +++ b/spark/src/main/scala/org/apache/comet/CometConf.scala @@ -362,6 +362,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_SCALA_UDF_CODEGEN_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.scalaUDF.codegen.enabled") + .category(CATEGORY_EXEC) + .doc("Experimental. Whether to route Spark `ScalaUDF` expressions through Comet's " + + "Arrow-direct codegen dispatcher. When enabled, a supported ScalaUDF is compiled into " + + "a per-batch kernel that reads and writes Arrow vectors directly from native " + + "execution. When disabled, plans containing a ScalaUDF fall back to Spark for the " + + "enclosing operator.") + .booleanConf + .createWithDefault(false) + val COMET_EXEC_SHUFFLE_WITH_HASH_PARTITIONING_ENABLED: ConfigEntry[Boolean] = conf("spark.comet.native.shuffle.partitioning.hash.enabled") .category(CATEGORY_SHUFFLE) diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometArrayData.scala b/spark/src/main/scala/org/apache/comet/codegen/CometArrayData.scala new file mode 100644 index 0000000000..308d1e9d96 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometArrayData.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Throwing-default `ArrayData` base for the codegen kernel. Subclasses override only the getters + * their element type needs. + * + * Consumer: per-column `InputArray_${path}` nested classes that back `getArray(ord)` plus the + * recursion for `Array>` and array-typed map keys / struct fields. + * + * `ArrayData` and `InternalRow` are sibling abstract classes, so a base aimed at one cannot serve + * the other. The shared `get(ordinal, dataType)` dispatch lives in + * [[CometSpecializedGettersDispatch]]. Mixes in [[CometInternalRowShim]] so Spark 4.x's + * `getVariant` / `getGeography` / `getGeometry` get throwing defaults. + */ +abstract class CometArrayData extends ArrayData with CometInternalRowShim { + + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + + override def get(ordinal: Int, dataType: DataType): AnyRef = + CometSpecializedGettersDispatch.get(this, ordinal, dataType) + + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + + override def getByte(ordinal: Int): Byte = unsupported("getByte") + + override def getShort(ordinal: Int): Short = unsupported("getShort") + + override def getInt(ordinal: Int): Int = unsupported("getInt") + + override def getLong(ordinal: Int): Long = unsupported("getLong") + + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this array shape") + + override def update(i: Int, value: Any): Unit = unsupported("update") + + override def copy(): ArrayData = unsupported("copy") + + override def array: Array[Any] = unsupported("array") + + override def toString(): String = { + val n = + try numElements().toString + catch { + case _: Throwable => "?" + } + s"${getClass.getSimpleName}(numElements=$n)" + } + + override def numElements(): Int = unsupported("numElements") +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala new file mode 100644 index 0000000000..2795911da3 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegen.scala @@ -0,0 +1,494 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, HigherOrderFunction, LambdaFunction, Literal, NamedLambdaVariable, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.shims.CometExprTraitShim + +/** + * Compiles a bound [[Expression]] plus an Arrow input schema into a [[CometBatchKernel]] that + * fuses Arrow input reads, Spark expression evaluation, and Arrow output writes into one + * Janino-compiled method per `(expression, schema)` pair. + * + * The kernel compiles any bound Catalyst expression. The tree need not be rooted at a `ScalaUDF`. + * Today's only consumer is [[org.apache.comet.udf.codegen.CometScalaUDFCodegen]]. + * + * Constraints: one output vector per kernel, per-row scalar evaluation only (aggregate, window, + * generator are rejected by [[canHandle]]). + * + * Input- and output-side emission live in [[CometBatchKernelCodegenInput]] and + * [[CometBatchKernelCodegenOutput]]. This file owns the [[ArrowColumnSpec]] vocabulary, the + * [[canHandle]] / [[allocateOutput]] / [[compile]] / [[generateSource]] entry points, and + * cross-cutting kernel-shape decisions (NullIntolerant short-circuit, CSE variant). + * + * The generated kernel is the `InternalRow` that Spark's `BoundReference.genCode` reads from. See + * [[generateSource]] for how the wiring is set up. + */ +object CometBatchKernelCodegen extends Logging with CometExprTraitShim { + + /** + * Resolve an Arrow vector class by simple name through the codegen object's own classloader. + * Tests use this to refer to vector classes via the same classloader the codegen pattern- + * matches against, in case the test classpath ever diverges from the codegen's (e.g. through + * future shading rearrangement). + */ + def vectorClassBySimpleName(name: String): Class[_ <: ValueVector] = name match { + case "BitVector" => classOf[BitVector] + case "TinyIntVector" => classOf[TinyIntVector] + case "SmallIntVector" => classOf[SmallIntVector] + case "IntVector" => classOf[IntVector] + case "BigIntVector" => classOf[BigIntVector] + case "Float4Vector" => classOf[Float4Vector] + case "Float8Vector" => classOf[Float8Vector] + case "DecimalVector" => classOf[DecimalVector] + case "DateDayVector" => classOf[DateDayVector] + case "TimeStampMicroVector" => classOf[TimeStampMicroVector] + case "TimeStampMicroTZVector" => classOf[TimeStampMicroTZVector] + case "VarCharVector" => classOf[VarCharVector] + case "VarBinaryVector" => classOf[VarBinaryVector] + case other => throw new IllegalArgumentException(s"unknown Arrow vector class: $other") + } + + /** + * Type surface the kernel covers on both input and output sides. Recursive: complex types are + * supported when their children are. + */ + def isSupportedDataType(dt: DataType): Boolean = dt match { + case BooleanType | ByteType | ShortType | IntegerType | LongType => true + case FloatType | DoubleType => true + case _: DecimalType => true + case _: StringType | _: BinaryType => true + case DateType | TimestampType | TimestampNTZType => true + case ArrayType(inner, _) => isSupportedDataType(inner) + case st: StructType => st.fields.forall(f => isSupportedDataType(f.dataType)) + case mt: MapType => isSupportedDataType(mt.keyType) && isSupportedDataType(mt.valueType) + case _ => false + } + + /** + * Mirrors `WholeStageCodegenExec.numOfNestedFields` so [[canHandle]] can reuse + * `spark.sql.codegen.maxFields`. + */ + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case st: StructType => st.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case _ => 1 + } + + /** + * Plan-time predicate. `None` greenlights the serde to emit the codegen proto; `Some(reason)` + * forces a Spark fallback (typically `withInfo(...) + None`) so the operator falls back cleanly + * rather than crashing the Janino compile at execute time. + * + * Checks every `BoundReference`'s data type and the root `expr.dataType` against + * [[isSupportedDataType]], rejects aggregates / generators / `CodegenFallback` (other than + * HOFs, which are admitted), and gates total nested-field count on + * `spark.sql.codegen.maxFields`. + */ + def canHandle(boundExpr: Expression): Option[String] = { + if (!isSupportedDataType(boundExpr.dataType)) { + return Some(s"codegen dispatch: unsupported output type ${boundExpr.dataType}") + } + // Mirror WSCG's `spark.sql.codegen.maxFields` gate. Wide schemas blow the generated class's + // typed input field count, the typed-getter switch, and the constant pool. Refuse here so the + // operator falls back to Spark cleanly rather than tripping a Janino compile failure + // mid-execution (Comet has no recovery for that). + val maxFields = SQLConf.get.wholeStageMaxNumFields + val totalFields = numOfNestedFields(boundExpr.dataType) + + boundExpr.collect { case b: BoundReference => numOfNestedFields(b.dataType) }.sum + if (totalFields > maxFields) { + return Some( + s"codegen dispatch: too many nested fields ($totalFields > " + + s"spark.sql.codegen.maxFields=$maxFields)") + } + // HOFs are `CodegenFallback` but admitted: `CodegenFallback.doGenCode` emits one + // `((Expression) references[N]).eval(row)` call site per HOF. The kernel dispatches to the + // HOF's interpreted `eval`, which mutates `NamedLambdaVariable.value` per element and reads + // the input array through the kernel's typed Arrow getters. Per-task `boundExpr` isolation + // in `CometScalaUDFCodegen.kernelCache` prevents concurrent partitions from racing on the + // lambda variable's `AtomicReference`. See `CometCodegenHOFSuite`. + // + // Nondeterministic / stateful expressions are accepted: each cache entry holds one kernel + // instance with a single `init(partitionIndex)` call, so `Rand` / `MonotonicallyIncreasingID` + // state advances correctly across batches. + // + // `ExecSubqueryExpression` (`ScalarSubquery`, `InSubqueryExec`) is accepted: the surrounding + // Comet operator's inherited `SparkPlan.waitForSubqueries` populates the subquery's + // `result` field before evaluation. The closure serializer captures that value into the + // arg-0 bytes, and the dispatcher keys its compile cache on those bytes, so distinct subquery + // results produce distinct cache entries. + // + // `Unevaluable`: rejected by default. `isCodegenInertUnevaluable` exempts version-specific + // leaves that are `Unevaluable` but never invoked by codegen (e.g. Spark 4.0's + // `ResolvedCollation` in `Collate.collation`, where `Collate.genCode` delegates to its child). + boundExpr.find { + case _: org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction => true + case _: org.apache.spark.sql.catalyst.expressions.Generator => true + case _: HigherOrderFunction => false + case _: LambdaFunction => false + case _: NamedLambdaVariable => false + case _: CodegenFallback => true + case u: Unevaluable if isCodegenInertUnevaluable(u) => false + case _: Unevaluable => true + case _ => false + } match { + case Some(bad) => + return Some( + s"codegen dispatch: expression ${bad.getClass.getSimpleName} not supported " + + "(aggregate, generator, codegen-fallback, or unevaluable)") + case None => + } + val badRef = boundExpr.collectFirst { + case b: BoundReference if !isSupportedDataType(b.dataType) => + b + } + badRef.map(b => + s"codegen dispatch: unsupported input type ${b.dataType} at ordinal ${b.ordinal}") + } + + /** + * Allocate an Arrow output vector from a pre-built `Field`. Forwards to + * [[CometBatchKernelCodegenOutput.allocateOutput]]. + */ + def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector = + CometBatchKernelCodegenOutput.allocateOutput(field, numRows, estimatedBytes) + + /** + * Spark `DataType` to an Arrow `Field`, resolving mismatches between Arrow Java's default field + * labels and what Spark / Arrow Rust expect on the FFI boundary. + */ + def toFfiArrowField(name: String, dataType: DataType, nullable: Boolean): Field = + CometBatchKernelCodegenOutput.toFfiArrowField(name, dataType, nullable) + + def compile(boundExpr: Expression, inputSchema: Seq[ArrowColumnSpec]): CompiledKernel = { + val src = generateSource(boundExpr, inputSchema) + val (clazz, _) = + try { + CodeGenerator.compile(src.code) + } catch { + case t: Throwable => + logError( + s"CometBatchKernelCodegen: compile failed for ${boundExpr.getClass.getSimpleName}. " + + s"Generated source follows:\n${CodeFormatter.format(src.code)}", + t) + throw t + } + logInfo( + s"CometBatchKernelCodegen: compiled ${boundExpr.getClass.getSimpleName} " + + s"-> ${boundExpr.dataType} inputs=" + + inputSchema + .map(s => s"${s.vectorClass.getSimpleName}${if (s.nullable) "?" else ""}") + .mkString(",")) + // ScalaUDF embeds stateful `ExpressionEncoder` serializers via `ctx.addReferenceObj` that + // reuse internal `UnsafeRow` / `byte[]` buffers per `apply`. Each kernel instance needs its + // own copy. The closure regenerates the references array per call so the dispatcher can hand + // a fresh array to every kernel it allocates from this `CompiledKernel`. + val freshReferences: () => Array[Any] = () => + generateSource(boundExpr, inputSchema).references + CompiledKernel(clazz, freshReferences) + } + + /** + * Generate the Java source without compiling it. Tests assert on emitted source (null short- + * circuit present, non-nullable `isNullAt` returns literal `false`, etc.) without paying for + * Janino. + */ + def generateSource( + boundExpr: Expression, + inputSchema: Seq[ArrowColumnSpec]): GeneratedSource = { + canHandle(boundExpr).foreach(reason => + throw new IllegalArgumentException(s"CometBatchKernelCodegen.generateSource: $reason")) + val ctx = new CodegenContext + // `BoundReference.genCode` emits `${ctx.INPUT_ROW}.getUTF8String(ord)`. Aliasing `row` to + // `this` at the top of `process` routes those reads to the kernel's typed getters (final + // class, JIT devirtualizes + folds the switch). `row` rather than `this` because Spark's + // `splitExpressions` uses `INPUT_ROW` as the parameter name of helper methods it emits; + // `this` is a reserved keyword and Janino rejects it as a parameter name. + ctx.INPUT_ROW = "row" + + val baseClass = classOf[CometBatchKernel].getName + // Resolve Arrow class names at runtime so the generated source matches the method signature + // the running classloader sees. The packaged Comet jar relocates `org.apache.arrow` to + // `org.apache.comet.shaded.arrow` (see `spark/pom.xml`); `.getName` picks the right name + // regardless of whether we run against the shaded jar or the unshaded build output. + val valueVectorClass = classOf[ValueVector].getName + val fieldVectorClass = classOf[FieldVector].getName + + // `outputSetup` holds once-per-batch declarations (typed child-vector casts for complex + // outputs) that `emitOutputWriter` factors out of the per-row body. Scalar outputs return an + // empty string here. + // + // TODO(method-size): perRowBody is inlined inside process's for-loop and not split. + // Sufficiently deep trees can exceed Janino's 64KB method size. Wrap in + // ctx.splitExpressionsWithCurrentInputs when hit. + val (concreteOutClass, outputSetup, perRowBody) = { + // Class-field CSE. `generateExpressions` runs `subexpressionElimination` under the hood, + // populating `ctx.subexprFunctions` with per-row helper calls that write common subtree + // results into `addMutableState` fields. The returned `ExprCode` references those fields. + // `subexprFunctionsCode` is the concatenated helper invocation block, spliced into the + // per-row body by `defaultBody`. + val ev = if (SQLConf.get.subexpressionEliminationEnabled) { + ctx.generateExpressions(Seq(boundExpr), doSubexpressionElimination = true).head + } else { + boundExpr.genCode(ctx) + } + val subExprsCode = ctx.subexprFunctionsCode + val (cls, setup, snippet) = + CometBatchKernelCodegenOutput.emitOutputWriter(boundExpr.dataType, ev.value, ctx) + (cls, setup, defaultBody(boundExpr, ev, snippet, subExprsCode)) + } + + val typedFieldDecls = CometBatchKernelCodegenInput.emitInputFieldDecls(inputSchema) + val typedInputCasts = CometBatchKernelCodegenInput.emitInputCasts(inputSchema) + val decimalTypeByOrdinal = CometBatchKernelCodegenInput.decimalPrecisionByOrdinal(boundExpr) + val getters = + CometBatchKernelCodegenInput.emitTypedGetters(inputSchema, decimalTypeByOrdinal) + val nested = CometBatchKernelCodegenInput.emitNestedClasses(inputSchema) + val getArrayMethod = CometBatchKernelCodegenInput.emitGetArrayMethod(inputSchema) + val getStructMethod = CometBatchKernelCodegenInput.emitGetStructMethod(inputSchema) + val getMapMethod = CometBatchKernelCodegenInput.emitGetMapMethod(inputSchema) + + val codeBody = + s""" + |public java.lang.Object generate(Object[] references) { + | return new SpecificCometBatchKernel(references); + |} + | + |final class SpecificCometBatchKernel extends $baseClass { + | + | ${ctx.declareMutableStates()} + | + | $typedFieldDecls + | private int rowIdx; + | + | public SpecificCometBatchKernel(Object[] references) { + | super(references); + | ${ctx.initMutableStates()} + | } + | + | @Override + | public void init(int partitionIndex) { + | ${ctx.initPartition()} + | } + | + | $getters + | $getArrayMethod + | $getStructMethod + | $getMapMethod + | + | @Override + | public void process( + | $valueVectorClass[] inputs, + | $fieldVectorClass outRaw, + | int numRows) { + | $concreteOutClass output = ($concreteOutClass) outRaw; + | $typedInputCasts + | $outputSetup + | // Alias the kernel as `row` so Spark-generated `${ctx.INPUT_ROW}.method()` reads + | // resolve to the kernel's typed getters. Helper methods that Spark splits via + | // `splitExpressions` also take `InternalRow row` as a parameter; `this` flows + | // implicitly via INPUT_ROW. + | org.apache.spark.sql.catalyst.InternalRow row = this; + | for (int i = 0; i < numRows; i++) { + | this.rowIdx = i; + | $perRowBody + | } + | } + | + | ${ctx.declareAddedFunctions()} + | + |$nested + |} + """.stripMargin + + val code = CodeFormatter.stripOverlappingComments( + new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) + GeneratedSource(code.body, code, ctx.references.toArray) + } + + /** + * Per-row body. For `NullIntolerant` expressions where the entire tree propagates nulls, + * prepends a short-circuit on the union of input ordinals so the whole `ev.code` cost is + * skipped on null rows. Otherwise the standard shape: run `ev.code`, then `setNull` or write + * based on `ev.isNull`. + * + * `subExprsCode` is the CSE helper-invocation block. It must run before `ev.code`. Inside the + * short-circuit it lives in the else branch so null rows skip CSE too. + */ + private def defaultBody( + boundExpr: Expression, + ev: ExprCode, + writeSnippet: String, + subExprsCode: String): String = { + boundExpr match { + case _ if isNullIntolerant(boundExpr) && allNullIntolerant(boundExpr) => + // Every node from root to leaf is `NullIntolerant` or a leaf, so "any BoundReference null + // -> whole expression null". A non-null-propagating node like `coalesce` or `if` would + // make this incorrect (`coalesce(null, x)` is `x`); `allNullIntolerant` rejects those. + val inputOrdinals = + boundExpr.collect { case b: BoundReference => b.ordinal }.distinct + val nullCheck = + if (inputOrdinals.isEmpty) "false" + else inputOrdinals.map(ord => s"this.col$ord.isNull(i)").mkString(" || ") + s""" + |if ($nullCheck) { + | output.setNull(i); + |} else { + | $subExprsCode + | ${ev.code} + | $writeSnippet + |} + """.stripMargin + case _ => + // NonNullableOutputShortCircuit: when `nullable = false`, drop the `if (ev.isNull)` + // guard at source level rather than relying on JIT folding. + if (!boundExpr.nullable) { + s""" + |$subExprsCode + |${ev.code} + |$writeSnippet + """.stripMargin + } else { + s""" + |$subExprsCode + |${ev.code} + |if (${ev.isNull}) { + | output.setNull(i); + |} else { + | $writeSnippet + |} + """.stripMargin + } + } + } + + /** + * True iff every node in the tree propagates nulls (`NullIntolerant`, `BoundReference`, or + * `Literal`). Gates the [[defaultBody]] short-circuit, which is only correct when no node + * (`Coalesce`, `If`, `CaseWhen`, `Concat`, ...) breaks the propagation chain. + */ + private def allNullIntolerant(expr: Expression): Boolean = + !expr.exists { + case _: BoundReference | _: Literal => false + case other => !isNullIntolerant(other) + } + + /** + * Per-column compile-time invariants. The concrete Arrow vector class and the nullability flag + * are baked into the generated kernel and form part of the cache key: different vector classes + * or nullability produce different kernels. The dispatcher hardcodes top-level `nullable=true` + * (per-batch null density is not part of the cache key); tests reach the non-nullable codegen + * path by constructing specs directly. + */ + sealed trait ArrowColumnSpec { + def vectorClass: Class[_ <: ValueVector] + + def nullable: Boolean + } + + /** Scalar column: one Arrow vector class per row slot, no nested structure. */ + final case class ScalarColumnSpec(vectorClass: Class[_ <: ValueVector], nullable: Boolean) + extends ArrowColumnSpec + + /** + * Array column: an Arrow `ListVector` wrapping a child spec. `elementSparkType` lets the + * nested-class emitter pick the right read template, and the child carries the Arrow vector + * class. Nested arrays compose recursively. + */ + final case class ArrayColumnSpec( + nullable: Boolean, + elementSparkType: DataType, + element: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[ListVector] + } + + /** + * Struct column: an Arrow `StructVector` over N typed children. Each [[StructFieldSpec]] + * carries the Spark name (cache-key identity), the Spark `DataType`, the child + * `ArrowColumnSpec`, and the per-field `nullable` bit (lets non-nullable fields elide their + * per-row null check). + */ + final case class StructColumnSpec(nullable: Boolean, fields: Seq[StructFieldSpec]) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[StructVector] + } + + /** One field entry on a [[StructColumnSpec]]. */ + final case class StructFieldSpec( + name: String, + sparkType: DataType, + nullable: Boolean, + child: ArrowColumnSpec) + + /** + * Map column: an Arrow `MapVector` (subclass of `ListVector`) whose data vector is a + * `StructVector` with key at child 0 and value at child 1. Nested keys and values compose + * recursively. The child specs' `nullable` field is unused on the read path. Output-side null + * guards for map values come from `MapType.valueContainsNull` on the Spark `DataType`. + */ + final case class MapColumnSpec( + nullable: Boolean, + keySparkType: DataType, + valueSparkType: DataType, + key: ArrowColumnSpec, + value: ArrowColumnSpec) + extends ArrowColumnSpec { + override def vectorClass: Class[_ <: ValueVector] = classOf[MapVector] + } + + /** + * Compiled kernel handle. `freshReferences` regenerates the references array per kernel + * allocation because `ScalaUDF` embeds stateful `ExpressionEncoder` serializers that cannot be + * shared. + */ + final case class CompiledKernel(factory: GeneratedClass, freshReferences: () => Array[Any]) { + def newInstance(): CometBatchKernel = + factory.generate(freshReferences()).asInstanceOf[CometBatchKernel] + } + + /** + * Output of [[generateSource]]. Tests inspect `body` to assert the shape of the generated + * source. See `CometCodegenSourceSuite`. + */ + final case class GeneratedSource(body: String, code: CodeAndComment, references: Array[Any]) + + object ArrowColumnSpec { + + /** Convenience constructor for the scalar case. */ + def apply(vectorClass: Class[_ <: ValueVector], nullable: Boolean): ArrowColumnSpec = + ScalarColumnSpec(vectorClass, nullable) + + /** Trait-level extractor that destructures only the scalar case. */ + def unapply(spec: ArrowColumnSpec): Option[(Class[_ <: ValueVector], Boolean)] = spec match { + case ScalarColumnSpec(c, n) => Some((c, n)) + case _ => None + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala new file mode 100644 index 0000000000..9a4f4bcc57 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenInput.scala @@ -0,0 +1,963 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import scala.collection.mutable + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression} +import org.apache.spark.sql.types._ + +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec} +import org.apache.comet.vector.CometPlainVector + +/** + * Input-side emitters for the codegen kernel: typed field declarations, per-batch input casts, + * top-level typed-getter switches, nested `InputArray_${path}` / `InputStruct_${path}` / + * `InputMap_${path}` classes per complex level. Paired with [[CometBatchKernelCodegenOutput]]. + * + * Path encoding. Each position in the spec tree has a unique path string used as a suffix on + * vector fields and nested classes. From a column ordinal: root `col${ord}`, array element + * `${P}_e`, struct field `fi` `${P}_f${fi}`, map key `${P}_k`, map value `${P}_v`. + * + * Nested-class composition. Each instance is allocated fresh per `getArray(i)` / `getStruct(i, + * n)` / `getMap(i)` call, with `final` slice fields. Matches Spark's `ColumnarRow` / + * `ColumnarArray` model: retain-by-reference consumers (e.g. `ArrayDistinct.nullSafeEval` + * stashing references in an `OpenHashSet`) get distinct identities, and JIT escape analysis + * usually scalarizes the allocation when the value is consumed locally. + */ +private[codegen] object CometBatchKernelCodegenInput { + + /** + * Primitive Arrow vector classes wrapped in [[CometPlainVector]] at input-cast time so per-row + * reads go through `Platform.get*` against a cached buffer address (JIT inlines to branchless + * reads). Decimal/VarChar/VarBinary stay on the typed Arrow field with cached buffer addresses + * for inline unsafe reads. + */ + private val primitiveArrowClasses: Set[Class[_]] = Set( + classOf[BitVector], + classOf[TinyIntVector], + classOf[SmallIntVector], + classOf[IntVector], + classOf[BigIntVector], + classOf[Float4Vector], + classOf[Float8Vector], + classOf[DateDayVector], + classOf[TimeStampMicroVector], + classOf[TimeStampMicroTZVector]) + private val cometPlainVectorName: String = classOf[CometPlainVector].getName + + /** Emit kernel typed-vector field declarations for every level of every input column. */ + def emitInputFieldDecls(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectVectorFieldDecls(path, spec, lines) + } + lines.mkString("\n ") + } + + /** + * Emit per-batch cast statements, recursing through complex types via `getDataVector` / etc. + */ + def emitInputCasts(inputSchema: Seq[ArrowColumnSpec]): String = { + val lines = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + val path = s"col$ord" + collectCasts(path, spec, s"inputs[$ord]", lines) + } + lines.mkString("\n ") + } + + /** + * Emit typed-getter overrides. Each switches on column ordinal. With the inlined constant + * ordinal from `BoundReference.genCode`, JIT folds the switch to one branch. + * + * `decimalTypeByOrdinal` lets the decimal getter specialize per ordinal: when only a + * `DecimalType(precision <= 18)` `BoundReference` reads the ordinal, the case skips the + * `BigDecimal` allocation and reads the unscaled long directly. + */ + def emitTypedGetters( + inputSchema: Seq[ArrowColumnSpec], + decimalTypeByOrdinal: Map[Int, Option[DecimalType]]): String = { + val withOrd = inputSchema.zipWithIndex + + val isNullCases = withOrd.map { case (spec, ord) => + if (!spec.nullable) { + s" case $ord: return false;" + } else { + // CometPlainVector exposes `isNullAt`; Arrow-typed fields expose `isNull`. Same semantics. + val method = spec.vectorClass match { + case cls if wrapsInCometPlainVector(cls) => "isNullAt" + case _ => "isNull" + } + s" case $ord: return this.col$ord.$method(this.rowIdx);" + } + } + + val booleanCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[BitVector] => + s" case $ord: return this.col$ord.getBoolean(this.rowIdx);" + } + val byteCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[TinyIntVector] => + s" case $ord: return this.col$ord.getByte(this.rowIdx);" + } + val shortCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[SmallIntVector] => + s" case $ord: return this.col$ord.getShort(this.rowIdx);" + } + val intCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[IntVector] || cls == classOf[DateDayVector] => + s" case $ord: return this.col$ord.getInt(this.rowIdx);" + } + val longCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) + if cls == classOf[BigIntVector] || + cls == classOf[TimeStampMicroVector] || + cls == classOf[TimeStampMicroTZVector] => + s" case $ord: return this.col$ord.getLong(this.rowIdx);" + } + val floatCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float4Vector] => + s" case $ord: return this.col$ord.getFloat(this.rowIdx);" + } + val doubleCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[Float8Vector] => + s" case $ord: return this.col$ord.getDouble(this.rowIdx);" + } + val decimalCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[DecimalVector] => + val known = decimalTypeByOrdinal.getOrElse(ord, None) + val valueAddr = s"this.col${ord}_valueAddr" + val slowField = s"this.col$ord" + val fastPath = emitDecimalFastBodyUnsafe(valueAddr, "this.rowIdx", " ") + val slowPath = emitDecimalSlowBody(slowField, "this.rowIdx", " ") + val body = known match { + case Some(dt) if dt.precision <= Decimal.MAX_LONG_DIGITS => fastPath + case Some(_) => slowPath + case None => + s""" if (precision <= ${Decimal.MAX_LONG_DIGITS}) { + |$fastPath + | } else { + |$slowPath + | }""".stripMargin + } + s""" case $ord: { + |$body + | }""".stripMargin + } + val binaryCases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarBinaryVector] => + s""" case $ord: { + |${emitBinaryBodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + val utf8Cases = withOrd.collect { + case (ArrowColumnSpec(cls, _), ord) if cls == classOf[VarCharVector] => + s""" case $ord: { + |${emitUtf8BodyUnsafe( + s"this.col${ord}_valueAddr", + s"this.col${ord}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + } + + Seq( + emitOrdinalSwitch("public boolean isNullAt(int ordinal)", "isNullAt", isNullCases), + emitOrdinalSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + emitOrdinalSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + emitOrdinalSwitch("public short getShort(int ordinal)", "getShort", shortCases), + emitOrdinalSwitch("public int getInt(int ordinal)", "getInt", intCases), + emitOrdinalSwitch("public long getLong(int ordinal)", "getLong", longCases), + emitOrdinalSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + emitOrdinalSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + emitOrdinalSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + emitOrdinalSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + emitOrdinalSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + private def wrapsInCometPlainVector(cls: Class[_]): Boolean = + primitiveArrowClasses.contains(cls) + + private def emitOrdinalSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } + + private def emitDecimalSlowBody(field: String, idx: String, ind: String): String = { + val cont = ind + " " + s"""${ind}java.math.BigDecimal bd = $field.getObject($idx); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.apply(bd, precision, scale);""".stripMargin + } + + private def emitDecimalFastBodyUnsafe(valueAddr: String, idx: String, ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}long unscaled = org.apache.spark.unsafe.Platform.getLong(null, + |$cont$valueAddr + (long) $i * 16L); + |${ind}return org.apache.spark.sql.types.Decimal$$.MODULE$$ + |$cont.createUnsafe(unscaled, precision, scale);""".stripMargin + } + + private def emitUtf8BodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}return org.apache.spark.unsafe.types.UTF8String + |$cont.fromAddress(null, $valueAddr + s, e - s);""".stripMargin + } + + /** Parenthesize `idx` when it contains whitespace, to keep `(long) idx * 16L` well-formed. */ + private def castableIdx(idx: String): String = if (idx.contains(' ')) s"($idx)" else idx + + private def emitBinaryBodyUnsafe( + valueAddr: String, + offsetAddr: String, + idx: String, + ind: String): String = { + val cont = ind + " " + val i = castableIdx(idx) + s"""${ind}int s = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + (long) $i * 4L); + |${ind}int e = org.apache.spark.unsafe.Platform.getInt(null, + |$cont$offsetAddr + ((long) $i + 1L) * 4L); + |${ind}int len = e - s; + |${ind}byte[] out = new byte[len]; + |${ind}org.apache.spark.unsafe.Platform.copyMemory(null, $valueAddr + s, out, + |${cont}org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET, len); + |${ind}return out;""".stripMargin + } + + /** + * Per-ordinal map of the `DecimalType` observed on `BoundReference`s. Used by + * [[emitTypedGetters]] to emit a precision-specialized `getDecimal` case per ordinal. + */ + def decimalPrecisionByOrdinal(boundExpr: Expression): Map[Int, Option[DecimalType]] = { + boundExpr + .collect { + case b: BoundReference if b.dataType.isInstanceOf[DecimalType] => + b.ordinal -> b.dataType.asInstanceOf[DecimalType] + } + .groupBy(_._1) + .map { case (ord, pairs) => + val distinct = pairs.map(_._2).toSet + ord -> (if (distinct.size == 1) Some(distinct.head) else None) + } + } + + /** + * Emit nested classes for every complex level of every input column: `InputArray_${path}` for + * arrays, `InputStruct_${path}` for structs, `InputMap_${path}` plus `InputArray` views for the + * key/value slices for maps (Spark's `MapData.keyArray()` / `valueArray()` return `ArrayData`). + */ + def emitNestedClasses(inputSchema: Seq[ArrowColumnSpec]): String = { + val out = new mutable.ArrayBuffer[String]() + inputSchema.zipWithIndex.foreach { case (spec, ord) => + collectNestedClasses(s"col$ord", spec, out) + } + out.mkString("\n") + } + + /** + * Top-level `getArray(int ordinal)` switch. Each case reads `(start, length)` from the outer + * `ListVector` offsets and allocates a fresh `InputArray_col${ord}` view. + */ + def emitGetArrayMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: ArrayColumnSpec, ord) => + s""" case $ord: { + | int __idx = this.rowIdx; + | int __s = this.col$ord.getElementStartIndex(__idx); + | int __e = this.col$ord.getElementEndIndex(__idx); + | return new InputArray_col$ord(__s, __e - __s); + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getArray out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** Top-level `getMap(int ordinal)` switch when the schema has at least one map column. */ + def emitGetMapMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: MapColumnSpec, ord) => + s""" case $ord: { + | int __idx = this.rowIdx; + | int __s = this.col$ord.getElementStartIndex(__idx); + | int __e = this.col$ord.getElementEndIndex(__idx); + | return new InputMap_col$ord(__s, __e - __s); + | }""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getMap out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** Top-level `getStruct(int ordinal, int numFields)` switch when the schema has any struct. */ + def emitGetStructMethod(inputSchema: Seq[ArrowColumnSpec]): String = { + val cases = inputSchema.zipWithIndex.collect { case (_: StructColumnSpec, ord) => + s""" case $ord: return new InputStruct_col$ord(this.rowIdx);""".stripMargin + } + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields) { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "getStruct out of range: " + ordinal); + | } + | } + |""".stripMargin + } + } + + /** + * Scalar columns that need a cached data-buffer address for inline unsafe reads. + * `DecimalVector` uses it for the short-precision fast path (`Platform.getLong`); + * `VarCharVector` / `VarBinaryVector` use it as the base for `UTF8String.fromAddress` / + * `Platform.copyMemory`. + */ + private def needsValueAddrField(cls: Class[_]): Boolean = + cls == classOf[DecimalVector] || + cls == classOf[VarCharVector] || + cls == classOf[VarBinaryVector] + + /** Variable-width columns also cache the offset-buffer address for `Platform.getInt`. */ + private def needsOffsetAddrField(cls: Class[_]): Boolean = + cls == classOf[VarCharVector] || cls == classOf[VarBinaryVector] + + /** + * Java method name for the per-column null check. Primitive scalars wrapped in + * [[CometPlainVector]] expose `isNullAt`; Arrow typed fields expose `isNull`. Same semantics. + */ + private def nullCheckMethod(spec: ArrowColumnSpec): String = spec match { + case sc: ScalarColumnSpec if wrapsInCometPlainVector(sc.vectorClass) => "isNullAt" + case _ => "isNull" + } + + private def collectVectorFieldDecls( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + // Primitive scalars wrap in CometPlainVector for JIT-inlined Platform.get* against a + // cached buffer address. Decimal/VarChar/VarBinary stay on the Arrow typed field with + // cached data- (and offset-) buffer addresses for inline unsafe reads. + val fieldClass = + if (wrapsInCometPlainVector(sc.vectorClass)) cometPlainVectorName + else sc.vectorClass.getName + out += s"private $fieldClass $path;" + if (needsValueAddrField(sc.vectorClass)) { + out += s"private long ${path}_valueAddr;" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"private long ${path}_offsetAddr;" + } + case ar: ArrayColumnSpec => + out += s"private ${classOf[ListVector].getName} $path;" + collectVectorFieldDecls(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += s"private ${classOf[StructVector].getName} $path;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectVectorFieldDecls(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += s"private ${classOf[MapVector].getName} $path;" + // Key/value vectors live at `${P}_k_e` / `${P}_v_e` so the synthetic `InputArray_${P}_k` / + // `InputArray_${P}_v` classes (which follow the array-element convention of reading from + // `${path}_e`) resolve correctly. + collectVectorFieldDecls(s"${path}_k_e", mp.key, out) + collectVectorFieldDecls(s"${path}_v_e", mp.value, out) + } + + private def collectCasts( + path: String, + spec: ArrowColumnSpec, + source: String, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case sc: ScalarColumnSpec => + if (wrapsInCometPlainVector(sc.vectorClass)) { + // `useDecimal128 = true` matches Spark's 128-bit decimal storage. + out += s"this.$path = new $cometPlainVectorName($source, true);" + } else { + out += s"this.$path = (${sc.vectorClass.getName}) $source;" + } + if (needsValueAddrField(sc.vectorClass)) { + out += s"this.${path}_valueAddr = this.$path.getDataBuffer().memoryAddress();" + } + if (needsOffsetAddrField(sc.vectorClass)) { + out += s"this.${path}_offsetAddr = this.$path.getOffsetBuffer().memoryAddress();" + } + case ar: ArrayColumnSpec => + out += s"this.$path = (${classOf[ListVector].getName}) $source;" + collectCasts(s"${path}_e", ar.element, s"this.$path.getDataVector()", out) + case st: StructColumnSpec => + out += s"this.$path = (${classOf[StructVector].getName}) $source;" + st.fields.zipWithIndex.foreach { case (f, fi) => + collectCasts(s"${path}_f$fi", f.child, s"this.$path.getChildByOrdinal($fi)", out) + } + case mp: MapColumnSpec => + // MapVector's data vector is a StructVector with key at child 0 and value at child 1. + val structLocal = s"${path}__mapStruct" + out += s"this.$path = (${classOf[MapVector].getName}) $source;" + out += s"${classOf[StructVector].getName} $structLocal = " + + s"(${classOf[StructVector].getName}) this.$path.getDataVector();" + collectCasts(s"${path}_k_e", mp.key, s"$structLocal.getChildByOrdinal(0)", out) + collectCasts(s"${path}_v_e", mp.value, s"$structLocal.getChildByOrdinal(1)", out) + } + + private def collectNestedClasses( + path: String, + spec: ArrowColumnSpec, + out: mutable.ArrayBuffer[String]): Unit = spec match { + case _: ScalarColumnSpec => () + case ar: ArrayColumnSpec => + out += emitArrayClass(path, ar) + collectNestedClasses(s"${path}_e", ar.element, out) + case st: StructColumnSpec => + out += emitStructClass(path, st) + st.fields.zipWithIndex.foreach { case (f, fi) => + collectNestedClasses(s"${path}_f$fi", f.child, out) + } + case mp: MapColumnSpec => + out += emitMapClass(path) + // Emit `InputArray_${path}_k` / `InputArray_${path}_v` (the views returned by + // `keyArray()` / `valueArray()`). Each reads from `${classPath}_e` per the array-element + // convention, mapping to the key/value vector at `${path}_k_e` / `${path}_v_e`. + out += emitArrayClass( + s"${path}_k", + ArrayColumnSpec(nullable = true, elementSparkType = mp.keySparkType, element = mp.key)) + out += emitArrayClass( + s"${path}_v", + ArrayColumnSpec( + nullable = true, + elementSparkType = mp.valueSparkType, + element = mp.value)) + collectNestedClasses(s"${path}_k_e", mp.key, out) + collectNestedClasses(s"${path}_v_e", mp.value, out) + } + + /** + * Emit one `InputArray_${path}` nested class. Constructor takes `(startIdx, length)` and stores + * both in `final` fields. Map key/value arrays share this shape. + */ + private def emitArrayClass(path: String, spec: ArrayColumnSpec): String = { + val baseClassName = classOf[CometArrayData].getName + val elemPath = s"${path}_e" + val isNullAt = + s""" @Override + | public boolean isNullAt(int i) { + | return $elemPath.${nullCheckMethod(spec.element)}(startIndex + i); + | }""".stripMargin + val elementGetter = emitArrayElementGetter(path, spec) + s""" private final class InputArray_$path extends $baseClassName { + | private final int startIndex; + | private final int length; + | + | InputArray_$path(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + |$isNullAt + | + |$elementGetter + | } + |""".stripMargin + } + + /** + * Element-getter body for a nested array. Scalar -> direct typed read. Complex -> allocate a + * fresh inner view. + * + * Reference-typed getters (`getDecimal` / `getUTF8String` / `getBinary` / `getStruct` / + * `getArray` / `getMap`) prepend `if (isNullAt(i)) return null;` when the element is nullable, + * because Spark's `CodeGenerator.setArrayElement` only emits the caller-side `isNullAt` check + * for primitive elements (it relies on the source's getter to return null for reference types, + * matching `ColumnarArray.getBinary`). Without this guard, expressions like `Flatten.doGenCode` + * write empty bytes / garbage decimals where Spark expects null. + */ + private def emitArrayElementGetter(path: String, spec: ArrayColumnSpec): String = { + val elemPath = s"${path}_e" + val nullGuard = + if (spec.element.nullable) " if (isNullAt(i)) return null;\n" + else "" + spec.element match { + case _: ScalarColumnSpec => + emitArrayElementScalarGetter(spec.elementSparkType, elemPath, spec.element.nullable) + case _: ArrayColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.util.ArrayData getArray(int i) { + |$nullGuard int __idx = startIndex + i; + | int __s = $elemPath.getElementStartIndex(__idx); + | int __e = $elemPath.getElementEndIndex(__idx); + | return new InputArray_$elemPath(__s, __e - __s); + | }""".stripMargin + case _: StructColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.InternalRow getStruct(int i, int numFields) { + |$nullGuard return new InputStruct_$elemPath(startIndex + i); + | }""".stripMargin + case _: MapColumnSpec => + s""" @Override + | public org.apache.spark.sql.catalyst.util.MapData getMap(int i) { + |$nullGuard int __idx = startIndex + i; + | int __s = $elemPath.getElementStartIndex(__idx); + | int __e = $elemPath.getElementEndIndex(__idx); + | return new InputMap_$elemPath(__s, __e - __s); + | }""".stripMargin + } + } + + /** + * Scalar-element getter override. Only the getter matching the element type is overridden; + * other getters inherit the base class's `UnsupportedOperationException`. Reference-typed + * getters (Decimal / String / Binary) prepend the null guard documented on + * [[emitArrayElementGetter]]. + */ + private def emitArrayElementScalarGetter( + elemType: DataType, + childField: String, + elementNullable: Boolean): String = { + val nullGuard = + if (elementNullable) " if (isNullAt(i)) return null;\n" + else "" + elemType match { + case BooleanType => + s""" @Override + | public boolean getBoolean(int i) { + | return $childField.getBoolean(startIndex + i); + | }""".stripMargin + case ByteType => + s""" @Override + | public byte getByte(int i) { + | return $childField.getByte(startIndex + i); + | }""".stripMargin + case ShortType => + s""" @Override + | public short getShort(int i) { + | return $childField.getShort(startIndex + i); + | }""".stripMargin + case IntegerType | DateType => + s""" @Override + | public int getInt(int i) { + | return $childField.getInt(startIndex + i); + | }""".stripMargin + case LongType | TimestampType | TimestampNTZType => + s""" @Override + | public long getLong(int i) { + | return $childField.getLong(startIndex + i); + | }""".stripMargin + case FloatType => + s""" @Override + | public float getFloat(int i) { + | return $childField.getFloat(startIndex + i); + | }""".stripMargin + case DoubleType => + s""" @Override + | public double getDouble(int i) { + | return $childField.getDouble(startIndex + i); + | }""".stripMargin + case dt: DecimalType => + val body = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + emitDecimalFastBodyUnsafe(s"${childField}_valueAddr", "startIndex + i", " ") + } else { + emitDecimalSlowBody(childField, "startIndex + i", " ") + } + s""" @Override + | public org.apache.spark.sql.types.Decimal getDecimal( + | int i, int precision, int scale) { + |$nullGuard$body + | }""".stripMargin + case _: StringType => + s""" @Override + | public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i) { + |$nullGuard${emitUtf8BodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case BinaryType => + s""" @Override + | public byte[] getBinary(int i) { + |$nullGuard${emitBinaryBodyUnsafe( + s"${childField}_valueAddr", + s"${childField}_offsetAddr", + "startIndex + i", + " ")} + | }""".stripMargin + case other => + throw new UnsupportedOperationException( + s"nested ArrayData: unsupported element type $other") + } + } + + /** + * Emit one `InputStruct_${path}` nested class. Constructor takes `rowIdx` and stores it in a + * `final` field. Scalar getters switch on field ordinal. Complex getters allocate fresh inner + * views (offsets computed for array/map children, rowIdx passed through for struct children). + */ + private def emitStructClass(path: String, spec: StructColumnSpec): String = { + val baseClassName = classOf[CometInternalRow].getName + val isNullCases = spec.fields.zipWithIndex.map { + case (f, fi) if !f.nullable => + s" case $fi: return false;" + case (f, fi) => + s" case $fi: return ${path}_f$fi.${nullCheckMethod(f.child)}(this.rowIdx);" + } + val scalarGetters = emitStructScalarGetters(path, spec) + val complexGetters = emitStructComplexGetters(path, spec) + s""" private final class InputStruct_$path extends $baseClassName { + | private final int rowIdx; + | + | InputStruct_$path(int outerRowIdx) { + | this.rowIdx = outerRowIdx; + | } + | + | @Override + | public int numFields() { + | return ${spec.fields.length}; + | } + | + | @Override + | public boolean isNullAt(int ordinal) { + | switch (ordinal) { + |${isNullCases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "InputStruct_$path.isNullAt out of range: " + ordinal); + | } + | } + | + |$scalarGetters + |$complexGetters + | } + |""".stripMargin + } + + // Scalar-read body templates parameterized on row-index expression (`idx`), cached buffer + // addresses (`valueAddr`, `offsetAddr`) for unsafe reads, or the Arrow field for the decimal + // slow path. `ind` is the per-line indent. + // + // TODO(#4280, #4279): once offset-address caching and validity-bitmap byte cache land in + // CometPlainVector, replace the VarChar/VarBinary unsafe emitters with CometPlainVector reads. + + private def emitStructScalarGetters(path: String, spec: StructColumnSpec): String = { + val withOrd = spec.fields.zipWithIndex + val scalarOrd = withOrd.filter { case (f, _) => f.child.isInstanceOf[ScalarColumnSpec] } + + // For nullable reference-typed struct fields, prepend the null guard so `getX(ord)` returns + // null on null positions (Spark contract for reference types). Same rationale as the array + // element getter. + def nullGuardForCase(fi: Int, fieldNullable: Boolean): String = + if (fieldNullable) s" if (isNullAt($fi)) return null;\n" + else "" + + def fieldReadScalar(fi: Int, dt: DataType, fieldNullable: Boolean): String = { + val guard = nullGuardForCase(fi, fieldNullable) + dt match { + case BooleanType => + s" case $fi: return ${path}_f$fi.getBoolean(this.rowIdx);" + case ByteType => + s" case $fi: return ${path}_f$fi.getByte(this.rowIdx);" + case ShortType => + s" case $fi: return ${path}_f$fi.getShort(this.rowIdx);" + case IntegerType | DateType => + s" case $fi: return ${path}_f$fi.getInt(this.rowIdx);" + case LongType | TimestampType | TimestampNTZType => + s" case $fi: return ${path}_f$fi.getLong(this.rowIdx);" + case FloatType => + s" case $fi: return ${path}_f$fi.getFloat(this.rowIdx);" + case DoubleType => + s" case $fi: return ${path}_f$fi.getDouble(this.rowIdx);" + case BinaryType => + s""" case $fi: { + |$guard${emitBinaryBodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: StringType => + s""" case $fi: { + |$guard${emitUtf8BodyUnsafe( + s"${path}_f${fi}_valueAddr", + s"${path}_f${fi}_offsetAddr", + "this.rowIdx", + " ")} + | }""".stripMargin + case _: DecimalType => + throw new IllegalStateException("decimal handled separately") + case other => + throw new UnsupportedOperationException( + s"nested InputStruct getter: unsupported field type $other") + } + } + + val booleanCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BooleanType => + fieldReadScalar(fi, BooleanType, f.nullable) + } + val byteCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ByteType => + fieldReadScalar(fi, ByteType, f.nullable) + } + val shortCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == ShortType => + fieldReadScalar(fi, ShortType, f.nullable) + } + val intCases = scalarOrd.collect { + case (f, fi) if f.sparkType == IntegerType || f.sparkType == DateType => + fieldReadScalar(fi, IntegerType, f.nullable) + } + val longCases = scalarOrd.collect { + case (f, fi) + if f.sparkType == LongType || f.sparkType == TimestampType || + f.sparkType == TimestampNTZType => + fieldReadScalar(fi, LongType, f.nullable) + } + val floatCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == FloatType => + fieldReadScalar(fi, FloatType, f.nullable) + } + val doubleCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == DoubleType => + fieldReadScalar(fi, DoubleType, f.nullable) + } + val binaryCases = + scalarOrd.collect { + case (f, fi) if f.sparkType == BinaryType => + fieldReadScalar(fi, BinaryType, f.nullable) + } + val utf8Cases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[StringType] => + fieldReadScalar(fi, f.sparkType, f.nullable) + } + + val decimalCases = scalarOrd.collect { + case (f, fi) if f.sparkType.isInstanceOf[DecimalType] => + val dt = f.sparkType.asInstanceOf[DecimalType] + val field = s"${path}_f$fi" + val body = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + emitDecimalFastBodyUnsafe(s"${field}_valueAddr", "this.rowIdx", " ") + } else { + emitDecimalSlowBody(field, "this.rowIdx", " ") + } + val guard = nullGuardForCase(fi, f.nullable) + s""" case $fi: { + |$guard$body + | }""".stripMargin + } + + Seq( + structSwitch("public boolean getBoolean(int ordinal)", "getBoolean", booleanCases), + structSwitch("public byte getByte(int ordinal)", "getByte", byteCases), + structSwitch("public short getShort(int ordinal)", "getShort", shortCases), + structSwitch("public int getInt(int ordinal)", "getInt", intCases), + structSwitch("public long getLong(int ordinal)", "getLong", longCases), + structSwitch("public float getFloat(int ordinal)", "getFloat", floatCases), + structSwitch("public double getDouble(int ordinal)", "getDouble", doubleCases), + structSwitch( + "public org.apache.spark.sql.types.Decimal getDecimal(" + + "int ordinal, int precision, int scale)", + "getDecimal", + decimalCases), + structSwitch("public byte[] getBinary(int ordinal)", "getBinary", binaryCases), + structSwitch( + "public org.apache.spark.unsafe.types.UTF8String getUTF8String(int ordinal)", + "getUTF8String", + utf8Cases)).mkString + } + + private def emitStructComplexGetters(path: String, spec: StructColumnSpec): String = { + // Same null-guard rationale as `emitArrayElementGetter`. + def guardLine(fi: Int, fieldNullable: Boolean): String = + if (fieldNullable) s" if (isNullAt($fi)) return null;\n" + else "" + val getArrayCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[ArrayColumnSpec] => + val fieldPath = s"${path}_f$fi" + s""" case $fi: { + |${guardLine(fi, f.nullable)} int __idx = this.rowIdx; + | int __s = $fieldPath.getElementStartIndex(__idx); + | int __e = $fieldPath.getElementEndIndex(__idx); + | return new InputArray_$fieldPath(__s, __e - __s); + | }""".stripMargin + } + val getStructCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[StructColumnSpec] => + val fieldPath = s"${path}_f$fi" + if (f.nullable) { + s""" case $fi: { + |${guardLine( + fi, + f.nullable)} return new InputStruct_$fieldPath(this.rowIdx); + | }""".stripMargin + } else { + s" case $fi: return new InputStruct_$fieldPath(this.rowIdx);" + } + } + val getMapCases = spec.fields.zipWithIndex.collect { + case (f, fi) if f.child.isInstanceOf[MapColumnSpec] => + val fieldPath = s"${path}_f$fi" + s""" case $fi: { + |${guardLine(fi, f.nullable)} int __idx = this.rowIdx; + | int __s = $fieldPath.getElementStartIndex(__idx); + | int __e = $fieldPath.getElementEndIndex(__idx); + | return new InputMap_$fieldPath(__s, __e - __s); + | }""".stripMargin + } + Seq( + structSwitch( + "public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)", + "getArray", + getArrayCases), + structSwitch( + "public org.apache.spark.sql.catalyst.InternalRow getStruct(int ordinal, int numFields)", + "getStruct", + getStructCases), + structSwitch( + "public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)", + "getMap", + getMapCases)).mkString + } + + /** + * Emit one `InputMap_${path}` nested class. Constructor takes `(start, length)`; `keyArray()` / + * `valueArray()` allocate fresh `InputArray_${path}_k` / `InputArray_${path}_v` views. + */ + private def emitMapClass(path: String): String = { + val baseClassName = classOf[CometMapData].getName + val keyPath = s"${path}_k" + val valPath = s"${path}_v" + s""" private final class InputMap_$path extends $baseClassName { + | private final int startIndex; + | private final int length; + | + | InputMap_$path(int startIdx, int len) { + | this.startIndex = startIdx; + | this.length = len; + | } + | + | @Override + | public int numElements() { + | return length; + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData keyArray() { + | return new InputArray_$keyPath(this.startIndex, this.length); + | } + | + | @Override + | public org.apache.spark.sql.catalyst.util.ArrayData valueArray() { + | return new InputArray_$valPath(this.startIndex, this.length); + | } + | } + |""".stripMargin + } + + private def structSwitch(methodSig: String, label: String, cases: Seq[String]): String = { + if (cases.isEmpty) { + "" + } else { + s""" + | @Override + | $methodSig { + | switch (ordinal) { + |${cases.mkString("\n")} + | default: throw new UnsupportedOperationException( + | "$label out of range: " + ordinal); + | } + | } + """.stripMargin + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala new file mode 100644 index 0000000000..7a6b02237d --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometBatchKernelCodegenOutput.scala @@ -0,0 +1,394 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types._ + +import org.apache.comet.CometArrowAllocator + +/** + * Output-side emitters for the codegen kernel: [[allocateOutput]], [[emitOutputWriter]] + * (top-level write entry), [[emitWrite]] (recursive per-type write), the output vector-class + * lookup. Paired with [[CometBatchKernelCodegenInput]] on the read side. + */ +private[codegen] object CometBatchKernelCodegenOutput { + + /** + * Spark `DataType` to an Arrow `Field` with names Comet expects on FFI export. Spark's + * `Utils.toArrowField` names list children `"element"`; this rewrites them to `"item"`. Pair + * with the [[RenamedListVector]] / [[RenamedMapVector]] / [[RenamedStructVector]] subclasses in + * [[allocateOutput]], which pin `getField()` so the cached Field actually reaches export. + */ + def toFfiArrowField(name: String, dataType: DataType, nullable: Boolean): Field = + renameForArrowRustFfi(Utils.toArrowField(name, dataType, nullable, "UTC")) + + private def renameForArrowRustFfi(field: Field): Field = { + val children = field.getChildren.asScala + if (children.isEmpty) return field + field.getType match { + case _: ArrowType.List | _: ArrowType.LargeList | _: ArrowType.FixedSizeList => + val child = children.head + val renamedChild = renameForArrowRustFfi( + new Field("item", child.getFieldType, child.getChildren)) + new Field( + field.getName, + field.getFieldType, + java.util.Collections.singletonList(renamedChild)) + case _ => + val renamedChildren = children.map(renameForArrowRustFfi).toList.asJava + new Field(field.getName, field.getFieldType, renamedChildren) + } + } + + /** + * Allocate an Arrow output vector from a pre-built `Field`. Callers cache the Field per + * `(expression, schema)` and pass it on every batch. + * + * Complex top-level types route through a [[RenamedListVector]] / [[RenamedMapVector]] / + * [[RenamedStructVector]] (see those for the runtime-vs-export naming gap). + * + * `estimatedBytes` pre-sizes the data buffer for variable-length scalar outputs. Ignored for + * other root types, and not propagated into nested var-width children (their `allocateNew` runs + * through the parent's `allocateNew`, which resets child buffers). + * + * TODO(nested-varwidth-sizing): thread the estimate into nested var-width children. + * + * TODO(cached-write-buffer-addrs): cache buffer addresses at `process` setup and emit + * `Platform.putByte` / `Platform.copyMemory` for VarChar / VarBinary / Decimal scalar outputs, + * bypassing `setSafe`'s realloc check. Depends on pre-allocated buffers. + * + * Closes the vector on any failure so a partially-initialized tree doesn't leak buffers. + */ + def allocateOutput(field: Field, numRows: Int, estimatedBytes: Int): FieldVector = { + val vec: FieldVector = field.getType match { + case _: ArrowType.List | _: ArrowType.LargeList | _: ArrowType.FixedSizeList => + val v = new RenamedListVector(field, CometArrowAllocator) + v.initializeChildrenFromFields(field.getChildren) + v + case _: ArrowType.Map => + val v = new RenamedMapVector(field, CometArrowAllocator) + v.initializeChildrenFromFields(field.getChildren) + v + case _: ArrowType.Struct => + val v = new RenamedStructVector(field, CometArrowAllocator) + v.initializeChildrenFromFields(field.getChildren) + v + case _ => + field.createVector(CometArrowAllocator).asInstanceOf[FieldVector] + } + try { + vec.setInitialCapacity(numRows) + vec match { + case v: BaseVariableWidthVector if estimatedBytes > 0 => + v.allocateNew(estimatedBytes.toLong, numRows) + case _ => + vec.allocateNew() + } + vec + } catch { + case t: Throwable => + try vec.close() + catch { + case NonFatal(_) => () + } + throw t + } + } + + /** + * Pin `getField()` to the cached Field so FFI export carries the names Comet expects. + * `ListVector.getField` rebuilds child labels from the runtime data vector, which + * `addOrGetVector` hardcodes to `"$data$"`. Applied to `MapVector` and `StructVector` too + * because their `getField` recurses and can pick up a buried `ListVector`'s `"$data$"`. + */ + private final class RenamedListVector(exportField: Field, allocator: BufferAllocator) + extends ListVector(exportField, allocator, null) { + override def getField: Field = exportField + } + + private final class RenamedMapVector(exportField: Field, allocator: BufferAllocator) + extends MapVector(exportField, allocator, null) { + override def getField: Field = exportField + } + + private final class RenamedStructVector(exportField: Field, allocator: BufferAllocator) + extends StructVector(exportField, allocator, null) { + override def getField: Field = exportField + } + + /** + * Returns `(concreteVectorClassName, batchSetup, perRowSnippet)`. `output` is cast to the + * concrete class in `process`'s prelude so `emitWrite`'s complex-type branches can hoist child + * casts off `output` without re-casting per row. + */ + def emitOutputWriter( + dataType: DataType, + valueTerm: String, + ctx: CodegenContext): (String, String, String) = { + val cls = outputVectorClass(dataType) + val emit = emitWrite("output", "i", valueTerm, dataType, ctx) + (cls, emit.setup, emit.perRow) + } + + /** Concrete Arrow vector class name for the output type, used to cast `outRaw` once. */ + private def outputVectorClass(dataType: DataType): String = dataType match { + case BooleanType => classOf[BitVector].getName + case ByteType => classOf[TinyIntVector].getName + case ShortType => classOf[SmallIntVector].getName + case IntegerType => classOf[IntVector].getName + case LongType => classOf[BigIntVector].getName + case FloatType => classOf[Float4Vector].getName + case DoubleType => classOf[Float8Vector].getName + case _: DecimalType => classOf[DecimalVector].getName + case _: StringType => classOf[VarCharVector].getName + case BinaryType => classOf[VarBinaryVector].getName + case DateType => classOf[DateDayVector].getName + case TimestampType => classOf[TimeStampMicroTZVector].getName + case TimestampNTZType => classOf[TimeStampMicroVector].getName + case _: ArrayType => classOf[ListVector].getName + case _: StructType => classOf[StructVector].getName + case _: MapType => classOf[MapVector].getName + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.outputVectorClass: unsupported output type $other") + } + + /** + * Composable write emitter. Returns an [[OutputEmit]] whose `setup` declares once-per-batch + * typed child-vector casts and whose `perRow` writes `source` into `targetVec` at `idx`. + * `targetVec` is assumed pre-cast to the right Arrow class (root prelude or a parent's setup). + * + * Scalars emit `perRow` only. Complex types emit both. Inner setup bubbles up so deep child + * casts land at the batch prelude. + */ + private def emitWrite( + targetVec: String, + idx: String, + source: String, + dataType: DataType, + ctx: CodegenContext): OutputEmit = dataType match { + case BooleanType => + OutputEmit("", s"$targetVec.set($idx, $source ? 1 : 0);") + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | DateType | + TimestampType | TimestampNTZType => + // Spark codegen emits the matching primitive Java type; Arrow `set` overloads accept it. + OutputEmit("", s"$targetVec.set($idx, $source);") + case dt: DecimalType => + // DecimalOutputShortFastPath: precision <= 18 fits in a signed long, so pass the unscaled + // value to `setSafe(int, long)` and skip the BigDecimal allocation. + val write = + if (dt.precision <= Decimal.MAX_LONG_DIGITS) { + s"$targetVec.setSafe($idx, $source.toUnscaledLong());" + } else { + s"$targetVec.setSafe($idx, $source.toJavaBigDecimal());" + } + OutputEmit("", write) + case _: StringType => + // Utf8OutputOnHeapShortcut: when the UTF8String is on-heap (Spark's string functions + // allocate results on-heap), pass its backing byte[] directly to `setSafe`, skipping the + // `getBytes()` allocation. Off-heap falls back to `getBytes()`. + // + // TODO(utf8-unsafe-write): output-side equivalent of `UTF8String.fromAddress`. Coupled + // with `cached-write-buffer-addrs` and a pre-allocated buffer. + val bBase = ctx.freshName("utfBase") + val bLen = ctx.freshName("utfLen") + val bArr = ctx.freshName("utfArr") + OutputEmit( + "", + s"""Object $bBase = $source.getBaseObject(); + |int $bLen = $source.numBytes(); + |if ($bBase instanceof byte[]) { + | $targetVec.setSafe($idx, (byte[]) $bBase, + | (int) ($source.getBaseOffset() + | - org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET), + | $bLen); + |} else { + | byte[] $bArr = $source.getBytes(); + | $targetVec.setSafe($idx, $bArr, 0, $bArr.length); + |}""".stripMargin) + case BinaryType => + OutputEmit("", s"$targetVec.setSafe($idx, $source, 0, $source.length);") + case ArrayType(elementType, containsNull) => + // Spark's `doGenCode` for ArrayType produces an `ArrayData` value. Iterate elements, + // write each into the `ListVector`'s child, bracket with `startNewValue`/`endValue`. The + // element write recurses through `emitWrite` on the child vector so any supported scalar + // becomes a valid element. Nested complex types compose. `targetVec` is a `ListVector` at + // the call site, and only its data vector needs casting (in setup). + // + // NullableElementElision: when `containsNull == false` drop the `isNullAt` guard at + // source level rather than relying on JIT folding. + val childVar = ctx.freshName("outListChild") + val childClass = outputVectorClass(elementType) + val arrVar = ctx.freshName("arr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val elemSource = emitSpecializedGetterExpr(arrVar, jVar, elementType) + val inner = emitWrite(childVar, s"$childIdx + $jVar", elemSource, elementType, ctx) + val setup = + (s"$childClass $childVar = ($childClass) $targetVec.getDataVector();" +: + Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + val elementWrite = if (containsNull) { + s"""if ($arrVar.isNullAt($jVar)) { + | $childVar.setNull($childIdx + $jVar); + | } else { + | ${inner.perRow} + | }""".stripMargin + } else { + inner.perRow + } + val perRow = + s"""org.apache.spark.sql.catalyst.util.ArrayData $arrVar = $source; + |int $nVar = $arrVar.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | $elementWrite + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case st: StructType => + // Spark's `doGenCode` for StructType produces an `InternalRow`. Typed child-vector casts + // hoist to setup, and the per-row body references the hoisted names. + // + // For non-nullable fields, drop the `row.isNullAt($fi)` guard at source level so HotSpot + // emits a straight write path per field rather than a branch. + val rowVar = ctx.freshName("row") + val perField = st.fields.zipWithIndex.map { case (field, fi) => + val childVar = ctx.freshName("outStructChild") + val childClass = outputVectorClass(field.dataType) + val childDecl = + s"$childClass $childVar = ($childClass) $targetVec.getChildByOrdinal($fi);" + val fieldSource = emitSpecializedGetterExpr(rowVar, fi.toString, field.dataType) + val inner = emitWrite(childVar, idx, fieldSource, field.dataType, ctx) + val write = + if (!field.nullable) { + inner.perRow + } else { + s"""if ($rowVar.isNullAt($fi)) { + | $childVar.setNull($idx); + |} else { + | ${inner.perRow} + |}""".stripMargin + } + val perFieldSetup = (Seq(childDecl) ++ Seq(inner.setup).filter(_.nonEmpty)).mkString("\n") + (perFieldSetup, write) + } + val setup = perField.map(_._1).mkString("\n") + val perFieldWrites = perField.map(_._2).mkString("\n") + val perRow = + s"""org.apache.spark.sql.catalyst.InternalRow $rowVar = $source; + |$targetVec.setIndexDefined($idx); + |$perFieldWrites""".stripMargin + OutputEmit(setup, perRow) + case mt: MapType => + // Spark's `doGenCode` for MapType produces a `MapData`. Typed child-vector casts for the + // entries struct and the key/value children hoist to setup. + // + // Per-row: read keyArray/valueArray, open via `startNewValue(idx)`, write each pair into + // the entries struct (key always non-null per Spark/Arrow invariant, value guarded on + // `valueContainsNull`), close via `endValue(idx, n)`. + val entriesVar = ctx.freshName("outMapEntries") + val keyVar = ctx.freshName("outMapKey") + val valVar = ctx.freshName("outMapVal") + val mapSrc = ctx.freshName("mapSrc") + val keyArr = ctx.freshName("keyArr") + val valArr = ctx.freshName("valArr") + val nVar = ctx.freshName("n") + val childIdx = ctx.freshName("cidx") + val jVar = ctx.freshName("j") + val structClass = classOf[StructVector].getName + val keyClass = outputVectorClass(mt.keyType) + val valClass = outputVectorClass(mt.valueType) + val keySrcExpr = emitSpecializedGetterExpr(keyArr, jVar, mt.keyType) + val valSrcExpr = emitSpecializedGetterExpr(valArr, jVar, mt.valueType) + val keyEmit = emitWrite(keyVar, s"$childIdx + $jVar", keySrcExpr, mt.keyType, ctx) + val valEmit = emitWrite(valVar, s"$childIdx + $jVar", valSrcExpr, mt.valueType, ctx) + val setup = + (Seq( + s"$structClass $entriesVar = ($structClass) $targetVec.getDataVector();", + s"$keyClass $keyVar = ($keyClass) $entriesVar.getChildByOrdinal(0);", + s"$valClass $valVar = ($valClass) $entriesVar.getChildByOrdinal(1);") ++ + Seq(keyEmit.setup, valEmit.setup).filter(_.nonEmpty)).mkString("\n") + val valueWrite = if (mt.valueContainsNull) { + s"""if ($valArr.isNullAt($jVar)) { + | $valVar.setNull($childIdx + $jVar); + | } else { + | ${valEmit.perRow} + | }""".stripMargin + } else { + valEmit.perRow + } + val perRow = + s"""org.apache.spark.sql.catalyst.util.MapData $mapSrc = $source; + |org.apache.spark.sql.catalyst.util.ArrayData $keyArr = $mapSrc.keyArray(); + |org.apache.spark.sql.catalyst.util.ArrayData $valArr = $mapSrc.valueArray(); + |int $nVar = $mapSrc.numElements(); + |int $childIdx = $targetVec.startNewValue($idx); + |for (int $jVar = 0; $jVar < $nVar; $jVar++) { + | $entriesVar.setIndexDefined($childIdx + $jVar); + | ${keyEmit.perRow} + | $valueWrite + |} + |$targetVec.endValue($idx, $nVar);""".stripMargin + OutputEmit(setup, perRow) + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitWrite: unsupported output type $other") + } + + /** + * Java expression that reads a typed value out of a `SpecializedGetters` (both `ArrayData` and + * `InternalRow` implement it). Used by [[emitWrite]] to source each element/field for its + * recursive inner write. + */ + private def emitSpecializedGetterExpr(target: String, idx: String, elemType: DataType): String = + elemType match { + case BooleanType => s"$target.getBoolean($idx)" + case ByteType => s"$target.getByte($idx)" + case ShortType => s"$target.getShort($idx)" + case IntegerType | DateType => s"$target.getInt($idx)" + case LongType | TimestampType | TimestampNTZType => s"$target.getLong($idx)" + case FloatType => s"$target.getFloat($idx)" + case DoubleType => s"$target.getDouble($idx)" + case dt: DecimalType => s"$target.getDecimal($idx, ${dt.precision}, ${dt.scale})" + case _: StringType => s"$target.getUTF8String($idx)" + case BinaryType => s"$target.getBinary($idx)" + case ArrayType(_, _) => s"$target.getArray($idx)" + case _: MapType => s"$target.getMap($idx)" + case _: StructType => + val numFields = elemType.asInstanceOf[StructType].fields.length + s"$target.getStruct($idx, $numFields)" + case other => + throw new UnsupportedOperationException( + s"CometBatchKernelCodegen.emitSpecializedGetterExpr: unsupported type $other") + } + + /** `setup` is once-per-batch (typed child-vector casts); `perRow` runs per row. */ + private case class OutputEmit(setup: String, perRow: String) +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala b/spark/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala new file mode 100644 index 0000000000..77321fed9c --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometInternalRow.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +import org.apache.comet.shims.CometInternalRowShim + +/** + * Throwing-default `InternalRow` base for the codegen kernel. Subclasses override only the + * getters their input shape needs. Centralizing the throws absorbs forward-compat breakage when + * Spark adds abstract methods. + * + * Two consumers: the compiled kernel (`ctx.INPUT_ROW = "row"` aliases `this`) and per-column + * `InputStruct_${path}` nested classes that back `getStruct(ord, n)`. + */ +abstract class CometInternalRow extends InternalRow with CometInternalRowShim { + + override def numFields: Int = unsupported("numFields") + + override def getInterval(ordinal: Int): CalendarInterval = unsupported("getInterval") + + override def get(ordinal: Int, dataType: DataType): AnyRef = + CometSpecializedGettersDispatch.get(this, ordinal, dataType) + + override def isNullAt(ordinal: Int): Boolean = unsupported("isNullAt") + + override def getBoolean(ordinal: Int): Boolean = unsupported("getBoolean") + + override def getByte(ordinal: Int): Byte = unsupported("getByte") + + override def getShort(ordinal: Int): Short = unsupported("getShort") + + override def getInt(ordinal: Int): Int = unsupported("getInt") + + override def getLong(ordinal: Int): Long = unsupported("getLong") + + override def getFloat(ordinal: Int): Float = unsupported("getFloat") + + override def getDouble(ordinal: Int): Double = unsupported("getDouble") + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = + unsupported("getDecimal") + + override def getUTF8String(ordinal: Int): UTF8String = unsupported("getUTF8String") + + override def getBinary(ordinal: Int): Array[Byte] = unsupported("getBinary") + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = unsupported("getStruct") + + override def getArray(ordinal: Int): ArrayData = unsupported("getArray") + + override def getMap(ordinal: Int): MapData = unsupported("getMap") + + override def setNullAt(i: Int): Unit = unsupported("setNullAt") + + override def update(i: Int, value: Any): Unit = unsupported("update") + + override def copy(): InternalRow = unsupported("copy") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this row shape") +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometMapData.scala b/spark/src/main/scala/org/apache/comet/codegen/CometMapData.scala new file mode 100644 index 0000000000..ac8254e72d --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometMapData.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} + +/** + * Throwing-default `MapData` base for the codegen kernel. Per-column `InputMap_${path}` + * subclasses override `numElements`, `keyArray`, and `valueArray` (the latter two return + * `InputArray_*` views over the same backing key/value vectors). + * + * `MapData` does not extend `SpecializedGetters`, so this base does not mix in the row/array shim + * or delegate to [[CometSpecializedGettersDispatch]]. + */ +abstract class CometMapData extends MapData { + + override def keyArray(): ArrayData = unsupported("keyArray") + + override def valueArray(): ArrayData = unsupported("valueArray") + + override def copy(): MapData = unsupported("copy") + + protected def unsupported(method: String): Nothing = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: $method not implemented for this map shape") + + override def toString(): String = { + val n = + try numElements().toString + catch { + case _: Throwable => "?" + } + s"${getClass.getSimpleName}(numElements=$n)" + } + + override def numElements(): Int = unsupported("numElements") +} diff --git a/spark/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala b/spark/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala new file mode 100644 index 0000000000..2f81c58c06 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/codegen/CometSpecializedGettersDispatch.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.codegen + +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.types._ + +/** + * Shared `SpecializedGetters.get(ordinal, dataType)` dispatch used by [[CometInternalRow]] and + * [[CometArrayData]]. Spark codegen paths (notably `SafeProjection` for ScalaUDF struct args) and + * interpreted-eval fallbacks (`ArrayDistinct.nullSafeEval` etc.) call the generic `get` instead + * of the typed getter, so both kernel-side bases need a non-throwing implementation. + * + * For complex types, the typed getter allocates a fresh `InputStruct_*` / `InputArray_*` / + * `InputMap_*` per call (`ColumnarRow`-style), so retain-by-reference consumers like + * `OpenHashSet` get distinct identities. + */ +private[codegen] object CometSpecializedGettersDispatch { + + def get(g: SpecializedGetters, ordinal: Int, dataType: DataType): AnyRef = { + if (g.isNullAt(ordinal)) return null + dataType match { + case BooleanType => java.lang.Boolean.valueOf(g.getBoolean(ordinal)) + case ByteType => java.lang.Byte.valueOf(g.getByte(ordinal)) + case ShortType => java.lang.Short.valueOf(g.getShort(ordinal)) + case IntegerType | DateType => java.lang.Integer.valueOf(g.getInt(ordinal)) + case LongType | TimestampType | TimestampNTZType => + java.lang.Long.valueOf(g.getLong(ordinal)) + case FloatType => java.lang.Float.valueOf(g.getFloat(ordinal)) + case DoubleType => java.lang.Double.valueOf(g.getDouble(ordinal)) + case _: StringType => g.getUTF8String(ordinal) + case BinaryType => g.getBinary(ordinal) + case dt: DecimalType => g.getDecimal(ordinal, dt.precision, dt.scale) + case st: StructType => g.getStruct(ordinal, st.size) + case _: ArrayType => g.getArray(ordinal) + case _: MapType => g.getMap(ordinal) + case other => + throw new UnsupportedOperationException( + s"${g.getClass.getSimpleName}: get for dataType $other not implemented") + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala new file mode 100644 index 0000000000..bf636f7221 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/CometScalaUDF.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.SparkEnv +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSeq, BindReferences, Literal, ScalaUDF} +import org.apache.spark.sql.types.BinaryType + +import org.apache.comet.CometConf +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.codegen.CometBatchKernelCodegen +import org.apache.comet.serde.ExprOuterClass.Expr +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, serializeDataType} +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * Routes scalar `ScalaUDF` (Scala and Java UDFs) through the codegen dispatcher. + * `ScalaUDF.doGenCode` emits compilable Java that invokes the user function via + * `ctx.addReferenceObj`; the dispatcher serializes the bound tree, the closure serializer carries + * the function reference across the wire, and the Janino-compiled kernel invokes it in a tight + * batch loop. + * + * Not covered: + * - Aggregate UDFs (`ScalaAggregator`, `TypedImperativeAggregate`, legacy UDAF). + * - Table UDFs and generators. + * - Python / Pandas UDFs. + * - Hive `GenericUDF` / `SimpleUDF`. + * + * Gated by [[CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED]]. When disabled, plans containing a + * `ScalaUDF` fall back to Spark for the enclosing operator. + */ +object CometScalaUDF extends CometExpressionSerde[ScalaUDF] { + + override def convert(expr: ScalaUDF, inputs: Seq[Attribute], binding: Boolean): Option[Expr] = { + if (!CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.get()) { + withInfo( + expr, + s"${CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key}=false; ScalaUDF has no native path " + + "so the plan falls back to Spark") + return None + } + + // Bind against only the AttributeReferences the tree actually reads, so ordinals align with + // the data args we ship. + val attrs = expr.collect { case a: AttributeReference => a }.distinct + val boundExpr = BindReferences.bindReference(expr, AttributeSeq(attrs)) + + // Gate at plan time. Surface the reason via withInfo rather than crashing Janino at execute. + CometBatchKernelCodegen.canHandle(boundExpr) match { + case Some(reason) => + withInfo(expr, reason) + return None + case None => + } + + // Serialize via Spark's closure serializer: respects the task context classloader (so user + // UDF jars are visible) and matches Spark's wire format. The bytes become arg 0 of the + // JvmScalarUdf proto and self-describe the expression so this works in cluster mode without + // executor-side driver registry state. + val serializer = SparkEnv.get.closureSerializer.newInstance() + val buffer = serializer.serialize(boundExpr) + val bytes = new Array[Byte](buffer.remaining()) + buffer.get(bytes) + val exprArg = exprToProtoInternal(Literal(bytes, BinaryType), inputs, binding) + .getOrElse(return None) + + val dataArgs = + attrs.map(a => exprToProtoInternal(a, inputs, binding).getOrElse(return None)) + val returnTypeProto = serializeDataType(expr.dataType).getOrElse(return None) + + val udfBuilder = ExprOuterClass.JvmScalarUdf + .newBuilder() + .setClassName(classOf[CometScalaUDFCodegen].getName) + .addArgs(exprArg) + dataArgs.foreach(udfBuilder.addArgs) + udfBuilder + .setReturnType(returnTypeProto) + .setReturnNullable(expr.nullable) + Some( + ExprOuterClass.Expr + .newBuilder() + .setJvmScalarUdf(udfBuilder.build()) + .build()) + } +} diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d85a2c30cb..b1703cb644 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -261,6 +261,7 @@ object QueryPlanSerde extends Logging with CometExprShim with CometTypeShim { classOf[MakeDecimal] -> CometMakeDecimal, classOf[MonotonicallyIncreasingID] -> CometMonotonicallyIncreasingId, classOf[ScalarSubquery] -> CometScalarSubquery, + classOf[ScalaUDF] -> CometScalaUDF, classOf[SparkPartitionID] -> CometSparkPartitionId, classOf[SortOrder] -> CometSortOrder, classOf[StaticInvoke] -> CometStaticInvoke, diff --git a/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala b/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala deleted file mode 100644 index 5e020ae74a..0000000000 --- a/spark/src/main/scala/org/apache/comet/udf/CometLambdaRegistry.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet.udf - -import java.util.UUID -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.sql.catalyst.expressions.Expression - -/** - * Thread-safe registry bridging plan-time Spark expressions to execution-time UDF lookup. At plan - * time the serde layer registers a lambda expression under a unique key; at execution time the - * UDF retrieves it by that key (passed as a scalar argument). - */ -object CometLambdaRegistry { - - private val registry = new ConcurrentHashMap[String, Expression]() - - def register(expression: Expression): String = { - val key = UUID.randomUUID().toString - registry.put(key, expression) - key - } - - def get(key: String): Expression = { - val expr = registry.get(key) - if (expr == null) { - throw new IllegalStateException( - s"Lambda expression not found in registry for key: $key. " + - "This indicates a lifecycle issue between plan creation and execution.") - } - expr - } - - def remove(key: String): Unit = { - registry.remove(key) - } - - // Visible for testing - def size(): Int = registry.size() -} diff --git a/spark/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala b/spark/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala new file mode 100644 index 0000000000..f575dd5b53 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/udf/codegen/CometScalaUDFCodegen.scala @@ -0,0 +1,326 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.udf.codegen + +import java.nio.ByteBuffer +import java.util.Collections +import java.util.concurrent.atomic.AtomicLong + +import scala.collection.mutable +import scala.util.control.NonFatal + +import org.apache.arrow.vector._ +import org.apache.arrow.vector.complex.{ListVector, MapVector, StructVector} +import org.apache.arrow.vector.types.pojo.Field +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.types.{BinaryType, DataType, StringType} + +import org.apache.comet.codegen.{CometBatchKernel, CometBatchKernelCodegen} +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} +import org.apache.comet.udf.CometUDF + +/** + * Arrow-direct codegen dispatcher. For each `(bound expression, input Arrow schema)` pair, + * compiles a specialized [[CometBatchKernel]] on first encounter, initializes it with the task's + * partition index, and caches the live instance. + * + * Arg 0 is a `VarBinaryVector` scalar carrying the closure-serialized bound `Expression` bytes; + * args 1..N are the data columns the `BoundReference`s read in ordinal order. + * + * Caching hierarchy, broadest scope on the left: + * {{{ + * +----------------------------+ +----------------------------+ +----------------------------+ + * | 1. JVM bytecode cache | | 2. Per-task dispatcher | | 3. Per-task kernel cache | + * | (Spark's CodeGenerator) | | (CometUdfBridge. | | (kernelCache field) | + * | | | INSTANCES) | | | + * +----------------------------+ +----------------------------+ +----------------------------+ + * | Key: generated Java | | Key: task + UDF class | | Key: bound expression + | + * | source | | | | input column shapes | + * | Value: compiled Java class | | Value: dispatcher object | | Value: ready-to-run kernel | + * | Scope: JVM, all queries | | Scope: one Spark task | | with state primed | + * | share it | | | | Scope: one Spark task | + * | Owner: Spark | | Owner: Comet | | (lives inside 2) | + * | | | | | Owner: Comet | + * +----------------------------+ +----------------------------+ +----------------------------+ + * }}} + * + * Stateful expressions (`Rand`, `MonotonicallyIncreasingID`) advance inside the per-task kernel + * across batches. + * + * `evaluate` runs under `this.synchronized` because DataFusion operators like `HashJoinExec` + * pipeline build/probe via `OnceAsync` (`tokio::spawn`), so multiple Tokio worker threads can + * call back into one task's dispatcher. The kernel's per-batch instance fields would race + * otherwise. + * + * TODO(udf-codegen-pool): if intra-task UDF parallelism shows up as a bottleneck, replace the + * per-key kernel instance with a pool and externalize per-partition counters. + */ +class CometScalaUDFCodegen extends CometUDF with Logging { + + /** + * Per-task cache keyed on serialized expression bytes plus per-column specs. The deserialized + * `boundExpr` carries mutable state (`NamedLambdaVariable.value` for HOFs, `Rand`'s + * `XORShiftRandom`) that must not be shared across concurrent tasks running the same query; + * keeping the cache per-task gives each task its own copy. Guarded by `this.synchronized`. + */ + private val kernelCache + : mutable.Map[CometScalaUDFCodegen.CacheKey, CometScalaUDFCodegen.CacheEntry] = + mutable.HashMap.empty + + override def evaluate(inputs: Array[ValueVector], numRows: Int): ValueVector = { + require( + inputs.length >= 1, + "CometScalaUDFCodegen requires at least 1 input (serialized expression), " + + s"got ${inputs.length}") + val exprVec = inputs(0).asInstanceOf[VarBinaryVector] + require( + exprVec.getValueCount >= 1 && !exprVec.isNull(0), + "CometScalaUDFCodegen requires non-null serialized expression bytes at arg 0") + val bytes = exprVec.get(0) + + // TODO(dict-encoded): kernels assume materialized inputs. Dict-encoded vectors would fail the + // cast in `specFor` below. Fix is to materialize at the dispatcher (via + // `CDataDictionaryProvider`) or widen `emitTypedGetters` with a dict-index + lookup path. + + val numDataCols = inputs.length - 1 + val dataCols = new Array[ValueVector](numDataCols) + val specs = new Array[ArrowColumnSpec](numDataCols) + var di = 0 + while (di < numDataCols) { + val v = inputs(di + 1) + dataCols(di) = v + specs(di) = specFor(v) + di += 1 + } + val n = numRows + val specsSeq = specs.toIndexedSeq + + val key = CometScalaUDFCodegen.CacheKey(ByteBuffer.wrap(bytes), specsSeq) + + // Cache lookup and `process` run under one lock to serialize concurrent Tokio callers that + // would otherwise race on the kernel's per-batch instance fields. + this.synchronized { + val entry = lookupOrCompile(key, bytes, specsSeq) + + val out = CometBatchKernelCodegen.allocateOutput( + entry.outputField, + n, + estimatedOutputBytes(entry.outputType, dataCols)) + try { + entry.kernel.process(dataCols, out, n) + out.setValueCount(n) + out + } catch { + case t: Throwable => + try out.close() + catch { + case NonFatal(_) => () + } + throw t + } + } + } + + private def lookupOrCompile( + key: CometScalaUDFCodegen.CacheKey, + bytes: Array[Byte], + specs: IndexedSeq[ArrowColumnSpec]): CometScalaUDFCodegen.CacheEntry = { + assert(Thread.holdsLock(this), "lookupOrCompile must run under this.synchronized") + kernelCache.get(key) match { + case Some(entry) => + CometScalaUDFCodegen.cacheHitCount.incrementAndGet() + entry + case None => + val loader = Option(Thread.currentThread().getContextClassLoader) + .getOrElse(classOf[Expression].getClassLoader) + val boundExpr = + try { + SparkEnv.get.closureSerializer + .newInstance() + .deserialize[Expression](ByteBuffer.wrap(bytes), loader) + } catch { + case NonFatal(t) => + logError( + "CometScalaUDFCodegen: closure-deserialize failed " + + s"(bytes=${bytes.length}, specs=$specs)", + t) + throw t + } + val compiled = CometBatchKernelCodegen.compile(boundExpr, specs) + val kernel = compiled.newInstance() + kernel.init(CometScalaUDFCodegen.currentPartitionIndex()) + val outputField = CometBatchKernelCodegen.toFfiArrowField( + "codegen_result", + boundExpr.dataType, + boundExpr.nullable) + val entry = + CometScalaUDFCodegen.CacheEntry(compiled, kernel, boundExpr.dataType, outputField) + kernelCache.put(key, entry) + CometScalaUDFCodegen.compileCount.incrementAndGet() + CometScalaUDFCodegen.recordCompiledSignature(specs, boundExpr.dataType) + entry + } + } + + /** + * Build the compile-time spec for one input Arrow vector. Recurses on complex types. + * + * Top-level `nullable=true` is hardcoded: the cache key does not specialize on per-batch null + * density. Schema-declared nullability still reaches the kernel via `BoundReference.nullable` + * embedded in `bytesKey`, so `BoundReference.doGenCode` elides its own `isNullAt` probe on + * non-null columns. `StructFieldSpec.nullable` reads `field.isNullable` from Arrow metadata, + * which is a schema property and therefore stable across batches. + */ + private def specFor(v: ValueVector): ArrowColumnSpec = v match { + case map: MapVector => + // MapVector extends ListVector, match it first. + val struct = map.getDataVector.asInstanceOf[StructVector] + val keyVec = struct.getChildByOrdinal(0).asInstanceOf[ValueVector] + val valueVec = struct.getChildByOrdinal(1).asInstanceOf[ValueVector] + MapColumnSpec( + nullable = true, + keySparkType = Utils.fromArrowField(keyVec.getField), + valueSparkType = Utils.fromArrowField(valueVec.getField), + key = specFor(keyVec), + value = specFor(valueVec)) + case list: ListVector => + val child = list.getDataVector + ArrayColumnSpec(nullable = true, Utils.fromArrowField(child.getField), specFor(child)) + case struct: StructVector => + val fieldSpecs = (0 until struct.size()).map { fi => + val childVec = struct.getChildByOrdinal(fi).asInstanceOf[ValueVector] + val field = struct.getField.getChildren.get(fi) + StructFieldSpec( + name = field.getName, + sparkType = Utils.fromArrowField(field), + nullable = field.isNullable, + child = specFor(childVec)) + } + StructColumnSpec(nullable = true, fieldSpecs) + case _: BitVector | _: TinyIntVector | _: SmallIntVector | _: IntVector | _: BigIntVector | + _: Float4Vector | _: Float8Vector | _: DecimalVector | _: VarCharVector | + _: VarBinaryVector | _: DateDayVector | _: TimeStampMicroVector | + _: TimeStampMicroTZVector => + ScalarColumnSpec(v.getClass.asInstanceOf[Class[_ <: ValueVector]], nullable = true) + case other => + throw new UnsupportedOperationException( + s"CometScalaUDFCodegen: unsupported Arrow vector ${other.getClass.getSimpleName}") + } + + /** + * Sum of variable-width input data buffer sizes as an upper bound for typical transform outputs + * (replace, upper, lower, substring, concat). Underestimates are still corrected by `setSafe`; + * this just reduces the odds of mid-loop reallocation. + */ + private def estimatedOutputBytes(outputType: DataType, dataCols: Array[ValueVector]): Int = { + outputType match { + case _: StringType | _: BinaryType => + var sum = 0 + var i = 0 + while (i < dataCols.length) { + dataCols(i) match { + case v: BaseVariableWidthVector => sum += v.getDataBuffer.writerIndex().toInt + case _ => // no size hint for fixed-width vector types + } + i += 1 + } + sum + case _ => -1 + } + } +} + +object CometScalaUDFCodegen { + + // JVM-wide counters across all per-task instances. Compile work is deduped JVM-wide via + // `CodeGenerator.compile`'s source cache. These track this dispatcher's per-task cache activity. + private val compileCount = new AtomicLong(0) + private val cacheHitCount = new AtomicLong(0) + + // Append-only set of distinct compiled-kernel signatures. Lets tests assert specialization + // shape (vector-class / dataType combinations the dispatcher emitted) and that composed + // subtrees fuse into one kernel. Per-task caches are dropped on completion, leaving no other + // place to observe the set across runs. + private val compiledSignatures = + Collections.synchronizedSet( + new java.util.HashSet[(IndexedSeq[Class[_ <: ValueVector]], DataType)]()) + + /** Snapshot of JVM-wide counters and distinct-signature count. */ + def stats(): DispatcherStats = + DispatcherStats(compileCount.get(), cacheHitCount.get(), compiledSignatures.size()) + + /** Reset counters; leaves the signature set intact. Tests only. */ + def resetStats(): Unit = { + compileCount.set(0) + cacheHitCount.set(0) + } + + /** + * Distinct compiled-kernel signatures: `(input vector classes in ordinal order, output Spark + * DataType)`. `ArrowColumnSpec.nullable` is intentionally omitted so the signature reflects + * what would specialize the kernel regardless of any future per-batch nullability variants. + */ + def snapshotCompiledSignatures(): Set[(IndexedSeq[Class[_ <: ValueVector]], DataType)] = { + import scala.jdk.CollectionConverters._ + compiledSignatures.synchronized { + compiledSignatures.iterator().asScala.toSet + } + } + + private[codegen] def recordCompiledSignature( + specs: IndexedSeq[ArrowColumnSpec], + outputType: DataType): Unit = { + compiledSignatures.add((specs.map(_.vectorClass), outputType)) + } + + /** + * Partition index for the kernel's `init`. Expressions whose `doGenCode` calls + * `addPartitionInitializationStatement` (`Rand`, `Randn`, `Uuid`) reseed mutable state from + * this. Falls back to 0 when the dispatcher is exercised outside a Spark task (unit tests). + */ + private def currentPartitionIndex(): Int = + Option(TaskContext.get()).map(_.partitionId()).getOrElse(0) + + /** + * Cache key: serialized expression bytes plus per-column compile-time invariants. `hashCode` + * walks `bytesKey` per lookup, so for large ScalaUDF closures it scales with closure size. + * + * TODO(perf-cache-key): if hot, options are a driver-precomputed hash piggybacked through the + * proto, per-instance last-key memoization, or a two-tier cache keyed on the generated source. + */ + final case class CacheKey(bytesKey: ByteBuffer, specs: IndexedSeq[ArrowColumnSpec]) + + /** Snapshot of dispatcher cache counters and current size. */ + final case class DispatcherStats(compileCount: Long, cacheHitCount: Long, cacheSize: Int) { + def hitRate: Double = + if (totalLookups == 0) 0.0 else cacheHitCount.toDouble / totalLookups.toDouble + + def totalLookups: Long = compileCount + cacheHitCount + } + + private case class CacheEntry( + compiled: CometBatchKernelCodegen.CompiledKernel, + kernel: CometBatchKernel, + outputType: DataType, + outputField: Field) +} diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..2ae589a996 --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, NullIntolerant} + +/** + * Per-profile view of expression traits that shifted shape across Spark versions. Spark 3.x has a + * `NullIntolerant` marker trait and no scalar-expression `Stateful` concept (added in 4.x as a + * boolean method on `Expression`). Routing checks through one shim avoids version pattern matches + * in the codegen dispatcher. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.isInstanceOf[NullIntolerant] + + // Aggregate/window/generator stateful cases are rejected elsewhere in `canHandle`, so treating + // all scalar expressions as non-stateful here is conservative-correct on this profile. + def isStateful(expr: Expression): Boolean = false + + // No collation / `ResolvedCollation` concept in 3.x. + def isCodegenInertUnevaluable(expr: Expression): Boolean = false +} diff --git a/spark/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala b/spark/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..18b3a4e6b3 --- /dev/null +++ b/spark/src/main/spark-3.x/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +/** + * Per-profile shim mixed into `CometInternalRow` and `CometArrayData`. Spark 4.x adds abstract + * `SpecializedGetters` getters (`getVariant` in 4.0, `getGeography` and `getGeometry` in 4.1) + * that subclasses must implement; Spark 3.x has none, so this trait is empty. + */ +trait CometInternalRowShim diff --git a/spark/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala b/spark/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..b855fe3a91 --- /dev/null +++ b/spark/src/main/spark-4.0/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.VariantVal + +/** + * Throwing defaults for Spark 4.0 `SpecializedGetters` additions: `getVariant`. Mixed into + * `CometInternalRow` and `CometArrayData` so the codegen kernel's subclasses satisfy the + * abstract-method check at class-load time. 4.1 also adds `getGeography` / `getGeometry` (see the + * spark-4.1 shim). + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") +} diff --git a/spark/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala b/spark/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..ce4cb7c06f --- /dev/null +++ b/spark/src/main/spark-4.1/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing defaults for Spark 4.x `SpecializedGetters` additions: `getVariant` (4.0), + * `getGeography` and `getGeometry` (4.1). Mixed into `CometInternalRow` and `CometArrayData`. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/spark/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala b/spark/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala new file mode 100644 index 0000000000..ce4cb7c06f --- /dev/null +++ b/spark/src/main/spark-4.2/org/apache/comet/shims/CometInternalRowShim.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, VariantVal} + +/** + * Throwing defaults for Spark 4.x `SpecializedGetters` additions: `getVariant` (4.0), + * `getGeography` and `getGeometry` (4.1). Mixed into `CometInternalRow` and `CometArrayData`. + */ +trait CometInternalRowShim { + def getVariant(ordinal: Int): VariantVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getVariant not supported") + + def getGeography(ordinal: Int): GeographyVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeography not supported") + + def getGeometry(ordinal: Int): GeometryVal = + throw new UnsupportedOperationException( + s"${getClass.getSimpleName}: getGeometry not supported") +} diff --git a/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala b/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala new file mode 100644 index 0000000000..a9a3d26bba --- /dev/null +++ b/spark/src/main/spark-4.x/org/apache/comet/shims/CometExprTraitShim.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.shims + +import org.apache.spark.sql.catalyst.expressions.{Expression, ResolvedCollation} + +/** + * Spark 4.x replaced the `NullIntolerant` marker trait with a boolean method on `Expression` and + * added a `stateful` boolean. Neither exists as a trait in 4.x. This shim routes the checks + * through the method form. + */ +trait CometExprTraitShim { + def isNullIntolerant(expr: Expression): Boolean = expr.nullIntolerant + def isStateful(expr: Expression): Boolean = expr.stateful + + // `ResolvedCollation` is an `Unevaluable` leaf living only in `Collate.collation` as a + // type-level marker. `Collate.genCode` passes through to its child and never invokes it. Spark + // 4.1 analyzes it away; 4.0 leaves it in the tree, so the dispatcher's `Unevaluable` guard + // would trip without this exemption. + def isCodegenInertUnevaluable(expr: Expression): Boolean = expr match { + case _: ResolvedCollation => true + case _ => false + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala index 63d8f44ebd..a0a5c48041 100644 --- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala @@ -247,7 +247,7 @@ class CometArrayExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelp .withColumn("arrUnsupportedArgs", expr("array_insert(arr, idx, 1)")) checkSparkAnswerAndFallbackReasons( df.select("arrUnsupportedArgs"), - Set("scalaudf is not supported", "unsupported arguments for ArrayInsert")) + Set("ScalaUDF has no native path", "unsupported arguments for ArrayInsert")) } } } diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala b/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala new file mode 100644 index 0000000000..13334a5134 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenAssertions.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.arrow.vector.ValueVector +import org.apache.spark.sql.types.DataType + +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * Shared assertions for the codegen-dispatcher test suites. Mix in alongside `CometTestBase`. + */ +trait CometCodegenAssertions { + + /** Asserts the dispatcher actually ran during `f`, guarding against silent serde fallback. */ + protected def assertCodegenRan(f: => Unit): Unit = { + CometScalaUDFCodegen.resetStats() + f + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Asserts the composed subtree fused into one kernel signature, not N (one per sub-expression). + * Uses the JVM-wide signature set rather than `compileCount` because per-task `boundExpr` + * isolation makes multi-partition queries trip `compileCount > 1` even when the bytecode is + * shared. + */ + protected def assertOneKernelForSubtree(f: => Unit): Unit = { + CometScalaUDFCodegen.resetStats() + val sigsBefore = CometScalaUDFCodegen.snapshotCompiledSignatures() + f + val sigsAfter = CometScalaUDFCodegen.snapshotCompiledSignatures() + val grew = sigsAfter.size - sigsBefore.size + assert( + grew <= 1, + s"expected <= 1 new compiled-kernel signature for the composed subtree, grew by $grew; " + + s"new=${sigsAfter -- sigsBefore}") + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount + after.cacheHitCount >= 1, + s"expected codegen dispatcher activity, got $after") + } + + /** + * Asserts a kernel matching the given input Arrow vector classes and output type sits in the + * JVM-wide signature set. Pair with `assertCodegenRan` since the set is append-only. Compares + * by simple name to be robust to Arrow shading. + */ + protected def assertKernelSignaturePresent( + inputs: Seq[Class[_ <: ValueVector]], + output: DataType): Unit = { + val sigs = CometScalaUDFCodegen.snapshotCompiledSignatures() + val expectedNames = inputs.map(_.getSimpleName).toIndexedSeq + val present = sigs.exists { case (cached, dt) => + dt == output && cached.map(_.getSimpleName) == expectedNames + } + assert( + present, + s"expected kernel signature $expectedNames -> $output; " + + s"cache had ${sigs.map { case (c, d) => (c.map(_.getSimpleName), d) }}") + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala new file mode 100644 index 0000000000..c87d48352b --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenFuzzSuite.scala @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.io.File +import java.text.SimpleDateFormat + +import scala.util.Random + +import org.apache.commons.io.FileUtils +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ + +import org.apache.comet.DataTypeSupport.isComplexType +import org.apache.comet.testing.{DataGenOptions, FuzzDataGenerator, ParquetGenerator, SchemaGenOptions} + +/** + * Randomized end-to-end tests for the Arrow-direct codegen dispatcher: schema-driven coverage of + * every input vector class against random parquet, plus a decimal precision-scale sweep across + * the `Decimal.MAX_LONG_DIGITS=18` boundary at varying null densities. Extends [[CometTestBase]] + * (not [[CometFuzzTestBase]]) because the base's `shuffle` x `nativeC2R` cross-product is + * irrelevant for projection-only queries. + */ +class CometCodegenFuzzSuite + extends CometTestBase + with AdaptiveSparkPlanHelper + with CometCodegenAssertions { + + /** Random schema with primitives plus shallow arrays and structs. No maps, no deep nesting. */ + private var mixedTypesFilename: String = _ + + /** Random schema with deeply nested arrays / structs / maps. */ + private var nestedTypesFilename: String = _ + + /** Asia/Kathmandu has a non-zero minute offset (UTC+5:45); good for timezone edge cases. */ + private val defaultTimezone = "Asia/Kathmandu" + + override def beforeAll(): Unit = { + super.beforeAll() + val tempDir = System.getProperty("java.io.tmpdir") + val random = new Random(42) + val dataGenOptions = DataGenOptions( + generateNegativeZero = false, + baseDate = new SimpleDateFormat("YYYY-MM-DD hh:mm:ss") + .parse("2024-05-25 12:34:56") + .getTime) + + mixedTypesFilename = s"$tempDir/CometCodegenFuzzSuite_${System.currentTimeMillis()}.parquet" + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { + val schemaGenOptions = + SchemaGenOptions(generateArray = true, generateStruct = true) + ParquetGenerator.makeParquetFile( + random, + spark, + mixedTypesFilename, + 1000, + schemaGenOptions, + dataGenOptions) + } + + nestedTypesFilename = + s"$tempDir/CometCodegenFuzzSuite_nested_${System.currentTimeMillis()}.parquet" + withSQLConf( + CometConf.COMET_ENABLED.key -> "false", + SQLConf.SESSION_LOCAL_TIMEZONE.key -> defaultTimezone) { + val schemaGenOptions = + SchemaGenOptions(generateArray = true, generateStruct = true, generateMap = true) + val schema = FuzzDataGenerator.generateNestedSchema( + random, + numCols = 10, + minDepth = 2, + maxDepth = 4, + options = schemaGenOptions) + ParquetGenerator.makeParquetFile( + random, + spark, + nestedTypesFilename, + schema, + 1000, + dataGenOptions) + } + + spark.read.parquet(mixedTypesFilename).createOrReplaceTempView("t1") + spark.read.parquet(nestedTypesFilename).createOrReplaceTempView("t2") + } + + protected override def afterAll(): Unit = { + super.afterAll() + FileUtils.deleteDirectory(new File(mixedTypesFilename)) + FileUtils.deleteDirectory(new File(nestedTypesFilename)) + } + + private val RowCount: Int = 512 + private val nullDensities: Seq[Double] = Seq(0.0, 0.1, 0.5, 1.0) + // (precision, scale) shapes spanning both sides of `Decimal.MAX_LONG_DIGITS=18`: small short, + // boundary short with varying scale, just-past-boundary long, and max decimal128. + private val decimalShapes: Seq[(Int, Int)] = Seq((9, 2), (18, 0), (18, 9), (19, 0), (38, 10)) + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key, "true") + + /** + * Identity ScalaUDF for one of the 14 primitive types in + * [[org.apache.comet.testing.SchemaGenOptions.defaultPrimitiveTypes]]. Returns the registered + * name when the type maps to a known Scala arg, or `None` for shapes we choose not to probe. + * `BigDecimal` UDF args are encoded as `DecimalType(38, 18)`; Spark inserts an implicit cast + * around the call but the underlying column read still hits our kernel's `getDecimal` at the + * column's native precision. + */ + private def registerIdentityUdfFor(dt: DataType, name: String): Option[String] = dt match { + case _: BooleanType => spark.udf.register(name, (x: Boolean) => x); Some(name) + case _: ByteType => spark.udf.register(name, (x: Byte) => x); Some(name) + case _: ShortType => spark.udf.register(name, (x: Short) => x); Some(name) + case _: IntegerType => spark.udf.register(name, (x: Int) => x); Some(name) + case _: LongType => spark.udf.register(name, (x: Long) => x); Some(name) + case _: FloatType => spark.udf.register(name, (x: Float) => x); Some(name) + case _: DoubleType => spark.udf.register(name, (x: Double) => x); Some(name) + case _: DecimalType => + spark.udf.register(name, (x: java.math.BigDecimal) => x); Some(name) + case _: DateType => spark.udf.register(name, (x: java.sql.Date) => x); Some(name) + case _: TimestampType => + spark.udf.register(name, (x: java.sql.Timestamp) => x); Some(name) + case _: TimestampNTZType => + spark.udf.register(name, (x: java.time.LocalDateTime) => x); Some(name) + case _: StringType => spark.udf.register(name, (x: String) => x); Some(name) + case _: BinaryType => spark.udf.register(name, (x: Array[Byte]) => x); Some(name) + case _ => None + } + + /** + * Identity-Int UDF for the cardinality-based complex probe. One UDF covers every Array and Map + * column, regardless of element type. + * + * Avoids `Seq[T]` / `Map[K, V]` UDF arg materialization: Spark's `MapObjects.doGenCode` reads + * each element unconditionally and null-checks afterward, so on null positions of a + * dictionary-encoded primitive Arrow vector the garbage ID buffer feeds + * `dictionary.decodeToLong/decodeToFloat` and throws `ArrayIndexOutOfBoundsException`. Bug + * reproduces in pure Spark; `cardinality(col)` exercises `getArray`/`getMap` without entering + * the element deserializer. + */ + private lazy val cardinalityProbeUdf: String = { + val name = "sz_complex" + spark.udf.register(name, (i: Int) => i) + name + } + + test("identity ScalaUDF over every primitive column") { + val primitiveFields = + spark.table("t1").schema.fields.filterNot(f => isComplexType(f.dataType)) + assert(primitiveFields.nonEmpty, "expected at least one primitive column in random schema") + for (field <- primitiveFields) { + val udfName = s"id_${field.name}" + registerIdentityUdfFor(field.dataType, udfName) match { + case Some(_) => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName(${field.name}) FROM t1") + } + case None => + fail( + s"primitive column ${field.name}: ${field.dataType} not in identity UDF catalog; " + + "extend registerIdentityUdfFor") + } + } + } + + test("complex-probe ScalaUDF on every complex column") { + val complexFields = spark.table("t1").schema.fields.filter(f => isComplexType(f.dataType)) + assert(complexFields.nonEmpty, "expected at least one complex column in random schema") + for (field <- complexFields) { + probeComplexColumn(field, viewName = "t1") + } + } + + test("complex-probe ScalaUDF on top-level columns of deeply nested schema") { + for (field <- spark.table("t2").schema.fields) { + probeComplexColumn(field, viewName = "t2") + } + } + + /** + * Element-level fuzz for nested array reads: `ArrayMax.doGenCode` walks every element of every + * row, calling the kernel's nested element getter, the path the unsafe-getter optimization + * touches and which the cardinality probe deliberately skips. + */ + test("array_max element fuzz: every Array column") { + val arrayPrimitiveFields = spark.table("t1").schema.fields.filter { + case StructField(_, ArrayType(elemDt, _), _, _) if !isComplexType(elemDt) => true + case _ => false + } + assert( + arrayPrimitiveFields.nonEmpty, + "expected at least one Array column in random schema") + for (field <- arrayPrimitiveFields) { + val ArrayType(elemDt, _) = field.dataType: @unchecked + val udfName = s"id_arrmax_${field.name}" + registerIdentityUdfFor(elemDt, udfName) match { + case Some(_) => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName(array_max(${field.name})) FROM t1") + } + case None => + fail( + s"array column ${field.name} elem ${elemDt} not in identity UDF catalog; " + + "extend registerIdentityUdfFor") + } + } + } + + /** + * Map variant of the array element fuzz: `map_keys` / `map_values` produce arrays the kernel + * walks via `ArrayMax`, exercising the map's per-row offset chain (MapVector -> entries + * StructVector -> child) that the array test alone wouldn't catch. + */ + test("array_max element fuzz: map_keys / map_values on Map columns") { + val mapPrimitiveFields = spark.table("t2").schema.fields.filter { + case StructField(_, MapType(kDt, vDt, _), _, _) + if !isComplexType(kDt) && !isComplexType(vDt) => + true + case _ => false + } + for (field <- mapPrimitiveFields) { + val MapType(kDt, vDt, _) = field.dataType: @unchecked + registerIdentityUdfFor(kDt, s"id_mapk_${field.name}").foreach { udf => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udf(array_max(map_keys(${field.name}))) FROM t2") + } + } + registerIdentityUdfFor(vDt, s"id_mapv_${field.name}").foreach { udf => + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $udf(array_max(map_values(${field.name}))) FROM t2") + } + } + } + } + + /** + * Doubly-nested array element fuzz: `flatten(arr)` collapses `Array>` into `Array` + * (exercising the outer-array element getter that returns each inner ArrayData), then + * `array_max` walks the leaf X primitives. Closes the gap that the singly-nested + * `array_max(arr)` test alone leaves on doubly-nested primitive arrays. + */ + test("array_max element fuzz: flatten on Array> columns") { + val nestedArrayPrimitiveFields = spark.table("t2").schema.fields.filter { + case StructField(_, ArrayType(ArrayType(elemDt, _), _), _, _) if !isComplexType(elemDt) => + true + case _ => false + } + for (field <- nestedArrayPrimitiveFields) { + val ArrayType(ArrayType(elemDt, _), _) = field.dataType: @unchecked + val udfName = s"id_arrflat_${field.name}" + registerIdentityUdfFor(elemDt, udfName).foreach { _ => + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $udfName(array_max(flatten(${field.name}))) FROM t2") + } + } + } + } + + /** + * Element-level fuzz for `Array>`. `array_distinct` is a non-HOF unary expression + * that hashes each element to dedupe. Struct hashing is field-wise, so the kernel emits element + * reads on each struct's fields. `cardinality` consumes the result without materialization. + * Asserts the optimizer keeps `ArrayDistinct` so the coverage isn't vacuously folded. + */ + test("array_distinct element fuzz: Array> columns") { + val arrayStructFields = spark.table("t1").schema.fields.filter { + case StructField(_, ArrayType(_: StructType, _), _, _) => true + case _ => false + } + spark.udf.register("id_int_arrdistinct", (i: Int) => i) + for (field <- arrayStructFields) { + val q = s"SELECT id_int_arrdistinct(cardinality(array_distinct(${field.name}))) FROM t1" + val df = sql(q) + val plan = df.queryExecution.optimizedPlan.toString + val planLower = plan.toLowerCase + assert( + planLower.contains("array_distinct") || planLower.contains("arraydistinct"), + s"optimizer eliminated array_distinct on column ${field.name}; coverage would be " + + s"vacuous. plan=\n$plan") + assertCodegenRan { + checkSparkAnswerAndOperator(df) + } + } + } + + private def probeCardinality(accessor: String, viewName: String): Unit = { + assertCodegenRan { + checkSparkAnswerAndOperator( + s"SELECT $cardinalityProbeUdf(cardinality($accessor)) FROM $viewName") + } + } + + /** + * Top-level Array / Map produces a cardinality probe. Struct drills into each scalar child via + * `GetStructField`. Nested Array / Map sub-fields also get the cardinality probe (depth bound: + * deeper struct-of-struct nesting is skipped to keep the sweep finite). + */ + private def probeComplexColumn(field: StructField, viewName: String): Unit = { + field.dataType match { + case _: ArrayType | _: MapType => + probeCardinality(field.name, viewName) + + case st: StructType => + for (subField <- st.fields) { + val accessor = s"${field.name}.${subField.name}" + subField.dataType match { + case _: ArrayType | _: MapType => probeCardinality(accessor, viewName) + case dt if !isComplexType(dt) => + val udfName = s"id_${field.name}_${subField.name}" + registerIdentityUdfFor(dt, udfName).foreach { _ => + assertCodegenRan { + checkSparkAnswerAndOperator(s"SELECT $udfName($accessor) FROM $viewName") + } + } + case _ => // deeper struct nesting skipped + } + } + + case _ => + } + } + + /** Random `BigDecimal` values fitting `(precision, scale)`, with `nullDensity` of them null. */ + private def generateDecimals( + seed: Long, + precision: Int, + scale: Int, + nullDensity: Double): Seq[java.math.BigDecimal] = { + val rng = new Random(seed) + val intDigits = precision - scale + // `BigInt.apply(bits, rng)` samples uniformly on `[0, 2^bits - 1]`; bound to the decimal's + // integer-part range (10^intDigits - 1) so the result fits the schema. `BigInteger.bitLength` + // would overshoot slightly. Min with the exact max is cheap insurance. + val intMax = BigInt(10).pow(intDigits) - 1 + val bits = math.max(intMax.bitLength, 1) + (0 until RowCount).map { _ => + if (rng.nextDouble() < nullDensity) null + else { + val mag = BigInt(bits, rng).min(intMax) + val signed = if (rng.nextBoolean()) -mag else mag + new java.math.BigDecimal(signed.bigInteger, scale) + } + } + } + + private def withDecimalTable(decimalType: String, values: Seq[java.math.BigDecimal])( + f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + if (values.nonEmpty) { + val rows = values.map { v => + if (v == null) "(NULL)" else s"(${v.toPlainString})" + } + rows.grouped(64).foreach { batch => + sql(s"INSERT INTO t VALUES ${batch.mkString(", ")}") + } + } + f + } + } + + for { + density <- nullDensities + (precision, scale) <- decimalShapes + } { + test(s"decimal identity precision=$precision scale=$scale nullDensity=$density") { + spark.udf.register("dec_id_fuzz", (d: java.math.BigDecimal) => d) + val seed = ((precision * 31L) + scale) * 31L + density.hashCode + val values = generateDecimals(seed, precision, scale, density) + withDecimalTable(s"DECIMAL($precision, $scale)", values) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT dec_id_fuzz(d) FROM t")) + } + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenHOFSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenHOFSuite.scala new file mode 100644 index 0000000000..9b2511ce0d --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenHOFSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.spark.SparkConf +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +/** + * Higher-order function regression coverage for the codegen dispatcher. + * + * Spark's HOFs (`ArrayTransform`, `ArrayFilter`, `ArrayAggregate`, `ArrayExists`, `ZipWith`, + * `MapFilter`, etc.) all extend `CodegenFallback`. The dispatcher's `canHandle` admits them. + * `CodegenFallback.doGenCode` emits a single `((Expression) references[N]).eval(row)` call site + * per HOF. The kernel dispatches to `Expression.eval(InternalRow)`, which iterates the array, + * mutates `NamedLambdaVariable.value`'s `AtomicReference` per element, and recursively evaluates + * the lambda body. Lambda-body leaf reads resolve through the kernel's typed Arrow getters since + * the kernel is an `InternalRow`. + * + * Cost model: per-row interpreted-eval inside the HOF subtree. Surrounding native operators stay + * native. Surrounding non-HOF expressions stay codegen. + * + * Each Spark task gets its own `boundExpr` Java object. The dispatcher's compile cache lives on + * the per-task instance, not the companion, so concurrent partitions cannot race on a shared + * `NamedLambdaVariable.value`. The two-collects test below regresses this. + */ +class CometCodegenHOFSuite + extends CometTestBase + with AdaptiveSparkPlanHelper + with CometCodegenAssertions { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key, "true") + + private def withArrayIntTable(rows: String)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (a ARRAY) USING parquet") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + test("ArrayTransform inside identity ScalaUDF over Array") { + // Regresses the simplest HOF shape: `idArr(transform(a, x -> x + 1))`. Tree contains one + // CodegenFallback HOF. The kernel splices its interpreted-eval call site into the per-row + // body and the result ArrayData feeds the ListVector output writer. Null and empty rows + // exercise the HOF's null-on-null-arg path and the empty-iteration path. + spark.udf.register("idArr", (arr: Seq[Int]) => arr) + withArrayIntTable("(array(1, 2, 3)), (array(-5, 5)), (array()), (null)") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idArr(transform(a, x -> x + 1)) FROM t")) + } + } + } + + test("array_max over ArrayTransform inside identity ScalaUDF") { + // Regresses composed CodegenFallback subtrees: array_max consumes the ArrayData transform + // produces. Both run interpreted. The kernel splices both eval call sites into the same + // per-row body. Empty/null rows exercise array_max's null-on-empty path. + spark.udf.register("idIntBoxed", (i: java.lang.Integer) => i) + withArrayIntTable("(array(1, 2, 3)), (array(-5, 5)), (null), (array(0))") { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT idIntBoxed(array_max(transform(a, x -> x * 2))) FROM t")) + } + } + } + + test("array_max over ArrayFilter inside identity ScalaUDF") { + // Regresses ArrayFilter (distinct HOF class from ArrayTransform). Filter producing an + // empty array from non-empty input exercises array_max(emptyArray) downstream. + spark.udf.register("idIntBoxed", (i: java.lang.Integer) => i) + withArrayIntTable("(array(1, -1, 2)), (array(-5, -2)), (array()), (null)") { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT idIntBoxed(array_max(filter(a, x -> x > 0))) FROM t")) + } + } + } + + test("HOF query produces correct results across two collects (per-task isolation regression)") { + // Regresses the per-task `boundExpr` isolation. When the dispatcher's compile cache lived on + // the companion object, multiple tasks shared one `boundExpr` and concurrent partitions + // raced on `NamedLambdaVariable.value`'s `AtomicReference`, producing off-by-one element + // values. The fix moved the cache to the per-task instance so each task deserializes its own + // boundExpr. Two collects of the same query must each match Spark's interpreter. + spark.udf.register("idArr", (arr: Seq[Int]) => arr) + withArrayIntTable("(array(1, 2)), (array(3, 4)), (array(5))") { + val q = "SELECT idArr(transform(a, x -> x + 1)) FROM t" + checkSparkAnswerAndOperator(sql(q)) + checkSparkAnswerAndOperator(sql(q)) + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala new file mode 100644 index 0000000000..27a5830c6d --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSourceSuite.scala @@ -0,0 +1,1098 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Add, BoundReference, Coalesce, Concat, CreateArray, CreateMap, ElementAt, Expression, GetStructField, LeafExpression, Length, Literal, Nondeterministic, Rand, Size, Unevaluable, Upper} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.types._ + +import org.apache.comet.codegen.CometBatchKernelCodegen +import org.apache.comet.codegen.CometBatchKernelCodegen.{ArrayColumnSpec, ArrowColumnSpec, MapColumnSpec, ScalarColumnSpec, StructColumnSpec, StructFieldSpec} +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +// Resolve Arrow vector classes through the codegen object so tests see the same `Class` objects +// the codegen pattern-matches against, regardless of any future shading rearrangement. + +/** + * Generated-source inspection tests. These exercise `CometBatchKernelCodegen.generateSource` and + * assert on the emitted Java directly, without invoking Janino. The goal is to catch regressions + * in the optimizations we claim the dispatcher applies: + * + * - `NullIntolerant` short-circuit wraps `ev.code` in `if (any-input-null) { setNull } else { + * ev.code; write }`. + * - Non-nullable column declaration emits `return false;` from `isNullAt(ord)`, and a + * `BoundReference.nullable=false` (Catalyst sets this from schema-declared nullability) makes + * Spark's `doGenCode` skip emitting its own `row.isNullAt(ord)` probe entirely. + * - Zero-copy string reads route through `UTF8String.fromAddress`. + * + * These are the smallest durable tests that the claimed optimizations actually reach the + * generated Java, and they document the shapes future contributors should preserve. + */ +class CometCodegenSourceSuite extends AnyFunSuite { + + private val varCharVectorClass = + CometBatchKernelCodegen.vectorClassBySimpleName("VarCharVector") + + private val nullableString = ArrowColumnSpec(varCharVectorClass, nullable = true) + private val nonNullableString = ArrowColumnSpec(varCharVectorClass, nullable = false) + + private def gen( + expr: org.apache.spark.sql.catalyst.expressions.Expression, + specs: ArrowColumnSpec*): String = + CometBatchKernelCodegen.generateSource(expr, specs.toIndexedSeq).body + + test("non-nullable column emits literal-false isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + src.contains("case 0: return false;"), + s"expected non-nullable isNullAt to return literal false; got:\n$src") + } + + test("non-nullable BoundReference elides Spark's own isNullAt probe in the expression body") { + // When the BoundReference carries `nullable=false` (Catalyst sets this from schema-declared + // nullability), Spark's `doGenCode` skips the `row.isNullAt(ord)` branch at source level. + // The dispatcher does not derive runtime nullability anymore. The BoundReference's source + // flag is the sole signal, and schema-non-null columns get full elision for free. + val expr = Length(BoundReference(0, StringType, nullable = false)) + val src = gen(expr, nonNullableString) + assert( + !src.contains("row.isNullAt(0)"), + s"expected Spark's BoundReference null probe to be elided; got:\n$src") + } + + test("nullable column emits delegated isNullAt case") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("case 0: return this.col0.isNull(this.rowIdx);"), + s"expected nullable isNullAt to delegate to the Arrow vector; got:\n$src") + } + + test("VarCharVector getUTF8String uses zero-copy fromAddress") { + val expr = Length(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("org.apache.spark.unsafe.types.UTF8String"), + s"expected UTF8String reference; got:\n$src") + assert(src.contains(".fromAddress("), s"expected zero-copy fromAddress read; got:\n$src") + } + + test("NullIntolerant expression emits input-null short-circuit before ev.code") { + // Upper is NullIntolerant (null in -> null out). Expect the default body to prepend + // `if (this.col0.isNull(i)) { setNull; } else { ... }` so null rows skip the whole + // expression eval, not just the setNull write. + val expr = Upper(BoundReference(0, StringType, nullable = true)) + val src = gen(expr, nullableString) + assert( + src.contains("this.col0.isNull(i)"), + s"expected NullIntolerant short-circuit on input ordinal 0; got:\n$src") + assert( + src.contains("output.setNull(i);"), + s"expected setNull emission for short-circuited null rows; got:\n$src") + } + + test("NullIntolerant short-circuit emitted when every node is NullIntolerant") { + // Length(Upper(BoundReference)): Length is NullIntolerant, Upper is NullIntolerant, + // BoundReference is a leaf. Every path from a leaf to the root propagates nulls, so the + // short-circuit heuristic ("any input null -> output null") holds. + val expr = Length(Upper(BoundReference(0, StringType, nullable = true))) + val src = gen(expr, nullableString) + assert( + src.contains("if (this.col0.isNull(i))"), + s"expected short-circuit on col0 when every node is NullIntolerant; got:\n$src") + } + + test("NullIntolerant short-circuit skipped when a non-NullIntolerant node breaks the chain") { + // Concat is not NullIntolerant. Null in some args doesn't necessarily produce a null + // result. The short-circuit heuristic would be incorrect here (short-circuiting on c0 or c1 + // being null would skip evaluation, but Concat's null handling differs). Expect the + // default path without the `if (colX.isNull(i) || colY.isNull(i))` wrapper, letting Spark's + // own `ev.code` handle nulls correctly. + val nullable1 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val nullable2 = ArrowColumnSpec(varCharVectorClass, nullable = true) + val expr = Length( + Concat( + Seq( + BoundReference(0, StringType, nullable = true), + BoundReference(1, StringType, nullable = true)))) + val src = gen(expr, nullable1, nullable2) + assert( + !src.contains("this.col0.isNull(i) || this.col1.isNull(i)"), + "expected no pre-null short-circuit when Concat breaks the NullIntolerant chain; " + + s"got:\n$src") + } + + test("canHandle rejects CodegenFallback expressions") { + val expr = FakeCodegenFallback(BoundReference(0, StringType, nullable = true)) + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject CodegenFallback") + assert( + reason.get.contains("FakeCodegenFallback"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("canHandle accepts Nondeterministic expressions (per-partition kernel handles state)") { + // Each cache entry holds one kernel instance with `init(partitionIndex)` called once, so + // Rand / Uuid / etc. produce the expected per-partition sequences across batches. The + // previous canHandle rejection was conservative. With that caching in place, accepting + // Nondeterministic is correct. + val expr = FakeNondeterministic() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isEmpty, s"expected canHandle to accept Nondeterministic; got $reason") + } + + test("canHandle rejects Unevaluable expressions") { + val expr = FakeUnevaluable() + val reason = CometBatchKernelCodegen.canHandle(expr) + assert(reason.isDefined, "expected canHandle to reject Unevaluable") + assert( + reason.get.contains("FakeUnevaluable"), + s"expected reason to name the rejected expression class; got: ${reason.get}") + } + + test("CSE collapses a repeated subtree to one evaluation in the generated body") { + // `Add(Length(Upper(c0)), Length(Upper(c0)))` has `Length(Upper(c0))` as a common subtree. + // Length.doGenCode emits `$value.numChars()` on every Spark version the project targets, + // which makes it a stable activation marker. Upper's own doGenCode text drifts across + // versions (Spark 3.5 emits `UTF8String.toUpperCase()`, Spark 4 emits + // `CollationSupport.Upper.exec*` via collation-aware codegen), so we avoid it as a marker. + // When CSE fires, `Length(Upper(c0))` compiles into one `subExpr_*` helper whose body calls + // `numChars()` once. Both uses in the `Add` read the cached result from mutable state. + // Without CSE, each Add child would emit its own `numChars()` call. + val upperOrd0 = Upper(BoundReference(0, StringType, nullable = true)) + val lenUpper = Length(upperOrd0) + val expr = Add(lenUpper, lenUpper) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val occurrences = "\\.numChars\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 1, + "expected CSE to collapse repeated Length evaluation to 1 numChars() call, " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + // Additional proof: CSE emitted a `subExpr_` helper method. Without CSE the generator would + // have inlined the repeated subtree into the main body with no helper at all. + assert( + result.body.contains("subExpr_0(row)"), + s"expected CSE helper invocation; got:\n${CodeFormatter.format(result.code)}") + } + + test("CSE does not fire on non-deterministic expressions (regression guard)") { + // `Add(Rand(0), Rand(0))` is two structurally identical non-deterministic subtrees. CSE must + // not collapse them: each Rand call must produce an independent draw. Spark's CSE + // (`EquivalentExpressions.updateExprInMap`) filters non-deterministic expressions via + // `expr.deterministic`, so the two Rands stay separate. This test is a regression guard + // against Spark ever relaxing that check and against us accidentally applying CSE outside + // the `generateExpressions` path (which respects the filter). `Rand.doGenCode` emits one + // `$rng.nextDouble()` call per evaluation, so two Rands produce two `.nextDouble()` calls + // in the body. One-call output would indicate incorrect CSE. + val expr = Add(Rand(Literal(0L, LongType)), Rand(Literal(0L, LongType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty) + val occurrences = "\\.nextDouble\\(\\)".r.findAllIn(result.body).size + assert( + occurrences == 2, + "expected two independent Rand evaluations (no CSE on nondeterministic), " + + s"got $occurrences; src=\n${CodeFormatter.format(result.code)}") + } + + test("DecimalVector getDecimal specializes to unscaled-long fast path for short precision") { + // Mirrors Spark's `UnsafeRow.getDecimal` split at `Decimal.MAX_LONG_DIGITS` (18), done at + // codegen time rather than at runtime. The dispatcher reads the `BoundReference`'s + // `DecimalType` at source-generation time and emits only the fast-path branch when + // `precision <= 18`. The fast path reads the low 8 bytes of the 16-byte Arrow decimal128 + // slot directly as a signed long via `ArrowBuf.getLong` and wraps with + // `Decimal.createUnsafe`, avoiding the `BigDecimal` allocation `DecimalVector.getObject` + // would perform. For precision > 18 the generator emits only the slow-path branch + // (`getObject + Decimal.apply`); see the companion test below. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".createUnsafe("), + "expected Decimal.createUnsafe call on fast path; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains("Platform.getLong(") && + result.body.contains("this.col0_valueAddr"), + "expected unsafe Platform.getLong against cached valueAddr; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains(".getObject("), + "expected specialized fast path (no BigDecimal fallback branch in source); got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known short-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector getDecimal specializes to BigDecimal slow path for long precision") { + // Companion to the fast-path test. For `DecimalType(p, s)` with `p > 18`, the unscaled value + // can exceed 64 bits, so the generator emits only the `getObject + Decimal.apply` branch. + // The fast path markers must be absent so the generated source is minimal for this column. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".getObject(") && result.body.contains(".apply("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".createUnsafe("), + "expected no fast-path emission for long-precision column; got:\n" + + CodeFormatter.format(result.code)) + assert( + !result.body.contains("if (precision <= 18)"), + "expected no runtime precision branch for known long-precision column; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses unscaled-long fast path for short-precision output") { + // The output writer specializes on the root expression's DecimalType precision. For + // precision <= 18 the Decimal's unscaled long is passed directly to + // `DecimalVector.setSafe(int, long)`, avoiding the BigDecimal allocation that + // `toJavaBigDecimal()` performs. Use a simple expression that produces a DecimalType output: + // `BoundReference(0, DecimalType(18, 2))` has output type DecimalType(18, 2), which is what + // the generator specializes on. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(18, 2), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toUnscaledLong()"), + s"expected toUnscaledLong call on fast path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toJavaBigDecimal("), + "expected no BigDecimal allocation for short-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("DecimalVector setSafe uses BigDecimal slow path for long-precision output") { + // Companion to the fast-path output test. Precision > 18 can have unscaled values exceeding + // 64 bits, so the writer must fall back to the BigDecimal path. + val decimalVectorClass = CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector") + val spec = ArrowColumnSpec(decimalVectorClass, nullable = true) + val expr = BoundReference(0, DecimalType(38, 10), nullable = true) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(spec)) + assert( + result.body.contains(".toJavaBigDecimal("), + s"expected BigDecimal slow path; got:\n${CodeFormatter.format(result.code)}") + assert( + !result.body.contains(".toUnscaledLong()"), + "expected no unscaled-long write for long-precision output; got:\n" + + CodeFormatter.format(result.code)) + } + + test("VarCharVector setSafe uses on-heap UTF8String shortcut") { + // The UTF8String output writer avoids the `byte[] b = $value.getBytes()` allocation when + // the UTF8String is on-heap by passing its backing byte[] directly to + // `VarCharVector.setSafe(int, byte[], int, int)`. Spark's string functions allocate their + // result on-heap, so this path hits for typical string expressions. Off-heap fallback + // (for passthrough of zero-copy input reads) stays as the else branch. + // + // Markers: `getBaseObject()` (inspecting the backing), `instanceof byte[]` (the branch), + // and `Platform.BYTE_ARRAY_OFFSET` (the on-heap offset math). + val expr = Upper(BoundReference(0, StringType, nullable = true)) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + assert( + result.body.contains(".getBaseObject()"), + s"expected UTF8String.getBaseObject call; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("instanceof byte[]"), + s"expected on-heap instanceof branch; got:\n${CodeFormatter.format(result.code)}") + assert( + result.body.contains("Platform.BYTE_ARRAY_OFFSET"), + "expected on-heap offset math via Platform.BYTE_ARRAY_OFFSET; got:\n" + + CodeFormatter.format(result.code)) + assert( + result.body.contains(".getBytes()"), + s"expected off-heap getBytes fallback; got:\n${CodeFormatter.format(result.code)}") + } + + test("non-nullable root expression omits the `if (isNull)` branch in default body") { + // When the bound expression claims `nullable = false`, the default body drops the + // `if (ev.isNull) output.setNull(i);` guard entirely. `Length` on a non-nullable column is + // itself non-nullable (Length.nullable = child.nullable = false), so the writer goes + // straight to the setSafe/set call. This test uses a non-NullIntolerant-short-circuit + // shape by wrapping Length in Coalesce, so we exercise the default branch of defaultBody + // rather than the NullIntolerant one. Actually, Length is NullIntolerant, so the NI branch + // fires. Use an expression that's non-nullable but whose tree is not fully NullIntolerant + // to hit the default branch. `Coalesce(Seq(Length(col_non_null), Literal(0)))` has + // nullable=false (Coalesce is non-null when any child is) and Coalesce itself is not + // NullIntolerant, so the default branch runs. Assert `setNull` is absent. + val expr = Coalesce( + Seq(Length(BoundReference(0, StringType, nullable = false)), Literal(0, IntegerType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nonNullableString)) + assert( + !result.body.contains("output.setNull(i);"), + "expected no setNull for a non-nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("nullable root expression keeps the `if (isNull)` branch in default body") { + // Baseline: when the root expression is nullable, the setNull branch must still be emitted. + // Uses Coalesce with a nullable child so the Coalesce itself remains nullable. Guards the + // NonNullableOutputShortCircuit optimization against over-firing. + val expr = Coalesce( + Seq( + Length(BoundReference(0, StringType, nullable = true)), + BoundReference(1, IntegerType, nullable = true))) + val result = CometBatchKernelCodegen.generateSource( + expr, + IndexedSeq( + nullableString, + ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true))) + assert( + result.body.contains("output.setNull(i);"), + "expected setNull branch for a nullable root expression; got:\n" + + CodeFormatter.format(result.code)) + } + + test("ArrayType(StringType) output emits ListVector startNewValue/endValue recursion") { + // CreateArray over a BoundReference(StringType) produces ArrayType(StringType). emitWrite's + // ArrayType case should emit: + // - ListVector cast of output + // - child VarCharVector extraction via getDataVector + // - startNewValue + per-element loop + endValue + // - the per-element write recursing into the StringType case (which uses the UTF8 on-heap + // shortcut marker `instanceof byte[]`) + // Focus markers: ListVector cast, VarCharVector child cast, startNewValue, endValue, and + // the inner UTF8 shortcut branch. + val expr = + CreateArray( + Seq(BoundReference(0, StringType, nullable = true), Literal.create("x", StringType))) + val result = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(nullableString)) + val src = result.body + val formatted = CodeFormatter.format(result.code) + assert(src.contains("ListVector"), s"expected ListVector in emitted body; got:\n$formatted") + assert(src.contains(".startNewValue("), s"expected startNewValue call; got:\n$formatted") + assert(src.contains(".endValue("), s"expected endValue call; got:\n$formatted") + assert( + src.contains(".getDataVector()"), + s"expected child vector extraction; got:\n$formatted") + assert( + src.contains("instanceof byte[]"), + s"expected inner UTF8 on-heap shortcut for string elements; got:\n$formatted") + } + + test("MapType output emits MapVector startNewValue/endValue + per-pair writes") { + // CreateMap produces MapType(k, v). emitWrite's MapType case should emit: + // - MapVector cast of output + // - entries StructVector extraction + // - typed key / value child casts via getChildByOrdinal(0) / (1) + // - startNewValue / endValue bracketing + // - setIndexDefined on each struct entry + // - keyArray() / valueArray() retrieval from the MapData source + // Non-null literals here mean `valueContainsNull == false`, so the value-side null guard is + // elided. The existence and elision of the `isNullAt` guard are exercised by the dedicated + // [[NullableElementElision]] tests below. + val expr = CreateMap( + Seq( + Literal.create("a", StringType), + Literal(1, IntegerType), + Literal.create("b", StringType), + Literal(2, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + Seq( + "MapVector", + "StructVector", + ".startNewValue(", + ".endValue(", + ".setIndexDefined(", + ".keyArray()", + ".valueArray()").foreach { marker => + assert(src.contains(marker), s"expected $marker in MapType output emission; got:\n$src") + } + } + + test("ArrayType output elides isNullAt on the element loop when containsNull is false") { + // CreateArray over only-non-null Literals produces ArrayType(elementType, containsNull=false). + // The element write should drop the `arr.isNullAt(j)` guard at source level rather than + // relying on JIT folding. + val expr = CreateArray(Seq(Literal(1, IntegerType), Literal(2, IntegerType))) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq.empty).body + assert( + !src.contains(".isNullAt("), + s"expected no isNullAt in element loop when containsNull=false; got:\n$src") + assert(src.contains(".startNewValue("), s"expected startNewValue still emitted; got:\n$src") + } + + test("ArrayType output keeps isNullAt on the element loop when containsNull is true") { + // CreateArray with at least one nullable child produces containsNull=true. The element + // null-guard must survive. + val expr = + CreateArray(Seq(BoundReference(0, IntegerType, nullable = true), Literal(2, IntegerType))) + val intSpec = ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body + assert( + src.contains(".isNullAt("), + s"expected isNullAt in element loop when containsNull=true; got:\n$src") + } + + test("MapType output keeps value isNullAt when valueContainsNull is true") { + // ElementAt with safe-index selection produces a nullable Int. Wrapping the value column in + // a CreateMap with that nullable Int makes valueContainsNull=true. The value-side null-guard + // must survive. + val expr = + CreateMap( + Seq(Literal.create("a", StringType), BoundReference(0, IntegerType, nullable = true))) + val intSpec = ArrowColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(intSpec)).body + assert( + src.contains(".isNullAt("), + s"expected isNullAt on the value-write branch when valueContainsNull=true; got:\n$src") + } + + test("ArrayType(StringType) input emits InputArray_col0 nested class with UTF8 child getter") { + // Array input with string elements: the kernel must expose a `getArray(0)` that hands Spark's + // `doGenCode` an `ArrayData` view onto the Arrow `ListVector`'s child `VarCharVector`. + // Markers: the nested class declaration with a slice constructor, the typed child getter + // using `fromAddress`, and a `getArray` switch on the ordinal that allocates a fresh view. + val varCharChildSpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = varCharChildSpec) + val expr = Size(BoundReference(0, ArrayType(StringType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("class InputArray_col0"), + s"expected nested ArrayData class for array col0; got:\n$src") + assert( + src.contains("InputArray_col0(int startIdx, int len)"), + s"expected InputArray_col0 to take a slice via constructor; got:\n$src") + assert( + src.contains("getElementStartIndex(") && src.contains("getElementEndIndex("), + s"expected list-offset reads at the call site; got:\n$src") + assert( + src.contains("public org.apache.spark.unsafe.types.UTF8String getUTF8String(int i)"), + s"expected element-type-specific UTF8String getter; got:\n$src") + assert( + src.contains(".fromAddress("), + s"expected zero-copy UTF8 read inside the nested ArrayData; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.ArrayData getArray(int ordinal)"), + s"expected kernel-level getArray switch; got:\n$src") + assert( + src.contains("return new InputArray_col0("), + s"expected getArray to allocate a fresh InputArray_col0 view; got:\n$src") + } + + test("ArrayType(IntegerType) input emits primitive int getter in nested class") { + val intChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val arraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = intChildSpec) + val expr = Size(BoundReference(0, ArrayType(IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains("public int getInt(int i)"), + s"expected primitive int getter on nested array class; got:\n$src") + // Scalar-element fast path reads directly off the typed child vector. No BigDecimal / + // fromAddress scaffolding should leak in. + assert( + !src.contains(".fromAddress("), + s"int element getter should not wrap with UTF8 fromAddress; got:\n$src") + } + + test( + "ArrayType(DecimalType) short-precision input emits decimal128 fast-path via getLong in " + + "nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(10, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(10, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + // Fast path markers: reads the low 8 bytes of the decimal128 slot via getLong + createUnsafe. + // The slow path would go through getObject + Decimal.apply. + assert( + src.contains(".getLong(") && src.contains(".createUnsafe("), + s"expected decimal-input short-precision fast path in nested class; got:\n$src") + assert( + !src.contains(".getObject("), + s"short-precision decimal element should not use BigDecimal slow path; got:\n$src") + } + + test("ArrayType(DecimalType) long-precision input emits BigDecimal slow path in nested class") { + val decimalChildSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("DecimalVector"), + nullable = true) + val arraySpec = ArrayColumnSpec( + nullable = true, + elementSparkType = DecimalType(30, 2), + element = decimalChildSpec) + val expr = + ElementAt( + BoundReference(0, ArrayType(DecimalType(30, 2)), nullable = true), + Literal(1, IntegerType)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(arraySpec)).body + + assert( + src.contains(".getObject(") && src.contains("Decimal$.MODULE$"), + s"expected BigDecimal slow path for p>18 element; got:\n$src") + } + + private def generate(expr: Expression, specs: IndexedSeq[ArrowColumnSpec]): String = + CometBatchKernelCodegen.generateSource(expr, specs).body + + test("Array> emits outer + inner array classes with fresh inner allocation") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = innerArray) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputArray_col0_e "), + s"expected both outer and inner array classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_e("), + s"expected outer class to allocate a fresh inner array view per call; got:\n$src") + assert( + src.contains("public int getInt(int i)"), + s"expected innermost scalar getter for IntegerType element; got:\n$src") + } + + test("Array> emits array class allocating fresh InputStruct_col0_e") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + element = innerStruct) + val elemType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val expr = Size(BoundReference(0, ArrayType(elemType), nullable = true)) + val src = generate(expr, IndexedSeq(outerArray)) + assert( + src.contains("class InputArray_col0 ") && src.contains("class InputStruct_col0_e "), + s"expected array-of-struct nested classes; got:\n$src") + assert( + src.contains("return new InputStruct_col0_e(startIndex + i)"), + s"expected array getStruct to allocate a fresh inner struct view; got:\n$src") + } + + test("Struct> emits outer + inner struct classes") { + val innerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "s", + StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray), + nullable = true, + innerStruct))) + val innerType = StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + val outerType = StructType(Seq(StructField("s", innerType, nullable = true)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputStruct_col0_f0 "), + s"expected outer + inner struct classes; got:\n$src") + assert( + src.contains("return new InputStruct_col0_f0(this.rowIdx)"), + s"expected outer struct getStruct to allocate a fresh inner struct view; got:\n$src") + assert( + src.contains("public int getInt(int ordinal)"), + s"expected innermost getInt on InputStruct_col0_f0; got:\n$src") + } + + test("Struct> emits struct class allocating fresh InputArray_col0_f0") { + val innerArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = true, innerArray))) + val structType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = true)).toArray) + val expr = Size(GetStructField(BoundReference(0, structType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("class InputStruct_col0 ") && src.contains("class InputArray_col0_f0 "), + s"expected struct-of-array nested classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_f0("), + s"expected struct getArray to allocate a fresh inner array view; got:\n$src") + } + + test("Map emits InputMap_col0 + keyArray / valueArray views") { + val keySpec = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueSpec = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = StringType, + valueSparkType = IntegerType, + key = keySpec, + value = valueSpec) + val expr = Size(BoundReference(0, MapType(StringType, IntegerType), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + assert( + src.contains("class InputMap_col0 "), + s"expected InputMap_col0 nested class; got:\n$src") + assert( + src.contains("class InputArray_col0_k ") && src.contains("class InputArray_col0_v "), + s"expected key/value array view classes; got:\n$src") + assert( + src.contains("return new InputArray_col0_k(this.startIndex, this.length)"), + s"expected keyArray to allocate a fresh view over the map slice; got:\n$src") + assert( + src.contains("return new InputArray_col0_v(this.startIndex, this.length)"), + s"expected valueArray to allocate a fresh view over the map slice; got:\n$src") + assert( + src.contains("public org.apache.spark.sql.catalyst.util.MapData getMap(int ordinal)"), + s"expected kernel-level getMap switch; got:\n$src") + assert( + src.contains("return new InputMap_col0("), + s"expected getMap to allocate a fresh InputMap_col0 view; got:\n$src") + } + + test("Map, Array> emits complex key and complex value views") { + val keyElem = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true) + val keyArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = IntegerType, element = keyElem) + val valueElem = ScalarColumnSpec(varCharVectorClass, nullable = true) + val valueArraySpec = + ArrayColumnSpec(nullable = true, elementSparkType = StringType, element = valueElem) + val mapSpec = MapColumnSpec( + nullable = true, + keySparkType = ArrayType(IntegerType), + valueSparkType = ArrayType(StringType), + key = keyArraySpec, + value = valueArraySpec) + val expr = Size( + BoundReference(0, MapType(ArrayType(IntegerType), ArrayType(StringType)), nullable = true)) + val src = CometBatchKernelCodegen.generateSource(expr, IndexedSeq(mapSpec)).body + // Full chain of nested classes should appear: top-level map view, the key/value array + // views, and the inner array classes for each complex key/value element. + Seq( + "class InputMap_col0 ", + "class InputArray_col0_k ", + "class InputArray_col0_v ", + "class InputArray_col0_k_e ", + "class InputArray_col0_v_e ").foreach { marker => + assert(src.contains(marker), s"expected $marker in emission; got:\n$src") + } + } + + /** + * Null-guard emission for nested reference-typed getters. Spark's + * `CodeGenerator.setArrayElement` only emits an `isNullAt` check before `update(i, getX(j))` + * for primitive elements. For reference types it relies on the source's `getX` to return null + * on null positions itself, matching `ColumnarArray.getBinary`. The emitter prepends `if + * (isNullAt(...)) return null;` when the element / field is nullable. + * + * Runtime regressions for the leaf reference types live in `CometCodegenSuite`; complex-type + * (Struct/Array/Map) coverage runs through HOFs in `CometCodegenHOFSuite`. + */ + private val nullableIntStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = true, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)))) + private val nullableIntStructType = + StructType(Seq(StructField("a", IntegerType, nullable = true)).toArray) + + private val nullableIntArray = ArrayColumnSpec( + nullable = true, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = true)) + + private val nullableIntStrMap = MapColumnSpec( + nullable = true, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = true)) + + test("nested array of nullable Struct emits null guard before allocating InputStruct view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = nullableIntStructType, + element = nullableIntStruct) + val expr = Size(BoundReference(0, ArrayType(nullableIntStructType), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputStruct_col0_e(startIndex + i)"), + s"expected null guard and InputStruct alloc on nullable Struct element; got:\n$src") + } + + test("nested array of non-nullable Struct elides null guard") { + // Fully non-nullable inner spec: outer struct nullable=false AND inner Int field + // nullable=false. Without the inner field also being non-nullable the inner + // primitive-Int getter wouldn't emit a guard anyway (we only guard reference types), but + // making everything non-nullable means the broad `!src.contains("if (isNullAt(...))")` + // assertion verifies "no guards anywhere" rather than passing because the inner happens + // to be a primitive we don't guard. + val nonNullableInner = StructColumnSpec( + nullable = false, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = false, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)))) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = nullableIntStructType, + element = nonNullableInner) + val expr = Size(BoundReference(0, ArrayType(nullableIntStructType), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputStruct_col0_e(startIndex + i)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;") && + !src.contains("if (isNullAt(0)) return null;"), + s"expected no null guard anywhere on fully non-nullable Struct element; got:\n$src") + } + + test( + "nested array of nullable inner Array emits null guard before allocating InputArray view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = nullableIntArray) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputArray_col0_e(__s, __e - __s)"), + s"expected null guard and InputArray alloc on nullable Array element; got:\n$src") + } + + test("nested array of non-nullable inner Array elides null guard") { + val nonNullableInner = ArrayColumnSpec( + nullable = false, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = ArrayType(IntegerType), + element = nonNullableInner) + val expr = Size(BoundReference(0, ArrayType(ArrayType(IntegerType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputArray_col0_e(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard on non-nullable inner Array element; got:\n$src") + } + + test("nested array of nullable Map emits null guard before allocating InputMap view") { + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(IntegerType, StringType), + element = nullableIntStrMap) + val expr = + Size(BoundReference(0, ArrayType(MapType(IntegerType, StringType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("if (isNullAt(i)) return null;") && + src.contains("new InputMap_col0_e(__s, __e - __s)"), + s"expected null guard and InputMap alloc on nullable Map element; got:\n$src") + } + + test("nested array of non-nullable Map elides null guard") { + val nonNullableMap = MapColumnSpec( + nullable = false, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = false)) + val outer = ArrayColumnSpec( + nullable = true, + elementSparkType = MapType(IntegerType, StringType), + element = nonNullableMap) + val expr = + Size(BoundReference(0, ArrayType(MapType(IntegerType, StringType)), nullable = true)) + val src = generate(expr, IndexedSeq(outer)) + assert( + src.contains("new InputMap_col0_e(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard on non-nullable Map element; got:\n$src") + } + + test("struct with nullable struct field emits null guard in getStruct(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("s", nullableIntStructType, nullable = true, nullableIntStruct))) + val outerType = + StructType(Seq(StructField("s", nullableIntStructType, nullable = true)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputStruct_col0_f0(this.rowIdx)"), + s"expected null guard and InputStruct alloc for nullable struct field; got:\n$src") + } + + test("struct with non-nullable struct field elides null guard") { + val nonNullableInner = StructColumnSpec( + nullable = false, + fields = Seq( + StructFieldSpec( + "a", + IntegerType, + nullable = false, + ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)))) + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("s", nullableIntStructType, nullable = false, nonNullableInner))) + val outerType = + StructType(Seq(StructField("s", nullableIntStructType, nullable = false)).toArray) + val expr = GetStructField( + GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("s")), + 0, + Some("a")) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputStruct_col0_f0(this.rowIdx)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;") && + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard anywhere on fully non-nullable struct field; got:\n$src") + } + + test("struct with nullable array field emits null guard in getArray(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = true, nullableIntArray))) + val outerType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = true)).toArray) + val expr = + Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputArray_col0_f0(__s, __e - __s)"), + s"expected null guard and InputArray alloc for nullable array field; got:\n$src") + } + + test("struct with non-nullable array field elides null guard") { + val nonNullableInner = ArrayColumnSpec( + nullable = false, + elementSparkType = IntegerType, + element = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = + Seq(StructFieldSpec("a", ArrayType(IntegerType), nullable = false, nonNullableInner))) + val outerType = + StructType(Seq(StructField("a", ArrayType(IntegerType), nullable = false)).toArray) + val expr = + Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("a"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputArray_col0_f0(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;") && + !src.contains("if (isNullAt(i)) return null;"), + s"expected no null guard anywhere on fully non-nullable array field; got:\n$src") + } + + test("struct with nullable map field emits null guard in getMap(ordinal) switch") { + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec( + "m", + MapType(IntegerType, StringType), + nullable = true, + nullableIntStrMap))) + val outerType = + StructType(Seq(StructField("m", MapType(IntegerType, StringType), nullable = true)).toArray) + val expr = Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("m"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("if (isNullAt(0)) return null;") && + src.contains("new InputMap_col0_f0(__s, __e - __s)"), + s"expected null guard and InputMap alloc for nullable map field; got:\n$src") + } + + test("struct with non-nullable map field elides null guard") { + val nonNullableMap = MapColumnSpec( + nullable = false, + keySparkType = IntegerType, + valueSparkType = StringType, + key = ScalarColumnSpec( + CometBatchKernelCodegen.vectorClassBySimpleName("IntVector"), + nullable = false), + value = ScalarColumnSpec(varCharVectorClass, nullable = false)) + val outerStruct = StructColumnSpec( + nullable = true, + fields = Seq( + StructFieldSpec("m", MapType(IntegerType, StringType), nullable = false, nonNullableMap))) + val outerType = StructType( + Seq(StructField("m", MapType(IntegerType, StringType), nullable = false)).toArray) + val expr = Size(GetStructField(BoundReference(0, outerType, nullable = true), 0, Some("m"))) + val src = generate(expr, IndexedSeq(outerStruct)) + assert( + src.contains("new InputMap_col0_f0(__s, __e - __s)"), + s"sanity: alloc still emitted; got:\n$src") + assert( + !src.contains("if (isNullAt(0)) return null;"), + s"expected no null guard on non-nullable map field; got:\n$src") + } + + test("CacheKey discriminates on ArrowColumnSpec.nullable") { + // Structural regression: same expression bytes and same Arrow vector class with different + // `nullable` must produce non-equal cache keys. The dispatcher today hardcodes `nullable=true` + // for top-level specs, so the two variants don't both arise from runtime data, but the case + // class equality contract still has to discriminate so that any future tiered cache or test + // construction can rely on it. The non-nullable variant's generated source emits a literal + // `false` from `isNullAt`, distinct codegen output that we never want to silently share with + // the nullable variant. + val bytes = java.nio.ByteBuffer.wrap(Array[Byte](1, 2, 3)) + val nullable = + IndexedSeq[ArrowColumnSpec](ArrowColumnSpec(varCharVectorClass, nullable = true)) + val nonNullable = + IndexedSeq[ArrowColumnSpec](ArrowColumnSpec(varCharVectorClass, nullable = false)) + val k1 = CometScalaUDFCodegen.CacheKey(bytes, nullable) + val k2 = CometScalaUDFCodegen.CacheKey(bytes, nonNullable) + assert( + k1 != k2, + "expected nullable=true and nullable=false specs to produce distinct cache keys") + assert( + k1.hashCode != k2.hashCode, + "case-class hashCode should also differ; identical hashCodes would degrade lookup but not " + + "equality, so the assertion is mainly a sanity check on Spec.hashCode") + } +} + +/** + * Minimal fake expressions for the `canHandle` rejection tests. Each opts into one of the marker + * traits whose presence forces a serde-level fallback. Bodies are unreachable; `canHandle` walks + * the tree structurally. + */ +private case class FakeCodegenFallback(child: Expression) + extends Expression + with CodegenFallback { + override def children: Seq[Expression] = Seq(child) + + override def nullable: Boolean = true + + override def dataType: DataType = StringType + + override def eval(input: InternalRow): Any = null + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): Expression = copy(child = newChildren.head) +} + +private case class FakeNondeterministic() extends LeafExpression with Nondeterministic { + override def nullable: Boolean = true + + override def dataType: DataType = IntegerType + + override protected def initializeInternal(partitionIndex: Int): Unit = {} + + override protected def evalInternal(input: InternalRow): Any = 0 + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + throw new UnsupportedOperationException("test fake; never reaches codegen") +} + +private case class FakeUnevaluable() extends LeafExpression with Unevaluable { + override def nullable: Boolean = true + + override def dataType: DataType = IntegerType +} diff --git a/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala b/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala new file mode 100644 index 0000000000..2da8dfd4d9 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometCodegenSuite.scala @@ -0,0 +1,1159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import org.apache.arrow.vector._ +import org.apache.spark.{SparkConf, TaskContext} +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.api.java.UDF1 +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper +import org.apache.spark.sql.types._ + +import org.apache.comet.udf.codegen.CometScalaUDFCodegen + +/** + * End-to-end correctness for the Arrow-direct codegen dispatcher. Covers the scalar and complex + * type surface, composed UDF trees, subquery reuse, `TaskContext` propagation, per-task cache + * isolation, the `maxFields` plan-time gate, and regressions pinned from fuzz. + * + * Tests exercising fallback paths (config disabled, `maxFields` exceeded) use `checkSparkAnswer` + * rather than `checkSparkAnswerAndOperator` because ScalaUDF has no Comet-native path. Under + * fallback the project runs on the JVM Spark path. + */ +class CometCodegenSuite + extends CometTestBase + with AdaptiveSparkPlanHelper + with CometCodegenAssertions { + + override protected def sparkConf: SparkConf = + super.sparkConf + .set(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key, "true") + + private def withSubjects(values: String*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (s STRING) USING parquet") + val rows = values + .map(v => if (v == null) "(NULL)" else s"('${v.replace("'", "''")}')") + .mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + f + } + } + + private def withTwoStringCols(rows: (String, String)*)(f: => Unit): Unit = { + withTable("t") { + sql("CREATE TABLE t (c1 STRING, c2 STRING) USING parquet") + if (rows.nonEmpty) { + val tuples = rows.map { case (a, b) => + val av = if (a == null) "NULL" else s"'${a.replace("'", "''")}'" + val bv = if (b == null) "NULL" else s"'${b.replace("'", "''")}'" + s"($av, $bv)" + } + sql(s"INSERT INTO t VALUES ${tuples.mkString(", ")}") + } + f + } + } + + test("ScalaUDF over concat(c1, c2) suppresses the null short-circuit") { + // Concat is not NullIntolerant. The dispatcher's short-circuit guard inspects every node in + // the bound tree and must skip the whole-tree null short-circuit because one child is + // non-NullIntolerant. The kernel therefore delegates null handling to Spark's generated + // code (which handles Concat(null, x) = x correctly) rather than returning null for any + // null input. Without the guard, null inputs would produce null outputs even where Spark + // produces a non-null concatenation. + spark.udf.register("tag", (s: String) => if (s == null) "N" else s"[${s}]") + withTwoStringCols(("abc", "123"), ("abc", null), (null, "123"), (null, null), ("zz", "zz")) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT tag(concat(c1, c2)) FROM t")) + } + } + } + + test("disabled mode bypasses the dispatcher") { + // When the per-feature config is off, `CometScalaUDF.convert` returns None and the enclosing + // operator falls back to Spark. The dispatcher's counters must not move. + spark.udf.register("noopStr", (s: String) => s) + CometScalaUDFCodegen.resetStats() + withSQLConf(CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "false") { + withSubjects("disabled_1", null) { + checkSparkAnswer(sql("SELECT noopStr(s) FROM t")) + } + } + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount == 0 && after.cacheHitCount == 0, + s"expected no dispatcher activity under disabled config, got $after") + } + + test("schema exceeding spark.sql.codegen.maxFields falls back to Spark") { + // `CometBatchKernelCodegen.canHandle` mirrors WSCG's `spark.sql.codegen.maxFields` gate by + // counting nested input fields plus the output field and refusing once the total exceeds the + // configured cap. Comet has no mid-execution fallback, so the gate must fire at plan time + // (in the serde) rather than letting an oversized kernel reach Janino. With 5 input + // BoundReferences and a 1-field output we have 6 fields total. Setting `maxFields=3` ensures + // the gate fires here regardless of test ordering or future schema additions. + spark.udf.register( + "sumFiveInts", + (a: Int, b: Int, c: Int, d: Int, e: Int) => a + b + c + d + e) + withTable("t") { + sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT) USING parquet") + sql("INSERT INTO t VALUES (1, 2, 3, 4, 5), (10, 20, 30, 40, 50)") + CometScalaUDFCodegen.resetStats() + withSQLConf("spark.sql.codegen.maxFields" -> "3") { + checkSparkAnswer(sql("SELECT sumFiveInts(a, b, c, d, e) FROM t")) + } + val after = CometScalaUDFCodegen.stats() + assert( + after.compileCount == 0 && after.cacheHitCount == 0, + s"expected dispatcher fallback under maxFields=3, got $after") + } + } + + test("dispatcher caches the compiled kernel across batches of one query") { + // Within a single query, the dispatcher compiles a kernel for the (expression, schema) pair + // once and reuses it across every subsequent batch of the same shape. Force multiple batches + // by lowering the Comet batch size with a row count well above it, then assert at least one + // cache hit happened during the query. + // + // We deliberately do not assert cross-query cache reuse: Spark's analyzer produces a fresh + // `ScalaUDF` instance per query resolution, and the encoders embedded in that instance + // contain `AttributeReference`s with fresh `ExprId`s that our `BindReferences.bindReference` + // does not recurse into. The closure-serialized cache key bytes therefore drift across + // queries even when the registered function and schema are identical, so each new query of a + // ScalaUDF pays one compile up front and amortizes within itself. This is an acceptable + // amortization story (a few tens of milliseconds per query), not a behavior we can or do + // promise across queries. + spark.udf.register("kernelCacheMarker", (s: String) => if (s == null) null else s + "_kc") + val rows = (0 until 256).map(i => s"row_$i") + CometScalaUDFCodegen.resetStats() + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "32") { + withSubjects(rows: _*) { + checkSparkAnswerAndOperator(sql("SELECT kernelCacheMarker(s) FROM t")) + } + } + val stats = CometScalaUDFCodegen.stats() + assert(stats.compileCount >= 1, s"expected at least one compile during the query, got $stats") + assert( + stats.cacheHitCount >= 1, + s"expected at least one cache hit across batches of the same query, got $stats") + } + + test("per-partition kernel preserves Nondeterministic state across batches") { + // Wrap `monotonically_increasing_id()` as the argument of a ScalaUDF so the whole tree + // (including the stateful MonotonicallyIncreasingID child) routes through the dispatcher. + // Per-partition kernel caching means the id counter advances across batches within a + // partition. Without it, every batch would restart at 0 and the UDF output would disagree + // with Spark's. The UDF body is a trivial identity. We're testing state correctness of the + // Nondeterministic child across batches, not the UDF logic. + spark.udf.register("idPassthrough", (id: Long) => id) + val rows = (0 until 4096).map(i => s"row_$i") + withSubjects(rows: _*) { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT s, idPassthrough(monotonically_increasing_id()) FROM t")) + } + } + } + + test( + "same UDF over nullable and non-nullable columns gets distinct kernels with independent state") { + // Two columns, same type, different schema-declared nullability. Same UDF applied to each + // alongside a per-projection MonotonicallyIncreasingID. Each projection has its own MII + // child (different bytesKey), so each kernel must have its own counter advancing 0..N-1. + // If the dispatcher collapses them onto one kernel or shares state somehow, the counters + // would interleave and the output would diverge from Spark. + spark.udf.register("withId", (s: String, id: Long) => s"${s}_${id}") + withTempPath { dir => + import org.apache.spark.sql.Row + import org.apache.spark.sql.types.{StringType, StructField, StructType} + val schema = StructType( + Seq( + StructField("a", StringType, nullable = true), + StructField("b", StringType, nullable = false))) + val rows = (0 until 64).map(i => Row(s"a_$i", s"b_$i")) + val rdd = spark.sparkContext.parallelize(rows, numSlices = 1) + spark.createDataFrame(rdd, schema).write.parquet(dir.getCanonicalPath) + withTable("t") { + sql(s"CREATE TABLE t USING parquet LOCATION '${dir.getCanonicalPath}'") + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "8") { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT withId(a, monotonically_increasing_id()), " + + "withId(b, monotonically_increasing_id()) FROM t")) + } + } + } + } + } + + test("Nondeterministic state persists across nullability flips within a partition") { + // Regression guard against re-introducing per-batch nullability into the cache key. Force a + // single parquet file with `spark.range(numPartitions=1)`, large enough that batch size 8 + // produces many batches in one scan partition. Null density varies by row range. If the + // dispatcher ever started deriving spec nullability from runtime data again, the cache key + // would flip mid-partition, the kernel would be re-allocated, and MII's counter would reset + // across the flip. + spark.udf.register("idPair", (id: Long, s: String) => (id, s)) + withTempPath { dir => + spark + .range(0, 200, 1, numPartitions = 1) + .selectExpr("CASE WHEN id >= 16 AND id < 32 THEN NULL ELSE concat('row_', id) END AS s") + .write + .parquet(dir.getCanonicalPath) + withTable("t") { + sql(s"CREATE TABLE t USING parquet LOCATION '${dir.getCanonicalPath}'") + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "8") { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT idPair(monotonically_increasing_id(), s) FROM t")) + } + } + } + } + } + + test("Nondeterministic state persists across two ScalaUDFs in one task") { + // The dispatcher is one instance per task (keyed by `(taskAttemptId, udfClassName)` in + // CometUdfBridge), so a plan with two distinct ScalaUDFs shares one CometScalaUDFCodegen. + // Two distinct closure-serialized expressions hit two cache entries. Per batch the + // dispatcher is invoked once for each. Each cache entry must stash its own kernel instance, + // otherwise the two expressions would fight for a shared kernel slot and stateful state + // (MII counter) would reset on every flip. + // + // Small batch size forces multiple batches over a small table so the per-key flip happens + // several times within one task. + spark.udf.register("idA", (id: Long) => id) + spark.udf.register("idB", (id: Long) => -id) + val rows = (0 until 64).map(i => s"row_$i") + withSubjects(rows: _*) { + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "8") { + assertCodegenRan { + checkSparkAnswerAndOperator( + sql( + "SELECT s, " + + "idA(monotonically_increasing_id()) AS a, " + + "idB(monotonically_increasing_id()) AS b FROM t")) + } + } + } + } + + test("per-task cache isolates UDF state across sequential task runs in one session") { + // Regression guard for the cache-scoping invariant on CometUdfBridge: instances live for + // exactly one Spark task and are dropped on task completion, so a stateful kernel sees a + // fresh instance per task. The query has to actually route through the dispatcher for this + // to test anything, so wrap `monotonically_increasing_id()` in a ScalaUDF identity. Running + // it twice in one session must produce results matching Spark each time. Under a cache that + // outlived a task and got reused by the next one, the counter would continue from the + // previous run's final value and the second run's IDs would diverge from Spark. Under a + // cache that was keyed by Tokio worker thread rather than task attempt ID, worker reuse + // across tasks would cause the same leak whenever the second task happened to be polled by + // the same worker. Two `checkSparkAnswerAndOperator` calls are stronger than asserting + // first == second: equality alone could pass if both runs are wrong-but-consistent (e.g. + // `init(partitionIndex)` never fires); matching Spark on both runs rules that out and + // implies cross-run equality because Spark is deterministic on the same query. + spark.udf.register("idPassthrough", (id: Long) => id) + val rows = (0 until 2048).map(i => s"row_$i") + withSubjects(rows: _*) { + val q = "SELECT s, idPassthrough(monotonically_increasing_id()) AS mid FROM t" + checkSparkAnswerAndOperator(sql(q)) + checkSparkAnswerAndOperator(sql(q)) + } + } + + /** + * Scalar ScalaUDF smoke tests. These prove that user-registered UDFs route through the codegen + * dispatcher rather than forcing a whole-plan Spark fallback. Spark's `ScalaUDF.doGenCode` + * already emits compilable Java that calls the user function via `ctx.addReferenceObj`, so the + * dispatcher's compile path picks it up for free. Tests that user-registered UDFs route through + * the dispatcher rather than forcing whole-plan Spark fallback. + */ + + test("registered string ScalaUDF routes through dispatcher") { + spark.udf.register("shout", (s: String) => if (s == null) null else s.toUpperCase + "!") + withSubjects("Abc", "xyz", null, "mixed") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT shout(s) FROM t")) + } + } + } + + test("registered Java UDF1 routes through dispatcher") { + // Java API path: `spark.udf.register(name, UDF1<...>, returnType)`. Spark wraps the Java + // functional interface in a Scala function and produces a `ScalaUDF` expression at plan + // time, so the dispatcher handles it the same as a Scala-registered UDF. Sanity check that + // both registration paths land on the same routing code. + spark.udf.register( + "javaLen", + new UDF1[String, Integer] { + override def call(s: String): Integer = if (s == null) -1 else s.length + }, + IntegerType) + withSubjects("abc", "hello", null, "x") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT javaLen(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("multi-arg ScalaUDF over string + literal routes through dispatcher") { + spark.udf.register( + "prepend", + (prefix: String, s: String) => if (s == null) null else prefix + s) + withSubjects("one", "two", null) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT prepend('[', s) FROM t")) + } + } + } + + test("ScalaUDF as a child of a native Spark expression") { + // The ScalaUDF routes through the dispatcher as a sub-expression. The surrounding `length` + // runs through Comet's native scalar function path. This exercises the cross-boundary + // composition where a dispatcher-compiled kernel returns a UTF8String that a native Comet + // expression then consumes. + spark.udf.register("wrap", (s: String) => if (s == null) null else s"|$s|") + withSubjects("abc", "def", null) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT length(wrap(s)) FROM t")) + } + } + } + + test("composed ScalaUDFs outer(inner(s)) fuse into one kernel") { + // Two user UDFs stacked, both operating on String. The dispatcher binds the whole tree and + // Spark's codegen emits two `ctx.addReferenceObj` calls inside one generated method. Races + // on the `ExpressionEncoder` serializers in `references` would show up here since each UDF + // contributes its own stateful serializer. The `freshReferences` closure in `CompiledKernel` + // is what keeps this correct across partitions. + spark.udf.register("inner", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("outer", (s: String) => if (s == null) null else s"<$s>") + withSubjects("abc", null, "xyz", "MiXeD") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT outer(inner(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), StringType) + } + } + + test("ScalaUDFs of different types compose: isShort(len(s))") { + // Exercises an input type transition: String -> Int -> Boolean. Two user UDFs with + // different I/O type shapes in one tree, one Janino compile. + spark.udf.register("len", (s: String) => if (s == null) -1 else s.length) + spark.udf.register("isShort", (i: Int) => i < 5) + withSubjects("ab", "abcdef", null, "hi") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT isShort(len(s)) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BooleanType) + } + } + + test("three-deep ScalaUDF composition lvl3(lvl2(lvl1(s)))") { + // Three user UDFs stacked in one tree: String -> String -> String -> Int. The fused kernel + // carries three `ctx.addReferenceObj` calls. `assertOneKernelForSubtree` asserts that the + // whole chain collapses into a single compile rather than one per nesting level. + // Null handling through composed UDFs is covered by the other composition tests above. + spark.udf.register("lvl1", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lvl2", (s: String) => if (s == null) null else s.reverse) + spark.udf.register("lvl3", (s: String) => if (s == null) -1 else s.length) + withSubjects("abc", "hello world", "x") { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT lvl3(lvl2(lvl1(s))) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("multi-column ScalaUDF composition join(upperU(c1), lowerU(c2))") { + // One multi-arg user UDF consuming two other user UDFs, each on a different input column. + // The bound tree has two BoundReferences, and the kernel is specialized on two VarCharVector + // columns. `assertOneKernelForSubtree` asserts that the two-branch composition fuses into a + // single kernel rather than one per branch or one per UDF. + // Input rows intentionally exclude nulls (see note on the three-deep test above). + spark.udf.register("upperU", (s: String) => if (s == null) null else s.toUpperCase) + spark.udf.register("lowerU", (s: String) => if (s == null) null else s.toLowerCase) + spark.udf.register( + "joinU", + (a: String, b: String) => if (a == null || b == null) null else s"$a-$b") + withTwoStringCols(("Abc", "XYZ"), ("Foo", "bar"), ("baz", "Bar"), ("Hi", "Lo")) { + assertOneKernelForSubtree { + checkSparkAnswerAndOperator(sql("SELECT joinU(upperU(c1), lowerU(c2)) FROM t")) + } + assertKernelSignaturePresent( + Seq(classOf[VarCharVector], classOf[VarCharVector]), + StringType) + } + } + + /** + * Per-primitive identity-UDF coverage. Each entry registers a `T => T` UDF over a parquet + * column declared at `sqlType` and asserts the dispatcher compiled a kernel for the matching + * `(vector class, output type)` pair. Parquet-backed (rather than `spark.range`-cast) tables + * keep the column's Arrow vector class aligned with the UDF signature. + */ + private def withTypedCol(sqlType: String, valueLiterals: String*)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (c $sqlType) USING parquet") + if (valueLiterals.nonEmpty) { + val rows = valueLiterals.map(v => s"($v)").mkString(", ") + sql(s"INSERT INTO t VALUES $rows") + } + f + } + } + + private case class IdentityUdfCase( + label: String, + sqlType: String, + values: Seq[String], + vec: Class[_ <: ValueVector], + output: DataType, + udfName: String, + register: () => Unit) + + private val identityScalarCases: Seq[IdentityUdfCase] = Seq( + IdentityUdfCase( + "Boolean", + "BOOLEAN", + Seq("TRUE", "FALSE", "TRUE"), + classOf[BitVector], + BooleanType, + "u_bool", + () => spark.udf.register("u_bool", (b: Boolean) => !b)), + IdentityUdfCase( + "Byte", + "TINYINT", + Seq("CAST(1 AS TINYINT)", "CAST(2 AS TINYINT)", "CAST(100 AS TINYINT)"), + classOf[TinyIntVector], + ByteType, + "u_byte", + () => spark.udf.register("u_byte", (b: Byte) => (b + 1).toByte)), + IdentityUdfCase( + "Short", + "SMALLINT", + Seq("CAST(1 AS SMALLINT)", "CAST(2 AS SMALLINT)", "CAST(30000 AS SMALLINT)"), + classOf[SmallIntVector], + ShortType, + "u_short", + () => spark.udf.register("u_short", (s: Short) => (s + 1).toShort)), + IdentityUdfCase( + "Int", + "INT", + Seq("1", "2", "100"), + classOf[IntVector], + IntegerType, + "u_int", + () => spark.udf.register("u_int", (i: Int) => i * 2)), + IdentityUdfCase( + "Long", + "BIGINT", + Seq("1", "2", "100"), + classOf[BigIntVector], + LongType, + "u_long", + () => spark.udf.register("u_long", (l: Long) => l + 1L)), + IdentityUdfCase( + "Float", + "FLOAT", + Seq("CAST(1.5 AS FLOAT)", "CAST(2.5 AS FLOAT)"), + classOf[Float4Vector], + FloatType, + "u_float", + () => spark.udf.register("u_float", (f: Float) => f * 1.5f)), + IdentityUdfCase( + "Double", + "DOUBLE", + Seq("1.5", "2.5", "100.0"), + classOf[Float8Vector], + DoubleType, + "u_double", + () => spark.udf.register("u_double", (d: Double) => d / 2.0)), + IdentityUdfCase( + "Date", + "DATE", + Seq("DATE'2024-01-01'", "DATE'2024-06-15'", "DATE'1970-01-01'"), + classOf[DateDayVector], + DateType, + "u_date", + () => + spark.udf.register( + "u_date", + (d: java.sql.Date) => + if (d == null) null else new java.sql.Date(d.getTime + 86400000L))), + IdentityUdfCase( + "Timestamp", + "TIMESTAMP", + Seq("TIMESTAMP'2024-01-01 12:00:00'", "TIMESTAMP'2024-06-15 23:59:59'"), + classOf[TimeStampMicroTZVector], + TimestampType, + "u_ts", + () => + spark.udf.register( + "u_ts", + (t: java.sql.Timestamp) => + if (t == null) null else new java.sql.Timestamp(t.getTime + 1000L))), + IdentityUdfCase( + "TimestampNTZ", + "TIMESTAMP_NTZ", + Seq("TIMESTAMP_NTZ'2024-01-01 12:00:00'", "TIMESTAMP_NTZ'2024-06-15 23:59:59'"), + classOf[TimeStampMicroVector], + TimestampNTZType, + "u_tsntz", + () => + spark.udf.register( + "u_tsntz", + (ldt: java.time.LocalDateTime) => if (ldt == null) null else ldt.plusDays(1)))) + + identityScalarCases.foreach { c => + test(s"identity ScalaUDF on ${c.label} routes through dispatcher") { + c.register() + withTypedCol(c.sqlType, c.values: _*) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql(s"SELECT ${c.udfName}(c) FROM t")) + } + assertKernelSignaturePresent(Seq(c.vec), c.output) + } + } + } + + test("ScalaUDF returning a different type than its input") { + // String -> Int output transition. Identity-loop above keeps input == output. This asserts + // the writer can switch types per the UDF's declared return. + spark.udf.register("codePoint", (s: String) => if (s == null) 0 else s.codePointAt(0)) + withSubjects("abc", "A", null, "!") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT codePoint(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), IntegerType) + } + } + + test("ScalaUDF returning BinaryType") { + // Binary output writer path, exercised here by a user UDF for the first time. Before this + // the writer only had direct-compile unit tests. + spark.udf.register("bytes", (s: String) => if (s == null) null else s.getBytes("UTF-8")) + withSubjects("abc", null, "hello") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT bytes(s) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarCharVector]), BinaryType) + } + } + + test("ScalaUDF on BinaryType") { + // Binary input getter path: VarBinaryVector with byte[] reads via Spark's `getBinary` getter. + spark.udf.register("blen", (b: Array[Byte]) => if (b == null) -1 else b.length) + withTable("t") { + sql("CREATE TABLE t (b BINARY) USING parquet") + sql("INSERT INTO t VALUES (CAST('abc' AS BINARY)), (CAST('hello' AS BINARY)), (NULL)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT blen(b) FROM t")) + } + assertKernelSignaturePresent(Seq(classOf[VarBinaryVector]), IntegerType) + } + } + + test("ScalaUDF returning ArrayType(StringType)") { + // First use of the ArrayType output path end-to-end. The UDF returns a `Seq[String]`, + // which Spark encodes as `ArrayType(StringType, containsNull = true)`. The dispatcher's + // canHandle accepts it (ArrayType is supported when its element type is supported), + // allocateOutput builds a ListVector with an inner VarCharVector, and emitWrite recurses + // into the StringType case for the per-element UTF8 on-heap shortcut. End-to-end answer + // matches Spark. + spark.udf.register( + "splitComma", + (s: String) => if (s == null) null else s.split(",", -1).toSeq) + withSubjects("a,b,c", "x", null, "", "one,,three") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT splitComma(s) FROM t")) + } + } + } + + test("ScalaUDF returning ArrayType(IntegerType)") { + // Exercises ArrayType output with a primitive element. emitWrite's ArrayType case + // recurses into the IntegerType case for the inner write. No byte[] allocation involved. + spark.udf.register( + "asLengths", + (s: String) => if (s == null) null else s.split(",").map(_.length).toSeq) + withSubjects("a,bb,ccc", null, "xyzzy") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT asLengths(s) FROM t")) + } + } + } + + test("zero-column ScalaUDF produces one row per input row") { + // Non-deterministic (so Spark doesn't constant-fold) with a deterministic body (so + // Spark-vs-Comet comparison stays honest). The expression has no `AttributeReference`, + // so the serde produces an empty data-arg list and the dispatcher has no data column to + // read the batch size from. Guards the `numRows` path through the JNI bridge. + import org.apache.spark.sql.functions.udf + val alwaysHello = udf(() => "hello").asNondeterministic() + spark.udf.register("helloU", alwaysHello) + withSubjects("a", "b", null, "c") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT helloU() FROM t")) + } + } + } + + /** + * Decimal end-to-end: the dispatcher's `getDecimal` specializes per `DecimalType.precision` at + * source-generation time. Two representative cases here; `CometCodegenFuzzSuite` sweeps every + * shape across the boundary at varying null densities. + */ + private def withDecimalTable(decimalType: String, values: Seq[String])(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (d $decimalType) USING parquet") + val rows = values.map(v => if (v == null) "(NULL)" else s"($v)").mkString(", ") + if (values.nonEmpty) sql(s"INSERT INTO t VALUES $rows") + f + } + } + + test("ScalaUDF over Decimal(18, 9) routes through the unscaled-long fast path") { + // Boundary precision (18 == `MAX_LONG_DIGITS`) with a non-zero scale exercises the fractional + // branch of the fast-path encoding. + spark.udf.register("decIdShort", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(18, 9)", + Seq("0.000000000", "1.123456789", "-1.123456789", "999999999.999999999", null)) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT decIdShort(d) FROM t")) + } + } + } + + test("ScalaUDF over Decimal(38, 10) routes through the BigDecimal slow path") { + spark.udf.register("decIdLong", (d: java.math.BigDecimal) => d) + withDecimalTable( + "DECIMAL(38, 10)", + Seq( + "0.0000000000", + "1.1234567890", + "-1.1234567890", + "9999999999999999999999999999.0000000000", + null)) { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT decIdLong(d) FROM t")) + } + } + } + + test("ScalaUDF sees TaskContext.partitionId() per partition") { + // Direct probe: register a ScalaUDF that reads TaskContext.partitionId() and returns it. + // Spark's own task thread has TaskContext set, so each partition's rows carry that + // partition's index. For the dispatcher to match Spark, the invocation thread must see a + // live TaskContext. With the `createPlan`-time TaskContext capture + bridge-side + // `TaskContext.setTaskContext` install (see `CometUdfBridge.evaluate` and + // `CometTaskContextShim`), Tokio workers see the propagated TaskContext and the UDF + // returns the real partitionId. Without that propagation, `TaskContext.get()` returns null + // on the Tokio thread and the sentinel (-1) leaks through, diverging from Spark. + spark.udf.register( + "pid", + (_: Long) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "pid(id) as p") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF sees TaskContext from fully-native parquet plan") { + // The `spark.range`-based test above runs through `CometSparkRowToColumnar`, which executes + // on a Spark task thread where TaskContext is live even without explicit propagation. The + // fully-native path through `CometNativeScan` runs the JVM UDF bridge on a Tokio worker + // thread where TaskContext.get() would otherwise be null. This test forces that path by + // sourcing from a Parquet table written as multiple files (so the native read produces + // multiple partitions) and asserting the UDF still sees the per-partition TaskContext via + // the `createPlan`-time capture + bridge-side install. + spark.udf.register( + "pidP", + (_: Int) => { + val tc = TaskContext.get() + if (tc != null) tc.partitionId() else -1 + }) + withTable("t") { + sql("CREATE TABLE t (x INT) USING parquet") + // Multiple INSERT statements -> multiple parquet files -> multiple read splits -> + // multiple partitions. + sql("INSERT INTO t VALUES (1), (2), (3), (4)") + sql("INSERT INTO t VALUES (5), (6), (7), (8)") + sql("INSERT INTO t VALUES (9), (10), (11), (12)") + sql("INSERT INTO t VALUES (13), (14), (15), (16)") + checkSparkAnswerAndOperator(sql("SELECT x, pidP(x) AS p FROM t")) + } + } + + test("Rand seeded per partition across a multi-partition table") { + // Rand.doGenCode registers an XORShiftRandom via ctx.addMutableState and seeds it via + // ctx.addPartitionInitializationStatement. That init statement runs inside our kernel's + // `init(int partitionIndex)`, called once per kernel allocation. Spark seeds + // `XORShiftRandom(seed + partitionIndex)` per partition, so different partitions produce + // different sequences for the same seed. Matching Spark across partitions requires the + // kernel to see the real partition index, which the dispatcher derives from + // `TaskContext.get().partitionId()`, live on this path thanks to the bridge-level + // TaskContext propagation. Composing with a ScalaUDF (identity on Double here) forces the + // tree through codegen dispatch so the Rand evaluation runs inside our kernel's init + // rather than via Spark's normal codegen. + spark.udf.register("dblId", (d: Double) => d) + val df = spark + .range(0, 1024, 1, numPartitions = 4) + .selectExpr("id", "dblId(rand(42)) as r") + checkSparkAnswerAndOperator(df) + } + + test("ScalaUDF composed with reused scalar subquery across projection and filter") { + // The same scalar subquery appears in two sites: the projection (which the dispatcher + // compiles into a fused kernel) and the filter (a separate operator). Each site holds its + // own `ScalarSubquery` expression instance with its own `@volatile result` field. Each + // surrounding operator's inherited `SparkPlan.waitForSubqueries` populates its instance's + // `result` before the dispatcher's bridge serializes the expression. The populated value + // travels through closure serialization into the cache key's bytes, so different subquery + // values compile distinct kernels. Exercises the full subquery-correctness invariant + // documented on `CometBatchKernelCodegen.canHandle`. + spark.udf.register("addOne", (i: Int) => i + 1) + withTable("t", "t2") { + sql("CREATE TABLE t (x INT) USING parquet") + sql("INSERT INTO t VALUES (1), (2), (3), (4), (5)") + sql("CREATE TABLE t2 (v INT) USING parquet") + sql("INSERT INTO t2 VALUES (2), (4)") + checkSparkAnswerAndOperator( + sql("SELECT addOne(x) + (SELECT max(v) FROM t2) AS r " + + "FROM t WHERE addOne(x) < (SELECT max(v) FROM t2) * 2")) + } + } + + /** + * ArrayType input. The dispatcher emits a nested `InputArray_col0` final class per array-typed + * input column; Spark's generated `getArray(ord)` resolves to our kernel's switch which returns + * the pre-allocated instance after resetting its start/length against the list's offsets. + * Element reads go through the typed child-vector field with no `ArrayData` copy or boxing. + * + * Each smoke test exercises the same serde/transport path at a different element type so the + * nested getter emitter's scalar-element cases are each covered: `StringType` (zero-copy + * `UTF8String.fromAddress`), `IntegerType` (primitive direct), and `DecimalType(p <= 18)` + * (decimal128 fast path). + */ + private def withArrayTable(colType: String, insertRows: String)(f: => Unit): Unit = { + withTable("t") { + sql(s"CREATE TABLE t (a $colType) USING parquet") + sql(s"INSERT INTO t VALUES $insertRows") + f + } + } + + test("ScalaUDF taking Seq[String] reads element by element") { + spark.udf.register( + "headOrNull", + (arr: Seq[String]) => if (arr == null || arr.isEmpty) null else arr.head) + withArrayTable( + "ARRAY", + "(array('a', 'b', 'c')), (array('x')), (null), (array()), (array('alone'))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT headOrNull(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[String] iterating all elements") { + spark.udf.register( + "concatArr", + (arr: Seq[String]) => if (arr == null) null else arr.mkString("|")) + withArrayTable( + "ARRAY", + "(array('one', 'two', 'three')), (array('solo')), (null), (array())") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT concatArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[Int] reads primitive elements") { + spark.udf.register("sumArr", (arr: Seq[Int]) => if (arr == null) -1 else arr.sum) + withArrayTable( + "ARRAY", + "(array(1, 2, 3)), (array(-5, 5)), (array()), (null), (array(42))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT sumArr(a) FROM t")) + } + } + } + + test("ScalaUDF taking Seq[BigDecimal] hits short-precision decimal fast path") { + // DecimalType(10, 2) is well inside p <= 18, so the nested-array `getDecimal` emits the + // unscaled-long fast path (see `emitNestedArrayElementGetter`). A `BigDecimal` UDF argument + // forces Spark's encoder to call `getDecimal(i, 10, 2)` on our nested ArrayData for each + // element, which exercises that code path end to end. + spark.udf.register( + "sumDecArr", + (arr: Seq[java.math.BigDecimal]) => + if (arr == null) null + else { + var acc = java.math.BigDecimal.ZERO + arr.foreach(v => if (v != null) acc = acc.add(v)) + acc + }) + withArrayTable( + "ARRAY", + "(array(1.23, 4.56)), (array(-9.99)), (null), (array())") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT sumDecArr(a) FROM t")) + } + } + } + + test("ScalaUDF composes with struct-field access reading Struct.age") { + // Keeps the UDF arg scalar (Int) but puts a `GetStructField` under it so the codegen + // dispatcher compiles the struct-input read path (`row.getStruct(0, 2).getInt(1)`). + spark.udf.register("doubleInt", (i: Int) => i * 2) + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42)), " + + "(null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT doubleInt(s.age) FROM t")) + } + } + } + + test("ScalaUDF taking full Struct value (case class arg)") { + // Case-class UDF arguments: test data must not include null top-level rows. + // `ScalaUDF.scalaConverter` applies Spark's `ExpressionEncoder.Deserializer` on every row + // to materialize the case-class instance. The generated deserializer has a + // `newInstance(NameAgePair)` step that throws `EXPRESSION_DECODING_FAILED` on a null input, + // independent of the dispatcher. Case-class UDF tests omit null top-level rows. Other + // tests with plain `Seq` / `Map` args can include nulls because the deserializer hands null + // to the UDF body which handles it. + spark.udf.register("fmtPair", (r: NameAgePair) => s"${r.name}:${r.age}") + withTable("t") { + sql("CREATE TABLE t (s STRUCT) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'alice', 'age', 30)), " + + "(named_struct('name', 'bob', 'age', 42))") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT fmtPair(s) FROM t")) + } + } + } + + test("ScalaUDF returning Struct (case class output)") { + spark.udf.register("makePair", (i: Int) => NameAgePair(s"n$i", i)) + withTypedCol("INT", "1", "2", "3") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT makePair(c) FROM t")) + } + } + } + + test("ScalaUDF taking Map") { + spark.udf.register("sumMap", (m: Map[String, Int]) => if (m == null) -1 else m.values.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP) USING parquet") + sql("INSERT INTO t VALUES (map('a', 1, 'b', 2)), (map()), (null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT sumMap(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map (primitive key and value)") { + // Map with non-string keys: exercises the primitive-key element getter on the input side + // and the corresponding writer on the output side. Spark's encoder for `Map[Int, Int]` calls + // `getInt(0)` / `getInt(1)` on the entries struct, hitting the kernel's typed scalar getter + // for each side rather than the UTF8 path. + spark.udf.register( + "incValues", + (m: Map[Int, Int]) => if (m == null) null else m.map { case (k, v) => k -> (v + 1) }) + withTable("t") { + sql("CREATE TABLE t (m MAP) USING parquet") + sql("INSERT INTO t VALUES (map(1, 10, 2, 20)), (map()), (null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT incValues(m) FROM t")) + } + } + } + + test("ScalaUDF returning Map") { + spark.udf.register( + "singletonMap", + (s: String, i: Int) => if (s == null) null else Map(s -> i)) + withTable("t") { + sql("CREATE TABLE t (s STRING, i INT) USING parquet") + sql("INSERT INTO t VALUES ('a', 1), ('b', 2), (null, 3)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT singletonMap(s, i) FROM t")) + } + } + } + + test("ScalaUDF taking Map> exercises nested composition") { + spark.udf.register( + "totalLens", + (m: Map[String, Seq[Int]]) => if (m == null) -1 else m.values.flatten.sum) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(1, 2, 3), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT totalLens(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Array> (nested array input + output)") { + // Exercises nested-array input reads and nested-list output writes in one call: the inner + // `InputArray_col0_e` class on the input side and the recursive emitWrite on the output. + spark.udf.register( + "reverseRows", + (arr: Seq[Seq[Int]]) => if (arr == null) null else arr.map(_.reverse)) + withTable("t") { + sql("CREATE TABLE t (a ARRAY>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(array(1, 2, 3), array(4, 5))), " + + "(array(array())), " + + "(null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT reverseRows(a) FROM t")) + } + } + } + + test("ScalaUDF round-trips Struct>") { + // Struct with a complex field on both sides: input reads go through InputStruct_col0 + + // InputArray_col0_f1, output writes through StructVector + ListVector. + // Null top-level rows omitted - case-class arg. See the note on `fmtPair` above. + spark.udf.register( + "growItems", + (r: NameItems) => + if (r == null) null else NameItems(r.name, if (r.items == null) null else r.items :+ 0)) + withTable("t") { + sql("CREATE TABLE t (s STRUCT>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(named_struct('name', 'a', 'items', array(1, 2))), " + + "(named_struct('name', 'b', 'items', array()))") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT growItems(s) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map> (nested value both sides)") { + // Map input read goes through InputMap_col0 + InputArray_col0_v (the complex-value side); + // output write emits MapVector + entries Struct + per-value ListVector inside the map's + // entries struct. + spark.udf.register( + "sortValues", + (m: Map[String, Seq[Int]]) => + if (m == null) null + else m.map { case (k, v) => k -> (if (v == null) null else v.sorted) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', array(3, 1, 2), 'b', array(10))), " + + "(map()), " + + "(null)") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT sortValues(m) FROM t")) + } + } + } + + test("ScalaUDF round-trips Map>") { + // Struct value inside a map, both sides. Null top-level rows omitted - the map value is a + // case class. See the note on `fmtPair` above. + spark.udf.register( + "tagValues", + (m: Map[String, XyPair]) => + if (m == null) null + else + m.map { case (k, v) => k -> (if (v == null) null else XyPair(v.x + 1, s"<${v.y}>")) }) + withTable("t") { + sql("CREATE TABLE t (m MAP>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(map('a', named_struct('x', 1, 'y', 'one'))), " + + "(map())") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT tagValues(m) FROM t")) + } + } + } + + test("array_distinct on Array> retains element identity across hash set") { + // Fuzz signal: cardinality(array_distinct(arr_of_struct)) returns 1 where Spark returns 2. + // Hypothesis: the kernel's InputStruct wrapper backing array_distinct's element reads is + // reused without resetting per-element state, so every hashed element looks identical and + // distinct collapses the array to a single entry. + spark.udf.register("idIntDistinct", (i: Int) => i) + withTable("t") { + sql("CREATE TABLE t (s ARRAY>) USING parquet") + sql( + "INSERT INTO t VALUES " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 1, 'b', 'x'))), " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 2, 'b', 'y'))), " + + "(array(named_struct('a', 1, 'b', 'x'), named_struct('a', 2, 'b', 'y'), " + + "named_struct('a', 1, 'b', 'x')))") + assertCodegenRan { + checkSparkAnswerAndOperator( + sql("SELECT idIntDistinct(cardinality(array_distinct(s))) FROM t")) + } + } + } + + test("array_max(flatten(arr)) on Array> with mixed null inner arrays") { + // Fuzz signal: array_max(flatten(arr)) returns empty byte arrays where Spark returns the + // actual max binary, with the empties sorting to the front of the output. Pattern points at + // cross-batch state pollution. Generate 100 rows of varied outer/inner shape, longer + // binaries, mixed nulls. Force multiple batches with a small batch size. + spark.udf.register("idBinFlat", (b: Array[Byte]) => b) + withSQLConf(CometConf.COMET_BATCH_SIZE.key -> "16") { + withTable("t") { + sql("CREATE TABLE t (a ARRAY>) USING parquet") + val rows = (0 until 100).map { i => + if (i % 11 == 0) { + "(NULL)" + } else { + val outerSize = (i % 5) + 1 + val inners = (0 until outerSize).map { j => + val pick = (i * 7 + j) % 13 + if (pick == 0) "array()" + else if (pick == 1) "NULL" + else { + val innerSize = ((i + j) % 4) + 1 + val bytes = (0 until innerSize).map { k => + val len = ((i + j + k) % 8) + 1 + val hex = (0 until len) + .map(b => f"${(i * 13 + j * 17 + k * 5 + b) & 0xff}%02x") + .mkString + s"X'$hex'" + } + "array(" + bytes.mkString(", ") + ")" + } + } + s"(array(${inners.mkString(", ")}))" + } + } + sql(s"INSERT INTO t VALUES ${rows.mkString(", ")}") + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idBinFlat(array_max(flatten(a))) FROM t")) + } + } + } + } + + /** + * Regressions for nested reference-typed getter null handling. Spark's + * `CodeGenerator.setArrayElement` only emits an `isNullAt` check before `array.update(i, + * getX(j))` for Java primitives. For reference-typed elements (Binary, String, Decimal, Struct, + * Array, Map) it relies on the source's `getX` to return `null` itself, matching + * `ColumnarArray.getBinary`. Without that contract, inner nulls become empty bytes / empty + * strings / garbage decimals / non-null shells in the flattened output. + */ + + test("array_max(flatten(arr)) on Array> with null inner Binary returns null") { + spark.udf.register("idBin", (b: Array[Byte]) => b) + withArrayTable( + "ARRAY>", + "(array(array(NULL))), " + + "(array(array(NULL, NULL))), " + + "(array(array(), array(NULL)))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idBin(array_max(flatten(a))) FROM t")) + } + } + } + + test("array_max(flatten(arr)) on Array> with null inner String returns null") { + spark.udf.register("idStr", (s: String) => s) + withArrayTable( + "ARRAY>", + "(array(array(NULL))), " + + "(array(array(NULL, NULL))), " + + "(array(array(), array(NULL)))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idStr(array_max(flatten(a))) FROM t")) + } + } + } + + test( + "array_max(flatten(arr)) on Array> with null inner Decimal " + + "(short-precision fast path)") { + spark.udf.register("idDec10", (d: java.math.BigDecimal) => d) + withArrayTable( + "ARRAY>", + "(array(array(CAST(NULL AS DECIMAL(10, 2))))), " + + "(array(array(" + + "CAST(NULL AS DECIMAL(10, 2)), CAST(NULL AS DECIMAL(10, 2))))), " + + "(array(array(), array(CAST(NULL AS DECIMAL(10, 2)))))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idDec10(array_max(flatten(a))) FROM t")) + } + } + } + + test( + "array_max(flatten(arr)) on Array> with null inner Decimal " + + "(long-precision slow path)") { + spark.udf.register("idDec30", (d: java.math.BigDecimal) => d) + withArrayTable( + "ARRAY>", + "(array(array(CAST(NULL AS DECIMAL(30, 2))))), " + + "(array(array(" + + "CAST(NULL AS DECIMAL(30, 2)), CAST(NULL AS DECIMAL(30, 2))))), " + + "(array(array(), array(CAST(NULL AS DECIMAL(30, 2)))))") { + assertCodegenRan { + checkSparkAnswerAndOperator(sql("SELECT idDec30(array_max(flatten(a))) FROM t")) + } + } + } + + // Runtime coverage for nullable nested `getStruct` / `getArray` / `getMap` element reads is + // exercised through HOFs in `CometCodegenHOFSuite`. Static emitter assertions live in + // `CometCodegenSourceSuite`. +} + +/** + * Case class used by the struct-input / struct-output smoke tests. Must be declared at file scope + * (not inside the test class) so Spark's TypeTag-based UDF encoder can resolve the Spark + * `StructType` schema from the Scala class. + */ +private case class NameAgePair(name: String, age: Int) + +private case class NameItems(name: String, items: Seq[Int]) + +private case class XyPair(x: Int, y: String) diff --git a/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala b/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala index 9622960932..53454d0034 100644 --- a/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometIcebergRewriteActionSuite.scala @@ -65,11 +65,11 @@ class CometIcebergRewriteActionSuite extends CometTestBase with CometIcebergTest } // Single-column zOrder is bit-pattern-equivalent to a natural sort (no second dimension to - // interleave with), so we expect the same ascending output as the sort test. The shuffle here - // is CometColumnarExchange rather than CometExchange because the z-value column is computed - // by a Spark Project (Iceberg's INTERLEAVE_BYTES / INT_ORDERED_BYTES are not recognised by - // Comet), so the path crosses a JVM-row boundary before the shuffle. - test("single-column zOrder rewrite runs scan, columnar exchange, and sort natively in Comet") { + // interleave with), so we expect the same ascending output as the sort test. Iceberg's + // `INT_ORDERED_BYTES` / `INTERLEAVE_BYTES` are `ScalaUDF`s that route through Comet's codegen + // dispatcher, so the project stays native and the shuffle picks `CometExchange` / + // `CometNativeShuffle` rather than the columnar-row roundtrip path. + test("single-column zOrder rewrite runs scan, native exchange, and sort natively in Comet") { runRewriteTest( RewriteCase( table = s"$catalog.db.zorder_test", @@ -77,7 +77,7 @@ class CometIcebergRewriteActionSuite extends CometTestBase with CometIcebergTest verifyDataAfter = assertSortedById, verifyPlans = { rewritePlans => assertReadsAreComet(rewritePlans) - assertOperator(rewritePlans, "CometColumnarExchange") + assertOperator(rewritePlans, "CometExchange") assertOperator(rewritePlans, "CometSort") })) } @@ -416,7 +416,8 @@ class CometIcebergRewriteActionSuite extends CometTestBase with CometIcebergTest s"spark.sql.catalog.$catalog.warehouse" -> warehouseDir.getAbsolutePath, CometConf.COMET_ENABLED.key -> "true", CometConf.COMET_EXEC_ENABLED.key -> "true", - CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true")(body) + CometConf.COMET_ICEBERG_NATIVE_ENABLED.key -> "true", + CometConf.COMET_SCALA_UDF_CODEGEN_ENABLED.key -> "true")(body) /** Creates an Iceberg table with `numFiles` separate appends, each producing one data file. */ private def createMultiFileTable(table: String, numFiles: Int): Unit = {