diff --git a/native/core/src/execution/expressions/arithmetic.rs b/native/core/src/execution/expressions/arithmetic.rs index 71fe85ef52..a9749678db 100644 --- a/native/core/src/execution/expressions/arithmetic.rs +++ b/native/core/src/execution/expressions/arithmetic.rs @@ -23,7 +23,7 @@ macro_rules! arithmetic_expr_builder { ($builder_name:ident, $expr_type:ident, $operator:expr) => { pub struct $builder_name; - impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + impl $crate::execution::planner::expression_registry::ExpressionBuilder for $builder_name { fn build( &self, spark_expr: &datafusion_comet_proto::spark_expression::Expr, @@ -61,7 +61,8 @@ use crate::execution::{ expressions::extract_expr, operators::ExecutionError, planner::{ - from_protobuf_eval_mode, traits::ExpressionBuilder, BinaryExprOptions, PhysicalPlanner, + expression_registry::ExpressionBuilder, from_protobuf_eval_mode, BinaryExprOptions, + PhysicalPlanner, }, }; diff --git a/native/core/src/execution/operators/mod.rs b/native/core/src/execution/operators/mod.rs index b01f7857be..33b9be9434 100644 --- a/native/core/src/execution/operators/mod.rs +++ b/native/core/src/execution/operators/mod.rs @@ -31,6 +31,7 @@ pub use expand::ExpandExec; mod iceberg_scan; mod parquet_writer; pub use parquet_writer::ParquetWriterExec; +pub mod projection; mod scan; /// Error returned during executing operators. diff --git a/native/core/src/execution/operators/projection.rs b/native/core/src/execution/operators/projection.rs new file mode 100644 index 0000000000..6ba1bb5d59 --- /dev/null +++ b/native/core/src/execution/operators/projection.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Projection operator builder + +use std::sync::Arc; + +use datafusion::physical_plan::projection::ProjectionExec; +use datafusion_comet_proto::spark_operator::Operator; +use jni::objects::GlobalRef; + +use crate::{ + execution::{ + operators::{ExecutionError, ScanExec}, + planner::{operator_registry::OperatorBuilder, PhysicalPlanner}, + spark_plan::SparkPlan, + }, + extract_op, +}; + +/// Builder for Projection operators +pub struct ProjectionBuilder; + +impl OperatorBuilder for ProjectionBuilder { + fn build( + &self, + spark_plan: &Operator, + inputs: &mut Vec>, + partition_count: usize, + planner: &PhysicalPlanner, + ) -> Result<(Vec, Arc), ExecutionError> { + let project = extract_op!(spark_plan, Projection); + let children = &spark_plan.children; + + assert_eq!(children.len(), 1); + let (scans, child) = planner.create_plan(&children[0], inputs, partition_count)?; + + // Create projection expressions + let exprs: Result, _> = project + .project_list + .iter() + .enumerate() + .map(|(idx, expr)| { + planner + .create_expr(expr, child.schema()) + .map(|r| (r, format!("col_{idx}"))) + }) + .collect(); + + let projection = Arc::new(ProjectionExec::try_new( + exprs?, + Arc::clone(&child.native_plan), + )?); + + Ok(( + scans, + Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), + )) + } +} diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 269ded1e48..cc92310475 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -18,7 +18,8 @@ //! Converts Spark physical plan to DataFusion physical plan pub mod expression_registry; -pub mod traits; +pub mod macros; +pub mod operator_registry; use crate::execution::operators::IcebergScanExec; use crate::{ @@ -27,6 +28,7 @@ use crate::{ expressions::subquery::Subquery, operators::{ExecutionError, ExpandExec, ParquetWriterExec, ScanExec}, planner::expression_registry::ExpressionRegistry, + planner::operator_registry::OperatorRegistry, serde::to_arrow_datatype, shuffle::ShuffleWriterExec, }, @@ -861,29 +863,19 @@ impl PhysicalPlanner { inputs: &mut Vec>, partition_count: usize, ) -> Result<(Vec, Arc), ExecutionError> { + // Try to use the modular registry first - this automatically handles any registered operator types + if OperatorRegistry::global().can_handle(spark_plan) { + return OperatorRegistry::global().create_plan( + spark_plan, + inputs, + partition_count, + self, + ); + } + + // Fall back to the original monolithic match for other operators let children = &spark_plan.children; match spark_plan.op_struct.as_ref().unwrap() { - OpStruct::Projection(project) => { - assert_eq!(children.len(), 1); - let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; - let exprs: PhyExprResult = project - .project_list - .iter() - .enumerate() - .map(|(idx, expr)| { - self.create_expr(expr, child.schema()) - .map(|r| (r, format!("col_{idx}"))) - }) - .collect(); - let projection = Arc::new(ProjectionExec::try_new( - exprs?, - Arc::clone(&child.native_plan), - )?); - Ok(( - scans, - Arc::new(SparkPlan::new(spark_plan.plan_id, projection, vec![child])), - )) - } OpStruct::Filter(filter) => { assert_eq!(children.len(), 1); let (scans, child) = self.create_plan(&children[0], inputs, partition_count)?; @@ -1634,6 +1626,10 @@ impl PhysicalPlanner { Arc::new(SparkPlan::new(spark_plan.plan_id, window_agg, vec![child])), )) } + _ => Err(GeneralError(format!( + "Unsupported or unregistered operator type: {:?}", + spark_plan.op_struct + ))), } } diff --git a/native/core/src/execution/planner/expression_registry.rs b/native/core/src/execution/planner/expression_registry.rs index f97cb984b1..227484ca87 100644 --- a/native/core/src/execution/planner/expression_registry.rs +++ b/native/core/src/execution/planner/expression_registry.rs @@ -25,7 +25,90 @@ use datafusion::physical_expr::PhysicalExpr; use datafusion_comet_proto::spark_expression::{expr::ExprStruct, Expr}; use crate::execution::operators::ExecutionError; -use crate::execution::planner::traits::{ExpressionBuilder, ExpressionType}; + +/// Trait for building physical expressions from Spark protobuf expressions +pub trait ExpressionBuilder: Send + Sync { + /// Build a DataFusion physical expression from a Spark protobuf expression + fn build( + &self, + spark_expr: &Expr, + input_schema: SchemaRef, + planner: &super::PhysicalPlanner, + ) -> Result, ExecutionError>; +} + +/// Enum to identify different expression types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExpressionType { + // Arithmetic expressions + Add, + Subtract, + Multiply, + Divide, + IntegralDivide, + Remainder, + UnaryMinus, + + // Comparison expressions + Eq, + Neq, + Lt, + LtEq, + Gt, + GtEq, + EqNullSafe, + NeqNullSafe, + + // Logical expressions + And, + Or, + Not, + + // Null checks + IsNull, + IsNotNull, + + // Bitwise operations + BitwiseAnd, + BitwiseOr, + BitwiseXor, + BitwiseShiftLeft, + BitwiseShiftRight, + + // Other expressions + Bound, + Unbound, + Literal, + Cast, + CaseWhen, + In, + If, + Substring, + Like, + Rlike, + CheckOverflow, + ScalarFunc, + NormalizeNanAndZero, + Subquery, + BloomFilterMightContain, + CreateNamedStruct, + GetStructField, + ToJson, + ToPrettyString, + ListExtract, + GetArrayStructFields, + ArrayInsert, + Rand, + Randn, + SparkPartitionId, + MonotonicallyIncreasingId, + + // Time functions + Hour, + Minute, + Second, + TruncTimestamp, +} /// Registry for expression builders pub struct ExpressionRegistry { diff --git a/native/core/src/execution/planner/traits.rs b/native/core/src/execution/planner/macros.rs similarity index 55% rename from native/core/src/execution/planner/traits.rs rename to native/core/src/execution/planner/macros.rs index 3f3467d0d0..9d9ccf35da 100644 --- a/native/core/src/execution/planner/traits.rs +++ b/native/core/src/execution/planner/macros.rs @@ -15,17 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! Core traits for the modular planner framework - -use std::sync::Arc; - -use arrow::datatypes::SchemaRef; -use datafusion::physical_expr::PhysicalExpr; -use datafusion_comet_proto::spark_expression::Expr; -use jni::objects::GlobalRef; - -use crate::execution::operators::ScanExec; -use crate::execution::{operators::ExecutionError, spark_plan::SparkPlan}; +//! Core macros for the modular planner framework /// Macro to extract a specific expression variant, panicking if called with wrong type. /// This should be used in expression builders where the registry guarantees the correct @@ -48,13 +38,34 @@ macro_rules! extract_expr { }; } +/// Macro to extract a specific operator variant, panicking if called with wrong type. +/// This should be used in operator builders where the registry guarantees the correct +/// operator type has been routed to the builder. +#[macro_export] +macro_rules! extract_op { + ($spark_operator:expr, $variant:ident) => { + match $spark_operator + .op_struct + .as_ref() + .expect("operator struct must be present") + { + datafusion_comet_proto::spark_operator::operator::OpStruct::$variant(op) => op, + other => panic!( + "{} builder called with wrong operator type: {:?}", + stringify!($variant), + other + ), + } + }; +} + /// Macro to generate binary expression builders with minimal boilerplate #[macro_export] macro_rules! binary_expr_builder { ($builder_name:ident, $expr_type:ident, $operator:expr) => { pub struct $builder_name; - impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + impl $crate::execution::planner::expression_registry::ExpressionBuilder for $builder_name { fn build( &self, spark_expr: &datafusion_comet_proto::spark_expression::Expr, @@ -84,7 +95,7 @@ macro_rules! unary_expr_builder { ($builder_name:ident, $expr_type:ident, $expr_constructor:expr) => { pub struct $builder_name; - impl $crate::execution::planner::traits::ExpressionBuilder for $builder_name { + impl $crate::execution::planner::expression_registry::ExpressionBuilder for $builder_name { fn build( &self, spark_expr: &datafusion_comet_proto::spark_expression::Expr, @@ -101,120 +112,3 @@ macro_rules! unary_expr_builder { } }; } - -/// Trait for building physical expressions from Spark protobuf expressions -pub trait ExpressionBuilder: Send + Sync { - /// Build a DataFusion physical expression from a Spark protobuf expression - fn build( - &self, - spark_expr: &Expr, - input_schema: SchemaRef, - planner: &super::PhysicalPlanner, - ) -> Result, ExecutionError>; -} - -/// Trait for building physical operators from Spark protobuf operators -#[allow(dead_code)] -pub trait OperatorBuilder: Send + Sync { - /// Build a Spark plan from a protobuf operator - fn build( - &self, - spark_plan: &datafusion_comet_proto::spark_operator::Operator, - inputs: &mut Vec>, - partition_count: usize, - planner: &super::PhysicalPlanner, - ) -> Result<(Vec, Arc), ExecutionError>; -} - -/// Enum to identify different expression types for registry dispatch -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ExpressionType { - // Arithmetic expressions - Add, - Subtract, - Multiply, - Divide, - IntegralDivide, - Remainder, - UnaryMinus, - - // Comparison expressions - Eq, - Neq, - Lt, - LtEq, - Gt, - GtEq, - EqNullSafe, - NeqNullSafe, - - // Logical expressions - And, - Or, - Not, - - // Null checks - IsNull, - IsNotNull, - - // Bitwise operations - BitwiseAnd, - BitwiseOr, - BitwiseXor, - BitwiseShiftLeft, - BitwiseShiftRight, - - // Other expressions - Bound, - Unbound, - Literal, - Cast, - CaseWhen, - In, - If, - Substring, - Like, - Rlike, - CheckOverflow, - ScalarFunc, - NormalizeNanAndZero, - Subquery, - BloomFilterMightContain, - CreateNamedStruct, - GetStructField, - ToJson, - ToPrettyString, - ListExtract, - GetArrayStructFields, - ArrayInsert, - Rand, - Randn, - SparkPartitionId, - MonotonicallyIncreasingId, - - // Time functions - Hour, - Minute, - Second, - TruncTimestamp, -} - -/// Enum to identify different operator types for registry dispatch -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -#[allow(dead_code)] -pub enum OperatorType { - Scan, - NativeScan, - IcebergScan, - Projection, - Filter, - HashAgg, - Limit, - Sort, - ShuffleWriter, - ParquetWriter, - Expand, - SortMergeJoin, - HashJoin, - Window, -} diff --git a/native/core/src/execution/planner/operator_registry.rs b/native/core/src/execution/planner/operator_registry.rs new file mode 100644 index 0000000000..e4899280b7 --- /dev/null +++ b/native/core/src/execution/planner/operator_registry.rs @@ -0,0 +1,155 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Registry for operator builders using modular pattern + +use std::{ + collections::HashMap, + sync::{Arc, OnceLock}, +}; + +use datafusion_comet_proto::spark_operator::Operator; +use jni::objects::GlobalRef; + +use super::PhysicalPlanner; +use crate::execution::{ + operators::{ExecutionError, ScanExec}, + spark_plan::SparkPlan, +}; + +/// Trait for building physical operators from Spark protobuf operators +pub trait OperatorBuilder: Send + Sync { + /// Build a Spark plan from a protobuf operator + fn build( + &self, + spark_plan: &datafusion_comet_proto::spark_operator::Operator, + inputs: &mut Vec>, + partition_count: usize, + planner: &PhysicalPlanner, + ) -> Result<(Vec, Arc), ExecutionError>; +} + +/// Enum to identify different operator types for registry dispatch +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OperatorType { + Scan, + NativeScan, + IcebergScan, + Projection, + Filter, + HashAgg, + Limit, + Sort, + ShuffleWriter, + ParquetWriter, + Expand, + SortMergeJoin, + HashJoin, + Window, +} + +/// Global registry of operator builders +pub struct OperatorRegistry { + builders: HashMap>, +} + +impl OperatorRegistry { + /// Create a new empty registry + fn new() -> Self { + Self { + builders: HashMap::new(), + } + } + + /// Get the global singleton instance of the operator registry + pub fn global() -> &'static OperatorRegistry { + static REGISTRY: OnceLock = OnceLock::new(); + REGISTRY.get_or_init(|| { + let mut registry = OperatorRegistry::new(); + registry.register_all_operators(); + registry + }) + } + + /// Check if the registry can handle a given operator + pub fn can_handle(&self, spark_operator: &Operator) -> bool { + get_operator_type(spark_operator) + .map(|op_type| self.builders.contains_key(&op_type)) + .unwrap_or(false) + } + + /// Create a Spark plan using the registered builder for this operator type + pub fn create_plan( + &self, + spark_operator: &Operator, + inputs: &mut Vec>, + partition_count: usize, + planner: &PhysicalPlanner, + ) -> Result<(Vec, Arc), ExecutionError> { + let operator_type = get_operator_type(spark_operator).ok_or_else(|| { + ExecutionError::GeneralError(format!( + "Unsupported operator type: {:?}", + spark_operator.op_struct + )) + })?; + + let builder = self.builders.get(&operator_type).ok_or_else(|| { + ExecutionError::GeneralError(format!( + "No builder registered for operator type: {:?}", + operator_type + )) + })?; + + builder.build(spark_operator, inputs, partition_count, planner) + } + + /// Register all operator builders + fn register_all_operators(&mut self) { + self.register_projection_operators(); + } + + /// Register projection operators + fn register_projection_operators(&mut self) { + use crate::execution::operators::projection::ProjectionBuilder; + + self.builders + .insert(OperatorType::Projection, Box::new(ProjectionBuilder)); + } +} + +/// Extract the operator type from a Spark operator +fn get_operator_type(spark_operator: &Operator) -> Option { + use datafusion_comet_proto::spark_operator::operator::OpStruct; + + match spark_operator.op_struct.as_ref()? { + OpStruct::Projection(_) => Some(OperatorType::Projection), + OpStruct::Filter(_) => Some(OperatorType::Filter), + OpStruct::HashAgg(_) => Some(OperatorType::HashAgg), + OpStruct::Limit(_) => Some(OperatorType::Limit), + OpStruct::Sort(_) => Some(OperatorType::Sort), + OpStruct::Scan(_) => Some(OperatorType::Scan), + OpStruct::NativeScan(_) => Some(OperatorType::NativeScan), + OpStruct::IcebergScan(_) => Some(OperatorType::IcebergScan), + OpStruct::ShuffleWriter(_) => Some(OperatorType::ShuffleWriter), + OpStruct::ParquetWriter(_) => Some(OperatorType::ParquetWriter), + OpStruct::Expand(_) => Some(OperatorType::Expand), + OpStruct::SortMergeJoin(_) => Some(OperatorType::SortMergeJoin), + OpStruct::HashJoin(_) => Some(OperatorType::HashJoin), + OpStruct::Window(_) => Some(OperatorType::Window), + OpStruct::Explode(_) => None, // Not yet in OperatorType enum + } +}