@@ -218,7 +218,7 @@ def test_udwf_errors(complex_window_df):
218218def test_udwf_errors_with_message ():
219219 """Test error cases for UDWF creation."""
220220 with pytest .raises (
221- TypeError , match = "`func` must implement the abstract base class WindowEvaluator"
221+ TypeError , match = "`func` must implement the WindowEvaluator protocol "
222222 ):
223223 udwf (
224224 NotSubclassOfWindowEvaluator , pa .int64 (), pa .int64 (), volatility = "immutable"
@@ -466,3 +466,51 @@ def test_udwf_named_function(ctx, count_window_df):
466466 FOLLOWING) FROM test_table"""
467467 ).collect ()[0 ]
468468 assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
469+
470+
471+ def test_window_evaluator_protocol (count_window_df ):
472+ """Test that WindowEvaluator works as a Protocol without explicit inheritance."""
473+
474+ # Define a class that implements the Protocol interface without inheriting
475+ class CounterWithoutInheritance :
476+ def __init__ (self , base : int = 0 ) -> None :
477+ self .base = base
478+
479+ def evaluate_all (self , values : list [pa .Array ], num_rows : int ) -> pa .Array :
480+ return pa .array ([self .base + i for i in range (num_rows )])
481+
482+ # Protocol methods with default implementations don't need to be defined
483+
484+ # Create a UDWF using the class that doesn't inherit from WindowEvaluator
485+ protocol_counter = udwf (
486+ CounterWithoutInheritance , pa .int64 (), pa .int64 (), volatility = "immutable"
487+ )
488+
489+ # Use the window function
490+ df = count_window_df .select (
491+ protocol_counter (column ("a" ))
492+ .window_frame (WindowFrame ("rows" , None , None ))
493+ .build ()
494+ .alias ("count" )
495+ )
496+
497+ result = df .collect ()[0 ]
498+ assert result .column (0 ) == pa .array ([0 , 1 , 2 ])
499+
500+ # Also test with constructor args
501+ protocol_counter_with_args = udwf (
502+ lambda : CounterWithoutInheritance (10 ),
503+ pa .int64 (),
504+ pa .int64 (),
505+ volatility = "immutable" ,
506+ )
507+
508+ df = count_window_df .select (
509+ protocol_counter_with_args (column ("a" ))
510+ .window_frame (WindowFrame ("rows" , None , None ))
511+ .build ()
512+ .alias ("count" )
513+ )
514+
515+ result = df .collect ()[0 ]
516+ assert result .column (0 ) == pa .array ([10 , 11 , 12 ])
0 commit comments