From 3afbc7d8caee6234ba88d6e42b95873c840e367a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 17 May 2026 12:41:09 -0400 Subject: [PATCH] feat: accept distinct kwarg on sum and avg MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Upstream exposes `sum_distinct` / `avg_distinct` / `count_distinct` as sibling functions that call the same underlying UDAF with `distinct: bool = true`. The Rust binding side already routes `distinct=Some(true)` through the aggregate builder for `sum`, `avg`, and `count` — but only `count` exposed the kwarg on the Python wrapper. Add `distinct: bool = False` to `sum()` and `avg()` mirroring the existing `count()` signature, and update SKILL.md so the check-upstream audit does not re-flag the three upstream `*_distinct` shortcuts as gaps. The plan emitted by `sum(col, distinct=True)` matches what upstream's `sum_distinct(col)` builds. Co-Authored-By: Claude Opus 4.7 (1M context) --- .ai/skills/check-upstream/SKILL.md | 9 ++++++++- python/datafusion/functions.py | 30 ++++++++++++++++++++++++++---- python/tests/test_functions.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 5 deletions(-) diff --git a/.ai/skills/check-upstream/SKILL.md b/.ai/skills/check-upstream/SKILL.md index 3bac018ef..529add0c3 100644 --- a/.ai/skills/check-upstream/SKILL.md +++ b/.ai/skills/check-upstream/SKILL.md @@ -82,11 +82,18 @@ The user may specify an area via `$ARGUMENTS`. If no area is specified or "all" - Python API: `python/datafusion/functions.py` (aggregate functions are mixed in with scalar functions) - Rust bindings: `crates/core/src/functions.rs` +**Evaluated and not requiring separate Python exposure:** +- `count_distinct` — covered by `count(expr, distinct=True)`. Both forms call + `count_udaf` with `distinct: bool = true` and produce the same logical plan. +- `sum_distinct` — covered by `sum(expr, distinct=True)`. +- `avg_distinct` — covered by `avg(expr, distinct=True)`. + **How to check:** 1. Fetch the upstream aggregate function documentation page 2. Compare against aggregate functions in `python/datafusion/functions.py` (check `__all__` list and function definitions) 3. A function is covered if it exists in the Python API, even if it aliases another function's Rust binding -4. Report only functions missing from the Python API +4. Check against the "evaluated and not requiring exposure" list before flagging as a gap +5. Report only functions missing from the Python API ### 3. Window Functions diff --git a/python/datafusion/functions.py b/python/datafusion/functions.py index 9761d1879..83af282ff 100644 --- a/python/datafusion/functions.py +++ b/python/datafusion/functions.py @@ -4521,6 +4521,7 @@ def grouping( def avg( expression: Expr, + distinct: bool = False, filter: Expr | None = None, ) -> Expr: """Returns the average value. @@ -4528,10 +4529,12 @@ def avg( This aggregate function expects a numeric expression and will return a float. If using the builder functions described in ref:`_aggregation` this function ignores - the options ``order_by``, ``null_treatment``, and ``distinct``. + the options ``order_by`` and ``null_treatment``. Args: expression: Values to combine into an array + distinct: If True, only distinct values are averaged. Equivalent to the + upstream ``avg_distinct`` shortcut. filter: If provided, only compute against rows for which the filter is True Examples: @@ -4551,9 +4554,17 @@ def avg( ... ).alias("v")]) >>> result.collect_column("v")[0].as_py() 2.5 + + >>> df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]}) + >>> result = df.aggregate( + ... [], [dfn.functions.avg( + ... dfn.col("a"), distinct=True, + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 2.0 """ filter_raw = filter.expr if filter is not None else None - return Expr(f.avg(expression.expr, filter=filter_raw)) + return Expr(f.avg(expression.expr, distinct=distinct, filter=filter_raw)) def corr(value_y: Expr, value_x: Expr, filter: Expr | None = None) -> Expr: @@ -4838,6 +4849,7 @@ def min(expression: Expr, filter: Expr | None = None) -> Expr: def sum( expression: Expr, + distinct: bool = False, filter: Expr | None = None, ) -> Expr: """Computes the sum of a set of numbers. @@ -4845,10 +4857,12 @@ def sum( This aggregate function expects a numeric expression. If using the builder functions described in ref:`_aggregation` this function ignores - the options ``order_by``, ``null_treatment``, and ``distinct``. + the options ``order_by`` and ``null_treatment``. Args: expression: Values to combine into an array + distinct: If True, only distinct values are summed. Equivalent to the + upstream ``sum_distinct`` shortcut. filter: If provided, only compute against rows for which the filter is True Examples: @@ -4868,9 +4882,17 @@ def sum( ... ).alias("v")]) >>> result.collect_column("v")[0].as_py() 5 + + >>> df = ctx.from_pydict({"a": [1, 1, 2, 3]}) + >>> result = df.aggregate( + ... [], [dfn.functions.sum( + ... dfn.col("a"), distinct=True, + ... ).alias("v")]) + >>> result.collect_column("v")[0].as_py() + 6 """ filter_raw = filter.expr if filter is not None else None - return Expr(f.sum(expression.expr, filter=filter_raw)) + return Expr(f.sum(expression.expr, distinct=distinct, filter=filter_raw)) def stddev(expression: Expr, filter: Expr | None = None) -> Expr: diff --git a/python/tests/test_functions.py b/python/tests/test_functions.py index 5538fc33b..34435ac12 100644 --- a/python/tests/test_functions.py +++ b/python/tests/test_functions.py @@ -1957,6 +1957,36 @@ def test_get_field(df): assert result.column(1) == pa.array([4, 5, 6]) +def test_sum_distinct_kwarg(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 1, 2, 3]}) + distinct = ( + df.aggregate([], [f.sum(column("a"), distinct=True).alias("v")]) + .collect_column("v")[0] + .as_py() + ) + total = ( + df.aggregate([], [f.sum(column("a")).alias("v")]).collect_column("v")[0].as_py() + ) + assert distinct == 6 + assert total == 7 + + +def test_avg_distinct_kwarg(): + ctx = SessionContext() + df = ctx.from_pydict({"a": [1.0, 1.0, 2.0, 3.0]}) + distinct = ( + df.aggregate([], [f.avg(column("a"), distinct=True).alias("v")]) + .collect_column("v")[0] + .as_py() + ) + mean = ( + df.aggregate([], [f.avg(column("a")).alias("v")]).collect_column("v")[0].as_py() + ) + assert distinct == 2.0 + assert mean == 1.75 + + def test_arrow_metadata(): ctx = SessionContext() field = pa.field("val", pa.int64(), metadata={"key1": "value1", "key2": "value2"})