Skip to content

Commit dcf6145

Browse files
committed
Enhance PyArrow integration by refining type handling and conversion in RustAccumulator and utility functions
1 parent f16c718 commit dcf6145

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

src/udaf.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use std::sync::Arc;
1919

2020
use datafusion::arrow::array::ArrayRef;
2121
use datafusion::arrow::datatypes::DataType;
22-
use datafusion::arrow::pyarrow::PyArrowType;
22+
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2323
use datafusion::common::ScalarValue;
2424
use datafusion::error::{DataFusionError, Result};
2525
use datafusion::logical_expr::{
@@ -47,14 +47,14 @@ impl RustAccumulator {
4747

4848
impl Accumulator for RustAccumulator {
4949
fn state(&mut self) -> Result<Vec<ScalarValue>> {
50-
Python::attach(|py| {
50+
Python::attach(|py| -> PyResult<Vec<ScalarValue>> {
5151
let values = self.accum.bind(py).call_method0("state")?;
5252
let mut scalars = Vec::new();
53-
for item in values.iter()? {
54-
let item = item?;
53+
for item in values.try_iter()? {
54+
let item: Bound<'_, PyAny> = item?;
5555
let scalar = match item.extract::<PyScalarValue>() {
5656
Ok(py_scalar) => py_scalar.0,
57-
Err(_) => py_obj_to_scalar_value(py, item.into_py(py))?,
57+
Err(_) => py_obj_to_scalar_value(py, item.unbind())?,
5858
};
5959
scalars.push(scalar);
6060
}
@@ -64,11 +64,11 @@ impl Accumulator for RustAccumulator {
6464
}
6565

6666
fn evaluate(&mut self) -> Result<ScalarValue> {
67-
Python::attach(|py| {
67+
Python::attach(|py| -> PyResult<ScalarValue> {
6868
let value = self.accum.bind(py).call_method0("evaluate")?;
6969
match value.extract::<PyScalarValue>() {
7070
Ok(py_scalar) => Ok(py_scalar.0),
71-
Err(_) => py_obj_to_scalar_value(py, value.into_py(py)),
71+
Err(_) => py_obj_to_scalar_value(py, value.unbind()),
7272
}
7373
})
7474
.map_err(|e| DataFusionError::Execution(format!("{e}")))
@@ -79,7 +79,7 @@ impl Accumulator for RustAccumulator {
7979
// 1. cast args to Pyarrow array
8080
let py_args = values
8181
.iter()
82-
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
82+
.map(|arg| arg.to_data().to_pyarrow(py).unwrap())
8383
.collect::<Vec<_>>();
8484
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
8585

@@ -100,7 +100,7 @@ impl Accumulator for RustAccumulator {
100100
.iter()
101101
.map(|state| {
102102
state
103-
.into_data()
103+
.to_data()
104104
.to_pyarrow(py)
105105
.map_err(|e| DataFusionError::Execution(format!("{e}")))
106106
})
@@ -125,7 +125,7 @@ impl Accumulator for RustAccumulator {
125125
// 1. cast args to Pyarrow array
126126
let py_args = values
127127
.iter()
128-
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
128+
.map(|arg| arg.to_data().to_pyarrow(py).unwrap())
129129
.collect::<Vec<_>>();
130130
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
131131

@@ -150,7 +150,7 @@ impl Accumulator for RustAccumulator {
150150
}
151151

152152
pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
153-
Arc::new(move |args| -> Result<Box<dyn Accumulator>> {
153+
Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
154154
let accum = Python::attach(|py| {
155155
accum
156156
.call0(py)

src/utils.rs

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,12 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<Sca
192192
// convert Python object to PyScalarValue to ScalarValue
193193

194194
let pa = py.import("pyarrow")?;
195-
let scalar_type = pa.getattr("Scalar")?.downcast::<PyType>()?;
196-
let array_type = pa.getattr("Array")?.downcast::<PyType>()?;
197-
let chunked_array_type = pa.getattr("ChunkedArray")?.downcast::<PyType>()?;
195+
let scalar_attr = pa.getattr("Scalar")?;
196+
let scalar_type = scalar_attr.downcast::<PyType>()?;
197+
let array_attr = pa.getattr("Array")?;
198+
let array_type = array_attr.downcast::<PyType>()?;
199+
let chunked_array_attr = pa.getattr("ChunkedArray")?;
200+
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;
198201

199202
let obj_ref = obj.bind(py);
200203

@@ -206,9 +209,9 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<Sca
206209

207210
if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
208211
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
209-
obj_ref.call_method0("combine_chunks")?.to_object(py)
212+
obj_ref.call_method0("combine_chunks")?.unbind()
210213
} else {
211-
obj_ref.to_object(py)
214+
obj_ref.clone().unbind()
212215
};
213216
let array_bound = array_obj.bind(py);
214217
let array_data = ArrayData::from_pyarrow_bound(&array_bound)

0 commit comments

Comments
 (0)