2121
2222from abc import ABCMeta , abstractmethod
2323from enum import Enum
24- from typing import TYPE_CHECKING , Callable , List , Optional , TypeVar
24+ from typing import TYPE_CHECKING , Callable , List , Optional , Protocol , TypeVar
2525
2626import pyarrow
2727
@@ -76,6 +76,15 @@ def __str__(self):
7676 return self .name .lower ()
7777
7878
79+ class ScalarUDFExportable (Protocol ):
80+ """Type hint for object that has __datafusion_scalar_udf__ PyCapsule.
81+
82+ https://datafusion.apache.org/python/user-guide/common-operations/udf-and-udfa.html
83+ """
84+
85+ def __datafusion_scalar_udf__ (self ) -> object : ... # noqa: D105
86+
87+
7988class ScalarUDF :
8089 """Class for performing scalar user-defined functions (UDF).
8190
@@ -86,20 +95,23 @@ class ScalarUDF:
8695 def __init__ (
8796 self ,
8897 name : Optional [str ],
89- func : Callable [..., _R ],
90- input_types : pyarrow .DataType | list [pyarrow .DataType ],
91- return_type : _R ,
92- volatility : Volatility | str ,
98+ func : Callable [..., _R ] | df_internal . ScalarUDF ,
99+ input_types : pyarrow .DataType | list [pyarrow .DataType ] | None ,
100+ return_type : Optional [ _R ] ,
101+ volatility : Volatility | str | None ,
93102 ) -> None :
94103 """Instantiate a scalar user-defined function (UDF).
95104
96105 See helper method :py:func:`udf` for argument details.
97106 """
98- if isinstance (input_types , pyarrow .DataType ):
99- input_types = [input_types ]
100- self ._udf = df_internal .ScalarUDF (
101- name , func , input_types , return_type , str (volatility )
102- )
107+ if isinstance (func , df_internal .ScalarUDF ):
108+ self ._udf = func
109+ else :
110+ if isinstance (input_types , pyarrow .DataType ):
111+ input_types = [input_types ]
112+ self ._udf = df_internal .ScalarUDF (
113+ name , func , input_types , return_type , str (volatility )
114+ )
103115
104116 def __call__ (self , * args : Expr ) -> Expr :
105117 """Execute the UDF.
@@ -110,6 +122,12 @@ def __call__(self, *args: Expr) -> Expr:
110122 args_raw = [arg .expr for arg in args ]
111123 return Expr (self ._udf .__call__ (* args_raw ))
112124
125+ @staticmethod
126+ def from_ffi (func : ScalarUDFExportable ) -> ScalarUDF :
127+ """Create a User-Defined Function from a provided PyCapsule."""
128+ udf = df_internal .ScalarUDF .from_ffi (func )
129+ return ScalarUDF (None , udf , None , None , None )
130+
113131 @staticmethod
114132 def udf (
115133 func : Callable [..., _R ],
0 commit comments