Skip to content

Commit da824e6

Browse files
committed
Add overloads for AggregateUDF.__init__ to support different initialization signatures
1 parent f157c37 commit da824e6

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

python/datafusion/user_defined.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,30 @@ class AggregateUDF:
299299
also :py:class:`ScalarUDF` for operating on a row by row basis.
300300
"""
301301

302+
@overload
303+
def __init__(
304+
self,
305+
name: str,
306+
accumulator: Callable[[], Accumulator],
307+
input_types: list[pa.DataType],
308+
return_type: pa.DataType,
309+
state_type: list[pa.DataType],
310+
volatility: Volatility | str,
311+
) -> None:
312+
...
313+
314+
@overload
315+
def __init__(
316+
self,
317+
name: str,
318+
accumulator: AggregateUDFExportable,
319+
input_types: None = ...,
320+
return_type: None = ...,
321+
state_type: None = ...,
322+
volatility: None = ...,
323+
) -> None:
324+
...
325+
302326
def __init__(
303327
self,
304328
name: str,

0 commit comments

Comments
 (0)