Skip to content

Commit d275ee8

Browse files
committed
Add PyCapsule typing protocol and helper detection
Introduce a _PyCapsule typing protocol to enable type checkers to recognize PyCapsule-based registrations. Restrict the AggregateUDF udaf overload to the PyCapsule protocol and update from_pycapsule to wrap raw capsule inputs using the internal binding directly.
1 parent da824e6 commit d275ee8

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

python/datafusion/user_defined.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Callable,
2929
Optional,
3030
Protocol,
31+
TypeGuard,
3132
TypeVar,
3233
cast,
3334
overload,
@@ -92,6 +93,16 @@ class ScalarUDFExportable(Protocol):
9293
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105
9394

9495

96+
class _PyCapsule(Protocol):
97+
"""Lightweight typing proxy for CPython ``PyCapsule`` objects."""
98+
99+
100+
def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
101+
"""Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
102+
103+
return value.__class__.__name__ == "PyCapsule"
104+
105+
95106
class ScalarUDF:
96107
"""Class for performing scalar user-defined functions (UDF).
97108
@@ -401,7 +412,7 @@ def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...
401412

402413
@overload
403414
@staticmethod
404-
def udaf(accum: object) -> AggregateUDF: ...
415+
def udaf(accum: _PyCapsule) -> AggregateUDF: ...
405416

406417
@staticmethod
407418
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
@@ -523,7 +534,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
523534

524535
return decorator
525536

526-
if hasattr(args[0], "__datafusion_aggregate_udf__"):
537+
if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]):
527538
return AggregateUDF.from_pycapsule(args[0])
528539

529540
if args and callable(args[0]):
@@ -533,12 +544,17 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
533544
return _decorator(*args, **kwargs)
534545

535546
@staticmethod
536-
def from_pycapsule(func: AggregateUDFExportable | object) -> AggregateUDF:
547+
def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF:
537548
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.
538549
539550
This function will instantiate a Aggregate UDF that uses a DataFusion
540551
AggregateUDF that is exported via the FFI bindings.
541552
"""
553+
if _is_pycapsule(func):
554+
aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
555+
aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
556+
return aggregate
557+
542558
capsule = cast(AggregateUDFExportable, func)
543559
name = str(capsule.__class__)
544560
return AggregateUDF(

0 commit comments

Comments
 (0)