Skip to content

Commit 6c7e55b

Browse files
committed
Add ArrowArrayExportable class and use it to create pyarrow arrays for python UDFs
1 parent 7067a9c commit 6c7e55b

File tree

6 files changed

+117
-77
lines changed

6 files changed

+117
-77
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ dev = [
141141
"maturin>=1.8.1",
142142
"numpy>1.25.0;python_version<'3.14'",
143143
"numpy>=2.3.2;python_version>='3.14'",
144+
"pyarrow>=19.0.0",
144145
"pre-commit>=4.3.0",
145146
"pyyaml>=6.0.3",
146147
"pytest>=7.4.4",

python/tests/test_udf.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pyarrow as pa
1919
import pytest
2020
from datafusion import column, udf
21+
from datafusion import functions as f
2122

2223

2324
@pytest.fixture
@@ -124,3 +125,26 @@ def udf_with_param(values: pa.Array) -> pa.Array:
124125
result = df2.collect()[0].column(0)
125126

126127
assert result == pa.array([False, True, True])
128+
129+
130+
def test_udf_with_metadata(ctx) -> None:
131+
from uuid import UUID
132+
133+
@udf([pa.string()], pa.uuid(), "stable")
134+
def uuid_from_string(uuid_string):
135+
return pa.array((UUID(s).bytes for s in uuid_string.to_pylist()), pa.uuid())
136+
137+
@udf([pa.uuid()], pa.int64(), "stable")
138+
def uuid_version(uuid):
139+
return pa.array(s.version for s in uuid.to_pylist())
140+
141+
batch = pa.record_batch({"idx": pa.array(range(5))})
142+
results = (
143+
ctx.create_dataframe([[batch]])
144+
.with_column("uuid_string", f.uuid())
145+
.with_column("uuid", uuid_from_string(column("uuid_string")))
146+
.select(uuid_version(column("uuid").alias("uuid_version")))
147+
.collect()
148+
)
149+
150+
assert results[0][0].to_pylist() == [4, 4, 4, 4, 4]

src/array.rs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{Array, ArrayRef};
21+
use arrow::datatypes::{Field, FieldRef};
22+
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
23+
use arrow::pyarrow::ToPyArrow;
24+
use pyo3::prelude::{PyAnyMethods, PyCapsuleMethods};
25+
use pyo3::types::PyCapsule;
26+
use pyo3::{pyclass, pymethods, Bound, PyAny, PyResult, Python};
27+
28+
use crate::errors::PyDataFusionResult;
29+
use crate::utils::validate_pycapsule;
30+
31+
/// A Python object which implements the Arrow PyCapsule for importing
32+
/// into other libraries.
33+
#[pyclass(name = "ArrowArrayExportable", module = "datafusion", frozen)]
34+
#[derive(Clone)]
35+
pub struct PyArrowArrayExportable {
36+
array: ArrayRef,
37+
field: FieldRef,
38+
}
39+
40+
#[pymethods]
41+
impl PyArrowArrayExportable {
42+
#[pyo3(signature = (requested_schema=None))]
43+
fn __arrow_c_array__<'py>(
44+
&'py self,
45+
py: Python<'py>,
46+
requested_schema: Option<Bound<'py, PyCapsule>>,
47+
) -> PyDataFusionResult<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>)> {
48+
let field = if let Some(schema_capsule) = requested_schema {
49+
validate_pycapsule(&schema_capsule, "arrow_schema")?;
50+
51+
let schema_ptr = unsafe { schema_capsule.reference::<FFI_ArrowSchema>() };
52+
let desired_field = Field::try_from(schema_ptr)?;
53+
54+
Arc::new(desired_field)
55+
} else {
56+
Arc::clone(&self.field)
57+
};
58+
59+
let ffi_schema = FFI_ArrowSchema::try_from(&field)?;
60+
let schema_capsule = PyCapsule::new(py, ffi_schema, Some(cr"arrow_schema".into()))?;
61+
62+
let ffi_array = FFI_ArrowArray::new(&self.array.to_data());
63+
let array_capsule = PyCapsule::new(py, ffi_array, Some(cr"arrow_array".into()))?;
64+
65+
Ok((schema_capsule, array_capsule))
66+
}
67+
}
68+
69+
impl ToPyArrow for PyArrowArrayExportable {
70+
fn to_pyarrow<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
71+
let module = py.import("pyarrow")?;
72+
let method = module.getattr("array")?;
73+
let array = method.call((self.clone(),), None)?;
74+
Ok(array)
75+
}
76+
}
77+
78+
impl PyArrowArrayExportable {
79+
pub fn new(array: ArrayRef, field: FieldRef) -> Self {
80+
Self { array, field }
81+
}
82+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ pub mod store;
5252
pub mod table;
5353
pub mod unparser;
5454

55+
mod array;
5556
#[cfg(feature = "substrait")]
5657
pub mod substrait;
5758
#[allow(clippy::borrow_deref_ref)]

src/udf.rs

Lines changed: 6 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717

1818
use std::any::Any;
1919
use std::hash::{Hash, Hasher};
20-
use std::ptr::addr_of;
2120
use std::sync::Arc;
2221

2322
use arrow::datatypes::{Field, FieldRef};
24-
use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema};
25-
use datafusion::arrow::array::{make_array, Array, ArrayData, ArrayRef};
23+
use arrow::pyarrow::ToPyArrow;
24+
use datafusion::arrow::array::{make_array, ArrayData};
2625
use datafusion::arrow::datatypes::DataType;
2726
use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType};
2827
use datafusion::error::DataFusionError;
@@ -31,10 +30,10 @@ use datafusion::logical_expr::{
3130
Volatility,
3231
};
3332
use datafusion_ffi::udf::FFI_ScalarUDF;
34-
use pyo3::ffi::Py_uintptr_t;
3533
use pyo3::prelude::*;
3634
use pyo3::types::{PyCapsule, PyTuple};
3735

