Skip to content

Commit f157c37

Browse files
committed
Add fallback overload for AggregateUDF.udaf
Implement fallback for PyCapsule-backed providers, ensuring type checkers are satisfied without protocol-aware stubs. Update typing imports and cast PyCapsule inputs in AggregateUDF.from_pycapsule for precise constructor typing.
1 parent 6b16285 commit f157c37

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

python/datafusion/user_defined.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,16 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
25+
from typing import (
26+
TYPE_CHECKING,
27+
Any,
28+
Callable,
29+
Optional,
30+
Protocol,
31+
TypeVar,
32+
cast,
33+
overload,
34+
)
2635

2736
import pyarrow as pa
2837

@@ -293,11 +302,11 @@ class AggregateUDF:
293302
def __init__(
294303
self,
295304
name: str,
296-
accumulator: Callable[[], Accumulator],
297-
input_types: list[pa.DataType],
298-
return_type: pa.DataType,
299-
state_type: list[pa.DataType],
300-
volatility: Volatility | str,
305+
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
306+
input_types: list[pa.DataType] | None,
307+
return_type: pa.DataType | None,
308+
state_type: list[pa.DataType] | None,
309+
volatility: Volatility | str | None,
301310
) -> None:
302311
"""Instantiate a user-defined aggregate function (UDAF).
303312
@@ -307,6 +316,18 @@ def __init__(
307316
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
308317
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
309318
return
319+
if (
320+
input_types is None
321+
or return_type is None
322+
or state_type is None
323+
or volatility is None
324+
):
325+
msg = (
326+
"`input_types`, `return_type`, `state_type`, and `volatility` "
327+
"must be provided when `accumulator` is callable."
328+
)
329+
raise TypeError(msg)
330+
310331
self._udaf = df_internal.AggregateUDF(
311332
name,
312333
accumulator,
@@ -350,6 +371,14 @@ def udaf(
350371
name: Optional[str] = None,
351372
) -> AggregateUDF: ...
352373

374+
@overload
375+
@staticmethod
376+
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
377+
378+
@overload
379+
@staticmethod
380+
def udaf(accum: object) -> AggregateUDF: ...
381+
353382
@staticmethod
354383
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
355384
"""Create a new User-Defined Aggregate Function (UDAF).
@@ -480,16 +509,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
480509
return _decorator(*args, **kwargs)
481510

482511
@staticmethod
483-
def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
512+
def from_pycapsule(func: AggregateUDFExportable | object) -> AggregateUDF:
484513
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.
485514
486515
This function will instantiate a Aggregate UDF that uses a DataFusion
487516
AggregateUDF that is exported via the FFI bindings.
488517
"""
489-
name = str(func.__class__)
518+
capsule = cast(AggregateUDFExportable, func)
519+
name = str(capsule.__class__)
490520
return AggregateUDF(
491521
name=name,
492-
accumulator=func,
522+
accumulator=capsule,
493523
input_types=None,
494524
return_type=None,
495525
state_type=None,

0 commit comments

Comments
 (0)