Skip to content

Commit e20884d

Browse files
committed
Initial commit for scalar udf pycapsule
1 parent 23be92b commit e20884d

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

src/udf.rs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
use std::sync::Arc;
1919

20+
use datafusion_ffi::udf::{FFI_ScalarUDF, ForeignScalarUDF};
21+
use pyo3::types::PyCapsule;
2022
use pyo3::{prelude::*, types::PyTuple};
2123

2224
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
@@ -28,9 +30,9 @@ use datafusion::logical_expr::function::ScalarFunctionImplementation;
2830
use datafusion::logical_expr::ScalarUDF;
2931
use datafusion::logical_expr::{create_udf, ColumnarValue};
3032

31-
use crate::errors::to_datafusion_err;
33+
use crate::errors::{py_datafusion_err, to_datafusion_err};
3234
use crate::expr::PyExpr;
33-
use crate::utils::parse_volatility;
35+
use crate::utils::{parse_volatility, validate_pycapsule};
3436

3537
/// Create a Rust callable function from a python function that expects pyarrow arrays
3638
fn pyarrow_function_to_rust(
@@ -105,6 +107,26 @@ impl PyScalarUDF {
105107
Ok(Self { function })
106108
}
107109

110+
#[staticmethod]
111+
fn from_ffi(func: Bound<PyAny>) -> PyResult<Self> {
112+
if func.hasattr("__datafusion_scalar_udf__")? {
113+
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
114+
let capsule = capsule.downcast::<PyCapsule>()?;
115+
validate_pycapsule(capsule, "datafusion_scalar_udf")?;
116+
117+
let func = unsafe { capsule.reference::<FFI_ScalarUDF>() };
118+
let func: ForeignScalarUDF = func.try_into().map_err(py_datafusion_err)?;
119+
120+
Ok(Self {
121+
function: ScalarUDF::from(func),
122+
})
123+
} else {
124+
Err(py_datafusion_err(
125+
"__datafusion_table_provider__ does not exist on Table Provider object.",
126+
))
127+
}
128+
}
129+
108130
/// creates a new PyExpr with the call of the udf
109131
#[pyo3(signature = (*args))]
110132
fn __call__(&self, args: Vec<PyExpr>) -> PyResult<PyExpr> {

0 commit comments

Comments
 (0)