Skip to content

Commit 9e62add

Browse files
timsaucerclaude
andcommitted
Add missing aggregate functions: grouping, percentile_cont, var_population
Expose upstream DataFusion aggregate functions that were not yet available in the Python API. Closes #1454. - grouping: returns grouping set membership indicator (rewritten by the ResolveGroupingFunction analyzer rule before physical planning) - percentile_cont: computes exact percentile using continuous interpolation (unlike approx_percentile_cont which uses t-digest) - var_population: alias for var_pop Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 16feeb1 commit 9e62add

File tree

3 files changed

+142
-4
lines changed

3 files changed

+142
-4
lines changed

crates/core/src/functions.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -709,9 +709,10 @@ aggregate_function!(var_pop);
709709
aggregate_function!(approx_distinct);
710710
aggregate_function!(approx_median);
711711

712-
// Code is commented out since grouping is not yet implemented
713-
// https://github.com/apache/datafusion-python/issues/861
714-
// aggregate_function!(grouping);
712+
// The grouping function's physical plan is not implemented, but the
713+
// ResolveGroupingFunction analyzer rule rewrites it before the physical
714+
// planner sees it, so it works correctly at runtime.
715+
aggregate_function!(grouping);
715716

716717
#[pyfunction]
717718
#[pyo3(signature = (sort_expression, percentile, num_centroids=None, filter=None))]
@@ -749,6 +750,19 @@ pub fn approx_percentile_cont_with_weight(
749750
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
750751
}
751752

753+
#[pyfunction]
754+
#[pyo3(signature = (sort_expression, percentile, filter=None))]
755+
pub fn percentile_cont(
756+
sort_expression: PySortExpr,
757+
percentile: f64,
758+
filter: Option<PyExpr>,
759+
) -> PyDataFusionResult<PyExpr> {
760+
let agg_fn =
761+
functions_aggregate::expr_fn::percentile_cont(sort_expression.sort, lit(percentile));
762+
763+
add_builder_fns_to_aggregate(agg_fn, None, filter, None, None)
764+
}
765+
752766
// We handle last_value explicitly because the signature expects an order_by
753767
// https://github.com/apache/datafusion/issues/12376
754768
#[pyfunction]
@@ -949,6 +963,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
949963
m.add_wrapped(wrap_pyfunction!(approx_median))?;
950964
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont))?;
951965
m.add_wrapped(wrap_pyfunction!(approx_percentile_cont_with_weight))?;
966+
m.add_wrapped(wrap_pyfunction!(percentile_cont))?;
952967
m.add_wrapped(wrap_pyfunction!(range))?;
953968
m.add_wrapped(wrap_pyfunction!(array_agg))?;
954969
m.add_wrapped(wrap_pyfunction!(arrow_typeof))?;
@@ -997,7 +1012,7 @@ pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
9971012
m.add_wrapped(wrap_pyfunction!(from_unixtime))?;
9981013
m.add_wrapped(wrap_pyfunction!(gcd))?;
9991014
m.add_wrapped(wrap_pyfunction!(greatest))?;
1000-
// m.add_wrapped(wrap_pyfunction!(grouping))?;
1015+
m.add_wrapped(wrap_pyfunction!(grouping))?;
10011016
m.add_wrapped(wrap_pyfunction!(in_list))?;
10021017
m.add_wrapped(wrap_pyfunction!(initcap))?;
10031018
m.add_wrapped(wrap_pyfunction!(isnan))?;

python/datafusion/functions.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
"from_unixtime",
154154
"gcd",
155155
"greatest",
156+
"grouping",
156157
"ifnull",
157158
"in_list",
158159
"initcap",
@@ -224,6 +225,7 @@
224225
"order_by",
225226
"overlay",
226227
"percent_rank",
228+
"percentile_cont",
227229
"pi",
228230
"pow",
229231
"power",
@@ -294,6 +296,7 @@
294296
"uuid",
295297
"var",
296298
"var_pop",
299+
"var_population",
297300
"var_samp",
298301
"var_sample",
299302
"when",
@@ -3643,6 +3646,47 @@ def approx_percentile_cont_with_weight(
36433646
)
36443647

36453648

