From b708133730db5244e93ddd9dde1cfc36e68281a3 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 15:25:01 +0800 Subject: [PATCH 01/18] feat: Introduce create_udwf method for User-Defined Window Functions - Added `create_udwf` static method to `WindowUDF` class, allowing users to create User-Defined Window Functions (UDWF) as both a function and a decorator. - Updated type hinting for `_R` using `TypeAlias` for better clarity. - Enhanced documentation with usage examples for both function and decorator styles, improving usability and understanding. --- python/datafusion/udf.py | 102 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 603b7063d..4ae4928f2 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload import pyarrow as pa @@ -30,7 +30,9 @@ from datafusion.expr import Expr if TYPE_CHECKING: - _R = TypeVar("_R", bound=pa.DataType) + from typing import TypeAlias + + _R: TypeAlias = pa.DataType class Volatility(Enum): @@ -684,9 +686,103 @@ def bias_10() -> BiasedNumbers: volatility=volatility, ) + @staticmethod + def create_udwf( + *args: Any, **kwargs: Any + ) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]: + """Create a new User-Defined Window Function (UDWF). + + This class can be used both as a **function** and as a **decorator**. + + Usage: + - **As a function**: Call `udwf(func, input_types, return_type, volatility, name)`. + - **As a decorator**: Use `@udwf(input_types, return_type, volatility, name)`. + When using `udwf` as a decorator, **do not pass `func` explicitly**. + + **Function example:** + ``` + import pyarrow as pa + + class BiasedNumbers(WindowEvaluator): + def __init__(self, start: int = 0) -> None: + self.start = start + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + return pa.array([self.start + i for i in range(num_rows)]) + + def bias_10() -> BiasedNumbers: + return BiasedNumbers(10) + + udwf1 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + ``` + + **Decorator example:** + ``` + @udwf(pa.int64(), pa.int64(), "immutable") + def biased_numbers() -> BiasedNumbers: + return BiasedNumbers(10) + ``` + + Args: + func: **Only needed when calling as a function. Skip this argument when using + `udwf` as a decorator.** + input_types: The data types of the arguments. + return_type: The data type of the return value. + volatility: See :py:class:`Volatility` for allowed values. + name: A descriptive name for the function. + + Returns: + A user-defined window function that can be used in window function calls. + """ + + def _function( + func: Callable[[], WindowEvaluator], + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> WindowUDF: + if not callable(func): + msg = "`func` argument must be callable" + raise TypeError(msg) + if not isinstance(func(), WindowEvaluator): + msg = "`func` must implement the abstract base class WindowEvaluator" + raise TypeError(msg) + if name is None: + if hasattr(func, "__qualname__"): + name = func.__qualname__.lower() + else: + name = func.__class__.__name__.lower() + if isinstance(input_types, pa.DataType): + input_types = [input_types] + return WindowUDF( + name=name, + func=func, + input_types=input_types, + return_type=return_type, + volatility=volatility, + ) + + def _decorator( + input_types: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[[Callable[[], WindowEvaluator]], WindowUDF]: + def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF: + return _function(func, input_types, return_type, volatility, name) + + return decorator + + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return _function(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return _decorator(*args, **kwargs) + # Convenience exports so we can import instead of treating as # variables at the package root udf = ScalarUDF.udf udaf = AggregateUDF.udaf -udwf = WindowUDF.udwf +udwf = WindowUDF.create_udwf From 333b80e77822865fc52e5765e302d983cbcdf3a8 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 15:36:53 +0800 Subject: [PATCH 02/18] refactor: Simplify UDWF test suite and introduce SimpleWindowCount evaluator - Removed multiple exponential smoothing classes to streamline the code. - Introduced SimpleWindowCount class for basic row counting functionality. - Updated test cases to validate the new SimpleWindowCount evaluator. - Refactored fixture and test functions for clarity and consistency. - Enhanced error handling in UDWF creation tests. --- python/tests/test_udwf.py | 340 ++++++++++---------------------------- 1 file changed, 89 insertions(+), 251 deletions(-) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 3d6dcf9d8..0d5b9d61b 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -20,146 +20,18 @@ import pyarrow as pa import pytest from datafusion import SessionContext, column, lit, udwf -from datafusion import functions as f from datafusion.expr import WindowFrame from datafusion.udf import WindowEvaluator -class ExponentialSmoothDefault(WindowEvaluator): - def __init__(self, alpha: float = 0.9) -> None: - self.alpha = alpha +class SimpleWindowCount(WindowEvaluator): + """A simple window evaluator that counts rows.""" - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: - results = [] - curr_value = 0.0 - values = values[0] - for idx in range(num_rows): - if idx == 0: - curr_value = values[idx].as_py() - else: - curr_value = values[idx].as_py() * self.alpha + curr_value * ( - 1.0 - self.alpha - ) - results.append(curr_value) - - return pa.array(results) - - -class ExponentialSmoothBounded(WindowEvaluator): - def __init__(self, alpha: float = 0.9) -> None: - self.alpha = alpha - - def supports_bounded_execution(self) -> bool: - return True - - def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: - # Override the default range of current row since uses_window_frame is False - # So for the purpose of this test we just smooth from the previous row to - # current. - if idx == 0: - return (0, 0) - return (idx - 1, idx) - - def evaluate( - self, values: list[pa.Array], eval_range: tuple[int, int] - ) -> pa.Scalar: - (start, stop) = eval_range - curr_value = 0.0 - values = values[0] - for idx in range(start, stop + 1): - if idx == start: - curr_value = values[idx].as_py() - else: - curr_value = values[idx].as_py() * self.alpha + curr_value * ( - 1.0 - self.alpha - ) - return pa.scalar(curr_value).cast(pa.float64()) - - -class ExponentialSmoothRank(WindowEvaluator): - def __init__(self, alpha: float = 0.9) -> None: - self.alpha = alpha - - def include_rank(self) -> bool: - return True - - def evaluate_all_with_rank( - self, num_rows: int, ranks_in_partition: list[tuple[int, int]] - ) -> pa.Array: - results = [] - for idx in range(num_rows): - if idx == 0: - prior_value = 1.0 - matching_row = [ - i - for i in range(len(ranks_in_partition)) - if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx - ][0] + 1 - curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha) - results.append(curr_value) - prior_value = matching_row - - return pa.array(results) - - -class ExponentialSmoothFrame(WindowEvaluator): - def __init__(self, alpha: float = 0.9) -> None: - self.alpha = alpha - - def uses_window_frame(self) -> bool: - return True - - def evaluate( - self, values: list[pa.Array], eval_range: tuple[int, int] - ) -> pa.Scalar: - (start, stop) = eval_range - curr_value = 0.0 - if len(values) > 1: - order_by = values[1] # noqa: F841 - values = values[0] - else: - values = values[0] - for idx in range(start, stop): - if idx == start: - curr_value = values[idx].as_py() - else: - curr_value = values[idx].as_py() * self.alpha + curr_value * ( - 1.0 - self.alpha - ) - return pa.scalar(curr_value).cast(pa.float64()) - - -class SmoothTwoColumn(WindowEvaluator): - """This class demonstrates using two columns. - - If the second column is above a threshold, then smooth over the first column from - the previous and next rows. - """ - - def __init__(self, alpha: float = 0.9) -> None: - self.alpha = alpha + def __init__(self, base: int = 0) -> None: + self.base = base def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: - results = [] - values_a = values[0] - values_b = values[1] - for idx in range(num_rows): - if values_b[idx].as_py() > 7: - if idx == 0: - results.append(values_a[1].cast(pa.float64())) - elif idx == num_rows - 1: - results.append(values_a[num_rows - 2].cast(pa.float64())) - else: - results.append( - pa.scalar( - values_a[idx - 1].as_py() * self.alpha - + values_a[idx + 1].as_py() * (1.0 - self.alpha) - ) - ) - else: - results.append(values_a[idx].cast(pa.float64())) - - return pa.array(results) + return pa.array([self.base + i for i in range(num_rows)]) class NotSubclassOfWindowEvaluator: @@ -167,142 +39,108 @@ class NotSubclassOfWindowEvaluator: @pytest.fixture -def df(): - ctx = SessionContext() +def ctx(): + return SessionContext() + +@pytest.fixture +def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( - [ - pa.array([0, 1, 2, 3, 4, 5, 6]), - pa.array([7, 4, 3, 8, 9, 1, 6]), - pa.array(["A", "A", "A", "A", "B", "B", "B"]), - ], - names=["a", "b", "c"], + [pa.array([1, 2, 3]), pa.array([4, 4, 6])], + names=["a", "b"], ) - return ctx.create_dataframe([[batch]]) + return ctx.create_dataframe([[batch]], name="test_table") -def test_udwf_errors(df): - with pytest.raises(TypeError): +def test_udwf_errors(): + """Test error cases for UDWF creation.""" + with pytest.raises( + TypeError, match="`func` must implement the abstract base class WindowEvaluator" + ): udwf( - NotSubclassOfWindowEvaluator, - pa.float64(), - pa.float64(), - volatility="immutable", + NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable" ) -smooth_default = udwf( - ExponentialSmoothDefault, - pa.float64(), - pa.float64(), - volatility="immutable", -) +def test_udwf_basic_usage(df): + """Test basic UDWF usage with a simple counting window function.""" + simple_count = udwf( + SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" + ) + + df = df.select( + simple_count(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + -smooth_w_arguments = udwf( - lambda: ExponentialSmoothDefault(0.8), - pa.float64(), - pa.float64(), - volatility="immutable", -) +def test_udwf_with_args(df): + """Test UDWF with constructor arguments.""" + count_base10 = udwf( + lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable" + ) -smooth_bounded = udwf( - ExponentialSmoothBounded, - pa.float64(), - pa.float64(), - volatility="immutable", -) + df = df.select( + count_base10(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([10, 11, 12]) -smooth_rank = udwf( - ExponentialSmoothRank, - pa.utf8(), - pa.float64(), - volatility="immutable", -) -smooth_frame = udwf( - ExponentialSmoothFrame, - pa.float64(), - pa.float64(), - volatility="immutable", -) +def test_udwf_decorator_basic(df): + """Test UDWF used as a decorator.""" -smooth_two_col = udwf( - SmoothTwoColumn, - [pa.int64(), pa.int64()], - pa.float64(), - volatility="immutable", -) + @udwf([pa.int64()], pa.int64(), "immutable") + def window_count() -> WindowEvaluator: + return SimpleWindowCount() -data_test_udwf_functions = [ - ( - "default_udwf_no_arguments", - smooth_default(column("a")), - [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], - ), - ( - "default_udwf_w_arguments", - smooth_w_arguments(column("a")), - [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], - ), - ( - "default_udwf_partitioned", - smooth_default(column("a")).partition_by(column("c")).build(), - [0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89], - ), - ( - "default_udwf_ordered", - smooth_default(column("a")).order_by(column("b")).build(), - [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], - ), - ( - "bounded_udwf", - smooth_bounded(column("a")), - [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], - ), - ( - "bounded_udwf_ignores_frame", - smooth_bounded(column("a")) + df = df.select( + window_count(column("a")) .window_frame(WindowFrame("rows", None, None)) - .build(), - [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], - ), - ( - "rank_udwf", - smooth_rank(column("c")).order_by(column("c")).build(), - [1, 1, 1, 1, 1.9, 2, 2], - ), - ( - "frame_unbounded_udwf", - smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(), - [5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889], - ), - ( - "frame_bounded_udwf", - smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(), - [0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], - ), - ( - "frame_bounded_udwf", - smooth_frame(column("a")) - .window_frame(WindowFrame("rows", None, 0)) - .order_by(column("b")) - .build(), - [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], - ), - ( - "two_column_udwf", - smooth_two_col(column("a"), column("b")), - [0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0], - ), -] + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) -@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) -def test_udwf_functions(df, name, expr, expected): - df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) +def test_udwf_decorator_with_args(df): + """Test UDWF decorator with constructor arguments.""" - # execute and collect the first (and only) batch - result = df.sort(column("a")).select(column(name)).collect()[0] + @udwf([pa.int64()], pa.int64(), "immutable") + def window_count_base10() -> WindowEvaluator: + return SimpleWindowCount(10) + + df = df.select( + window_count_base10(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + result = df.collect()[0] + assert result.column(0) == pa.array([10, 11, 12]) + + +def test_register_udwf(ctx, df): + """Test registering and using UDWF in SQL context.""" + window_count = udwf( + SimpleWindowCount, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="window_count", + ) - assert result.column(0) == pa.array(expected) + ctx.register_udwf(window_count) + result = ctx.sql( + "SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" + ).collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) From a52af174e40b98e78563220064039ef068140e54 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 15:44:57 +0800 Subject: [PATCH 03/18] fix: Update type alias import to use typing_extensions for compatibility --- python/datafusion/udf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 4ae4928f2..dbb4c02c4 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -30,7 +30,10 @@ from datafusion.expr import Expr if TYPE_CHECKING: - from typing import TypeAlias + # for python 3.10 and above, we can use + # from typing import TypeAlias + # but for python 3.9, we use the following + from typing_extensions import TypeAlias _R: TypeAlias = pa.DataType From cd972b53277cd6ec360a21dc812d70da70a4b507 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 15:57:40 +0800 Subject: [PATCH 04/18] Add udwf tests for multiple input types and decorator syntax --- python/tests/test_udwf.py | 302 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 291 insertions(+), 11 deletions(-) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 0d5b9d61b..bbcf9bb52 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -20,10 +20,148 @@ import pyarrow as pa import pytest from datafusion import SessionContext, column, lit, udwf +from datafusion import functions as f from datafusion.expr import WindowFrame from datafusion.udf import WindowEvaluator +class ExponentialSmoothDefault(WindowEvaluator): + def __init__(self, alpha: float = 0.9) -> None: + self.alpha = alpha + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + results = [] + curr_value = 0.0 + values = values[0] + for idx in range(num_rows): + if idx == 0: + curr_value = values[idx].as_py() + else: + curr_value = values[idx].as_py() * self.alpha + curr_value * ( + 1.0 - self.alpha + ) + results.append(curr_value) + + return pa.array(results) + + +class ExponentialSmoothBounded(WindowEvaluator): + def __init__(self, alpha: float = 0.9) -> None: + self.alpha = alpha + + def supports_bounded_execution(self) -> bool: + return True + + def get_range(self, idx: int, num_rows: int) -> tuple[int, int]: + # Override the default range of current row since uses_window_frame is False + # So for the purpose of this test we just smooth from the previous row to + # current. + if idx == 0: + return (0, 0) + return (idx - 1, idx) + + def evaluate( + self, values: list[pa.Array], eval_range: tuple[int, int] + ) -> pa.Scalar: + (start, stop) = eval_range + curr_value = 0.0 + values = values[0] + for idx in range(start, stop + 1): + if idx == start: + curr_value = values[idx].as_py() + else: + curr_value = values[idx].as_py() * self.alpha + curr_value * ( + 1.0 - self.alpha + ) + return pa.scalar(curr_value).cast(pa.float64()) + + +class ExponentialSmoothRank(WindowEvaluator): + def __init__(self, alpha: float = 0.9) -> None: + self.alpha = alpha + + def include_rank(self) -> bool: + return True + + def evaluate_all_with_rank( + self, num_rows: int, ranks_in_partition: list[tuple[int, int]] + ) -> pa.Array: + results = [] + for idx in range(num_rows): + if idx == 0: + prior_value = 1.0 + matching_row = [ + i + for i in range(len(ranks_in_partition)) + if ranks_in_partition[i][0] <= idx and ranks_in_partition[i][1] > idx + ][0] + 1 + curr_value = matching_row * self.alpha + prior_value * (1.0 - self.alpha) + results.append(curr_value) + prior_value = matching_row + + return pa.array(results) + + +class ExponentialSmoothFrame(WindowEvaluator): + def __init__(self, alpha: float = 0.9) -> None: + self.alpha = alpha + + def uses_window_frame(self) -> bool: + return True + + def evaluate( + self, values: list[pa.Array], eval_range: tuple[int, int] + ) -> pa.Scalar: + (start, stop) = eval_range + curr_value = 0.0 + if len(values) > 1: + order_by = values[1] # noqa: F841 + values = values[0] + else: + values = values[0] + for idx in range(start, stop): + if idx == start: + curr_value = values[idx].as_py() + else: + curr_value = values[idx].as_py() * self.alpha + curr_value * ( + 1.0 - self.alpha + ) + return pa.scalar(curr_value).cast(pa.float64()) + + +class SmoothTwoColumn(WindowEvaluator): + """This class demonstrates using two columns. + + If the second column is above a threshold, then smooth over the first column from + the previous and next rows. + """ + + def __init__(self, alpha: float = 0.9) -> None: + self.alpha = alpha + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + results = [] + values_a = values[0] + values_b = values[1] + for idx in range(num_rows): + if values_b[idx].as_py() > 7: + if idx == 0: + results.append(values_a[1].cast(pa.float64())) + elif idx == num_rows - 1: + results.append(values_a[num_rows - 2].cast(pa.float64())) + else: + results.append( + pa.scalar( + values_a[idx - 1].as_py() * self.alpha + + values_a[idx + 1].as_py() * (1.0 - self.alpha) + ) + ) + else: + results.append(values_a[idx].cast(pa.float64())) + + return pa.array(results) + + class SimpleWindowCount(WindowEvaluator): """A simple window evaluator that counts rows.""" @@ -44,7 +182,23 @@ def ctx(): @pytest.fixture -def df(ctx): +def df(): + ctx = SessionContext() + + # create a RecordBatch and a new DataFrame from it + batch = pa.RecordBatch.from_arrays( + [ + pa.array([0, 1, 2, 3, 4, 5, 6]), + pa.array([7, 4, 3, 8, 9, 1, 6]), + pa.array(["A", "A", "A", "A", "B", "B", "B"]), + ], + names=["a", "b", "c"], + ) + return ctx.create_dataframe([[batch]]) + + +@pytest.fixture +def simple_df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 4, 6])], @@ -53,7 +207,17 @@ def df(ctx): return ctx.create_dataframe([[batch]], name="test_table") -def test_udwf_errors(): +def test_udwf_errors(df): + with pytest.raises(TypeError): + udwf( + NotSubclassOfWindowEvaluator, + pa.float64(), + pa.float64(), + volatility="immutable", + ) + + +def test_udwf_errors_with_message(): """Test error cases for UDWF creation.""" with pytest.raises( TypeError, match="`func` must implement the abstract base class WindowEvaluator" @@ -63,13 +227,13 @@ def test_udwf_errors(): ) -def test_udwf_basic_usage(df): +def test_udwf_basic_usage(simple_df): """Test basic UDWF usage with a simple counting window function.""" simple_count = udwf( SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" ) - df = df.select( + df = simple_df.select( simple_count(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -79,13 +243,13 @@ def test_udwf_basic_usage(df): assert result.column(0) == pa.array([0, 1, 2]) -def test_udwf_with_args(df): +def test_udwf_with_args(simple_df): """Test UDWF with constructor arguments.""" count_base10 = udwf( lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable" ) - df = df.select( + df = simple_df.select( count_base10(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -95,14 +259,14 @@ def test_udwf_with_args(df): assert result.column(0) == pa.array([10, 11, 12]) -def test_udwf_decorator_basic(df): +def test_udwf_decorator_basic(simple_df): """Test UDWF used as a decorator.""" @udwf([pa.int64()], pa.int64(), "immutable") def window_count() -> WindowEvaluator: return SimpleWindowCount() - df = df.select( + df = simple_df.select( window_count(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -112,14 +276,14 @@ def window_count() -> WindowEvaluator: assert result.column(0) == pa.array([0, 1, 2]) -def test_udwf_decorator_with_args(df): +def test_udwf_decorator_with_args(simple_df): """Test UDWF decorator with constructor arguments.""" @udwf([pa.int64()], pa.int64(), "immutable") def window_count_base10() -> WindowEvaluator: return SimpleWindowCount(10) - df = df.select( + df = simple_df.select( window_count_base10(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -129,7 +293,7 @@ def window_count_base10() -> WindowEvaluator: assert result.column(0) == pa.array([10, 11, 12]) -def test_register_udwf(ctx, df): +def test_register_udwf(ctx, simple_df): """Test registering and using UDWF in SQL context.""" window_count = udwf( SimpleWindowCount, @@ -144,3 +308,119 @@ def test_register_udwf(ctx, df): "SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" ).collect()[0] assert result.column(0) == pa.array([0, 1, 2]) + + +smooth_default = udwf( + ExponentialSmoothDefault, + pa.float64(), + pa.float64(), + volatility="immutable", +) + +smooth_w_arguments = udwf( + lambda: ExponentialSmoothDefault(0.8), + pa.float64(), + pa.float64(), + volatility="immutable", +) + +smooth_bounded = udwf( + ExponentialSmoothBounded, + pa.float64(), + pa.float64(), + volatility="immutable", +) + +smooth_rank = udwf( + ExponentialSmoothRank, + pa.utf8(), + pa.float64(), + volatility="immutable", +) + +smooth_frame = udwf( + ExponentialSmoothFrame, + pa.float64(), + pa.float64(), + volatility="immutable", +) + +smooth_two_col = udwf( + SmoothTwoColumn, + [pa.int64(), pa.int64()], + pa.float64(), + volatility="immutable", +) + +data_test_udwf_functions = [ + ( + "default_udwf_no_arguments", + smooth_default(column("a")), + [0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], + ), + ( + "default_udwf_w_arguments", + smooth_w_arguments(column("a")), + [0, 0.8, 1.76, 2.752, 3.75, 4.75, 5.75], + ), + ( + "default_udwf_partitioned", + smooth_default(column("a")).partition_by(column("c")).build(), + [0, 0.9, 1.89, 2.889, 4.0, 4.9, 5.89], + ), + ( + "default_udwf_ordered", + smooth_default(column("a")).order_by(column("b")).build(), + [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], + ), + ( + "bounded_udwf", + smooth_bounded(column("a")), + [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], + ), + ( + "bounded_udwf_ignores_frame", + smooth_bounded(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build(), + [0, 0.9, 1.9, 2.9, 3.9, 4.9, 5.9], + ), + ( + "rank_udwf", + smooth_rank(column("c")).order_by(column("c")).build(), + [1, 1, 1, 1, 1.9, 2, 2], + ), + ( + "frame_unbounded_udwf", + smooth_frame(column("a")).window_frame(WindowFrame("rows", None, None)).build(), + [5.889, 5.889, 5.889, 5.889, 5.889, 5.889, 5.889], + ), + ( + "frame_bounded_udwf", + smooth_frame(column("a")).window_frame(WindowFrame("rows", None, 0)).build(), + [0.0, 0.9, 1.89, 2.889, 3.889, 4.889, 5.889], + ), + ( + "frame_bounded_udwf", + smooth_frame(column("a")) + .window_frame(WindowFrame("rows", None, 0)) + .order_by(column("b")) + .build(), + [0.551, 1.13, 2.3, 2.755, 3.876, 5.0, 5.513], + ), + ( + "two_column_udwf", + smooth_two_col(column("a"), column("b")), + [0.0, 1.0, 2.0, 2.2, 3.2, 5.0, 6.0], + ), +] + + +@pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) +def test_udwf_functions(df, name, expr, expected): + df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) + + # execute and collect the first (and only) batch + result = df.sort(column("a")).select(column(name)).collect()[0] + + assert result.column(0) == pa.array(expected) From d7ffa02e82f57e615a4369c8bcfff7f04a7b0a5f Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 16:16:53 +0800 Subject: [PATCH 05/18] replace old def udwf --- python/datafusion/udf.py | 65 +--------------------------------------- 1 file changed, 1 insertion(+), 64 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index dbb4c02c4..dac79d15b 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -628,69 +628,6 @@ def __call__(self, *args: Expr) -> Expr: @staticmethod def udwf( - func: Callable[[], WindowEvaluator], - input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> WindowUDF: - """Create a new User-Defined Window Function. - - If your :py:class:`WindowEvaluator` can be instantiated with no arguments, you - can simply pass it's type as ``func``. If you need to pass additional arguments - to it's constructor, you can define a lambda or a factory method. During runtime - the :py:class:`WindowEvaluator` will be constructed for every instance in - which this UDWF is used. The following examples are all valid. - - .. code-block:: python - - import pyarrow as pa - - class BiasedNumbers(WindowEvaluator): - def __init__(self, start: int = 0) -> None: - self.start = start - - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: - return pa.array([self.start + i for i in range(num_rows)]) - - def bias_10() -> BiasedNumbers: - return BiasedNumbers(10) - - udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") - udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") - udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") - - Args: - func: A callable to create the window function. - input_types: The data types of the arguments to ``func``. - return_type: The data type of the return value. - volatility: See :py:class:`Volatility` for allowed values. - arguments: A list of arguments to pass in to the __init__ method for accum. - name: A descriptive name for the function. - - Returns: - A user-defined window function. - """ # noqa: W505, E501 - if not callable(func): - msg = "`func` must be callable." - raise TypeError(msg) - if not isinstance(func(), WindowEvaluator): - msg = "`func` must implement the abstract base class WindowEvaluator" - raise TypeError(msg) - if name is None: - name = func().__class__.__qualname__.lower() - if isinstance(input_types, pa.DataType): - input_types = [input_types] - return WindowUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) - - @staticmethod - def create_udwf( *args: Any, **kwargs: Any ) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]: """Create a new User-Defined Window Function (UDWF). @@ -788,4 +725,4 @@ def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF: # variables at the package root udf = ScalarUDF.udf udaf = AggregateUDF.udaf -udwf = WindowUDF.create_udwf +udwf = WindowUDF.udwf From 3eade9586a9fa41414009cd1787a6fbb0ef05e7a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 16:18:51 +0800 Subject: [PATCH 06/18] refactor: Simplify df fixture by passing ctx as an argument --- python/tests/test_udwf.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index bbcf9bb52..44fa54c59 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -182,9 +182,7 @@ def ctx(): @pytest.fixture -def df(): - ctx = SessionContext() - +def df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [ From 86cc70e971d2550d708804f8d54fbb1f068288ca Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 16:21:44 +0800 Subject: [PATCH 07/18] refactor: Rename DataFrame fixtures and update test functions - Renamed `df` fixture to `complex_window_df` for clarity. - Renamed `simple_df` fixture to `count_window_df` to better reflect its purpose. - Updated test functions to use the new fixture names, enhancing readability and maintainability. --- python/tests/test_udwf.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 44fa54c59..c51857f6e 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -182,7 +182,7 @@ def ctx(): @pytest.fixture -def df(ctx): +def complex_window_df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [ @@ -196,7 +196,7 @@ def df(ctx): @pytest.fixture -def simple_df(ctx): +def count_window_df(ctx): # create a RecordBatch and a new DataFrame from it batch = pa.RecordBatch.from_arrays( [pa.array([1, 2, 3]), pa.array([4, 4, 6])], @@ -205,7 +205,7 @@ def simple_df(ctx): return ctx.create_dataframe([[batch]], name="test_table") -def test_udwf_errors(df): +def test_udwf_errors(complex_window_df): with pytest.raises(TypeError): udwf( NotSubclassOfWindowEvaluator, @@ -225,13 +225,13 @@ def test_udwf_errors_with_message(): ) -def test_udwf_basic_usage(simple_df): +def test_udwf_basic_usage(count_window_df): """Test basic UDWF usage with a simple counting window function.""" simple_count = udwf( SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" ) - df = simple_df.select( + df = count_window_df.select( simple_count(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -241,13 +241,13 @@ def test_udwf_basic_usage(simple_df): assert result.column(0) == pa.array([0, 1, 2]) -def test_udwf_with_args(simple_df): +def test_udwf_with_args(count_window_df): """Test UDWF with constructor arguments.""" count_base10 = udwf( lambda: SimpleWindowCount(10), pa.int64(), pa.int64(), volatility="immutable" ) - df = simple_df.select( + df = count_window_df.select( count_base10(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -257,14 +257,14 @@ def test_udwf_with_args(simple_df): assert result.column(0) == pa.array([10, 11, 12]) -def test_udwf_decorator_basic(simple_df): +def test_udwf_decorator_basic(count_window_df): """Test UDWF used as a decorator.""" @udwf([pa.int64()], pa.int64(), "immutable") def window_count() -> WindowEvaluator: return SimpleWindowCount() - df = simple_df.select( + df = count_window_df.select( window_count(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -274,14 +274,14 @@ def window_count() -> WindowEvaluator: assert result.column(0) == pa.array([0, 1, 2]) -def test_udwf_decorator_with_args(simple_df): +def test_udwf_decorator_with_args(count_window_df): """Test UDWF decorator with constructor arguments.""" @udwf([pa.int64()], pa.int64(), "immutable") def window_count_base10() -> WindowEvaluator: return SimpleWindowCount(10) - df = simple_df.select( + df = count_window_df.select( window_count_base10(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() @@ -291,7 +291,7 @@ def window_count_base10() -> WindowEvaluator: assert result.column(0) == pa.array([10, 11, 12]) -def test_register_udwf(ctx, simple_df): +def test_register_udwf(ctx, count_window_df): """Test registering and using UDWF in SQL context.""" window_count = udwf( SimpleWindowCount, @@ -415,8 +415,8 @@ def test_register_udwf(ctx, simple_df): @pytest.mark.parametrize(("name", "expr", "expected"), data_test_udwf_functions) -def test_udwf_functions(df, name, expr, expected): - df = df.select("a", "b", f.round(expr, lit(3)).alias(name)) +def test_udwf_functions(complex_window_df, name, expr, expected): + df = complex_window_df.select("a", "b", f.round(expr, lit(3)).alias(name)) # execute and collect the first (and only) batch result = df.sort(column("a")).select(column(name)).collect()[0] From ae623835f07e3a199d57567c7c0923736c50efb2 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 16:31:01 +0800 Subject: [PATCH 08/18] refactor: Update udwf calls in WindowUDF to use BiasedNumbers directly - Changed udwf1 to use BiasedNumbers instead of bias_10. - Added udwf2 to call udwf with bias_10. - Introduced udwf3 to demonstrate a lambda function returning BiasedNumbers(20). --- python/datafusion/udf.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index dac79d15b..5ee807fdd 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -653,7 +653,10 @@ def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: def bias_10() -> BiasedNumbers: return BiasedNumbers(10) - udwf1 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + udwf1 = udwf(BiasedNumbers, pa.int64(), pa.int64(), "immutable") + udwf2 = udwf(bias_10, pa.int64(), pa.int64(), "immutable") + udwf3 = udwf(lambda: BiasedNumbers(20), pa.int64(), pa.int64(), "immutable") + ``` **Decorator example:** From 4c397cf4dac0724d9e1a39d77b4f3e0830325f54 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 16:57:15 +0800 Subject: [PATCH 09/18] feat: Add overloads for udwf function to support multiple input types and decorator syntax --- python/datafusion/udf.py | 21 ++++++++++++ python/tests/test_udwf.py | 68 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 5ee807fdd..3219a0453 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -626,6 +626,27 @@ def __call__(self, *args: Expr) -> Expr: args_raw = [arg.expr for arg in args] return Expr(self._udwf.__call__(*args_raw)) + @overload + @staticmethod + def udwf( + input_type: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: str, + name: Optional[str] = None, + ) -> Callable[..., WindowUDF]: ... + + @overload + @staticmethod + def udwf( + windown: Callable[[], WindowEvaluator], + input_type: pa.DataType | list[pa.DataType], + return_type: pa.DataType, + state_type: list[pa.DataType], + volatility: str, + name: Optional[str] = None, + ) -> WindowUDF: ... + @staticmethod def udwf( *args: Any, **kwargs: Any diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index c51857f6e..e018d38b9 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -422,3 +422,71 @@ def test_udwf_functions(complex_window_df, name, expr, expected): result = df.sort(column("a")).select(column(name)).collect()[0] assert result.column(0) == pa.array(expected) + + +def test_udwf_overloads(count_window_df): + """Test different overload patterns for UDWF function.""" + # Single input type syntax + single_input = udwf( + SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" + ) + + # List of input types syntax + list_input = udwf( + SimpleWindowCount, [pa.int64()], pa.int64(), volatility="immutable" + ) + + # Decorator syntax with single input type + @udwf(pa.int64(), pa.int64(), "immutable") + def window_count_single() -> WindowEvaluator: + return SimpleWindowCount() + + # Decorator syntax with list of input types + @udwf([pa.int64()], pa.int64(), "immutable") + def window_count_list() -> WindowEvaluator: + return SimpleWindowCount() + + # Test all variants produce the same result + df = count_window_df.select( + single_input(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("single"), + list_input(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("list"), + window_count_single(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("decorator_single"), + window_count_list(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("decorator_list"), + ) + + result = df.collect()[0] + expected = pa.array([0, 1, 2]) + + assert result.column(0) == expected + assert result.column(1) == expected + assert result.column(2) == expected + assert result.column(3) == expected + + +def test_udwf_named_function(ctx, count_window_df): + """Test UDWF with explicit name parameter.""" + window_count = udwf( + SimpleWindowCount, + pa.int64(), + pa.int64(), + volatility="immutable", + name="my_custom_counter", + ) + + ctx.register_udwf(window_count) + result = ctx.sql( + "SELECT my_custom_counter(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" + ).collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) From 1164374e58902a40d0d959ce7b93249c87fc9728 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 17:01:04 +0800 Subject: [PATCH 10/18] refactor: Simplify udwf method signature by removing redundant type hints --- python/datafusion/udf.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 3219a0453..7ce34da89 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, overload import pyarrow as pa @@ -639,7 +639,7 @@ def udwf( @overload @staticmethod def udwf( - windown: Callable[[], WindowEvaluator], + func: Callable[[], WindowEvaluator], input_type: pa.DataType | list[pa.DataType], return_type: pa.DataType, state_type: list[pa.DataType], @@ -648,9 +648,7 @@ def udwf( ) -> WindowUDF: ... @staticmethod - def udwf( - *args: Any, **kwargs: Any - ) -> Union[WindowUDF, Callable[[Callable[[], WindowEvaluator]], WindowUDF]]: + def udwf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Window Function (UDWF). This class can be used both as a **function** and as a **decorator**. From d29acf6f3dbf468226934eddaf964a4556ee1eaf Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 17:27:58 +0800 Subject: [PATCH 11/18] refactor: Remove state_type from udwf method signature and update return type handling - Eliminated the state_type parameter from the udwf method to simplify the function signature. - Updated return type handling in the _function and _decorator methods to use a generic type _R for better type flexibility. - Enhanced the decorator to wrap the original function, allowing for improved argument handling and expression return. --- python/datafusion/udf.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 7ce34da89..7ab376fd3 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -19,6 +19,7 @@ from __future__ import annotations +from ast import Call import functools from abc import ABCMeta, abstractmethod from enum import Enum @@ -631,7 +632,6 @@ def __call__(self, *args: Expr) -> Expr: def udwf( input_type: pa.DataType | list[pa.DataType], return_type: pa.DataType, - state_type: list[pa.DataType], volatility: str, name: Optional[str] = None, ) -> Callable[..., WindowUDF]: ... @@ -642,7 +642,6 @@ def udwf( func: Callable[[], WindowEvaluator], input_type: pa.DataType | list[pa.DataType], return_type: pa.DataType, - state_type: list[pa.DataType], volatility: str, name: Optional[str] = None, ) -> WindowUDF: ... @@ -700,7 +699,7 @@ def biased_numbers() -> BiasedNumbers: def _function( func: Callable[[], WindowEvaluator], input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, + return_type: _R, volatility: Volatility | str, name: Optional[str] = None, ) -> WindowUDF: @@ -727,12 +726,20 @@ def _function( def _decorator( input_types: pa.DataType | list[pa.DataType], - return_type: pa.DataType, + return_type: _R, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable[[Callable[[], WindowEvaluator]], WindowUDF]: - def decorator(func: Callable[[], WindowEvaluator]) -> WindowUDF: - return _function(func, input_types, return_type, volatility, name) + ) -> Callable[..., Callable[..., Expr]]: + def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: + udwf_caller = WindowUDF.udwf( + func, input_types, return_type, volatility, name + ) + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Expr: + return udwf_caller(*args, **kwargs) + + return wrapper return decorator From 46097d1460f682348722d3bee689830d75a04967 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 17:33:30 +0800 Subject: [PATCH 12/18] refactor: Update volatility parameter type in udwf method signature to support Volatility enum --- python/datafusion/udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 7ab376fd3..cbfa756ab 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -632,7 +632,7 @@ def __call__(self, *args: Expr) -> Expr: def udwf( input_type: pa.DataType | list[pa.DataType], return_type: pa.DataType, - volatility: str, + volatility: Volatility | str, name: Optional[str] = None, ) -> Callable[..., WindowUDF]: ... @@ -642,7 +642,7 @@ def udwf( func: Callable[[], WindowEvaluator], input_type: pa.DataType | list[pa.DataType], return_type: pa.DataType, - volatility: str, + volatility: Volatility | str, name: Optional[str] = None, ) -> WindowUDF: ... From b0a1803c4b43bcddca8a9a4ef169beff086f29e9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 18:15:21 +0800 Subject: [PATCH 13/18] Fix ruff errors --- python/datafusion/udf.py | 18 ++++++++++-------- python/tests/test_udwf.py | 11 +++++++++-- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index cbfa756ab..504561a41 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -19,7 +19,6 @@ from __future__ import annotations -from ast import Call import functools from abc import ABCMeta, abstractmethod from enum import Enum @@ -647,15 +646,17 @@ def udwf( ) -> WindowUDF: ... @staticmethod - def udwf(*args: Any, **kwargs: Any): # noqa: D417 + def udwf(*args: Any, **kwargs: Any): # noqa: D417, C901 """Create a new User-Defined Window Function (UDWF). This class can be used both as a **function** and as a **decorator**. Usage: - - **As a function**: Call `udwf(func, input_types, return_type, volatility, name)`. - - **As a decorator**: Use `@udwf(input_types, return_type, volatility, name)`. - When using `udwf` as a decorator, **do not pass `func` explicitly**. + - **As a function**: Call `udwf(func, input_types, return_type, volatility, + name)`. + - **As a decorator**: Use `@udwf(input_types, return_type, volatility, + name)`. When using `udwf` as a decorator, **do not pass `func` + explicitly**. **Function example:** ``` @@ -665,7 +666,8 @@ class BiasedNumbers(WindowEvaluator): def __init__(self, start: int = 0) -> None: self.start = start - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + def evaluate_all(self, values: list[pa.Array], + num_rows: int) -> pa.Array: return pa.array([self.start + i for i in range(num_rows)]) def bias_10() -> BiasedNumbers: @@ -685,8 +687,8 @@ def biased_numbers() -> BiasedNumbers: ``` Args: - func: **Only needed when calling as a function. Skip this argument when using - `udwf` as a decorator.** + func: **Only needed when calling as a function. Skip this argument when + using `udwf` as a decorator.** input_types: The data types of the arguments. return_type: The data type of the return value. volatility: See :py:class:`Volatility` for allowed values. diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index e018d38b9..3336a0acb 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -303,7 +303,11 @@ def test_register_udwf(ctx, count_window_df): ctx.register_udwf(window_count) result = ctx.sql( - "SELECT window_count(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" + """ + SELECT window_count(a) + OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED + FOLLOWING) FROM test_table + """ ).collect()[0] assert result.column(0) == pa.array([0, 1, 2]) @@ -487,6 +491,9 @@ def test_udwf_named_function(ctx, count_window_df): ctx.register_udwf(window_count) result = ctx.sql( - "SELECT my_custom_counter(a) OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) FROM test_table" + """ + SELECT my_custom_counter(a) + OVER (ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED + FOLLOWING) FROM test_table""" ).collect()[0] assert result.column(0) == pa.array([0, 1, 2]) From ad33378402dfdb2329551e6c6be7680c9e64d16a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 18:25:05 +0800 Subject: [PATCH 14/18] fix C901 for def udwf --- python/datafusion/udf.py | 115 +++++++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 48 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 504561a41..96efd00ba 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -646,7 +646,7 @@ def udwf( ) -> WindowUDF: ... @staticmethod - def udwf(*args: Any, **kwargs: Any): # noqa: D417, C901 + def udwf(*args: Any, **kwargs: Any): # noqa: D417 """Create a new User-Defined Window Function (UDWF). This class can be used both as a **function** and as a **decorator**. @@ -697,59 +697,78 @@ def biased_numbers() -> BiasedNumbers: Returns: A user-defined window function that can be used in window function calls. """ + if args and callable(args[0]): + # Case 1: Used as a function, require the first parameter to be callable + return WindowUDF._create_window_udf(*args, **kwargs) + # Case 2: Used as a decorator with parameters + return WindowUDF._create_window_udf_decorator(*args, **kwargs) - def _function( - func: Callable[[], WindowEvaluator], - input_types: pa.DataType | list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> WindowUDF: - if not callable(func): - msg = "`func` argument must be callable" - raise TypeError(msg) - if not isinstance(func(), WindowEvaluator): - msg = "`func` must implement the abstract base class WindowEvaluator" - raise TypeError(msg) - if name is None: - if hasattr(func, "__qualname__"): - name = func.__qualname__.lower() - else: - name = func.__class__.__name__.lower() - if isinstance(input_types, pa.DataType): - input_types = [input_types] - return WindowUDF( - name=name, - func=func, - input_types=input_types, - return_type=return_type, - volatility=volatility, - ) + @staticmethod + def _create_window_udf( + func: Callable[[], WindowEvaluator], + input_types: pa.DataType | list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> WindowUDF: + """Create a WindowUDF instance from function arguments.""" + if not callable(func): + msg = "`func` argument must be callable" + raise TypeError(msg) + if not isinstance(func(), WindowEvaluator): + msg = "`func` must implement the abstract base class WindowEvaluator" + raise TypeError(msg) + + if name is None: + name = WindowUDF._get_default_name(func) + + input_types_list = WindowUDF._normalize_input_types(input_types) + + return WindowUDF( + name=name, + func=func, + input_types=input_types_list, + return_type=return_type, + volatility=volatility, + ) - def _decorator( - input_types: pa.DataType | list[pa.DataType], - return_type: _R, - volatility: Volatility | str, - name: Optional[str] = None, - ) -> Callable[..., Callable[..., Expr]]: - def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: - udwf_caller = WindowUDF.udwf( - func, input_types, return_type, volatility, name - ) + @staticmethod + def _get_default_name(func: Callable) -> str: + """Get the default name for a function based on its attributes.""" + if hasattr(func, "__qualname__"): + return func.__qualname__.lower() + return func.__class__.__name__.lower() - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Expr: - return udwf_caller(*args, **kwargs) + @staticmethod + def _normalize_input_types( + input_types: pa.DataType | list[pa.DataType], + ) -> list[pa.DataType]: + """Convert a single DataType to a list if needed.""" + if isinstance(input_types, pa.DataType): + return [input_types] + return input_types - return wrapper + @staticmethod + def _create_window_udf_decorator( + input_types: pa.DataType | list[pa.DataType], + return_type: _R, + volatility: Volatility | str, + name: Optional[str] = None, + ) -> Callable[..., Callable[..., Expr]]: + """Create a decorator for a WindowUDF.""" - return decorator + def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: + udwf_caller = WindowUDF._create_window_udf( + func, input_types, return_type, volatility, name + ) - if args and callable(args[0]): - # Case 1: Used as a function, require the first parameter to be callable - return _function(*args, **kwargs) - # Case 2: Used as a decorator with parameters - return _decorator(*args, **kwargs) + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Expr: + return udwf_caller(*args, **kwargs) + + return wrapper + + return decorator # Convenience exports so we can import instead of treating as From 6f253372238cf63c8d75a6717674c5a2b231794a Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 18:38:27 +0800 Subject: [PATCH 15/18] refactor: Update udwf method signature and simplify input handling - Changed the type hint for the return type in the _create_window_udf_decorator method to use pa.DataType directly instead of a TypeVar. - Simplified the handling of input types by removing redundant checks and directly using the input types list. - Removed unnecessary comments and cleaned up the code for better readability. - Updated the test for udwf to use parameterized tests for better coverage and maintainability. --- python/datafusion/udf.py | 29 +++++++------------- python/tests/test_udwf.py | 57 +++++++++------------------------------ 2 files changed, 22 insertions(+), 64 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 96efd00ba..5e55f0f28 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,7 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, overload +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload import pyarrow as pa @@ -30,12 +30,7 @@ from datafusion.expr import Expr if TYPE_CHECKING: - # for python 3.10 and above, we can use - # from typing import TypeAlias - # but for python 3.9, we use the following - from typing_extensions import TypeAlias - - _R: TypeAlias = pa.DataType + _R = TypeVar("_R", bound=pa.DataType) class Volatility(Enum): @@ -719,19 +714,13 @@ def _create_window_udf( msg = "`func` must implement the abstract base class WindowEvaluator" raise TypeError(msg) - if name is None: - name = WindowUDF._get_default_name(func) - - input_types_list = WindowUDF._normalize_input_types(input_types) - - return WindowUDF( - name=name, - func=func, - input_types=input_types_list, - return_type=return_type, - volatility=volatility, + name = name or func.__qualname__.lower() + input_types = ( + [input_types] if isinstance(input_types, pa.DataType) else input_types ) + return WindowUDF(name, func, input_types, return_type, volatility) + @staticmethod def _get_default_name(func: Callable) -> str: """Get the default name for a function based on its attributes.""" @@ -751,10 +740,10 @@ def _normalize_input_types( @staticmethod def _create_window_udf_decorator( input_types: pa.DataType | list[pa.DataType], - return_type: _R, + return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, - ) -> Callable[..., Callable[..., Expr]]: + ) -> Callable[[Callable[[], WindowEvaluator]], Callable[..., Expr]]: """Create a decorator for a WindowUDF.""" def decorator(func: Callable[[], WindowEvaluator]) -> Callable[..., Expr]: diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 3336a0acb..4190e7d64 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -428,55 +428,24 @@ def test_udwf_functions(complex_window_df, name, expr, expected): assert result.column(0) == pa.array(expected) -def test_udwf_overloads(count_window_df): - """Test different overload patterns for UDWF function.""" - # Single input type syntax - single_input = udwf( - SimpleWindowCount, pa.int64(), pa.int64(), volatility="immutable" - ) - - # List of input types syntax - list_input = udwf( - SimpleWindowCount, [pa.int64()], pa.int64(), volatility="immutable" - ) - - # Decorator syntax with single input type - @udwf(pa.int64(), pa.int64(), "immutable") - def window_count_single() -> WindowEvaluator: - return SimpleWindowCount() - - # Decorator syntax with list of input types - @udwf([pa.int64()], pa.int64(), "immutable") - def window_count_list() -> WindowEvaluator: - return SimpleWindowCount() - - # Test all variants produce the same result +@pytest.mark.parametrize( + "udwf_func", + [ + udwf(SimpleWindowCount, pa.int64(), pa.int64(), "immutable"), + udwf(SimpleWindowCount, [pa.int64()], pa.int64(), "immutable"), + udwf([pa.int64()], pa.int64(), "immutable")(lambda: SimpleWindowCount()), + udwf(pa.int64(), pa.int64(), "immutable")(lambda: SimpleWindowCount()), + ], +) +def test_udwf_overloads(udwf_func, count_window_df): df = count_window_df.select( - single_input(column("a")) + udwf_func(column("a")) .window_frame(WindowFrame("rows", None, None)) .build() - .alias("single"), - list_input(column("a")) - .window_frame(WindowFrame("rows", None, None)) - .build() - .alias("list"), - window_count_single(column("a")) - .window_frame(WindowFrame("rows", None, None)) - .build() - .alias("decorator_single"), - window_count_list(column("a")) - .window_frame(WindowFrame("rows", None, None)) - .build() - .alias("decorator_list"), + .alias("count") ) - result = df.collect()[0] - expected = pa.array([0, 1, 2]) - - assert result.column(0) == expected - assert result.column(1) == expected - assert result.column(2) == expected - assert result.column(3) == expected + assert result.column(0) == pa.array([0, 1, 2]) def test_udwf_named_function(ctx, count_window_df): From 20d5dd9ff55b246e0d6e9d9c3496523aa3d6d995 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Thu, 13 Mar 2025 18:55:13 +0800 Subject: [PATCH 16/18] refactor: Rename input_type to input_types in udwf method signature for clarity --- python/datafusion/udf.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 5e55f0f28..e93a34ca5 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -624,7 +624,7 @@ def __call__(self, *args: Expr) -> Expr: @overload @staticmethod def udwf( - input_type: pa.DataType | list[pa.DataType], + input_types: pa.DataType | list[pa.DataType], return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, @@ -634,7 +634,7 @@ def udwf( @staticmethod def udwf( func: Callable[[], WindowEvaluator], - input_type: pa.DataType | list[pa.DataType], + input_types: pa.DataType | list[pa.DataType], return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, @@ -702,13 +702,13 @@ def biased_numbers() -> BiasedNumbers: def _create_window_udf( func: Callable[[], WindowEvaluator], input_types: pa.DataType | list[pa.DataType], - return_type: _R, + return_type: pa.DataType, volatility: Volatility | str, name: Optional[str] = None, ) -> WindowUDF: """Create a WindowUDF instance from function arguments.""" if not callable(func): - msg = "`func` argument must be callable" + msg = "`func` must be callable." raise TypeError(msg) if not isinstance(func(), WindowEvaluator): msg = "`func` must implement the abstract base class WindowEvaluator" From 16dbe5f3fd88f42d0a304384b162009bd9e49a35 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 14 Mar 2025 17:26:18 +0800 Subject: [PATCH 17/18] refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization --- python/datafusion/udf.py | 18 ++++++++++---- python/tests/test_udwf.py | 50 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index e93a34ca5..03d8328ce 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,7 +22,16 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Optional, + Protocol, + runtime_checkable, + overload, + TypeVar, +) import pyarrow as pa @@ -429,8 +438,9 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) -class WindowEvaluator: - """Evaluator class for user-defined window functions (UDWF). +@runtime_checkable +class WindowEvaluator(Protocol): + """Protocol defining interface for user-defined window functions (UDWF). It is up to the user to decide which evaluate function is appropriate. @@ -711,7 +721,7 @@ def _create_window_udf( msg = "`func` must be callable." raise TypeError(msg) if not isinstance(func(), WindowEvaluator): - msg = "`func` must implement the abstract base class WindowEvaluator" + msg = "`func` must implement the WindowEvaluator protocol" raise TypeError(msg) name = name or func.__qualname__.lower() diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index 4190e7d64..c267e81f0 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df): def test_udwf_errors_with_message(): """Test error cases for UDWF creation.""" with pytest.raises( - TypeError, match="`func` must implement the abstract base class WindowEvaluator" + TypeError, match="`func` must implement the WindowEvaluator protocol" ): udwf( NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable" @@ -466,3 +466,51 @@ def test_udwf_named_function(ctx, count_window_df): FOLLOWING) FROM test_table""" ).collect()[0] assert result.column(0) == pa.array([0, 1, 2]) + + +def test_window_evaluator_protocol(count_window_df): + """Test that WindowEvaluator works as a Protocol without explicit inheritance.""" + + # Define a class that implements the Protocol interface without inheriting + class CounterWithoutInheritance: + def __init__(self, base: int = 0) -> None: + self.base = base + + def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: + return pa.array([self.base + i for i in range(num_rows)]) + + # Protocol methods with default implementations don't need to be defined + + # Create a UDWF using the class that doesn't inherit from WindowEvaluator + protocol_counter = udwf( + CounterWithoutInheritance, pa.int64(), pa.int64(), volatility="immutable" + ) + + # Use the window function + df = count_window_df.select( + protocol_counter(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + + result = df.collect()[0] + assert result.column(0) == pa.array([0, 1, 2]) + + # Also test with constructor args + protocol_counter_with_args = udwf( + lambda: CounterWithoutInheritance(10), + pa.int64(), + pa.int64(), + volatility="immutable", + ) + + df = count_window_df.select( + protocol_counter_with_args(column("a")) + .window_frame(WindowFrame("rows", None, None)) + .build() + .alias("count") + ) + + result = df.collect()[0] + assert result.column(0) == pa.array([10, 11, 12]) From 78c0203f718f8201b26441be2742d912fdbc013d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Fri, 14 Mar 2025 17:42:33 +0800 Subject: [PATCH 18/18] Revert "refactor: Enhance typing in udf.py by introducing Protocol for WindowEvaluator and improving import organization" This reverts commit 16dbe5f3fd88f42d0a304384b162009bd9e49a35. --- python/datafusion/udf.py | 18 ++++---------- python/tests/test_udwf.py | 50 +-------------------------------------- 2 files changed, 5 insertions(+), 63 deletions(-) diff --git a/python/datafusion/udf.py b/python/datafusion/udf.py index 03d8328ce..e93a34ca5 100644 --- a/python/datafusion/udf.py +++ b/python/datafusion/udf.py @@ -22,16 +22,7 @@ import functools from abc import ABCMeta, abstractmethod from enum import Enum -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Optional, - Protocol, - runtime_checkable, - overload, - TypeVar, -) +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, overload import pyarrow as pa @@ -438,9 +429,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Expr: return _decorator(*args, **kwargs) -@runtime_checkable -class WindowEvaluator(Protocol): - """Protocol defining interface for user-defined window functions (UDWF). +class WindowEvaluator: + """Evaluator class for user-defined window functions (UDWF). It is up to the user to decide which evaluate function is appropriate. @@ -721,7 +711,7 @@ def _create_window_udf( msg = "`func` must be callable." raise TypeError(msg) if not isinstance(func(), WindowEvaluator): - msg = "`func` must implement the WindowEvaluator protocol" + msg = "`func` must implement the abstract base class WindowEvaluator" raise TypeError(msg) name = name or func.__qualname__.lower() diff --git a/python/tests/test_udwf.py b/python/tests/test_udwf.py index c267e81f0..4190e7d64 100644 --- a/python/tests/test_udwf.py +++ b/python/tests/test_udwf.py @@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df): def test_udwf_errors_with_message(): """Test error cases for UDWF creation.""" with pytest.raises( - TypeError, match="`func` must implement the WindowEvaluator protocol" + TypeError, match="`func` must implement the abstract base class WindowEvaluator" ): udwf( NotSubclassOfWindowEvaluator, pa.int64(), pa.int64(), volatility="immutable" @@ -466,51 +466,3 @@ def test_udwf_named_function(ctx, count_window_df): FOLLOWING) FROM test_table""" ).collect()[0] assert result.column(0) == pa.array([0, 1, 2]) - - -def test_window_evaluator_protocol(count_window_df): - """Test that WindowEvaluator works as a Protocol without explicit inheritance.""" - - # Define a class that implements the Protocol interface without inheriting - class CounterWithoutInheritance: - def __init__(self, base: int = 0) -> None: - self.base = base - - def evaluate_all(self, values: list[pa.Array], num_rows: int) -> pa.Array: - return pa.array([self.base + i for i in range(num_rows)]) - - # Protocol methods with default implementations don't need to be defined - - # Create a UDWF using the class that doesn't inherit from WindowEvaluator - protocol_counter = udwf( - CounterWithoutInheritance, pa.int64(), pa.int64(), volatility="immutable" - ) - - # Use the window function - df = count_window_df.select( - protocol_counter(column("a")) - .window_frame(WindowFrame("rows", None, None)) - .build() - .alias("count") - ) - - result = df.collect()[0] - assert result.column(0) == pa.array([0, 1, 2]) - - # Also test with constructor args - protocol_counter_with_args = udwf( - lambda: CounterWithoutInheritance(10), - pa.int64(), - pa.int64(), - volatility="immutable", - ) - - df = count_window_df.select( - protocol_counter_with_args(column("a")) - .window_frame(WindowFrame("rows", None, None)) - .build() - .alias("count") - ) - - result = df.collect()[0] - assert result.column(0) == pa.array([10, 11, 12])