Skip to content

Commit e1d3ce2

Browse files
committed
working through functions to register udtf
1 parent 16597db commit e1d3ce2

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

python/datafusion/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,9 @@ def register_table_provider(
752752
"""
753753
self.ctx.register_table_provider(name, provider)
754754

755-
def register_udtf(self, name: str, func: TableFunction) -> None:
755+
def register_udtf(self, func: TableFunction) -> None:
756756
"""Register a user defined table function."""
757-
self.ctx.register_udtf(name, func._udtf)
757+
self.ctx.register_udtf(func._udtf)
758758

759759
def register_record_batches(
760760
self, name: str, partitions: list[list[pa.RecordBatch]]

python/datafusion/udf.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -777,12 +777,12 @@ def __init__(
777777
See :py:func:`udtf` for a convenience function and argument
778778
descriptions.
779779
"""
780-
self._udtf = df_internal.user_defined.TableFunction(name, func)
780+
self._udtf = df_internal.TableFunction(name, func)
781781

782782
def __call__(self, *args: Expr) -> Any:
783783
"""Execute the UDTF and return a table provider."""
784784
args_raw = [arg.expr for arg in args]
785-
return Expr(self._udtf.__call__(*args_raw))
785+
return self._udtf.__call__(*args_raw)
786786

787787
@overload
788788
@staticmethod
@@ -803,6 +803,8 @@ def udtf(*args: Any, **kwargs: Any):
803803
if args and callable(args[0]):
804804
# Case 1: Used as a function, require the first parameter to be callable
805805
return TableFunction._create_table_udf(*args, **kwargs)
806+
if args and hasattr(args[0], "__datafusion_table_function__"):
807+
return TableFunction(args[1], args[0])
806808
# Case 2: Used as a decorator with parameters
807809
return TableFunction._create_table_udf_decorator(*args, **kwargs)
808810

0 commit comments

Comments
 (0)