diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2292f5855bfde..1947e25adb467 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -57,6 +57,7 @@ use datafusion_common::{ plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_expr::{ ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, @@ -410,21 +411,76 @@ impl DataFrame { expr_list: impl IntoIterator>, ) -> Result { let expr_list: Vec = - expr_list.into_iter().map(|e| e.into()).collect::>(); + expr_list.into_iter().map(|e| e.into()).collect(); + // Extract plain expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, }); - let window_func_exprs = find_window_exprs(expressions); - let plan = if window_func_exprs.is_empty() { + // Apply window functions first + let window_func_exprs = find_window_exprs(expressions.clone()); + + let mut plan = if window_func_exprs.is_empty() { self.plan } else { LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + // Collect aggregate expressions + let aggr_exprs = find_aggregate_exprs(expressions.clone()); + + // Check if any expression is non-aggregate + let has_non_aggregate_expr = expressions + .clone() + .any(|expr| find_aggregate_exprs(std::iter::once(expr)).is_empty()); + + // Fallback to projection: + // - already aggregated + // - contains non-aggregate expressions + // - no aggregates at all + if matches!(plan, LogicalPlan::Aggregate(_)) + || has_non_aggregate_expr + || aggr_exprs.is_empty() + { + let project_plan = + LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + + return Ok(DataFrame { + session_state: self.session_state, + plan: project_plan, + projection_requires_validation: false, + }); + } + + // Build Aggregate node + let aggr_exprs: Vec = aggr_exprs + .into_iter() + .enumerate() + .map(|(i, expr)| expr.alias(format!("__agg_{i}"))) + .collect(); + + plan = LogicalPlanBuilder::from(plan) + .aggregate(Vec::::new(), aggr_exprs)? + .build()?; + + // Replace aggregates with their aliases + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); + for (i, select_expr) in expr_list.into_iter().enumerate() { + match select_expr { + SelectExpr::Expression(expr) => { + let column = Expr::Column(Column::from_name(format!("__agg_{i}"))); + let alias = expr.name_for_alias()?; + rewritten_exprs.push(SelectExpr::Expression(column.alias(alias))); + } + other => rewritten_exprs.push(other), + } + } + + let project_plan = LogicalPlanBuilder::from(plan) + .project(rewritten_exprs)? + .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80bbde1f6ba14..9dcc147339166 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -6854,3 +6854,26 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { + let df = test_table().await?; + + let res = df.select(vec![ + count(col("c9")).alias("count_c9"), + count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+", + "| count_c9 | count_c9_str |", + "+----------+--------------+", + "| 100 | 100 |", + "+----------+--------------+", + ], + &res.collect().await? + ); + + Ok(()) +}