From 6dcb82c294895135c977fe4edd412bf70e9fc6ff Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 25 Jul 2025 20:00:11 +0800 Subject: [PATCH 1/5] Add test for user-defined aggregation function (UDAF) with DataFusion - Implement MyAccumulator class following Accumulator interface - Register UDAF named "my_accumulator" in SessionContext - Create test DataFrame and run SQL query using UDAF with GROUP BY - Verify results match expected aggregated values - Ensure correct integration and functionality of UDAF in Python bindings --- test_udaf_script.py | 74 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 test_udaf_script.py diff --git a/test_udaf_script.py b/test_udaf_script.py new file mode 100644 index 000000000..04ca2c7bc --- /dev/null +++ b/test_udaf_script.py @@ -0,0 +1,74 @@ +import pyarrow as pa +import pyarrow.compute as pc +from datafusion import Accumulator, SessionContext, udaf + + +# Define a user-defined aggregation function (UDAF) +class MyAccumulator(Accumulator): + """ + Interface of a user-defined accumulation. + """ + + def __init__(self) -> None: + self._sum = pa.scalar(0.0) + + def update(self, values: pa.Array) -> None: + # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + + def merge(self, states: list[pa.Array]) -> None: + # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + + def state(self) -> list[pa.Scalar]: + return [self._sum] + + def evaluate(self) -> pa.Scalar: + return self._sum + + +my_udaf = udaf( + MyAccumulator, + pa.float64(), + pa.float64(), + [pa.float64()], + "stable", + # This will be the name of the UDAF in SQL + # If not specified it will by default the same as accumulator class name + name="my_accumulator", +) + +# Create a context +ctx = SessionContext() + +# Create a datafusion DataFrame from a Python dictionary +source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t") +# Dataframe: +# +---+---+ +# | a | b | +# +---+---+ +# | 1 | 4 | +# | 1 | 5 | +# | 3 | 6 | +# +---+---+ + +# Register UDF for use in SQL +ctx.register_udaf(my_udaf) + +# Query the DataFrame using SQL +result_df = ctx.sql( + "select a, my_accumulator(b) as b_aggregated from t group by a order by a" +) +# Dataframe: +# +---+--------------+ +# | a | b_aggregated | +# +---+--------------+ +# | 1 | 9 | +# | 3 | 6 | +# +---+--------------+ + +result_dict = result_df.to_pydict() +print("Result:", result_dict) +assert result_dict["a"] == [1, 3] +assert result_dict["b_aggregated"] == [9.0, 6.0] +print("Test passed successfully!") \ No newline at end of file From 3ae32607334581a8ad6d15b0353b65cff0dd5ab3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 25 Jul 2025 21:36:08 +0800 Subject: [PATCH 2/5] fix: update UDAF implementation to use correct pyarrow compute functions --- test_udaf_script.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test_udaf_script.py b/test_udaf_script.py index 04ca2c7bc..519bb9a93 100644 --- a/test_udaf_script.py +++ b/test_udaf_script.py @@ -14,11 +14,11 @@ def __init__(self) -> None: def update(self, values: pa.Array) -> None: # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py()) def merge(self, states: list[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states).as_py()) def state(self) -> list[pa.Scalar]: return [self._sum] From 8541b257a1df87cdbe1f75e146a585426dafabaf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 25 Jul 2025 21:50:49 +0800 Subject: [PATCH 3/5] fix: correct comment capitalization and remove unused import in UDAF script --- test_udaf_script.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test_udaf_script.py b/test_udaf_script.py index 519bb9a93..c76773833 100644 --- a/test_udaf_script.py +++ b/test_udaf_script.py @@ -1,5 +1,4 @@ import pyarrow as pa -import pyarrow.compute as pc from datafusion import Accumulator, SessionContext, udaf @@ -13,12 +12,12 @@ def __init__(self) -> None: self._sum = pa.scalar(0.0) def update(self, values: pa.Array) -> None: - # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py()) def merge(self, states: list[pa.Array]) -> None: - # Not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states).as_py()) + # not nice since pyarrow scalars can't be summed yet. This breaks on `None` + self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states[0]).as_py()) def state(self) -> list[pa.Scalar]: return [self._sum] From 5ef214ed398dabee03865a381b7f1a73104ae848 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Jul 2025 15:34:08 +0800 Subject: [PATCH 4/5] fix: update UDAF accumulator methods to handle state and summation correctly --- examples/sql-using-python-udaf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/sql-using-python-udaf.py b/examples/sql-using-python-udaf.py index 32ce38900..f42bbdc23 100644 --- a/examples/sql-using-python-udaf.py +++ b/examples/sql-using-python-udaf.py @@ -28,16 +28,16 @@ class MyAccumulator(Accumulator): def __init__(self) -> None: self._sum = pa.scalar(0.0) - def update(self, values: pa.Array) -> None: + def update(self, values: list[pa.Array]) -> None: # not nice since pyarrow scalars can't be summed yet. This breaks on `None` self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py()) def merge(self, states: pa.Array) -> None: # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states).as_py()) + self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states[0]).as_py()) - def state(self) -> pa.Array: - return pa.array([self._sum.as_py()]) + def state(self) -> list[pa.Array]: + return [self._sum] def evaluate(self) -> pa.Scalar: return self._sum From 631851173606278861492eb169c1d2d86bf68db8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sun, 27 Jul 2025 15:39:15 +0800 Subject: [PATCH 5/5] Remove test script --- test_udaf_script.py | 73 --------------------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 test_udaf_script.py diff --git a/test_udaf_script.py b/test_udaf_script.py deleted file mode 100644 index c76773833..000000000 --- a/test_udaf_script.py +++ /dev/null @@ -1,73 +0,0 @@ -import pyarrow as pa -from datafusion import Accumulator, SessionContext, udaf - - -# Define a user-defined aggregation function (UDAF) -class MyAccumulator(Accumulator): - """ - Interface of a user-defined accumulation. - """ - - def __init__(self) -> None: - self._sum = pa.scalar(0.0) - - def update(self, values: pa.Array) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(values).as_py()) - - def merge(self, states: list[pa.Array]) -> None: - # not nice since pyarrow scalars can't be summed yet. This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pa.compute.sum(states[0]).as_py()) - - def state(self) -> list[pa.Scalar]: - return [self._sum] - - def evaluate(self) -> pa.Scalar: - return self._sum - - -my_udaf = udaf( - MyAccumulator, - pa.float64(), - pa.float64(), - [pa.float64()], - "stable", - # This will be the name of the UDAF in SQL - # If not specified it will by default the same as accumulator class name - name="my_accumulator", -) - -# Create a context -ctx = SessionContext() - -# Create a datafusion DataFrame from a Python dictionary -source_df = ctx.from_pydict({"a": [1, 1, 3], "b": [4, 5, 6]}, name="t") -# Dataframe: -# +---+---+ -# | a | b | -# +---+---+ -# | 1 | 4 | -# | 1 | 5 | -# | 3 | 6 | -# +---+---+ - -# Register UDF for use in SQL -ctx.register_udaf(my_udaf) - -# Query the DataFrame using SQL -result_df = ctx.sql( - "select a, my_accumulator(b) as b_aggregated from t group by a order by a" -) -# Dataframe: -# +---+--------------+ -# | a | b_aggregated | -# +---+--------------+ -# | 1 | 9 | -# | 3 | 6 | -# +---+--------------+ - -result_dict = result_df.to_pydict() -print("Result:", result_dict) -assert result_dict["a"] == [1, 3] -assert result_dict["b_aggregated"] == [9.0, 6.0] -print("Test passed successfully!") \ No newline at end of file