Skip to content

Commit 67a6bc1

Browse files
committed
Update tests to pass back raw python values or pyarrow scalar
1 parent 390d753 commit 67a6bc1

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

python/tests/test_udaf.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,28 @@
2828
class Summarize(Accumulator):
2929
"""Interface of a user-defined accumulation."""
3030

31-
def __init__(self, initial_value: float = 0.0):
32-
self._sum = pa.scalar(initial_value)
31+
def __init__(self, initial_value: float = 0.0, as_scalar: bool = False):
32+
self._sum = initial_value
33+
self.as_scalar = as_scalar
3334

3435
def state(self) -> list[pa.Scalar]:
36+
if self.as_scalar:
37+
return [pa.scalar(self._sum)]
3538
return [self._sum]
3639

3740
def update(self, values: pa.Array) -> None:
3841
# Not nice since pyarrow scalars can't be summed yet.
3942
# This breaks on `None`
40-
self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py())
43+
self._sum = self._sum + pc.sum(values).as_py()
4144

4245
def merge(self, states: list[pa.Array]) -> None:
4346
# Not nice since pyarrow scalars can't be summed yet.
4447
# This breaks on `None`
45-
self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py())
48+
self._sum = self._sum + pc.sum(states[0]).as_py()
4649

4750
def evaluate(self) -> pa.Scalar:
51+
if self.as_scalar:
52+
return pa.scalar(self._sum)
4853
return self._sum
4954

5055

@@ -163,11 +168,12 @@ def summarize():
163168
assert result.column(0) == pa.array([1.0 + 2.0 + 3.0])
164169

165170

166-
def test_udaf_aggregate_with_arguments(df):
171+
@pytest.mark.parametrize("as_scalar", [True, False])
172+
def test_udaf_aggregate_with_arguments(df, as_scalar):
167173
bias = 10.0
168174

169175
summarize = udaf(
170-
lambda: Summarize(bias),
176+
lambda: Summarize(initial_value=bias, as_scalar=as_scalar),
171177
pa.float64(),
172178
pa.float64(),
173179
[pa.float64()],

0 commit comments

Comments
 (0)