@@ -27,7 +27,7 @@ use datafusion::logical_expr::{
2727} ;
2828use datafusion_ffi:: udaf:: { FFI_AggregateUDF , ForeignAggregateUDF } ;
2929use pyo3:: prelude:: * ;
30- use pyo3:: types:: { PyCapsule , PyTuple } ;
30+ use pyo3:: types:: { PyCapsule , PyDict , PyTuple , PyType } ;
3131
3232use crate :: common:: data_type:: PyScalarValue ;
3333use crate :: errors:: { py_datafusion_err, to_datafusion_err, PyDataFusionResult } ;
@@ -37,11 +37,12 @@ use crate::utils::{parse_volatility, validate_pycapsule};
3737#[ derive( Debug ) ]
3838struct RustAccumulator {
3939 accum : Py < PyAny > ,
40+ return_type : DataType ,
4041}
4142
4243impl RustAccumulator {
43- fn new ( accum : Py < PyAny > ) -> Self {
44- Self { accum }
44+ fn new ( accum : Py < PyAny > , return_type : DataType ) -> Self {
45+ Self { accum, return_type }
4546 }
4647}
4748
@@ -59,10 +60,23 @@ impl Accumulator for RustAccumulator {
5960
6061 fn evaluate ( & mut self ) -> Result < ScalarValue > {
6162 Python :: attach ( |py| {
62- self . accum
63- . bind ( py)
64- . call_method0 ( "evaluate" ) ?
65- . extract :: < PyScalarValue > ( )
63+ let value = self . accum . bind ( py) . call_method0 ( "evaluate" ) ?;
64+ let is_list_type = matches ! (
65+ self . return_type,
66+ DataType :: List ( _) | DataType :: LargeList ( _) | DataType :: FixedSizeList ( _, _)
67+ ) ;
68+ if is_list_type && is_pyarrow_array_like ( py, & value) ? {
69+ let pyarrow = PyModule :: import ( py, "pyarrow" ) ?;
70+ let list_value = value. call_method0 ( "to_pylist" ) ?;
71+ let py_type = self . return_type . to_pyarrow ( py) ?;
72+ let kwargs = PyDict :: new ( py) ;
73+ kwargs. set_item ( "type" , py_type) ?;
74+ return pyarrow
75+ . getattr ( "scalar" ) ?
76+ . call ( ( list_value, ) , Some ( kwargs) ) ?
77+ . extract :: < PyScalarValue > ( ) ;
78+ }
79+ value. extract :: < PyScalarValue > ( )
6680 } )
6781 . map ( |v| v. 0 )
6882 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) )
@@ -144,16 +158,26 @@ impl Accumulator for RustAccumulator {
144158}
145159
146160pub fn to_rust_accumulator ( accum : Py < PyAny > ) -> AccumulatorFactoryFunction {
147- Arc :: new ( move |_ | -> Result < Box < dyn Accumulator > > {
161+ Arc :: new ( move |args | -> Result < Box < dyn Accumulator > > {
148162 let accum = Python :: attach ( |py| {
149163 accum
150164 . call0 ( py)
151165 . map_err ( |e| DataFusionError :: Execution ( format ! ( "{e}" ) ) )
152166 } ) ?;
153- Ok ( Box :: new ( RustAccumulator :: new ( accum) ) )
167+ Ok ( Box :: new ( RustAccumulator :: new (
168+ accum,
169+ args. return_type . clone ( ) ,
170+ ) ) )
154171 } )
155172}
156173
174+ fn is_pyarrow_array_like ( py : Python < ' _ > , value : & Bound < ' _ , PyAny > ) -> PyResult < bool > {
175+ let pyarrow = PyModule :: import ( py, "pyarrow" ) ?;
176+ let array_type = pyarrow. getattr ( "Array" ) ?. downcast :: < PyType > ( ) ?;
177+ let chunked_array_type = pyarrow. getattr ( "ChunkedArray" ) ?. downcast :: < PyType > ( ) ?;
178+ Ok ( value. is_instance ( array_type) ? || value. is_instance ( chunked_array_type) ?)
179+ }
180+
157181fn aggregate_udf_from_capsule ( capsule : & Bound < ' _ , PyCapsule > ) -> PyDataFusionResult < AggregateUDF > {
158182 validate_pycapsule ( capsule, "datafusion_aggregate_udf" ) ?;
159183
0 commit comments