Skip to content

Commit 5b458ad

Browse files
authored
fix: Validate message fields before protobuf encoding for better error messages (#333)
Signed-off-by: Sreekanth <prsreekanth920@gmail.com>
1 parent b37f25f commit 5b458ad

16 files changed

Lines changed: 275 additions & 6 deletions

File tree

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
def _validate_message_fields(value, keys, tags):
2+
"""Validate common Message fields at construction time.
3+
4+
Raises TypeError with a clear message pointing at the caller's code
5+
rather than letting bad types propagate to protobuf serialization.
6+
"""
7+
if value is not None and not isinstance(value, bytes):
8+
raise TypeError(f"Message 'value' must be bytes, got {type(value).__name__}")
9+
if keys is not None:
10+
if not isinstance(keys, list) or not all(isinstance(k, str) for k in keys):
11+
raise TypeError(f"Message 'keys' must be a list of strings, got {keys!r}")
12+
if tags is not None:
13+
if not isinstance(tags, list) or not all(isinstance(t, str) for t in tags):
14+
raise TypeError(f"Message 'tags' must be a list of strings, got {tags!r}")

packages/pynumaflow/pynumaflow/accumulator/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pynumaflow.shared.asynciter import NonBlockingIterator
1010
from pynumaflow._constants import DROP
11+
from pynumaflow._validate import _validate_message_fields
1112

1213
M = TypeVar("M", bound="Message")
1314

@@ -389,6 +390,7 @@ def __init__(
389390
"""
390391
Creates a Message object to send value to a vertex.
391392
"""
393+
_validate_message_fields(value, keys, tags)
392394
self._keys = keys or []
393395
self._tags = tags or []
394396
self._value = value or b""

packages/pynumaflow/pynumaflow/batchmapper/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import AsyncIterable, Callable
77

88
from pynumaflow._constants import DROP
9+
from pynumaflow._validate import _validate_message_fields
910

1011
M = TypeVar("M", bound="Message")
1112
B = TypeVar("B", bound="BatchResponse")
@@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
3132
"""
3233
Creates a Message object to send value to a vertex.
3334
"""
35+
_validate_message_fields(value, keys, tags)
3436
self._keys = keys or []
3537
self._tags = tags or []
3638
self._value = value or b""

packages/pynumaflow/pynumaflow/mapper/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pynumaflow._constants import DROP
1010
from pynumaflow._metadata import UserMetadata, SystemMetadata
11+
from pynumaflow._validate import _validate_message_fields
1112

1213
M = TypeVar("M", bound="Message")
1314
Ms = TypeVar("Ms", bound="Messages")
@@ -40,6 +41,7 @@ def __init__(
4041
"""
4142
Creates a Message object to send value to a vertex.
4243
"""
44+
_validate_message_fields(value, keys, tags)
4345
self._keys = keys or []
4446
self._tags = tags or []
4547
self._value = value or b""

packages/pynumaflow/pynumaflow/mapstreamer/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from warnings import warn
88

99
from pynumaflow._constants import DROP
10+
from pynumaflow._validate import _validate_message_fields
1011

1112
M = TypeVar("M", bound="Message")
1213
Ms = TypeVar("Ms", bound="Messages")
@@ -31,6 +32,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
3132
"""
3233
Creates a Message object to send value to a vertex.
3334
"""
35+
_validate_message_fields(value, keys, tags)
3436
self._keys = keys or []
3537
self._tags = tags or []
3638
self._value = value or b""

packages/pynumaflow/pynumaflow/reducer/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from pynumaflow.shared.asynciter import NonBlockingIterator
1212
from pynumaflow._constants import DROP
13+
from pynumaflow._validate import _validate_message_fields
1314

1415
M = TypeVar("M", bound="Message")
1516
Ms = TypeVar("Ms", bound="Messages")
@@ -48,6 +49,7 @@ def __init__(self, value: bytes, keys: list[str] | None = None, tags: list[str]
4849
"""
4950
Creates a Message object to send value to a vertex.
5051
"""
52+
_validate_message_fields(value, keys, tags)
5153
self._keys = keys or []
5254
self._tags = tags or []
5355
self._value = value or b""

packages/pynumaflow/pynumaflow/reducestreamer/_dtypes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from pynumaflow.shared.asynciter import NonBlockingIterator
1010
from pynumaflow._constants import DROP
11+
from pynumaflow._validate import _validate_message_fields
1112

1213
M = TypeVar("M", bound="Message")
1314

@@ -270,6 +271,7 @@ def __init__(
270271
"""
271272
Creates a Message object to send value to a vertex.
272273
"""
274+
_validate_message_fields(value, keys, tags)
273275
self._keys = keys or []
274276
self._tags = tags or []
275277
self._value = value or b""

packages/pynumaflow/pynumaflow/reducestreamer/async_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,13 @@ def start(self):
171171
_LOGGER.info(
172172
"Starting Async Reduce Stream Server",
173173
)
174-
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=self.shutdown_callback)
174+
175+
def _shutdown_handler(loop):
176+
_LOGGER.info("Received graceful shutdown signal, shutting down ReduceStream server")
177+
if self.shutdown_callback:
178+
self.shutdown_callback(loop)
179+
180+
aiorun.run(self.aexec(), use_uvloop=True, shutdown_callback=_shutdown_handler)
175181
if self._error:
176182
_LOGGER.critical("Server exiting due to UDF error: %s", self._error)
177183
sys.exit(1)

packages/pynumaflow/pynumaflow/reducestreamer/servicer/async_servicer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ async def ReduceFn(
101101
# If the message is an exception, we raise the exception
102102
if isinstance(msg, BaseException):
103103
err_msg = f"ReduceStreamError, {ERR_UDF_EXCEPTION_STRING}: {repr(msg)}"
104-
_LOGGER.critical(err_msg, exc_info=True)
104+
_LOGGER.critical(err_msg, exc_info=msg)
105105
update_context_err(context, msg, err_msg)
106106
self._error = msg
107107
if self._shutdown_event is not None:

packages/pynumaflow/pynumaflow/reducestreamer/servicer/task_manager.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,18 @@ async def process_input_stream(self, request_iterator: AsyncIterable[reduce_pb2.
275275
_LOGGER.critical(err_msg, exc_info=True)
276276
await self.global_result_queue.put(e)
277277

278+
# Cancel and await remaining tasks to suppress "never retrieved" warnings
279+
for task in self.get_tasks():
280+
for fut in (task.future, task.consumer_future):
281+
if fut and not fut.done():
282+
fut.cancel()
283+
for fut in (task.future, task.consumer_future):
284+
if fut:
285+
try:
286+
await fut
287+
except (asyncio.CancelledError, BaseException):
288+
pass
289+
278290
async def write_to_global_queue(
279291
self, input_queue: NonBlockingIterator, output_queue: NonBlockingIterator, window
280292
):
@@ -284,10 +296,19 @@ async def write_to_global_queue(
284296
to the global result queue
285297
"""
286298
reader = input_queue.read_iterator()
287-
async for msg in reader:
288-
res = reduce_pb2.ReduceResponse.Result(keys=msg.keys, value=msg.value, tags=msg.tags)
289-
out = reduce_pb2.ReduceResponse(result=res, window=window)
290-
await output_queue.put(out)
299+
try:
300+
async for msg in reader:
301+
res = reduce_pb2.ReduceResponse.Result(
302+
keys=msg.keys, value=msg.value, tags=msg.tags
303+
)
304+
out = reduce_pb2.ReduceResponse(result=res, window=window)
305+
await output_queue.put(out)
306+
except Exception as e:
307+
# Using Exception (not BaseException) so that asyncio.CancelledError
308+
# (a BaseException subclass in Python 3.9+) propagates normally
309+
# when the task is cancelled during shutdown.
310+
_LOGGER.critical("Error serializing reduce result: %s", e, exc_info=True)
311+
await output_queue.put(e)
291312

292313
def clean_background(self, task):
293314
"""

0 commit comments

Comments
 (0)