Skip to content

Commit c6415f8

Browse files
committed
Adding python wrapper classes for FFI scalar udf
1 parent e20884d commit c6415f8

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

python/datafusion/udf.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from abc import ABCMeta, abstractmethod
2323
from enum import Enum
24-
from typing import TYPE_CHECKING, Callable, List, Optional, TypeVar
24+
from typing import TYPE_CHECKING, Callable, List, Optional, Protocol, TypeVar
2525

2626
import pyarrow
2727

@@ -76,6 +76,15 @@ def __str__(self):
7676
return self.name.lower()
7777

7878

79+
class ScalarUDFExportable(Protocol):
80+
"""Type hint for object that has __datafusion_scalar_udf__ PyCapsule.
81+
82+
https://datafusion.apache.org/python/user-guide/common-operations/udf-and-udfa.html
83+
"""
84+
85+
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
86+
87+
7988
class ScalarUDF:
8089
"""Class for performing scalar user-defined functions (UDF).
8190
@@ -86,20 +95,23 @@ class ScalarUDF:
8695
def __init__(
8796
self,
8897
name: Optional[str],
89-
func: Callable[..., _R],
90-
input_types: pyarrow.DataType | list[pyarrow.DataType],
91-
return_type: _R,
92-
volatility: Volatility | str,
98+
func: Callable[..., _R] | df_internal.ScalarUDF,
99+
input_types: pyarrow.DataType | list[pyarrow.DataType] | None,
100+
return_type: Optional[_R],
101+
volatility: Volatility | str | None,
93102
) -> None:
94103
"""Instantiate a scalar user-defined function (UDF).
95104
96105
See helper method :py:func:`udf` for argument details.
97106
"""
98-
if isinstance(input_types, pyarrow.DataType):
99-
input_types = [input_types]
100-
self._udf = df_internal.ScalarUDF(
101-
name, func, input_types, return_type, str(volatility)
102-
)
107+
if isinstance(func, df_internal.ScalarUDF):
108+
self._udf = func
109+
else:
110+
if isinstance(input_types, pyarrow.DataType):
111+
input_types = [input_types]
112+
self._udf = df_internal.ScalarUDF(
113+
name, func, input_types, return_type, str(volatility)
114+
)
103115

104116
def __call__(self, *args: Expr) -> Expr:
105117
"""Execute the UDF.
@@ -110,6 +122,12 @@ def __call__(self, *args: Expr) -> Expr:
110122
args_raw = [arg.expr for arg in args]
111123
return Expr(self._udf.__call__(*args_raw))
112124

125+
@staticmethod
126+
def from_ffi(func: ScalarUDFExportable) -> ScalarUDF:
127+
"""Create a User-Defined Function from a provided PyCapsule."""
128+
udf = df_internal.ScalarUDF.from_ffi(func)
129+
return ScalarUDF(None, udf, None, None, None)
130+
113131
@staticmethod
114132
def udf(
115133
func: Callable[..., _R],

0 commit comments

Comments
 (0)