Skip to content

Commit 68960e3

Browse files
committed
working through functions to register udtf
1 parent b534319 commit 68960e3

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

python/datafusion/context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -752,9 +752,9 @@ def register_table_provider(
752752
"""
753753
self.ctx.register_table_provider(name, provider)
754754

755-
def register_udtf(self, name: str, func: TableFunction) -> None:
755+
def register_udtf(self, func: TableFunction) -> None:
756756
"""Register a user defined table function."""
757-
self.ctx.register_udtf(name, func._udtf)
757+
self.ctx.register_udtf(func._udtf)
758758

759759
def register_record_batches(
760760
self, name: str, partitions: list[list[pa.RecordBatch]]

python/datafusion/user_defined.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -747,7 +747,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
747747

748748
return decorator
749749

750-
751750
class TableFunction:
752751
"""Class for performing user-defined table functions (UDTF).
753752
@@ -756,33 +755,33 @@ class TableFunction:
756755
"""
757756

758757
def __init__(
759-
self,
760-
name: str,
761-
func: Callable[[], any],
758+
self,
759+
name: str,
760+
func: Callable[[], any],
762761
) -> None:
763762
"""Instantiate a user-defined table function (UDTF).
764763
765764
See :py:func:`udtf` for a convenience function and argument
766765
descriptions.
767766
"""
768-
self._udtf = df_internal.user_defined.TableFunction(name, func)
767+
self._udtf = df_internal.TableFunction(name, func)
769768

770769
def __call__(self, *args: Expr) -> Any:
771770
"""Execute the UDTF and return a table provider."""
772771
args_raw = [arg.expr for arg in args]
773-
return Expr(self._udtf.__call__(*args_raw))
772+
return self._udtf.__call__(*args_raw)
774773

775774
@overload
776775
@staticmethod
777776
def udtf(
778-
name: str,
777+
name: str,
779778
) -> Callable[..., Any]: ...
780779

781780
@overload
782781
@staticmethod
783782
def udtf(
784-
func: Callable[[], Any],
785-
name: str,
783+
func: Callable[[], Any],
784+
name: str,
786785
) -> TableFunction: ...
787786

788787
@staticmethod
@@ -791,13 +790,15 @@ def udtf(*args: Any, **kwargs: Any):
791790
if args and callable(args[0]):
792791
# Case 1: Used as a function, require the first parameter to be callable
793792
return TableFunction._create_table_udf(*args, **kwargs)
793+
if args and hasattr(args[0], "__datafusion_table_function__"):
794+
return TableFunction(args[1], args[0])
794795
# Case 2: Used as a decorator with parameters
795796
return TableFunction._create_table_udf_decorator(*args, **kwargs)
796797

797798
@staticmethod
798799
def _create_table_udf(
799-
func: Callable[..., Any],
800-
name: str,
800+
func: Callable[..., Any],
801+
name: str,
801802
) -> TableFunction:
802803
"""Create a TableFunction instance from function arguments."""
803804
if not callable(func):
@@ -810,7 +811,6 @@ def __repr__(self) -> str:
810811
"""User printable representation."""
811812
return self._udtf.__repr__()
812813

813-
814814
# Convenience exports so we can import instead of treating as
815815
# variables at the package root
816816
udf = ScalarUDF.udf

0 commit comments

Comments
 (0)