From 5f101761d988e8730a57e9975de670742e2eec04 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 12:42:28 +0800 Subject: [PATCH 01/20] Implement UDAF improvements for list type handling Store UDAF return type in Rust accumulator and wrap pyarrow Array/ChunkedArray returns into list scalars for list-like return types. Add a UDAF test to return a list of timestamps via a pyarrow array, validating the aggregate output for correctness. --- python/tests/test_udaf.py | 45 +++++++++++++++++++++++++++++++++++++++ src/udaf.rs | 42 ++++++++++++++++++++++++++++-------- 2 files changed, 78 insertions(+), 9 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 453ff6f4f..088200e20 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,6 +17,8 @@ from __future__ import annotations +from datetime import datetime + import pyarrow as pa import pyarrow.compute as pc import pytest @@ -58,6 +60,25 @@ def state(self) -> list[pa.Scalar]: return [self._sum] +class CollectTimestamps(Accumulator): + def __init__(self): + self._values: list[datetime] = [] + + def state(self) -> list[pa.Scalar]: + return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))] + + def update(self, values: pa.Array) -> None: + self._values.extend(values.to_pylist()) + + def merge(self, states: list[pa.Array]) -> None: + for state in states[0].to_pylist(): + if state is not None: + self._values.extend(state) + + def evaluate(self) -> pa.Array: + return pa.array(self._values, type=pa.timestamp("ns")) + + @pytest.fixture def df(ctx): # create a RecordBatch and a new DataFrame from it @@ -217,3 +238,27 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 + + +def test_udaf_list_timestamp_return(ctx) -> None: + timestamps = [datetime(2024, 1, 1), datetime(2024, 1, 2)] + batch = pa.RecordBatch.from_arrays( + [pa.array(timestamps, type=pa.timestamp("ns"))], + names=["ts"], + ) + df = ctx.create_dataframe([[batch]], name="timestamp_table") + + collect = udaf( + CollectTimestamps, + pa.timestamp("ns"), + pa.list_(pa.timestamp("ns")), + [pa.list_(pa.timestamp("ns"))], + volatility="immutable", + ) + + result = df.aggregate([], [collect(column("ts"))]).collect()[0] + + assert result.column(0) == pa.array( + [timestamps], + type=pa.list_(pa.timestamp("ns")), + ) diff --git a/src/udaf.rs b/src/udaf.rs index 92857f9f7..6004f1132 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -27,7 +27,7 @@ use datafusion::logical_expr::{ }; use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF}; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyTuple}; +use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType}; use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; @@ -37,11 +37,12 @@ use crate::utils::{parse_volatility, validate_pycapsule}; #[derive(Debug)] struct RustAccumulator { accum: Py, + return_type: DataType, } impl RustAccumulator { - fn new(accum: Py) -> Self { - Self { accum } + fn new(accum: Py, return_type: DataType) -> Self { + Self { accum, return_type } } } @@ -59,10 +60,23 @@ impl Accumulator for RustAccumulator { fn evaluate(&mut self) -> Result { Python::attach(|py| { - self.accum - .bind(py) - .call_method0("evaluate")? - .extract::() + let value = self.accum.bind(py).call_method0("evaluate")?; + let is_list_type = matches!( + self.return_type, + DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) + ); + if is_list_type && is_pyarrow_array_like(py, &value)? { + let pyarrow = PyModule::import(py, "pyarrow")?; + let list_value = value.call_method0("to_pylist")?; + let py_type = self.return_type.to_pyarrow(py)?; + let kwargs = PyDict::new(py); + kwargs.set_item("type", py_type)?; + return pyarrow + .getattr("scalar")? + .call((list_value,), Some(kwargs))? + .extract::(); + } + value.extract::() }) .map(|v| v.0) .map_err(|e| DataFusionError::Execution(format!("{e}"))) @@ -144,16 +158,26 @@ impl Accumulator for RustAccumulator { } pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_| -> Result> { + Arc::new(move |args| -> Result> { let accum = Python::attach(|py| { accum .call0(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) })?; - Ok(Box::new(RustAccumulator::new(accum))) + Ok(Box::new(RustAccumulator::new( + accum, + args.return_type.clone(), + ))) }) } +fn is_pyarrow_array_like(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult { + let pyarrow = PyModule::import(py, "pyarrow")?; + let array_type = pyarrow.getattr("Array")?.downcast::()?; + let chunked_array_type = pyarrow.getattr("ChunkedArray")?.downcast::()?; + Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?) +} + fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { validate_pycapsule(capsule, "datafusion_aggregate_udf")?; From fbba2a03c6d437e0cdf2f98e45c15d4f085bc971 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 12:43:06 +0800 Subject: [PATCH 02/20] Document UDAF list-valued scalar returns Add documented list-valued scalar returns for UDAF accumulators, including an example with pa.scalar and a note about unsupported pyarrow.Array returns from evaluate(). Also, introduce a UDAF FAQ entry detailing list-returning patterns and required return_type/state_type definitions. --- .../common-operations/udf-and-udfa.rst | 11 +++++++++++ python/datafusion/user_defined.py | 16 +++++++++++++++- 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index 0830fa81c..970a42b04 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -149,6 +149,17 @@ also see how the inputs to ``update`` and ``merge`` differ. df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")]) +FAQ +^^^ + +**How do I return a list from a UDAF?** +Use a list-valued scalar and declare list types for both the return and state +definitions. Returning a ``pyarrow.Array`` from ``evaluate`` is not supported +unless you convert it to a list scalar. For example, in ``evaluate`` you can +return ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and register the +UDAF with ``return_type=pa.list_(pa.timestamp("ms"))`` and +``state_type=[pa.list_(pa.timestamp("ms"))]``. + Window Functions ---------------- diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 43a72c805..49b7bb39f 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -282,7 +282,21 @@ def merge(self, states: list[pa.Array]) -> None: @abstractmethod def evaluate(self) -> pa.Scalar: - """Return the resultant value.""" + """Return the resultant value. + + If you need to return a list, wrap it in a scalar with the correct + list type, for example:: + + import pyarrow as pa + + return pa.scalar( + [pa.scalar("2024-01-01T00:00:00Z")], + type=pa.list_(pa.timestamp("ms")), + ) + + Returning a ``pyarrow.Array`` from ``evaluate`` is not supported unless + you explicitly convert it to a list-valued scalar. + """ class AggregateUDFExportable(Protocol): From 21906bbe43c5fb6f65e92100e0fe05d9936d8a67 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 12:46:46 +0800 Subject: [PATCH 03/20] Fix pyarrow calls and improve type handling in RustAccumulator --- src/udaf.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 6004f1132..06359a484 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -73,7 +73,7 @@ impl Accumulator for RustAccumulator { kwargs.set_item("type", py_type)?; return pyarrow .getattr("scalar")? - .call((list_value,), Some(kwargs))? + .call((list_value,), Some(&kwargs))? .extract::(); } value.extract::() @@ -166,15 +166,17 @@ pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { })?; Ok(Box::new(RustAccumulator::new( accum, - args.return_type.clone(), + args.return_type().clone(), ))) }) } fn is_pyarrow_array_like(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult { let pyarrow = PyModule::import(py, "pyarrow")?; - let array_type = pyarrow.getattr("Array")?.downcast::()?; - let chunked_array_type = pyarrow.getattr("ChunkedArray")?.downcast::()?; + let array_attr = pyarrow.getattr("Array")?; + let array_type = array_attr.downcast::()?; + let chunked_array_attr = pyarrow.getattr("ChunkedArray")?; + let chunked_array_type = chunked_array_attr.downcast::()?; Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?) } From 7f363a7384a1d4c430bbb5358afc7aedf4b46a03 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 12:53:24 +0800 Subject: [PATCH 04/20] Refactor RustAccumulator to support pyarrow array types and improve type checking for list types --- python/tests/test_udaf.py | 4 +-- src/udaf.rs | 54 +++++++++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 088200e20..0b90c0f0f 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -75,8 +75,8 @@ def merge(self, states: list[pa.Array]) -> None: if state is not None: self._values.extend(state) - def evaluate(self) -> pa.Array: - return pa.array(self._values, type=pa.timestamp("ns")) + def evaluate(self) -> pa.Scalar: + return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns"))) @pytest.fixture diff --git a/src/udaf.rs b/src/udaf.rs index 06359a484..4357bafa7 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -38,11 +38,52 @@ use crate::utils::{parse_volatility, validate_pycapsule}; struct RustAccumulator { accum: Py, return_type: DataType, + pyarrow_array_type: Option>, + pyarrow_chunked_array_type: Option>, } impl RustAccumulator { fn new(accum: Py, return_type: DataType) -> Self { - Self { accum, return_type } + Self { + accum, + return_type, + pyarrow_array_type: None, + pyarrow_chunked_array_type: None, + } + } + + fn ensure_pyarrow_types( + &mut self, + py: Python<'_>, + ) -> PyResult<(Bound<'_, PyType>, Bound<'_, PyType>)> { + if self.pyarrow_array_type.is_none() || self.pyarrow_chunked_array_type.is_none() { + let pyarrow = PyModule::import(py, "pyarrow")?; + let array_attr = pyarrow.getattr("Array")?; + let array_type = array_attr.downcast::()?; + let chunked_array_attr = pyarrow.getattr("ChunkedArray")?; + let chunked_array_type = chunked_array_attr.downcast::()?; + self.pyarrow_array_type = Some(array_type.unbind()); + self.pyarrow_chunked_array_type = Some(chunked_array_type.unbind()); + } + Ok(( + self.pyarrow_array_type + .as_ref() + .expect("array type set") + .bind(py), + self.pyarrow_chunked_array_type + .as_ref() + .expect("chunked array type set") + .bind(py), + )) + } + + fn is_pyarrow_array_like( + &mut self, + py: Python<'_>, + value: &Bound<'_, PyAny>, + ) -> PyResult { + let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?; + Ok(value.is_instance(&array_type)? || value.is_instance(&chunked_array_type)?) } } @@ -65,7 +106,7 @@ impl Accumulator for RustAccumulator { self.return_type, DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) ); - if is_list_type && is_pyarrow_array_like(py, &value)? { + if is_list_type && self.is_pyarrow_array_like(py, &value)? { let pyarrow = PyModule::import(py, "pyarrow")?; let list_value = value.call_method0("to_pylist")?; let py_type = self.return_type.to_pyarrow(py)?; @@ -171,15 +212,6 @@ pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { }) } -fn is_pyarrow_array_like(py: Python<'_>, value: &Bound<'_, PyAny>) -> PyResult { - let pyarrow = PyModule::import(py, "pyarrow")?; - let array_attr = pyarrow.getattr("Array")?; - let array_type = array_attr.downcast::()?; - let chunked_array_attr = pyarrow.getattr("ChunkedArray")?; - let chunked_array_type = chunked_array_attr.downcast::()?; - Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?) -} - fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { validate_pycapsule(capsule, "datafusion_aggregate_udf")?; From 5271ba2b0e980c5e83c2b4d68748f08f13e381ab Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 12:57:01 +0800 Subject: [PATCH 05/20] Fixed PyO3 type mismatch by cloning Array/ChunkedArray types before unbinding and binding fresh copies when checking array-likeness, eliminating the Bound reference error --- src/udaf.rs | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 4357bafa7..fa7a5aa4d 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -52,28 +52,25 @@ impl RustAccumulator { } } - fn ensure_pyarrow_types( - &mut self, - py: Python<'_>, - ) -> PyResult<(Bound<'_, PyType>, Bound<'_, PyType>)> { + fn ensure_pyarrow_types(&mut self, py: Python<'_>) -> PyResult<(Py, Py)> { if self.pyarrow_array_type.is_none() || self.pyarrow_chunked_array_type.is_none() { let pyarrow = PyModule::import(py, "pyarrow")?; let array_attr = pyarrow.getattr("Array")?; let array_type = array_attr.downcast::()?; let chunked_array_attr = pyarrow.getattr("ChunkedArray")?; let chunked_array_type = chunked_array_attr.downcast::()?; - self.pyarrow_array_type = Some(array_type.unbind()); - self.pyarrow_chunked_array_type = Some(chunked_array_type.unbind()); + self.pyarrow_array_type = Some(array_type.clone().unbind()); + self.pyarrow_chunked_array_type = Some(chunked_array_type.clone().unbind()); } Ok(( self.pyarrow_array_type .as_ref() .expect("array type set") - .bind(py), + .clone_ref(py), self.pyarrow_chunked_array_type .as_ref() .expect("chunked array type set") - .bind(py), + .clone_ref(py), )) } @@ -83,6 +80,8 @@ impl RustAccumulator { value: &Bound<'_, PyAny>, ) -> PyResult { let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?; + let array_type = array_type.bind(py); + let chunked_array_type = chunked_array_type.bind(py); Ok(value.is_instance(&array_type)? || value.is_instance(&chunked_array_type)?) } } From 9c59258e63511b12638d88e999d2bf502acdf094 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 13:22:47 +0800 Subject: [PATCH 06/20] Add timezone information to datetime objects in test_udaf_list_timestamp_return --- python/tests/test_udaf.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 0b90c0f0f..cfbbbca1c 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,7 +17,7 @@ from __future__ import annotations -from datetime import datetime +from datetime import datetime, timezone import pyarrow as pa import pyarrow.compute as pc @@ -241,7 +241,10 @@ def test_register_udaf(ctx, df) -> None: def test_udaf_list_timestamp_return(ctx) -> None: - timestamps = [datetime(2024, 1, 1), datetime(2024, 1, 2)] + timestamps = [ + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 1, 2, tzinfo=timezone.utc), + ] batch = pa.RecordBatch.from_arrays( [pa.array(timestamps, type=pa.timestamp("ns"))], names=["ts"], From 6742954c0383ab9590cbde4275130addd967f14f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Tue, 20 Jan 2026 13:40:05 +0800 Subject: [PATCH 07/20] clippy fix --- src/udaf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/udaf.rs b/src/udaf.rs index fa7a5aa4d..1edbd0fad 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -82,7 +82,7 @@ impl RustAccumulator { let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?; let array_type = array_type.bind(py); let chunked_array_type = chunked_array_type.bind(py); - Ok(value.is_instance(&array_type)? || value.is_instance(&chunked_array_type)?) + Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?) } } From f16c7187e0708b95b741f70d06a42400a4848b07 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 5 Feb 2026 18:47:02 +0800 Subject: [PATCH 08/20] Refactor RustAccumulator and utility functions for improved type handling and conversion from Python objects to Arrow types --- src/udaf.rs | 93 ++++++++++++---------------------------------------- src/utils.rs | 38 ++++++++++++++++++++- 2 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 1edbd0fad..4d9fe1df4 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -17,9 +17,9 @@ use std::sync::Arc; -use datafusion::arrow::array::{Array, ArrayRef}; +use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; +use datafusion::arrow::pyarrow::PyArrowType; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ @@ -27,98 +27,50 @@ use datafusion::logical_expr::{ }; use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF}; use pyo3::prelude::*; -use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType}; +use pyo3::types::{PyCapsule, PyTuple}; use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; +use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule}; #[derive(Debug)] struct RustAccumulator { accum: Py, - return_type: DataType, - pyarrow_array_type: Option>, - pyarrow_chunked_array_type: Option>, } impl RustAccumulator { - fn new(accum: Py, return_type: DataType) -> Self { - Self { - accum, - return_type, - pyarrow_array_type: None, - pyarrow_chunked_array_type: None, - } - } - - fn ensure_pyarrow_types(&mut self, py: Python<'_>) -> PyResult<(Py, Py)> { - if self.pyarrow_array_type.is_none() || self.pyarrow_chunked_array_type.is_none() { - let pyarrow = PyModule::import(py, "pyarrow")?; - let array_attr = pyarrow.getattr("Array")?; - let array_type = array_attr.downcast::()?; - let chunked_array_attr = pyarrow.getattr("ChunkedArray")?; - let chunked_array_type = chunked_array_attr.downcast::()?; - self.pyarrow_array_type = Some(array_type.clone().unbind()); - self.pyarrow_chunked_array_type = Some(chunked_array_type.clone().unbind()); - } - Ok(( - self.pyarrow_array_type - .as_ref() - .expect("array type set") - .clone_ref(py), - self.pyarrow_chunked_array_type - .as_ref() - .expect("chunked array type set") - .clone_ref(py), - )) - } - - fn is_pyarrow_array_like( - &mut self, - py: Python<'_>, - value: &Bound<'_, PyAny>, - ) -> PyResult { - let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?; - let array_type = array_type.bind(py); - let chunked_array_type = chunked_array_type.bind(py); - Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?) + fn new(accum: Py) -> Self { + Self { accum } } } impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { Python::attach(|py| { - self.accum - .bind(py) - .call_method0("state")? - .extract::>() + let values = self.accum.bind(py).call_method0("state")?; + let mut scalars = Vec::new(); + for item in values.iter()? { + let item = item?; + let scalar = match item.extract::() { + Ok(py_scalar) => py_scalar.0, + Err(_) => py_obj_to_scalar_value(py, item.into_py(py))?, + }; + scalars.push(scalar); + } + Ok(scalars) }) - .map(|v| v.into_iter().map(|x| x.0).collect()) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn evaluate(&mut self) -> Result { Python::attach(|py| { let value = self.accum.bind(py).call_method0("evaluate")?; - let is_list_type = matches!( - self.return_type, - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _) - ); - if is_list_type && self.is_pyarrow_array_like(py, &value)? { - let pyarrow = PyModule::import(py, "pyarrow")?; - let list_value = value.call_method0("to_pylist")?; - let py_type = self.return_type.to_pyarrow(py)?; - let kwargs = PyDict::new(py); - kwargs.set_item("type", py_type)?; - return pyarrow - .getattr("scalar")? - .call((list_value,), Some(&kwargs))? - .extract::(); + match value.extract::() { + Ok(py_scalar) => Ok(py_scalar.0), + Err(_) => py_obj_to_scalar_value(py, value.into_py(py)), } - value.extract::() }) - .map(|v| v.0) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } @@ -204,10 +156,7 @@ pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { .call0(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) })?; - Ok(Box::new(RustAccumulator::new( - accum, - args.return_type().clone(), - ))) + Ok(Box::new(RustAccumulator::new(accum))) }) } diff --git a/src/utils.rs b/src/utils.rs index 6038c77b1..afc39fd86 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,6 +19,10 @@ use std::future::Future; use std::sync::{Arc, OnceLock}; use std::time::Duration; +use datafusion::arrow::array::{make_array, ArrayData, ListArray}; +use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use datafusion::arrow::datatypes::Field; +use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::common::ScalarValue; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionContext; @@ -26,7 +30,7 @@ use datafusion::logical_expr::Volatility; use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; -use pyo3::types::PyCapsule; +use pyo3::types::{PyCapsule, PyType}; use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio::time::sleep; @@ -188,6 +192,38 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult()?; + let array_type = pa.getattr("Array")?.downcast::()?; + let chunked_array_type = pa.getattr("ChunkedArray")?.downcast::()?; + + let obj_ref = obj.bind(py); + + if obj_ref.is_instance(scalar_type)? { + let py_scalar = PyScalarValue::extract_bound(obj_ref) + .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; + return Ok(py_scalar.into()); + } + + if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? { + let array_obj = if obj_ref.is_instance(chunked_array_type)? { + obj_ref.call_method0("combine_chunks")?.to_object(py) + } else { + obj_ref.to_object(py) + }; + let array_bound = array_obj.bind(py); + let array_data = ArrayData::from_pyarrow_bound(&array_bound) + .map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?; + let array = make_array(array_data); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32])); + let list_array = Arc::new(ListArray::new( + Arc::new(Field::new_list_field(array.data_type().clone(), true)), + offsets, + array, + None, + )); + + return Ok(ScalarValue::List(list_array)); + } // Convert Python object to PyArrow scalar let scalar = pa.call_method1("scalar", (obj,))?; From dcf61454f42bee2c22d6195f5f0a4700b701f3a5 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 5 Feb 2026 19:43:35 +0800 Subject: [PATCH 09/20] Enhance PyArrow integration by refining type handling and conversion in RustAccumulator and utility functions --- src/udaf.rs | 22 +++++++++++----------- src/utils.rs | 13 ++++++++----- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/src/udaf.rs b/src/udaf.rs index 4d9fe1df4..3d68178b4 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -19,7 +19,7 @@ use std::sync::Arc; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; -use datafusion::arrow::pyarrow::PyArrowType; +use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::{ @@ -47,14 +47,14 @@ impl RustAccumulator { impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { - Python::attach(|py| { + Python::attach(|py| -> PyResult> { let values = self.accum.bind(py).call_method0("state")?; let mut scalars = Vec::new(); - for item in values.iter()? { - let item = item?; + for item in values.try_iter()? { + let item: Bound<'_, PyAny> = item?; let scalar = match item.extract::() { Ok(py_scalar) => py_scalar.0, - Err(_) => py_obj_to_scalar_value(py, item.into_py(py))?, + Err(_) => py_obj_to_scalar_value(py, item.unbind())?, }; scalars.push(scalar); } @@ -64,11 +64,11 @@ impl Accumulator for RustAccumulator { } fn evaluate(&mut self) -> Result { - Python::attach(|py| { + Python::attach(|py| -> PyResult { let value = self.accum.bind(py).call_method0("evaluate")?; match value.extract::() { Ok(py_scalar) => Ok(py_scalar.0), - Err(_) => py_obj_to_scalar_value(py, value.into_py(py)), + Err(_) => py_obj_to_scalar_value(py, value.unbind()), } }) .map_err(|e| DataFusionError::Execution(format!("{e}"))) @@ -79,7 +79,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -100,7 +100,7 @@ impl Accumulator for RustAccumulator { .iter() .map(|state| { state - .into_data() + .to_data() .to_pyarrow(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) }) @@ -125,7 +125,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -150,7 +150,7 @@ impl Accumulator for RustAccumulator { } pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |args| -> Result> { + Arc::new(move |_args| -> Result> { let accum = Python::attach(|py| { accum .call0(py) diff --git a/src/utils.rs b/src/utils.rs index afc39fd86..d0f689fc5 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -192,9 +192,12 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult()?; - let array_type = pa.getattr("Array")?.downcast::()?; - let chunked_array_type = pa.getattr("ChunkedArray")?.downcast::()?; + let scalar_attr = pa.getattr("Scalar")?; + let scalar_type = scalar_attr.downcast::()?; + let array_attr = pa.getattr("Array")?; + let array_type = array_attr.downcast::()?; + let chunked_array_attr = pa.getattr("ChunkedArray")?; + let chunked_array_type = chunked_array_attr.downcast::()?; let obj_ref = obj.bind(py); @@ -206,9 +209,9 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult Date: Thu, 5 Feb 2026 21:48:38 +0800 Subject: [PATCH 10/20] Fix array data binding in py_obj_to_scalar_value function --- src/utils.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.rs b/src/utils.rs index 2f04826f6..3b97ffb88 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -229,7 +229,7 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult Date: Fri, 6 Feb 2026 09:12:25 -0500 Subject: [PATCH 11/20] Implement single point for scalar conversion from python objects --- src/config.rs | 6 +-- src/dataframe.rs | 9 ++-- src/pyarrow_util.rs | 117 ++++++++++++++++++++++++++++++++++++++++---- src/udaf.rs | 12 ++--- src/udwf.rs | 1 - src/utils.rs | 57 --------------------- 6 files changed, 117 insertions(+), 85 deletions(-) diff --git a/src/config.rs b/src/config.rs index 583dea7ef..38936e6c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,8 +22,8 @@ use parking_lot::RwLock; use pyo3::prelude::*; use pyo3::types::*; +use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionResult; -use crate::utils::py_obj_to_scalar_value; #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { @@ -65,9 +65,9 @@ impl PyConfig { /// Set a configuration option pub fn set(&self, key: &str, value: Py, py: Python) -> PyDataFusionResult<()> { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let mut options = self.config.write(); - options.set(key, scalar_value.to_string().as_str())?; + options.set(key, scalar_value.0.to_string().as_str())?; Ok(()) } diff --git a/src/dataframe.rs b/src/dataframe.rs index 94105d7ea..0b6eaf2a0 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -48,6 +48,7 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use pyo3::PyErr; +use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::{to_sort_expressions, PySortExpr}; use crate::expr::PyExpr; @@ -55,9 +56,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::{poll_next_batch, PyRecordBatchStream}; use crate::sql::logical::PyLogicalPlan; use crate::table::{PyTable, TempViewTable}; -use crate::utils::{ - is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future, -}; +use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future}; /// File-level static CStr for the Arrow array stream capsule name. static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream"); @@ -1191,14 +1190,14 @@ impl PyDataFrame { columns: Option>, py: Python, ) -> PyDataFusionResult { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let cols = match columns { Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(), None => Vec::new(), // Empty vector means fill null for all columns }; - let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?; + let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?; Ok(Self::new(df)) } } diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index 264cfd342..6221be1ad 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -17,8 +17,13 @@ //! Conversions between PyArrow and DataFusion types -use arrow::array::{Array, ArrayData}; +use std::sync::Arc; + +use arrow::array::{make_array, Array, ArrayData, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::Field; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use datafusion::common::exec_err; use datafusion::scalar::ScalarValue; use pyo3::types::{PyAnyMethods, PyList}; use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; @@ -26,21 +31,113 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionError; +fn pyobj_extract_scalar_via_capsule( + value: &Bound<'_, PyAny>, + as_list_array: bool, +) -> PyResult { + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + + if as_list_array { + let field = Arc::new(Field::new_list_field( + array.data_type().clone(), + array.nulls().is_some(), + )); + let offsets = OffsetBuffer::from_lengths(vec![array.len()]); + let list_array = ListArray::new(field, offsets, array, None); + Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))) + } else { + let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + Ok(PyScalarValue(scalar)) + } +} + impl FromPyArrow for PyScalarValue { fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); - let typ = value.getattr("type")?; + let pyarrow_mod = py.import("pyarrow"); - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [value])?; - let array = factory.call1((args, typ))?; + // Is it a PyArrow object? + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar_type = pa.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + let typ = value.getattr("type")?; - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, [value])?; + let array = factory.call1((args, typ))?; - Ok(PyScalarValue(scalar)) + return pyobj_extract_scalar_via_capsule(&array, false); + } + + let array_type = pa.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a NanoArrow scalar? + if let Ok(na) = py.import("nanoarrow") { + let type_name = value.get_type().repr()?; + if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = na.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a arro3 scalar? + if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) { + let scalar_type = arro3.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = arro3.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Does it have a PyCapsule interface but isn't one of our known libraries? + // If so do our "best guess". Try checking type name, and if that fails + // return a single value if the length is 1 and return a List value otherwise + if value.hasattr("__arrow_c_array__")? { + let type_name = value.get_type().repr()?; + if type_name.contains("Scalar")? { + return pyobj_extract_scalar_via_capsule(value, false); + } + if type_name.contains("Array")? { + return pyobj_extract_scalar_via_capsule(value, true); + } + + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + if array.len() == 1 { + let scalar = + ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + return Ok(PyScalarValue(scalar)); + } else { + let field = Arc::new(Field::new_list_field( + array.data_type().clone(), + array.nulls().is_some(), + )); + let offsets = OffsetBuffer::from_lengths(vec![array.len()]); + let list_array = ListArray::new(field, offsets, array, None); + return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))); + } + } + + // Last attempt - try to create a PyArrow scalar from a plain Python object + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar = pa.call_method1("scalar", (value,))?; + + PyScalarValue::from_pyarrow_bound(&scalar) + } else { + exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)? + } } } diff --git a/src/udaf.rs b/src/udaf.rs index 883170adf..24ef1f6d3 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple}; use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule}; +use crate::utils::{parse_volatility, validate_pycapsule}; #[derive(Debug)] struct RustAccumulator { @@ -52,10 +52,7 @@ impl Accumulator for RustAccumulator { let mut scalars = Vec::new(); for item in values.try_iter()? { let item: Bound<'_, PyAny> = item?; - let scalar = match item.extract::() { - Ok(py_scalar) => py_scalar.0, - Err(_) => py_obj_to_scalar_value(py, item.unbind())?, - }; + let scalar = item.extract::()?.0; scalars.push(scalar); } Ok(scalars) @@ -66,10 +63,7 @@ impl Accumulator for RustAccumulator { fn evaluate(&mut self) -> Result { Python::attach(|py| -> PyResult { let value = self.accum.bind(py).call_method0("evaluate")?; - match value.extract::() { - Ok(py_scalar) => Ok(py_scalar.0), - Err(_) => py_obj_to_scalar_value(py, value.unbind()), - } + value.extract::().map(|v| v.0) }) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } diff --git a/src/udwf.rs b/src/udwf.rs index 86310609c..6b4f07c36 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator { } fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { - println!("evaluate all called with number of values {}", values.len()); Python::attach(|py| { let py_values = PyList::new( py, diff --git a/src/utils.rs b/src/utils.rs index 3b97ffb88..4b45f29bf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,11 +19,6 @@ use std::future::Future; use std::sync::{Arc, OnceLock}; use std::time::Duration; -use datafusion::arrow::array::{make_array, ArrayData, ListArray}; -use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use datafusion::arrow::datatypes::Field; -use datafusion::arrow::pyarrow::FromPyArrow; -use datafusion::common::ScalarValue; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; @@ -37,7 +32,6 @@ use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio::time::sleep; -use crate::common::data_type::PyScalarValue; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::TokioRuntime; @@ -203,57 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>( } } -pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult { - // convert Python object to PyScalarValue to ScalarValue - - let pa = py.import("pyarrow")?; - let scalar_attr = pa.getattr("Scalar")?; - let scalar_type = scalar_attr.downcast::()?; - let array_attr = pa.getattr("Array")?; - let array_type = array_attr.downcast::()?; - let chunked_array_attr = pa.getattr("ChunkedArray")?; - let chunked_array_type = chunked_array_attr.downcast::()?; - - let obj_ref = obj.bind(py); - - if obj_ref.is_instance(scalar_type)? { - let py_scalar = PyScalarValue::extract_bound(obj_ref) - .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; - return Ok(py_scalar.into()); - } - - if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? { - let array_obj = if obj_ref.is_instance(chunked_array_type)? { - obj_ref.call_method0("combine_chunks")?.unbind() - } else { - obj_ref.clone().unbind() - }; - let array_bound = array_obj.bind(py); - let array_data = ArrayData::from_pyarrow_bound(array_bound) - .map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?; - let array = make_array(array_data); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32])); - let list_array = Arc::new(ListArray::new( - Arc::new(Field::new_list_field(array.data_type().clone(), true)), - offsets, - array, - None, - )); - - return Ok(ScalarValue::List(list_array)); - } - - // Convert Python object to PyArrow scalar - let scalar = pa.call_method1("scalar", (obj,))?; - - // Convert PyArrow scalar to PyScalarValue - let py_scalar = PyScalarValue::extract_bound(scalar.as_ref()) - .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; - - // Convert PyScalarValue to ScalarValue - Ok(py_scalar.into()) -} - pub(crate) fn extract_logical_extension_codec( py: Python, obj: Option>, From 76dee1ccca291b64da853488f8df5739f6a7e1b8 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 6 Feb 2026 10:32:43 -0500 Subject: [PATCH 12/20] Add unit tests and simplify python wrapper for literal --- python/datafusion/expr.py | 3 --- python/tests/test_expr.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 695fe7c49..9df58f52a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) return Expr(expr_internal.RawExpr.literal(value)) @staticmethod @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - value = value if isinstance(value, pa.Scalar) else pa.scalar(value) return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata)) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 39e48f7c3..6ff3f4004 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -20,6 +20,8 @@ from datetime import date, datetime, time, timezone from decimal import Decimal +import arro3.core +import nanoarrow import pyarrow as pa import pytest from datafusion import ( @@ -980,6 +982,34 @@ def test_literal_metadata(ctx): assert expected_field.metadata == actual_field.metadata +def test_scalar_conversion() -> None: + expected_value = lit(1) + assert str(expected_value) == "Expr(Int64(1))" + + # Test pyarrow imports + assert expected_value == lit(pa.scalar(1)) + assert expected_value == lit(pa.scalar(1, type=pa.int32())) + + # Test nanoarrow + na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0] + assert expected_value == lit(na_scalar) + + # Test pyo3 + arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_scalar) + + expected_value = lit([1, 2, 3]) + assert str(expected_value) == "Expr(List([1, 2, 3]))" + + assert expected_value == lit(pa.scalar([1, 2, 3])) + + na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32()) + assert expected_value == lit(na_array) + + arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_array) + + def test_ensure_expr(): e = col("a") assert ensure_expr(e) is e.expr From 85ee4f7eb05f689dd794b7b2fcf38df2d20a0f73 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 11 Feb 2026 08:18:57 -0500 Subject: [PATCH 13/20] Add nanoarrow and arro3-core to dev dependencies. Sort the dependencies alphabetically. --- pyproject.toml | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 497943a34..3fa93728b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,27 +138,29 @@ ignore-words-list = [ [dependency-groups] dev = [ + "arro3-core==0.6.5", + "codespell==2.4.1", "maturin>=1.8.1", + "nanoarrow==0.8.0", "numpy>1.25.0;python_version<'3.14'", "numpy>=2.3.2;python_version>='3.14'", - "pyarrow>=19.0.0", "pre-commit>=4.3.0", - "pyyaml>=6.0.3", + "pyarrow>=19.0.0", + "pygithub==2.5.0", "pytest>=7.4.4", "pytest-asyncio>=0.23.3", + "pyyaml>=6.0.3", "ruff>=0.9.1", "toml>=0.10.2", - "pygithub==2.5.0", - "codespell==2.4.1", ] docs = [ - "sphinx>=7.1.2", - "pydata-sphinx-theme==0.8.0", - "myst-parser>=3.0.1", - "jinja2>=3.1.5", "ipython>=8.12.3", + "jinja2>=3.1.5", + "myst-parser>=3.0.1", "pandas>=2.0.3", "pickleshare>=0.7.5", - "sphinx-autoapi>=3.4.0", + "pydata-sphinx-theme==0.8.0", "setuptools>=75.3.0", + "sphinx>=7.1.2", + "sphinx-autoapi>=3.4.0", ] From 9d0ac50b0494a97a37d4c520f937a4beb22dda3b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 11 Feb 2026 08:26:26 -0500 Subject: [PATCH 14/20] Refactor common code into helper function so we do not duplicate it. --- src/pyarrow_util.rs | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index 6221be1ad..c22a59562 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use arrow::array::{make_array, Array, ArrayData, ListArray}; +use arrow::array::{make_array, Array, ArrayData, ArrayRef, ListArray}; use arrow::buffer::OffsetBuffer; use arrow::datatypes::Field; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; @@ -31,13 +31,7 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionError; -fn pyobj_extract_scalar_via_capsule( - value: &Bound<'_, PyAny>, - as_list_array: bool, -) -> PyResult { - let array_data = ArrayData::from_pyarrow_bound(value)?; - let array = make_array(array_data); - +fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult { if as_list_array { let field = Arc::new(Field::new_list_field( array.data_type().clone(), @@ -52,6 +46,16 @@ fn pyobj_extract_scalar_via_capsule( } } +fn pyobj_extract_scalar_via_capsule( + value: &Bound<'_, PyAny>, + as_list_array: bool, +) -> PyResult { + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + + array_to_scalar_value(array, as_list_array) +} + impl FromPyArrow for PyScalarValue { fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); @@ -115,19 +119,9 @@ impl FromPyArrow for PyScalarValue { let array_data = ArrayData::from_pyarrow_bound(value)?; let array = make_array(array_data); - if array.len() == 1 { - let scalar = - ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; - return Ok(PyScalarValue(scalar)); - } else { - let field = Arc::new(Field::new_list_field( - array.data_type().clone(), - array.nulls().is_some(), - )); - let offsets = OffsetBuffer::from_lengths(vec![array.len()]); - let list_array = ListArray::new(field, offsets, array, None); - return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))); - } + + let as_array_list = array.len() != 1; + return array_to_scalar_value(array, as_array_list); } // Last attempt - try to create a PyArrow scalar from a plain Python object From 5ac0164b36cee726ae06344ca2577b883a622fe6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 11 Feb 2026 08:38:30 -0500 Subject: [PATCH 15/20] Update import path to access Scalar type --- src/pyarrow_util.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index c22a59562..544cecc24 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -83,8 +83,8 @@ impl FromPyArrow for PyScalarValue { // Is it a NanoArrow scalar? if let Ok(na) = py.import("nanoarrow") { - let type_name = value.get_type().repr()?; - if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? { + let scalar_type = py.import("nanoarrow.array")?.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { return pyobj_extract_scalar_via_capsule(value, false); } let array_type = na.getattr("Array")?; From 1f7da06cbea68f28be38f9655a9d0781914e06fb Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 11 Feb 2026 08:59:48 -0500 Subject: [PATCH 16/20] Add test for generic python objects that support the C interface --- python/tests/test_expr.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6ff3f4004..92251827b 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -983,6 +983,15 @@ def test_literal_metadata(ctx): def test_scalar_conversion() -> None: + class WrappedPyArrow: + """Wrapper class for testing __arrow_c_array__.""" + + def __init__(self, val: pa.Array) -> None: + self.val = val + + def __arrow_c_array__(self, requested_schema=None): + return self.val.__arrow_c_array__(requested_schema=requested_schema) + expected_value = lit(1) assert str(expected_value) == "Expr(Int64(1))" @@ -998,6 +1007,9 @@ def test_scalar_conversion() -> None: arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) assert expected_value == lit(arro3_scalar) + generic_scalar = WrappedPyArrow(pa.array([1])) + assert expected_value == lit(generic_scalar) + expected_value = lit([1, 2, 3]) assert str(expected_value) == "Expr(List([1, 2, 3]))" @@ -1009,6 +1021,9 @@ def test_scalar_conversion() -> None: arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) assert expected_value == lit(arro3_array) + generic_array = WrappedPyArrow(pa.array([1, 2, 3])) + assert expected_value == lit(generic_array) + def test_ensure_expr(): e = col("a") From 390d753de95b6a794ed0e85f877b8c1cd6f9af74 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Feb 2026 09:12:33 -0500 Subject: [PATCH 17/20] Update unit test to pass back either pyarrow array or array wrapped as scalar --- python/tests/test_udaf.py | 51 ++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index cfbbbca1c..9fb6c4ca0 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -61,11 +61,14 @@ def state(self) -> list[pa.Scalar]: class CollectTimestamps(Accumulator): - def __init__(self): + def __init__(self, wrap_in_scalar: bool): self._values: list[datetime] = [] + self.wrap_in_scalar = wrap_in_scalar def state(self) -> list[pa.Scalar]: - return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))] + if self.wrap_in_scalar: + return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))] + return [pa.array(self._values, type=pa.timestamp("ns"))] def update(self, values: pa.Array) -> None: self._values.extend(values.to_pylist()) @@ -76,7 +79,9 @@ def merge(self, states: list[pa.Array]) -> None: self._values.extend(state) def evaluate(self) -> pa.Scalar: - return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns"))) + if self.wrap_in_scalar: + return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns"))) + return pa.array(self._values, type=pa.timestamp("ns")) @pytest.fixture @@ -240,28 +245,46 @@ def test_register_udaf(ctx, df) -> None: assert df_result.collect()[0][0][0].as_py() == 14.0 -def test_udaf_list_timestamp_return(ctx) -> None: - timestamps = [ +@pytest.mark.parametrize("wrap_in_scalar", [True, False]) +def test_udaf_list_timestamp_return(ctx, wrap_in_scalar) -> None: + timestamps1 = [ datetime(2024, 1, 1, tzinfo=timezone.utc), datetime(2024, 1, 2, tzinfo=timezone.utc), ] - batch = pa.RecordBatch.from_arrays( - [pa.array(timestamps, type=pa.timestamp("ns"))], + timestamps2 = [ + datetime(2024, 1, 3, tzinfo=timezone.utc), + datetime(2024, 1, 4, tzinfo=timezone.utc), + ] + batch1 = pa.RecordBatch.from_arrays( + [pa.array(timestamps1, type=pa.timestamp("ns"))], names=["ts"], ) - df = ctx.create_dataframe([[batch]], name="timestamp_table") + batch2 = pa.RecordBatch.from_arrays( + [pa.array(timestamps2, type=pa.timestamp("ns"))], + names=["ts"], + ) + df = ctx.create_dataframe([[batch1], [batch2]], name="timestamp_table") + + list_type = pa.list_( + pa.field("item", type=pa.timestamp("ns"), nullable=wrap_in_scalar) + ) collect = udaf( - CollectTimestamps, + lambda: CollectTimestamps(wrap_in_scalar), pa.timestamp("ns"), - pa.list_(pa.timestamp("ns")), - [pa.list_(pa.timestamp("ns"))], + list_type, + [list_type], volatility="immutable", ) result = df.aggregate([], [collect(column("ts"))]).collect()[0] - assert result.column(0) == pa.array( - [timestamps], - type=pa.list_(pa.timestamp("ns")), + # There is no guarantee about the ordering of the batches, so perform a sort + # to get consistent results. Alternatively we could sort on evaluate(). + assert ( + result.column(0).values.sort() + == pa.array( + [[*timestamps1, *timestamps2]], + type=list_type, + ).values ) From 67a6bc1e2223eb8f587ea0f58b4ec9e91dd64764 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Feb 2026 09:52:29 -0500 Subject: [PATCH 18/20] Update tests to pass back raw python values or pyarrow scalar --- python/tests/test_udaf.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 9fb6c4ca0..8cd480e37 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -28,23 +28,28 @@ class Summarize(Accumulator): """Interface of a user-defined accumulation.""" - def __init__(self, initial_value: float = 0.0): - self._sum = pa.scalar(initial_value) + def __init__(self, initial_value: float = 0.0, as_scalar: bool = False): + self._sum = initial_value + self.as_scalar = as_scalar def state(self) -> list[pa.Scalar]: + if self.as_scalar: + return [pa.scalar(self._sum)] return [self._sum] def update(self, values: pa.Array) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) + self._sum = self._sum + pc.sum(values).as_py() def merge(self, states: list[pa.Array]) -> None: # Not nice since pyarrow scalars can't be summed yet. # This breaks on `None` - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) + self._sum = self._sum + pc.sum(states[0]).as_py() def evaluate(self) -> pa.Scalar: + if self.as_scalar: + return pa.scalar(self._sum) return self._sum @@ -163,11 +168,12 @@ def summarize(): assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) -def test_udaf_aggregate_with_arguments(df): +@pytest.mark.parametrize("as_scalar", [True, False]) +def test_udaf_aggregate_with_arguments(df, as_scalar): bias = 10.0 summarize = udaf( - lambda: Summarize(bias), + lambda: Summarize(initial_value=bias, as_scalar=as_scalar), pa.float64(), pa.float64(), [pa.float64()], From 27fa92af18cb195760d3f63ecae08f8261e3ce0b Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Feb 2026 10:41:01 -0500 Subject: [PATCH 19/20] Expand on user documentation for how to return list arrays --- .../common-operations/udf-and-udfa.rst | 34 ++++++++++++------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index feed436b2..f669721a3 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -123,7 +123,7 @@ also see how the inputs to ``update`` and ``merge`` differ. .. code-block:: python - import pyarrow + import pyarrow as pa import pyarrow.compute import datafusion from datafusion import col, udaf, Accumulator @@ -136,16 +136,16 @@ also see how the inputs to ``update`` and ``merge`` differ. def __init__(self): self._sum = 0.0 - def update(self, values_a: pyarrow.Array, values_b: pyarrow.Array) -> None: + def update(self, values_a: pa.Array, values_b: pa.Array) -> None: self._sum = self._sum + pyarrow.compute.sum(values_a).as_py() - pyarrow.compute.sum(values_b).as_py() - def merge(self, states: List[pyarrow.Array]) -> None: + def merge(self, states: list[pa.Array]) -> None: self._sum = self._sum + pyarrow.compute.sum(states[0]).as_py() - def state(self) -> pyarrow.Array: - return pyarrow.array([self._sum]) + def state(self) -> list[pa.Scalar]: + return [pyarrow.scalar(self._sum)] - def evaluate(self) -> pyarrow.Scalar: + def evaluate(self) -> pa.Scalar: return pyarrow.scalar(self._sum) ctx = datafusion.SessionContext() @@ -156,7 +156,7 @@ also see how the inputs to ``update`` and ``merge`` differ. } ) - my_udaf = udaf(MyAccumulator, [pyarrow.float64(), pyarrow.float64()], pyarrow.float64(), [pyarrow.float64()], 'stable') + my_udaf = udaf(MyAccumulator, [pa.float64(), pa.float64()], pa.float64(), [pa.float64()], 'stable') df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")]) @@ -164,12 +164,20 @@ FAQ ^^^ **How do I return a list from a UDAF?** -Use a list-valued scalar and declare list types for both the return and state -definitions. Returning a ``pyarrow.Array`` from ``evaluate`` is not supported -unless you convert it to a list scalar. For example, in ``evaluate`` you can -return ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and register the -UDAF with ``return_type=pa.list_(pa.timestamp("ms"))`` and -``state_type=[pa.list_(pa.timestamp("ms"))]``. + +Both the ``evaluate`` and the ``state`` functions expect to return scalar values. +If you wish to return a list array as a scalar value, the best practice is to +wrap the values in a ``pyarrow.Scalar`` object. For example, you can return a +timestamp list with ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and +register the appropriate return or state types as +``return_type=pa.list_(pa.timestamp("ms"))`` and +``state_type=[pa.list_(pa.timestamp("ms"))]``, respectively. + +As of DataFusion 52.0.0 , you can pass return any Python object, including a +PyArrow array, as the return value(s) for these functions and DataFusion will +attempt to create a scalar type from the value. DataFusion has been tested to +convert PyArrow, nanoarrow, and arro3 objects as well as primitive data types +like integers, strings, and so on. Window Functions ---------------- From a10ee5a5a23b3f9cd23d0cff6374d5e61cdc59b2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 18 Feb 2026 10:53:47 -0500 Subject: [PATCH 20/20] More user documentation --- python/datafusion/user_defined.py | 30 +++++++++++++++++------------- src/common/data_type.rs | 3 +++ src/pyarrow_util.rs | 7 +++++++ 3 files changed, 27 insertions(+), 13 deletions(-) diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index d4ebfe049..d4e5302b5 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -298,7 +298,16 @@ class Accumulator(metaclass=ABCMeta): @abstractmethod def state(self) -> list[pa.Scalar]: - """Return the current state.""" + """Return the current state. + + While this function template expects PyArrow Scalar values return type, + you can return any value that can be converted into a Scalar. This + includes basic Python data types such as integers and strings. In + addition to primitive types, we currently support PyArrow, nanoarrow, + and arro3 objects in addition to primitive data types. Other objects + that support the Arrow FFI standard will be given a "best attempt" at + conversion to scalar objects. + """ @abstractmethod def update(self, *values: pa.Array) -> None: @@ -312,18 +321,13 @@ def merge(self, states: list[pa.Array]) -> None: def evaluate(self) -> pa.Scalar: """Return the resultant value. - If you need to return a list, wrap it in a scalar with the correct - list type, for example:: - - import pyarrow as pa - - return pa.scalar( - [pa.scalar("2024-01-01T00:00:00Z")], - type=pa.list_(pa.timestamp("ms")), - ) - - Returning a ``pyarrow.Array`` from ``evaluate`` is not supported unless - you explicitly convert it to a list-valued scalar. + While this function template expects a PyArrow Scalar value return type, + you can return any value that can be converted into a Scalar. This + includes basic Python data types such as integers and strings. In + addition to primitive types, we currently support PyArrow, nanoarrow, + and arro3 objects in addition to primitive data types. Other objects + that support the Arrow FFI standard will be given a "best attempt" at + conversion to scalar objects. """ diff --git a/src/common/data_type.rs b/src/common/data_type.rs index 55848da5c..1ff332ebb 100644 --- a/src/common/data_type.rs +++ b/src/common/data_type.rs @@ -22,6 +22,9 @@ use datafusion::logical_expr::expr::NullTreatment as DFNullTreatment; use pyo3::exceptions::{PyNotImplementedError, PyValueError}; use pyo3::prelude::*; +/// A [`ScalarValue`] wrapped in a Python object. This struct allows for conversion +/// from a variety of Python objects into a [`ScalarValue`]. See +/// ``FromPyArrow::from_pyarrow_bound`` conversion details. #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)] pub struct PyScalarValue(pub ScalarValue); diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index fcfaca3b8..2a119274f 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -31,6 +31,9 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionError; +/// Helper function to turn an Array into a ScalarValue. If ``as_list_array`` is true, +/// the array will be turned into a ``ListArray``. Otherwise, we extract the first value +/// from the array. fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult { if as_list_array { let field = Arc::new(Field::new_list_field( @@ -46,6 +49,10 @@ fn array_to_scalar_value(array: ArrayRef, as_list_array: bool) -> PyResult, as_list_array: bool,