Skip to content

Commit 3481904

Browse files
kosiewtimsaucer
andauthored
Fix Python UDAF list-of-timestamps return by enforcing list-valued scalars and caching PyArrow types (#1347)
* Implement UDAF improvements for list type handling Store UDAF return type in Rust accumulator and wrap pyarrow Array/ChunkedArray returns into list scalars for list-like return types. Add a UDAF test to return a list of timestamps via a pyarrow array, validating the aggregate output for correctness. * Document UDAF list-valued scalar returns Add documented list-valued scalar returns for UDAF accumulators, including an example with pa.scalar and a note about unsupported pyarrow.Array returns from evaluate(). Also, introduce a UDAF FAQ entry detailing list-returning patterns and required return_type/state_type definitions. * Fix pyarrow calls and improve type handling in RustAccumulator * Refactor RustAccumulator to support pyarrow array types and improve type checking for list types * Fixed PyO3 type mismatch by cloning Array/ChunkedArray types before unbinding and binding fresh copies when checking array-likeness, eliminating the Bound reference error * Add timezone information to datetime objects in test_udaf_list_timestamp_return * clippy fix * Refactor RustAccumulator and utility functions for improved type handling and conversion from Python objects to Arrow types * Enhance PyArrow integration by refining type handling and conversion in RustAccumulator and utility functions * Fix array data binding in py_obj_to_scalar_value function * Implement single point for scalar conversion from python objects * Add unit tests and simplify python wrapper for literal * Add nanoarrow and arro3-core to dev dependencies. Sort the dependencies alphabetically. * Refactor common code into helper function so we do not duplicate it. * Update import path to access Scalar type * Add test for generic python objects that support the C interface * Update unit test to pass back either pyarrow array or array wrapped as scalar * Update tests to pass back raw python values or pyarrow scalar * Expand on user documentation for how to return list arrays * More user documentation --------- Co-authored-by: Tim Saucer <timsaucer@gmail.com>
1 parent 4cd5674 commit 3481904

File tree

13 files changed

+320
-81
lines changed

13 files changed

+320
-81
lines changed

docs/source/user-guide/common-operations/udf-and-udfa.rst

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ.
123123

124124
.. code-block:: python
125125
126-
import pyarrow
126+
import pyarrow as pa
127127
import pyarrow.compute
128128
import datafusion
129129
from datafusion import col, udaf, Accumulator
@@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge`` differ.
136136
def __init__(self):
137137
self._sum = 0.0
138138
139-
def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) -> None:
139+
def update(self, values_a: pa.Array, values_b: pa.Array) -> None:
140140
self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() - pyarrow.compute.sum(values_b).as_py()
141141
142-
def merge(self, states: List[pyarrow.Array]) -> None:
142+
def merge(self, states: list[pa.Array]) -> None:
143143
self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py()
144144
145-
def state(self) -> pyarrow.Array:
146-
return pyarrow.array([self._sum])
145+
def state(self) -> list[pa.Scalar]:
146+
return [pyarrow.scalar(self._sum)]
147147
148-
def evaluate(self) -> pyarrow.Scalar:
148+
def evaluate(self) -> pa.Scalar:
149149
return pyarrow.scalar(self._sum)
150150
151151
ctx = datafusion.SessionContext()
@@ -156,10 +156,29 @@ also see how the inputs to ``update`` and ``merge`` differ.
156156
}
157157
)
158158
159-
my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()], pyarrow.float64(), [pyarrow.float64()], 'stable')
159+
my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(), [pa.float64()], 'stable')
160160
161161
df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")])
162162
163+
FAQ
164+
^^^
165+
166+
**How do I return a list from a UDAF?**
167+
168+
Both the ``evaluate`` and the ``state`` functions expect to return scalar values.
169+
If you wish to return a list array as a scalar value, the best practice is to
170+
wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a
171+
timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and
172+
register the appropriate return or state types as
173+
``return_type=pa.list_(pa.timestamp("ms"))`` and
174+
``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively.
175+
176+
As of DataFusion 52.0.0 , you can pass return any Python object, including a
177+
PyArrow array, as the return value(s) for these functions and DataFusion will
178+
attempt to create a scalar type from the value. DataFusion has been tested to
179+
convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types
180+
like integers, strings, and so on.
181+
163182
Window Functions
164183
----------------
165184

