Skip to content

Commit aa7d35c

Browse files
committed
Update naming from type to field where appropriate
1 parent 77ae1f8 commit aa7d35c

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

python/datafusion/user_defined.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,8 @@ def __init__(
124124
self,
125125
name: str,
126126
func: Callable[..., _R],
127-
input_types: list[pa.Field],
128-
return_type: _R,
127+
input_fields: list[pa.Field],
128+
return_field: _R,
129129
volatility: Volatility | str,
130130
) -> None:
131131
"""Instantiate a scalar user-defined function (UDF).
@@ -135,10 +135,10 @@ def __init__(
135135
if hasattr(func, "__datafusion_scalar_udf__"):
136136
self._udf = df_internal.ScalarUDF.from_pycapsule(func)
137137
return
138-
if isinstance(input_types, pa.DataType):
139-
input_types = [input_types]
138+
if isinstance(input_fields, pa.DataType):
139+
input_fields = [input_fields]
140140
self._udf = df_internal.ScalarUDF(
141-
name, func, input_types, return_type, str(volatility)
141+
name, func, input_fields, return_field, str(volatility)
142142
)
143143

144144
def __repr__(self) -> str:
@@ -157,8 +157,8 @@ def __call__(self, *args: Expr) -> Expr:
157157
@overload
158158
@staticmethod
159159
def udf(
160-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
161-
return_type: pa.DataType | pa.Field,
160+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
161+
return_field: pa.DataType | pa.Field,
162162
volatility: Volatility | str,
163163
name: str | None = None,
164164
) -> Callable[..., ScalarUDF]: ...
@@ -167,8 +167,8 @@ def udf(
167167
@staticmethod
168168
def udf(
169169
func: Callable[..., _R],
170-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
171-
return_type: pa.DataType | pa.Field,
170+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
171+
return_field: pa.DataType | pa.Field,
172172
volatility: Volatility | str,
173173
name: str | None = None,
174174
) -> ScalarUDF: ...
@@ -194,10 +194,10 @@ def udf(*args: Any, **kwargs: Any): # noqa: D417
194194
backed ScalarUDF within a PyCapsule, you can pass this parameter
195195
and ignore the rest. They will be determined directly from the
196196
underlying function. See the online documentation for more information.
197-
input_types (list[pa.DataType]): The data types of the arguments
198-
to ``func``. This list must be of the same length as the number of
199-
arguments.
200-
return_type (_R): The data type of the return value from the function.
197+
input_fields (list[pa.Field | pa.DataType]): The data types or Fields
198+
of the arguments to ``func``. This list must be of the same length
199+
as the number of arguments.
200+
return_field (_R): The field of the return value from the function.
201201
volatility (Volatility | str): See `Volatility` for allowed values.
202202
name (Optional[str]): A descriptive name for the function.
203203
@@ -221,8 +221,8 @@ def double_udf(x):
221221

222222
def _function(
223223
func: Callable[..., _R],
224-
input_types: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
225-
return_type: pa.DataType | pa.Field,
224+
input_fields: Sequence[pa.DataType | pa.Field] | pa.DataType | pa.Field,
225+
return_field: pa.DataType | pa.Field,
226226
volatility: Volatility | str,
227227
name: str | None = None,
228228
) -> ScalarUDF:
@@ -234,25 +234,25 @@ def _function(
234234
name = func.__qualname__.lower()
235235
else:
236236
name = func.__class__.__name__.lower()
237-
input_types = data_types_or_fields_to_field_list(input_types)
238-
return_type = data_type_or_field_to_field(return_type, "value")
237+
input_fields = data_types_or_fields_to_field_list(input_fields)
238+
return_field = data_type_or_field_to_field(return_field, "value")
239239
return ScalarUDF(
240240
name=name,
241241
func=func,
242-
input_types=input_types,
243-
return_type=return_type,
242+
input_fields=input_fields,
243+
return_field=return_field,
244244
volatility=volatility,
245245
)
246246

247247
def _decorator(
248-
input_types: list[pa.DataType],
249-
return_type: _R,
248+
input_fields: list[pa.DataType],
249+
return_field: _R,
250250
volatility: Volatility | str,
251251
name: str | None = None,
252252
) -> Callable:
253253
def decorator(func: Callable) -> Callable:
254254
udf_caller = ScalarUDF.udf(
255-
func, input_types, return_type, volatility, name
255+
func, input_fields, return_field, volatility, name
256256
)
257257

258258
@functools.wraps(func)
@@ -283,8 +283,8 @@ def from_pycapsule(func: ScalarUDFExportable) -> ScalarUDF:
283283
return ScalarUDF(
284284
name=name,
285285
func=func,
286-
input_types=None,
287-
return_type=None,
286+
input_fields=None,
287+
return_field=None,
288288
volatility=None,
289289
)
290290

0 commit comments

Comments
 (0)