Skip to content

Commit 7a908f2

Browse files
authored
fix: Graceful shutdown for all UDFs (#337)
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent bcdc4c3 commit 7a908f2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+2419
-225
lines changed

packages/pynumaflow/pynumaflow/accumulator/async_server.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import asyncio
2+
import contextlib
13
import inspect
4+
import sys
25

36
import aiorun
47
import grpc
58

69
from pynumaflow.accumulator.servicer.async_servicer import AsyncAccumulatorServicer
10+
from pynumaflow.info.server import write as info_server_write
711
from pynumaflow.info.types import ServerInfo, ContainerType, MINIMUM_NUMAFLOW_VERSION
812
from pynumaflow.proto.accumulator import accumulator_pb2_grpc
913

@@ -15,6 +19,7 @@
1519
MAX_NUM_THREADS,
1620
ACCUMULATOR_SOCK_PATH,
1721
ACCUMULATOR_SERVER_INFO_FILE_PATH,
22+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
1823
)
1924

2025
from pynumaflow.accumulator._dtypes import (
@@ -23,7 +28,7 @@
2328
Accumulator,
2429
)
2530

26-
from pynumaflow.shared.server import NumaflowServer, check_instance, start_async_server
31+
from pynumaflow.shared.server import NumaflowServer, check_instance
2732

2833