pyproject.toml

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -173,27 +173,29 @@ ignore-words-list = ["ans", "IST"]
173173

174174
[dependency-groups]
175175
dev = [
176+
"arro3-core==0.6.5",
177+
"codespell==2.4.1",
176178
"maturin>=1.8.1",
179+
"nanoarrow==0.8.0",
177180
"numpy>1.25.0;python_version<'3.14'",
178181
"numpy>=2.3.2;python_version>='3.14'",
179-
"pyarrow>=19.0.0",
180182
"pre-commit>=4.3.0",
181-
"pyyaml>=6.0.3",
183+
"pyarrow>=19.0.0",
184+
"pygithub==2.5.0",
182185
"pytest>=7.4.4",
183186
"pytest-asyncio>=0.23.3",
187+
"pyyaml>=6.0.3",
184188
"ruff>=0.9.1",
185189
"toml>=0.10.2",
186-
"pygithub==2.5.0",
187-
"codespell==2.4.1",
188190
]
189191
docs = [
190-
"sphinx>=7.1.2",
191-
"pydata-sphinx-theme==0.8.0",
192-
"myst-parser>=3.0.1",
193-
"jinja2>=3.1.5",
194192
"ipython>=8.12.3",
193+
"jinja2>=3.1.5",
194+
"myst-parser>=3.0.1",
195195
"pandas>=2.0.3",
196196
"pickleshare>=0.7.5",
197-
"sphinx-autoapi>=3.4.0",
197+
"pydata-sphinx-theme==0.8.0",
198198
"setuptools>=75.3.0",
199+
"sphinx>=7.1.2",
200+
"sphinx-autoapi>=3.4.0",
199201
]

python/datafusion/expr.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,8 +562,6 @@ def literal(value: Any) -> Expr:
562562
"""
563563
if isinstance(value, str):
564564
value = pa.scalar(value, type=pa.string_view())
565-
if not isinstance(value, pa.Scalar):
566-
value = pa.scalar(value)
567565
return Expr(expr_internal.RawExpr.literal(value))
568566

569567
@staticmethod
@@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
576574
"""
577575
if isinstance(value, str):
578576
value = pa.scalar(value, type=pa.string_view())
579-
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)
580577

581578
return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))
582579

python/datafusion/user_defined.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta):
298298

299299
@abstractmethod
300300
def state(self) -> list[pa.Scalar]:
301-
"""Return the current state."""
301+
"""Return the current state.
302+
303+
While this function template expects PyArrow Scalar values return type,
304+
you can return any value that can be converted into a Scalar. This
305+
includes basic Python data types such as integers and strings. In
306+
addition to primitive types, we currently support PyArrow, nanoarrow,
307+
and arro3 objects in addition to primitive data types. Other objects
308+
that support the Arrow FFI standard will be given a "best attempt" at
309+
conversion to scalar objects.
310+
"""
302311

303312
@abstractmethod
304313
def update(self, *values: pa.Array) -> None:
@@ -310,7 +319,16 @@ def merge(self, states: list[pa.Array]) -> None:
310319

311320
@abstractmethod
312321
def evaluate(self) -> pa.Scalar:
313-
"""Return the resultant value."""
322+
"""Return the resultant value.
323+
324+
While this function template expects a PyArrow Scalar value return type,
325+
you can return any value that can be converted into a Scalar. This
326+
includes basic Python data types such as integers and strings. In
327+
addition to primitive types, we currently support PyArrow, nanoarrow,
328+
and arro3 objects in addition to primitive data types. Other objects
329+
that support the Arrow FFI standard will be given a "best attempt" at
330+
conversion to scalar objects.
331+
"""
314332

315333

316334
class AggregateUDFExportable(Protocol):

python/tests/test_expr.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from datetime import date, datetime, time, timezone
2121
from decimal import Decimal
2222

