|
28 | 28 | class Summarize(Accumulator): |
29 | 29 | """Interface of a user-defined accumulation.""" |
30 | 30 |
|
31 | | - def __init__(self, initial_value: float = 0.0): |
32 | | - self._sum = pa.scalar(initial_value) |
| 31 | + def __init__(self, initial_value: float = 0.0, as_scalar: bool = False): |
| 32 | + self._sum = initial_value |
| 33 | + self.as_scalar = as_scalar |
33 | 34 |
|
34 | 35 | def state(self) -> list[pa.Scalar]: |
| 36 | + if self.as_scalar: |
| 37 | + return [pa.scalar(self._sum)] |
35 | 38 | return [self._sum] |
36 | 39 |
|
37 | 40 | def update(self, values: pa.Array) -> None: |
38 | 41 | # Not nice since pyarrow scalars can't be summed yet. |
39 | 42 | # This breaks on `None` |
40 | | - self._sum = pa.scalar(self._sum.as_py() + pc.sum(values).as_py()) |
| 43 | + self._sum = self._sum + pc.sum(values).as_py() |
41 | 44 |
|
42 | 45 | def merge(self, states: list[pa.Array]) -> None: |
43 | 46 | # Not nice since pyarrow scalars can't be summed yet. |
44 | 47 | # This breaks on `None` |
45 | | - self._sum = pa.scalar(self._sum.as_py() + pc.sum(states[0]).as_py()) |
| 48 | + self._sum = self._sum + pc.sum(states[0]).as_py() |
46 | 49 |
|
47 | 50 | def evaluate(self) -> pa.Scalar: |
| 51 | + if self.as_scalar: |
| 52 | + return pa.scalar(self._sum) |
48 | 53 | return self._sum |
49 | 54 |
|
50 | 55 |
|
@@ -163,11 +168,12 @@ def summarize(): |
163 | 168 | assert result.column(0) == pa.array([1.0 + 2.0 + 3.0]) |
164 | 169 |
|
165 | 170 |
|
166 | | -def test_udaf_aggregate_with_arguments(df): |
| 171 | +@pytest.mark.parametrize("as_scalar", [True, False]) |
| 172 | +def test_udaf_aggregate_with_arguments(df, as_scalar): |
167 | 173 | bias = 10.0 |
168 | 174 |
|
169 | 175 | summarize = udaf( |
170 | | - lambda: Summarize(bias), |
| 176 | + lambda: Summarize(initial_value=bias, as_scalar=as_scalar), |
171 | 177 | pa.float64(), |
172 | 178 | pa.float64(), |
173 | 179 | [pa.float64()], |
|
0 commit comments