@@ -747,7 +747,6 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
747747
748748 return decorator
749749
750-
751750class TableFunction :
752751 """Class for performing user-defined table functions (UDTF).
753752
@@ -756,33 +755,33 @@ class TableFunction:
756755 """
757756
758757 def __init__ (
759- self ,
760- name : str ,
761- func : Callable [[], any ],
758+ self ,
759+ name : str ,
760+ func : Callable [[], any ],
762761 ) -> None :
763762 """Instantiate a user-defined table function (UDTF).
764763
765764 See :py:func:`udtf` for a convenience function and argument
766765 descriptions.
767766 """
768- self ._udtf = df_internal .user_defined . TableFunction (name , func )
767+ self ._udtf = df_internal .TableFunction (name , func )
769768
770769 def __call__ (self , * args : Expr ) -> Any :
771770 """Execute the UDTF and return a table provider."""
772771 args_raw = [arg .expr for arg in args ]
773- return Expr ( self ._udtf .__call__ (* args_raw ) )
772+ return self ._udtf .__call__ (* args_raw )
774773
775774 @overload
776775 @staticmethod
777776 def udtf (
778- name : str ,
777+ name : str ,
779778 ) -> Callable [..., Any ]: ...
780779
781780 @overload
782781 @staticmethod
783782 def udtf (
784- func : Callable [[], Any ],
785- name : str ,
783+ func : Callable [[], Any ],
784+ name : str ,
786785 ) -> TableFunction : ...
787786
788787 @staticmethod
@@ -791,13 +790,15 @@ def udtf(*args: Any, **kwargs: Any):
791790 if args and callable (args [0 ]):
792791 # Case 1: Used as a function, require the first parameter to be callable
793792 return TableFunction ._create_table_udf (* args , ** kwargs )
793+ if args and hasattr (args [0 ], "__datafusion_table_function__" ):
794+ return TableFunction (args [1 ], args [0 ])
794795 # Case 2: Used as a decorator with parameters
795796 return TableFunction ._create_table_udf_decorator (* args , ** kwargs )
796797
797798 @staticmethod
798799 def _create_table_udf (
799- func : Callable [..., Any ],
800- name : str ,
800+ func : Callable [..., Any ],
801+ name : str ,
801802 ) -> TableFunction :
802803 """Create a TableFunction instance from function arguments."""
803804 if not callable (func ):
@@ -810,7 +811,6 @@ def __repr__(self) -> str:
810811 """User printable representation."""
811812 return self ._udtf .__repr__ ()
812813
813-
814814# Convenience exports so we can import instead of treating as
815815# variables at the package root
816816udf = ScalarUDF .udf
0 commit comments