2222import functools
2323from abc import ABCMeta , abstractmethod
2424from enum import Enum
25- from typing import Any , Callable , Optional , Protocol , Sequence , overload
25+ from typing import TYPE_CHECKING , Any , Callable , Optional , Protocol , TypeVar , overload
2626
2727import pyarrow as pa
2828
2929import datafusion ._internal as df_internal
3030from datafusion .expr import Expr
3131
32+ if TYPE_CHECKING :
33+ _R = TypeVar ("_R" , bound = pa .DataType )
34+
35+
3236class Volatility (Enum ):
3337 """Defines how stable or volatile a function is.
3438
@@ -73,40 +77,6 @@ def __str__(self) -> str:
7377 return self .name .lower ()
7478
7579
76- def _normalize_field (value : pa .DataType | pa .Field , * , default_name : str ) -> pa .Field :
77- if isinstance (value , pa .Field ):
78- return value
79- if isinstance (value , pa .DataType ):
80- return pa .field (default_name , value )
81- msg = "Expected a pyarrow.DataType or pyarrow.Field"
82- raise TypeError (msg )
83-
84-
85- def _normalize_input_fields (
86- values : pa .DataType | pa .Field | Sequence [pa .DataType | pa .Field ],
87- ) -> list [pa .Field ]:
88- if isinstance (values , (pa .DataType , pa .Field )):
89- sequence : Sequence [pa .DataType | pa .Field ] = [values ]
90- elif isinstance (values , Sequence ) and not isinstance (values , (str , bytes )):
91- sequence = values
92- else :
93- msg = "input_types must be a DataType, Field, or a sequence of them"
94- raise TypeError (msg )
95-
96- return [
97- _normalize_field (value , default_name = f"arg_{ idx } " ) for idx , value in enumerate (sequence )
98- ]
99-
100-
101- def _normalize_return_field (
102- value : pa .DataType | pa .Field ,
103- * ,
104- name : str ,
105- ) -> pa .Field :
106- default_name = f"{ name } _result" if name else "result"
107- return _normalize_field (value , default_name = default_name )
108-
109-
11080class ScalarUDFExportable (Protocol ):
11181 """Type hint for object that has __datafusion_scalar_udf__ PyCapsule."""
11282
@@ -123,9 +93,9 @@ class ScalarUDF:
12393 def __init__ (
12494 self ,
12595 name : str ,
126- func : Callable [..., Any ],
127- input_types : pa .DataType | pa . Field | Sequence [pa .DataType | pa . Field ],
128- return_type : pa . DataType | pa . Field ,
96+ func : Callable [..., _R ],
97+ input_types : pa .DataType | list [pa .DataType ],
98+ return_type : _R ,
12999 volatility : Volatility | str ,
130100 ) -> None :
131101 """Instantiate a scalar user-defined function (UDF).
@@ -135,10 +105,10 @@ def __init__(
135105 if hasattr (func , "__datafusion_scalar_udf__" ):
136106 self ._udf = df_internal .ScalarUDF .from_pycapsule (func )
137107 return
138- normalized_inputs = _normalize_input_fields (input_types )
139- normalized_return = _normalize_return_field ( return_type , name = name )
108+ if isinstance (input_types , pa . DataType ):
109+ input_types = [ input_types ]
140110 self ._udf = df_internal .ScalarUDF (
141- name , func , normalized_inputs , normalized_return , str (volatility )
111+ name , func , input_types , return_type , str (volatility )
142112 )
143113
144114 def __repr__ (self ) -> str :
@@ -157,18 +127,18 @@ def __call__(self, *args: Expr) -> Expr:
157127 @overload
158128 @staticmethod
159129 def udf (
160- input_types : list [pa .DataType | pa . Field ],
161- return_type : pa . DataType | pa . Field ,
130+ input_types : list [pa .DataType ],
131+ return_type : _R ,
162132 volatility : Volatility | str ,
163133 name : Optional [str ] = None ,
164134 ) -> Callable [..., ScalarUDF ]: ...
165135
166136 @overload
167137 @staticmethod
168138 def udf (
169- func : Callable [..., Any ],
170- input_types : list [pa .DataType | pa . Field ],
171- return_type : pa . DataType | pa . Field ,
139+ func : Callable [..., _R ],
140+ input_types : list [pa .DataType ],
141+ return_type : _R ,
172142 volatility : Volatility | str ,
173143 name : Optional [str ] = None ,
174144 ) -> ScalarUDF : ...
@@ -194,11 +164,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
194164 backed ScalarUDF within a PyCapsule, you can pass this parameter
195165 and ignore the rest. They will be determined directly from the
196166 underlying function. See the online documentation for more information.
197- input_types (list[pa.DataType | pa.Field]): The argument types for ``func``.
198- This list must be of the same length as the number of arguments. Pass
199- :class:`pyarrow.Field` instances to preserve extension metadata.
200- return_type (pa.DataType | pa.Field): The return type of the function. Use a
201- :class:`pyarrow.Field` to preserve metadata on extension arrays.
167+ input_types (list[pa.DataType]): The data types of the arguments
168+ to ``func``. This list must be of the same length as the number of
169+ arguments.
170+ return_type (_R): The data type of the return value from the function.
202171 volatility (Volatility | str): See `Volatility` for allowed values.
203172 name (Optional[str]): A descriptive name for the function.
204173
0 commit comments