23+
import arro3.core
24+
import nanoarrow
2325
import pyarrow as pa
2426
import pytest
2527
from datafusion import (
@@ -980,6 +982,49 @@ def test_literal_metadata(ctx):
980982
assert expected_field.metadata == actual_field.metadata
981983

982984

985+
def test_scalar_conversion() -> None:
986+
class WrappedPyArrow:
987+
"""Wrapper class for testing __arrow_c_array__."""
988+
989+
def __init__(self, val: pa.Array) -> None:
990+
self.val = val
991+
992+
def __arrow_c_array__(self, requested_schema=None):
993+
return self.val.__arrow_c_array__(requested_schema=requested_schema)
994+
995+
expected_value = lit(1)
996+
assert str(expected_value) == "Expr(Int64(1))"
997+
998+
# Test pyarrow imports
999+
assert expected_value == lit(pa.scalar(1))
1000+
assert expected_value == lit(pa.scalar(1, type=pa.int32()))
1001+
1002+
# Test nanoarrow
1003+
na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
1004+
assert expected_value == lit(na_scalar)
1005+
1006+
# Test pyo3
1007+
arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
1008+
assert expected_value == lit(arro3_scalar)
1009+
1010+
generic_scalar = WrappedPyArrow(pa.array([1]))
1011+
assert expected_value == lit(generic_scalar)
1012+
1013+
expected_value = lit([1, 2, 3])
1014+
assert str(expected_value) == "Expr(List([1, 2, 3]))"
1015+
1016+
assert expected_value == lit(pa.scalar([1, 2, 3]))
1017+
1018+
na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
1019+
assert expected_value == lit(na_array)
1020+
1021+
arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
1022+
assert expected_value == lit(arro3_array)
1023+
1024+
generic_array = WrappedPyArrow(pa.array([1, 2, 3]))
1025+
assert expected_value == lit(generic_array)
1026+
1027+
9831028
def test_ensure_expr():
9841029
e = col("a")
9851030
assert ensure_expr(e) is e.expr

python/tests/test_udaf.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from __future__ import annotations
1919

20+
from datetime import datetime, timezone
21+
2022
import pyarrow as pa
2123
import pyarrow.compute as pc
2224
import pytest
@@ -26,23 +28,28 @@
2628
class Summarize(Accumulator):
2729
"""Interface of a user-defined accumulation."""
2830

29-
def __init__(self, initial_value: float = 0.0):
30-
self._sum = pa.scalar(initial_value)
31+
def __init__(self, initial_value: float = 0.0, as_scalar: bool = False):
32+
self._sum = initial_value
33+
self.as_scalar = as_scalar
3134

3235
def state(self) -> list[pa.Scalar]:
36+
if self.as_scalar:
37+
return [pa.scalar(self._sum)]
3338
return [self._sum]
3439

3540
def update(self, values: pa.Array) -> None:
3641
# Not nice since pyarrow scalars can't be summed yet.
3742
# This breaks on `None`
38-
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
43+
self._sum = self._sum + pc.sum(values).as_py()
3944

4045
def merge(self, states: list[pa.Array]) -> None:
4146
# Not nice since pyarrow scalars can't be summed yet.
4247
# This breaks on `None`
43-
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
48+
self._sum = self._sum + pc.sum(states[0]).as_py()
4449

4550
def evaluate(self) -> pa.Scalar:
51+
if self.as_scalar:
52+
return pa.scalar(self._sum)
4653
return self._sum
4754

4855

@@ -58,6 +65,30 @@ def state(self) -> list[pa.Scalar]:
5865
return [self._sum]
5966

6067

68+
class CollectTimestamps(Accumulator):
69+
def __init__(self, wrap_in_scalar: bool):
70+
self._values: list[datetime] = []
71+
self.wrap_in_scalar = wrap_in_scalar
72+
73+
def state(self) -> list[pa.Scalar]:
74+
if self.wrap_in_scalar:
75+
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
76+
return [pa.array(self._values, type=pa.timestamp("ns"))]
77+
78+
def update(self, values: pa.Array) -> None:
79+
self._values.extend(values.to_pylist())
80+
81+
def merge(self, states: list[pa.Array]) -> None:
82+
for state in states[0].to_pylist():
83+
if state is not None:
84+
self._values.extend(state)
85+
86+
def evaluate(self) -> pa.Scalar:
87+
if self.wrap_in_scalar:
88+
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))
89+
return pa.array(self._values, type=pa.timestamp("ns"))
90+
91+
6192
@pytest.fixture
6293
def df(ctx):
6394
# create a RecordBatch and a new DataFrame from it
@@ -137,11 +168,12 @@ def summarize():
137168
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
138169

