diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala b/common/src/main/scala/org/apache/comet/CometConf.scala index 522ccbc94c..7fa1482341 100644 --- a/common/src/main/scala/org/apache/comet/CometConf.scala +++ b/common/src/main/scala/org/apache/comet/CometConf.scala @@ -122,16 +122,14 @@ object CometConf extends ShimCometConf { val SCAN_AUTO = "auto" val COMET_NATIVE_SCAN_IMPL: ConfigEntry[String] = conf("spark.comet.scan.impl") - .category(CATEGORY_SCAN) + .category(CATEGORY_PARQUET) .doc( - "The implementation of Comet Native Scan to use. Available modes are " + + "The implementation of Comet's Parquet scan to use. Available scans are " + s"`$SCAN_NATIVE_DATAFUSION`, and `$SCAN_NATIVE_ICEBERG_COMPAT`. " + - s"`$SCAN_NATIVE_DATAFUSION` is a fully native implementation of scan based on " + - "DataFusion. " + - s"`$SCAN_NATIVE_ICEBERG_COMPAT` is the recommended native implementation that " + - "exposes apis to read parquet columns natively and supports complex types. " + - s"`$SCAN_AUTO` (default) chooses the best scan.") - .internal() + s"`$SCAN_NATIVE_DATAFUSION` is a fully native implementation, and " + + s"`$SCAN_NATIVE_ICEBERG_COMPAT` is a hybrid implementation that supports some " + + "additional features, such as row indexes and field ids. " + + s"`$SCAN_AUTO` (default) chooses the best available scan based on the scan schema.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set(SCAN_NATIVE_DATAFUSION, SCAN_NATIVE_ICEBERG_COMPAT, SCAN_AUTO)) @@ -549,6 +547,17 @@ object CometConf extends ShimCometConf { .booleanConf .createWithDefault(false) + val COMET_NATIVE_PHYSICAL_OPTIMIZER_ENABLED: ConfigEntry[Boolean] = + conf("spark.comet.exec.nativePhysicalOptimizer.enabled") + .category(CATEGORY_TESTING) + .doc( + "When enabled, Comet will run DataFusion's physical optimizer rules on the " + + "native query plan before execution. This can improve performance through " + + "additional optimizations such as projection pushdown, coalesce batches, " + + "filter pushdown, and limit pushdown. This feature is highly experimental.") + .booleanConf + .createWithDefault(true) + val COMET_EXTENDED_EXPLAIN_FORMAT_VERBOSE = "verbose" val COMET_EXTENDED_EXPLAIN_FORMAT_FALLBACK = "fallback" diff --git a/dev/benchmarks/comet-tpch.sh b/dev/benchmarks/comet-tpch.sh index a748a02319..02aaca2f6a 100755 --- a/dev/benchmarks/comet-tpch.sh +++ b/dev/benchmarks/comet-tpch.sh @@ -42,6 +42,9 @@ $SPARK_HOME/bin/spark-submit \ --conf spark.shuffle.manager=org.apache.spark.sql.comet.execution.shuffle.CometShuffleManager \ --conf spark.comet.scan.impl=native_datafusion \ --conf spark.comet.exec.replaceSortMergeJoin=true \ + --conf spark.comet.scan.impl=native_datafusion \ + --conf spark.comet.datafusion.execution.parquet.pushdown_filters=true \ + --conf spark.comet.exec.nativePhysicalOptimizer.enabled=true \ --conf spark.comet.expression.Cast.allowIncompatible=true \ --conf spark.hadoop.fs.s3a.impl=org.apache.hadoop.fs.s3a.S3AFileSystem \ --conf spark.hadoop.fs.s3a.aws.credentials.provider=com.amazonaws.auth.DefaultAWSCredentialsProviderChain \ diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 146e0feb8e..677fa5fcf8 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -83,11 +83,11 @@ use crate::execution::tracing::{log_memory_usage, trace_begin, trace_end, with_t use crate::execution::spark_config::{ SparkConfig, COMET_DEBUG_ENABLED, COMET_EXPLAIN_NATIVE_ENABLED, COMET_MAX_TEMP_DIRECTORY_SIZE, - COMET_TRACING_ENABLED, + COMET_NATIVE_PHYSICAL_OPTIMIZER_ENABLED, COMET_TRACING_ENABLED, }; use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID}; use datafusion_comet_proto::spark_operator::operator::OpStruct; -use log::info; +use log::{debug, info}; use once_cell::sync::Lazy; #[cfg(feature = "jemalloc")] use tikv_jemalloc_ctl::{epoch, stats}; @@ -153,6 +153,8 @@ struct ExecutionContext { pub memory_pool_config: MemoryPoolConfig, /// Whether to log memory usage on each call to execute_plan pub tracing_enabled: bool, + /// Whether to run DataFusion's physical optimizer on the native plan + pub native_physical_optimizer_enabled: bool, } /// Accept serialized query plan and return the address of the native query plan. @@ -192,6 +194,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let tracing_enabled = spark_config.get_bool(COMET_TRACING_ENABLED); let max_temp_directory_size = spark_config.get_u64(COMET_MAX_TEMP_DIRECTORY_SIZE, 100 * 1024 * 1024 * 1024); + let native_physical_optimizer_enabled = + spark_config.get_bool(COMET_NATIVE_PHYSICAL_OPTIMIZER_ENABLED); with_trace("createPlan", tracing_enabled, || { // Init JVM classes @@ -246,6 +250,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( local_dirs_vec, max_temp_directory_size, task_cpus as usize, + &spark_config, )?; let plan_creation_time = start.elapsed(); @@ -286,6 +291,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( explain_native, memory_pool_config, tracing_enabled, + native_physical_optimizer_enabled, }); Ok(Box::into_raw(exec_context) as i64) @@ -300,6 +306,7 @@ fn prepare_datafusion_session_context( local_dirs: Vec, max_temp_directory_size: u64, task_cpus: usize, + spark_config: &HashMap, ) -> CometResult { let paths = local_dirs.into_iter().map(PathBuf::from).collect(); let disk_manager = DiskManagerBuilder::default() @@ -308,10 +315,7 @@ fn prepare_datafusion_session_context( let mut rt_config = RuntimeEnvBuilder::new().with_disk_manager_builder(disk_manager); rt_config = rt_config.with_memory_pool(memory_pool); - // Get Datafusion configuration from Spark Execution context - // can be configured in Comet Spark JVM using Spark --conf parameters - // e.g: spark-shell --conf spark.datafusion.sql_parser.parse_float_as_decimal=true - let session_config = SessionConfig::new() + let mut session_config = SessionConfig::new() .with_target_partitions(task_cpus) // This DataFusion context is within the scope of an executing Spark Task. We want to set // its internal parallelism to the number of CPUs allocated to Spark Tasks. This can be @@ -328,9 +332,55 @@ fn prepare_datafusion_session_context( &ScalarValue::Float64(Some(1.1)), ); + // Pass through DataFusion configs from Spark. + // e.g: spark-shell --conf spark.comet.datafusion.sql_parser.parse_float_as_decimal=true + // becomes datafusion.sql_parser.parse_float_as_decimal=true + const SPARK_COMET_DF_PREFIX: &str = "spark.comet.datafusion."; + for (key, value) in spark_config { + if let Some(df_key) = key.strip_prefix(SPARK_COMET_DF_PREFIX) { + let df_key = format!("datafusion.{df_key}"); + session_config = session_config.set_str(&df_key, value); + } + } + let runtime = rt_config.build()?; - let mut session_ctx = SessionContext::new_with_config_rt(session_config, Arc::new(runtime)); + // Only include physical optimizer rules that are compatible with + // Comet's execution model. Spark handles distribution, sorting, + // filter/projection pushdown, and join selection externally. + use datafusion::execution::SessionStateBuilder; + use datafusion::physical_optimizer::{ + // aggregate_statistics::AggregateStatistics, + // combine_partial_final_agg::CombinePartialFinalAggregate, + // limit_pushdown::LimitPushdown, + // limit_pushdown_past_window::LimitPushPastWindows, + // limited_distinct_aggregation::LimitedDistinctAggregation, + // topk_aggregation::TopKAggregation, + // update_aggr_exprs::OptimizeAggregateOrder, + PhysicalOptimizerRule, + }; + + use crate::execution::physical_cse::PhysicalCommonSubexprEliminate; + + let physical_optimizer_rules: Vec> = vec![ + // Arc::new(AggregateStatistics::new()), + // Arc::new(LimitedDistinctAggregation::new()), + // Arc::new(CombinePartialFinalAggregate::new()), + // Arc::new(OptimizeAggregateOrder::new()), + // Arc::new(TopKAggregation::new()), + // Arc::new(LimitPushPastWindows::new()), + // Arc::new(LimitPushdown::new()), + Arc::new(PhysicalCommonSubexprEliminate::new()), + ]; + + let state = SessionStateBuilder::new() + .with_config(session_config) + .with_runtime_env(Arc::new(runtime)) + .with_default_features() + .with_physical_optimizer_rules(physical_optimizer_rules) + .build(); + + let mut session_ctx = SessionContext::new_with_state(state); datafusion::functions_nested::register_all(&mut session_ctx)?; register_datafusion_spark_function(&session_ctx); @@ -495,9 +545,45 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let physical_plan_time = start.elapsed(); exec_context.plan_creation_time += physical_plan_time; - exec_context.root_op = Some(Arc::clone(&root_op)); exec_context.scans = scans; + let root_op = if exec_context.native_physical_optimizer_enabled { + let state = exec_context.session_ctx.state(); + let optimizers = state.physical_optimizers(); + let config = Arc::clone(state.config_options()); + let mut optimized_plan = Arc::clone(&root_op.native_plan); + + if exec_context.explain_native { + let before = + DisplayableExecutionPlan::new(optimized_plan.as_ref()).indent(true); + info!("Comet native plan before DataFusion optimization:\n{before}"); + } + + let opt_start = std::time::Instant::now(); + for optimizer in optimizers { + optimized_plan = optimizer.optimize(optimized_plan, &config)?; + } + debug!("Comet physical optimization completed in {:?}", opt_start.elapsed()); + + if exec_context.explain_native { + let after = + DisplayableExecutionPlan::new(optimized_plan.as_ref()).indent(true); + info!("Comet native plan after DataFusion optimization:\n{after}"); + } + + // Keep the original SparkPlan tree structure (for metrics) + // but replace the root's native plan with the optimized version + Arc::new(SparkPlan::new( + root_op.plan_id, + optimized_plan, + root_op.children.clone(), + )) + } else { + root_op + }; + + exec_context.root_op = Some(Arc::clone(&root_op)); + if exec_context.explain_native { let formatted_plan_str = DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true); diff --git a/native/core/src/execution/mod.rs b/native/core/src/execution/mod.rs index 85fc672461..1fae6a04f5 100644 --- a/native/core/src/execution/mod.rs +++ b/native/core/src/execution/mod.rs @@ -21,6 +21,7 @@ pub mod expressions; pub mod jni_api; pub(crate) mod metrics; pub mod operators; +pub(crate) mod physical_cse; pub(crate) mod planner; pub mod serde; pub mod shuffle; diff --git a/native/core/src/execution/physical_cse.rs b/native/core/src/execution/physical_cse.rs new file mode 100644 index 0000000000..f37e5e2f7d --- /dev/null +++ b/native/core/src/execution/physical_cse.rs @@ -0,0 +1,611 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical common subexpression elimination (CSE) optimizer rule. +//! +//! Identifies repeated subexpressions within `ProjectionExec` and +//! `AggregateExec` nodes and rewrites the plan to compute them once via an +//! intermediate projection. + +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +use datafusion::common::config::ConfigOptions; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRecursion}; +use datafusion::common::Result; +use datafusion::physical_expr::aggregate::AggregateFunctionExpr; +use datafusion::physical_expr::expressions::{Column, Literal}; +use datafusion::physical_expr_common::physical_expr::is_volatile; +use datafusion::physical_optimizer::PhysicalOptimizerRule; +use datafusion::physical_plan::aggregates::{AggregateExec, PhysicalGroupBy}; +use datafusion::physical_plan::projection::{ProjectionExec, ProjectionExpr}; +use datafusion::physical_plan::ExecutionPlan; +use log::debug; + +/// Needed because `Arc` doesn't implement `Eq`/`Hash` +/// directly — this delegates to the trait-object implementations. +struct ExprKey(Arc); + +impl PartialEq for ExprKey { + fn eq(&self, other: &Self) -> bool { + self.0.as_ref() == other.0.as_ref() + } +} + +impl Eq for ExprKey {} + +impl Hash for ExprKey { + fn hash(&self, state: &mut H) { + self.0.as_ref().hash(state); + } +} + +/// Physical optimizer rule that eliminates common subexpressions within +/// `ProjectionExec` and `AggregateExec` nodes. +#[derive(Debug)] +pub struct PhysicalCommonSubexprEliminate; + +impl PhysicalCommonSubexprEliminate { + pub fn new() -> Self { + Self + } +} + +impl PhysicalOptimizerRule for PhysicalCommonSubexprEliminate { + fn optimize( + &self, + plan: Arc, + _config: &ConfigOptions, + ) -> Result> { + let start = std::time::Instant::now(); + let result = plan + .transform_up(|node| { + if node.as_any().downcast_ref::().is_some() { + try_optimize_projection(node) + } else if node.as_any().downcast_ref::().is_some() { + try_optimize_aggregate(node) + } else { + Ok(Transformed::no(node)) + } + }) + .data(); + debug!("Physical CSE optimizer completed in {:?}", start.elapsed()); + result + } + + fn name(&self) -> &str { + "physical_common_subexpr_eliminate" + } + + fn schema_check(&self) -> bool { + true + } +} + +/// Columns and literals are too cheap to be worth extracting. +fn is_trivial(expr: &Arc) -> bool { + expr.as_any().downcast_ref::().is_some() + || expr.as_any().downcast_ref::().is_some() +} + +fn collect_subexprs( + expr: &Arc, + counts: &mut HashMap, +) { + if is_trivial(expr) || is_volatile(expr) { + return; + } + let key = ExprKey(Arc::clone(expr)); + *counts.entry(key).or_insert(0) += 1; + for child in expr.children() { + collect_subexprs(child, counts); + } +} + +fn find_common_subexprs( + exprs: &[Arc], +) -> Vec> { + let mut counts: HashMap = HashMap::new(); + for expr in exprs { + collect_subexprs(expr, &mut counts); + } + let common: Vec> = counts + .into_iter() + .filter(|(_, count)| *count >= 2) + .map(|(key, _)| key.0) + .collect(); + + // After rewriting the larger CSE to a column reference, its children + // are no longer evaluated, so any smaller CSE nested inside it would + // produce an unused column in the intermediate projection. + let common_set: std::collections::HashSet = + common.iter().map(|e| ExprKey(Arc::clone(e))).collect(); + + common + .into_iter() + .filter(|expr| { + !common_set.iter().any(|other| { + if other.0.as_ref() == expr.as_ref() { + return false; + } + contains_subexpr(&other.0, expr) + }) + }) + .collect() +} + +fn contains_subexpr( + haystack: &Arc, + needle: &Arc, +) -> bool { + for child in haystack.children() { + if child.as_ref() == needle.as_ref() { + return true; + } + if contains_subexpr(child, needle) { + return true; + } + } + false +} + +/// Replaces occurrences of any common subexpression in `expr` with a +/// `Column` reference into the intermediate projection's schema. +fn rewrite_expr( + expr: Arc, + cse_map: &HashMap, +) -> Result> { + expr.transform_down(|node| { + if is_trivial(&node) { + return Ok(Transformed::no(node)); + } + let lookup = ExprKey(Arc::clone(&node)); + if let Some((name, index)) = cse_map.get(&lookup) { + let col = Arc::new(Column::new(name, *index)) + as Arc; + // Jump skips recursing into children that are now behind a column ref + Ok(Transformed::new(col, true, TreeNodeRecursion::Jump)) + } else { + Ok(Transformed::no(node)) + } + }) + .data() +} + +fn try_optimize_projection( + node: Arc, +) -> Result>> { + let projection = node.as_any().downcast_ref::().unwrap(); + let proj_exprs = projection.expr(); + + let raw_exprs: Vec> = + proj_exprs.iter().map(|pe| Arc::clone(&pe.expr)).collect(); + let common = find_common_subexprs(&raw_exprs); + + if common.is_empty() { + return Ok(Transformed::no(node)); + } + + let input = projection.input(); + let input_schema = input.schema(); + let num_input_cols = input_schema.fields().len(); + + let mut intermediate_exprs: Vec = Vec::new(); + for (i, field) in input_schema.fields().iter().enumerate() { + intermediate_exprs.push(ProjectionExpr { + expr: Arc::new(Column::new(field.name(), i)), + alias: field.name().clone(), + }); + } + + let mut cse_map: HashMap = HashMap::new(); + for (idx, cse_expr) in common.iter().enumerate() { + let cse_name = format!("__cse_{idx}"); + let col_index = num_input_cols + idx; + intermediate_exprs.push(ProjectionExpr { + expr: Arc::clone(cse_expr), + alias: cse_name.clone(), + }); + cse_map.insert(ExprKey(Arc::clone(cse_expr)), (cse_name, col_index)); + } + + let intermediate = Arc::new(ProjectionExec::try_new( + intermediate_exprs, + Arc::clone(input), + )?) as Arc; + + let mut new_proj_exprs: Vec = Vec::new(); + for proj_expr in proj_exprs { + let rewritten = rewrite_expr(Arc::clone(&proj_expr.expr), &cse_map)?; + new_proj_exprs.push(ProjectionExpr { + expr: rewritten, + alias: proj_expr.alias.clone(), + }); + } + + let new_projection = + Arc::new(ProjectionExec::try_new(new_proj_exprs, intermediate)?) as Arc; + + debug!( + "Physical CSE: rewrote ProjectionExec, extracted {} common subexpression(s): [{}]", + common.len(), + common.iter().map(|e| e.to_string()).collect::>().join(", ") + ); + + Ok(Transformed::yes(new_projection)) +} + +fn try_optimize_aggregate( + node: Arc, +) -> Result>> { + let agg_exec = node.as_any().downcast_ref::().unwrap(); + + // Final/FinalPartitioned aggregates reference partial outputs, not + // the original column expressions, so CSE doesn't apply. + if !agg_exec.mode().is_first_stage() { + return Ok(Transformed::no(node)); + } + + let aggr_exprs = agg_exec.aggr_expr(); + let all_args: Vec> = aggr_exprs + .iter() + .flat_map(|agg_fn| agg_fn.expressions()) + .collect(); + + let common = find_common_subexprs(&all_args); + if common.is_empty() { + return Ok(Transformed::no(node)); + } + + let input = agg_exec.input(); + let input_schema = input.schema(); + let num_input_cols = input_schema.fields().len(); + + let mut intermediate_exprs: Vec = Vec::new(); + for (i, field) in input_schema.fields().iter().enumerate() { + intermediate_exprs.push(ProjectionExpr { + expr: Arc::new(Column::new(field.name(), i)), + alias: field.name().clone(), + }); + } + + let mut cse_map: HashMap = HashMap::new(); + for (idx, cse_expr) in common.iter().enumerate() { + let cse_name = format!("__cse_{idx}"); + let col_index = num_input_cols + idx; + intermediate_exprs.push(ProjectionExpr { + expr: Arc::clone(cse_expr), + alias: cse_name.clone(), + }); + cse_map.insert(ExprKey(Arc::clone(cse_expr)), (cse_name, col_index)); + } + + let intermediate = Arc::new(ProjectionExec::try_new( + intermediate_exprs, + Arc::clone(input), + )?) as Arc; + let intermediate_schema = intermediate.schema(); + + let mut new_aggr_exprs: Vec> = Vec::new(); + for agg_fn in aggr_exprs { + let old_args = agg_fn.expressions(); + let mut new_args = Vec::with_capacity(old_args.len()); + for arg in &old_args { + new_args.push(rewrite_expr(Arc::clone(arg), &cse_map)?); + } + let order_by_exprs: Vec> = agg_fn + .order_bys() + .iter() + .map(|sort_expr| Arc::clone(&sort_expr.expr)) + .collect(); + let new_agg_fn = agg_fn + .with_new_expressions(new_args, order_by_exprs) + .ok_or_else(|| { + datafusion::common::DataFusionError::Internal(format!( + "Failed to rewrite aggregate expression: {}", + agg_fn.name() + )) + })?; + new_aggr_exprs.push(Arc::new(new_agg_fn)); + } + + let new_filters: Vec>> = agg_exec + .filter_expr() + .iter() + .map(|filter_opt| { + filter_opt + .as_ref() + .map(|f| rewrite_expr(Arc::clone(f), &cse_map)) + .transpose() + }) + .collect::>()?; + + let old_group_by = agg_exec.group_expr(); + let new_group_exprs: Vec<(Arc, String)> = + old_group_by + .expr() + .iter() + .map(|(expr, alias)| { + Ok((rewrite_expr(Arc::clone(expr), &cse_map)?, alias.clone())) + }) + .collect::>()?; + let new_null_exprs: Vec<(Arc, String)> = + old_group_by + .null_expr() + .iter() + .map(|(expr, alias)| (Arc::clone(expr), alias.clone())) + .collect(); + let new_group_by = + PhysicalGroupBy::new(new_group_exprs, new_null_exprs, old_group_by.groups().to_vec()); + + let new_agg = AggregateExec::try_new( + *agg_exec.mode(), + new_group_by, + new_aggr_exprs, + new_filters, + intermediate, + intermediate_schema, + )?; + + debug!( + "Physical CSE: rewrote AggregateExec ({:?} mode), extracted {} common subexpression(s): [{}]", + agg_exec.mode(), + common.len(), + common.iter().map(|e| e.to_string()).collect::>().join(", ") + ); + + Ok(Transformed::yes(Arc::new(new_agg) as Arc)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::functions_aggregate::sum::sum_udaf; + use datafusion::logical_expr::Operator; + use datafusion::physical_expr::aggregate::AggregateExprBuilder; + use datafusion::physical_expr::expressions::{binary, col}; + use datafusion::physical_plan::aggregates::AggregateMode; + use datafusion::physical_plan::empty::EmptyExec; + + fn test_schema() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])) + } + + #[test] + fn test_cse_extracts_common_subexpr() -> Result<()> { + let schema = test_schema(); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let a = col("a", &schema)?; + let b = col("b", &schema)?; + + // (a + b) * 2, (a + b) * 3 — both share (a + b) + let a_plus_b_1 = binary(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema)?; + let a_plus_b_2 = binary(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema)?; + + let two = Arc::new(Literal::new(datafusion::common::ScalarValue::Int32(Some( + 2, + )))); + let three = Arc::new(Literal::new(datafusion::common::ScalarValue::Int32(Some( + 3, + )))); + + let expr_x = binary(a_plus_b_1, Operator::Multiply, two, &schema)?; + let expr_y = binary(a_plus_b_2, Operator::Multiply, three, &schema)?; + + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: expr_x, + alias: "x".to_string(), + }, + ProjectionExpr { + expr: expr_y, + alias: "y".to_string(), + }, + ], + input, + )?; + + let plan: Arc = Arc::new(projection); + let config = ConfigOptions::new(); + let rule = PhysicalCommonSubexprEliminate::new(); + let optimized = rule.optimize(plan, &config)?; + + let top = optimized + .as_any() + .downcast_ref::() + .expect("top should be ProjectionExec"); + let intermediate = top + .input() + .as_any() + .downcast_ref::() + .expect("intermediate should be ProjectionExec"); + + assert_eq!(intermediate.expr().len(), 3); // a, b, __cse_0 + assert_eq!(intermediate.expr()[2].alias, "__cse_0"); + assert_eq!(top.expr().len(), 2); + assert_eq!(top.expr()[0].alias, "x"); + assert_eq!(top.expr()[1].alias, "y"); + + Ok(()) + } + + #[test] + fn test_no_cse_when_no_common_subexpr() -> Result<()> { + let schema = test_schema(); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let a = col("a", &schema)?; + let b = col("b", &schema)?; + + let projection = ProjectionExec::try_new( + vec![ + ProjectionExpr { + expr: a, + alias: "a".to_string(), + }, + ProjectionExpr { + expr: b, + alias: "b".to_string(), + }, + ], + input, + )?; + + let plan: Arc = Arc::new(projection); + let config = ConfigOptions::new(); + let rule = PhysicalCommonSubexprEliminate::new(); + let optimized = rule.optimize(Arc::clone(&plan), &config)?; + + let top = optimized + .as_any() + .downcast_ref::() + .expect("should be ProjectionExec"); + assert!(top + .input() + .as_any() + .downcast_ref::() + .is_none()); + + Ok(()) + } + + #[test] + fn test_aggregate_cse_extracts_common_subexpr() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + Field::new("c", DataType::Int64, false), + ])); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let a = col("a", &schema)?; + let b = col("b", &schema)?; + let c = col("c", &schema)?; + + let a_plus_b_1 = binary(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema)?; + let a_plus_b_2 = binary(Arc::clone(&a), Operator::Plus, Arc::clone(&b), &schema)?; + + // sum(a + b) and sum((a + b) * c) — both share (a + b) + let agg1 = AggregateExprBuilder::new(sum_udaf(), vec![a_plus_b_1]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let expr2 = binary(a_plus_b_2, Operator::Multiply, Arc::clone(&c), &schema)?; + let agg2 = AggregateExprBuilder::new(sum_udaf(), vec![expr2]) + .schema(Arc::clone(&schema)) + .alias("sum2") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let group_by = PhysicalGroupBy::new_single(vec![]); + let aggregate = AggregateExec::try_new( + AggregateMode::Partial, + group_by, + vec![Arc::new(agg1), Arc::new(agg2)], + vec![None, None], + input, + Arc::clone(&schema), + )?; + + let plan: Arc = Arc::new(aggregate); + let config = ConfigOptions::new(); + let rule = PhysicalCommonSubexprEliminate::new(); + let optimized = rule.optimize(plan, &config)?; + + let top_agg = optimized + .as_any() + .downcast_ref::() + .expect("top should be AggregateExec"); + let intermediate = top_agg + .input() + .as_any() + .downcast_ref::() + .expect("intermediate should be ProjectionExec"); + + assert_eq!(intermediate.expr().len(), 4); // a, b, c, __cse_0 + assert_eq!(intermediate.expr()[3].alias, "__cse_0"); + assert_eq!(top_agg.aggr_expr().len(), 2); + + Ok(()) + } + + #[test] + fn test_aggregate_cse_skips_final_mode() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Int64, false), + ])); + let input: Arc = Arc::new(EmptyExec::new(Arc::clone(&schema))); + + let a = col("a", &schema)?; + + let agg1 = AggregateExprBuilder::new(sum_udaf(), vec![Arc::clone(&a)]) + .schema(Arc::clone(&schema)) + .alias("sum1") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let agg2 = AggregateExprBuilder::new(sum_udaf(), vec![Arc::clone(&a)]) + .schema(Arc::clone(&schema)) + .alias("sum2") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let group_by = PhysicalGroupBy::new_single(vec![]); + let aggregate = AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![Arc::new(agg1), Arc::new(agg2)], + vec![None, None], + input, + Arc::clone(&schema), + )?; + + let plan: Arc = Arc::new(aggregate); + let config = ConfigOptions::new(); + let rule = PhysicalCommonSubexprEliminate::new(); + let optimized = rule.optimize(plan, &config)?; + + let top_agg = optimized + .as_any() + .downcast_ref::() + .expect("should be AggregateExec"); + assert!( + top_agg + .input() + .as_any() + .downcast_ref::() + .is_none(), + "Final-mode aggregate should not have intermediate projection" + ); + + Ok(()) + } +} diff --git a/native/core/src/execution/spark_config.rs b/native/core/src/execution/spark_config.rs index 60ebb2ff8b..0c4c326eb3 100644 --- a/native/core/src/execution/spark_config.rs +++ b/native/core/src/execution/spark_config.rs @@ -21,6 +21,8 @@ pub(crate) const COMET_TRACING_ENABLED: &str = "spark.comet.tracing.enabled"; pub(crate) const COMET_DEBUG_ENABLED: &str = "spark.comet.debug.enabled"; pub(crate) const COMET_EXPLAIN_NATIVE_ENABLED: &str = "spark.comet.explain.native.enabled"; pub(crate) const COMET_MAX_TEMP_DIRECTORY_SIZE: &str = "spark.comet.maxTempDirectorySize"; +pub(crate) const COMET_NATIVE_PHYSICAL_OPTIMIZER_ENABLED: &str = + "spark.comet.exec.nativePhysicalOptimizer.enabled"; pub(crate) trait SparkConfig { fn get_bool(&self, name: &str) -> bool;