Skip to content

Commit 6dcb82c

Browse files
committed
Add test for user-defined aggregation function (UDAF) with DataFusion
- Implement MyAccumulator class following Accumulator interface - Register UDAF named "my_accumulator" in SessionContext - Create test DataFrame and run SQL query using UDAF with GROUP BY - Verify results match expected aggregated values - Ensure correct integration and functionality of UDAF in Python bindings
1 parent ba0b49e commit 6dcb82c

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

test_udaf_script.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pyarrow as pa
2+
import pyarrow.compute as pc
3+
from datafusion import Accumulator, SessionContext, udaf
4+
5+
6+
# Define a user-defined aggregation function (UDAF)
7+
class MyAccumulator(Accumulator):
8+
"""
9+
Interface of a user-defined accumulation.
10+
"""
11+
12+
def __init__(self) -> None:
13+
self._sum = pa.scalar(0.0)
14+
15+
def update(self, values: pa.Array) -> None:
16+
# Not nice since pyarrow scalars can't be summed yet. This breaks on `None`
17+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
18+
19+
def merge(self, states: list[pa.Array]) -> None:
20+
# Not nice since pyarrow scalars can't be summed yet. This breaks on `None`
21+
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
22+
23+
def state(self) -> list[pa.Scalar]:
24+
return [self._sum]
25+
26+
def evaluate(self) -> pa.Scalar:
27+
return self._sum
28+
29+
30+
my_udaf = udaf(
31+
MyAccumulator,
32+
pa.float64(),
33+
pa.float64(),
34+
[pa.float64()],
35+
"stable",
36+
# This will be the name of the UDAF in SQL
37+
# If not specified it will by default the same as accumulator class name
38+
name="my_accumulator",
39+
)
40+
41+
# Create a context
42+
ctx = SessionContext()
43+
44+
# Create a datafusion DataFrame from a Python dictionary
45+
source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t")
46+
# Dataframe:
47+
# +---+---+
48+
# | a | b |
49+
# +---+---+
50+
# | 1 | 4 |
51+
# | 1 | 5 |
52+
# | 3 | 6 |
53+
# +---+---+
54+
55+
# Register UDF for use in SQL
56+
ctx.register_udaf(my_udaf)
57+
58+
# Query the DataFrame using SQL
59+
result_df = ctx.sql(
60+
"select a, my_accumulator(b) as b_aggregated from t group by a order by a"
61+
)
62+
# Dataframe:
63+
# +---+--------------+
64+
# | a | b_aggregated |
65+
# +---+--------------+
66+
# | 1 | 9 |
67+
# | 3 | 6 |
68+
# +---+--------------+
69+
70+
result_dict = result_df.to_pydict()
71+
print("Result:", result_dict)
72+
assert result_dict["a"] == [1, 3]
73+
assert result_dict["b_aggregated"] == [9.0, 6.0]
74+
print("Test passed successfully!")

0 commit comments

Comments
 (0)