@@ -124,8 +124,8 @@ def __init__(
124124 self ,
125125 name : str ,
126126 func : Callable [..., _R ],
127- input_types : list [pa .Field ],
128- return_type : _R ,
127+ input_fields : list [pa .Field ],
128+ return_field : _R ,
129129 volatility : Volatility | str ,
130130 ) -> None :
131131 """Instantiate a scalar user-defined function (UDF).
@@ -135,10 +135,10 @@ def __init__(
135135 if hasattr (func , "__datafusion_scalar_udf__" ):
136136 self ._udf = df_internal .ScalarUDF .from_pycapsule (func )
137137 return
138- if isinstance (input_types , pa .DataType ):
139- input_types = [input_types ]
138+ if isinstance (input_fields , pa .DataType ):
139+ input_fields = [input_fields ]
140140 self ._udf = df_internal .ScalarUDF (
141- name , func , input_types , return_type , str (volatility )
141+ name , func , input_fields , return_field , str (volatility )
142142 )
143143
144144 def __repr__ (self ) -> str :
@@ -157,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr:
157157 @overload
158158 @staticmethod
159159 def udf (
160- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
161- return_type : pa .DataType | pa .Field ,
160+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
161+ return_field : pa .DataType | pa .Field ,
162162 volatility : Volatility | str ,
163163 name : str | None = None ,
164164 ) -> Callable [..., ScalarUDF ]: ...
@@ -167,8 +167,8 @@ def udf(
167167 @staticmethod
168168 def udf (
169169 func : Callable [..., _R ],
170- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
171- return_type : pa .DataType | pa .Field ,
170+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
171+ return_field : pa .DataType | pa .Field ,
172172 volatility : Volatility | str ,
173173 name : str | None = None ,
174174 ) -> ScalarUDF : ...
@@ -194,10 +194,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
194194 backed ScalarUDF within a PyCapsule, you can pass this parameter
195195 and ignore the rest. They will be determined directly from the
196196 underlying function. See the online documentation for more information.
197- input_types (list[pa.DataType]): The data types of the arguments
198- to ``func``. This list must be of the same length as the number of
199- arguments.
200- return_type (_R): The data type of the return value from the function.
197+ input_fields (list[pa.Field | pa. DataType]): The data types or Fields
198+ of the arguments to ``func``. This list must be of the same length
199+ as the number of arguments.
200+ return_field (_R): The field of the return value from the function.
201201 volatility (Volatility | str): See `Volatility` for allowed values.
202202 name (Optional[str]): A descriptive name for the function.
203203
@@ -221,8 +221,8 @@ def double_udf(x):
221221
222222 def _function (
223223 func : Callable [..., _R ],
224- input_types : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
225- return_type : pa .DataType | pa .Field ,
224+ input_fields : Sequence [pa .DataType | pa .Field ] | pa .DataType | pa .Field ,
225+ return_field : pa .DataType | pa .Field ,
226226 volatility : Volatility | str ,
227227 name : str | None = None ,
228228 ) -> ScalarUDF :
@@ -234,25 +234,25 @@ def _function(
234234 name = func .__qualname__ .lower ()
235235 else :
236236 name = func .__class__ .__name__ .lower ()
237- input_types = data_types_or_fields_to_field_list (input_types )
238- return_type = data_type_or_field_to_field (return_type , "value" )
237+ input_fields = data_types_or_fields_to_field_list (input_fields )
238+ return_field = data_type_or_field_to_field (return_field , "value" )
239239 return ScalarUDF (
240240 name = name ,
241241 func = func ,
242- input_types = input_types ,
243- return_type = return_type ,
242+ input_fields = input_fields ,
243+ return_field = return_field ,
244244 volatility = volatility ,
245245 )
246246
247247 def _decorator (
248- input_types : list [pa .DataType ],
249- return_type : _R ,
248+ input_fields : list [pa .DataType ],
249+ return_field : _R ,
250250 volatility : Volatility | str ,
251251 name : str | None = None ,
252252 ) -> Callable :
253253 def decorator (func : Callable ) -> Callable :
254254 udf_caller = ScalarUDF .udf (
255- func , input_types , return_type , volatility , name
255+ func , input_fields , return_field , volatility , name
256256 )
257257
258258 @functools .wraps (func )
@@ -283,8 +283,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
283283 return ScalarUDF (
284284 name = name ,
285285 func = func ,
286- input_types = None ,
287- return_type = None ,
286+ input_fields = None ,
287+ return_field = None ,
288288 volatility = None ,
289289 )
290290
0 commit comments