diff --git a/clients/python/src/taskbroker_client/worker/client.py b/clients/python/src/taskbroker_client/worker/client.py index 14d2112e..d8b1f104 100644 --- a/clients/python/src/taskbroker_client/worker/client.py +++ b/clients/python/src/taskbroker_client/worker/client.py @@ -225,7 +225,18 @@ def request_deserializer(serialized_request: bytes) -> Any: _RPC_SIGNATURE_AUTH_TLS.failed = True return inner_deserializer(b"") - return inner_deserializer(serialized_request) + try: + return inner_deserializer(serialized_request) + except Exception: + # gRPC swallows deserializer exceptions in `grpc._common._transform` + # (it logs to the `grpc._common` logger and returns None), and the server + # then aborts the call with an opaque INTERNAL "Exception deserializing + # request!". Log here so the failure is visible on our own logger. + logger.exception( + "taskworker.grpc_server.request_deserialization_failed", + extra={"method": method}, + ) + raise def unary_unary(request: Any, context: grpc.ServicerContext) -> Any: if getattr(_RPC_SIGNATURE_AUTH_TLS, "failed", False): diff --git a/clients/python/tests/worker/test_client.py b/clients/python/tests/worker/test_client.py index 33d9083a..7db50de2 100644 --- a/clients/python/tests/worker/test_client.py +++ b/clients/python/tests/worker/test_client.py @@ -5,7 +5,9 @@ import string import time from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Generator +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from pathlib import Path from typing import Any from unittest.mock import Mock, patch @@ -13,6 +15,7 @@ import grpc import pytest from google.protobuf.message import Message +from sentry_protos.taskbroker.v1 import taskbroker_pb2_grpc from sentry_protos.taskbroker.v1.taskbroker_pb2 import ( TASK_ACTIVATION_STATUS_COMPLETE, TASK_ACTIVATION_STATUS_RETRY, @@ -20,6 +23,7 @@ GetTaskRequest, GetTaskResponse, PushTaskRequest, + PushTaskResponse, SetTaskStatusRequest, SetTaskStatusResponse, TaskActivation, @@ -394,6 +398,70 @@ def test_request_signature_server_interceptor_skips_grpc_health_check() -> None: assert interceptor.intercept_service(lambda _: handler, handler_call_details) is handler +class _RecordingWorkerServicer(taskbroker_pb2_grpc.WorkerServiceServicer): + """Records the requests it receives, like the real WorkerServicer would process them""" + + def __init__(self) -> None: + self.requests: list[PushTaskRequest] = [] + + def PushTask(self, request: PushTaskRequest, context: grpc.ServicerContext) -> PushTaskResponse: + self.requests.append(request) + return PushTaskResponse() + + +@contextmanager +def _running_worker_server( + secrets: list[str], +) -> Generator[tuple[_RecordingWorkerServicer, grpc.Channel], None, None]: + """Run a real gRPC server with the signature interceptor, as PushTaskWorker does. + + The signature is verified inside a deserializer wrapper that gRPC runs on a + different thread than the servicer, so this behaviour can only be exercised + against a real server, not a mocked channel. + """ + servicer = _RecordingWorkerServicer() + server = grpc.server( + ThreadPoolExecutor(max_workers=2), + interceptors=[RequestSignatureServerInterceptor(secrets)], + ) + taskbroker_pb2_grpc.add_WorkerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") + server.start() + channel = grpc.insecure_channel(f"localhost:{port}") + try: + yield servicer, channel + finally: + channel.close() + server.stop(grace=None) + + +def test_request_signature_server_interceptor_logs_deserialization_errors( + caplog: pytest.LogCaptureFixture, +) -> None: + # gRPC swallows deserializer exceptions in `grpc._common._transform` and the + # server aborts the call with an opaque INTERNAL status, so the interceptor must + # log the exception itself to leave a trace on our own logger. + body = b"\xff\xff\xff\xff not a protobuf" + signature = _push_task_hmac(b"secret", body) + + with _running_worker_server(["secret"]) as (servicer, channel): + multicallable = channel.unary_unary( + _PUSH_TASK_METHOD, + request_serializer=None, + response_deserializer=PushTaskResponse.FromString, + ) + with caplog.at_level("ERROR", logger="taskbroker_client.worker.client"): + with pytest.raises(grpc.RpcError) as excinfo: + multicallable(body, timeout=5, metadata=(("sentry-signature", signature),)) + + assert excinfo.value.code() == grpc.StatusCode.INTERNAL + assert servicer.requests == [] + assert any( + record.message == "taskworker.grpc_server.request_deserialization_failed" + for record in caplog.records + ) + + def test_get_task_with_namespace() -> None: channel = MockChannel() channel.add_response(