Skip to content

Commit 345897c

Browse files
committed
Add integration test with scalar udf
1 parent 0db0bf6 commit 345897c

File tree

9 files changed

+306
-140
lines changed

9 files changed

+306
-140
lines changed

examples/ffi-library/Cargo.lock

Lines changed: 177 additions & 130 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

examples/ffi-library/Cargo.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,18 @@ edition = "2021"
2424
datafusion = { version = "45.0.0" }
2525
datafusion-ffi = { version = "45.0.0" }
2626
pyo3 = { version = "0.23", features = ["extension-module", "abi3", "abi3-py38"] }
27-
arrow = { version = "54" }
28-
arrow-array = { version = "54" }
29-
arrow-schema = { version = "54" }
27+
arrow = { version = "54.2.0" }
28+
arrow-array = { version = "54.2.0" }
29+
arrow-schema = { version = "54.2.0" }
3030

3131
[build-dependencies]
3232
pyo3-build-config = "0.23"
3333

3434
[lib]
3535
name = "datafusion_ffi_library"
3636
crate-type = ["cdylib", "rlib"]
37+
38+
# TODO remove once we update datafusion versions to 46
39+
[patch.crates-io]
40+
datafusion = { git = "https://github.com/apache/datafusion.git", rev = "8ab0661a39bd69783b31b949e7a768fb518629e7", features = ["avro", "unicode_expressions"] }
41+
datafusion-ffi = { git = "https://github.com/apache/datafusion.git", rev = "8ab0661a39bd69783b31b949e7a768fb518629e7" }
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pyarrow as pa
19+
from datafusion import SessionContext, col, ffi_udf
20+
from datafusion_ffi_library import IsEvenFunction
21+
22+
23+
def test_table_loading():
24+
ctx = SessionContext()
25+
df = ctx.from_pydict({"a": [-3, -2, None, 0, 1, 2]})
26+
27+
is_even = ffi_udf(IsEvenFunction())
28+
29+
result = df.select(is_even(col("a"))).collect()
30+
df.with_column("is_even", is_even(col("a"))).show()
31+
print(result)
32+
33+
assert len(result) == 1
34+
assert result[0].num_columns == 1
35+
36+
result = [r.column(0) for r in result]
37+
expected = [
38+
pa.array([False, True, None, None, False, True], type=pa.bool_()),
39+
]
40+
41+
assert result == expected

examples/ffi-library/src/lib.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
// under the License.
1717

1818
use pyo3::prelude::*;
19-
use table_provider::MyTableProvider;
19+
mod scalar_udf;
2020
mod table_provider;
2121

2222
#[pymodule]
2323
fn datafusion_ffi_library(m: &Bound<'_, PyModule>) -> PyResult<()> {
24-
m.add_class::<MyTableProvider>()?;
24+
m.add_class::<table_provider::MyTableProvider>()?;
25+
m.add_class::<scalar_udf::IsEvenFunction>()?;
26+
2527
Ok(())
2628
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::{ffi::CString, sync::Arc};
19+
20+
use arrow::array::BooleanArray;
21+
use arrow_array::ArrayRef;
22+
use datafusion::common::cast::as_int64_array;
23+
use datafusion::logical_expr::create_udf;
24+
use datafusion::logical_expr::Volatility;
25+
use datafusion::physical_plan::ColumnarValue;
26+
use datafusion::{arrow::datatypes::DataType, error::Result};
27+
use datafusion_ffi::udf::FFI_ScalarUDF;
28+
use pyo3::{prelude::*, types::PyCapsule};
29+
30+
#[pyclass(name = "IsEvenFunction", module = "datafusion_ffi_library", subclass)]
31+
#[derive(Clone)]
32+
pub struct IsEvenFunction {}
33+
34+
fn is_even(args: &[ColumnarValue]) -> Result<ColumnarValue> {
35+
assert_eq!(args.len(), 1);
36+
let args = ColumnarValue::values_to_arrays(args)?;
37+
38+
let values = as_int64_array(&args[0]).expect("cast failed");
39+
40+
let array = values
41+
.iter()
42+
.map(|value| value.and_then(|v| if v == 0 { None } else { Some(v % 2 == 0) }))
43+
.collect::<BooleanArray>();
44+
45+
Ok(ColumnarValue::from(Arc::new(array) as ArrayRef))
46+
}
47+
48+
#[pymethods]
49+
impl IsEvenFunction {
50+
#[new]
51+
fn new() -> Self {
52+
Self {}
53+
}
54+
55+
fn __datafusion_scalar_udf__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
56+
let name = CString::new("datafusion_scalar_udf").unwrap();
57+
58+
let func = create_udf(
59+
"is_even",
60+
vec![DataType::Int64],
61+
DataType::Boolean,
62+
Volatility::Immutable,
63+
Arc::new(is_even),
64+
);
65+
66+
let ffi_func: FFI_ScalarUDF = (Arc::new(func)).try_into()?;
67+
68+
PyCapsule::new(py, ffi_func, Some(name))
69+
}
70+
}

examples/ffi-library/src/table_provider.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyCapsule};
3131

3232
/// In order to provide a test that demonstrates different sized record batches,
3333
/// the first batch will have num_rows, the second batch num_rows+1, and so on.
34-
#[pyclass(name = "MyTableProvider", module = "ffi_table_provider", subclass)]
34+
#[pyclass(name = "MyTableProvider", module = "datafusion_ffi_library", subclass)]
3535
#[derive(Clone)]
3636
pub struct MyTableProvider {
3737
num_cols: usize,

python/datafusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ def lit(value):
118118

119119

120120
udf = ScalarUDF.udf
121+
ffi_udf = ScalarUDF.ffi_udf
121122

122123
udaf = AggregateUDF.udaf
123124

python/datafusion/udf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def __call__(self, *args: Expr) -> Expr:
123123
return Expr(self._udf.__call__(*args_raw))
124124

125125
@staticmethod
126-
def from_ffi(func: ScalarUDFExportable) -> ScalarUDF:
126+
def ffi_udf(func: ScalarUDFExportable) -> ScalarUDF:
127127
"""Create a User-Defined Function from a provided PyCapsule."""
128-
udf = df_internal.ScalarUDF.from_ffi(func)
128+
udf = df_internal.ScalarUDF.ffi_udf(func)
129129
return ScalarUDF(None, udf, None, None, None)
130130

131131
@staticmethod

src/udf.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ impl PyScalarUDF {
108108
}
109109

110110
#[staticmethod]
111-
fn from_ffi(func: Bound<PyAny>) -> PyResult<Self> {
111+
fn ffi_udf(func: Bound<PyAny>) -> PyResult<Self> {
112112
if func.hasattr("__datafusion_scalar_udf__")? {
113113
let capsule = func.getattr("__datafusion_scalar_udf__")?.call0()?;
114114
let capsule = capsule.downcast::<PyCapsule>()?;
@@ -122,7 +122,7 @@ impl PyScalarUDF {
122122
})
123123
} else {
124124
Err(py_datafusion_err(
125-
"__datafusion_table_provider__ does not exist on Table Provider object.",
125+
"__datafusion_scalar_udf__ does not exist on Scalar UDF object.",
126126
))
127127
}
128128
}

0 commit comments

Comments
 (0)