diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index e12792d1d1..17603ea491 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -361,6 +361,7 @@ jobs: org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite org.apache.comet.CometFuzzAggregateSuite + org.apache.spark.sql.comet.execution.arrow.CometArrowStreamSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index 41285dc8fd..04a6c19078 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -177,6 +177,7 @@ jobs: org.apache.spark.sql.CometToPrettyStringSuite org.apache.spark.sql.CometCollationSuite org.apache.comet.CometFuzzAggregateSuite + org.apache.spark.sql.comet.execution.arrow.CometArrowStreamSuite - name: "expressions" value: | org.apache.comet.CometExpressionSuite diff --git a/docs/source/contributor-guide/native_shuffle.md b/docs/source/contributor-guide/native_shuffle.md index 18e80a90c8..c46b8b45a8 100644 --- a/docs/source/contributor-guide/native_shuffle.md +++ b/docs/source/contributor-guide/native_shuffle.md @@ -69,8 +69,9 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ▼ ┌─────────────────────────────────────────────────────────────────────────────┐ │ CometNativeShuffleWriter │ -│ - Constructs protobuf operator plan │ -│ - Invokes native execution via CometExec.getCometIterator() │ +│ - Builds protobuf operator plan: ShuffleWriter(child = childNativeOp) │ +│ - Reads per-partition leaf iterators from CometNativeShuffleInputIterator │ +│ - Drives one CometExecIterator per partition │ └─────────────────────────────────────────────────────────────────────────────┘ │ ▼ (JNI) @@ -103,13 +104,14 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ### Scala Side -| Class | Location | Description | -| ------------------------------ | ------------------------------------------------ | --------------------------------------------------------------------------------------------- | -| `CometShuffleExchangeExec` | `.../shuffle/CometShuffleExchangeExec.scala` | Physical plan node. Validates types and partitioning, creates `CometShuffleDependency`. | -| `CometNativeShuffleWriter` | `.../shuffle/CometNativeShuffleWriter.scala` | Implements `ShuffleWriter`. Builds protobuf plan and invokes native execution. | -| `CometShuffleDependency` | `.../shuffle/CometShuffleDependency.scala` | Extends `ShuffleDependency`. Holds shuffle type, schema, and range partition bounds. | -| `CometBlockStoreShuffleReader` | `.../shuffle/CometBlockStoreShuffleReader.scala` | Reads shuffle blocks via `ShuffleBlockFetcherIterator`. Decodes Arrow IPC to `ColumnarBatch`. | -| `NativeBatchDecoderIterator` | `.../shuffle/NativeBatchDecoderIterator.scala` | Reads compressed Arrow IPC from input stream. Calls native decode via JNI. | +| Class | Location | Description | +| ------------------------------ | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------- | +| `CometShuffleExchangeExec` | `.../shuffle/CometShuffleExchangeExec.scala` | Physical plan node. Validates types and partitioning, creates `CometShuffleDependency`. | +| `CometNativeShuffleWriter` | `.../shuffle/CometNativeShuffleWriter.scala` | Implements `ShuffleWriter`. Builds the unified `ShuffleWriter(child = childNativeOp)` plan and runs it in one `CometExecIterator` per partition. | +| `CometShuffleDependency` | `.../shuffle/CometShuffleDependency.scala` | Extends `ShuffleDependency`. Holds shuffle type, schema, range partition bounds, and (native shuffle only) a `NativeShuffleSpec`. | +| `CometNativeShuffleInputRDD` | `.../shuffle/CometNativeShuffleInputRDD.scala` | Thin scheduling-anchor RDD on the native-shuffle path. `compute` returns a `CometNativeShuffleInputIterator` carrying per-partition leaf iterators. | +| `CometBlockStoreShuffleReader` | `.../shuffle/CometBlockStoreShuffleReader.scala` | Reads shuffle blocks via `ShuffleBlockFetcherIterator`. Decodes Arrow IPC to `ColumnarBatch`. | +| `NativeBatchDecoderIterator` | `.../shuffle/NativeBatchDecoderIterator.scala` | Reads compressed Arrow IPC from input stream. Calls native decode via JNI. | ### Rust Side @@ -123,11 +125,19 @@ Native shuffle (`CometExchange`) is selected when all of the following condition ### Write Path -1. **Plan construction**: `CometNativeShuffleWriter` builds a protobuf operator plan containing: - - A scan operator reading from the input iterator - - A `ShuffleWriter` operator with partitioning config and compression codec - -2. **Native execution**: `CometExec.getCometIterator()` executes the plan in Rust. +1. **Plan construction**: `CometNativeShuffleWriter` builds a protobuf operator tree with a + `ShuffleWriter` operator at the root and `childNativeOp` as its child. `childNativeOp` takes + one of two shapes: + - The child plan's `nativeOp` directly, when `CometShuffleExchangeExec`'s child is a + `CometNativeExec` subtree. The upstream operators run inside the same `CometExecIterator` + as the writer, with no JVM-to-native batch boundary between them. + - A synthetic `Scan("ShuffleWriterInput")` placeholder, when the dep was built via the + convenience `prepareShuffleDependency(rdd, ...)` overload (used by + `CometCollectLimitExec` and `CometTakeOrderedAndProjectExec`, or when the + exchange's child is a non-native `CometPlan` such as `CometSparkToColumnarExec`). Native + code reads `ColumnarBatch`es from the JVM input iterator via Arrow C Stream Interface. + +2. **Native execution**: A single `CometExecIterator` per partition runs the unified plan. 3. **Partitioning**: `ShuffleWriterExec` receives batches and routes to the appropriate partitioner: - `MultiPartitionShuffleRepartitioner`: For hash/range/round-robin partitioning diff --git a/native/core/src/execution/operators/aligned_stream_reader.rs b/native/core/src/execution/operators/aligned_stream_reader.rs new file mode 100644 index 0000000000..c1d615a79f --- /dev/null +++ b/native/core/src/execution/operators/aligned_stream_reader.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{RecordBatch, RecordBatchOptions, StructArray}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; +use arrow::error::ArrowError; +use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi_stream::FFI_ArrowArrayStream; +use std::ffi::CStr; +use std::sync::Arc; + +/// C Stream Interface reader that calls [`arrow::array::ArrayData::align_buffers`] on every +/// imported batch before constructing typed arrays. Stock `ArrowArrayStreamReader` panics +/// when a JVM producer hands us a `Decimal128` buffer at an offset that is 8-byte but not +/// 16-byte aligned, which Java's allocator does not guarantee. Track upstream: +/// . +#[derive(Debug)] +pub struct AlignedArrowStreamReader { + stream: FFI_ArrowArrayStream, + schema: SchemaRef, +} + +impl AlignedArrowStreamReader { + /// # Safety + /// `raw` must point at a valid `FFI_ArrowArrayStream` whose ownership is being transferred + /// to this reader. The stream's release callback fires when the reader is dropped. + pub unsafe fn from_raw(raw: *mut FFI_ArrowArrayStream) -> Result { + let mut stream = FFI_ArrowArrayStream::from_raw(raw); + if stream.release.is_none() { + return Err(ArrowError::CDataInterface( + "input stream is already released".to_string(), + )); + } + let schema = read_schema(&mut stream)?; + Ok(Self { stream, schema }) + } + + pub fn schema(&self) -> SchemaRef { + Arc::clone(&self.schema) + } + + fn last_error(&mut self) -> Option { + let get = self.stream.get_last_error?; + let ptr = unsafe { get(&mut self.stream) }; + if ptr.is_null() { + return None; + } + Some( + unsafe { CStr::from_ptr(ptr) } + .to_string_lossy() + .into_owned(), + ) + } +} + +impl Iterator for AlignedArrowStreamReader { + type Item = Result; + + fn next(&mut self) -> Option { + let mut array = FFI_ArrowArray::empty(); + let ret = unsafe { self.stream.get_next.unwrap()(&mut self.stream, &mut array) }; + if ret != 0 { + let msg = self + .last_error() + .unwrap_or_else(|| format!("get_next returned {ret}")); + return Some(Err(ArrowError::CDataInterface(msg))); + } + if array.is_released() { + return None; + } + + let dt = DataType::Struct(self.schema.fields().clone()); + Some( + unsafe { from_ffi_and_data_type(array, dt) }.and_then(|mut data| { + data.align_buffers(); + let len = data.len(); + RecordBatch::try_new_with_options( + Arc::clone(&self.schema), + StructArray::from(data).into_parts().1, + &RecordBatchOptions::new().with_row_count(Some(len)), + ) + }), + ) + } +} + +fn read_schema(stream: &mut FFI_ArrowArrayStream) -> Result { + let mut schema = FFI_ArrowSchema::empty(); + let ret = unsafe { stream.get_schema.unwrap()(stream, &mut schema) }; + if ret != 0 { + return Err(ArrowError::CDataInterface(format!( + "Cannot get schema from input stream. Error code: {ret}" + ))); + } + Ok(Arc::new(Schema::try_from(&schema)?)) +} diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index 4b2c06575d..d68252bd9b 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -19,10 +19,12 @@ pub use crate::errors::ExecutionError; +pub use aligned_stream_reader::*; pub use copy::*; pub use iceberg_scan::*; pub use scan::*; +mod aligned_stream_reader; mod copy; mod expand; pub use expand::ExpandExec; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index e318d9e66b..2ef32f6a13 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -15,19 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::execution::operators::{copy_array, copy_or_unpack_array, CopyMode}; -use crate::{ - errors::CometError, - execution::{ - operators::ExecutionError, planner::TEST_EXEC_CONTEXT_ID, utils::SparkArrowConvert, - }, - jvm_bridge::JVMClasses, -}; -use arrow::array::{make_array, ArrayData, ArrayRef, RecordBatch, RecordBatchOptions}; +use crate::execution::operators::{copy_or_unpack_array, AlignedArrowStreamReader, CopyMode}; +use crate::{errors::CometError, execution::planner::TEST_EXEC_CONTEXT_ID}; +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; -use arrow::ffi::FFI_ArrowArray; -use arrow::ffi::FFI_ArrowSchema; use datafusion::common::{arrow_datafusion_err, DataFusionError, Result as DataFusionResult}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::metrics::{ @@ -40,8 +32,6 @@ use datafusion::{ }; use futures::Stream; use itertools::Itertools; -use jni::objects::{Global, JObject, JValue}; -use std::rc::Rc; use std::{ any::Any, pin::Pin, @@ -49,43 +39,34 @@ use std::{ task::{Context, Poll}, }; -/// ScanExec reads batches of data from Spark via JNI. The source of the scan could be a file -/// scan or the result of reading a broadcast or shuffle exchange. ScanExec isn't invoked -/// until the data is already available in the JVM. When CometExecIterator invokes -/// Native.executePlan, it passes in the memory addresses of the input batches. +/// `ScanExec` reads batches of data from Spark over the Arrow C Stream Interface. The +/// `input_source` is moved out of the JVM-exported `ArrowArrayStream` at plan-construction time; +/// dropping the reader (when this exec drops) fires the stream's release callback, which closes +/// the JVM-side `ArrowReader` and its `VectorSchemaRoot`. #[derive(Debug, Clone)] pub struct ScanExec { - /// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM - /// environment `JNIEnv` from the execution context. + /// JVM execution-context id used to look up the `JNIEnv` for callbacks. pub exec_context_id: i64, - /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. - pub input_source: Option>>>, - /// A description of the input source for informational purposes + /// The C Stream Interface reader. `None` only in unit tests that seed input via + /// `set_input_batch`. + pub input_source: Option>>, pub input_source_description: String, - /// The data types of columns of the input batch. Converted from Spark schema. pub data_types: Vec, - /// Schema of first batch pub schema: SchemaRef, - /// The input batch of input data. Used to determine the schema of the input data. - /// It is also used in unit test to mock the input data from JVM. + /// Used in unit tests to mock the input batch; otherwise written by `pull_next` on each + /// poll. pub batch: Arc>>, - /// Cache of expensive-to-compute plan properties cache: Arc, - /// Metrics collector metrics: ExecutionPlanMetricsSet, - /// Baseline metrics baseline_metrics: BaselineMetrics, - /// Whether native code can assume ownership of batches that it receives - arrow_ffi_safe: bool, } impl ScanExec { pub fn new( exec_context_id: i64, - input_source: Option>>>, + input_source: Option>>, input_source_description: &str, data_types: Vec, - arrow_ffi_safe: bool, ) -> Result { let metrics_set = ExecutionPlanMetricsSet::default(); let baseline_metrics = BaselineMetrics::new(&metrics_set, 0); @@ -112,7 +93,6 @@ impl ScanExec { metrics: metrics_set, baseline_metrics, schema, - arrow_ffi_safe, }) } @@ -131,22 +111,18 @@ impl ScanExec { *self.batch.try_lock().unwrap() = Some(input); } - /// Pull next input batch from JVM. + /// Pull next input batch from the upstream `ArrowArrayStreamReader`. pub fn get_next_batch(&mut self) -> Result<(), CometError> { if self.input_source.is_none() { - // This is a unit test. We don't need to call JNI. + // This is a unit test. Input batches are seeded via `set_input_batch`. return Ok(()); } let mut timer = self.baseline_metrics.elapsed_compute().timer(); let mut current_batch = self.batch.try_lock().unwrap(); if current_batch.is_none() { - let next_batch = ScanExec::get_next( - self.exec_context_id, - self.input_source.as_ref().unwrap().as_obj(), - self.data_types.len(), - self.arrow_ffi_safe, - )?; + let next_batch = + ScanExec::pull_next(self.exec_context_id, self.input_source.as_ref().unwrap())?; *current_batch = Some(next_batch); } @@ -155,119 +131,35 @@ impl ScanExec { Ok(()) } - /// Invokes JNI call to get next batch. - fn get_next( + /// Pull the next `RecordBatch` from the stream and convert it to an `InputBatch`. Dictionary + /// columns are unpacked because Comet's downstream operators do not handle them. + fn pull_next( exec_context_id: i64, - iter: &JObject, - num_cols: usize, - arrow_ffi_safe: bool, + reader: &Arc>, ) -> Result { if exec_context_id == TEST_EXEC_CONTEXT_ID { - // This is a unit test. We don't need to call JNI. + // Unit test path; input batches are seeded directly. return Ok(InputBatch::EOF); } - if iter.is_null() { - return Err(CometError::from(ExecutionError::GeneralError(format!( - "Null batch iterator object. Plan id: {exec_context_id}" - )))); - } - - JVMClasses::with_env(|env| { - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).has_next() -> i32)? - }; - - if num_rows == -1 { - return Ok(InputBatch::EOF); - } - - // fetch batch data from JVM via FFI - let (num_rows, array_addrs, schema_addrs) = - Self::allocate_and_fetch_batch(env, iter, num_cols)?; - - let mut inputs: Vec = Vec::with_capacity(num_cols); - - // Process each column - for i in 0..num_cols { - let array_ptr = array_addrs[i]; - let schema_ptr = schema_addrs[i]; - let array_data = ArrayData::from_spark((array_ptr, schema_ptr))?; - - // TODO: validate array input data - // array_data.validate_full()?; - - let array = make_array(array_data); - - let array = if arrow_ffi_safe { - // ownership of this array has been transferred to native - // but we still need to unpack dictionary arrays - copy_or_unpack_array(&array, &CopyMode::UnpackOrClone)? - } else { - // it is necessary to copy the array because the contents may be - // overwritten on the JVM side in the future - copy_array(&array) - }; - - inputs.push(array); - - // Drop the Arcs to avoid memory leak - unsafe { - Rc::from_raw(array_ptr as *const FFI_ArrowArray); - Rc::from_raw(schema_ptr as *const FFI_ArrowSchema); + let mut reader = reader + .try_lock() + .map_err(|_| CometError::Internal("AlignedArrowStreamReader contended".to_string()))?; + + let next = reader.next(); + match next { + None => Ok(InputBatch::EOF), + Some(Err(e)) => Err(CometError::from(e)), + Some(Ok(record_batch)) => { + let num_rows = record_batch.num_rows(); + let columns = record_batch.columns(); + let mut inputs: Vec = Vec::with_capacity(columns.len()); + for col in columns { + inputs.push(copy_or_unpack_array(col, &CopyMode::UnpackOrClone)?); } + Ok(InputBatch::new(inputs, Some(num_rows))) } - - Ok(InputBatch::new(inputs, Some(num_rows as usize))) - }) - } - - /// Allocates Arrow FFI structures and calls JNI to get the next batch data. - /// Returns the number of rows and the allocated array/schema addresses. - fn allocate_and_fetch_batch( - env: &mut jni::Env, - iter: &JObject, - num_cols: usize, - ) -> Result<(i32, Vec, Vec), CometError> { - let mut array_addrs = Vec::with_capacity(num_cols); - let mut schema_addrs = Vec::with_capacity(num_cols); - - for _ in 0..num_cols { - let arrow_array = Rc::new(FFI_ArrowArray::empty()); - let arrow_schema = Rc::new(FFI_ArrowSchema::empty()); - let (array_ptr, schema_ptr) = ( - Rc::into_raw(arrow_array) as i64, - Rc::into_raw(arrow_schema) as i64, - ); - - array_addrs.push(array_ptr); - schema_addrs.push(schema_ptr); } - - // Prepare the java array parameters - let long_array_addrs = env.new_long_array(num_cols)?; - let long_schema_addrs = env.new_long_array(num_cols)?; - - long_array_addrs.set_region(env, 0, &array_addrs)?; - long_schema_addrs.set_region(env, 0, &schema_addrs)?; - - let array_obj = JObject::from(long_array_addrs); - let schema_obj = JObject::from(long_schema_addrs); - - let array_obj = JValue::Object(array_obj.as_ref()); - let schema_obj = JValue::Object(schema_obj.as_ref()); - - let num_rows: i32 = unsafe { - jni_call!(env, - comet_batch_iterator(iter).next(array_obj, schema_obj) -> i32)? - }; - - // we already checked for end of results on call to has_next() so should always - // have a valid row count when calling next() - assert!(num_rows != -1); - - Ok((num_rows, array_addrs, schema_addrs)) } } diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index c6160bddd4..c09ed5a0ef 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -21,7 +21,9 @@ pub mod expression_registry; pub mod macros; pub mod operator_registry; +use crate::errors::CometError; use crate::execution::operators::init_csv_datasource_exec; +use crate::execution::operators::AlignedArrowStreamReader; use crate::execution::operators::IcebergScanExec; use crate::execution::{ expressions::list_positions::ListPositionsExpr, @@ -30,8 +32,9 @@ use crate::execution::{ planner::expression_registry::ExpressionRegistry, planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, - shuffle::ShuffleWriterExec, + shuffle::{SchemaAlignExec, ShuffleWriterExec}, }; +use crate::jvm_bridge::{jni_call, JVMClasses}; use arrow::compute::CastOptions; use arrow::datatypes::{DataType, Field, FieldRef, Schema, TimeUnit, DECIMAL128_MAX_PRECISION}; use datafusion::functions_aggregate::bit_and_or_xor::{bit_and_udaf, bit_or_udaf, bit_xor_udaf}; @@ -1161,45 +1164,12 @@ impl PhysicalPlanner { Arc::clone(&schema), )?, ); - let result_exprs: PhyExprResult = agg - .result_exprs - .iter() - .enumerate() - .map(|(idx, expr)| { - self.create_expr(expr, aggregate.schema()) - .map(|r| (r, format!("col_{idx}"))) - }) - .collect(); - if agg.result_exprs.is_empty() { - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), - )) - } else { - // For final aggregation, DF's hash aggregate exec doesn't support Spark's - // aggregate result expressions like `COUNT(col) + 1`, but instead relying - // on additional `ProjectionExec` to handle the case. Therefore, here we'll - // add a projection node on top of the aggregate node. - // - // Note that `result_exprs` should only be set for final aggregation on the - // Spark side. - let projection = Arc::new(ProjectionExec::try_new( - result_exprs?, - Arc::clone(&aggregate), - )?); - Ok(( - scans, - shuffle_scans, - Arc::new(SparkPlan::new_with_additional( - spark_plan.plan_id, - projection, - vec![child], - vec![aggregate], - )), - )) - } + Ok(( + scans, + shuffle_scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, aggregate, vec![child])), + )) } OpStruct::Limit(limit) => { assert_eq!(children.len(), 1); @@ -1447,23 +1417,36 @@ impl PhysicalPlanner { return Err(GeneralError("No input for scan".to_string())); } - // Consumes the first input source for the scan - let input_source = - if self.exec_context_id == TEST_EXEC_CONTEXT_ID && inputs.is_empty() { - // For unit test, we will set input batch to scan directly by `set_input_batch`. - None - } else { - Some(inputs.remove(0)) - }; + // Consumes the first input source for the scan. The Java side passes an + // `org.apache.arrow.c.ArrowArrayStream` whose `memoryAddress` points at the C + // struct; native takes ownership via `AlignedArrowStreamReader::from_raw`. + let input_source = if self.exec_context_id == TEST_EXEC_CONTEXT_ID + && inputs.is_empty() + { + // For unit test, we will set input batch to scan directly by `set_input_batch`. + None + } else { + let java_stream = inputs.remove(0); + let address: i64 = JVMClasses::with_env(|env| -> Result { + let addr = unsafe { + jni_call!(env, arrow_array_stream(java_stream.as_obj()).memory_address() -> i64)? + }; + Ok(addr) + })?; + let reader = unsafe { + AlignedArrowStreamReader::from_raw( + address as *mut arrow::ffi_stream::FFI_ArrowArrayStream, + ) + } + .map_err(|e| { + GeneralError(format!("Failed to import ArrowArrayStream from JVM: {e}")) + })?; + Some(Arc::new(std::sync::Mutex::new(reader))) + }; // The `ScanExec` operator will take actual arrays from Spark during execution - let scan = ScanExec::new( - self.exec_context_id, - input_source, - &scan.source, - data_types, - scan.arrow_ffi_safe, - )?; + let scan = + ScanExec::new(self.exec_context_id, input_source, &scan.source, data_types)?; Ok(( vec![scan.clone()], @@ -1515,9 +1498,14 @@ impl PhysicalPlanner { let (scans, shuffle_scans, child) = self.create_plan(&children[0], inputs, partition_count)?; + let writer_input = align_shuffle_writer_input( + Arc::clone(&child.native_plan), + &writer.expected_output_schema, + )?; + let partitioning = self.create_partitioning( writer.partitioning.as_ref().unwrap(), - child.native_plan.schema(), + writer_input.schema(), )?; let codec = match writer.codec.try_into() { @@ -1535,7 +1523,7 @@ impl PhysicalPlanner { let write_buffer_size = writer.write_buffer_size as usize; let shuffle_writer = Arc::new(ShuffleWriterExec::try_new( - Arc::clone(&child.native_plan), + writer_input, partitioning, codec, writer.output_data_file.clone(), @@ -3124,6 +3112,20 @@ fn convert_spark_types_to_arrow_schema( arrow_schema } +/// Wrap `child` in a `SchemaAlignExec` when its output drifts from what Spark catalyst +/// declared. See . +fn align_shuffle_writer_input( + child: Arc, + expected_proto: &[spark_operator::SparkStructField], +) -> Result, ExecutionError> { + if expected_proto.is_empty() { + return Ok(child); + } + let expected = convert_spark_types_to_arrow_schema(expected_proto); + SchemaAlignExec::try_new_or_passthrough(child, &expected) + .map_err(|e| ExecutionError::DataFusionError(e.to_string())) +} + /// Converts a protobuf PartitionValue to an iceberg Literal. /// fn partition_value_to_literal( @@ -3982,7 +3984,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4048,7 +4049,6 @@ mod tests { type_info: None, }], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4258,7 +4258,6 @@ mod tests { op_struct: Some(OpStruct::Scan(spark_operator::Scan { fields: vec![create_proto_datatype()], source: "".to_string(), - arrow_ffi_safe: false, })), } } @@ -4301,7 +4300,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4424,7 +4422,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; @@ -4907,7 +4904,6 @@ mod tests { }, ], source: "".to_string(), - arrow_ffi_safe: false, })), }; diff --git a/native/core/src/execution/utils.rs b/native/core/src/execution/utils.rs index 2fe6f8758f..6195e3f0ae 100644 --- a/native/core/src/execution/utils.rs +++ b/native/core/src/execution/utils.rs @@ -19,48 +19,15 @@ use crate::execution::operators::ExecutionError; use arrow::{ array::ArrayData, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{FFI_ArrowArray, FFI_ArrowSchema}, }; pub trait SparkArrowConvert { - /// Build Arrow Arrays from C data interface passed from Spark. - /// It accepts a tuple (ArrowArray address, ArrowSchema address). - fn from_spark(addresses: (i64, i64)) -> Result - where - Self: Sized; - /// Move Arrow Arrays to C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError>; } impl SparkArrowConvert for ArrayData { - fn from_spark(addresses: (i64, i64)) -> Result { - let (array_ptr, schema_ptr) = addresses; - - let array_ptr = array_ptr as *mut FFI_ArrowArray; - let schema_ptr = schema_ptr as *mut FFI_ArrowSchema; - - if array_ptr.is_null() || schema_ptr.is_null() { - return Err(ExecutionError::ArrowError( - "At least one of passed pointers is null".to_string(), - )); - }; - - // `ArrowArray` will convert raw pointers back to `Arc`. No worries - // about memory leak. - let mut ffi_array = unsafe { - let array_data = std::ptr::replace(array_ptr, FFI_ArrowArray::empty()); - let schema_data = std::ptr::replace(schema_ptr, FFI_ArrowSchema::empty()); - - from_ffi(array_data, &schema_data)? - }; - - // Align imported buffers from Java. - ffi_array.align_buffers(); - - Ok(ffi_array) - } - /// Move this ArrowData to pointers of Arrow C data interface. fn move_to_spark(&self, array: i64, schema: i64) -> Result<(), ExecutionError> { let array_ptr = array as *mut FFI_ArrowArray; diff --git a/native/jni-bridge/src/batch_iterator.rs b/native/jni-bridge/src/arrow_array_stream.rs similarity index 57% rename from native/jni-bridge/src/batch_iterator.rs rename to native/jni-bridge/src/arrow_array_stream.rs index addda133fa..2cfea73688 100644 --- a/native/jni-bridge/src/batch_iterator.rs +++ b/native/jni-bridge/src/arrow_array_stream.rs @@ -15,45 +15,38 @@ // specific language governing permissions and limitations // under the License. -use jni::signature::Primitive; use jni::{ errors::Result as JniResult, objects::{JClass, JMethodID}, - signature::ReturnType, + signature::{Primitive, ReturnType}, strings::JNIString, Env, }; -/// A struct that holds all the JNI methods and fields for JVM `CometBatchIterator` class. +/// A struct that holds all the JNI methods and fields for JVM `org.apache.arrow.c.ArrowArrayStream` +/// class. `memoryAddress()` is read once per partition so native can take ownership of the +/// underlying C struct via `AlignedArrowStreamReader::from_raw`. #[allow(dead_code)] // we need to keep references to Java items to prevent GC -pub struct CometBatchIterator<'a> { +pub struct ArrowArrayStream<'a> { pub class: JClass<'a>, - pub method_has_next: JMethodID, - pub method_has_next_ret: ReturnType, - pub method_next: JMethodID, - pub method_next_ret: ReturnType, + pub method_memory_address: JMethodID, + pub method_memory_address_ret: ReturnType, } -impl<'a> CometBatchIterator<'a> { - pub const JVM_CLASS: &'static str = "org/apache/comet/CometBatchIterator"; +impl<'a> ArrowArrayStream<'a> { + pub const JVM_CLASS: &'static str = "org/apache/arrow/c/ArrowArrayStream"; - pub fn new(env: &mut Env<'a>) -> JniResult> { + pub fn new(env: &mut Env<'a>) -> JniResult> { let class = env.find_class(JNIString::new(Self::JVM_CLASS))?; - Ok(CometBatchIterator { - class, - method_has_next: env.get_method_id( - JNIString::new(Self::JVM_CLASS), - jni::jni_str!("hasNext"), - jni::jni_sig!("()I"), - )?, - method_has_next_ret: ReturnType::Primitive(Primitive::Int), - method_next: env.get_method_id( + Ok(ArrowArrayStream { + method_memory_address: env.get_method_id( JNIString::new(Self::JVM_CLASS), - jni::jni_str!("next"), - jni::jni_sig!("([J[J)I"), + jni::jni_str!("memoryAddress"), + jni::jni_sig!("()J"), )?, - method_next_ret: ReturnType::Primitive(Primitive::Int), + method_memory_address_ret: ReturnType::Primitive(Primitive::Long), + class, }) } } diff --git a/native/jni-bridge/src/lib.rs b/native/jni-bridge/src/lib.rs index 490e80d076..8db4c07851 100644 --- a/native/jni-bridge/src/lib.rs +++ b/native/jni-bridge/src/lib.rs @@ -189,14 +189,14 @@ impl<'a> TryFrom> for BinaryWrapper<'a> { mod comet_exec; pub use comet_exec::*; -mod batch_iterator; +mod arrow_array_stream; mod comet_metric_node; mod comet_s3_credential_dispatcher; mod comet_task_memory_manager; mod comet_udf_bridge; mod shuffle_block_iterator; -use batch_iterator::CometBatchIterator; +use arrow_array_stream::ArrowArrayStream; pub use comet_metric_node::*; pub use comet_s3_credential_dispatcher::CometS3CredentialDispatcher; pub use comet_task_memory_manager::*; @@ -225,8 +225,9 @@ pub struct JVMClasses<'a> { pub comet_metric_node: CometMetricNode<'a>, /// The static CometExec class. Used for getting the subquery result. pub comet_exec: CometExec<'a>, - /// The CometBatchIterator class. Used for iterating over the batches. - pub comet_batch_iterator: CometBatchIterator<'a>, + /// The org.apache.arrow.c.ArrowArrayStream class. Used to get the C struct memory address + /// when importing a JVM-exported batch stream into native code. + pub arrow_array_stream: ArrowArrayStream<'a>, /// The CometShuffleBlockIterator class. Used for iterating over shuffle blocks. pub comet_shuffle_block_iterator: CometShuffleBlockIterator<'a>, /// The CometTaskMemoryManager used for interacting with JVM side to @@ -304,7 +305,7 @@ impl JVMClasses<'_> { throwable_get_cause_method, comet_metric_node: CometMetricNode::new(env).unwrap(), comet_exec: CometExec::new(env).unwrap(), - comet_batch_iterator: CometBatchIterator::new(env).unwrap(), + arrow_array_stream: ArrowArrayStream::new(env).unwrap(), comet_shuffle_block_iterator: CometShuffleBlockIterator::new(env).unwrap(), comet_task_memory_manager: CometTaskMemoryManager::new(env).unwrap(), comet_udf_bridge: { diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 7f50aa928c..a180d71309 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -85,8 +85,6 @@ message Scan { // is purely for informational purposes when viewing native query plans in // debug mode. string source = 2; - // Whether native code can assume ownership of batches that it receives - bool arrow_ffi_safe = 3; } message ShuffleScan { @@ -294,7 +292,10 @@ message Sort { message HashAggregate { repeated spark.spark_expression.Expr grouping_exprs = 1; repeated spark.spark_expression.AggExpr agg_exprs = 2; - repeated spark.spark_expression.Expr result_exprs = 3; + // Was result_exprs / apply_result_projection; now expressed as an explicit Projection + // op above HashAggregate (see CometBaseAggregate.doConvert, comet#4515). + reserved 3, 8; + reserved "result_exprs", "apply_result_projection"; AggregateMode mode = 5; // Per-expression modes for mixed-mode aggregates (e.g., PartialMerge + Partial). // When set, each entry corresponds to agg_exprs at the same index. @@ -327,6 +328,12 @@ message ShuffleWriter { // Size of the write buffer in bytes used when writing shuffle data to disk. // Larger values may improve write performance but use more memory. int32 write_buffer_size = 8; + // Spark-declared output schema of the writer's child. When the child is an inlined native + // subtree, the native planner casts the child's actual output to this schema before + // serializing to shuffle blocks, since there is no FFI boundary or ScanExec between them + // to absorb DataFusion-vs-Spark type drift. Empty when the child is a placeholder Scan; + // that path already has a cast point upstream. + repeated SparkStructField expected_output_schema = 9; } message ParquetWriter { diff --git a/native/shuffle/src/lib.rs b/native/shuffle/src/lib.rs index dd3b900272..2263ae0dac 100644 --- a/native/shuffle/src/lib.rs +++ b/native/shuffle/src/lib.rs @@ -19,6 +19,7 @@ pub(crate) mod comet_partitioning; pub mod ipc; pub(crate) mod metrics; pub(crate) mod partitioners; +mod schema_align; mod shuffle_writer; mod spark_crc32c_hasher; pub mod spark_unsafe; @@ -26,5 +27,6 @@ pub(crate) mod writers; pub use comet_partitioning::CometPartitioning; pub use ipc::read_ipc_compressed; +pub use schema_align::SchemaAlignExec; pub use shuffle_writer::ShuffleWriterExec; pub use writers::{CompressionCodec, ShuffleBlockWriter}; diff --git a/native/shuffle/src/partitioners/multi_partition.rs b/native/shuffle/src/partitioners/multi_partition.rs index 7de9314f54..40f09496c0 100644 --- a/native/shuffle/src/partitioners/multi_partition.rs +++ b/native/shuffle/src/partitioners/multi_partition.rs @@ -22,10 +22,10 @@ use crate::partitioners::partitioned_batch_iterator::{ use crate::partitioners::ShufflePartitioner; use crate::writers::{BufBatchWriter, PartitionWriter}; use crate::{comet_partitioning, CometPartitioning, CompressionCodec, ShuffleBlockWriter}; -use arrow::array::{ArrayRef, RecordBatch}; +use arrow::array::{Array, ArrayData, ArrayRef, RecordBatch}; use arrow::datatypes::SchemaRef; use datafusion::common::utils::proxy::VecAllocExt; -use datafusion::common::DataFusionError; +use datafusion::common::{DataFusionError, HashSet}; use datafusion::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::metrics::Time; @@ -125,6 +125,55 @@ pub(crate) struct MultiPartitionShuffleRepartitioner { tracing_enabled: bool, /// Size of the write buffer in bytes write_buffer_size: usize, + /// Start addresses (as `usize`, since raw pointers are not `Send`) of the backing buffers + /// currently pinned by `buffered_batches`, so the spill reservation charges each distinct + /// allocation once rather than once per slice that references it. Cleared whenever the + /// buffered batches drain (spill / shuffle_write). See `count_new_buffers`. + pinned_buffers: HashSet, +} + +/// Sum of the capacities of the backing buffers reachable from `batch` whose start address is +/// not already in `seen` (recursing through child data: dictionary values, list children, and so +/// on). `seen` is kept across every buffered batch, so this returns the bytes a batch newly +/// pins, which is the memory the shuffle writer holds resident by buffering it. +/// +/// Cheaper measures do not match resident memory for the batches this writer sees. A partial +/// `HashAggregate` emits one group-values buffer sliced into batch_size chunks, and every +/// buffered chunk shares that one allocation: +/// +/// * `RecordBatch::get_array_memory_size()` charges a buffer's capacity once per array that +/// references it, counting the shared allocation once per chunk and overstating memory by the +/// chunk count. Reserving against that figure trips the memory limit on nearly every batch +/// and spills spuriously. +/// * the sum of `ArrayData::get_slice_memory_size()` charges only the live rows of each slice, +/// but holding a slice pins its whole backing allocation. The group-values `Vec` rounds +/// capacity up to the next power of two, so that figure undercounts resident memory and lets +/// the writer hold well past its limit before spilling. +/// +/// Counting each distinct allocation once, keyed by start address, is the measure that tracks +/// resident memory regardless of how arrays share or slice their buffers. +fn count_new_buffers(batch: &RecordBatch, seen: &mut HashSet) -> usize { + fn visit(data: &ArrayData, seen: &mut HashSet, total: &mut usize) { + for buffer in data.buffers() { + if seen.insert(buffer.data_ptr().as_ptr() as usize) { + *total += buffer.capacity(); + } + } + if let Some(nulls) = data.nulls() { + let inner = nulls.inner().inner(); + if seen.insert(inner.data_ptr().as_ptr() as usize) { + *total += inner.capacity(); + } + } + for child in data.child_data() { + visit(child, seen, total); + } + } + let mut total = 0; + for column in batch.columns() { + visit(&column.to_data(), seen, &mut total); + } + total } impl MultiPartitionShuffleRepartitioner { @@ -190,6 +239,7 @@ impl MultiPartitionShuffleRepartitioner { reservation, tracing_enabled, write_buffer_size, + pinned_buffers: HashSet::new(), }) } @@ -210,9 +260,6 @@ impl MultiPartitionShuffleRepartitioner { )); } - // Update data size metric - self.metrics.data_size.add(input.get_array_memory_size()); - // NOTE: in shuffle writer exec, the output_rows metrics represents the // number of rows those are written to output data file. self.metrics.baseline.record_output(input.num_rows()); @@ -398,7 +445,11 @@ impl MultiPartitionShuffleRepartitioner { partition_row_indices: &[u32], partition_starts: &[u32], ) -> datafusion::common::Result<()> { - let mut mem_growth: usize = input.get_array_memory_size(); + // Charge both the reservation and the data_size metric for the buffers this batch newly + // pins; `count_new_buffers` dedups buffers shared across already-buffered batches. + let new_buffer_bytes = count_new_buffers(&input, &mut self.pinned_buffers); + self.metrics.data_size.add(new_buffer_bytes); + let mut mem_growth: usize = new_buffer_bytes; let buffered_partition_idx = self.buffered_batches.len() as u32; self.buffered_batches.push(input); @@ -517,6 +568,7 @@ impl MultiPartitionShuffleRepartitioner { } self.reservation.free(); + self.pinned_buffers.clear(); self.metrics.spill_count.add(1); self.metrics.spilled_bytes.add(spilled_bytes); Ok(()) @@ -560,6 +612,7 @@ impl ShufflePartitioner for MultiPartitionShuffleRepartitioner { let start_time = Instant::now(); let mut partitioned_batches = self.partitioned_batches(); + self.pinned_buffers.clear(); let num_output_partitions = self.partition_indices.len(); let mut offsets = vec![0; num_output_partitions + 1]; diff --git a/native/shuffle/src/schema_align.rs b/native/shuffle/src/schema_align.rs new file mode 100644 index 0000000000..f564afc024 --- /dev/null +++ b/native/shuffle/src/schema_align.rs @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! `SchemaAlignExec` reshapes a shuffle writer's input so each column's Arrow type and field-level +//! nullability match what Spark catalyst declared, casting where necessary. +//! +//! This concern is enclosed by shuffle on purpose: everywhere else in the native runtime, +//! return-type drift from DataFusion / `datafusion-spark` is self-healing. When a native plan's +//! output crosses back to the JVM and feeds another native plan, the consuming `ScanExec` casts +//! every imported column to the catalyst-declared type, so a wrong Arrow type never survives the +//! boundary. Shuffle is the lone exception, on two counts: +//! +//! 1. The writer hash-partitions on these columns, and Spark's hash differs by type (e.g. `Int32` +//! vs `Int64`), so a drifted type would route rows to the wrong partition. A read-side cast +//! cannot undo a wrong partition assignment, so the type must be corrected before partitioning. +//! 2. The shuffle read path (`ShuffleScanExec`) does not cast; it stamps the catalyst schema onto +//! the decoded block and errors on any mismatch. The schema is serialized into the block on +//! write and trusted on read. +//! +//! Both force the alignment to happen on the writer input. See +//! for the running list of mismatched +//! functions. + +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use arrow::compute::{cast_with_options, CastOptions}; +use arrow::datatypes::{Field, Schema, SchemaRef}; +use datafusion::common::DataFusionError; +use datafusion::physical_expr::EquivalenceProperties; +use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion::{ + execution::TaskContext, + physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, ExecutionPlanProperties, PlanProperties, + RecordBatchStream, SendableRecordBatchStream, + }, +}; +use futures::{Stream, StreamExt}; +use std::{ + any::Any, + collections::HashSet, + pin::Pin, + sync::{Arc, Mutex, OnceLock}, + task::{Context, Poll}, +}; + +/// Process-wide set of `(column, actual, expected)` signatures we have already warned about. +/// Each schema drift produces the same warning on every partition of every query that runs +/// the offending expression; deduping here keeps logs readable while still surfacing each +/// distinct mismatch once. +fn warn_dedup() -> &'static Mutex> { + static SET: OnceLock>> = OnceLock::new(); + SET.get_or_init(|| Mutex::new(HashSet::new())) +} + +/// Casts each column of `child`'s output to the data_type Spark catalyst declared, widening +/// nullability to `actual.nullable || expected.nullable`. See +/// . +#[derive(Debug)] +pub struct SchemaAlignExec { + child: Arc, + target_schema: SchemaRef, + column_actions: Arc>, + cache: Arc, +} + +#[derive(Debug, Clone)] +enum ColumnAction { + /// Pass the input column through unchanged. Any nullability/metadata difference is + /// absorbed when the batch is re-stamped via `RecordBatch::try_new_with_options`. + Passthrough, + /// Cast the input column to the target data_type. + Cast, +} + +impl SchemaAlignExec { + /// Build a SchemaAlignExec that aligns `child`'s output to `expected`. Returns + /// `Ok(child)` unchanged when no per-column reshape is needed; otherwise wraps `child` + /// in a SchemaAlignExec whose target schema preserves `expected`'s data_type and metadata + /// but widens nullability to `actual.nullable || expected.nullable`. + pub fn try_new_or_passthrough( + child: Arc, + expected: &SchemaRef, + ) -> Result, DataFusionError> { + let actual = child.schema(); + if actual.fields().len() != expected.fields().len() { + return Err(DataFusionError::Plan(format!( + "SchemaAlignExec: expected {} fields, child produces {}", + expected.fields().len(), + actual.fields().len() + ))); + } + let mut needs_alignment = false; + let mut actions = Vec::with_capacity(actual.fields().len()); + let mut target_fields = Vec::with_capacity(actual.fields().len()); + for (idx, (actual_field, expected_field)) in actual + .fields() + .iter() + .zip(expected.fields().iter()) + .enumerate() + { + let action = if actual_field.data_type() == expected_field.data_type() { + ColumnAction::Passthrough + } else { + let signature = format!( + "{}|{:?}|{:?}", + expected_field.name(), + actual_field.data_type(), + expected_field.data_type() + ); + if warn_dedup().lock().unwrap().insert(signature) { + log::warn!( + "ShuffleWriter input schema mismatch on col[{idx}] '{}': child produced \ + {:?}, catalyst declared {:?}. Inserting a cast; please file the upstream \ + function bug at https://github.com/apache/datafusion-comet/issues/4515.", + expected_field.name(), + actual_field.data_type(), + expected_field.data_type() + ); + } + ColumnAction::Cast + }; + let target_nullable = actual_field.is_nullable() || expected_field.is_nullable(); + let field_changed = !matches!(action, ColumnAction::Passthrough) + || target_nullable != actual_field.is_nullable() + || expected_field.metadata() != actual_field.metadata() + || expected_field.name() != actual_field.name(); + if field_changed { + needs_alignment = true; + } + target_fields.push( + Field::new( + expected_field.name(), + expected_field.data_type().clone(), + target_nullable, + ) + .with_metadata(expected_field.metadata().clone()), + ); + actions.push(action); + } + if !needs_alignment { + return Ok(child); + } + let target_schema: SchemaRef = Arc::new(Schema::new(target_fields)); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&target_schema)), + child.output_partitioning().clone(), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(Arc::new(Self { + child, + target_schema, + column_actions: Arc::new(actions), + cache, + })) + } +} + +impl DisplayAs for SchemaAlignExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "CometSchemaAlignExec") + } + DisplayFormatType::TreeRender => unimplemented!(), + } + } +} + +impl ExecutionPlan for SchemaAlignExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + Arc::clone(&self.target_schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion::common::Result> { + // Rebuild PlanProperties from the new child since `output_partitioning` may differ. + let new_child = Arc::clone(&children[0]); + let cache = Arc::new(PlanProperties::new( + EquivalenceProperties::new(Arc::clone(&self.target_schema)), + new_child.output_partitioning().clone(), + EmissionType::Incremental, + Boundedness::Bounded, + )); + Ok(Arc::new(Self { + child: new_child, + target_schema: Arc::clone(&self.target_schema), + column_actions: Arc::clone(&self.column_actions), + cache, + })) + } + + fn execute( + &self, + partition: usize, + context: Arc, + ) -> datafusion::common::Result { + let child_stream = self.child.execute(partition, context)?; + Ok(Box::pin(SchemaAlignStream { + child_stream, + target_schema: Arc::clone(&self.target_schema), + column_actions: Arc::clone(&self.column_actions), + })) + } + + fn properties(&self) -> &Arc { + &self.cache + } + + fn name(&self) -> &str { + "CometSchemaAlignExec" + } +} + +struct SchemaAlignStream { + child_stream: SendableRecordBatchStream, + target_schema: SchemaRef, + column_actions: Arc>, +} + +impl SchemaAlignStream { + fn align(&self, batch: RecordBatch) -> Result { + let mut columns: Vec = Vec::with_capacity(batch.num_columns()); + for (idx, action) in self.column_actions.iter().enumerate() { + let column = batch.column(idx); + let aligned = match action { + ColumnAction::Passthrough => Arc::clone(column), + ColumnAction::Cast => cast_with_options( + column, + self.target_schema.field(idx).data_type(), + &CastOptions::default(), + )?, + }; + columns.push(aligned); + } + let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows())); + RecordBatch::try_new_with_options(Arc::clone(&self.target_schema), columns, &options) + .map_err(DataFusionError::from) + } +} + +impl Stream for SchemaAlignStream { + type Item = datafusion::common::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.child_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => Poll::Ready(Some(self.align(batch))), + other => other, + } + } +} + +impl RecordBatchStream for SchemaAlignStream { + fn schema(&self) -> SchemaRef { + Arc::clone(&self.target_schema) + } +} diff --git a/native/shuffle/src/shuffle_writer.rs b/native/shuffle/src/shuffle_writer.rs index 8502c79624..7b6f4ca7e2 100644 --- a/native/shuffle/src/shuffle_writer.rs +++ b/native/shuffle/src/shuffle_writer.rs @@ -29,7 +29,7 @@ use datafusion::physical_expr::{EquivalenceProperties, Partitioning}; use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType}; use datafusion::physical_plan::EmptyRecordBatchStream; use datafusion::{ - arrow::{datatypes::SchemaRef, error::ArrowError}, + arrow::datatypes::SchemaRef, error::Result, execution::context::TaskContext, physical_plan::{ @@ -38,7 +38,7 @@ use datafusion::{ DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, SendableRecordBatchStream, }, }; -use futures::{StreamExt, TryFutureExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use std::{ any::Any, fmt, @@ -171,23 +171,23 @@ impl ExecutionPlan for ShuffleWriterExec { let input = self.input.execute(partition, Arc::clone(&context))?; let metrics = ShufflePartitionerMetrics::new(&self.metrics, 0); + // Propagate DataFusionError unchanged: the JNI bridge only downcasts a single + // `DataFusionError::External(SparkError)` layer, so any extra wrap here loses the + // typed exception (e.g. SparkArithmeticException on decimal overflow). Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::once( - external_shuffle( - input, - partition, - self.output_data_file.clone(), - self.output_index_file.clone(), - self.partitioning.clone(), - metrics, - context, - self.codec.clone(), - self.tracing_enabled, - self.write_buffer_size, - ) - .map_err(|e| ArrowError::ExternalError(Box::new(e))), - ) + futures::stream::once(external_shuffle( + input, + partition, + self.output_data_file.clone(), + self.output_index_file.clone(), + self.partitioning.clone(), + metrics, + context, + self.codec.clone(), + self.tracing_enabled, + self.write_buffer_size, + )) .try_flatten(), ))) } @@ -267,7 +267,7 @@ async fn external_shuffle( mod test { use super::*; use crate::{read_ipc_compressed, ShuffleBlockWriter}; - use arrow::array::{Array, StringArray, StringBuilder}; + use arrow::array::{Array, Int64Array, StringArray, StringBuilder}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use arrow::row::{RowConverter, SortField}; @@ -389,6 +389,60 @@ mod test { repartitioner.insert_batch(batch.clone()).await.unwrap(); } + #[tokio::test] + async fn shuffle_partitioner_charges_shared_buffer_once() { + // `insert_batch` slices a large batch into batch_size chunks that all share one backing + // buffer (the shape a partial HashAggregate's sliced emit hands the writer). The + // reservation and the data_size metric must charge that buffer once, not once per chunk; + // otherwise the chunk count multiplies it and a batch well under the memory limit spills + // spuriously and reports a wildly inflated data_size. + let n = 16_384usize; + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int64, false)])); + let backing = RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int64Array::from_iter_values(0..n as i64))], + ) + .unwrap(); + let buffer_bytes = backing.get_array_memory_size(); + + let memory_limit = 512 * 1024; + let batch_size = 1024; // 16 chunks, all sharing the one backing buffer + let num_partitions = 2; + let runtime_env = create_runtime(memory_limit); + let metrics_set = ExecutionPlanMetricsSet::new(); + let metrics = ShufflePartitionerMetrics::new(&metrics_set, 0); + let data_size = metrics.data_size.clone(); + let spill_count = metrics.spill_count.clone(); + let dir = tempfile::tempdir().unwrap(); + let mut repartitioner = MultiPartitionShuffleRepartitioner::try_new( + 0, + dir.path().join("data.out").to_str().unwrap().to_string(), + dir.path().join("index.out").to_str().unwrap().to_string(), + backing.schema(), + CometPartitioning::Hash(vec![Arc::new(Column::new("a", 0))], num_partitions), + metrics, + runtime_env, + batch_size, + CompressionCodec::Lz4Frame, + false, + 1024 * 1024, + ) + .unwrap(); + + repartitioner.insert_batch(backing).await.unwrap(); + + assert!( + data_size.value() <= 2 * buffer_bytes, + "data_size {} should charge the shared buffer about once (~{buffer_bytes} bytes), not per chunk", + data_size.value() + ); + assert_eq!( + spill_count.value(), + 0, + "one buffer under the memory limit must not spill once per chunk" + ); + } + fn create_runtime(memory_limit: usize) -> Arc { Arc::new( RuntimeEnvBuilder::new() diff --git a/spark/pom.xml b/spark/pom.xml index d1613460ba..7dd3c6fe33 100644 --- a/spark/pom.xml +++ b/spark/pom.xml @@ -469,14 +469,13 @@ under the License. org.apache.arrow ${comet.shade.packageName}.arrow - - org/apache/arrow/c/jni/JniWrapper - org/apache/arrow/c/jni/PrivateData - org/apache/arrow/c/jni/CDataJniException - - org/apache/arrow/c/ArrayStreamExporter$ExportedArrayStreamPrivateData + + org/apache/arrow/c/** diff --git a/spark/src/main/java/org/apache/comet/CometBatchIterator.java b/spark/src/main/java/org/apache/comet/CometBatchIterator.java deleted file mode 100644 index 9b48a47c57..0000000000 --- a/spark/src/main/java/org/apache/comet/CometBatchIterator.java +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.comet; - -import scala.collection.Iterator; - -import org.apache.spark.sql.vectorized.ColumnarBatch; - -import org.apache.comet.vector.NativeUtil; - -/** - * Iterator for fetching batches from JVM to native code. Usually called via JNI from native - * ScanExec. - * - *