2934
def get_handler(
@@ -157,6 +162,7 @@ def __init__(
157162
]
158163
# Get the servicer instance for the async server
159164
self.servicer = AsyncAccumulatorServicer(self.accumulator_handler)
165+
self._error: BaseException | None = None
160166

161167
def start(self):
162168
"""
@@ -167,6 +173,9 @@ def start(self):
167173
"Starting Async Accumulator Server",
168174
)
169175
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
176+
if self._error:
177+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
178+
sys.exit(1)
170179

171180
async def aexec(self):
172181
"""
@@ -176,18 +185,62 @@ async def aexec(self):
176185
# As the server is async, we need to create a new server instance in the
177186
# same thread as the event loop so that all the async calls are made in the
178187
# same context
179-
# Create a new async server instance and add the servicer to it
180188
server = grpc.aio.server(options=self._server_options)
181189
server.add_insecure_port(self.sock_path)
190+
191+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
192+
# because it must be bound to the running event loop that aiorun creates.
193+
# At __init__ time no event loop exists yet.
194+
shutdown_event = asyncio.Event()
195+
self.servicer.set_shutdown_event(shutdown_event)
196+
182197
accumulator_pb2_grpc.add_AccumulatorServicer_to_server(self.servicer, server)
183198

184199
serv_info = ServerInfo.get_default_server_info()
185200
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Accumulator]
186-
await start_async_server(
187-
server_async=server,
188-
sock_path=self.sock_path,
189-
max_threads=self.max_threads,
190-
cleanup_coroutines=list(),
191-
server_info_file=self.server_info_file,
192-
server_info=serv_info,
201+
202+
await server.start()
203+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
204+
205+
_LOGGER.info(
206+
"Async GRPC Server listening on: %s with max threads: %s",
207+
self.sock_path,
208+
self.max_threads,
193209
)
210+
211+
async def _watch_for_shutdown():
212+
"""Wait for the shutdown event and stop the server with a grace period."""
213+
await shutdown_event.wait()
214+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
215+
# Stop accepting new requests and wait for a maximum of
216+
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
217+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
218+
219+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
220+
try:
221+
await server.wait_for_termination()
222+
except asyncio.CancelledError:
223+
# SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error
224+
# path (where _watch_for_shutdown calls server.stop()), this path
225+
# must stop the gRPC server explicitly. Without this, the server
226+
# object is never stopped and when it is garbage-collected, its
227+
# __del__ tries to schedule a cleanup coroutine on an event loop
228+
# that is already closed, causing errors/warnings.
229+
_LOGGER.info("Received cancellation, stopping server gracefully...")
230+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
231+
232+
# Propagate error so start() can exit with a non-zero code
233+
self._error = self.servicer._error
234+
235+
shutdown_task.cancel()
236+
with contextlib.suppress(asyncio.CancelledError):
237+
await shutdown_task
238+
239+
_LOGGER.info("Stopping event loop...")
240+
# We use aiorun to manage the event loop. The aiorun.run() runs
241+
# forever until loop.stop() is called. If we don't stop the
242+
# event loop explicitly here, the python process will not exit.
243+
# It reamins stuck for 5 minutes until liveness and readiness probe
244+
# fails enough times and k8s sends a SIGTERM
245+
asyncio.get_running_loop().stop()
246+
_LOGGER.info("Event loop stopped")

packages/pynumaflow/pynumaflow/accumulator/servicer/async_servicer.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from google.protobuf import empty_pb2 as _empty_pb2
55

6-
from pynumaflow._constants import ERR_UDF_EXCEPTION_STRING
6+
from pynumaflow._constants import _LOGGER, ERR_UDF_EXCEPTION_STRING
77
from pynumaflow.proto.accumulator import accumulator_pb2, accumulator_pb2_grpc
88
from pynumaflow.accumulator._dtypes import (
99
Datum,
@@ -13,7 +13,7 @@
1313
KeyedWindow,
1414
)
1515
from pynumaflow.accumulator.servicer.task_manager import TaskManager
16-
from pynumaflow.shared.server import handle_async_error
16+
from pynumaflow.shared.server import update_context_err
1717
from pynumaflow.types import NumaflowServicerContext
1818

1919

@@ -57,6 +57,12 @@ def __init__(
5757
):
5858
# The accumulator handler can be a function or a builder class instance.
5959
self.__accumulator_handler: AccumulatorAsyncCallable | _AccumulatorBuilderClass = handler
60+
self._shutdown_event: asyncio.Event | None = None
61+
self._error: BaseException | None = None
62+
63+
def set_shutdown_event(self, event: asyncio.Event):
64+
"""Wire up the shutdown event created by the server's aexec() coroutine."""
65+
self._shutdown_event = event
6066

6167
async def AccumulateFn(
6268
self,
@@ -104,20 +110,49 @@ async def AccumulateFn(
104110
async for msg in consumer:
105111
# If the message is an exception, we raise the exception
106112
if isinstance(msg, BaseException):
107-
await handle_async_error(context, msg, ERR_UDF_EXCEPTION_STRING)
113+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
114+
_LOGGER.critical(err_msg, exc_info=True)
115+
update_context_err(context, msg, err_msg)
116+
self._error = msg
117+
if self._shutdown_event is not None:
118+
self._shutdown_event.set()
108119
return
109120
# Send window EOF response or Window result response
110121
# back to the client
111122
else:
112123
yield msg
124+
except asyncio.CancelledError:
125+
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
126+
_LOGGER.info("Server shutting down, cancelling RPC.")
127+
if self._shutdown_event is not None:
128+
self._shutdown_event.set()
129+
return
130+
113131
except BaseException as e:
114-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
132+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
133+
_LOGGER.critical(err_msg, exc_info=True)
134+
update_context_err(context, e, err_msg)
135+
self._error = e
136+
if self._shutdown_event is not None:
137+
self._shutdown_event.set()
115138
return
116139
# Wait for the process_input_stream task to finish for a clean exit
117140
try:
118141
await producer
142+
except asyncio.CancelledError:
143+
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
144+
_LOGGER.info("Server shutting down, cancelling RPC.")
145+
if self._shutdown_event is not None:
146+
self._shutdown_event.set()
147+
return
148+
119149
except BaseException as e:
120-
await handle_async_error(context, e, ERR_UDF_EXCEPTION_STRING)
150+
err_msg = f"{ERR_UDF_EXCEPTION_STRING}: {repr(e)}"
151+
_LOGGER.critical(err_msg, exc_info=True)
152+
update_context_err(context, e, err_msg)
153+
self._error = e
154+
if self._shutdown_event is not None:
155+
self._shutdown_event.set()
121156
return
122157

123158
async def IsReady(

packages/pynumaflow/pynumaflow/accumulator/servicer/task_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ async def __invoke_accumulator(
213213
_ = await new_instance(request_iterator, output)
214214
# send EOF to the output stream
215215
await output.put(STREAM_EOF)
216+
except asyncio.CancelledError:
217+
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
218+
return
216219
# If there is an error in the accumulator operation, log and
217220
# then send the error to the result queue
218221
except BaseException as err:
@@ -243,6 +246,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator
243246
case _:
244247
_LOGGER.debug(f"No operation matched for request: {request}", exc_info=True)
245248

249+
except asyncio.CancelledError:
250+
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
251+
return
246252
# If there is an error in the accumulator operation, log and
247253
# then send the error to the result queue
248254
except BaseException as e:
@@ -274,6 +280,9 @@ async def process_input_stream(self, request_iterator: AsyncIterable[Accumulator
274280

275281
# Now send STREAM_EOF to terminate the global result queue iterator
276282
await self.global_result_queue.put(STREAM_EOF)
283+
except asyncio.CancelledError:
284+
# Task cancelled during shutdown (e.g. SIGTERM) — not a UDF fault.
285+
return
277286
except BaseException as e:
278287
err_msg = f"Accumulator Streaming Error: {repr(e)}"
279288
_LOGGER.critical(err_msg, exc_info=True)

packages/pynumaflow/pynumaflow/batchmapper/async_server.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import asyncio
2+
import contextlib
3+
import sys
4+
15
import aiorun
26
import grpc
37

@@ -8,9 +12,11 @@
812
BATCH_MAP_SOCK_PATH,
913
MAP_SERVER_INFO_FILE_PATH,
1014
MAX_NUM_THREADS,
15+
NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS,
1116
)
1217
from pynumaflow.batchmapper._dtypes import BatchMapCallable
1318
from pynumaflow.batchmapper.servicer.async_servicer import AsyncBatchMapServicer
19+
from pynumaflow.info.server import write as info_server_write
1420
from pynumaflow.info.types import (
1521
ServerInfo,
1622
MAP_MODE_KEY,
@@ -19,7 +25,7 @@
1925
ContainerType,
2026
)
2127
from pynumaflow.proto.mapper import map_pb2_grpc
22-
from pynumaflow.shared.server import NumaflowServer, start_async_server
28+
from pynumaflow.shared.server import NumaflowServer
2329

2430

2531
class BatchMapAsyncServer(NumaflowServer):
@@ -92,13 +98,17 @@ async def handler(
9298
]
9399

94100
self.servicer = AsyncBatchMapServicer(handler=self.batch_mapper_instance)
101+
self._error: BaseException | None = None
95102

96103
def start(self):
97104
"""
98105
Starter function for the Async Batch Map server, we need a separate caller
99106
to the aexec so that all the async coroutines can be started from a single context
100107
"""
101108
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
109+
if self._error:
110+
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
111+
sys.exit(1)
102112

103113
async def aexec(self):
104114
"""
@@ -108,25 +118,64 @@ async def aexec(self):
108118
# As the server is async, we need to create a new server instance in the
109119
# same thread as the event loop so that all the async calls are made in the
110120
# same context
111-
# Create a new async server instance and add the servicer to it
112121
server = grpc.aio.server(options=self._server_options)
113122
server.add_insecure_port(self.sock_path)
114-
map_pb2_grpc.add_MapServicer_to_server(
115-
self.servicer,
116-
server,
117-
)
118-
_LOGGER.info("Starting Batch Map Server")
123+
124+
# The asyncio.Event must be created here (inside aexec) rather than in __init__,
125+
# because it must be bound to the running event loop that aiorun creates.
126+
# At __init__ time no event loop exists yet.
127+
shutdown_event = asyncio.Event()
128+
self.servicer.set_shutdown_event(shutdown_event)
129+
130+
map_pb2_grpc.add_MapServicer_to_server(self.servicer, server)
131+
119132
serv_info = ServerInfo.get_default_server_info()
120133
serv_info.minimum_numaflow_version = MINIMUM_NUMAFLOW_VERSION[ContainerType.Mapper]
121134
# Add the MAP_MODE metadata to the server info for the correct map mode
122135
serv_info.metadata[MAP_MODE_KEY] = MapMode.BatchMap
123136

124-
# Start the async server
125-
await start_async_server(
126-
server_async=server,
127-
sock_path=self.sock_path,
128-
max_threads=self.max_threads,
129-
cleanup_coroutines=list(),
130-
server_info_file=self.server_info_file,
131-
server_info=serv_info,
137+
await server.start()
138+
info_server_write(server_info=serv_info, info_file=self.server_info_file)
139+
140+
_LOGGER.info(
141+
"Async GRPC Server listening on: %s with max threads: %s",
142+
self.sock_path,
143+
self.max_threads,
132144
)
145+
146+
async def _watch_for_shutdown():
147+
"""Wait for the shutdown event and stop the server with a grace period."""
148+
await shutdown_event.wait()
149+
_LOGGER.info("Shutdown signal received, stopping server gracefully...")
150+
# Stop accepting new requests and wait for a maximum of
151+
# NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS seconds for in-flight requests to complete
152+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
153+
154+
shutdown_task = asyncio.create_task(_watch_for_shutdown())
155+
try:
156+
await server.wait_for_termination()
157+
except asyncio.CancelledError:
158+
# SIGTERM received — aiorun cancels all tasks. Unlike the UDF-error
159+
# path (where _watch_for_shutdown calls server.stop()), this path
160+
# must stop the gRPC server explicitly. Without this, the server
161+
# object is never stopped and when it is garbage-collected, its
162+
# __del__ tries to schedule a cleanup coroutine on an event loop
163+
# that is already closed, causing errors/warnings.
164+
_LOGGER.info("Received cancellation, stopping server gracefully...")
165+
await server.stop(NUMAFLOW_GRPC_SHUTDOWN_GRACE_PERIOD_SECONDS)
166+
167+
# Propagate error so start() can exit with a non-zero code
168+
self._error = self.servicer._error
169+
170+
shutdown_task.cancel()
171+
with contextlib.suppress(asyncio.CancelledError):
172+
await shutdown_task
173+
174+
_LOGGER.info("Stopping event loop...")
175+
# We use aiorun to manage the event loop. The aiorun.run() runs
176+
# forever until loop.stop() is called. If we don't stop the
177+
# event loop explicitly here, the python process will not exit.
178+
# It reamins stuck for 5 minutes until liveness and readiness probe
179+
# fails enough times and k8s sends a SIGTERM
180+
asyncio.get_running_loop().stop()
181+
_LOGGER.info("Event loop stopped")

0 commit comments

Comments
 (0)