Skip to content

Commit 7067a9c

Browse files
committed
Pass Field information back and forth when using scalar UDFs
1 parent 3227276 commit 7067a9c

File tree

2 files changed

+159
-42
lines changed

2 files changed

+159
-42
lines changed

python/datafusion/user_defined.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from _typeshed import CapsuleType as _PyCapsule
3535

3636
_R = TypeVar("_R", bound=pa.DataType)
37-
from collections.abc import Callable
37+
from collections.abc import Callable, Sequence
3838

3939

4040
class Volatility(Enum):
@@ -81,6 +81,27 @@ def __str__(self) -> str:
8181
return self.name.lower()
8282

8383

84+
def data_type_or_field_to_field(value: pa.DataType | pa.Field, name: str) -> pa.Field:
85+
"""Helper function to return a Field from either a Field or DataType."""
86+
if isinstance(value, pa.Field):
87+
return value
88+
return pa.field(name, type=value)
89+
90+
91+
def data_types_or_fields_to_field_list(
92+
inputs: Sequence[pa.Field | pa.DataType] | pa.Field | pa.DataType,
93+
) -> list[pa.Field]:
94+
"""Helper function to return a list of Fields."""
95+
if isinstance(inputs, pa.DataType):
96+
return [pa.field("value", type=inputs)]
97+
if isinstance(inputs, pa.Field):
98+
return [inputs]
99+
100+
return [
101+
data_type_or_field_to_field(v, f"value_{idx}") for (idx, v) in enumerate(inputs)
102+
]
103+
104+
84105
class ScalarUDFExportable(Protocol):
85106
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
86107