139170

140-
def test_udaf_aggregate_with_arguments(df):
171+
@pytest.mark.parametrize("as_scalar", [True, False])
172+
def test_udaf_aggregate_with_arguments(df, as_scalar):
141173
bias = 10.0
142174

143175
summarize = udaf(
144-
lambda: Summarize(bias),
176+
lambda: Summarize(initial_value=bias, as_scalar=as_scalar),
145177
pa.float64(),
146178
pa.float64(),
147179
[pa.float64()],
@@ -217,3 +249,48 @@ def test_register_udaf(ctx, df) -> None:
217249
df_result = ctx.sql("select summarize(b) from test_table")
218250

219251
assert df_result.collect()[0][0][0].as_py() == 14.0
252+
253+
254+
@pytest.mark.parametrize("wrap_in_scalar", [True, False])
255+
def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None:
256+
timestamps1 = [
257+
datetime(2024, 1, 1, tzinfo=timezone.utc),
258+
datetime(2024, 1, 2, tzinfo=timezone.utc),
259+
]
260+
timestamps2 = [
261+
datetime(2024, 1, 3, tzinfo=timezone.utc),
262+
datetime(2024, 1, 4, tzinfo=timezone.utc),
263+
]
264+
batch1 = pa.RecordBatch.from_arrays(
265+
[pa.array(timestamps1, type=pa.timestamp("ns"))],
266+
names=["ts"],
267+
)
268+
batch2 = pa.RecordBatch.from_arrays(
269+
[pa.array(timestamps2, type=pa.timestamp("ns"))],
270+
names=["ts"],
271+
)
272+
df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table")
273+
274+
list_type = pa.list_(
275+
pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar)
276+
)
277+
278+
collect = udaf(
279+
lambda: CollectTimestamps(wrap_in_scalar),
280+
pa.timestamp("ns"),
281+
list_type,
282+
[list_type],
283+
volatility="immutable",
284+
)
285+
286+
result = df.aggregate([], [collect(column("ts"))]).collect()[0]
287+
288+
# There is no guarantee about the ordering of the batches, so perform a sort
289+
# to get consistent results. Alternatively we could sort on evaluate().
290+
assert (
291+
result.column(0).values.sort()
292+
== pa.array(
293+
[[*timestamps1, *timestamps2]],
294+
type=list_type,
295+
).values
296+
)

src/common/data_type.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment;
2222
use pyo3::exceptions::{PyNotImplementedError, PyValueError};
2323
use pyo3::prelude::*;
2424

25+
/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion
26+
/// from a variety of Python objects into a [`ScalarValue`]. See
27+
/// ``FromPyArrow::from_pyarrow_bound`` conversion details.
2528
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
2629
pub struct PyScalarValue(pub ScalarValue);
2730

src/config.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ use parking_lot::RwLock;
2222
use pyo3::prelude::*;
2323
use pyo3::types::*;
2424

25+
use crate::common::data_type::PyScalarValue;
2526
use crate::errors::PyDataFusionResult;
26-
use crate::utils::py_obj_to_scalar_value;
2727
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
2828
#[derive(Clone)]
2929
pub(crate) struct PyConfig {
@@ -65,9 +65,9 @@ impl PyConfig {
6565

6666
/// Set a configuration option
6767
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> PyDataFusionResult<()> {
68-
let scalar_value = py_obj_to_scalar_value(py, value)?;
68+
let scalar_value: PyScalarValue = value.extract(py)?;
6969
let mut options = self.config.write();
70-
options.set(key, scalar_value.to_string().as_str())?;
70+
options.set(key, scalar_value.0.to_string().as_str())?;
7171
Ok(())
7272
}
7373

0 commit comments

Comments
 (0)