36+
use crate::array::PyArrowArrayExportable;
3837
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3938
use crate::expr::PyExpr;
4039
use crate::utils::{parse_volatility, validate_pycapsule};
@@ -92,26 +91,6 @@ impl Hash for PythonFunctionScalarUDF {
9291
}
9392
}
9493

95-
fn array_to_pyarrow_with_field(
96-
py: Python,
97-
array: ArrayRef,
98-
field: &FieldRef,
99-
) -> PyResult<Py<PyAny>> {
100-
let array = FFI_ArrowArray::new(&array.to_data());
101-
let schema = FFI_ArrowSchema::try_from(field).map_err(py_datafusion_err)?;
102-
103-
let module = py.import("pyarrow")?;
104-
let class = module.getattr("Array")?;
105-
let array = class.call_method1(
106-
"_import_from_c",
107-
(
108-
addr_of!(array) as Py_uintptr_t,
109-
addr_of!(schema) as Py_uintptr_t,
110-
),
111-
)?;
112-
Ok(array.unbind())
113-
}
114-
11594
impl ScalarUDFImpl for PythonFunctionScalarUDF {
11695
fn as_any(&self) -> &dyn Any {
11796
self
@@ -149,7 +128,9 @@ impl ScalarUDFImpl for PythonFunctionScalarUDF {
149128
.zip(args.arg_fields)
150129
.map(|(arg, field)| {
151130
let array = arg.to_array(num_rows)?;
152-
array_to_pyarrow_with_field(py, array, &field).map_err(to_datafusion_err)
131+
PyArrowArrayExportable::new(array, field)
132+
.to_pyarrow(py)
133+
.map_err(to_datafusion_err)
153134
})
154135
.collect::<Result<Vec<_>, _>>()?;
155136
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

uv.lock

Lines changed: 3 additions & 52 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)