Skip to content

Commit f16c718

Browse files
committed
Refactor RustAccumulator and utility functions for improved type handling and conversion from Python objects to Arrow types
1 parent 6742954 commit f16c718

File tree

2 files changed

+58
-73
lines changed

2 files changed

+58
-73
lines changed

src/udaf.rs

Lines changed: 21 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,108 +17,60 @@
1717

1818
use std::sync::Arc;
1919

20-
use datafusion::arrow::array::{Array, ArrayRef};
20+
use datafusion::arrow::array::ArrayRef;
2121
use datafusion::arrow::datatypes::DataType;
22-
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
22+
use datafusion::arrow::pyarrow::PyArrowType;
2323
use datafusion::common::ScalarValue;
2424
use datafusion::error::{DataFusionError, Result};
2525
use datafusion::logical_expr::{
2626
create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF,
2727
};
2828
use datafusion_ffi::udaf::{FFI_AggregateUDF, ForeignAggregateUDF};
2929
use pyo3::prelude::*;
30-
use pyo3::types::{PyCapsule, PyDict, PyTuple, PyType};
30+
use pyo3::types::{PyCapsule, PyTuple};
3131

3232
use crate::common::data_type::PyScalarValue;
3333
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3434
use crate::expr::PyExpr;
35-
use crate::utils::{parse_volatility, validate_pycapsule};
35+
use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule};
3636

3737
#[derive(Debug)]
3838
struct RustAccumulator {
3939
accum: Py<PyAny>,
40-
return_type: DataType,
41-
pyarrow_array_type: Option<Py<PyType>>,
42-
pyarrow_chunked_array_type: Option<Py<PyType>>,
4340
}
4441

4542
impl RustAccumulator {
46-
fn new(accum: Py<PyAny>, return_type: DataType) -> Self {
47-
Self {
48-
accum,
49-
return_type,
50-
pyarrow_array_type: None,
51-
pyarrow_chunked_array_type: None,
52-
}
53-
}
54-
55-
fn ensure_pyarrow_types(&mut self, py: Python<'_>) -> PyResult<(Py<PyType>, Py<PyType>)> {
56-
if self.pyarrow_array_type.is_none() || self.pyarrow_chunked_array_type.is_none() {
57-
let pyarrow = PyModule::import(py, "pyarrow")?;
58-
let array_attr = pyarrow.getattr("Array")?;
59-
let array_type = array_attr.downcast::<PyType>()?;
60-
let chunked_array_attr = pyarrow.getattr("ChunkedArray")?;
61-
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;
62-
self.pyarrow_array_type = Some(array_type.clone().unbind());
63-
self.pyarrow_chunked_array_type = Some(chunked_array_type.clone().unbind());
64-
}
65-
Ok((
66-
self.pyarrow_array_type
67-
.as_ref()
68-
.expect("array type set")
69-
.clone_ref(py),
70-
self.pyarrow_chunked_array_type
71-
.as_ref()
72-
.expect("chunked array type set")
73-
.clone_ref(py),
74-
))
75-
}
76-
77-
fn is_pyarrow_array_like(
78-
&mut self,
79-
py: Python<'_>,
80-
value: &Bound<'_, PyAny>,
81-
) -> PyResult<bool> {
82-
let (array_type, chunked_array_type) = self.ensure_pyarrow_types(py)?;
83-
let array_type = array_type.bind(py);
84-
let chunked_array_type = chunked_array_type.bind(py);
85-
Ok(value.is_instance(array_type)? || value.is_instance(chunked_array_type)?)
43+
fn new(accum: Py<PyAny>) -> Self {
44+
Self { accum }
8645
}
8746
}
8847

8948
impl Accumulator for RustAccumulator {
9049
fn state(&mut self) -> Result<Vec<ScalarValue>> {
9150
Python::attach(|py| {
92-
self.accum
93-
.bind(py)
94-
.call_method0("state")?
95-
.extract::<Vec<PyScalarValue>>()
51+
let values = self.accum.bind(py).call_method0("state")?;
52+
let mut scalars = Vec::new();
53+
for item in values.iter()? {
54+
let item = item?;
55+
let scalar = match item.extract::<PyScalarValue>() {
56+
Ok(py_scalar) => py_scalar.0,
57+
Err(_) => py_obj_to_scalar_value(py, item.into_py(py))?,
58+
};
59+
scalars.push(scalar);
60+
}
61+
Ok(scalars)
9662
})
97-
.map(|v| v.into_iter().map(|x| x.0).collect())
9863
.map_err(|e| DataFusionError::Execution(format!("{e}")))
9964
}
10065

10166
fn evaluate(&mut self) -> Result<ScalarValue> {
10267
Python::attach(|py| {
10368
let value = self.accum.bind(py).call_method0("evaluate")?;
104-
let is_list_type = matches!(
105-
self.return_type,
106-
DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _)
107-
);
108-
if is_list_type && self.is_pyarrow_array_like(py, &value)? {
109-
let pyarrow = PyModule::import(py, "pyarrow")?;
110-
let list_value = value.call_method0("to_pylist")?;
111-
let py_type = self.return_type.to_pyarrow(py)?;
112-
let kwargs = PyDict::new(py);
113-
kwargs.set_item("type", py_type)?;
114-
return pyarrow
115-
.getattr("scalar")?
116-
.call((list_value,), Some(&kwargs))?
117-
.extract::<PyScalarValue>();
69+
match value.extract::<PyScalarValue>() {
70+
Ok(py_scalar) => Ok(py_scalar.0),
71+
Err(_) => py_obj_to_scalar_value(py, value.into_py(py)),
11872
}
119-
value.extract::<PyScalarValue>()
12073
})
121-
.map(|v| v.0)
12274
.map_err(|e| DataFusionError::Execution(format!("{e}")))
12375
}
12476

@@ -204,10 +156,7 @@ pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
204156
.call0(py)
205157
.map_err(|e| DataFusionError::Execution(format!("{e}")))
206158
})?;
207-
Ok(Box::new(RustAccumulator::new(
208-
accum,
209-
args.return_type().clone(),
210-
)))
159+
Ok(Box::new(RustAccumulator::new(accum)))
211160
})
212161
}
213162

