Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 74 additions & 5 deletions python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,31 @@
import functools
from abc import ABCMeta, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, TypeVar, overload
from typing import (
TYPE_CHECKING,
Any,
Callable,
Optional,
Protocol,
TypeVar,
cast,
overload,
)

import pyarrow as pa
from typing_extensions import TypeGuard

import datafusion._internal as df_internal
from datafusion.expr import Expr

if TYPE_CHECKING:
from _typeshed import CapsuleType as _PyCapsule

_R = TypeVar("_R", bound=pa.DataType)
else:

class _PyCapsule:
"""Lightweight typing proxy for CPython ``PyCapsule`` objects."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

Copy link
Contributor Author

@kosiew kosiew Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_typeshed.CapsuleType only exists for static type checkers (TYPE_CHECKING), so inside the else: branch (runtime, not TYPE_CHECKING) we provide a lightweight runtime proxy to keep the _PyCapsule symbol defined.

Without the class _PyCapsule in the else branch, the following will error:

from typing import TYPE_CHECKING, TypeGuard

if TYPE_CHECKING:
    from _typeshed import CapsuleType as _PyCapsule

def is_capsule(obj: object) -> TypeGuard[_PyCapsule]:
    return hasattr(obj, "__capsule__")
Traceback (most recent call last):
  File "/Users/kosiew/GitHub/datafusion-python/examples/example_fail.py", line 6, in <module>
    def is_capsule(obj: object) -> TypeGuard[_PyCapsule]:
                                             ^^^^^^^^^^
NameError: name '_PyCapsule' is not defined

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but I think this is okay because if the same file appropriately sets future annotations it has no isse:

from __future__ import annotations

from typing import TYPE_CHECKING, TypeGuard

if TYPE_CHECKING:
    from _typeshed import CapsuleType as _PyCapsule

def is_capsule(obj: object) -> TypeGuard[_PyCapsule]:
    return hasattr(obj, "__capsule__")

The first line prevents this error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed
class _PyCapsule



class Volatility(Enum):
Expand Down Expand Up @@ -83,6 +99,11 @@ class ScalarUDFExportable(Protocol):
def __datafusion_scalar_udf__(self) -> object: ... # noqa: D105


def _is_pycapsule(value: object) -> TypeGuard[_PyCapsule]:
"""Return ``True`` when ``value`` is a CPython ``PyCapsule``."""
return value.__class__.__name__ == "PyCapsule"


class ScalarUDF:
"""Class for performing scalar user-defined functions (UDF).

Expand Down Expand Up @@ -290,6 +311,7 @@ class AggregateUDF:
also :py:class:`ScalarUDF` for operating on a row by row basis.
"""

@overload
def __init__(
self,
name: str,
Expand All @@ -298,6 +320,27 @@ def __init__(
return_type: pa.DataType,
state_type: list[pa.DataType],
volatility: Volatility | str,
) -> None: ...

@overload
def __init__(
self,
name: str,
accumulator: AggregateUDFExportable,
input_types: None = ...,
return_type: None = ...,
state_type: None = ...,
volatility: None = ...,
) -> None: ...

def __init__(
self,
name: str,
accumulator: Callable[[], Accumulator] | AggregateUDFExportable,
input_types: list[pa.DataType] | None,
return_type: pa.DataType | None,
state_type: list[pa.DataType] | None,
volatility: Volatility | str | None,
) -> None:
"""Instantiate a user-defined aggregate function (UDAF).