Batches are owned by the JVM. Native code can safely access the batch after calling `next` but - * the native code must not retain references to the batch because the next call to `hasNext` will - * signal to the JVM that the batch can be closed. - */ -public class CometBatchIterator { - private final Iterator input; - private final NativeUtil nativeUtil; - private ColumnarBatch previousBatch = null; - private ColumnarBatch currentBatch = null; - - CometBatchIterator(Iterator input, NativeUtil nativeUtil) { - this.input = input; - this.nativeUtil = nativeUtil; - } - - /** - * Fetch the next input batch and allow the previous batch to be closed (this may not happen - * immediately). - * - * @return Number of rows in next batch or -1 if no batches left. - */ - public int hasNext() { - - // release reference to previous batch - previousBatch = null; - - if (currentBatch == null) { - if (input.hasNext()) { - currentBatch = input.next(); - } - } - if (currentBatch == null) { - return -1; - } else { - return currentBatch.numRows(); - } - } - - /** - * Get the next batch of Arrow arrays. - * - * @param arrayAddrs The addresses of the ArrowArray structures. - * @param schemaAddrs The addresses of the ArrowSchema structures. - * @return the number of rows of the current batch. -1 if there is no more batch. - */ - public int next(long[] arrayAddrs, long[] schemaAddrs) { - if (currentBatch == null) { - return -1; - } - - // export the batch using the Arrow C Data Interface - int numRows = nativeUtil.exportBatch(arrayAddrs, schemaAddrs, currentBatch); - - // keep a reference to the exported batch so that it doesn't get garbage collected - // while the native code is still processing it - previousBatch = currentBatch; - - currentBatch = null; - - return numRows; - } -} diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 6140eca553..c684e17a92 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -47,8 +47,11 @@ import org.apache.comet.vector.NativeUtil * `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is * done). * - * @param inputs - * The input iterators producing sequence of batches of Arrow Arrays. + * @param inputObjects + * Already-built native input slots, in scan-input order. Each slot is either an + * org.apache.arrow.c.ArrowArrayStream (consumed natively via from_raw against its + * memoryAddress) or a CometShuffleBlockIterator (consumed via the JNI block-iteration + * protocol). * @param protobufQueryPlan * The serialized bytes of Spark execution plan. * @param numParts @@ -60,7 +63,7 @@ import org.apache.comet.vector.NativeUtil */ class CometExecIterator( val id: Long, - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode, @@ -79,14 +82,6 @@ class CometExecIterator( private val taskAttemptId = TaskContext.get().taskAttemptId private val taskCPUs = TaskContext.get().cpus() private val cometTaskMemoryManager = new CometTaskMemoryManager(id, taskAttemptId) - // Build a mixed array of iterators: CometShuffleBlockIterator for shuffle - // scan indices, CometBatchIterator for regular scan indices. - private val inputIterators: Array[Object] = inputs.zipWithIndex.map { - case (_, idx) if shuffleBlockIterators.contains(idx) => - shuffleBlockIterators(idx).asInstanceOf[Object] - case (iterator, _) => - new CometBatchIterator(iterator, nativeUtil).asInstanceOf[Object] - }.toArray private val plan = { val conf = SparkEnv.get.conf @@ -112,7 +107,7 @@ class CometExecIterator( nativeLib.createPlan( id, - inputIterators, + inputObjects, protobufQueryPlan, protobufSparkConfigs, numParts, diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala index 60fb65277e..4ae73565c6 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometDataWritingCommand.scala @@ -96,7 +96,6 @@ object CometDataWritingCommand extends CometOperatorSerde[DataWritingCommandExec val scanOp = OperatorOuterClass.Scan .newBuilder() .setSource(cmd.query.nodeName) - .setArrowFfiSafe(false) // Add fields from the query output schema val scanTypes = cmd.query.output.flatMap { attr => diff --git a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala index ff11b5d23b..b1834c5083 100644 --- a/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala +++ b/spark/src/main/scala/org/apache/comet/serde/operator/CometSink.scala @@ -40,9 +40,6 @@ import org.apache.comet.serde.QueryPlanSerde.{serializeDataType, supportedDataTy */ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { - /** Whether the data produced by the Comet operator is FFI safe */ - def isFfiSafe: Boolean = true - override def enabledConfig: Option[ConfigEntry[Boolean]] = None override def convert( @@ -65,7 +62,6 @@ abstract class CometSink[T <: SparkPlan] extends CometOperatorSerde[T] { } else { scanBuilder.setSource(source) } - scanBuilder.setArrowFfiSafe(isFfiSafe) val scanTypes = op.output.flatten { attr => serializeDataType(attr.dataType) diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala index 47eda98a11..caf639e792 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.comet +import org.apache.arrow.c.ArrowArrayStream import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD @@ -27,7 +28,7 @@ import org.apache.spark.sql.execution.ScalarSubquery import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration -import org.apache.comet.CometExecIterator +import org.apache.comet.{CometExecIterator, CometRuntimeException, CometShuffleBlockIterator} import org.apache.comet.serde.OperatorOuterClass /** @@ -40,23 +41,21 @@ private[spark] class CometExecPartition( extends Partition /** - * Unified RDD for Comet native execution. + * Unified RDD for Comet native execution. Non-shuffle input slots are `RDD[ArrowArrayStream]` + * (consumed natively via the C Stream Interface); shuffle input slots are `CometShuffledBatchRDD` + * (consumed via `CometShuffleBlockIterator`). Slot order matches the scan-input order in the + * serialized native plan. * - * Solves the closure capture problem: instead of capturing all partitions' data in the closure - * (which gets serialized to every task), each Partition object carries only its own data. + * Solves the closure-capture problem: instead of capturing all partitions' data in the closure + * (which gets serialized to every task), each `CometExecPartition` carries only its own data. * - * Handles three cases: - * - With inputs + per-partition data: injects planning data into operator tree - * - With inputs + no per-partition data: just zips inputs (no injection overhead) - * - No inputs: uses numPartitions to create partitions - * - * NOTE: This RDD does not handle DPP (InSubqueryExec), which is resolved in - * CometIcebergNativeScanExec.serializedPartitionData before this RDD is created. It also handles - * ScalarSubquery expressions by registering them with CometScalarSubquery before execution. + * Does not handle DPP (InSubqueryExec), which is resolved in + * `CometIcebergNativeScanExec.serializedPartitionData` before this RDD is created. It does handle + * `ScalarSubquery` expressions by registering them with `CometScalarSubquery` before execution. */ private[spark] class CometExecRDD( sc: SparkContext, - var inputRDDs: Seq[RDD[ColumnarBatch]], + var inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], @transient perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], @@ -97,9 +96,12 @@ private[spark] class CometExecRDD( override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { val partition = split.asInstanceOf[CometExecPartition] - val inputs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => - rdd.iterator(part, context) - } + val (inputObjects, shuffleBlockIters) = + CometExecRDD.resolveInputObjects( + inputRDDs, + partition.inputPartitions, + shuffleScanIndices, + context) // Only inject if we have per-partition planning data val actualPlan = if (commonByKey.nonEmpty) { @@ -111,18 +113,9 @@ private[spark] class CometExecRDD( serializedPlan } - // Create shuffle block iterators for inputs that are CometShuffledBatchRDD - val shuffleBlockIters = shuffleScanIndices.flatMap { idx => - inputRDDs(idx) match { - case rdd: CometShuffledBatchRDD => - Some(idx -> rdd.computeAsShuffleBlockIterator(partition.inputPartitions(idx), context)) - case _ => None - } - }.toMap - val it = new CometExecIterator( CometExec.newIterId, - inputs, + inputObjects, numOutputCols, actualPlan, nativeMetrics, @@ -163,13 +156,55 @@ private[spark] class CometExecRDD( object CometExecRDD { + /** + * Resolve the per-partition native input slots for `createPlan`, in scan-input order. A slot is + * either a `CometShuffleBlockIterator` (for slots in `shuffleScanIndices`, fed by a + * `CometShuffledBatchRDD` consumed via the JNI block-iteration protocol) or the single + * `ArrowArrayStream` exported by a non-shuffle `RDD[ArrowArrayStream]`. Returned alongside the + * subset that are shuffle-block iterators, which `CometExecIterator` needs to drive block + * iteration. Shared by [[CometExecRDD.compute]] and the native-shuffle path so both classify + * and resolve slots identically. + */ + def resolveInputObjects( + inputRDDs: Seq[RDD[_]], + inputPartitions: Array[Partition], + shuffleScanIndices: Set[Int], + context: TaskContext): (Array[Object], Map[Int, CometShuffleBlockIterator]) = { + val shuffleBlockIters = scala.collection.mutable.Map.empty[Int, CometShuffleBlockIterator] + val inputObjects: Array[Object] = inputRDDs + .zip(inputPartitions) + .zipWithIndex + .map { case ((rdd, part), idx) => + if (shuffleScanIndices.contains(idx)) { + rdd match { + case shuffleRDD: CometShuffledBatchRDD => + val it = shuffleRDD.computeAsShuffleBlockIterator(part, context) + shuffleBlockIters(idx) = it + it.asInstanceOf[Object] + case other => + throw new CometRuntimeException( + s"Slot $idx is marked as a shuffle scan but the input RDD is " + + s"${other.getClass.getName}, expected CometShuffledBatchRDD") + } + } else { + val streams = rdd.iterator(part, context).asInstanceOf[Iterator[ArrowArrayStream]] + if (!streams.hasNext) { + throw new CometRuntimeException(s"Empty ArrowArrayStream RDD partition for slot $idx") + } + streams.next().asInstanceOf[Object] + } + } + .toArray + (inputObjects, shuffleBlockIters.toMap) + } + /** * Creates an RDD for native execution with optional per-partition planning data. */ // scalastyle:off def apply( sc: SparkContext, - inputRDDs: Seq[RDD[ColumnarBatch]], + inputRDDs: Seq[RDD[_]], commonByKey: Map[String, Array[Byte]], perPartitionByKey: Map[String, Array[Array[Byte]]], serializedPlan: Array[Byte], diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala index a2af60142b..8145e563f4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometExecUtils.scala @@ -25,6 +25,8 @@ import scala.reflect.ClassTag import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.vectorized.ColumnarBatch @@ -56,8 +58,14 @@ object CometExecUtils { // Serialize the plan once before mapping to avoid repeated serialization per partition val limitOp = CometExecUtils.getLimitNativePlan(outputAttribute, limit, offset).get val serializedPlan = CometExec.serializeNativePlan(limitOp) + val inputSchema = Utils.fromAttributes(outputAttribute) childPlan.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), outputAttribute.length, serializedPlan, numParts, idx) + CometExec.getCometIterator( + CometArrowStream.inputObjects(iter, inputSchema, "CometExecUtils-getNativeLimit"), + outputAttribute.length, + serializedPlan, + numParts, + idx) } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala index 622168bcc9..9c7f81e2ea 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometLocalTableScanExec.scala @@ -19,19 +19,24 @@ package org.apache.spark.sql.comet -import org.apache.spark.TaskContext +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} -import org.apache.spark.sql.comet.CometLocalTableScanExec.createMetricsIterator -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.types.{DataType, NullType} import com.google.common.base.Objects -import org.apache.comet.{CometConf, ConfigEntry} +import org.apache.comet.{CometConf, ConfigEntry, DataTypeSupport} +import org.apache.comet.CometSparkSessionExtensions.withFallbackReason import org.apache.comet.serde.OperatorOuterClass.Operator import org.apache.comet.serde.operator.CometSink @@ -40,7 +45,8 @@ case class CometLocalTableScanExec( @transient rows: Seq[InternalRow], override val output: Seq[Attribute]) extends CometExec - with LeafExecNode { + with LeafExecNode + with CometNativeArrowSource { override lazy val metrics: Map[String, SQLMetric] = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -63,20 +69,34 @@ case class CometLocalTableScanExec( } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - val numInputRows = longMetric("numOutputRows") + private def countingRows( + iter: Iterator[InternalRow], + numOutputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numOutputRows.add(1) + row + } + } + + /** + * Build the per-partition `RowArrowReader`; the trait routes it to the JVM or native consumer. + */ + override protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] = { + val numOutputRows = longMetric("numOutputRows") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC to match native side expectations. See CometSparkToColumnarExec. - val timeZoneId = "UTC" - rdd.mapPartitionsInternal { sparkBatches => - val context = TaskContext.get() - val batches = CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - originalPlan.schema, - maxRecordsPerBatch, - timeZoneId, - context) - createMetricsIterator(batches, numInputRows) + val sparkSchema = originalPlan.schema + rdd.mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + consume( + "CometLocalTableScan", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numOutputRows), + maxRecordsPerBatch)) } } @@ -104,29 +124,36 @@ case class CometLocalTableScanExec( override def hashCode(): Int = Objects.hashCode(originalPlan, originalPlan.schema, output) } -object CometLocalTableScanExec extends CometSink[LocalTableScanExec] { - - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false +object CometLocalTableScanExec extends CometSink[LocalTableScanExec] with DataTypeSupport { override def enabledConfig: Option[ConfigEntry[Boolean]] = Some( CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED) - override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { - CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) + // ArrowWriter (used by RowArrowReader) handles NullType via Utils.toArrowType + NullWriter; + // other types off DataTypeSupport's allow list (TimeType, intervals, ...) have no ArrowWriter + // coverage and must fall back to Spark. + override def isTypeSupported( + dt: DataType, + name: String, + fallbackReasons: ListBuffer[String]): Boolean = dt match { + case _: NullType => true + case _ => super.isTypeSupported(dt, name, fallbackReasons) } - private def createMetricsIterator( - it: Iterator[ColumnarBatch], - numInputRows: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - override def hasNext: Boolean = it.hasNext - - override def next(): ColumnarBatch = { - val batch = it.next() - numInputRows.add(batch.numRows()) - batch - } + override def convert( + op: LocalTableScanExec, + builder: Operator.Builder, + childOp: Operator*): Option[Operator] = { + val fallbackReasons = new ListBuffer[String]() + if (!isSchemaSupported(op.schema, fallbackReasons)) { + withFallbackReason(op, fallbackReasons.mkString("; ")) + None + } else { + super.convert(op, builder, childOp: _*) } } + + override def createExec(nativeOp: Operator, op: LocalTableScanExec): CometNativeExec = { + CometScanWrapper(nativeOp, CometLocalTableScanExec(op, op.rows, op.output)) + } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala index 4fb8af39e8..4cfdd11361 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometNativeWriteExec.scala @@ -28,6 +28,8 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.io.{FileCommitProtocol, FileNameSpec} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -212,7 +214,10 @@ case class CometNativeWriteExec( val execIterator = new CometExecIterator( CometExec.newIterId, - Seq(iter), + CometArrowStream.inputObjects( + iter, + CometUtils.fromAttributes(child.output), + "CometNativeWriteExec"), numOutputCols, planBytes, nativeMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala index efe6a97d40..48be16100f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometSparkToColumnarExec.scala @@ -20,14 +20,17 @@ package org.apache.spark.sql.comet import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag -import org.apache.spark.TaskContext +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource, RowArrowReader, SparkColumnarArrowReader} +import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types._ @@ -39,7 +42,8 @@ import org.apache.comet.serde.operator.CometSink case class CometSparkToColumnarExec(child: SparkPlan) extends RowToColumnarTransition - with CometPlan { + with CometPlan + with CometNativeArrowSource { override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = child.outputPartitioning @@ -69,72 +73,73 @@ case class CometSparkToColumnarExec(child: SparkPlan) sparkContext, "time converting Spark batches to Arrow batches")) - // The conversion happens in next(), so wrap the call to measure time spent. - private def createTimingIter( + private def countingBatches( iter: Iterator[ColumnarBatch], - numInputRows: SQLMetric, - numOutputBatches: SQLMetric, - conversionTime: SQLMetric): Iterator[ColumnarBatch] = { - new Iterator[ColumnarBatch] { - - override def hasNext: Boolean = { - iter.hasNext - } + numInputRows: SQLMetric): Iterator[ColumnarBatch] = new Iterator[ColumnarBatch] { + override def hasNext: Boolean = iter.hasNext + override def next(): ColumnarBatch = { + val batch = iter.next() + numInputRows += batch.numRows() + batch + } + } - override def next(): ColumnarBatch = { - val startNs = System.nanoTime() - val batch = iter.next() - conversionTime += System.nanoTime() - startNs - numInputRows += batch.numRows() - numOutputBatches += 1 - batch - } + private def countingRows( + iter: Iterator[InternalRow], + numInputRows: SQLMetric): Iterator[InternalRow] = new Iterator[InternalRow] { + override def hasNext: Boolean = iter.hasNext + override def next(): InternalRow = { + val row = iter.next() + numInputRows += 1 + row } } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { + /** + * Build the per-partition `ArrowReader` (columnar or row, depending on the child); the trait + * routes it to the JVM or native consumer. + * + * `numOutputBatches` is incremented from the reader's per-produced-batch callback rather than + * by counting input batches, so it stays accurate on the native path too (native drives + * `loadNextBatch`) and counts produced Arrow batches, not Spark input batches. + */ + override protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] = { val numInputRows = longMetric("numInputRows") val numOutputBatches = longMetric("numOutputBatches") val conversionTime = longMetric("conversionTime") val maxRecordsPerBatch = CometConf.COMET_BATCH_SIZE.get(conf) - // Use UTC for Arrow schema timezone to match the native side, which always - // deserializes Timestamp as Timestamp(Microsecond, Some("UTC")). Spark's internal - // timestamp representation is always UTC microseconds, so the timezone here is - // purely schema metadata. Using session timezone would cause Arrow RowConverter - // schema mismatch errors in non-UTC sessions. See COMET-2720. - val timeZoneId = "UTC" - val schema = child.schema + val sparkSchema = child.schema + val onConversionNs: Long => Unit = ns => { + conversionTime += ns + numOutputBatches += 1 + } if (child.supportsColumnar) { - child - .executeColumnar() - .mapPartitionsInternal { sparkBatches => - val arrowBatches = - sparkBatches.flatMap { sparkBatch => - val context = TaskContext.get() - CometArrowConverters.columnarBatchToArrowBatchIter( - sparkBatch, - schema, - maxRecordsPerBatch, - timeZoneId, - context) - } - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + val maxBatchInt = maxRecordsPerBatch.toInt + child.executeColumnar().mapPartitionsInternal { sparkBatches => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + consume( + "CometSparkColumnarToColumnar", + new SparkColumnarArrowReader( + _, + arrowSchema, + countingBatches(sparkBatches, numInputRows), + maxBatchInt, + onConversionNs)) + } } else { - child - .execute() - .mapPartitionsInternal { sparkBatches => - val context = TaskContext.get() - val arrowBatches = - CometArrowConverters.rowToArrowBatchIter( - sparkBatches, - schema, - maxRecordsPerBatch, - timeZoneId, - context) - createTimingIter(arrowBatches, numInputRows, numOutputBatches, conversionTime) - } + child.execute().mapPartitionsInternal { rowIter => + val arrowSchema = Utils.toArrowSchema(sparkSchema, CometArrowStream.NATIVE_TIMEZONE) + consume( + "CometSparkRowToColumnar", + new RowArrowReader( + _, + arrowSchema, + countingRows(rowIter, numInputRows), + maxRecordsPerBatch, + onConversionNs)) + } } } @@ -145,9 +150,6 @@ case class CometSparkToColumnarExec(child: SparkPlan) object CometSparkToColumnarExec extends CometSink[SparkPlan] with DataTypeSupport { - // uses CometArrowConverters, which re-uses arrays - override def isFfiSafe: Boolean = false - override def createExec( nativeOp: OperatorOuterClass.Operator, op: SparkPlan): CometNativeExec = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala index a66d1b58d6..09dd944d93 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometTakeOrderedAndProjectExec.scala @@ -24,7 +24,9 @@ import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.comet.execution.shuffle.{CometShuffledBatchRDD, CometShuffleExchangeExec} +import org.apache.spark.sql.comet.util.{Utils => CometUtils} import org.apache.spark.sql.execution.{SparkPlan, TakeOrderedAndProjectExec, UnaryExecNode, UnsafeRowSerializer} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -140,8 +142,14 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopK = CometExec.serializeNativePlan(topK) val numOutputCols = child.output.length + val inputSchema = CometUtils.fromAttributes(child.output) childRDD.mapPartitionsWithIndexInternal { case (idx, iter) => - CometExec.getCometIterator(Seq(iter), numOutputCols, serializedTopK, numParts, idx) + CometExec.getCometIterator( + CometArrowStream.inputObjects(iter, inputSchema, "CometTakeOrderedAndProject-topK"), + numOutputCols, + serializedTopK, + numParts, + idx) } } @@ -163,9 +171,11 @@ case class CometTakeOrderedAndProjectExec( .get val serializedTopKAndProjection = CometExec.serializeNativePlan(topKAndProjection) val finalOutputLength = output.length + val finalInputSchema = CometUtils.fromAttributes(child.output) singlePartitionRDD.mapPartitionsInternal { iter => val it = CometExec.getCometIterator( - Seq(iter), + CometArrowStream + .inputObjects(iter, finalInputSchema, "CometTakeOrderedAndProject-final"), finalOutputLength, serializedTopKAndProjection, 1, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala new file mode 100644 index 0000000000..6cbba9e06e --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ColumnarBatchArrowReader.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import java.util.{ArrayList => JArrayList} + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.{FieldVector, VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.dictionary.DictionaryEncoder +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.{CometDictionaryVector, CometVector} + +/** + * `ArrowReader` over an iterator of Arrow-backed `ColumnarBatch`es. Each `loadNextBatch` unloads + * the source's `FieldVector`s into a transient `ArrowRecordBatch` (retains buffers), loads it + * into this reader's stable VSR via `loadFieldBuffers` (release-and-replace), then closes the + * source batch. The unload/load step decouples this reader's VSR ownership from whatever the + * source does with its own buffers. + */ +private[comet] class ColumnarBatchArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch]) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!source.hasNext) { + return false + } + + val src = source.next() + var materialized: JArrayList[FieldVector] = null + try { + val sourceVectors = new JArrayList[FieldVector](src.numCols()) + var i = 0 + while (i < src.numCols()) { + val col = src.column(i).asInstanceOf[CometVector] + val fv = col match { + case d: CometDictionaryVector => + // Stable VSR was built from the logical (non-dict) schema, so a dict-encoded + // source's indices layout would mismatch the dest buffer count on load. Native + // unpacks downstream anyway via copy_or_unpack_array. + val indices = d.getValueVector + val dictionary = d.provider.lookup(indices.getField.getDictionary.getId) + val plain = DictionaryEncoder + .decode(indices, dictionary, allocator) + .asInstanceOf[FieldVector] + if (materialized == null) materialized = new JArrayList[FieldVector]() + materialized.add(plain) + plain + case _ => + col.getValueVector.asInstanceOf[FieldVector] + } + sourceVectors.add(fv) + i += 1 + } + val transient = new VectorSchemaRoot(sourceVectors) + transient.setRowCount(src.numRows()) + + val unloader = new VectorUnloader(transient) + val rb = unloader.getRecordBatch + try { + loadRecordBatch(rb) + } finally { + rb.close() + } + // Do not close `transient`. It shares FieldVectors with `src`; closing `src` below + // releases the producer-side refs. Closing `transient` would double-release. + } finally { + if (materialized != null) { + var j = 0 + while (j < materialized.size()) { + try materialized.get(j).close() + catch { case _: Throwable => () } + j += 1 + } + } + src.close() + } + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala index 6d52078181..2d4fd71376 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala @@ -22,175 +22,56 @@ package org.apache.spark.sql.comet.execution.arrow import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema -import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} +import org.apache.spark.sql.vectorized.ColumnarBatch -import org.apache.comet.CometArrowAllocator import org.apache.comet.vector.NativeUtil +/** + * Convert a stream of Spark `InternalRow`s to a stream of independently-owned Arrow + * `ColumnarBatch`es: each emitted batch owns a fresh `VectorSchemaRoot` with newly allocated + * buffers and the consumer is responsible for closing it. + * + * This differs from [[RowArrowReader]], which reuses one stable `VectorSchemaRoot` + * (release-and-replace) so only one batch is valid at a time. Use this when multiple emitted + * batches must be alive simultaneously (e.g. tests that buffer several batches before consuming). + * Buffers come from the caller-provided `BufferAllocator`, whose lifecycle the caller owns. + */ object CometArrowConverters extends Logging { - // This is similar how Spark converts internal row to Arrow format except that it is transforming - // the result batch to Comet's ColumnarBatch instead of serialized bytes. - // There's another big difference that Comet may consume the ColumnarBatch by exporting it to - // the native side. Hence, we need to: - // 1. reset the Arrow writer after the ColumnarBatch is consumed - // 2. close the allocator when the task is finished but not when the iterator is all consumed - // The reason for the second point is that when ColumnarBatch is exported to the native side, the - // exported process increases the reference count of the Arrow vectors. The reference count is - // only decreased when the native plan is done with the vectors, which is usually longer than - // all the ColumnarBatches are consumed. - - abstract private[sql] class ArrowBatchIterBase( - schema: StructType, - timeZoneId: String, - context: TaskContext) - extends Iterator[ColumnarBatch] - with AutoCloseable { - - protected val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) - // Reuse the same root allocator here. - protected val allocator: BufferAllocator = - CometArrowAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, Long.MaxValue) - protected val root: VectorSchemaRoot = VectorSchemaRoot.create(arrowSchema, allocator) - protected val arrowWriter: ArrowWriter = ArrowWriter.create(root) - - protected var currentBatch: ColumnarBatch = null - protected var closed: Boolean = false - Option(context).foreach { - _.addTaskCompletionListener[Unit] { _ => - close(true) - } - } - - override def close(): Unit = { - close(false) - } - - protected def close(closeAllocator: Boolean): Unit = { - try { - if (!closed) { - if (currentBatch != null) { - arrowWriter.reset() - currentBatch.close() - currentBatch = null - } - root.close() - closed = true - } - } finally { - // the allocator shall be closed when the task is finished - if (closeAllocator) { - allocator.close() - } - } - } - - override def next(): ColumnarBatch = { - currentBatch = nextBatch() - currentBatch - } - - protected def nextBatch(): ColumnarBatch - - } - - private[sql] class RowToArrowBatchIter( + /** + * Convert an iterator of Spark `InternalRow`s into an iterator of Arrow `ColumnarBatch`es. + * + * Each call to `next()` allocates a fresh `VectorSchemaRoot`, writes up to `maxRecordsPerBatch` + * rows into it, and emits a `ColumnarBatch` wrapping that root. The consumer must close every + * emitted batch. + */ + def rowToArrowBatchIter( rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Long, timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { + allocator: BufferAllocator): Iterator[ColumnarBatch] = { + val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId) - override def hasNext: Boolean = rowIter.hasNext || { - close(false) - false - } + new Iterator[ColumnarBatch] { + override def hasNext: Boolean = rowIter.hasNext - override protected def nextBatch(): ColumnarBatch = { - if (rowIter.hasNext) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() + override def next(): ColumnarBatch = { + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val writer = ArrowWriter.create(root) var rowCount = 0L - while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { - val row = rowIter.next() - arrowWriter.write(row) + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) rowCount += 1 } - arrowWriter.finish() - NativeUtil.rootAsBatch(root) - } else { - null - } - } - } - - def rowToArrowBatchIter( - rowIter: Iterator[InternalRow], - schema: StructType, - maxRecordsPerBatch: Long, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new RowToArrowBatchIter(rowIter, schema, maxRecordsPerBatch, timeZoneId, context) - } - - private[sql] class ColumnBatchToArrowBatchIter( - colBatch: ColumnarBatch, - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - context: TaskContext) - extends ArrowBatchIterBase(schema, timeZoneId, context) - with AutoCloseable { - - private var rowsProduced: Int = 0 - - override def hasNext: Boolean = rowsProduced < colBatch.numRows() || { - close(false) - false - } - - override protected def nextBatch(): ColumnarBatch = { - val rowsInBatch = colBatch.numRows() - if (rowsProduced < rowsInBatch) { - // the arrow writer shall be reset before writing the next batch - arrowWriter.reset() - val rowsToProduce = - if (maxRecordsPerBatch <= 0) rowsInBatch - rowsProduced - else Math.min(maxRecordsPerBatch, rowsInBatch - rowsProduced) - - for (columnIndex <- 0 until colBatch.numCols()) { - val column = colBatch.column(columnIndex) - val columnArray = new ColumnarArray(column, rowsProduced, rowsToProduce) - if (column.hasNull) { - arrowWriter.writeCol(columnArray, columnIndex) - } else { - arrowWriter.writeColNoNull(columnArray, columnIndex) - } - } - - rowsProduced += rowsToProduce - - arrowWriter.finish() + writer.finish() NativeUtil.rootAsBatch(root) - } else { - null } } } - - def columnarBatchToArrowBatchIter( - colBatch: ColumnarBatch, - schema: StructType, - maxRecordsPerBatch: Int, - timeZoneId: String, - context: TaskContext): Iterator[ColumnarBatch] = { - new ColumnBatchToArrowBatchIter(colBatch, schema, maxRecordsPerBatch, timeZoneId, context) - } } diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala new file mode 100644 index 0000000000..6415256a39 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometNativeArrowSource.scala @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +import org.apache.arrow.c.{ArrowArrayStream, Data} +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.{Field, FieldType, Schema} +import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.util.Utils +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometArrowAllocator +import org.apache.comet.vector.{CometDictionaryVector, CometVector, NativeUtil} + +/** + * A Comet operator that produces its output as Arrow data, consumable either as JVM + * `ColumnarBatch`es (`doExecuteColumnar`) or, when the consumer is a Comet native executor, + * directly as the Arrow C Stream Interface (`doExecuteAsArrowStream`), skipping the intermediate + * `RDD[ColumnarBatch]` layer. + * + * Implementors supply only [[mapToReaders]] (their source RDD + per-partition `ArrowReader`); the + * two execution paths here differ solely in whether each partition's reader is drained into + * `ColumnarBatch`es or exported as a stream. + */ +trait CometNativeArrowSource extends SparkPlan { + + /** + * Build this operator's per-partition `ArrowReader` and hand it to `consume`, returning the + * output RDD. `consume` is provided by this trait: `CometArrowStream.readerBatchIter` for the + * JVM columnar path, `CometArrowStream.stream` for the native C Stream path. + */ + protected def mapToReaders[T: ClassTag]( + consume: (String, BufferAllocator => ArrowReader) => Iterator[T]): RDD[T] + + override protected def doExecuteColumnar(): RDD[ColumnarBatch] = + mapToReaders(CometArrowStream.readerBatchIter) + + def doExecuteAsArrowStream(): RDD[ArrowArrayStream] = + mapToReaders(CometArrowStream.stream) +} + +object CometArrowStream extends Logging { + + /** + * Native side asserts `Timestamp(Microsecond, Some("UTC"))` regardless of session timezone; + * Spark's internal timestamp representation is always UTC microseconds anyway, and a non-UTC + * timezone here would only show up as schema metadata that breaks Arrow RowConverter + * validation. See COMET-2720. + */ + val NATIVE_TIMEZONE: String = "UTC" + + /** + * Wrap an `RDD[ColumnarBatch]` whose batches are Arrow-backed into an `RDD[ArrowArrayStream]`. + */ + def wrapColumnarBatchRDD( + rdd: RDD[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): RDD[ArrowArrayStream] = { + // Arrow `Schema` is not Serializable; only Spark's `StructType` is. Build the Arrow schema + // inside the per-task body so the closure cleaner doesn't try to ship a Schema across. + rdd.mapPartitionsInternal { batchIter => + val expected = Utils.toArrowSchema(sparkSchema, timeZoneId) + val (arrowSchema, iter) = reconcileStreamSchema(name, expected, batchIter) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, iter)) + } + } + + /** + * Wrap a single per-partition `Iterator[ColumnarBatch]` (Arrow-backed) and return the exported + * `ArrowArrayStream`. For callers outside `CometExecRDD` that hand a JNI input slot directly to + * a `CometExecIterator`. + */ + def fromColumnarBatchIter( + iter: Iterator[ColumnarBatch], + sparkSchema: StructType, + timeZoneId: String, + name: String): ArrowArrayStream = { + val expected = Utils.toArrowSchema(sparkSchema, timeZoneId) + val (arrowSchema, reconciled) = reconcileStreamSchema(name, expected, iter) + stream(name, allocator => new ColumnarBatchArrowReader(allocator, arrowSchema, reconciled)) + .next() + } + + /** + * Build the `inputObjects` array that `CometExecIterator` / `CometExec.getCometIterator` pass + * to native `createPlan`, for the common case of a single scan input fed by one per-partition + * `Iterator[ColumnarBatch]`. The iterator is exported to one `ArrowArrayStream` (Arrow C + * Stream) and boxed as the lone element, using the native timezone. + */ + def inputObjects( + iter: Iterator[ColumnarBatch], + sparkSchema: StructType, + name: String): Array[Object] = + Array[Object](fromColumnarBatchIter(iter, sparkSchema, NATIVE_TIMEZONE, name)) + + /** + * Build the stream's advertised Arrow schema from the actual `CometVector` types in the first + * batch, not from `expected` (which derives from the consumer's Spark-declared types). Native + * operators like `ScanExec` already cast their input to the declared scan-input schema in + * `build_record_batch`, so the truthful schema lets that cast actually fire. Advertising + * `expected` instead silently mislabels Int32 buffers as Int64 (and similar) and corrupts on + * import. See PR #4393 width_bucket investigation. + * + * If the first batch's column types differ from `expected` in their `DataType` (timezone-only + * differences on `Timestamp` are ignored), log one warning naming the operator, column, and + * type drift; the cast happens transparently downstream in native. + */ + private[arrow] def reconcileStreamSchema( + name: String, + expected: Schema, + iter: Iterator[ColumnarBatch]): (Schema, Iterator[ColumnarBatch]) = { + val buffered = iter.buffered + if (!buffered.hasNext) { + // Empty partition: keep the consumer-declared schema; consumer can still build its plan. + return (expected, buffered) + } + val first = buffered.head + val expectedFields = expected.getFields + val actualFields = (0 until first.numCols()).map { i => + val col = first.column(i).asInstanceOf[CometVector] + actualFieldOf(col, expectedFields.get(i)) + } + val mismatches = actualFields.zip(expectedFields.asScala).zipWithIndex.collect { + case ((actual, exp), idx) if actual.getType != exp.getType => + s"col[$idx] '${exp.getName}': expected ${exp.getType}, child produced ${actual.getType}" + } + if (mismatches.nonEmpty) { + logWarning( + s"CometArrowStream '$name' input schema mismatch: ${mismatches.mkString("; ")}. " + + "Native ScanExec will cast at the boundary. This usually means a DataFusion-Spark " + + "function declares a different return type than Spark catalyst.") + } + (new Schema(actualFields.asJava), buffered) + } + + /** + * The Arrow field that this column's buffers will look like once unloaded. For a + * `CometDictionaryVector`, [[ColumnarBatchArrowReader]] decodes it via + * `DictionaryEncoder.decode` before unloading, so the wire-level field is the dictionary's + * *value* type, not `Dictionary`. For everything else, use the underlying value + * vector's field. + * + * Field name and metadata come from `expected` so that consumers indexing by name keep working. + * Nullability is the union of the two: a CometVector that happens to hold no nulls in this + * batch can still be nullable per Spark's contract (the next batch may have one), and a column + * whose actual buffer carries validity bits must stay nullable even if Spark thought otherwise. + * Taking only `raw.isNullable` here would advertise non-nullable when the next batch does carry + * a null and crash native validation. + */ + private def actualFieldOf(col: CometVector, expected: Field): Field = { + val raw = col match { + case d: CometDictionaryVector => + val indices = d.getValueVector + val dict = d.provider.lookup(indices.getField.getDictionary.getId) + dict.getVector.getField + case _ => col.getValueVector.getField + } + val nullable = expected.isNullable || raw.isNullable + val fieldType = + new FieldType(nullable, raw.getType, raw.getDictionary, expected.getMetadata) + new Field(expected.getName, fieldType, raw.getChildren) + } + + /** + * Allocate a child allocator, build a reader, export it as an `ArrowArrayStream`, and register + * task-completion cleanup. Returns a single-element iterator so this composes with + * `RDD.mapPartitionsInternal`. + * + * Close ordering: when native drops its `ArrowArrayStreamReader`, the C release callback fires + * synchronously into `ExportedArrayStreamPrivateData.close` -> `reader.close` -> the VSR's + * buffers are released. The task-completion listener registered here runs strictly after that + * (Spark fires listeners in reverse registration order, and the listener that drops the native + * plan is registered later by `CometExecIterator`), so `allocator.close` finds zero outstanding + * bytes. + */ + def stream( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ArrowArrayStream] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + var reader: ArrowReader = null + var arrowStream: ArrowArrayStream = null + try { + reader = readerFactory(allocator) + arrowStream = ArrowArrayStream.allocateNew(allocator) + Data.exportArrayStream(allocator, reader, arrowStream) + } catch { + case t: Throwable => + // Roll back partial setup before rethrowing -- nothing has been registered with + // TaskContext yet, so without this the allocator (and possibly the reader/stream) leaks. + if (arrowStream != null) { + try arrowStream.close() + catch { case _: Throwable => () } + } + if (reader != null) { + try reader.close() + catch { case _: Throwable => () } + } + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + val streamRef = arrowStream + context.addTaskCompletionListener[Unit] { _ => + streamRef.close() + allocator.close() + } + } + Iterator.single(arrowStream) + } + + /** + * Drive an `ArrowReader` from a per-task body and emit `ColumnarBatch`es wrapping the reader's + * stable VSR. Lifecycle: the supplied factory builds the reader against a fresh child + * allocator; both close at task completion. This is the non-native consumer path + * (`doExecuteColumnar`) -- the native consumer path uses [[stream]] to export instead. + */ + def readerBatchIter( + name: String, + readerFactory: BufferAllocator => ArrowReader): Iterator[ColumnarBatch] = { + val context = TaskContext.get() + val allocator = CometArrowAllocator.newChildAllocator(name, 0, Long.MaxValue) + val reader = + try readerFactory(allocator) + catch { + case t: Throwable => + try allocator.close() + catch { case _: Throwable => () } + throw t + } + if (context != null) { + context.addTaskCompletionListener[Unit] { _ => + reader.close() + allocator.close() + } + } + new Iterator[ColumnarBatch] { + // Lazily prefetch one batch so `hasNext` can answer without consuming. + private var loaded: Boolean = false + private var hasMore: Boolean = false + + private def ensureLoaded(): Unit = { + if (!loaded) { + hasMore = reader.loadNextBatch() + loaded = true + } + } + + override def hasNext: Boolean = { + ensureLoaded() + hasMore + } + + override def next(): ColumnarBatch = { + ensureLoaded() + if (!hasMore) { + throw new NoSuchElementException("No more batches") + } + loaded = false + NativeUtil.rootAsBatch(reader.getVectorSchemaRoot) + } + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala new file mode 100644 index 0000000000..e1829eb5c5 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/RowArrowReader.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.catalyst.InternalRow + +/** + * `ArrowReader` over an iterator of Spark `InternalRow`s, writing up to `maxRecordsPerBatch` rows + * per call into the reader's stable VSR via `ArrowWriter`. + * + * `ArrowWriter.create(root)` calls `vector.allocateNew()`, which releases any prior buffers and + * allocates fresh ones. This is required for FFI safety: previously-exported batches retain their + * buffers via the C release callback, so reusing those buffers in place would corrupt native + * consumers still holding the prior batch. + */ +private[comet] class RowArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + rowIter: Iterator[InternalRow], + maxRecordsPerBatch: Long, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!rowIter.hasNext) { + return false + } + + val startNs = System.nanoTime() + val writer = ArrowWriter.create(getVectorSchemaRoot) + var rowCount = 0L + while (rowIter.hasNext && + (maxRecordsPerBatch <= 0 || rowCount < maxRecordsPerBatch)) { + writer.write(rowIter.next()) + rowCount += 1 + } + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala new file mode 100644 index 0000000000..0af940fab3 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/arrow/SparkColumnarArrowReader.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import org.apache.arrow.memory.BufferAllocator +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.arrow.vector.types.pojo.Schema +import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarBatch} + +/** + * `ArrowReader` over an iterator of Spark-side `ColumnarBatch`es (not Arrow-backed). Slices up to + * `maxRecordsPerBatch` rows per `loadNextBatch` from the current Spark batch into the reader's + * stable VSR via `ArrowWriter.writeCol`. Spark's `ColumnVector` implementations aren't Arrow + * buffers, so this reader necessarily copies element values into Arrow format. + */ +private[comet] class SparkColumnarArrowReader( + allocator: BufferAllocator, + arrowSchema: Schema, + source: Iterator[ColumnarBatch], + maxRecordsPerBatch: Int, + onConversionNs: Long => Unit = _ => ()) + extends ArrowReader(allocator) { + + private var current: ColumnarBatch = _ + private var rowsConsumedInCurrent: Int = 0 + + override protected def readSchema(): Schema = arrowSchema + + override def bytesRead(): Long = 0L + + override protected def closeReadSource(): Unit = () + + private def advanceToNonEmptyBatch(): Boolean = { + while (current == null || rowsConsumedInCurrent >= current.numRows()) { + if (current != null) { + // We don't own Spark ColumnarBatches; just drop the reference. + current = null + rowsConsumedInCurrent = 0 + } + if (!source.hasNext) { + return false + } + current = source.next() + rowsConsumedInCurrent = 0 + } + true + } + + override def loadNextBatch(): Boolean = { + prepareLoadNextBatch() + + if (!advanceToNonEmptyBatch()) { + return false + } + + val startNs = System.nanoTime() + val rowsRemaining = current.numRows() - rowsConsumedInCurrent + val rowsToProduce = + if (maxRecordsPerBatch <= 0) rowsRemaining + else math.min(maxRecordsPerBatch, rowsRemaining) + + val writer = ArrowWriter.create(getVectorSchemaRoot) + var col = 0 + while (col < current.numCols()) { + val column = current.column(col) + val columnArray = new ColumnarArray(column, rowsConsumedInCurrent, rowsToProduce) + if (column.hasNull) { + writer.writeCol(columnArray, col) + } else { + writer.writeColNoNull(columnArray, col) + } + col += 1 + } + rowsConsumedInCurrent += rowsToProduce + + writer.finish() + onConversionNs(System.nanoTime() - startNs) + true + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala new file mode 100644 index 0000000000..0579e57bce --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleInputRDD.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.shuffle + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.comet.CometExecRDD +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.CometShuffleBlockIterator + +/** + * Thin scheduling-anchor RDD for the native-shuffle path. Declares `OneToOneDependency` on each + * leaf input RDD (so the DAGScheduler triggers prior stages, broadcasts, etc.) and resolves the + * per-partition native input slots in `compute`, packaged into a + * [[CometNativeShuffleInputIterator]]. The iterator reports `hasNext = false`; + * [[CometNativeShuffleWriter]] downcasts it and reads those slots directly to drive the unified + * `ShuffleWriter(child = childNativeOp)` plan. + */ +private[shuffle] class CometNativeShuffleInputRDD( + sc: SparkContext, + var inputRDDs: Seq[RDD[_]], + numPartitionsParam: Int, + shuffleScanIndices: Set[Int]) + extends RDD[Product2[Int, ColumnarBatch]]( + sc, + inputRDDs.map(rdd => new OneToOneDependency(rdd))) { + + override protected def getPartitions: Array[Partition] = + (0 until numPartitionsParam).map { i => + // Resolve leaf-RDD partitions on the driver here (where their @transient fields are still + // populated). Stashing them on the partition lets `compute` avoid touching + // `leafRdd.partitions` on the executor, which would otherwise trigger getPartitions and + // hit the @transient-null trap (e.g. CometExecRDD.perPartitionByKey). + val inputParts = inputRDDs.map(_.partitions(i)).toArray + new CometNativeShuffleInputPartition(i, inputParts) + }.toArray + + override def compute( + split: Partition, + context: TaskContext): Iterator[Product2[Int, ColumnarBatch]] = { + val partition = split.asInstanceOf[CometNativeShuffleInputPartition] + val (inputObjects, shuffleBlockIters) = + CometExecRDD.resolveInputObjects( + inputRDDs, + partition.inputPartitions, + shuffleScanIndices, + context) + new CometNativeShuffleInputIterator(partition.index, inputObjects, shuffleBlockIters) + } + + override def getPreferredLocations(split: Partition): Seq[String] = { + if (inputRDDs == null || inputRDDs.isEmpty) return Nil + val partition = split.asInstanceOf[CometNativeShuffleInputPartition] + val prefs = inputRDDs.zip(partition.inputPartitions).map { case (rdd, part) => + rdd.preferredLocations(part) + } + val intersection = prefs.reduce((a, b) => a.intersect(b)) + if (intersection.nonEmpty) intersection else prefs.flatten.distinct + } + + override def clearDependencies(): Unit = { + super.clearDependencies() + inputRDDs = null + } +} + +private[shuffle] class CometNativeShuffleInputPartition( + override val index: Int, + val inputPartitions: Array[Partition]) + extends Partition + +/** + * Iterator handed to [[CometNativeShuffleWriter.write]] via Spark's ShuffleMapTask. Reports no + * elements; the writer downcasts and reads `partitionIndex`, `inputObjects`, and + * `shuffleBlockIterators` directly to drive the unified native plan. `inputObjects` are the + * already-resolved native input slots (see [[CometExecRDD.resolveInputObjects]]). + */ +private[shuffle] class CometNativeShuffleInputIterator( + val partitionIndex: Int, + val inputObjects: Array[Object], + val shuffleBlockIterators: Map[Int, CometShuffleBlockIterator]) + extends Iterator[Product2[Int, ColumnarBatch]] { + + override def hasNext: Boolean = false + + override def next(): Product2[Int, ColumnarBatch] = + throw new NoSuchElementException( + "CometNativeShuffleInputIterator should never be drained as an iterator. Reaching this " + + "code means a non-Comet ShuffleWriter is consuming the input, which is a bug.") +} diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala index f27d021ac4..8cb1eb86ad 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometNativeShuffleWriter.scala @@ -32,19 +32,26 @@ import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriteMetricsR import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} -import org.apache.spark.sql.comet.{CometExec, CometMetricNode} +import org.apache.spark.sql.comet.{CometExec, CometMetricNode, CometScalarSubquery, PlanDataInjector} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.types.StructField -import org.apache.comet.CometConf +import org.apache.comet.{CometConf, CometExecIterator} import org.apache.comet.serde.{OperatorOuterClass, PartitioningOuterClass, QueryPlanSerde} import org.apache.comet.serde.OperatorOuterClass.{CompressionCodec, Operator} -import org.apache.comet.serde.QueryPlanSerde.serializeDataType +import org.apache.comet.serde.operator.schema2Proto /** - * A [[ShuffleWriter]] that will delegate shuffle write to native shuffle. + * Drives the native shuffle write in a single [[CometExecIterator]] per partition. The plan is + * `ShuffleWriter(child = childNativeOp)`; leaf iterators come from a + * [[CometNativeShuffleInputIterator]]. `childNativeOp` is either a rich Comet native subtree + * (when fed by [[CometShuffleExchangeExec]] with a [[org.apache.spark.sql.comet.CometNativeExec]] + * child) or a synthetic `Scan("ShuffleWriterInput")` placeholder (the + * [[CometShuffleExchangeExec.prepareShuffleDependency]] convenience overload). Same handling + * either way. */ class CometNativeShuffleWriter[K, V]( + spec: NativeShuffleSpec, outputPartitioning: Partitioning, outputAttributes: Seq[Attribute], metrics: Map[String, SQLMetric], @@ -72,8 +79,31 @@ class CometNativeShuffleWriter[K, V]( val tempDataFilePath = Paths.get(tempDataFilename) val tempIndexFilePath = Paths.get(tempIndexFilename) - // Call native shuffle write - val nativePlan = getNativePlan(tempDataFilename, tempIndexFilename) + // The dep's _rdd is always a CometNativeShuffleInputRDD on this path. Pattern-match instead + // of asInstanceOf so a future RDD-layering change produces a clear error here rather than a + // bare ClassCastException deeper in the stack. + val shuffleInputIter = inputs match { + case it: CometNativeShuffleInputIterator => it + case other => + throw new IllegalStateException( + "CometNativeShuffleWriter expects its input iterator to be a " + + "CometNativeShuffleInputIterator (produced by CometNativeShuffleInputRDD), got " + + s"${other.getClass.getName}") + } + val partitionIdx = shuffleInputIter.partitionIndex + val inputObjects = shuffleInputIter.inputObjects + val shuffleBlockIters = shuffleInputIter.shuffleBlockIterators + + val unifiedPlan = buildUnifiedPlan(tempDataFilename, tempIndexFilename) + val ctx = spec.execContext + val finalNativePlan = if (ctx.commonByKey.nonEmpty) { + val partitionDataByKey = ctx.perPartitionByKey.map { case (k, arr) => + k -> arr(partitionIdx) + } + PlanDataInjector.injectPlanData(unifiedPlan, ctx.commonByKey, partitionDataByKey) + } else { + unifiedPlan + } val detailedMetrics = Seq( "elapsed_compute", @@ -82,29 +112,48 @@ class CometNativeShuffleWriter[K, V]( "input_batches", "spill_count", "spilled_bytes") - - // Maps native metrics to SQL metrics val metricsOutputRows = new SQLMetric("outputRows") val metricsWriteTime = new SQLMetric("writeTime") - val nativeSQLMetrics = Map( + val shuffleWriterSQLMetrics = Map( "output_rows" -> metricsOutputRows, "data_size" -> metrics("dataSize"), "write_time" -> metricsWriteTime) ++ metrics.filterKeys(detailedMetrics.contains) - val nativeMetrics = CometMetricNode(nativeSQLMetrics) - // Getting rid of the fake partitionId - val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + // ShuffleWriter metrics at the root; child's metric tree underneath so the SQL UI's per-node + // breakdown matches what the split-driver flow showed. + val nativeMetrics = CometMetricNode(shuffleWriterSQLMetrics, Seq(spec.childMetricNode)) - val cometIter = CometExec.getCometIterator( - Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), + // The leaf scans execute inside this writer's single plan rather than a separate native + // stage RDD, so the usual CometExecRDD.compute() bridge (operators.scala) never runs for + // them. Report their bytes/rows to the task's input metrics here instead. + if (ctx.hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) + } + + val cometIter = new CometExecIterator( + CometExec.newIterId, + inputObjects, outputAttributes.length, - nativePlan, + CometExec.serializeNativePlan(finalNativePlan), nativeMetrics, numParts, - context.partitionId(), - broadcastedHadoopConfForEncryption = None, - encryptedFilePaths = Seq.empty) + partitionIdx, + ctx.broadcastedHadoopConfForEncryption, + ctx.encryptedFilePaths, + shuffleBlockIters) + + // Register subqueries against the iterator id so native callbacks resolve them to values. + ctx.subqueries.foreach { sub => + CometScalarSubquery.setSubquery(cometIter.id, sub) + } + Option(context).foreach { taskCtx => + taskCtx.addTaskCompletionListener[Unit] { _ => + ctx.subqueries.foreach { sub => + CometScalarSubquery.removeSubquery(cometIter.id, sub) + } + } + } while (cometIter.hasNext) { cometIter.next() @@ -134,7 +183,7 @@ class CometNativeShuffleWriter[K, V]( // Report spill metrics to Spark's task metrics so they appear in // Spark UI task summaries (not just SQL metrics) - val spilledBytes = nativeSQLMetrics.get("spilled_bytes").map(_.value).getOrElse(0L) + val spilledBytes = shuffleWriterSQLMetrics.get("spilled_bytes").map(_.value).getOrElse(0L) if (spilledBytes > 0) { context.taskMetrics().incMemoryBytesSpilled(spilledBytes) context.taskMetrics().incDiskBytesSpilled(spilledBytes) @@ -162,163 +211,149 @@ class CometNativeShuffleWriter[K, V]( case _ => false } - private def getNativePlan(dataFile: String, indexFile: String): Operator = { - val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") - val opBuilder = OperatorOuterClass.Operator.newBuilder() - - val scanTypes = outputAttributes.flatten { attr => - serializeDataType(attr.dataType) + /** + * Build the unified `ShuffleWriter(child = childNativeOp)` plan with the partitioning serde, + * compression settings, and output file paths. + */ + private def buildUnifiedPlan(dataFile: String, indexFile: String): Operator = { + val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() + shuffleWriterBuilder.setOutputDataFile(dataFile) + shuffleWriterBuilder.setOutputIndexFile(indexFile) + + if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { + val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { + case "zstd" => CompressionCodec.Zstd + case "lz4" => CompressionCodec.Lz4 + case "snappy" => CompressionCodec.Snappy + case other => throw new UnsupportedOperationException(s"invalid codec: $other") + } + shuffleWriterBuilder.setCodec(codec) + } else { + shuffleWriterBuilder.setCodec(CompressionCodec.None) } - - if (scanTypes.length == outputAttributes.length) { - scanBuilder.addAllFields(scanTypes.asJava) - - val shuffleWriterBuilder = OperatorOuterClass.ShuffleWriter.newBuilder() - shuffleWriterBuilder.setOutputDataFile(dataFile) - shuffleWriterBuilder.setOutputIndexFile(indexFile) - - if (SparkEnv.get.conf.getBoolean("spark.shuffle.compress", true)) { - val codec = CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_CODEC.get() match { - case "zstd" => CompressionCodec.Zstd - case "lz4" => CompressionCodec.Lz4 - case "snappy" => CompressionCodec.Snappy - case other => throw new UnsupportedOperationException(s"invalid codec: $other") + shuffleWriterBuilder.setCompressionLevel( + CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) + shuffleWriterBuilder.setWriteBufferSize( + CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) + + outputPartitioning match { + case p if isSinglePartitioning(p) => + val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setSinglePartition(partitioning).build()) + case _: HashPartitioning => + val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] + val partitioning = PartitioningOuterClass.HashPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + val partitionExprs = hashPartitioning.expressions + .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) + + if (partitionExprs.length != hashPartitioning.expressions.length) { + throw new UnsupportedOperationException( + s"Partitioning $hashPartitioning is not supported.") } - shuffleWriterBuilder.setCodec(codec) - } else { - shuffleWriterBuilder.setCodec(CompressionCodec.None) - } - shuffleWriterBuilder.setCompressionLevel( - CometConf.COMET_EXEC_SHUFFLE_COMPRESSION_ZSTD_LEVEL.get) - shuffleWriterBuilder.setWriteBufferSize( - CometConf.COMET_SHUFFLE_WRITE_BUFFER_SIZE.get().min(Int.MaxValue).toInt) - outputPartitioning match { - case p if isSinglePartitioning(p) => - val partitioning = PartitioningOuterClass.SinglePartition.newBuilder() - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setSinglePartition(partitioning).build()) - case _: HashPartitioning => - val hashPartitioning = outputPartitioning.asInstanceOf[HashPartitioning] - - val partitioning = PartitioningOuterClass.HashPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) + partitioning.addAllHashExpression(partitionExprs.asJava) + + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setHashPartition(partitioning).build()) + case _: RangePartitioning => + val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] + val partitioning = PartitioningOuterClass.RangePartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + + // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering + // DataFusion will deduplicate identical sort expressions in LexOrdering, + // so we need to transform boundary rows to match the deduplicated structure + val seenExprs = mutable.HashSet[Expression]() + val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) + + rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => + if (seenExprs.contains(sortOrder.child)) { + deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion + } else { + seenExprs += sortOrder.child + deduplicationMap += (idx -> true) // Will be kept by DataFusion + } + } - val partitionExprs = hashPartitioning.expressions + { + val orderingExprs = rangePartitioning.ordering .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - - if (partitionExprs.length != hashPartitioning.expressions.length) { + if (orderingExprs.length != rangePartitioning.ordering.length) { throw new UnsupportedOperationException( - s"Partitioning $hashPartitioning is not supported.") - } - - partitioning.addAllHashExpression(partitionExprs.asJava) - - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setHashPartition(partitioning).build()) - case _: RangePartitioning => - val rangePartitioning = outputPartitioning.asInstanceOf[RangePartitioning] - - val partitioning = PartitioningOuterClass.RangePartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - - // Detect duplicates by tracking expressions directly, similar to DataFusion's LexOrdering - // DataFusion will deduplicate identical sort expressions in LexOrdering, - // so we need to transform boundary rows to match the deduplicated structure - val seenExprs = mutable.HashSet[Expression]() - val deduplicationMap = mutable.ArrayBuffer[(Int, Boolean)]() // (originalIndex, isKept) - - rangePartitioning.ordering.zipWithIndex.foreach { case (sortOrder, idx) => - if (seenExprs.contains(sortOrder.child)) { - deduplicationMap += (idx -> false) // Will be deduplicated by DataFusion - } else { - seenExprs += sortOrder.child - deduplicationMap += (idx -> true) // Will be kept by DataFusion - } - } - - { - // Serialize the ordering expressions for comparisons - val orderingExprs = rangePartitioning.ordering - .flatMap(e => QueryPlanSerde.exprToProto(e, outputAttributes)) - if (orderingExprs.length != rangePartitioning.ordering.length) { - throw new UnsupportedOperationException( - s"Partitioning $rangePartitioning is not supported.") - } - partitioning.addAllSortOrders(orderingExprs.asJava) + s"Partitioning $rangePartitioning is not supported.") } + partitioning.addAllSortOrders(orderingExprs.asJava) + } - // Convert Spark's sequence of InternalRows that represent partitioning boundaries to - // sequences of Literals, where each outer entry represents a boundary row, and each - // internal entry is a value in that row. In other words, these are stored in row major - // order, not column major - val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) - - // Transform boundary rows to match DataFusion's deduplicated structure - val transformedBoundaryExprs: Seq[Seq[Literal]] = - rangePartitionBounds.get.map((row: InternalRow) => { - // For every InternalRow, map its values to Literals - val allLiterals = - row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => - Literal(value, valueType) - } - - // Keep only the literals that correspond to non-deduplicated expressions - allLiterals - .zip(deduplicationMap) - .filter(_._2._2) // Keep only where isKept = true - .map(_._1) // Extract the literal + val boundarySchema = rangePartitioning.ordering.flatMap(e => Some(e.dataType)) + + // rangePartitionBounds holds Spark InternalRows of partitioning boundaries: each row is a + // boundary, each entry a value in that row (row-major, not column-major). Convert to + // Literals and keep only the entries whose ordering expression survived deduplication, so + // the boundary shape matches DataFusion's deduplicated LexOrdering. + val transformedBoundaryExprs: Seq[Seq[Literal]] = + rangePartitionBounds.get.map((row: InternalRow) => { + val allLiterals = + row.toSeq(boundarySchema).zip(boundarySchema).map { case (value, valueType) => + Literal(value, valueType) + } + allLiterals + .zip(deduplicationMap) + .filter(_._2._2) + .map(_._1) + }) + + { + val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs + .map((rowLiterals: Seq[Literal]) => { + val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); + val serializedExprs = + rowLiterals.map(lit_value => + QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) + rowBuilder.addAllPartitionBounds(serializedExprs.asJava) + rowBuilder.build() }) + partitioning.addAllBoundaryRows(boundaryRows.asJava) + } - { - // Convert the sequences of Literals to a collection of serialized BoundaryRows - val boundaryRows: Seq[PartitioningOuterClass.BoundaryRow] = transformedBoundaryExprs - .map((rowLiterals: Seq[Literal]) => { - // Serialize each sequence of Literals as a BoundaryRow - val rowBuilder = PartitioningOuterClass.BoundaryRow.newBuilder(); - val serializedExprs = - rowLiterals.map(lit_value => - QueryPlanSerde.exprToProto(lit_value, outputAttributes).get) - rowBuilder.addAllPartitionBounds(serializedExprs.asJava) - rowBuilder.build() - }) - partitioning.addAllBoundaryRows(boundaryRows.asJava) - } + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRangePartition(partitioning).build()) - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRangePartition(partitioning).build()) + case _: RoundRobinPartitioning => + val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() + partitioning.setNumPartitions(outputPartitioning.numPartitions) + partitioning.setMaxHashColumns( + CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) - case _: RoundRobinPartitioning => - val partitioning = PartitioningOuterClass.RoundRobinPartition.newBuilder() - partitioning.setNumPartitions(outputPartitioning.numPartitions) - partitioning.setMaxHashColumns( - CometConf.COMET_EXEC_SHUFFLE_WITH_ROUND_ROBIN_PARTITIONING_MAX_HASH_COLUMNS.get()) + val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() + shuffleWriterBuilder.setPartitioning( + partitioningBuilder.setRoundRobinPartition(partitioning).build()) - val partitioningBuilder = PartitioningOuterClass.Partitioning.newBuilder() - shuffleWriterBuilder.setPartitioning( - partitioningBuilder.setRoundRobinPartition(partitioning).build()) + case _ => + throw new UnsupportedOperationException( + s"Partitioning $outputPartitioning is not supported.") + } - case _ => - throw new UnsupportedOperationException( - s"Partitioning $outputPartitioning is not supported.") - } + shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) - shuffleWriterBuilder.setTracingEnabled(CometConf.COMET_TRACING_ENABLED.get()) + // Used by the native planner to cast the inlined child's output when DataFusion's + // declared return type drifts from Spark catalyst (see comet#4515). + val expectedFields = outputAttributes + .map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)) + .toArray + schema2Proto(expectedFields).foreach(shuffleWriterBuilder.addExpectedOutputSchema) - val shuffleWriterOpBuilder = OperatorOuterClass.Operator.newBuilder() - shuffleWriterOpBuilder - .setShuffleWriter(shuffleWriterBuilder) - .addChildren(opBuilder.setScan(scanBuilder).build()) - .build() - } else { - // There are unsupported scan type - throw new UnsupportedOperationException( - s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") - } + OperatorOuterClass.Operator + .newBuilder() + .setShuffleWriter(shuffleWriterBuilder) + .addChildren(spec.childNativeOp) + .build() } override def stop(success: Boolean): Option[MapStatus] = { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala index 2b74e5a168..2a05843007 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleDependency.scala @@ -28,11 +28,29 @@ import org.apache.spark.shuffle.ShuffleWriteProcessor import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.comet.{CometMetricNode, NativeExecContext} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.StructType +import org.apache.comet.serde.OperatorOuterClass + +/** + * Bundle of context the native shuffle write path needs at task time. Co-populated for native + * shuffles only; consolidated into a single field on [[CometShuffleDependency]] so it cannot be + * partially set. + */ +case class NativeShuffleSpec( + childNativeOp: OperatorOuterClass.Operator, + childMetricNode: CometMetricNode, + execContext: NativeExecContext) + /** * A [[ShuffleDependency]] that allows us to identify the shuffle dependency as a Comet shuffle. + * + * On the native-shuffle path, also carries a [[NativeShuffleSpec]] so + * [[CometNativeShuffleWriter]] can drive the unified `ShuffleWriter(child = childNativeOp)` plan + * in a single [[org.apache.comet.CometExecIterator]] per partition. `nativeShuffleSpec` is + * populated only when `shuffleType == CometNativeShuffle`. */ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( @transient private val _rdd: RDD[_ <: Product2[K, V]], @@ -49,7 +67,8 @@ class CometShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val outputAttributes: Seq[Attribute] = Seq.empty, val shuffleWriteMetrics: Map[String, SQLMetric] = Map.empty, val numParts: Int = 0, - val rangePartitionBounds: Option[Seq[InternalRow]] = None) + val rangePartitionBounds: Option[Seq[InternalRow]] = None, + val nativeShuffleSpec: Option[NativeShuffleSpec] = None) extends ShuffleDependency[K, V, C]( _rdd, partitioner, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index ee8b716ea3..565183278d 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -34,13 +34,14 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Exp import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.logical.Statistics import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder} +import org.apache.spark.sql.comet.{CometMetricNode, CometNativeExec, CometPlan, CometSinkPlaceHolder, NativeExecContext} +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.ShuffleQueryStageExec import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, ShuffleExchangeExec, ShuffleExchangeLike, ShuffleOrigin} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, ShortType, StringType, StructType, TimestampNTZType, TimestampType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, NullType, ShortType, StringType, StructField, StructType, TimestampNTZType, TimestampType} import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.MutablePair import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator} @@ -101,13 +102,35 @@ case class CometShuffleExchangeExec( private lazy val serializer: Serializer = new UnsafeRowSerializer(child.output.size, longMetric("dataSize")) + /** + * Single-driver native-shuffle context, computed once and shared between [[inputRDD]] and + * [[shuffleDependency]]. `Some` only when `shuffleType == CometNativeShuffle` AND the child is + * a [[CometNativeExec]] subtree. Otherwise the dep is built via the + * [[CometShuffleExchangeExec.prepareShuffleDependency]] convenience overload (synthetic Scan + * placeholder). + */ + @transient private lazy val nativeChildContext: Option[NativeExecContext] = child match { + case nativeChild: CometNativeExec if shuffleType == CometNativeShuffle => + Some(nativeChild.buildNativeContext()) + case _ => None + } + @transient lazy val inputRDD: RDD[_] = if (shuffleType == CometNativeShuffle) { - // CometNativeShuffle assumes that the input plan is Comet plan. - child.executeColumnar() + nativeChildContext match { + case Some(ctx) => + new CometNativeShuffleInputRDD( + sparkContext, + ctx.inputs, + ctx.numPartitions, + ctx.shuffleScanIndices) + case None => + // Non-native child (e.g. CometSparkToColumnarExec): no subtree to inline. The dep gets + // built via the convenience overload below; we just need a real RDD of batches. + child.executeColumnar() + } } else if (shuffleType == CometColumnarShuffle) { - // CometColumnarShuffle uses Spark's row-based execute() API. For Spark row-based plans, - // rows flow directly. For Comet native plans, their doExecute() wraps with ColumnarToRowExec - // to convert columnar batches to rows. + // Row-based shuffle. CometNativeExec.doExecute wraps columnar output with + // ColumnarToRowExec; non-Comet children flow through directly. child.execute() } else { throw new UnsupportedOperationException( @@ -149,12 +172,34 @@ case class CometShuffleExchangeExec( @transient lazy val shuffleDependency: ShuffleDependency[Int, _, _] = if (shuffleType == CometNativeShuffle) { - val dep = CometShuffleExchangeExec.prepareShuffleDependency( - inputRDD.asInstanceOf[RDD[ColumnarBatch]], - child.output, - outputPartitioning, - serializer, - metrics) + val dep = nativeChildContext match { + case Some(ctx) => + val nativeChild = child.asInstanceOf[CometNativeExec] + // RangePartitioner needs real rows for sampling. Reuse the precomputed context so we + // don't re-walk the SparkPlan tree or re-broadcast the encryption Hadoop conf. + val samplingRDD: Option[RDD[ColumnarBatch]] = outputPartitioning match { + case _: RangePartitioning => Some(nativeChild.executeColumnarWithContext(ctx)) + case _ => None + } + CometShuffleExchangeExec.prepareNativeShuffleDependency( + inputRDD.asInstanceOf[CometNativeShuffleInputRDD], + samplingRDD, + child.output, + outputPartitioning, + serializer, + metrics, + NativeShuffleSpec( + nativeChild.nativeOp, + CometMetricNode.fromCometPlan(nativeChild), + ctx)) + case None => + CometShuffleExchangeExec.prepareShuffleDependency( + inputRDD.asInstanceOf[RDD[ColumnarBatch]], + child.output, + outputPartitioning, + serializer, + metrics) + } metrics("numPartitions").set(dep.partitioner.numPartitions) val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY) SQLMetrics.postDriverMetricUpdates( @@ -624,21 +669,112 @@ object CometShuffleExchangeExec } } + /** + * Build a Comet native shuffle dependency around an existing `RDD[ColumnarBatch]` of real + * batches. Used by [[org.apache.spark.sql.comet.CometCollectLimitExec]] and + * [[org.apache.spark.sql.comet.CometTakeOrderedAndProjectExec]] where the input is the result + * of a local-limit / topK transform and there is no separate child native subtree to inline. + * + * Implemented as a thin wrapper around [[prepareNativeShuffleDependency]]: synthesizes a + * `Scan("ShuffleWriterInput")` as the child native op (so the writer's plan is still + * `ShuffleWriter -> Scan`, consuming JVM batches via Arrow C Stream), wraps `rdd` as the single + * leaf input of a thin scheduling RDD, and supplies a minimal [[NativeExecContext]]. Lets the + * writer use one code path for both this case and the [[CometShuffleExchangeExec]] case. + */ def prepareShuffleDependency( rdd: RDD[ColumnarBatch], outputAttributes: Seq[Attribute], outputPartitioning: Partitioning, serializer: Serializer, metrics: Map[String, SQLMetric]): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { - val numParts = rdd.getNumPartitions + + val scanBuilder = OperatorOuterClass.Scan.newBuilder().setSource("ShuffleWriterInput") + val scanTypes = outputAttributes.flatMap { attr => + QueryPlanSerde.serializeDataType(attr.dataType) + } + if (scanTypes.length != outputAttributes.length) { + throw new UnsupportedOperationException( + s"$outputAttributes contains unsupported data types for CometShuffleExchangeExec.") + } + scanBuilder.addAllFields(scanTypes.asJava) + val scanOp = OperatorOuterClass.Operator.newBuilder().setScan(scanBuilder).build() + + // Wrap the raw batches as an RDD[ArrowArrayStream] so the leaf reaches native via the Arrow C + // Stream Interface, matching how CometNativeExec.buildNativeContext feeds the native-child + // path. The synthetic Scan("ShuffleWriterInput") above is the native consumer. + val streamRDD = CometArrowStream.wrapColumnarBatchRDD( + rdd, + StructType( + outputAttributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))), + CometArrowStream.NATIVE_TIMEZONE, + "ShuffleWriterInput") + + val thinRDD = new CometNativeShuffleInputRDD( + rdd.sparkContext, + Seq(streamRDD), + rdd.getNumPartitions, + shuffleScanIndices = Set.empty) + + val ctx = NativeExecContext( + inputs = Seq(streamRDD), + numPartitions = rdd.getNumPartitions, + subqueries = Seq.empty, + broadcastedHadoopConfForEncryption = None, + encryptedFilePaths = Seq.empty, + commonByKey = Map.empty, + perPartitionByKey = Map.empty, + shuffleScanIndices = Set.empty, + hasScanInput = false) + + // The Scan placeholder has no per-operator metrics, so the metric tree for the unified plan + // is `shuffleWriterMetrics` at the root with one empty leaf for the Scan child. + prepareNativeShuffleDependency( + thinRDD, + Some(rdd), + outputAttributes, + outputPartitioning, + serializer, + metrics, + NativeShuffleSpec(scanOp, CometMetricNode(Map.empty), ctx)) + } + + /** + * Build a Comet native shuffle dependency for the [[CometShuffleExchangeExec]] case where the + * shuffle is fed by a [[CometNativeExec]] child. The writer drives the unified + * `ShuffleWriter(child = childNativeOp)` plan in a single + * [[org.apache.comet.CometExecIterator]] per partition. The returned dep carries the + * [[NativeShuffleSpec]] so [[CometNativeShuffleWriter]] can reach the child's per-partition + * execution context, root native operator, and metric node at task time. + * + * @param thinRDD + * scheduling-anchor RDD whose `compute` returns a [[CometNativeShuffleInputIterator]]; + * produces no batches itself. + * @param samplingRDD + * regular columnar execution of the child, only required for [[RangePartitioning]] (sampling + * needs real rows). `None` for hash / single / round-robin. + */ + def prepareNativeShuffleDependency( + thinRDD: CometNativeShuffleInputRDD, + samplingRDD: Option[RDD[ColumnarBatch]], + outputAttributes: Seq[Attribute], + outputPartitioning: Partitioning, + serializer: Serializer, + metrics: Map[String, SQLMetric], + spec: NativeShuffleSpec): ShuffleDependency[Int, ColumnarBatch, ColumnarBatch] = { + val numParts = thinRDD.getNumPartitions // The code block below is mostly brought over from // ShuffleExchangeExec::prepareShuffleDependency val (partitioner, rangePartitionBounds) = outputPartitioning match { case rangePartitioning: RangePartitioning => + // Sampling needs real rows; use the dedicated samplingRDD (a regular columnar execution + // of the child). The thin RDD itself yields nothing. + val samplingInput = samplingRDD.getOrElse( + throw new IllegalStateException( + "RangePartitioning requires a samplingRDD on the native-shuffle path")) // Extract only fields used for sorting to avoid collecting large fields that does not // affect sorting result when deciding partition bounds in RangePartitioner - val rddForSampling = rdd.mapPartitionsInternal { iter => + val rddForSampling = samplingInput.mapPartitionsInternal { iter => val projection = UnsafeProjection.create(rangePartitioning.ordering.map(_.child), outputAttributes) val mutablePair = new MutablePair[InternalRow, Null]() @@ -683,10 +819,8 @@ object CometShuffleExchangeExec None) } - val dependency = new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( - rdd.map( - (0, _) - ), // adding fake partitionId that is always 0 because ShuffleDependency requires it + new CometShuffleDependency[Int, ColumnarBatch, ColumnarBatch]( + thinRDD, serializer = serializer, shuffleWriterProcessor = ShuffleExchangeExec.createShuffleWriteProcessor(metrics), shuffleType = CometNativeShuffle, @@ -696,8 +830,8 @@ object CometShuffleExchangeExec outputAttributes = outputAttributes, shuffleWriteMetrics = metrics, numParts = numParts, - rangePartitionBounds = rangePartitionBounds) - dependency + rangePartitionBounds = rangePartitionBounds, + nativeShuffleSpec = Some(spec)) } /** diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala index c8f2199d53..bd69e91898 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleManager.scala @@ -231,6 +231,7 @@ class CometShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { case cometShuffleHandle: CometNativeShuffleHandle[K @unchecked, V @unchecked] => val dep = cometShuffleHandle.dependency.asInstanceOf[CometShuffleDependency[_, _, _]] new CometNativeShuffleWriter( + dep.nativeShuffleSpec.get, dep.outputPartitioning.get, dep.outputAttributes, dep.shuffleWriteMetrics, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index 8cbf7c9189..3122bcfc80 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.comet.execution.arrow.{CometArrowStream, CometNativeArrowSource} import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec import org.apache.spark.sql.comet.util.Utils import org.apache.spark.sql.execution._ @@ -311,13 +312,13 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, numParts: Int, partitionIdx: Int): CometExecIterator = { getCometIterator( - inputs, + inputObjects, numOutputCols, nativePlan, CometMetricNode(Map.empty), @@ -332,14 +333,14 @@ object CometExec { * executing the same plan across multiple partitions to avoid serializing the plan repeatedly. */ def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, serializedPlan: Array[Byte], numParts: Int, partitionIdx: Int): CometExecIterator = { new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, serializedPlan, CometMetricNode(Map.empty), @@ -350,7 +351,7 @@ object CometExec { } def getCometIterator( - inputs: Seq[Iterator[ColumnarBatch]], + inputObjects: Array[Object], numOutputCols: Int, nativePlan: Operator, nativeMetrics: CometMetricNode, @@ -361,7 +362,7 @@ object CometExec { val bytes = serializeNativePlan(nativePlan) new CometExecIterator( newIterId, - inputs, + inputObjects, numOutputCols, bytes, nativeMetrics, @@ -381,6 +382,33 @@ object CometExec { } } +/** + * Per-partition execution context for a native subtree rooted at a [[CometNativeExec]] boundary. + * Built once on the driver from the SparkPlan tree, then consumed by either + * [[CometNativeExec.executeColumnarWithContext]] (to build a [[CometExecRDD]]) or the + * native-shuffle path (to drive [[CometNativeShuffleWriter]]). Captures broadcast partition + * alignment, plan-data, subqueries, and encryption options so each consumer doesn't re-walk the + * tree. + */ +private[comet] case class NativeExecContext( + inputs: Seq[RDD[_]], + numPartitions: Int, + subqueries: Seq[ScalarSubquery], + broadcastedHadoopConfForEncryption: Option[Broadcast[SerializableConfiguration]], + encryptedFilePaths: Seq[String], + commonByKey: Map[String, Array[Byte]], + perPartitionByKey: Map[String, Array[Array[Byte]]], + shuffleScanIndices: Set[Int], + hasScanInput: Boolean) { + // Catch shape divergence (e.g. broadcast scans with different partition counts after DPP + // filtering) at construction so consumers don't trip ArrayIndexOutOfBoundsException at + // partition idx access time. + require( + perPartitionByKey.values.forall(_.length == numPartitions), + s"All per-partition arrays must have length $numPartitions, but found: " + + perPartitionByKey.map { case (key, arr) => s"$key -> ${arr.length}" }.mkString(", ")) +} + /** * A Comet native physical operator. */ @@ -419,171 +447,232 @@ abstract class CometNativeExec extends CometExec { runningSubqueries.clear() } - override def doExecuteColumnar(): RDD[ColumnarBatch] = { - serializedPlanOpt.plan match { - case None => - // This is in the middle of a native execution, it should not be executed directly. - throw new CometRuntimeException( - s"CometNativeExec should not be executed directly without a serialized plan: $this") - case Some(serializedPlan) => - val serializedPlanCopy = serializedPlan - // TODO: support native metrics for all operators. - val nativeMetrics = CometMetricNode.fromCometPlan(this) - - // Go over all the native scans, in order to see if they need encryption options. - // For each relation in a CometNativeScan generate a hadoopConf, - // for each file path in a relation associate with hadoopConf - // This is done per native plan, so only count scans until a comet input is reached. - val encryptionOptions = - mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] - foreachUntilCometInput(this) { - case scan: CometNativeScanExec => - // This creates a hadoopConf that brings in any SQLConf "spark.hadoop.*" configs and - // per-relation configs since different tables might have different decryption - // properties. - val hadoopConf = scan.relation.sparkSession.sessionState - .newHadoopConfWithOptions(scan.relation.options) - val encryptionEnabled = CometParquetUtils.encryptionEnabled(hadoopConf) - if (encryptionEnabled) { - // hadoopConf isn't serializable, so we have to do a broadcasted config. - val broadcastedConf = - scan.relation.sparkSession.sparkContext - .broadcast(new SerializableConfiguration(hadoopConf)) - - val optsTuple: (Broadcast[SerializableConfiguration], Seq[String]) = - (broadcastedConf, scan.relation.inputFiles.toSeq) - encryptionOptions += optsTuple - } - case _ => // no-op + override def doExecuteColumnar(): RDD[ColumnarBatch] = + executeColumnarWithContext(buildNativeContext()) + + /** + * Build a [[CometExecRDD]] from a precomputed [[NativeExecContext]]. Public so the native + * shuffle path can sample (RangePartitioning) without re-walking the SparkPlan tree and + * re-broadcasting the encryption Hadoop conf. + */ + private[comet] def executeColumnarWithContext(ctx: NativeExecContext): RDD[ColumnarBatch] = { + val serializedPlan = serializedPlanOpt.plan.getOrElse( + throw new CometRuntimeException( + s"CometNativeExec should not be executed directly without a serialized plan: $this")) + val nativeMetrics = CometMetricNode.fromCometPlan(this) + + new CometExecRDD( + sparkContext, + ctx.inputs, + ctx.commonByKey, + ctx.perPartitionByKey, + serializedPlan, + ctx.numPartitions, + output.length, + nativeMetrics, + ctx.subqueries, + ctx.broadcastedHadoopConfForEncryption, + ctx.encryptedFilePaths, + ctx.shuffleScanIndices) { + override def compute(split: Partition, context: TaskContext): Iterator[ColumnarBatch] = { + val res = super.compute(split, context) + if (ctx.hasScanInput) { + Option(context).foreach(nativeMetrics.reportScanInputMetrics) } - assert( - encryptionOptions.size <= 1, - "We expect one native scan that requires encryption reading in a Comet plan," + - " since we will broadcast one hadoopConf.") - // If this assumption changes in the future, you can look at the commit history of #2447 - // to see how there used to be a map of relations to broadcasted confs in case multiple - // relations in a single plan. The example that came up was UNION. See discussion at: - // https://github.com/apache/datafusion-comet/pull/2447#discussion_r2406118264 - val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = - encryptionOptions.headOption match { - case Some((conf, paths)) => (Some(conf), paths) - case None => (None, Seq.empty) - } + res + } + } + } - // Find planning data within this stage (stops at shuffle boundaries). - val (commonByKey, perPartitionByKey) = findAllPlanData(this) + /** + * Walk this CometNativeExec subtree once and gather everything needed to launch native + * execution. See [[NativeExecContext]] for the field set. + */ + private[comet] def buildNativeContext(): NativeExecContext = { + // Find native scans that need encryption: build a hadoopConf per relation, broadcast it once + // so executors can decrypt on read. Capped at one because we only broadcast one conf per + // CometExecIterator (see #2447 for history of the per-relation map approach). + val encryptionOptions = + mutable.ArrayBuffer.empty[(Broadcast[SerializableConfiguration], Seq[String])] + foreachUntilCometInput(this) { + case scan: CometNativeScanExec => + // Bring in any SQLConf "spark.hadoop.*" configs and the per-relation options, since + // different tables may have different decryption properties. + val hadoopConf = scan.relation.sparkSession.sessionState + .newHadoopConfWithOptions(scan.relation.options) + if (CometParquetUtils.encryptionEnabled(hadoopConf)) { + // hadoopConf isn't serializable, so ship it to executors via a broadcast. + val broadcastedConf = scan.relation.sparkSession.sparkContext + .broadcast(new SerializableConfiguration(hadoopConf)) + encryptionOptions += ((broadcastedConf, scan.relation.inputFiles.toSeq)) + } + case _ => + } + assert( + encryptionOptions.size <= 1, + "We expect one native scan that requires encryption reading in a Comet plan," + + " since we will broadcast one hadoopConf.") + val (broadcastedHadoopConfForEncryption, encryptedFilePaths) = + encryptionOptions.headOption match { + case Some((conf, paths)) => (Some(conf), paths) + case None => (None, Seq.empty) + } - // Collect the input ColumnarBatches from the child operators and create a CometExecIterator - // to execute the native plan. - val sparkPlans = ArrayBuffer.empty[SparkPlan] - val inputs = ArrayBuffer.empty[RDD[ColumnarBatch]] + // Find planning data within this stage (stops at shuffle boundaries). + val (commonByKey, perPartitionByKey) = findAllPlanData(this) - foreachUntilCometInput(this)(sparkPlans += _) + // Collect the input batches from the child operators. Non-shuffle inputs become + // RDD[ArrowArrayStream] (one stream per partition, exported via the C Stream Interface + // for native consumption); shuffle inputs stay as CometShuffledBatchRDD. + val sparkPlans = ArrayBuffer.empty[SparkPlan] + val inputs = ArrayBuffer.empty[RDD[_]] - // Find the first non broadcast plan - val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { - case (_: CometBroadcastExchangeExec, _) => false - case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false - case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false - case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false - case _ => true - } + foreachUntilCometInput(this)(sparkPlans += _) - val containsBroadcastInput = sparkPlans.exists { - case _: CometBroadcastExchangeExec => true - case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true - case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true - case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true - case _ => false - } + // Find the first non broadcast plan + val firstNonBroadcastPlan = sparkPlans.zipWithIndex.find { + case (_: CometBroadcastExchangeExec, _) => false + case (BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _), _) => false + case (BroadcastQueryStageExec(_, _: ReusedExchangeExec, _), _) => false + case (ReusedExchangeExec(_, _: CometBroadcastExchangeExec), _) => false + case _ => true + } - // If the first non broadcast plan is not found, it means all the plans are broadcast plans. - // This is not expected, so throw an exception. - if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) { - throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") - } + val containsBroadcastInput = sparkPlans.exists { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true + case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true + case _ => false + } - // If the first non broadcast plan is found, we need to adjust the partition number of - // the broadcast plans to make sure they have the same partition number as the first non - // broadcast plan. - val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = - firstNonBroadcastPlan.get._1 match { - case plan: CometNativeExec => - (null, plan.outputPartitioning.numPartitions) - case plan => - val rdd = plan.executeColumnar() - (rdd, rdd.getNumPartitions) - } + // If the first non broadcast plan is not found, it means all the plans are broadcast plans. + // This is not expected, so throw an exception. + if (containsBroadcastInput && firstNonBroadcastPlan.isEmpty) { + throw new CometRuntimeException(s"Cannot find the first non broadcast plan: $this") + } - // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with - // same partition number. But for Comet, we need to zip them so we need to adjust the - // partition number of Broadcast RDDs to make sure they have the same partition number. - sparkPlans.zipWithIndex.foreach { case (plan, idx) => - plan match { - case c: CometBroadcastExchangeExec => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case BroadcastQueryStageExec( - _, - ReusedExchangeExec(_, c: CometBroadcastExchangeExec), - _) => - inputs += c.executeColumnar(firstNonBroadcastPlanNumPartitions) - case _: CometNativeExec => - // no-op - case _ if idx == firstNonBroadcastPlan.get._2 => - inputs += firstNonBroadcastPlanRDD - case _ => - val rdd = plan.executeColumnar() - if (rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { - throw new CometRuntimeException( - s"Partition number mismatch: ${rdd.getNumPartitions} != " + - s"$firstNonBroadcastPlanNumPartitions") - } else { - inputs += rdd - } - } - } + def isShuffleScanInput(plan: SparkPlan): Boolean = plan match { + case _: CometShuffleExchangeExec | _: ShuffleQueryStageExec | _: AQEShuffleReadExec => + true + case ReusedExchangeExec(_, _: CometShuffleExchangeExec) => true + case _ => false + } - if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) { - throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") - } + // The protobuf is the source of truth for whether a slot is a ShuffleScan or a regular + // Scan: `CometExchangeSink.shouldUseShuffleScan` only fires for AQE wrappers + // (`ShuffleQueryStageExec`), so a bare non-AQE `CometShuffleExchangeExec` always serializes + // as a regular Scan regardless of `COMET_SHUFFLE_DIRECT_READ_ENABLED`. Driving the JVM + // dispatch from `shuffleScanIndices` instead of the conf keeps the two aligned. + val shuffleScanIndices = findShuffleScanIndices(nativeOp) - // Detect ShuffleScan indices for direct read in CometExecRDD - val shuffleScanIndices = findShuffleScanIndices(serializedPlanCopy) - - // Unified RDD creation - CometExecRDD handles all cases - val subqueries = collectSubqueries(this) - val hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec]) - new CometExecRDD( - sparkContext, - inputs.toSeq, - commonByKey, - perPartitionByKey, - serializedPlanCopy, - firstNonBroadcastPlanNumPartitions, - output.length, - nativeMetrics, - subqueries, - broadcastedHadoopConfForEncryption, - encryptedFilePaths, - shuffleScanIndices) { - override def compute( - split: Partition, - context: TaskContext): Iterator[ColumnarBatch] = { - val res = super.compute(split, context) - - // Report scan input metrics only when the native plan contains a scan. - if (hasScanInput) { - Option(context).foreach(nativeMetrics.reportScanInputMetrics) - } + def isBroadcastInput(plan: SparkPlan): Boolean = plan match { + case _: CometBroadcastExchangeExec => true + case BroadcastQueryStageExec(_, _: CometBroadcastExchangeExec, _) => true + case BroadcastQueryStageExec(_, _: ReusedExchangeExec, _) => true + case ReusedExchangeExec(_, _: CometBroadcastExchangeExec) => true + case _ => false + } + + // Unwrap any number of AQE / reuse wrappers to find a CometBroadcastExchangeExec, if + // present. Returns the unwrapped exchange for input wiring -- broadcast partition counts + // are coerced to match firstNonBroadcastPlanNumPartitions, so we always read from the + // underlying exchange directly. + def asBroadcastExchange(plan: SparkPlan): Option[CometBroadcastExchangeExec] = + plan match { + case c: CometBroadcastExchangeExec => Some(c) + case BroadcastQueryStageExec(_, c: CometBroadcastExchangeExec, _) => Some(c) + case ReusedExchangeExec(_, c: CometBroadcastExchangeExec) => Some(c) + case BroadcastQueryStageExec( + _, + ReusedExchangeExec(_, c: CometBroadcastExchangeExec), + _) => + Some(c) + case _ => None + } + + def asArrowStreamRDD(plan: SparkPlan, partitionCount: Int, scanSlot: Int): RDD[_] = + plan match { + case s: CometNativeArrowSource => + s.doExecuteAsArrowStream() + case _ if asBroadcastExchange(plan).isDefined => + val c = asBroadcastExchange(plan).get + CometArrowStream.wrapColumnarBatchRDD( + c.executeColumnar(partitionCount), + c.schema, + CometArrowStream.NATIVE_TIMEZONE, + c.nodeName) + case _ if isShuffleScanInput(plan) && shuffleScanIndices.contains(scanSlot) => + // Direct-read shuffle: `CometShuffledBatchRDD` reaches native via + // CometShuffleBlockIterator. Other shuffle slots fall through and get wrapped. + plan.executeColumnar() + case _ => + CometArrowStream.wrapColumnarBatchRDD( + plan.executeColumnar(), + plan.schema, + CometArrowStream.NATIVE_TIMEZONE, + plan.nodeName) + } - res + // Walk-order: count how many non-CometNativeExec plans come before the firstNonBroadcast + // plan in `sparkPlans`. That's the slot index it will occupy in `inputs`, and therefore + // the protobuf scan-slot index whose Scan vs ShuffleScan classification governs whether + // it should be wrapped or direct-read. + val firstNonBroadcastSlot = sparkPlans + .take(firstNonBroadcastPlan.get._2) + .count(p => !p.isInstanceOf[CometNativeExec]) + + // If the first non broadcast plan is found, we need to adjust the partition number of + // the broadcast plans to make sure they have the same partition number as the first non + // broadcast plan. + val (firstNonBroadcastPlanRDD, firstNonBroadcastPlanNumPartitions) = + firstNonBroadcastPlan.get._1 match { + case plan: CometNativeExec => + (null, plan.outputPartitioning.numPartitions) + case plan => + val rdd = asArrowStreamRDD(plan, 0, firstNonBroadcastSlot) + (rdd, rdd.getNumPartitions) + } + + // Spark doesn't need to zip Broadcast RDDs, so it doesn't schedule Broadcast RDDs with + // same partition number. But for Comet, we need to zip them so we need to adjust the + // partition number of Broadcast RDDs to make sure they have the same partition number. + sparkPlans.zipWithIndex.foreach { case (plan, idx) => + plan match { + case _: CometNativeExec => + // no-op + case _ if idx == firstNonBroadcastPlan.get._2 => + inputs += firstNonBroadcastPlanRDD + case _ => + // Each plan we add to `inputs` corresponds to the next protobuf scan slot, in + // walk order. `inputs.size` is the slot index this plan will occupy. + val scanSlot = inputs.size + val rdd = asArrowStreamRDD(plan, firstNonBroadcastPlanNumPartitions, scanSlot) + if (!isBroadcastInput(plan) && + rdd.getNumPartitions != firstNonBroadcastPlanNumPartitions) { + throw new CometRuntimeException( + s"Partition number mismatch: ${rdd.getNumPartitions} != " + + s"$firstNonBroadcastPlanNumPartitions") + } else { + inputs += rdd } - } + } } + + if (inputs.isEmpty && !sparkPlans.forall(_.isInstanceOf[CometNativeExec])) { + throw new CometRuntimeException(s"No input for CometNativeExec:\n $this") + } + + NativeExecContext( + inputs = inputs.toSeq, + numPartitions = firstNonBroadcastPlanNumPartitions, + subqueries = collectSubqueries(this), + broadcastedHadoopConfForEncryption = broadcastedHadoopConfForEncryption, + encryptedFilePaths = encryptedFilePaths, + commonByKey = commonByKey, + perPartitionByKey = perPartitionByKey, + shuffleScanIndices = shuffleScanIndices, + hasScanInput = sparkPlans.exists(_.isInstanceOf[CometNativeScanExec])) } /** @@ -623,11 +712,10 @@ abstract class CometNativeExec extends CometExec { } /** - * Walk the serialized protobuf plan depth-first to find which input indices correspond to + * Walk the protobuf operator tree depth-first to find which input indices correspond to * ShuffleScan vs Scan leaf nodes. Each Scan or ShuffleScan leaf consumes one input in order. */ - private def findShuffleScanIndices(planBytes: Array[Byte]): Set[Int] = { - val plan = OperatorOuterClass.Operator.parseFrom(planBytes) + private def findShuffleScanIndices(plan: OperatorOuterClass.Operator): Set[Int] = { var scanIndex = 0 val indices = mutable.Set.empty[Int] def walk(op: OperatorOuterClass.Operator): Unit = { @@ -1502,17 +1590,12 @@ trait CometBaseAggregate { if (aggregateExpressions.isEmpty) { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withFallbackReason( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - Some(builder.setHashAgg(hashAggBuilder).build()) + buildAggOp( + builder, + hashAggBuilder, + groupingExpressions.map(_.toAttribute), + resultExpressions, + aggregate) } else { // Validate mode combinations. We support: // - All Partial @@ -1583,18 +1666,6 @@ trait CometBaseAggregate { val hashAggBuilder = OperatorOuterClass.HashAggregate.newBuilder() hashAggBuilder.addAllGroupingExprs(groupingExprs.map(_.get).asJava) hashAggBuilder.addAllAggExprs(aggExprs.map(_.get).asJava) - if (mode == CometAggregateMode.Final) { - val attributes = groupingExpressions.map(_.toAttribute) ++ aggregateAttributes - val resultExprs = resultExpressions.map(exprToProto(_, attributes)) - if (resultExprs.exists(_.isEmpty)) { - withFallbackReason( - aggregate, - s"Unsupported result expressions found in: $resultExpressions", - resultExpressions: _*) - return None - } - hashAggBuilder.addAllResultExprs(resultExprs.map(_.get).asJava) - } hashAggBuilder.setModeValue(mode.getNumber) // Send per-expression modes and buffer offset for PartialMerge handling @@ -1613,7 +1684,19 @@ trait CometBaseAggregate { hashAggBuilder.setInitialInputBufferOffset(aggregate.initialInputBufferOffset) } - Some(builder.setHashAgg(hashAggBuilder).build()) + // Final aggregations may carry a result projection (e.g. `COUNT(col) + 1`) that + // catalyst encodes via `resultExpressions`. Partial / PartialMerge aggregates emit + // raw state buffers and never need it. + if (mode == CometAggregateMode.Final) { + buildAggOp( + builder, + hashAggBuilder, + groupingExpressions.map(_.toAttribute) ++ aggregateAttributes, + resultExpressions, + aggregate) + } else { + Some(builder.setHashAgg(hashAggBuilder).build()) + } } else { val allChildren: Seq[Expression] = groupingExpressions ++ aggregateExpressions ++ aggregateAttributes @@ -1624,6 +1707,51 @@ trait CometBaseAggregate { } + /** + * Serialize a HashAggregate, wrapping it in an explicit `Projection` op when Spark's declared + * output (`resultExpressions`) differs from the aggregate's natural output. DataFusion's hash + * aggregate emits only its natural shape (group keys + agg results), so any reshape catalyst + * declared - alias renames, `COUNT(col) + 1`, or empty output for catalyst-pruned EXISTS / + * row-existence-only subqueries - is expressed as a separate Projection above the HashAgg. Both + * ops share the caller's `plan_id` so the aggregate's native metrics roll up under the same + * Spark operator. + */ + private def buildAggOp( + builder: Operator.Builder, + hashAggBuilder: OperatorOuterClass.HashAggregate.Builder, + naturalOutput: Seq[Attribute], + resultExpressions: Seq[NamedExpression], + aggregate: BaseAggregateExec): Option[Operator] = { + if (resultExpressions.map(_.toAttribute) == naturalOutput) { + return Some(builder.setHashAgg(hashAggBuilder).build()) + } + val resultExprs = resultExpressions.map(exprToProto(_, naturalOutput)) + if (resultExprs.exists(_.isEmpty)) { + withFallbackReason( + aggregate, + s"Unsupported result expressions found in: $resultExpressions", + resultExpressions: _*) + return None + } + val planId = builder.getPlanId + val hashAggOp = OperatorOuterClass.Operator + .newBuilder() + .setPlanId(planId) + .addAllChildren(builder.getChildrenList) + .setHashAgg(hashAggBuilder) + .build() + val projection = OperatorOuterClass.Projection + .newBuilder() + .addAllProjectList(resultExprs.map(_.get).asJava) + Some( + OperatorOuterClass.Operator + .newBuilder() + .setPlanId(planId) + .addChildren(hashAggOp) + .setProjection(projection) + .build()) + } + /** * Find the first Comet partial aggregate in the plan. If it reaches a Spark HashAggregate with * partial or partial-merge mode, it will return None. diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala index 71d75b8ed8..15e1e2c410 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala @@ -36,6 +36,7 @@ import org.apache.arrow.vector.util.VectorSchemaRootAppender import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.io.CompressionCodec +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.comet.execution.arrow.ArrowReaderIterator import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -205,6 +206,14 @@ object Utils extends CometTypeShim with Logging { }.asJava) } + /** + * Build a `StructType` from a sequence of Spark `Attribute`s. Avoids + * `StructType.fromAttributes` (removed in Spark 4) and `DataTypeUtils.fromAttributes` (only on + * 4) so the same call works across supported Spark versions. + */ + def fromAttributes(attributes: Seq[Attribute]): StructType = + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) + /** * Serializes a list of `ColumnarBatch` into an output stream. This method must be in `spark` * package because `ChunkedByteBufferOutputStream` is spark private class. As it uses Arrow diff --git a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala index 9c34b3a3ce..e30a1cf6b3 100644 --- a/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometNativeSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.sql.CometTestBase import org.apache.spark.sql.catalyst.expressions.PrettyAttribute import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometExec, CometExecUtils} -import org.apache.spark.sql.types.LongType +import org.apache.spark.sql.comet.execution.arrow.CometArrowStream +import org.apache.spark.sql.types.{LongType, StructField, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch class CometNativeSuite extends CometTestBase { @@ -31,15 +32,16 @@ class CometNativeSuite extends CometTestBase { val rdd = spark.range(0, 1).rdd.map { value => val limitOp = CometExecUtils.getLimitNativePlan(Seq(PrettyAttribute("test", LongType)), 100).get - val cometIter = CometExec.getCometIterator( - Seq(new Iterator[ColumnarBatch] { + val arrowStream = CometArrowStream.fromColumnarBatchIter( + new Iterator[ColumnarBatch] { override def hasNext: Boolean = true override def next(): ColumnarBatch = throw new NullPointerException() - }), - 1, - limitOp, - 1, - 0) + }, + StructType(Seq(StructField("test", LongType, nullable = false))), + CometArrowStream.NATIVE_TIMEZONE, + "test-npe") + val cometIter = + CometExec.getCometIterator(Array(arrowStream.asInstanceOf[Object]), 1, limitOp, 1, 0) try { cometIter.next() } finally { diff --git a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala index cd0beb56cc..ae14c68207 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala @@ -2109,4 +2109,45 @@ class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + // Regression: Catalyst prunes `HashAggregateExec.resultExpressions` to + // empty for EXISTS / row-existence-only subqueries. The native HashAggregate's natural + // output (the grouping keys) then disagrees with the pruned JVM `output`, leaking through + // any boundary that derived its schema from `output`. The fix wraps the aggregate in an + // explicit Projection op when natural != declared. + // + // Surfaced upstream in `subquery/exists-subquery/exists-orderby-limit.sql` (query #19, + // an EXISTS over `max(...) GROUP BY state LIMIT 1 OFFSET 2`). The exact `EXISTS-in-WHERE` + // shape doesn't reproduce under CometTestBase's optimizer state, but `count(*)` over the + // same derived aggregate triggers the equivalent ColumnPruning path locally - we assert + // the inner HashAgg's resultExpressions actually got pruned, so a future Spark version + // that breaks the trigger fails the test loudly rather than passing silently. + test("HashAggregate with catalyst-pruned resultExpressions returns 0-col output") { + withTempDir { dir => + val deptPath = new Path(dir.toURI.toString, "dept") + spark + .sql("""SELECT * FROM VALUES + | (10, 'CA'), (20, 'NY'), (30, 'TX'), + | (40, 'OR'), (50, 'NJ'), (70, 'FL') + |AS t(dept_id, state)""".stripMargin) + .write + .parquet(deptPath.toUri.toString) + withParquetTable(deptPath.toUri.toString, "dept") { + val sql = + """SELECT count(*) FROM ( + | SELECT max(dept_id) AS m FROM dept GROUP BY state LIMIT 1 OFFSET 2) sub""".stripMargin + val plan = spark.sql(sql).queryExecution.executedPlan + val pruned = collectWithSubqueries(plan) { + case a: org.apache.spark.sql.execution.aggregate.HashAggregateExec + if a.resultExpressions.isEmpty => + a + case a: CometHashAggregateExec if a.resultExpressions.isEmpty => a + } + assert( + pruned.nonEmpty, + s"Expected a HashAggregateExec with empty resultExpressions in:\n$plan") + checkSparkAnswerAndOperator(sql) + } + } + } + } diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index a1460427c0..c1cfc5f7d2 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -3940,6 +3940,45 @@ class CometExecSuite extends CometTestBase { } } + test("CometLocalTableScanExec falls back when schema contains TimeType") { + assume( + org.apache.comet.CometSparkSessionExtensions.isSpark41Plus, + "TimeType requires Spark 4.1+") + // spark.sql.timeType.enabled defaults to Utils.isTesting; enable explicitly so the + // row encoder accepts TIME (matches Spark's own TimeFunctionsSuiteBase setup). + withSQLConf( + "spark.sql.timeType.enabled" -> "true", + CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + // VALUES folds to a LocalRelation, exercising the CometLocalTableScanExec convert + // path; the TimeType column should drive the schema-level fallback. + val df = spark.sql("SELECT * FROM VALUES (TIME '12:34:56'), (TIME '01:02:03') AS t(c)") + checkSparkAnswer(df) + } + } + + test("CometLocalTableScanExec does not leak Arrow buffers (project consumer)") { + // Forces a CometNativeExec consumer over an ArrowArrayStream input. The producer must not + // leak the Arrow buffers it allocates per batch; if it does, the BaseAllocator + // leak detector fires inside the task completion listener. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select($"a" + 1)) + } + } + + test("CometLocalTableScanExec does not leak Arrow buffers (collect_list)") { + // Mirrors DataFrameAggregateSuite "collect functions" which is the test that + // surfaced the leak in CI. + withSQLConf(CometConf.COMET_EXEC_LOCAL_TABLE_SCAN_ENABLED.key -> "true") { + val session = spark + import session.implicits._ + val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") + checkSparkAnswer(df.select(collect_list($"a"), collect_list($"b"))) + } + } + test("Native_datafusion reports correct files and bytes scanned") { val inputFiles = 2 diff --git a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala index b858fe5c83..a2aac7e6c7 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometNativeColumnarToRowSuite.scala @@ -492,9 +492,13 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan InternalRow(i, UTF8String.fromString(s"value_$i")) } - // Create batches using rowToArrowBatchIter which handles shading internally + // Each emitted batch needs independent Arrow buffers so the test can hold rows from + // earlier batches while later batches are consumed. CometArrowConverters allocates a + // fresh VSR per batch from the supplied allocator. + val allocator = + org.apache.comet.CometArrowAllocator.newChildAllocator("c2r-test", 0, Long.MaxValue) val batchIter = CometArrowConverters - .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", null) + .rowToArrowBatchIter(rows.iterator, schema, rowsPerBatch, "UTC", allocator) val converter = new NativeColumnarToRowConverter(schema, rowsPerBatch) try { @@ -529,6 +533,7 @@ class CometNativeColumnarToRowSuite extends CometTestBase with AdaptiveSparkPlan "reused UnsafeRow object.") } finally { converter.close() + allocator.close() } } diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala index 0187aed8e5..fdcca8d351 100644 --- a/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/comet/CometTaskMetricsSuite.scala @@ -106,6 +106,57 @@ class CometTaskMetricsSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("native shuffle reports task input metrics for its scan child") { + // A native shuffle whose child subtree includes a CometNativeScanExec runs that scan inside + // the writer's single plan, so the usual CometExecRDD.compute() input-metric bridge never + // runs. CometNativeShuffleWriter must report the scan's bytes/rows itself; otherwise the + // ShuffleMapTask reports zero input metrics. + withParquetTable((0 until 10000).map(i => (i, (i + 1).toLong)), "tbl") { + val shuffled = sql("SELECT * FROM tbl").repartition(4, $"_1") + + val cometShuffle = find(shuffled.queryExecution.executedPlan) { + case _: CometShuffleExchangeExec => true + case _ => false + } + assert(cometShuffle.isDefined, "CometShuffleExchangeExec not found in the plan") + assert( + cometShuffle.get.asInstanceOf[CometShuffleExchangeExec].shuffleType == CometNativeShuffle) + assert( + find(shuffled.queryExecution.executedPlan) { + case _: CometNativeScanExec => true + case _ => false + }.isDefined, + "expected a CometNativeScanExec child so the scan is embedded in the writer plan") + + val mapInputBytes = mutable.ArrayBuffer.empty[Long] + val mapInputRecords = mutable.ArrayBuffer.empty[Long] + spark.sparkContext.addSparkListener(new SparkListener { + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + if (taskEnd.taskType.contains("ShuffleMapTask")) { + val im = taskEnd.taskMetrics.inputMetrics + mapInputBytes.synchronized { mapInputBytes += im.bytesRead } + mapInputRecords.synchronized { mapInputRecords += im.recordsRead } + } + } + }) + + // Avoid receiving earlier taskEnd events + spark.sparkContext.listenerBus.waitUntilEmpty() + + shuffled.collect() + + spark.sparkContext.listenerBus.waitUntilEmpty() + + assert(mapInputRecords.nonEmpty, "no ShuffleMapTask metrics captured") + assert( + mapInputRecords.sum == 10000, + s"recordsRead across map tasks (${mapInputRecords.sum}) should equal the scanned row count") + assert( + mapInputBytes.sum > 0, + s"bytesRead across map tasks (${mapInputBytes.sum}) should be > 0") + } + } + test("native parquet write reports task-level output metrics") { withParquetTable((0 until 5000).map(i => (i, (i + 1).toLong)), "tbl") { withTempPath { dir => diff --git a/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala new file mode 100644 index 0000000000..c423a49d2a --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowStreamSuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.comet.execution.arrow + +import scala.jdk.CollectionConverters._ + +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.{BigIntVector, IntVector} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.vectorized.ColumnarBatch + +import org.apache.comet.vector.{CometPlainVector, CometVector} + +/** + * Direct tests for [[CometArrowStream.reconcileStreamSchema]]. The end-to-end regression that + * motivated this (Spark Long vs DataFusion Int32 for `width_bucket`) lives in + * `CometMathExpressionSuite`, but that test only catches *one* function-level type drift. This + * suite covers the boundary contract independently of any specific function. + */ +class CometArrowStreamSuite extends AnyFunSuite with Matchers { + + private def expectedSchema(types: (String, ArrowType)*): Schema = { + val fields = types.map { case (name, t) => + new Field(name, new FieldType(true, t, null), java.util.Collections.emptyList[Field]()) + } + new Schema(fields.asJava) + } + + private def batchOf(vectors: CometVector*): ColumnarBatch = { + val numRows = if (vectors.isEmpty) 0 else vectors.head.getValueVector.getValueCount + new ColumnarBatch(vectors.toArray, numRows) + } + + test("reconcileStreamSchema returns expected schema unchanged on empty iterator") { + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + val (returned, iter) = + CometArrowStream.reconcileStreamSchema("test", expected, Iterator.empty) + returned shouldBe expected + iter.hasNext shouldBe false + } + + test("reconcileStreamSchema returns expected schema when types match") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + val v = new BigIntVector("col_0", allocator) + v.allocateNew() + v.setSafe(0, 1L) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + + val (returned, iter) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + returned.getFields.get(0).getType shouldBe new ArrowType.Int(64, true) + iter.hasNext shouldBe true + iter.next() should be theSameInstanceAs batch + + cv.close() + } finally { + allocator.close() + } + } + + test("reconcileStreamSchema rebuilds schema from actual vector types when they differ") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + // Producer produced Int32 (e.g., DataFusion-Spark width_bucket pre-fix), consumer expects + // Int64 (Spark catalyst WidthBucket.dataType = LongType). The truthful schema is Int32 so + // native ScanExec's build_record_batch can cast at the boundary. + val v = new IntVector("col_0", allocator) + v.allocateNew() + v.setSafe(0, 1) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) + + val (returned, iter) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + val returnedField = returned.getFields.get(0) + returnedField.getType shouldBe new ArrowType.Int(32, true) + // Names come from `expected` so name-indexed consumers keep working. + returnedField.getName shouldBe "c0" + iter.hasNext shouldBe true + iter.next() should be theSameInstanceAs batch + + cv.close() + } finally { + allocator.close() + } + } + + test( + "reconcileStreamSchema preserves nullability when expected is nullable but actual is not") { + val allocator = new RootAllocator(Integer.MAX_VALUE) + try { + // Spark catalyst declares the column nullable; the first batch happens to come from a + // vector whose Field reports non-nullable. Subsequent batches may carry nulls, so the + // wire schema must stay nullable or native validation rejects the next null with + // "declared as non-nullable but contains null values". + val v = new BigIntVector( + new Field( + "col_0", + new FieldType(false, new ArrowType.Int(64, true), null), + java.util.Collections.emptyList[Field]()), + allocator) + v.allocateNew() + v.setSafe(0, 1L) + v.setValueCount(1) + val cv = new CometPlainVector(v, false) + val batch = batchOf(cv) + val expected = expectedSchema("c0" -> new ArrowType.Int(64, true)) // nullable=true + + val (returned, _) = CometArrowStream + .reconcileStreamSchema("test", expected, Iterator.single(batch)) + + val returnedField = returned.getFields.get(0) + returnedField.isNullable shouldBe true + + cv.close() + } finally { + allocator.close() + } + } +}