@@ -103,7 +124,7 @@ def __init__(
103124
self,
104125
name: str,
105126
func: Callable[..., _R],
106-
input_types: pa.DataType | list[pa.DataType],
127+
input_types: list[pa.Field],
107128
return_type: _R,
108129
volatility: Volatility | str,
109130
) -> None:
@@ -136,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr:
136157
@overload
137158
@staticmethod
138159
def udf(
139-
input_types: list[pa.DataType],
140-
return_type: _R,
160+
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
161+
return_type: pa.DataType | pa.Field,
141162
volatility: Volatility | str,
142163
name: str | None = None,
143164
) -> Callable[..., ScalarUDF]: ...
@@ -146,8 +167,8 @@ def udf(
146167
@staticmethod
147168
def udf(
148169
func: Callable[..., _R],
149-
input_types: list[pa.DataType],
150-
return_type: _R,
170+
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
171+
return_type: pa.DataType | pa.Field,
151172
volatility: Volatility | str,
152173
name: str | None = None,
153174
) -> ScalarUDF: ...
@@ -200,8 +221,8 @@ def double_udf(x):
200221

201222
def _function(
202223
func: Callable[..., _R],
203-
input_types: list[pa.DataType],
204-
return_type: _R,
224+
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
225+
return_type: pa.DataType | pa.Field,
205226
volatility: Volatility | str,
206227
name: str | None = None,
207228
) -> ScalarUDF:
@@ -213,6 +234,8 @@ def _function(
213234
name = func.__qualname__.lower()
214235
else:
215236
name = func.__class__.__name__.lower()
237+
input_types = data_types_or_fields_to_field_list(input_types)
238+
return_type = data_type_or_field_to_field(return_type, "value")
216239
return ScalarUDF(
217240
name=name,
218241
func=func,

src/udf.rs

Lines changed: 128 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,67 +15,159 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::any::Any;
19+
use std::hash::{Hash, Hasher};
20+
use std::ptr::addr_of;
1821
use std::sync::Arc;
1922

23+
use arrow::datatypes::{Field, FieldRef};
24+
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
2025
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
2126
use datafusion::arrow::datatypes::DataType;
22-
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow};
27+
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
2328
use datafusion::error::DataFusionError;
24-
use datafusion::logical_expr::function::ScalarFunctionImplementation;
25-
use datafusion::logical_expr::{create_udf, ColumnarValue, ScalarUDF, ScalarUDFImpl};
29+
use datafusion::logical_expr::{
30+
ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature,
31+
Volatility,
32+
};
2633
use datafusion_ffi::udf::FFI_ScalarUDF;
34+
use pyo3::ffi::Py_uintptr_t;
2735
use pyo3::prelude::*;
2836
use pyo3::types::{PyCapsule, PyTuple};
2937

3038
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3139
use crate::expr::PyExpr;
3240
use crate::utils::{parse_volatility, validate_pycapsule};
3341

34-
/// Create a Rust callable function from a python function that expects pyarrow arrays
35-
fn pyarrow_function_to_rust(
42+
/// This struct holds the Python written function that is a
43+
/// ScalarUDF.
44+
#[derive(Debug)]
45+
struct PythonFunctionScalarUDF {
46+
name: String,
3647
func: Py<PyAny>,
37-
) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
38-
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
48+
signature: Signature,
49+
return_field: FieldRef,
50+
}
51+
52+
impl PythonFunctionScalarUDF {
53+
fn new(
54+
name: String,
55+
func: Py<PyAny>,
56+
input_fields: Vec<Field>,
57+
return_field: Field,
58+
volatility: Volatility,
59+
) -> Self {
60+
let input_types = input_fields.iter().map(|f| f.data_type().clone()).collect();
61+
let signature = Signature::exact(input_types, volatility);
62+
Self {
63+
name,
64+
func,
65+
signature,
66+
return_field: Arc::new(return_field),
67+
}
68+
}
69+
}
70+
71+
impl Eq for PythonFunctionScalarUDF {}
72+
impl PartialEq for PythonFunctionScalarUDF {
73+
fn eq(&self, other: &Self) -> bool {
74+
self.name == other.name
75+
&& self.signature == other.signature
76+
&& self.return_field == other.return_field
77+
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
78+
}
79+
}
80+
81+
impl Hash for PythonFunctionScalarUDF {
82+
fn hash<H: Hasher>(&self, state: &mut H) {
83+
self.name.hash(state);
84+
self.signature.hash(state);
85+
self.return_field.hash(state);
86+
87+
Python::attach(|py| {
88+
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects
89+
90+
state.write_isize(py_hash);
91+
});
92+
}
93+
}
94+
95+
fn array_to_pyarrow_with_field(
96+
py: Python,
97+
array: ArrayRef,
98+
field: &FieldRef,
99+
) -> PyResult<Py<PyAny>> {
100+
let array = FFI_ArrowArray::new(&array.to_data());
101+
let schema = FFI_ArrowSchema::try_from(field).map_err(py_datafusion_err)?;
102+
103+
let module = py.import("pyarrow")?;
104+
let class = module.getattr("Array")?;
105+
let array = class.call_method1(
106+
"_import_from_c",
107+
(
108+
addr_of!(array) as Py_uintptr_t,
109+
addr_of!(schema) as Py_uintptr_t,
110+
),
111+
)?;
112+
Ok(array.unbind())
113+
}
114+
115+
impl ScalarUDFImpl for PythonFunctionScalarUDF {
116+
fn as_any(&self) -> &dyn Any {
117+
self
118+
}
119+
120+
fn name(&self) -> &str {
121+
&self.name
122+
}
123+
124+
fn signature(&self) -> &Signature {
125+
&self.signature
126+
}
127+
128+
fn return_type(&self, _arg_types: &[DataType]) -> datafusion::common::Result<DataType> {
129+
unimplemented!()
130+
}
131+
132+
fn return_field_from_args(
133+
&self,
134+
_args: ReturnFieldArgs,
135+
) -> datafusion::common::Result<FieldRef> {
136+
Ok(Arc::clone(&self.return_field))
137+
}
138+
139+
fn invoke_with_args(
140+
&self,
141+
args: ScalarFunctionArgs,
142+
) -> datafusion::common::Result<ColumnarValue> {
143+
let num_rows = args.number_rows;
39144
Python::attach(|py| {
40145
// 1. cast args to Pyarrow arrays
41146
let py_args = args
42-
.iter()
43-
.map(|arg| {
44-
arg.into_data()
45-
.to_pyarrow(py)
46-
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))
147+
.args
148+
.into_iter()
149+
.zip(args.arg_fields)
150+
.map(|(arg, field)| {
151+
let array = arg.to_array(num_rows)?;
152+
array_to_pyarrow_with_field(py, array, &field).map_err(to_datafusion_err)
47153
})
48154
.collect::<Result<Vec<_>, _>>()?;
49155
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;
50156

51157
// 2. call function
52-
let value = func
158+
let value = self
159+
.func
53160
.call(py, py_args, None)
54161
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
55162

56163
// 3. cast to arrow::array::Array
57164
let array_data = ArrayData::from_pyarrow_bound(value.bind(py))
58165
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
59-
Ok(make_array(array_data))
166+
Ok(ColumnarValue::Array(make_array(array_data)))
60167
})
61168
}
62169
}
63170

64-
/// Create a DataFusion's UDF implementation from a python function
65-
/// that expects pyarrow arrays. This is more efficient as it performs
66-
/// a zero-copy of the contents.
67-
fn to_scalar_function_impl(func: Py<PyAny>) -> ScalarFunctionImplementation {
68-
// Make the python function callable from rust
69-
let pyarrow_func = pyarrow_function_to_rust(func);
70-
71-
// Convert input/output from datafusion ColumnarValue to arrow arrays
72-
Arc::new(move |args: &[ColumnarValue]| {
73-
let array_refs = ColumnarValue::values_to_arrays(args)?;
74-
let array_result = pyarrow_func(&array_refs)?;
75-
Ok(array_result.into())
76-
})
77-
}
78-
79171
/// Represents a PyScalarUDF
80172
#[pyclass(frozen, name = "ScalarUDF", module = "datafusion", subclass)]
81173
#[derive(Debug, Clone)]
@@ -88,19 +180,21 @@ impl PyScalarUDF {
88180
#[new]
89181
#[pyo3(signature=(name, func, input_types, return_type, volatility))]
90182
fn new(
91-
name: &str,
183+
name: String,
92184
func: Py<PyAny>,
93-
input_types: PyArrowType<Vec<DataType>>,
94-
return_type: PyArrowType<DataType>,
185+
input_types: PyArrowType<Vec<Field>>,
186+
return_type: PyArrowType<Field>,
95187
volatility: &str,
96188
) -> PyResult<Self> {
97-
let function = create_udf(
189+
let py_function = PythonFunctionScalarUDF::new(
98190
name,
191+
func,
99192
input_types.0,
100193
return_type.0,
101194
parse_volatility(volatility)?,
102-
to_scalar_function_impl(func),
103195
);
196+
let function = ScalarUDF::new_from_impl(py_function);
197+
104198
Ok(Self { function })
105199
}
106200

0 commit comments

Comments
 (0)