3649+
def percentile_cont(
3650+
sort_expression: Expr | SortExpr,
3651+
percentile: float,
3652+
filter: Expr | None = None,
3653+
) -> Expr:
3654+
"""Computes the exact percentile of input values using continuous interpolation.
3655+
3656+
Unlike :py:func:`approx_percentile_cont`, this function computes the exact
3657+
percentile value rather than an approximation.
3658+
3659+
If using the builder functions described in ref:`_aggregation` this function ignores
3660+
the options ``order_by``, ``null_treatment``, and ``distinct``.
3661+
3662+
Args:
3663+
sort_expression: Values for which to find the percentile
3664+
percentile: This must be between 0.0 and 1.0, inclusive
3665+
filter: If provided, only compute against rows for which the filter is True
3666+
3667+
Examples:
3668+
>>> ctx = dfn.SessionContext()
3669+
>>> df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
3670+
>>> result = df.aggregate(
3671+
... [], [dfn.functions.percentile_cont(
3672+
... dfn.col("a"), 0.5
3673+
... ).alias("v")])
3674+
>>> result.collect_column("v")[0].as_py()
3675+
3.0
3676+
3677+
>>> result = df.aggregate(
3678+
... [], [dfn.functions.percentile_cont(
3679+
... dfn.col("a"), 0.5,
3680+
... filter=dfn.col("a") > dfn.lit(1.0),
3681+
... ).alias("v")])
3682+
>>> result.collect_column("v")[0].as_py()
3683+
3.5
3684+
"""
3685+
sort_expr_raw = sort_or_default(sort_expression)
3686+
filter_raw = filter.expr if filter is not None else None
3687+
return Expr(f.percentile_cont(sort_expr_raw, percentile, filter=filter_raw))
3688+
3689+
36463690
def array_agg(
36473691
expression: Expr,
36483692
distinct: bool = False,
@@ -3701,6 +3745,30 @@ def array_agg(
37013745
)
37023746

37033747

3748+
def grouping(
3749+
expression: Expr,
3750+
distinct: bool | None = None,
3751+
filter: Expr | None = None,
3752+
) -> Expr:
3753+
"""Returns 1 if the data is aggregated across the specified column, or 0 otherwise.
3754+
3755+
This function is used with ``GROUPING SETS``, ``CUBE``, or ``ROLLUP`` to
3756+
distinguish between aggregated and non-aggregated rows. In a regular
3757+
``GROUP BY`` without grouping sets, it always returns 0.
3758+
3759+
Note: The ``grouping`` aggregate function is rewritten by the query
3760+
optimizer before execution, so it works correctly even though its
3761+
physical plan is not directly implemented.
3762+
3763+
Args:
3764+
expression: The column to check grouping status for
3765+
distinct: If True, compute on distinct values only
3766+
filter: If provided, only compute against rows for which the filter is True
3767+
"""
3768+
filter_raw = filter.expr if filter is not None else None
3769+
return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw))
3770+
3771+
37043772
def avg(
37053773
expression: Expr,
37063774
filter: Expr | None = None,
@@ -4172,6 +4240,15 @@ def var_pop(expression: Expr, filter: Expr | None = None) -> Expr:
41724240
return Expr(f.var_pop(expression.expr, filter=filter_raw))
41734241

41744242

4243+
def var_population(expression: Expr, filter: Expr | None = None) -> Expr:
4244+
"""Computes the population variance of the argument.
4245+
4246+
See Also:
4247+
This is an alias for :py:func:`var_pop`.
4248+
"""
4249+
return var_pop(expression, filter)
4250+
4251+
41754252
def var_samp(expression: Expr, filter: Expr | None = None) -> Expr:
41764253
"""Computes the sample variance of the argument.
41774254

python/tests/test_functions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,3 +1660,49 @@ def df_with_nulls():
16601660
def test_conditional_functions(df_with_nulls, expr, expected):
16611661
result = df_with_nulls.select(expr.alias("result")).collect()[0]
16621662
assert result.column(0) == expected
1663+
1664+
1665+
def test_percentile_cont():
1666+
ctx = SessionContext()
1667+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1668+
result = df.aggregate(
1669+
[], [f.percentile_cont(column("a"), 0.5).alias("v")]
1670+
).collect()[0]
1671+
assert result.column(0)[0].as_py() == 3.0
1672+
1673+
1674+
def test_percentile_cont_with_filter():
1675+
ctx = SessionContext()
1676+
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1677+
result = df.aggregate(
1678+
[],
1679+
[
1680+
f.percentile_cont(
1681+
column("a"), 0.5, filter=column("a") > literal(1.0)
1682+
).alias("v")
1683+
],
1684+
).collect()[0]
1685+
assert result.column(0)[0].as_py() == 3.5
1686+
1687+
1688+
def test_grouping():
1689+
ctx = SessionContext()
1690+
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
1691+
# In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
1692+
# Note: grouping() must not be aliased directly in the aggregate expression list
1693+
# due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction
1694+
# rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead.
1695+
result = df.aggregate(
1696+
[column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")]
1697+
).collect()
1698+
grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist()
1699+
assert all(v == 0 for v in grouping_col)
1700+
1701+
1702+
def test_var_population():
1703+
ctx = SessionContext()
1704+
df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]})
1705+
result = df.aggregate([], [f.var_population(column("a")).alias("v")]).collect()[0]
1706+
# var_population is an alias for var_pop
1707+
expected = df.aggregate([], [f.var_pop(column("a")).alias("v")]).collect()[0]
1708+
assert abs(result.column(0)[0].as_py() - expected.column(0)[0].as_py()) < 1e-10

0 commit comments

Comments
 (0)