2222import functools
2323from abc import ABCMeta , abstractmethod
2424from enum import Enum
25- from typing import TYPE_CHECKING , Any , Callable , Optional , Protocol , TypeVar , overload
25+ from typing import (
26+ TYPE_CHECKING ,
27+ Any ,
28+ Callable ,
29+ Optional ,
30+ Protocol ,
31+ TypeVar ,
32+ cast ,
33+ overload ,
34+ )
2635
2736import pyarrow as pa
2837
@@ -293,11 +302,11 @@ class AggregateUDF:
293302 def __init__ (
294303 self ,
295304 name : str ,
296- accumulator : Callable [[], Accumulator ],
297- input_types : list [pa .DataType ],
298- return_type : pa .DataType ,
299- state_type : list [pa .DataType ],
300- volatility : Volatility | str ,
305+ accumulator : Callable [[], Accumulator ] | AggregateUDFExportable ,
306+ input_types : list [pa .DataType ] | None ,
307+ return_type : pa .DataType | None ,
308+ state_type : list [pa .DataType ] | None ,
309+ volatility : Volatility | str | None ,
301310 ) -> None :
302311 """Instantiate a user-defined aggregate function (UDAF).
303312
@@ -307,6 +316,18 @@ def __init__(
307316 if hasattr (accumulator , "__datafusion_aggregate_udf__" ):
308317 self ._udaf = df_internal .AggregateUDF .from_pycapsule (accumulator )
309318 return
319+ if (
320+ input_types is None
321+ or return_type is None
322+ or state_type is None
323+ or volatility is None
324+ ):
325+ msg = (
326+ "`input_types`, `return_type`, `state_type`, and `volatility` "
327+ "must be provided when `accumulator` is callable."
328+ )
329+ raise TypeError (msg )
330+
310331 self ._udaf = df_internal .AggregateUDF (
311332 name ,
312333 accumulator ,
@@ -350,6 +371,14 @@ def udaf(
350371 name : Optional [str ] = None ,
351372 ) -> AggregateUDF : ...
352373
374+ @overload
375+ @staticmethod
376+ def udaf (accum : AggregateUDFExportable ) -> AggregateUDF : ...
377+
378+ @overload
379+ @staticmethod
380+ def udaf (accum : object ) -> AggregateUDF : ...
381+
353382 @staticmethod
354383 def udaf (* args : Any , ** kwargs : Any ): # noqa: D417, C901
355384 """Create a new User-Defined Aggregate Function (UDAF).
@@ -480,16 +509,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
480509 return _decorator (* args , ** kwargs )
481510
482511 @staticmethod
483- def from_pycapsule (func : AggregateUDFExportable ) -> AggregateUDF :
512+ def from_pycapsule (func : AggregateUDFExportable | object ) -> AggregateUDF :
484513 """Create an Aggregate UDF from AggregateUDF PyCapsule object.
485514
486515 This function will instantiate a Aggregate UDF that uses a DataFusion
487516 AggregateUDF that is exported via the FFI bindings.
488517 """
489- name = str (func .__class__ )
518+ capsule = cast (AggregateUDFExportable , func )
519+ name = str (capsule .__class__ )
490520 return AggregateUDF (
491521 name = name ,
492- accumulator = func ,
522+ accumulator = capsule ,
493523 input_types = None ,
494524 return_type = None ,
495525 state_type = None ,
0 commit comments