src/utils.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,18 @@ use std::future::Future;
1919
use std::sync::{Arc, OnceLock};
2020
use std::time::Duration;
2121

22+
use datafusion::arrow::array::{make_array, ArrayData, ListArray};
23+
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
24+
use datafusion::arrow::datatypes::Field;
25+
use datafusion::arrow::pyarrow::FromPyArrow;
2226
use datafusion::common::ScalarValue;
2327
use datafusion::datasource::TableProvider;
2428
use datafusion::execution::context::SessionContext;
2529
use datafusion::logical_expr::Volatility;
2630
use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
2731
use pyo3::exceptions::PyValueError;
2832
use pyo3::prelude::*;
29-
use pyo3::types::PyCapsule;
33+
use pyo3::types::{PyCapsule, PyType};
3034
use tokio::runtime::Runtime;
3135
use tokio::task::JoinHandle;
3236
use tokio::time::sleep;
@@ -188,6 +192,38 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<Sca
188192
// convert Python object to PyScalarValue to ScalarValue
189193

190194
let pa = py.import("pyarrow")?;
195+
let scalar_type = pa.getattr("Scalar")?.downcast::<PyType>()?;
196+
let array_type = pa.getattr("Array")?.downcast::<PyType>()?;
197+
let chunked_array_type = pa.getattr("ChunkedArray")?.downcast::<PyType>()?;
198+
199+
let obj_ref = obj.bind(py);
200+
201+
if obj_ref.is_instance(scalar_type)? {
202+
let py_scalar = PyScalarValue::extract_bound(obj_ref)
203+
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
204+
return Ok(py_scalar.into());
205+
}
206+
207+
if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
208+
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
209+
obj_ref.call_method0("combine_chunks")?.to_object(py)
210+
} else {
211+
obj_ref.to_object(py)
212+
};
213+
let array_bound = array_obj.bind(py);
214+
let array_data = ArrayData::from_pyarrow_bound(&array_bound)
215+
.map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?;
216+
let array = make_array(array_data);
217+
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32]));
218+
let list_array = Arc::new(ListArray::new(
219+
Arc::new(Field::new_list_field(array.data_type().clone(), true)),
220+
offsets,
221+
array,
222+
None,
223+
));
224+
225+
return Ok(ScalarValue::List(list_array));
226+
}
191227

192228
// Convert Python object to PyArrow scalar
193229
let scalar = pa.call_method1("scalar", (obj,))?;

0 commit comments

Comments
 (0)