Skip to content

Commit 6f25337

Browse files
committed
refactor: Update udwf method signature and simplify input handling
- Changed the type hint for the return type in the _create_window_udf_decorator method to use pa.DataType directly instead of a TypeVar. - Simplified the handling of input types by removing redundant checks and directly using the input types list. - Removed unnecessary comments and cleaned up the code for better readability. - Updated the test for udwf to use parameterized tests for better coverage and maintainability.
1 parent ad33378 commit 6f25337

File tree

2 files changed

+22
-64
lines changed

2 files changed

+22
-64
lines changed

python/datafusion/udf.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,15 @@
2222
import functools
2323
from abc import ABCMeta, abstractmethod
2424
from enum import Enum
25-
from typing import TYPE_CHECKING, Any, Callable, Optional, overload
25+
from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload
2626

2727
import pyarrow as pa
2828

2929
import datafusion._internal as df_internal
3030
from datafusion.expr import Expr
3131

3232
if TYPE_CHECKING:
33-
# for python 3.10 and above, we can use
34-
# from typing import TypeAlias
35-
# but for python 3.9, we use the following
36-
from typing_extensions import TypeAlias
37-
38-
_R: TypeAlias = pa.DataType
33+
_R = TypeVar("_R", bound=pa.DataType)
3934

4035

4136
class Volatility(Enum):
@@ -719,19 +714,13 @@ def _create_window_udf(
719714
msg = "`func` must implement the abstract base class WindowEvaluator"
720715
raise TypeError(msg)
721716

722-
if name is None:
723-
name = WindowUDF._get_default_name(func)
724-
725-
input_types_list = WindowUDF._normalize_input_types(input_types)
726-
727-
return WindowUDF(
728-
name=name,
729-
func=func,
730-
input_types=input_types_list,
731-
return_type=return_type,
732-
volatility=volatility,
717+
name = name or func.__qualname__.lower()
718+
input_types = (
719+
[input_types] if isinstance(input_types, pa.DataType) else input_types
733720
)
734721

722+
return WindowUDF(name, func, input_types, return_type, volatility)
723+
735724
@staticmethod
736725
def _get_default_name(func: Callable) -> str:
737726
"""Get the default name for a function based on its attributes."""
@@ -751,10 +740,10 @@ def _normalize_input_types(
751740
@staticmethod
752741
def _create_window_udf_decorator(
753742
input_types: pa.DataType | list[pa.DataType],
754-
return_type: _R,
743+
return_type: pa.DataType,
755744
volatility: Volatility | str,
756745
name: Optional[str] = None,
757-
) -> Callable[..., Callable[..., Expr]]:
746+
) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]:
758747
"""Create a decorator for a WindowUDF."""
759748

760749
def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]:

python/tests/test_udwf.py

Lines changed: 13 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -428,55 +428,24 @@ def test_udwf_functions(complex_window_df, name, expr, expected):
428428
assert result.column(0) == pa.array(expected)
429429

430430

431-
def test_udwf_overloads(count_window_df):
432-
"""Test different overload patterns for UDWF function."""
433-
# Single input type syntax
434-
single_input = udwf(
435-
SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable"
436-
)
437-
438-
# List of input types syntax
439-
list_input = udwf(
440-
SimpleWindowCount, [pa.int64()], pa.int64(), volatility="immutable"
441-
)
442-
443-
# Decorator syntax with single input type
444-
@udwf(pa.int64(), pa.int64(), "immutable")
445-
def window_count_single() -> WindowEvaluator:
446-
return SimpleWindowCount()
447-
448-
# Decorator syntax with list of input types
449-
@udwf([pa.int64()], pa.int64(), "immutable")
450-
def window_count_list() -> WindowEvaluator:
451-
return SimpleWindowCount()
452-
453-
# Test all variants produce the same result
431+
@pytest.mark.parametrize(
432+
"udwf_func",
433+
[
434+
udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"),
435+
udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"),
436+
udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()),
437+
udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()),
438+
],
439+
)
440+
def test_udwf_overloads(udwf_func, count_window_df):
454441
df = count_window_df.select(
455-
single_input(column("a"))
442+
udwf_func(column("a"))
456443
.window_frame(WindowFrame("rows", None, None))
457444
.build()
458-
.alias("single"),
459-
list_input(column("a"))
460-
.window_frame(WindowFrame("rows", None, None))
461-
.build()
462-
.alias("list"),
463-
window_count_single(column("a"))
464-
.window_frame(WindowFrame("rows", None, None))
465-
.build()
466-
.alias("decorator_single"),
467-
window_count_list(column("a"))
468-
.window_frame(WindowFrame("rows", None, None))
469-
.build()
470-
.alias("decorator_list"),
445+
.alias("count")
471446
)
472-
473447
result = df.collect()[0]
474-
expected = pa.array([0, 1, 2])
475-
476-
assert result.column(0) == expected
477-
assert result.column(1) == expected
478-
assert result.column(2) == expected
479-
assert result.column(3) == expected
448+
assert result.column(0) == pa.array([0, 1, 2])
480449

481450

482451
def test_udwf_named_function(ctx, count_window_df):

0 commit comments

Comments
 (0)