Skip to content

Commit ec66d20

Browse files
committed
Refactor test UDAF registration by removing unused accumulator code
1 parent 3154a6b commit ec66d20

File tree

1 file changed

+3
-39
lines changed

1 file changed

+3
-39
lines changed
Lines changed: 3 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,4 @@
1-
import pyarrow as pa
2-
from datafusion import udaf, SessionContext
3-
from datafusion.user_defined import Accumulator # base class for aggregators
4-
5-
# Define a simple test accumulator for demonstration:
6-
class TestAccumulator(Accumulator):
7-
def __init__(self) -> None:
8-
self.total = 0
9-
10-
def state(self) -> list[pa.Scalar]:
11-
return [pa.scalar(self.total)]
12-
13-
def update(self, *values: pa.Array) -> None:
14-
# Sum up integer values from the first argument
15-
self.total += sum(value.as_py() for value in values[0])
16-
17-
def merge(self, states: list[pa.Array]) -> None:
18-
# Assumes the state is a list with one scalar integer per actor
19-
self.total += sum(state[0].as_py() for state in states)
20-
21-
def evaluate(self) -> pa.Scalar:
22-
return pa.scalar(self.total)
23-
24-
# Create the test UDAF using TestAccumulator.
25-
# Note: the overload taking (accum, input_types, return_type, state_type, volatility, name)
26-
test_udaf = udaf(
27-
TestAccumulator, # accumulator function or type producing an Accumulator object
28-
[pa.int64()], # input types (list of one int64)
29-
pa.int64(), # return type
30-
[pa.int64()], # state type (list of one int64)
31-
"immutable", # volatility indicator
32-
name="test_udaf"
33-
)
34-
35-
# Register UDAF into a session context (if needed)
1+
from datafusion import SessionContext, udf, udaf
2+
from geodatafusion import native
363
ctx = SessionContext()
37-
ctx.register_udaf(test_udaf)
38-
39-
# The code should type check without error:
40-
print("Type checking passed for test_udaf!")
4+
ctx.register_udaf(udaf(native.Extent()))

0 commit comments

Comments
 (0)