Expand All @@ -307,6 +350,18 @@ def __init__(
if hasattr(accumulator, "__datafusion_aggregate_udf__"):
self._udaf = df_internal.AggregateUDF.from_pycapsule(accumulator)
return
if (
input_types is None
or return_type is None
or state_type is None
or volatility is None
):
msg = (
"`input_types`, `return_type`, `state_type`, and `volatility` "
"must be provided when `accumulator` is callable."
)
raise TypeError(msg)

self._udaf = df_internal.AggregateUDF(
name,
accumulator,
Expand Down Expand Up @@ -350,6 +405,14 @@ def udaf(
name: Optional[str] = None,
) -> AggregateUDF: ...

@overload
@staticmethod
def udaf(accum: AggregateUDFExportable) -> AggregateUDF: ...

@overload
@staticmethod
def udaf(accum: _PyCapsule) -> AggregateUDF: ...

@staticmethod
def udaf(*args: Any, **kwargs: Any): # noqa: D417, C901
"""Create a new User-Defined Aggregate Function (UDAF).
Expand Down Expand Up @@ -470,7 +533,7 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:

return decorator

if hasattr(args[0], "__datafusion_aggregate_udf__"):
if hasattr(args[0], "__datafusion_aggregate_udf__") or _is_pycapsule(args[0]):
return AggregateUDF.from_pycapsule(args[0])

if args and callable(args[0]):
Expand All @@ -480,16 +543,22 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr:
return _decorator(*args, **kwargs)

@staticmethod
def from_pycapsule(func: AggregateUDFExportable) -> AggregateUDF:
def from_pycapsule(func: AggregateUDFExportable | _PyCapsule) -> AggregateUDF:
"""Create an Aggregate UDF from AggregateUDF PyCapsule object.

This function will instantiate a Aggregate UDF that uses a DataFusion
AggregateUDF that is exported via the FFI bindings.
"""
name = str(func.__class__)
if _is_pycapsule(func):
aggregate = cast(AggregateUDF, object.__new__(AggregateUDF))
aggregate._udaf = df_internal.AggregateUDF.from_pycapsule(func)
return aggregate

capsule = cast(AggregateUDFExportable, func)
name = str(capsule.__class__)
return AggregateUDF(
name=name,
accumulator=func,
accumulator=capsule,
input_types=None,
return_type=None,
state_type=None,
Expand Down
3 changes: 1 addition & 2 deletions python/tests/test_pyclass_frozen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@
r"(?P<key>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*\"(?P<value>[^\"]+)\"",
)
STRUCT_NAME_RE = re.compile(
r"\b(?:pub\s+)?(?:struct|enum)\s+"
r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
Comment on lines -38 to -39
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this PR but this came up as a Ruff error.

r"\b(?:pub\s+)?(?:struct|enum)\s+" r"(?P<name>[A-Za-z_][A-Za-z0-9_]*)",
)


Expand Down
33 changes: 21 additions & 12 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,15 @@ pub fn to_rust_accumulator(accum: PyObject) -> AccumulatorFactoryFunction {
})
}

fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult<AggregateUDF> {
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
let udaf: ForeignAggregateUDF = udaf.try_into()?;

Ok(udaf.into())
}

/// Represents an AggregateUDF
#[pyclass(frozen, name = "AggregateUDF", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -186,22 +195,22 @@ impl PyAggregateUDF {

#[staticmethod]
pub fn from_pycapsule(func: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if func.is_instance_of::<PyCapsule>() {
let capsule = func.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
let function = aggregate_udf_from_capsule(&capsule)?;
return Ok(Self { function });
}

if func.hasattr("__datafusion_aggregate_udf__")? {
let capsule = func.getattr("__datafusion_aggregate_udf__")?.call0()?;
let capsule = capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_aggregate_udf")?;

let udaf = unsafe { capsule.reference::<FFI_AggregateUDF>() };
let udaf: ForeignAggregateUDF = udaf.try_into()?;

Ok(Self {
function: udaf.into(),
})
} else {
Err(crate::errors::PyDataFusionError::Common(
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
))
let function = aggregate_udf_from_capsule(&capsule)?;
return Ok(Self { function });
}

Err(crate::errors::PyDataFusionError::Common(
"__datafusion_aggregate_udf__ does not exist on AggregateUDF object.".to_string(),
))
}

/// creates a new PyExpr with the call of the udf
Expand Down
Loading