Skip to content

Commit 5f10176

Browse files
committed
Implement UDAF improvements for list type handling
Store UDAF return type in Rust accumulator and wrap pyarrow Array/ChunkedArray returns into list scalars for list-like return types. Add a UDAF test to return a list of timestamps via a pyarrow array, validating the aggregate output for correctness.
1 parent 7aff363 commit 5f10176

File tree

2 files changed

+78
-9
lines changed

2 files changed

+78
-9
lines changed

python/tests/test_udaf.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
from __future__ import annotations
1919

20+
from datetime import datetime
21+
2022
import pyarrow as pa
2123
import pyarrow.compute as pc
2224
import pytest
@@ -58,6 +60,25 @@ def state(self) -> list[pa.Scalar]:
5860
return [self._sum]
5961

6062

63+
class CollectTimestamps(Accumulator):
64+
def __init__(self):
65+
self._values: list[datetime] = []
66+
67+
def state(self) -> list[pa.Scalar]:
68+
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]
69+
70+
def update(self, values: pa.Array) -> None:
71+
self._values.extend(values.to_pylist())
72+
73+
def merge(self, states: list[pa.Array]) -> None:
74+
for state in states[0].to_pylist():
75+
if state is not None:
76+
self._values.extend(state)
77+
78+
def evaluate(self) -> pa.Array:
79+
return pa.array(self._values, type=pa.timestamp("ns"))
80+
81+
6182
@pytest.fixture
6283
def df(ctx):
6384
# create a RecordBatch and a new DataFrame from it
@@ -217,3 +238,27 @@ def test_register_udaf(ctx, df) -> None:
217238
df_result = ctx.sql("select summarize(b) from test_table")
218239

219240
assert df_result.collect()[0][0][0].as_py() == 14.0
241+
242+
243+
def test_udaf_list_timestamp_return(ctx) -> None:
244+
timestamps = [datetime(2024, 1, 1), datetime(2024, 1, 2)]
245+
batch = pa.RecordBatch.from_arrays(
246+
[pa.array(timestamps, type=pa.timestamp("ns"))],
247+
names=["ts"],
248+
)
249+
df = ctx.create_dataframe([[batch]], name="timestamp_table")
250+
251+
collect = udaf(
252+
CollectTimestamps,
253+
pa.timestamp("ns"),
254+
pa.list_(pa.timestamp("ns")),
255+
[pa.list_(pa.timestamp("ns"))],
256+
volatility="immutable",
257+
)
258+
259+
result = df.aggregate([], [collect(column("ts"))]).collect()[0]
260+
261+
assert result.column(0) == pa.array(
262+
[timestamps],
263+
type=pa.list_(pa.timestamp("ns")),
264+
)

src/udaf.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use datafusion::logical_expr::{
2727
};
2828
use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF};
2929
use pyo3::prelude::*;
30-
use pyo3::types::{PyCapsule, PyTuple};
30+
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};
3131

3232
use crate::common::data_type::PyScalarValue;
3333
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
@@ -37,11 +37,12 @@ use crate::utils::{parse_volatility, validate_pycapsule};
3737
#[derive(Debug)]
3838
struct RustAccumulator {
3939
accum: Py<PyAny>,
40+
return_type: DataType,
4041
}
4142

4243
impl RustAccumulator {
43-
fn new(accum: Py<PyAny>) -> Self {
44-
Self { accum }
44+
fn new(accum: Py<PyAny>, return_type: DataType) -> Self {
45+
Self { accum, return_type }
4546
}
4647
}
4748

@@ -59,10 +60,23 @@ impl Accumulator for RustAccumulator {
5960

6061
fn evaluate(&mut self) -> Result<ScalarValue> {
6162
Python::attach(|py| {
62-
self.accum
63-
.bind(py)
64-
.call_method0("evaluate")?
65-
.extract::<PyScalarValue>()
63+
let value = self.accum.bind(py).call_method0("evaluate")?;
64+
let is_list_type = matches!(
65+
self.return_type,
66+
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
67+
);
68+
if is_list_type && is_pyarrow_array_like(py, &value)? {
69+
let pyarrow = PyModule::import(py, "pyarrow")?;
70+
let list_value = value.call_method0("to_pylist")?;
71+
let py_type = self.return_type.to_pyarrow(py)?;
72+
let kwargs = PyDict::new(py);
73+
kwargs.set_item("type", py_type)?;
74+
return pyarrow
75+
.getattr("scalar")?
76+
.call((list_value,), Some(kwargs))?
77+
.extract::<PyScalarValue>();
78+
}
79+
value.extract::<PyScalarValue>()
6680
})
6781
.map(|v| v.0)
6882
.map_err(|e| DataFusionError::Execution(format!("{e}")))
@@ -144,16 +158,26 @@ impl Accumulator for RustAccumulator {
144158
}
145159

146160
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
147-
Arc::new(move |_| -> Result<Box<dyn Accumulator>> {
161+
Arc::new(move |args| -> Result<Box<dyn Accumulator>> {
148162
let accum = Python::attach(|py| {
149163
accum
150164
.call0(py)
151165
.map_err(|e| DataFusionError::Execution(format!("{e}")))
152166
})?;
153-
Ok(Box::new(RustAccumulator::new(accum)))
167+
Ok(Box::new(RustAccumulator::new(
168+
accum,
169+
args.return_type.clone(),
170+
)))
154171
})
155172
}
156173

174+
fn is_pyarrow_array_like(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult<bool> {
175+
let pyarrow = PyModule::import(py, "pyarrow")?;
176+
let array_type = pyarrow.getattr("Array")?.downcast::<PyType>()?;
177+
let chunked_array_type = pyarrow.getattr("ChunkedArray")?.downcast::<PyType>()?;
178+
Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?)
179+
}
180+
157181
fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
158182
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;
159183

0 commit comments

Comments
 (0)