diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b1aa850284aee..9c8cdd5fa5dd8 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -84,7 +84,7 @@ use datafusion_expr::expr::{ }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; -use datafusion_expr::utils::split_conjunction; +use datafusion_expr::utils::{split_conjunction, split_projection}; use datafusion_expr::{ Analyze, BinaryExpr, DescribeTable, DmlStatement, Explain, ExplainFormat, Extension, FetchType, Filter, JoinType, Operator, RecursiveQuery, SkipType, StringifiedPlan, @@ -455,25 +455,8 @@ impl DefaultPhysicalPlanner { ) -> Result> { let exec_node: Arc = match node { // Leaves (no children) - LogicalPlan::TableScan(TableScan { - source, - projection, - filters, - fetch, - .. - }) => { - let source = source_as_provider(source)?; - // Remove all qualifiers from the scan as the provider - // doesn't know (nor should care) how the relation was - // referred to in the query - let filters = unnormalize_cols(filters.iter().cloned()); - let filters_vec = filters.into_iter().collect::>(); - let opts = ScanArgs::default() - .with_projection(projection.as_deref()) - .with_filters(Some(&filters_vec)) - .with_limit(*fetch); - let res = source.scan_with_args(session_state, opts).await?; - Arc::clone(res.plan()) + LogicalPlan::TableScan(scan) => { + self.plan_table_scan(scan, session_state).await? } LogicalPlan::Values(Values { values, schema }) => { let exprs = values @@ -1725,6 +1708,108 @@ impl DefaultPhysicalPlanner { )) } } + + /// Plan a TableScan node, wrapping with ProjectionExec as needed. + /// + /// This method handles projection pushdown by: + /// 1. Computing which columns the scan needs to produce + /// 2. Creating the scan with minimal required columns + /// 3. Applying any remainder projection (for complex expressions) + async fn plan_table_scan( + &self, + scan: &TableScan, + session_state: &SessionState, + ) -> Result> { + let provider = source_as_provider(&scan.source)?; + let source_schema = scan.source.schema(); + + // Remove qualifiers from filters + let filters: Vec = unnormalize_cols(scan.filters.iter().cloned()); + + // Compute required column indices and remainder projection + let split = split_projection(&scan.projection, &source_schema)?; + + // Create the scan + let scan_args = ScanArgs::default() + .with_projection(split.column_indices.as_deref()) + .with_filters(if filters.is_empty() { + None + } else { + Some(&filters) + }) + .with_limit(scan.fetch); + + let scan_result = provider.scan_with_args(session_state, scan_args).await?; + let mut plan: Arc = Arc::clone(scan_result.plan()); + + // Wrap with ProjectionExec if remainder projection is needed + if let Some(ref proj_exprs) = split.remainder { + let scan_df_schema = DFSchema::try_from(plan.schema().as_ref().clone())?; + let unnormalized_proj_exprs: Vec = + unnormalize_cols(proj_exprs.iter().cloned()); + plan = self.create_projection_exec( + &unnormalized_proj_exprs, + plan, + &scan_df_schema, + session_state, + )?; + } + + Ok(plan) + } + + /// Creates a ProjectionExec from logical expressions, handling async UDF expressions. + /// + /// If the expressions contain async UDFs, wraps them with `AsyncFuncExec`. + fn create_projection_exec( + &self, + exprs: &[Expr], + input: Arc, + input_dfschema: &DFSchema, + session_state: &SessionState, + ) -> Result> { + let physical_exprs: Vec<(Arc, String)> = exprs + .iter() + .map(|e| { + let physical = + self.create_physical_expr(e, input_dfschema, session_state)?; + let name = e.schema_name().to_string(); + Ok((physical, name)) + }) + .collect::>>()?; + + let num_input_columns = input.schema().fields().len(); + let input_schema = input.schema(); + + match self.try_plan_async_exprs( + num_input_columns, + PlannedExprResult::ExprWithName(physical_exprs), + input_schema.as_ref(), + )? { + PlanAsyncExpr::Sync(PlannedExprResult::ExprWithName(physical_exprs)) => { + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new(proj_exprs, input)?)) + } + PlanAsyncExpr::Async( + async_map, + PlannedExprResult::ExprWithName(physical_exprs), + ) => { + let async_exec = AsyncFuncExec::try_new(async_map.async_exprs, input)?; + let proj_exprs: Vec = physical_exprs + .into_iter() + .map(|(expr, alias)| ProjectionExpr { expr, alias }) + .collect(); + Ok(Arc::new(ProjectionExec::try_new( + proj_exprs, + Arc::new(async_exec), + )?)) + } + _ => internal_err!("Unexpected PlanAsyncExpressions variant"), + } + } } /// Expand and align a GROUPING SET expression. diff --git a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs index ca1eaa1f958ea..794e5aeed9619 100644 --- a/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs +++ b/datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs @@ -184,8 +184,8 @@ impl TableProvider for CustomProvider { filters: &[Expr], _: Option, ) -> Result> { - let empty = Vec::new(); - let projection = projection.unwrap_or(&empty); + // None means "all columns", Some(empty) means "no columns" + let select_all_columns = projection.is_none() || !projection.unwrap().is_empty(); match &filters[0] { Expr::BinaryExpr(BinaryExpr { right, .. }) => { let int_value = match &**right { @@ -215,9 +215,10 @@ impl TableProvider for CustomProvider { }; Ok(Arc::new(CustomPlan::new( - match projection.is_empty() { - true => Arc::new(Schema::empty()), - false => self.zero_batch.schema(), + if select_all_columns { + self.zero_batch.schema() + } else { + Arc::new(Schema::empty()) }, match int_value { 0 => vec![self.zero_batch.clone()], @@ -227,9 +228,10 @@ impl TableProvider for CustomProvider { ))) } _ => Ok(Arc::new(CustomPlan::new( - match projection.is_empty() { - true => Arc::new(Schema::empty()), - false => self.zero_batch.schema(), + if select_all_columns { + self.zero_batch.schema() + } else { + Arc::new(Schema::empty()) }, vec![], ))), diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2e23fef1da768..a433eb34df5a7 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -520,13 +520,8 @@ impl LogicalPlanBuilder { { let sub_plan = p.into_owned(); - if let Some(proj) = table_scan.projection { - let projection_exprs = proj - .into_iter() - .map(|i| { - Expr::Column(Column::from(sub_plan.schema().qualified_field(i))) - }) - .collect::>(); + if let Some(projection_exprs) = table_scan.projection { + // projection is now Vec, use directly return Self::new(sub_plan) .project(projection_exprs)? .alias(table_scan.table_name); diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index c2b01868c97f3..5087b25178ab6 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -42,8 +42,8 @@ pub use plan::{ EmptyRelation, Explain, ExplainOption, Extension, FetchType, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Projection, RecursiveQuery, Repartition, SkipType, Sort, StringifiedPlan, Subquery, - SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, - projection_schema, + SubqueryAlias, TableScan, TableScanBuilder, ToStringifiedPlan, Union, Unnest, Values, + Window, projection_schema, }; pub use statement::{ Deallocate, Execute, Prepare, ResetVariable, SetVariable, Statement, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 032a97bdb3efa..664762e351f30 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -42,7 +42,8 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::logical_plan::{DmlStatement, Statement}; use crate::utils::{ enumerate_grouping_sets, exprlist_to_fields, find_out_reference_exprs, - grouping_set_expr_count, grouping_set_to_exprlist, split_conjunction, + grouping_set_expr_count, grouping_set_to_exprlist, projection_indices_from_exprs, + split_conjunction, }; use crate::{ BinaryExpr, CreateMemoryTable, CreateView, Execute, Expr, ExprSchemable, @@ -1815,11 +1816,16 @@ impl LogicalPlan { .. }) => { let projected_fields = match projection { - Some(indices) => { - let schema = source.schema(); - let names: Vec<&str> = indices + Some(exprs) => { + let names: Vec = exprs .iter() - .map(|i| schema.field(*i).name().as_str()) + .map(|e| { + if let Expr::Column(col) = e { + col.name.clone() + } else { + e.schema_name().to_string() + } + }) .collect(); format!(" projection=[{}]", names.join(", ")) } @@ -2682,8 +2688,9 @@ pub struct TableScan { pub table_name: TableReference, /// The source of the table pub source: Arc, - /// Optional column indices to use as a projection - pub projection: Option>, + /// Optional column expressions to use as a projection. + /// Each expression should be a simple column reference (`Expr::Column`). + pub projection: Option>, /// The schema description of the output pub projected_schema: DFSchemaRef, /// Optional expressions to be used as filters by the table provider @@ -2725,8 +2732,8 @@ impl PartialOrd for TableScan { struct ComparableTableScan<'a> { /// The name of the table pub table_name: &'a TableReference, - /// Optional column indices to use as a projection - pub projection: &'a Option>, + /// Optional column expressions to use as a projection + pub projection: &'a Option>, /// Optional expressions to be used as filters by the table provider pub filters: &'a Vec, /// Optional number of rows to read @@ -2764,6 +2771,9 @@ impl Hash for TableScan { impl TableScan { /// Initialize TableScan with appropriate schema from the given /// arguments. + /// + /// This method accepts column indices for backward compatibility and + /// converts them internally to column expressions. pub fn try_new( table_name: impl Into, table_source: Arc, @@ -2771,46 +2781,145 @@ impl TableScan { filters: Vec, fetch: Option, ) -> Result { - let table_name = table_name.into(); + let schema = table_source.schema(); + let projection_exprs = projection.map(|indices| { + indices + .iter() + .map(|&i| Expr::Column(Column::new_unqualified(schema.field(i).name()))) + .collect::>() + }); + + TableScanBuilder::new(table_name, table_source) + .projection(projection_exprs) + .filters(filters) + .fetch(fetch) + .build() + } +} + +/// Builder for creating `TableScan` nodes with expression-based projections. +/// +/// This builder provides a flexible way to construct `TableScan` nodes, +/// particularly when working with expression-based projections directly. +#[derive(Clone)] +pub struct TableScanBuilder { + table_name: TableReference, + table_source: Arc, + projection: Option>, + filters: Vec, + fetch: Option, +} + +impl Debug for TableScanBuilder { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + f.debug_struct("TableScanBuilder") + .field("table_name", &self.table_name) + .field("table_source", &"...") + .field("projection", &self.projection) + .field("filters", &self.filters) + .field("fetch", &self.fetch) + .finish() + } +} + +impl TableScanBuilder { + /// Create a new TableScanBuilder with the given table name and source. + pub fn new( + table_name: impl Into, + table_source: Arc, + ) -> Self { + Self { + table_name: table_name.into(), + table_source, + projection: None, + filters: vec![], + fetch: None, + } + } - if table_name.table().is_empty() { + /// Set the projection expressions. + pub fn projection(mut self, projection: Option>) -> Self { + self.projection = projection; + self + } + + /// Set the filters. + pub fn filters(mut self, filters: Vec) -> Self { + self.filters = filters; + self + } + + /// Set the fetch limit. + pub fn fetch(mut self, fetch: Option) -> Self { + self.fetch = fetch; + self + } + + /// Build the TableScan. + pub fn build(self) -> Result { + if self.table_name.table().is_empty() { return plan_err!("table_name cannot be empty"); } - let schema = table_source.schema(); + + let schema = self.table_source.schema(); let func_dependencies = FunctionalDependencies::new_from_constraints( - table_source.constraints(), + self.table_source.constraints(), schema.fields.len(), ); - let projected_schema = projection - .as_ref() - .map(|p| { - let projected_func_dependencies = - func_dependencies.project_functional_dependencies(p, p.len()); - let df_schema = DFSchema::new_with_metadata( - p.iter() - .map(|i| { - (Some(table_name.clone()), Arc::clone(&schema.fields()[*i])) - }) - .collect(), - schema.metadata.clone(), + // Build the projected schema from projection expressions + let projected_schema = match &self.projection { + Some(exprs) => { + // Create a qualified schema for expression evaluation + let qualified_schema = DFSchema::try_from_qualified_schema( + self.table_name.clone(), + &schema, )?; - df_schema.with_functional_dependencies(projected_func_dependencies) - }) - .unwrap_or_else(|| { + + // Derive output fields from projection expressions + // For simple column references, qualify with table name + // For complex expressions, don't add qualifier (matches Projection behavior) + let fields: Vec<(Option, FieldRef)> = exprs + .iter() + .map(|expr| { + let (_qualifier, field) = expr.to_field(&qualified_schema)?; + let qualifier = if matches!(expr, Expr::Column(_)) { + Some(self.table_name.clone()) + } else { + None + }; + Ok((qualifier, field)) + }) + .collect::>>()?; + + // Try to compute functional dependencies for simple column projections + let projected_func_dependencies = + match projection_indices_from_exprs(exprs, &schema)? { + Some(indices) => func_dependencies + .project_functional_dependencies(&indices, indices.len()), + None => FunctionalDependencies::empty(), + }; + let df_schema = - DFSchema::try_from_qualified_schema(table_name.clone(), &schema)?; - df_schema.with_functional_dependencies(func_dependencies) - })?; - let projected_schema = Arc::new(projected_schema); + DFSchema::new_with_metadata(fields, schema.metadata.clone())?; + df_schema.with_functional_dependencies(projected_func_dependencies)? + } + None => { + let df_schema = DFSchema::try_from_qualified_schema( + self.table_name.clone(), + &schema, + )?; + df_schema.with_functional_dependencies(func_dependencies)? + } + }; - Ok(Self { - table_name, - source: table_source, - projection, - projected_schema, - filters, - fetch, + Ok(TableScan { + table_name: self.table_name, + source: self.table_source, + projection: self.projection, + projected_schema: Arc::new(projected_schema), + filters: self.filters, + fetch: self.fetch, }) } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index b19299981cef3..1f7189f54dbb8 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{ }; use datafusion_common::utils::get_at_indices; use datafusion_common::{ - Column, DFSchema, DFSchemaRef, HashMap, Result, TableReference, internal_err, - plan_err, + Column, DFSchema, DFSchemaRef, DataFusionError, HashMap, Result, TableReference, + internal_err, plan_err, }; #[cfg(not(feature = "sql"))] @@ -1286,6 +1286,43 @@ pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{name}[{state_name}]") } +/// Convert projection expressions (assumed to be column references) to column indices. +/// +/// This function takes a list of expressions (which should be `Expr::Column` variants) +/// and returns the indices of those columns in the given schema. +/// +/// # Arguments +/// * `exprs` - A slice of expressions, expected to be `Expr::Column` variants +/// * `schema` - The schema to look up column indices in +/// +/// # Returns +/// * `Ok(Some(Vec))` - If all expressions are column references found in the schema +/// * `Ok(None)` - If any expression is not a simple column reference +/// * `Err(...)` - If a column reference is not found in the schema (indicates a bug) +pub fn projection_indices_from_exprs( + exprs: &[Expr], + schema: &Schema, +) -> Result>> { + let mut indices = Vec::with_capacity(exprs.len()); + for expr in exprs { + match expr { + Expr::Column(col) => { + let idx = schema.index_of(&col.name).map_err(|_| { + DataFusionError::Internal(format!( + "Column '{}' not found in schema during projection index conversion. \ + Available columns: {:?}", + col.name, + schema.fields().iter().map(|f| f.name()).collect::>() + )) + })?; + indices.push(idx); + } + _ => return Ok(None), // Non-column expression, cannot convert to indices + } + } + Ok(Some(indices)) +} + /// Determine the set of [`Column`]s produced by the subquery. pub fn collect_subquery_cols( exprs: &[Expr], @@ -1304,6 +1341,112 @@ pub fn collect_subquery_cols( }) } +/// Result of splitting a projection into column indices and remainder expressions. +/// +/// See [`split_projection`] for details on how projections are split. +#[derive(Debug, Clone, Default)] +pub struct SplitProjection { + /// If the projection contains complex expressions (not just column references), + /// this contains the full projection to apply on top of the column selection. + /// `None` means no remainder projection is needed (all expressions were simple columns). + pub remainder: Option>, + /// Column indices to scan from the source. `None` means scan all columns. + pub column_indices: Option>, +} + +/// Split a projection into a column mask and an optional remainder projection. +/// +/// Given a list of projection expressions and a schema, this function separates +/// simple column references from complex expressions. This is useful when an +/// operator can push down simple column selections but needs a follow-up +/// projection for computed expressions. +/// +/// # Arguments +/// * `projection` - Optional list of projection expressions. `None` means select all columns. +/// * `schema` - The schema to resolve column indices against +/// +/// # Returns +/// A [`SplitProjection`] containing: +/// * `remainder: None, column_indices: None` - If projection is `None` (select all columns) +/// * `remainder: None, column_indices: Some(indices)` - If all expressions are simple column references +/// * `remainder: Some(exprs), column_indices: Some(indices)` - If any expression is complex +/// +/// # Example +/// Given projection `[col("a"), col("a") + col("c"), col("d")]` and schema `[a, b, c, d]`: +/// - `column_indices = Some([0, 2, 3])` (indices of a, c, d) +/// - `remainder = Some([col("a"), col("a") + col("c"), col("d")])` +pub fn split_projection( + projection: &Option>, + schema: &Schema, +) -> Result { + let Some(exprs) = projection else { + // None means scan all columns, no remainder needed + return Ok(SplitProjection::default()); + }; + + if exprs.is_empty() { + return Ok(SplitProjection { + remainder: None, + column_indices: Some(vec![]), + }); + } + + let mut has_complex_expr = false; + let mut all_required_columns = BTreeSet::new(); + let mut remainder_exprs = vec![]; + + for expr in exprs { + // Collect all column references from this expression + let mut is_complex_expr = false; + expr.apply(|e| { + if let Expr::Column(col) = e { + if let Ok(index) = schema.index_of(col.name()) { + // If we made it this far this must be the first level and the whole + // expression is a simple column reference. + // But we don't know if subsequent expressions might have more complex + // expressions necessitating `remainder_exprs` to be populated, so we + // push to `remainder_exprs` just in case they are needed later. + // It is simpler to do this now than to try to backtrack later since + // we already matched into Expr::Column and thus can simply clone + // `expr` here. + // If `is_complex_expr` is true then we will append the complex + // expression itself to `remainder_exprs` instead later once we've + // fully traversed this expression. + if !is_complex_expr { + remainder_exprs.push(expr.clone()); + } + all_required_columns.insert(index); + } + } else { + // Nothing to do here except note that we will have to append the full + // expression later + is_complex_expr = true; + } + Ok(TreeNodeRecursion::Continue) + })?; + if is_complex_expr { + // If any expression in the projection is not a simple column reference we + // will need to apply a remainder projection + has_complex_expr = true; + // Append the full expression itself to the remainder expressions + // So given a projection like `[a, a + c, d]` we would have: + // all_required_columns = {0, 2, 3} + // original schema: [a: Int, b: Int, c: Int, d: Int] + // projected schema: [a: Int, c: Int, d: Int] + // remainder_exprs = [col(a), col(a) + col(c), col(d)] + remainder_exprs.push(expr.clone()); + } + } + + // Always return explicit indices to ensure compatibility with all providers. + // Some providers (e.g., FFI) cannot distinguish between None (scan all) and + // empty vec (scan nothing), so we always provide explicit column indices. + Ok(SplitProjection { + remainder: has_complex_expr.then_some(remainder_exprs), + column_indices: Some(all_required_columns.into_iter().collect()), + }) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index f97b05ea68fbd..4d7a90f6dcd8f 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -30,8 +30,8 @@ use datafusion_common::{ }; use datafusion_expr::expr::Alias; use datafusion_expr::{ - Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, Unnest, Window, - logical_plan::LogicalPlan, + Aggregate, Distinct, EmptyRelation, Expr, Projection, TableScan, TableScanBuilder, + Unnest, Window, logical_plan::LogicalPlan, }; use crate::optimize_projections::required_indices::RequiredIndices; @@ -262,14 +262,36 @@ fn optimize_projections( projected_schema: _, } = table_scan; - // Get indices referred to in the original (schema with all fields) - // given projected indices. - let projection = match &projection { - Some(projection) => indices.into_mapped_indices(|idx| projection[idx]), - None => indices.into_inner(), + let source_schema = source.schema(); + let new_projection = match &projection { + Some(proj_exprs) => { + // Map required indices through existing projection expressions + let new_exprs: Vec = indices + .into_inner() + .iter() + .map(|&idx| proj_exprs[idx].clone()) + .collect(); + new_exprs + } + None => { + // Create column expressions for required indices + let new_exprs: Vec = indices + .into_inner() + .iter() + .map(|&idx| { + let field = source_schema.field(idx); + Expr::Column(Column::new_unqualified(field.name())) + }) + .collect(); + new_exprs + } }; - let new_scan = - TableScan::try_new(table_name, source, Some(projection), filters, fetch)?; + + let new_scan = TableScanBuilder::new(table_name, source) + .projection(Some(new_projection)) + .filters(filters) + .fetch(fetch) + .build()?; return Ok(Transformed::yes(LogicalPlan::TableScan(new_scan))); } diff --git a/datafusion/optimizer/src/optimize_projections/required_indices.rs b/datafusion/optimizer/src/optimize_projections/required_indices.rs index c1e0885c9b5f2..20691beb58eb6 100644 --- a/datafusion/optimizer/src/optimize_projections/required_indices.rs +++ b/datafusion/optimizer/src/optimize_projections/required_indices.rs @@ -193,15 +193,6 @@ impl RequiredIndices { self } - /// Apply the given function `f` to each index in this instance, returning - /// the mapped indices - pub fn into_mapped_indices(self, f: F) -> Vec - where - F: Fn(usize) -> usize, - { - self.map_indices(f).into_inner() - } - /// Returns the `Expr`s from `exprs` that are at the indices in this instance pub fn get_at_indices(&self, exprs: &[Expr]) -> Vec { self.indices.iter().map(|&idx| exprs[idx].clone()).collect() diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index ecd6a89f2a3e6..9d583b7185276 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -3109,13 +3109,27 @@ mod tests { filters: Vec, projection: Option>, ) -> Result { + use datafusion_common::Column; + let test_provider = PushDownProvider { filter_support }; + let schema = test_provider.schema(); + + // Convert indices to expressions + let projection_exprs = projection.map(|indices| { + indices + .iter() + .map(|&i| { + let field = schema.field(i); + Expr::Column(Column::new_unqualified(field.name())) + }) + .collect::>() + }); let table_scan = LogicalPlan::TableScan(TableScan { table_name: "test".into(), filters, - projected_schema: Arc::new(DFSchema::try_from(test_provider.schema())?), - projection, + projected_schema: Arc::new(DFSchema::try_from(schema)?), + projection: projection_exprs, source: Arc::new(test_provider), fetch: None, }); diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 810ec6d1f17a3..dff1df92f954a 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -87,7 +87,7 @@ message ListingTableScanNode { TableReference table_name = 14; repeated string paths = 2; string file_extension = 3; - ProjectionColumns projection = 4; + ProjectionColumns projection = 4; // Deprecated: use projection_exprs instead datafusion_common.Schema schema = 5; repeated LogicalExprNode filters = 6; repeated PartitionColumn table_partition_cols = 7; @@ -101,6 +101,7 @@ message ListingTableScanNode { datafusion_common.ArrowFormat arrow = 16; } repeated SortExprNodeCollection file_sort_order = 13; + repeated LogicalExprNode projection_exprs = 17; // Expression-based projections } message ViewTableScanNode { @@ -108,18 +109,20 @@ message ViewTableScanNode { TableReference table_name = 6; LogicalPlanNode input = 2; datafusion_common.Schema schema = 3; - ProjectionColumns projection = 4; + ProjectionColumns projection = 4; // Deprecated: use projection_exprs instead string definition = 5; + repeated LogicalExprNode projection_exprs = 7; // Expression-based projections } // Logical Plan to Scan a CustomTableProvider registered at runtime message CustomTableScanNode { reserved 1; // was string table_name TableReference table_name = 6; - ProjectionColumns projection = 2; + ProjectionColumns projection = 2; // Deprecated: use projection_exprs instead datafusion_common.Schema schema = 3; repeated LogicalExprNode filters = 4; bytes custom_table_data = 5; + repeated LogicalExprNode projection_exprs = 7; // Expression-based projections } message ProjectionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7ed20785ab384..8e30f70700d21 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4645,6 +4645,9 @@ impl serde::Serialize for CustomTableScanNode { if !self.custom_table_data.is_empty() { len += 1; } + if !self.projection_exprs.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.CustomTableScanNode", len)?; if let Some(v) = self.table_name.as_ref() { struct_ser.serialize_field("tableName", v)?; @@ -4663,6 +4666,9 @@ impl serde::Serialize for CustomTableScanNode { #[allow(clippy::needless_borrows_for_generic_args)] struct_ser.serialize_field("customTableData", pbjson::private::base64::encode(&self.custom_table_data).as_str())?; } + if !self.projection_exprs.is_empty() { + struct_ser.serialize_field("projectionExprs", &self.projection_exprs)?; + } struct_ser.end() } } @@ -4680,6 +4686,8 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { "filters", "custom_table_data", "customTableData", + "projection_exprs", + "projectionExprs", ]; #[allow(clippy::enum_variant_names)] @@ -4689,6 +4697,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { Schema, Filters, CustomTableData, + ProjectionExprs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4715,6 +4724,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { "schema" => Ok(GeneratedField::Schema), "filters" => Ok(GeneratedField::Filters), "customTableData" | "custom_table_data" => Ok(GeneratedField::CustomTableData), + "projectionExprs" | "projection_exprs" => Ok(GeneratedField::ProjectionExprs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4739,6 +4749,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { let mut schema__ = None; let mut filters__ = None; let mut custom_table_data__ = None; + let mut projection_exprs__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::TableName => { @@ -4773,6 +4784,12 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) ; } + GeneratedField::ProjectionExprs => { + if projection_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("projectionExprs")); + } + projection_exprs__ = Some(map_.next_value()?); + } } } Ok(CustomTableScanNode { @@ -4781,6 +4798,7 @@ impl<'de> serde::Deserialize<'de> for CustomTableScanNode { schema: schema__, filters: filters__.unwrap_or_default(), custom_table_data: custom_table_data__.unwrap_or_default(), + projection_exprs: projection_exprs__.unwrap_or_default(), }) } } @@ -11101,6 +11119,9 @@ impl serde::Serialize for ListingTableScanNode { if !self.file_sort_order.is_empty() { len += 1; } + if !self.projection_exprs.is_empty() { + len += 1; + } if self.file_format_type.is_some() { len += 1; } @@ -11135,6 +11156,9 @@ impl serde::Serialize for ListingTableScanNode { if !self.file_sort_order.is_empty() { struct_ser.serialize_field("fileSortOrder", &self.file_sort_order)?; } + if !self.projection_exprs.is_empty() { + struct_ser.serialize_field("projectionExprs", &self.projection_exprs)?; + } if let Some(v) = self.file_format_type.as_ref() { match v { listing_table_scan_node::FileFormatType::Csv(v) => { @@ -11180,6 +11204,8 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "targetPartitions", "file_sort_order", "fileSortOrder", + "projection_exprs", + "projectionExprs", "csv", "parquet", "avro", @@ -11199,6 +11225,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { CollectStat, TargetPartitions, FileSortOrder, + ProjectionExprs, Csv, Parquet, Avro, @@ -11235,6 +11262,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { "collectStat" | "collect_stat" => Ok(GeneratedField::CollectStat), "targetPartitions" | "target_partitions" => Ok(GeneratedField::TargetPartitions), "fileSortOrder" | "file_sort_order" => Ok(GeneratedField::FileSortOrder), + "projectionExprs" | "projection_exprs" => Ok(GeneratedField::ProjectionExprs), "csv" => Ok(GeneratedField::Csv), "parquet" => Ok(GeneratedField::Parquet), "avro" => Ok(GeneratedField::Avro), @@ -11269,6 +11297,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { let mut collect_stat__ = None; let mut target_partitions__ = None; let mut file_sort_order__ = None; + let mut projection_exprs__ = None; let mut file_format_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -11334,6 +11363,12 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { } file_sort_order__ = Some(map_.next_value()?); } + GeneratedField::ProjectionExprs => { + if projection_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("projectionExprs")); + } + projection_exprs__ = Some(map_.next_value()?); + } GeneratedField::Csv => { if file_format_type__.is_some() { return Err(serde::de::Error::duplicate_field("csv")); @@ -11382,6 +11417,7 @@ impl<'de> serde::Deserialize<'de> for ListingTableScanNode { collect_stat: collect_stat__.unwrap_or_default(), target_partitions: target_partitions__.unwrap_or_default(), file_sort_order: file_sort_order__.unwrap_or_default(), + projection_exprs: projection_exprs__.unwrap_or_default(), file_format_type: file_format_type__, }) } @@ -23856,6 +23892,9 @@ impl serde::Serialize for ViewTableScanNode { if !self.definition.is_empty() { len += 1; } + if !self.projection_exprs.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("datafusion.ViewTableScanNode", len)?; if let Some(v) = self.table_name.as_ref() { struct_ser.serialize_field("tableName", v)?; @@ -23872,6 +23911,9 @@ impl serde::Serialize for ViewTableScanNode { if !self.definition.is_empty() { struct_ser.serialize_field("definition", &self.definition)?; } + if !self.projection_exprs.is_empty() { + struct_ser.serialize_field("projectionExprs", &self.projection_exprs)?; + } struct_ser.end() } } @@ -23888,6 +23930,8 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { "schema", "projection", "definition", + "projection_exprs", + "projectionExprs", ]; #[allow(clippy::enum_variant_names)] @@ -23897,6 +23941,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { Schema, Projection, Definition, + ProjectionExprs, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -23923,6 +23968,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { "schema" => Ok(GeneratedField::Schema), "projection" => Ok(GeneratedField::Projection), "definition" => Ok(GeneratedField::Definition), + "projectionExprs" | "projection_exprs" => Ok(GeneratedField::ProjectionExprs), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -23947,6 +23993,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { let mut schema__ = None; let mut projection__ = None; let mut definition__ = None; + let mut projection_exprs__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::TableName => { @@ -23979,6 +24026,12 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { } definition__ = Some(map_.next_value()?); } + GeneratedField::ProjectionExprs => { + if projection_exprs__.is_some() { + return Err(serde::de::Error::duplicate_field("projectionExprs")); + } + projection_exprs__ = Some(map_.next_value()?); + } } } Ok(ViewTableScanNode { @@ -23987,6 +24040,7 @@ impl<'de> serde::Deserialize<'de> for ViewTableScanNode { schema: schema__, projection: projection__, definition: definition__.unwrap_or_default(), + projection_exprs: projection_exprs__.unwrap_or_default(), }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 0c9320c77892b..2329c4c457ea2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -109,6 +109,7 @@ pub struct ListingTableScanNode { pub paths: ::prost::alloc::vec::Vec<::prost::alloc::string::String>, #[prost(string, tag = "3")] pub file_extension: ::prost::alloc::string::String, + /// Deprecated: use projection_exprs instead #[prost(message, optional, tag = "4")] pub projection: ::core::option::Option, #[prost(message, optional, tag = "5")] @@ -123,6 +124,9 @@ pub struct ListingTableScanNode { pub target_partitions: u32, #[prost(message, repeated, tag = "13")] pub file_sort_order: ::prost::alloc::vec::Vec, + /// Expression-based projections + #[prost(message, repeated, tag = "17")] + pub projection_exprs: ::prost::alloc::vec::Vec, #[prost( oneof = "listing_table_scan_node::FileFormatType", tags = "10, 11, 12, 15, 16" @@ -155,16 +159,21 @@ pub struct ViewTableScanNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] pub schema: ::core::option::Option, + /// Deprecated: use projection_exprs instead #[prost(message, optional, tag = "4")] pub projection: ::core::option::Option, #[prost(string, tag = "5")] pub definition: ::prost::alloc::string::String, + /// Expression-based projections + #[prost(message, repeated, tag = "7")] + pub projection_exprs: ::prost::alloc::vec::Vec, } /// Logical Plan to Scan a CustomTableProvider registered at runtime #[derive(Clone, PartialEq, ::prost::Message)] pub struct CustomTableScanNode { #[prost(message, optional, tag = "6")] pub table_name: ::core::option::Option, + /// Deprecated: use projection_exprs instead #[prost(message, optional, tag = "2")] pub projection: ::core::option::Option, #[prost(message, optional, tag = "3")] @@ -173,6 +182,9 @@ pub struct CustomTableScanNode { pub filters: ::prost::alloc::vec::Vec, #[prost(bytes = "vec", tag = "5")] pub custom_table_data: ::prost::alloc::vec::Vec, + /// Expression-based projections + #[prost(message, repeated, tag = "7")] + pub projection_exprs: ::prost::alloc::vec::Vec, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct ProjectionNode { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 218c2e4e47d04..4eac457736060 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -56,12 +56,12 @@ use datafusion_expr::{ }; use datafusion_expr::{ DistinctOn, DropView, Expr, LogicalPlan, LogicalPlanBuilder, ScalarUDF, SortExpr, - Statement, WindowUDF, dml, + Statement, WindowUDF, col, dml, logical_plan::{ Aggregate, CreateCatalog, CreateCatalogSchema, CreateExternalTable, CreateView, DdlStatement, Distinct, EmptyRelation, Extension, Join, JoinConstraint, Prepare, - Projection, Repartition, Sort, SubqueryAlias, TableScan, Values, Window, - builder::project, + Projection, Repartition, Sort, SubqueryAlias, TableScan, TableScanBuilder, + Values, Window, builder::project, }, }; @@ -480,36 +480,46 @@ impl AsLogicalPlan for LogicalPlanNode { let table_name = from_table_reference(scan.table_name.as_ref(), "ListingTableScan")?; - let mut projection = None; - if let Some(columns) = &scan.projection { - let column_indices = columns - .columns - .iter() - .map(|name| provider.schema().index_of(name)) - .collect::, _>>()?; - projection = Some(column_indices); - } + // Prefer new projection_exprs field, fall back to old projection for backward compatibility + let projection = if !scan.projection_exprs.is_empty() { + // New format: expressions are in projection_exprs + Some(from_proto::parse_exprs( + &scan.projection_exprs, + ctx, + extension_codec, + )?) + } else if let Some(columns) = &scan.projection { + if !columns.columns.is_empty() { + // Backward compatibility: convert old column names to expressions + Some( + columns + .columns + .iter() + .map(|name| col(name.as_str())) + .collect(), + ) + } else { + // New format: projection field is a marker indicating Some([]) + Some(vec![]) + } + } else { + // No projection field means projection is None + None + }; - LogicalPlanBuilder::scan_with_filters( - table_name, - provider_as_source(Arc::new(provider)), - projection, - filters, - )? - .build() + Ok(LogicalPlan::TableScan( + TableScanBuilder::new( + table_name, + provider_as_source(Arc::new(provider)), + ) + .projection(projection) + .filters(filters) + .build()?, + )) } LogicalPlanType::CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); - let mut projection = None; - if let Some(columns) = &scan.projection { - let column_indices = columns - .columns - .iter() - .map(|name| schema.index_of(name)) - .collect::, _>>()?; - projection = Some(column_indices); - } let filters = from_proto::parse_exprs(&scan.filters, ctx, extension_codec)?; @@ -524,13 +534,39 @@ impl AsLogicalPlan for LogicalPlanNode { ctx, )?; - LogicalPlanBuilder::scan_with_filters( - table_name, - provider_as_source(provider), - projection, - filters, - )? - .build() + // Prefer new projection_exprs field, fall back to old projection for backward compatibility + let projection = if !scan.projection_exprs.is_empty() { + // New format: expressions are in projection_exprs + Some(from_proto::parse_exprs( + &scan.projection_exprs, + ctx, + extension_codec, + )?) + } else if let Some(columns) = &scan.projection { + if !columns.columns.is_empty() { + // Backward compatibility: convert old column names to expressions + Some( + columns + .columns + .iter() + .map(|name| col(name.as_str())) + .collect(), + ) + } else { + // New format: projection field is a marker indicating Some([]) + Some(vec![]) + } + } else { + // No projection field means projection is None + None + }; + + Ok(LogicalPlan::TableScan( + TableScanBuilder::new(table_name, provider_as_source(provider)) + .projection(projection) + .filters(filters) + .build()?, + )) } LogicalPlanType::Sort(sort) => { let input: LogicalPlan = @@ -836,17 +872,7 @@ impl AsLogicalPlan for LogicalPlanNode { .build() } LogicalPlanType::ViewScan(scan) => { - let schema: Schema = convert_required!(scan.schema)?; - - let mut projection = None; - if let Some(columns) = &scan.projection { - let column_indices = columns - .columns - .iter() - .map(|name| schema.index_of(name)) - .collect::, _>>()?; - projection = Some(column_indices); - } + let _schema: Schema = convert_required!(scan.schema)?; let input: LogicalPlan = into_logical_plan!(scan.input, ctx, extension_codec)?; @@ -862,12 +888,41 @@ impl AsLogicalPlan for LogicalPlanNode { let table_name = from_table_reference(scan.table_name.as_ref(), "ViewScan")?; - LogicalPlanBuilder::scan( - table_name, - provider_as_source(Arc::new(provider)), - projection, - )? - .build() + // Prefer new projection_exprs field, fall back to old projection for backward compatibility + let projection = if !scan.projection_exprs.is_empty() { + // New format: expressions are in projection_exprs + Some(from_proto::parse_exprs( + &scan.projection_exprs, + ctx, + extension_codec, + )?) + } else if let Some(columns) = &scan.projection { + if !columns.columns.is_empty() { + // Backward compatibility: convert old column names to expressions + Some( + columns + .columns + .iter() + .map(|name| col(name.as_str())) + .collect(), + ) + } else { + // New format: projection field is a marker indicating Some([]) + Some(vec![]) + } + } else { + // No projection field means projection is None + None + }; + + Ok(LogicalPlan::TableScan( + TableScanBuilder::new( + table_name, + provider_as_source(Arc::new(provider)), + ) + .projection(projection) + .build()?, + )) } LogicalPlanType::Prepare(prepare) => { let input: LogicalPlan = @@ -1021,16 +1076,16 @@ impl AsLogicalPlan for LogicalPlanNode { let schema = provider.schema(); let source = provider.as_any(); - let projection = match projection { - None => None, - Some(columns) => { - let column_names = columns - .iter() - .map(|i| schema.field(*i).name().to_owned()) - .collect(); - Some(protobuf::ProjectionColumns { - columns: column_names, - }) + // Serialize projection expressions to the new projection_exprs field + // Use the old projection field as a marker to distinguish None vs Some([]) + let (projection_exprs, projection) = match projection { + None => (vec![], None), + Some(exprs) => { + let serialized = serialize_exprs(exprs, extension_codec)?; + // Set projection as a marker that a projection exists (even if empty) + let marker = + Some(protobuf::ProjectionColumns { columns: vec![] }); + (serialized, marker) } }; @@ -1148,6 +1203,7 @@ impl AsLogicalPlan for LogicalPlanNode { filters, target_partitions: options.target_partitions as u32, file_sort_order: exprs_vec, + projection_exprs: projection_exprs.clone(), }, )), }) @@ -1169,6 +1225,7 @@ impl AsLogicalPlan for LogicalPlanNode { .definition() .map(|s| s.to_string()) .unwrap_or_default(), + projection_exprs: projection_exprs.clone(), }, ))), }) @@ -1198,6 +1255,7 @@ impl AsLogicalPlan for LogicalPlanNode { schema: Some(schema), filters, custom_table_data: bytes, + projection_exprs, }); let node = LogicalPlanNode { logical_plan_type: Some(scan), diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index 9f770f9f45e1d..41d51ece1e295 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -41,7 +41,7 @@ use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ Column, DataFusionError, Result, ScalarValue, TableReference, assert_or_internal_err, internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; use datafusion_expr::{ @@ -1088,26 +1088,43 @@ impl Unparser<'_> { // Avoid creating a duplicate Projection node, which would result in an additional subquery if a projection already exists. // For example, if the `optimize_projection` rule is applied, there will be a Projection node, and duplicate projection // information included in the TableScan node. - if !already_projected && let Some(project_vec) = &table_scan.projection { - if project_vec.is_empty() { + if !already_projected && let Some(project_exprs) = &table_scan.projection + { + if project_exprs.is_empty() { builder = builder.project(self.empty_projection_fallback())?; } else { - let project_columns = project_vec - .iter() - .cloned() - .map(|i| { - let schema = table_scan.source.schema(); - let field = schema.field(i); - if alias.is_some() { - Column::new(alias.clone(), field.name().clone()) - } else { - Column::new( - Some(table_scan.table_name.clone()), - field.name().clone(), - ) - } - }) - .collect::>(); + // Handle expression-based projections with alias rewriting + let project_columns: Vec = if alias.is_some() { + project_exprs + .iter() + .map(|expr| { + expr.clone() + .transform(|e| { + if let Expr::Column(col) = &e { + if let Some(relation) = &col.relation { + if relation != &table_scan.table_name + { + return Ok(Transformed::no(e)); + } + Ok(Transformed::yes(Expr::Column( + Column::new( + alias.clone(), + col.name().to_string(), + ), + ))) + } else { + Ok(Transformed::no(e)) + } + } else { + Ok(Transformed::no(e)) + } + }) + .map(|t| t.data) + }) + .collect::>>()? + } else { + project_exprs.clone() + }; builder = builder.project(project_columns)?; }; } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index f539c0ddc1e87..7c83979269328 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -27,8 +27,8 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_expr::{ - Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, Unnest, - Window, expr, utils::grouping_set_to_exprlist, + Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, Projection, SortExpr, + TableScanBuilder, Unnest, Window, expr, utils::grouping_set_to_exprlist, }; use indexmap::IndexSet; @@ -385,11 +385,17 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( } } - let mut builder = LogicalPlanBuilder::scan( - table_scan.table_name.clone(), - Arc::clone(&table_scan.source), - table_scan.projection.clone(), - )?; + // Use TableScanBuilder to preserve expression-based projections + let mut builder = LogicalPlanBuilder::from(LogicalPlan::TableScan( + TableScanBuilder::new( + table_scan.table_name.clone(), + Arc::clone(&table_scan.source), + ) + .projection(table_scan.projection.clone()) + .filters(vec![]) // Filters handled separately + .fetch(table_scan.fetch) + .build()?, + )); if let Some(alias) = table_alias.take() { builder = builder.alias(alias)?; diff --git a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs index 832110e11131c..7473012d7d27e 100644 --- a/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/consumer/rel/read_rel.rs @@ -20,7 +20,7 @@ use crate::logical_plan::consumer::from_substrait_literal; use crate::logical_plan::consumer::from_substrait_named_struct; use crate::logical_plan::consumer::utils::ensure_schema_compatibility; use datafusion::common::{ - DFSchema, DFSchemaRef, TableReference, not_impl_err, plan_err, + Column, DFSchema, DFSchemaRef, TableReference, not_impl_err, plan_err, substrait_datafusion_err, substrait_err, }; use datafusion::datasource::provider_as_source; @@ -325,11 +325,21 @@ fn apply_projection( .map(|(qualifier, field)| (qualifier.cloned(), Arc::clone(field))) .collect(); + // Convert indices to column expressions for the new projection format + let source_schema = scan.source.schema(); + let projection_exprs: Vec = column_indices + .iter() + .map(|&i| { + let field = source_schema.field(i); + Expr::Column(Column::new_unqualified(field.name())) + }) + .collect(); + scan.projected_schema = DFSchemaRef::new(DFSchema::new_with_metadata( fields, df_schema.metadata().clone(), )?); - scan.projection = Some(column_indices); + scan.projection = Some(projection_exprs); Ok(LogicalPlan::TableScan(scan)) } diff --git a/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs index 33920cdf86f7a..ea57dabfbcc2a 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/project_rel.rs @@ -91,7 +91,10 @@ pub fn from_window( /// A DataFusion Projection only outputs expressions. In order to keep the Substrait /// plan consistent with DataFusion, we must apply an output mapping that skips the input /// fields so that the Substrait Project will only output the expression fields. -fn create_project_remapping(expr_count: usize, input_field_count: usize) -> EmitKind { +pub(crate) fn create_project_remapping( + expr_count: usize, + input_field_count: usize, +) -> EmitKind { let expression_field_start = input_field_count; let expression_field_end = expression_field_start + expr_count; let output_mapping = (expression_field_start..expression_field_end) diff --git a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs index 8dfbb36d3767d..30104155aa400 100644 --- a/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs +++ b/datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::logical_plan::producer::rel::project_rel::create_project_remapping; use crate::logical_plan::producer::{ SubstraitProducer, to_substrait_literal, to_substrait_named_struct, }; use datafusion::common::{DFSchema, ToDFSchema, substrait_datafusion_err}; -use datafusion::logical_expr::utils::conjunction; +use datafusion::logical_expr::utils::{conjunction, split_projection}; use datafusion::logical_expr::{EmptyRelation, Expr, TableScan, Values}; use datafusion::scalar::ScalarValue; use std::sync::Arc; @@ -29,7 +30,7 @@ use substrait::proto::expression::mask_expression::{StructItem, StructSelect}; use substrait::proto::expression::nested::Struct as NestedStruct; use substrait::proto::read_rel::{NamedTable, ReadType, VirtualTable}; use substrait::proto::rel::RelType; -use substrait::proto::{ReadRel, Rel}; +use substrait::proto::{ProjectRel, ReadRel, Rel, RelCommon}; /// Converts rows of literal expressions into Substrait literal structs. /// @@ -97,18 +98,24 @@ pub fn from_table_scan( producer: &mut impl SubstraitProducer, scan: &TableScan, ) -> datafusion::common::Result> { - let projection = scan.projection.as_ref().map(|p| { - p.iter() - .map(|i| StructItem { - field: *i as i32, + let source_schema = scan.source.schema(); + + // Compute required column indices and remainder projection expressions. + let split = split_projection(&scan.projection, source_schema.as_ref())?; + + // Build the projection mask from computed scan indices + let projection = split.column_indices.as_ref().map(|indices| { + let struct_items = indices + .iter() + .map(|&i| StructItem { + field: i as i32, child: None, }) - .collect() - }); - - let projection = projection.map(|struct_items| MaskExpression { - select: Some(StructSelect { struct_items }), - maintain_singular_struct: false, + .collect(); + MaskExpression { + select: Some(StructSelect { struct_items }), + maintain_singular_struct: false, + } }); let table_schema = scan.source.schema().to_dfschema_ref()?; @@ -131,7 +138,7 @@ pub fn from_table_scan( Some(Box::new(filter_expr)) }; - Ok(Box::new(Rel { + let read_rel = Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, base_schema: Some(base_schema), @@ -144,7 +151,53 @@ pub fn from_table_scan( advanced_extension: None, })), }))), - })) + }); + + // If we have complex expressions, wrap the ReadRel with a ProjectRel + if let Some(ref proj_exprs) = split.remainder { + // Build a schema for the scanned columns (the output of the ReadRel). + // The projection expressions reference columns by name, and the schema + // tells us the position of each column in the scan output. + // We need to construct this from the source schema and scan indices since + // `projected_schema` is the final output schema after complex projections. + let scan_output_schema = { + let indices = split + .column_indices + .as_ref() + .expect("column_indices should be Some when remainder is Some"); + let projected_arrow_schema = source_schema.project(indices)?; + Arc::new(DFSchema::try_from_qualified_schema( + scan.table_name.clone(), + &projected_arrow_schema, + )?) + }; + + let expressions = proj_exprs + .iter() + .map(|e| producer.handle_expr(e, &scan_output_schema)) + .collect::>>()?; + + let emit_kind = create_project_remapping( + expressions.len(), + scan_output_schema.fields().len(), + ); + let common = RelCommon { + emit_kind: Some(emit_kind), + hint: None, + advanced_extension: None, + }; + + Ok(Box::new(Rel { + rel_type: Some(RelType::Project(Box::new(ProjectRel { + common: Some(common), + input: Some(read_rel), + expressions, + advanced_extension: None, + }))), + })) + } else { + Ok(read_rel) + } } /// Encodes an EmptyRelation as a Substrait VirtualTable. diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f78b255526dc9..c1c4936a9fc2b 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -864,6 +864,19 @@ async fn roundtrip_arithmetic_ops() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_table_scan_complex_projection() -> Result<()> { + // Test TableScan with complex (non-column) projection expressions + // This verifies that the producer wraps the ReadRel with a ProjectRel + // when the projection contains expressions like a + e + roundtrip("SELECT a + e FROM data").await?; + // Mix of simple columns and complex expressions + roundtrip("SELECT a, a + e, f FROM data").await?; + // Complex expression with CAST + roundtrip("SELECT CAST(a AS double) + CAST(e AS double) FROM data").await?; + Ok(()) +} + #[tokio::test] async fn roundtrip_like() -> Result<()> { roundtrip("SELECT f FROM data WHERE f LIKE 'a%b'").await