From ff0ada71eccc9b827fc6f3b2f88b4d82ed483fea Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 10:06:23 +0400 Subject: [PATCH 1/3] DataFrame API: allow aggregate functions in select() (#17874) --- ...es@explain_plan_environment_overrides.snap | 12 ++-- datafusion/core/src/dataframe/mod.rs | 64 +++++++++++++++++-- datafusion/core/tests/dataframe/mod.rs | 24 +++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 1359cefbe71c7..5f43ca88dc9d7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -18,19 +18,19 @@ exit_code: 0 | logical_plan | [ | | | { | | | "Plan": { | -| | "Expressions": [ | -| | "Int64(123)" | -| | ], | | | "Node Type": "Projection", | -| | "Output": [ | +| | "Expressions": [ | | | "Int64(123)" | | | ], | | | "Plans": [ | | | { | | | "Node Type": "EmptyRelation", | -| | "Output": [], | -| | "Plans": [] | +| | "Plans": [], | +| | "Output": [] | | | } | +| | ], | +| | "Output": [ | +| | "Int64(123)" | | | ] | | | } | | | } | diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 2292f5855bfde..1947e25adb467 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -57,6 +57,7 @@ use datafusion_common::{ plan_datafusion_err, plan_err, unqualified_field_not_found, }; use datafusion_expr::select_expr::SelectExpr; +use datafusion_expr::utils::find_aggregate_exprs; use datafusion_expr::{ ExplainOption, SortExpr, TableProviderFilterPushDown, UNNAMED_TABLE, case, dml::InsertOp, @@ -410,21 +411,76 @@ impl DataFrame { expr_list: impl IntoIterator>, ) -> Result { let expr_list: Vec = - expr_list.into_iter().map(|e| e.into()).collect::>(); + expr_list.into_iter().map(|e| e.into()).collect(); + // Extract plain expressions let expressions = expr_list.iter().filter_map(|e| match e { SelectExpr::Expression(expr) => Some(expr), _ => None, }); - let window_func_exprs = find_window_exprs(expressions); - let plan = if window_func_exprs.is_empty() { + // Apply window functions first + let window_func_exprs = find_window_exprs(expressions.clone()); + + let mut plan = if window_func_exprs.is_empty() { self.plan } else { LogicalPlanBuilder::window_plan(self.plan, window_func_exprs)? }; - let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + // Collect aggregate expressions + let aggr_exprs = find_aggregate_exprs(expressions.clone()); + + // Check if any expression is non-aggregate + let has_non_aggregate_expr = expressions + .clone() + .any(|expr| find_aggregate_exprs(std::iter::once(expr)).is_empty()); + + // Fallback to projection: + // - already aggregated + // - contains non-aggregate expressions + // - no aggregates at all + if matches!(plan, LogicalPlan::Aggregate(_)) + || has_non_aggregate_expr + || aggr_exprs.is_empty() + { + let project_plan = + LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; + + return Ok(DataFrame { + session_state: self.session_state, + plan: project_plan, + projection_requires_validation: false, + }); + } + + // Build Aggregate node + let aggr_exprs: Vec = aggr_exprs + .into_iter() + .enumerate() + .map(|(i, expr)| expr.alias(format!("__agg_{i}"))) + .collect(); + + plan = LogicalPlanBuilder::from(plan) + .aggregate(Vec::::new(), aggr_exprs)? + .build()?; + + // Replace aggregates with their aliases + let mut rewritten_exprs = Vec::with_capacity(expr_list.len()); + for (i, select_expr) in expr_list.into_iter().enumerate() { + match select_expr { + SelectExpr::Expression(expr) => { + let column = Expr::Column(Column::from_name(format!("__agg_{i}"))); + let alias = expr.name_for_alias()?; + rewritten_exprs.push(SelectExpr::Expression(column.alias(alias))); + } + other => rewritten_exprs.push(other), + } + } + + let project_plan = LogicalPlanBuilder::from(plan) + .project(rewritten_exprs)? + .build()?; Ok(DataFrame { session_state: self.session_state, diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 80bbde1f6ba14..5fc67b18b06ed 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,6 +34,7 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; +use datafusion_functions_aggregate::approx_distinct::approx_distinct; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -6854,3 +6855,26 @@ async fn test_duplicate_state_fields_for_dfschema_construct() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { + let df = test_table().await?; + + let res = df.select(vec![ + approx_distinct(col("c9")).alias("count_c9"), + approx_distinct(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + ])?; + + assert_batches_eq!( + &[ + "+----------+--------------+", + "| count_c9 | count_c9_str |", + "+----------+--------------+", + "| 100 | 100 |", + "+----------+--------------+", + ], + &res.collect().await? + ); + + Ok(()) +} From 1659fa7c8a69242a59695370d56b59f893e02a7d Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 10:59:44 +0400 Subject: [PATCH 2/3] use count instead of approx_distinct in test --- datafusion/core/tests/dataframe/mod.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index 5fc67b18b06ed..9dcc147339166 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -34,7 +34,6 @@ use arrow::util::pretty::pretty_format_batches; use arrow_schema::{SortOptions, TimeUnit}; use datafusion::{assert_batches_eq, dataframe}; use datafusion_common::metadata::FieldMetadata; -use datafusion_functions_aggregate::approx_distinct::approx_distinct; use datafusion_functions_aggregate::count::{count_all, count_all_window}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, avg_distinct, count, count_distinct, max, median, min, sum, @@ -6861,8 +6860,8 @@ async fn test_dataframe_api_aggregate_fn_in_select() -> Result<()> { let df = test_table().await?; let res = df.select(vec![ - approx_distinct(col("c9")).alias("count_c9"), - approx_distinct(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), + count(col("c9")).alias("count_c9"), + count(cast(col("c9"), DataType::Utf8View)).alias("count_c9_str"), ])?; assert_batches_eq!( From f9f351e0adf7cff8b1adce5f7005a1efa42d71e4 Mon Sep 17 00:00:00 2001 From: Sergey Zhukov Date: Wed, 18 Mar 2026 16:29:49 +0400 Subject: [PATCH 3/3] Update CLI snapshot --- ...overrides@explain_plan_environment_overrides.snap | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap index 5f43ca88dc9d7..1359cefbe71c7 100644 --- a/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap +++ b/datafusion-cli/tests/snapshots/cli_explain_environment_overrides@explain_plan_environment_overrides.snap @@ -18,19 +18,19 @@ exit_code: 0 | logical_plan | [ | | | { | | | "Plan": { | -| | "Node Type": "Projection", | | | "Expressions": [ | | | "Int64(123)" | | | ], | +| | "Node Type": "Projection", | +| | "Output": [ | +| | "Int64(123)" | +| | ], | | | "Plans": [ | | | { | | | "Node Type": "EmptyRelation", | -| | "Plans": [], | -| | "Output": [] | +| | "Output": [], | +| | "Plans": [] | | | } | -| | ], | -| | "Output": [ | -| | "Int64(123)" | | | ] | | | } | | | } |