diff --git a/README.md b/README.md index 1913ad0..768fabc 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Tusk Drift currently supports the following packages and versions: | HTTPX | all versions | | aiohttp | all versions | | urllib3 | all versions | +| grpcio (client-side only) | all versions | | psycopg | `>=3.1.12` | | psycopg2 | all versions | | Redis | `>=4.0.0` | diff --git a/drift/core/drift_sdk.py b/drift/core/drift_sdk.py index 678779a..5c013ae 100644 --- a/drift/core/drift_sdk.py +++ b/drift/core/drift_sdk.py @@ -451,6 +451,16 @@ def _init_auto_instrumentations(self) -> None: except ImportError: pass + try: + import grpc # type: ignore[unresolved-import] + + from ..instrumentation.grpc import GrpcInstrumentation + + _ = GrpcInstrumentation() + logger.debug("gRPC instrumentation initialized") + except ImportError: + pass + try: import django diff --git a/drift/instrumentation/grpc/__init__.py b/drift/instrumentation/grpc/__init__.py new file mode 100644 index 0000000..66cc70a --- /dev/null +++ b/drift/instrumentation/grpc/__init__.py @@ -0,0 +1,5 @@ +"""gRPC client instrumentation.""" + +from .instrumentation import GrpcInstrumentation + +__all__ = ["GrpcInstrumentation"] diff --git a/drift/instrumentation/grpc/e2e-tests/.tusk/config.yaml b/drift/instrumentation/grpc/e2e-tests/.tusk/config.yaml new file mode 100644 index 0000000..cd9e2bb --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/.tusk/config.yaml @@ -0,0 +1,27 @@ +version: 1 + +service: + id: "grpc-e2e-test-id" + name: "grpc-e2e-test" + port: 8000 + start: + command: "python src/app.py" + readiness_check: + command: "curl -f http://localhost:8000/health" + timeout: 45s + interval: 5s + +tusk_api: + url: "http://localhost:8000" + +test_execution: + concurrent_limit: 10 + batch_size: 10 + timeout: 30s + +recording: + sampling_rate: 1.0 + export_spans: false + +replay: + enable_telemetry: false diff --git a/drift/instrumentation/grpc/e2e-tests/Dockerfile b/drift/instrumentation/grpc/e2e-tests/Dockerfile new file mode 100644 index 0000000..91dc058 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/Dockerfile @@ -0,0 +1,21 @@ +FROM python-e2e-base:latest + +# Copy SDK source for editable install +COPY . /sdk + +# Copy test files +COPY drift/instrumentation/grpc/e2e-tests /app + +WORKDIR /app + +# Install dependencies (requirements.txt uses -e /sdk for SDK) +RUN pip install -q -r requirements.txt + +# Make entrypoint executable +RUN chmod +x entrypoint.py + +# Create .tusk directories +RUN mkdir -p /app/.tusk/traces /app/.tusk/logs + +# Run entrypoint +ENTRYPOINT ["python", "entrypoint.py"] diff --git a/drift/instrumentation/grpc/e2e-tests/docker-compose.yml b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml new file mode 100644 index 0000000..a249d83 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/docker-compose.yml @@ -0,0 +1,19 @@ +services: + app: + build: + context: ../../../.. + dockerfile: drift/instrumentation/grpc/e2e-tests/Dockerfile + args: + - TUSK_CLI_VERSION=${TUSK_CLI_VERSION:-latest} + environment: + - PORT=8000 + - TUSK_ANALYTICS_DISABLED=1 + - PYTHONUNBUFFERED=1 + working_dir: /app + volumes: + # Mount SDK source for hot reload (no rebuild needed for SDK changes) + - ../../../..:/sdk + # Mount app source for development + - ./src:/app/src + # Mount .tusk folder to persist traces + - ./.tusk:/app/.tusk diff --git a/drift/instrumentation/grpc/e2e-tests/entrypoint.py b/drift/instrumentation/grpc/e2e-tests/entrypoint.py new file mode 100644 index 0000000..3c5a8ea --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/entrypoint.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +""" +E2E Test Entrypoint for gRPC Instrumentation + +This script orchestrates the full e2e test lifecycle: +1. Setup: Install dependencies, generate proto files +2. Record: Start app in RECORD mode, execute requests +3. Test: Run Tusk CLI tests +4. Teardown: Cleanup and return exit code +""" + +import sys +from pathlib import Path + +# Add SDK to path for imports +sys.path.insert(0, "/sdk") + +from drift.instrumentation.e2e_common.base_runner import E2ETestRunnerBase + + +class GrpcE2ETestRunner(E2ETestRunnerBase): + """E2E test runner for gRPC instrumentation.""" + + def __init__(self): + import os + + port = int(os.getenv("PORT", "8000")) + super().__init__(app_port=port) + + def setup(self): + """Phase 1: Setup dependencies and generate proto files.""" + self.log("=" * 50, self.Colors.BLUE) + self.log("Phase 1: Setup", self.Colors.BLUE) + self.log("=" * 50, self.Colors.BLUE) + + self.log("Installing Python dependencies...", self.Colors.BLUE) + self.run_command(["pip", "install", "-q", "-r", "requirements.txt"]) + + # Generate proto files + self.log("Generating proto files...", self.Colors.BLUE) + self.run_command( + [ + "python", + "-m", + "grpc_tools.protoc", + "-I", + "src/proto", + "--python_out=src", + "--grpc_python_out=src", + "src/proto/greeter.proto", + ] + ) + + self.log("Setup complete", self.Colors.GREEN) + + # Use Colors from base class + @property + def Colors(self): + from drift.instrumentation.e2e_common.base_runner import Colors + + return Colors + + +if __name__ == "__main__": + runner = GrpcE2ETestRunner() + exit_code = runner.run() + sys.exit(exit_code) diff --git a/drift/instrumentation/grpc/e2e-tests/requirements.txt b/drift/instrumentation/grpc/e2e-tests/requirements.txt new file mode 100644 index 0000000..538204c --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/requirements.txt @@ -0,0 +1,5 @@ +-e /sdk +Flask>=3.1.2 +grpcio>=1.60.0 +grpcio-tools>=1.60.0 +protobuf>=6.0 diff --git a/drift/instrumentation/grpc/e2e-tests/run.sh b/drift/instrumentation/grpc/e2e-tests/run.sh new file mode 100755 index 0000000..d67bae1 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/run.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Exit on error +set -e + +# Accept optional port parameter (default: 8000) +APP_PORT=${1:-8000} +export APP_PORT + +# Generate unique docker compose project name +# Get the instrumentation name (parent directory of e2e-tests) +TEST_NAME="$(basename "$(dirname "$(pwd)")")" +PROJECT_NAME="python-${TEST_NAME}-${APP_PORT}" + +# Colors for output +GREEN='\033[0;32m' +RED='\033[0;31m' +YELLOW='\033[1;33m' +BLUE='\033[0;34m' +NC='\033[0m' + +echo -e "${BLUE}========================================${NC}" +echo -e "${BLUE}Running Python E2E Test: ${TEST_NAME}${NC}" +echo -e "${BLUE}Port: ${APP_PORT}${NC}" +echo -e "${BLUE}========================================${NC}" +echo "" + +# Cleanup function +cleanup() { + echo "" + echo -e "${YELLOW}Cleaning up containers...${NC}" + docker compose -p "$PROJECT_NAME" down -v 2>/dev/null || true +} + +# Register cleanup on exit +trap cleanup EXIT + +# Build containers +echo -e "${BLUE}Building containers...${NC}" +docker compose -p "$PROJECT_NAME" build --no-cache + +# Run the test container +echo -e "${BLUE}Starting test...${NC}" +echo "" + +# Run container and capture exit code (always use port 8000 inside container) +# Disable set -e temporarily to capture exit code +set +e +docker compose -p "$PROJECT_NAME" run --rm app +EXIT_CODE=$? +set -e + +echo "" +if [ $EXIT_CODE -eq 0 ]; then + echo -e "${GREEN}========================================${NC}" + echo -e "${GREEN}Test passed!${NC}" + echo -e "${GREEN}========================================${NC}" +else + echo -e "${RED}========================================${NC}" + echo -e "${RED}Test failed with exit code ${EXIT_CODE}${NC}" + echo -e "${RED}========================================${NC}" +fi + +exit $EXIT_CODE diff --git a/drift/instrumentation/grpc/e2e-tests/src/app.py b/drift/instrumentation/grpc/e2e-tests/src/app.py new file mode 100644 index 0000000..296ea82 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/app.py @@ -0,0 +1,222 @@ +"""Flask test app for e2e tests - gRPC instrumentation testing. + +This app acts as an HTTP gateway that makes gRPC calls to a backend service. +This pattern is common in microservices architectures. +""" + +import sys +import threading +import time + +from flask import Flask, jsonify, request + +from drift import TuskDrift + +# Initialize SDK +sdk = TuskDrift.initialize( + api_key="tusk-test-key", + log_level="debug", +) + +# Import gRPC modules (generated from proto) +import grpc + +# Add src directory to path for generated proto files +sys.path.insert(0, "/app/src") + +import greeter_pb2 +import greeter_pb2_grpc + +# Import and start the gRPC server +from grpc_server import serve as start_grpc_server + +app = Flask(__name__) + +# gRPC channel and stub (will be created after server starts) +grpc_channel = None +grpc_stub = None + +GRPC_SERVER_PORT = 50051 + + +def init_grpc_client(): + """Initialize gRPC client connection.""" + global grpc_channel, grpc_stub + grpc_channel = grpc.insecure_channel(f"localhost:{GRPC_SERVER_PORT}") + grpc_stub = greeter_pb2_grpc.GreeterStub(grpc_channel) + + +# Health check endpoint +@app.route("/health", methods=["GET"]) +def health(): + return jsonify({"status": "healthy"}) + + +# Simple unary RPC +@app.route("/api/greet", methods=["GET"]) +def greet(): + """Test simple unary gRPC call.""" + name = request.args.get("name", "World") + try: + response = grpc_stub.SayHello(greeter_pb2.HelloRequest(name=name)) + return jsonify({"message": response.message}) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + + +# Unary RPC with more complex request/response +@app.route("/api/greet-with-info", methods=["POST"]) +def greet_with_info(): + """Test unary gRPC call with complex request.""" + data = request.get_json() or {} + try: + grpc_request = greeter_pb2.HelloRequestWithInfo( + name=data.get("name", "World"), + age=data.get("age", 25), + city=data.get("city", "Unknown"), + ) + response = grpc_stub.SayHelloWithInfo(grpc_request) + return jsonify( + { + "message": response.message, + "greeting_id": response.greeting_id, + "timestamp": response.timestamp, + } + ) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + + +# Server streaming RPC +@app.route("/api/greet-stream", methods=["GET"]) +def greet_stream(): + """Test server streaming gRPC call.""" + name = request.args.get("name", "World") + try: + responses = grpc_stub.SayHelloStream(greeter_pb2.HelloRequest(name=name)) + messages = [r.message for r in responses] + return jsonify({"messages": messages, "count": len(messages)}) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + + +# Multiple sequential gRPC calls +@app.route("/api/greet-chain", methods=["GET"]) +def greet_chain(): + """Test multiple sequential gRPC calls.""" + try: + # First call + response1 = grpc_stub.SayHello(greeter_pb2.HelloRequest(name="Alice")) + + # Second call + response2 = grpc_stub.SayHello(greeter_pb2.HelloRequest(name="Bob")) + + # Third call with more info + response3 = grpc_stub.SayHelloWithInfo(greeter_pb2.HelloRequestWithInfo(name="Charlie", age=30, city="NYC")) + + return jsonify( + { + "greeting1": response1.message, + "greeting2": response2.message, + "greeting3": { + "message": response3.message, + "greeting_id": response3.greeting_id, + }, + } + ) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + + +# Test with_call method +@app.route("/api/greet-with-call", methods=["GET"]) +def greet_with_call(): + """Test unary gRPC call using with_call to get metadata.""" + name = request.args.get("name", "World") + try: + response, call = grpc_stub.SayHello.with_call(greeter_pb2.HelloRequest(name=name)) + # Get metadata from call + initial_metadata = dict(call.initial_metadata()) + trailing_metadata = dict(call.trailing_metadata()) + + return jsonify( + { + "message": response.message, + "has_initial_metadata": len(initial_metadata) > 0, + "has_trailing_metadata": len(trailing_metadata) > 0, + } + ) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + + +@app.route("/test/future-call", methods=["GET"]) +def test_future_call(): + """Test async future gRPC call.""" + name = request.args.get("name", "FutureUser") + try: + # Use .future() for async call + future = grpc_stub.SayHello.future(greeter_pb2.HelloRequest(name=name)) + # Wait for result + response = future.result(timeout=5.0) + return jsonify({"message": response.message, "method": "future"}) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + except Exception as e: + return jsonify({"error": str(e), "type": type(e).__name__}), 500 + + +@app.route("/test/stream-unary", methods=["GET"]) +def test_stream_unary(): + """Test client streaming gRPC call (stream-unary pattern).""" + try: + # Create an iterator of requests + def request_iterator(): + names = ["Alice", "Bob", "Charlie"] + for name in names: + yield greeter_pb2.HelloRequest(name=name) + + # Make stream-unary call + response = grpc_stub.SayHelloToMany(request_iterator()) + return jsonify({"message": response.message, "method": "stream_unary"}) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + except Exception as e: + return jsonify({"error": str(e), "type": type(e).__name__}), 500 + + +@app.route("/test/stream-stream", methods=["GET"]) +def test_stream_stream(): + """Test bidirectional streaming gRPC call (stream-stream pattern).""" + try: + # Create an iterator of requests + def request_iterator(): + names = ["Echo1", "Echo2", "Echo3"] + for name in names: + yield greeter_pb2.HelloRequest(name=name) + + # Make stream-stream call + responses = grpc_stub.Chat(request_iterator()) + messages = [r.message for r in responses] + return jsonify({"messages": messages, "count": len(messages), "method": "stream_stream"}) + except grpc.RpcError as e: + return jsonify({"error": str(e)}), 500 + except Exception as e: + return jsonify({"error": str(e), "type": type(e).__name__}), 500 + + +if __name__ == "__main__": + # Start gRPC server in background thread + grpc_server = start_grpc_server(port=GRPC_SERVER_PORT) + + # Wait a moment for server to start + time.sleep(0.5) + + # Initialize gRPC client + init_grpc_client() + + # Mark app as ready + sdk.mark_app_as_ready() + + # Start Flask app + app.run(host="0.0.0.0", port=8000, debug=False) diff --git a/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2.py b/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2.py new file mode 100644 index 0000000..e1e5a5e --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: greeter.proto +# Protobuf Python Version: 6.31.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 31, + 1, + '', + 'greeter.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rgreeter.proto\x12\x07greeter\"\x1c\n\x0cHelloRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\"?\n\x14HelloRequestWithInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0b\n\x03\x61ge\x18\x02 \x01(\x05\x12\x0c\n\x04\x63ity\x18\x03 \x01(\t\"\x1d\n\nHelloReply\x12\x0f\n\x07message\x18\x01 \x01(\t\"M\n\x12HelloReplyWithInfo\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\x13\n\x0bgreeting_id\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\x03\x32\xd3\x02\n\x07Greeter\x12\x38\n\x08SayHello\x12\x15.greeter.HelloRequest\x1a\x13.greeter.HelloReply\"\x00\x12P\n\x10SayHelloWithInfo\x12\x1d.greeter.HelloRequestWithInfo\x1a\x1b.greeter.HelloReplyWithInfo\"\x00\x12@\n\x0eSayHelloStream\x12\x15.greeter.HelloRequest\x1a\x13.greeter.HelloReply\"\x00\x30\x01\x12@\n\x0eSayHelloToMany\x12\x15.greeter.HelloRequest\x1a\x13.greeter.HelloReply\"\x00(\x01\x12\x38\n\x04\x43hat\x12\x15.greeter.HelloRequest\x1a\x13.greeter.HelloReply\"\x00(\x01\x30\x01\x62\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'greeter_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_HELLOREQUEST']._serialized_start=26 + _globals['_HELLOREQUEST']._serialized_end=54 + _globals['_HELLOREQUESTWITHINFO']._serialized_start=56 + _globals['_HELLOREQUESTWITHINFO']._serialized_end=119 + _globals['_HELLOREPLY']._serialized_start=121 + _globals['_HELLOREPLY']._serialized_end=150 + _globals['_HELLOREPLYWITHINFO']._serialized_start=152 + _globals['_HELLOREPLYWITHINFO']._serialized_end=229 + _globals['_GREETER']._serialized_start=232 + _globals['_GREETER']._serialized_end=571 +# @@protoc_insertion_point(module_scope) diff --git a/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2_grpc.py b/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2_grpc.py new file mode 100644 index 0000000..fa28693 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/greeter_pb2_grpc.py @@ -0,0 +1,279 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import greeter_pb2 as greeter__pb2 + +GRPC_GENERATED_VERSION = '1.76.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + ' but the generated code in greeter_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class GreeterStub(object): + """The greeting service definition. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SayHello = channel.unary_unary( + '/greeter.Greeter/SayHello', + request_serializer=greeter__pb2.HelloRequest.SerializeToString, + response_deserializer=greeter__pb2.HelloReply.FromString, + _registered_method=True) + self.SayHelloWithInfo = channel.unary_unary( + '/greeter.Greeter/SayHelloWithInfo', + request_serializer=greeter__pb2.HelloRequestWithInfo.SerializeToString, + response_deserializer=greeter__pb2.HelloReplyWithInfo.FromString, + _registered_method=True) + self.SayHelloStream = channel.unary_stream( + '/greeter.Greeter/SayHelloStream', + request_serializer=greeter__pb2.HelloRequest.SerializeToString, + response_deserializer=greeter__pb2.HelloReply.FromString, + _registered_method=True) + self.SayHelloToMany = channel.stream_unary( + '/greeter.Greeter/SayHelloToMany', + request_serializer=greeter__pb2.HelloRequest.SerializeToString, + response_deserializer=greeter__pb2.HelloReply.FromString, + _registered_method=True) + self.Chat = channel.stream_stream( + '/greeter.Greeter/Chat', + request_serializer=greeter__pb2.HelloRequest.SerializeToString, + response_deserializer=greeter__pb2.HelloReply.FromString, + _registered_method=True) + + +class GreeterServicer(object): + """The greeting service definition. + """ + + def SayHello(self, request, context): + """Sends a greeting (unary-unary) + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SayHelloWithInfo(self, request, context): + """Sends a greeting with additional info (unary-unary) + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SayHelloStream(self, request, context): + """Server streaming - sends multiple greetings (unary-stream) + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SayHelloToMany(self, request_iterator, context): + """Client streaming - receives multiple names (stream-unary) + BUG #3: Channel.stream_unary is NOT patched + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Chat(self, request_iterator, context): + """Bidirectional streaming (stream-stream) + BUG #4: Channel.stream_stream is NOT patched + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_GreeterServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SayHello': grpc.unary_unary_rpc_method_handler( + servicer.SayHello, + request_deserializer=greeter__pb2.HelloRequest.FromString, + response_serializer=greeter__pb2.HelloReply.SerializeToString, + ), + 'SayHelloWithInfo': grpc.unary_unary_rpc_method_handler( + servicer.SayHelloWithInfo, + request_deserializer=greeter__pb2.HelloRequestWithInfo.FromString, + response_serializer=greeter__pb2.HelloReplyWithInfo.SerializeToString, + ), + 'SayHelloStream': grpc.unary_stream_rpc_method_handler( + servicer.SayHelloStream, + request_deserializer=greeter__pb2.HelloRequest.FromString, + response_serializer=greeter__pb2.HelloReply.SerializeToString, + ), + 'SayHelloToMany': grpc.stream_unary_rpc_method_handler( + servicer.SayHelloToMany, + request_deserializer=greeter__pb2.HelloRequest.FromString, + response_serializer=greeter__pb2.HelloReply.SerializeToString, + ), + 'Chat': grpc.stream_stream_rpc_method_handler( + servicer.Chat, + request_deserializer=greeter__pb2.HelloRequest.FromString, + response_serializer=greeter__pb2.HelloReply.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'greeter.Greeter', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('greeter.Greeter', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class Greeter(object): + """The greeting service definition. + """ + + @staticmethod + def SayHello(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/greeter.Greeter/SayHello', + greeter__pb2.HelloRequest.SerializeToString, + greeter__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloWithInfo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/greeter.Greeter/SayHelloWithInfo', + greeter__pb2.HelloRequestWithInfo.SerializeToString, + greeter__pb2.HelloReplyWithInfo.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloStream(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/greeter.Greeter/SayHelloStream', + greeter__pb2.HelloRequest.SerializeToString, + greeter__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SayHelloToMany(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/greeter.Greeter/SayHelloToMany', + greeter__pb2.HelloRequest.SerializeToString, + greeter__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Chat(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/greeter.Greeter/Chat', + greeter__pb2.HelloRequest.SerializeToString, + greeter__pb2.HelloReply.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/drift/instrumentation/grpc/e2e-tests/src/grpc_server.py b/drift/instrumentation/grpc/e2e-tests/src/grpc_server.py new file mode 100644 index 0000000..6f321c1 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/grpc_server.py @@ -0,0 +1,74 @@ +"""gRPC server for e2e tests.""" + +import time +from concurrent import futures + +# These will be generated from proto file +import greeter_pb2 +import greeter_pb2_grpc +import grpc + + +class GreeterServicer(greeter_pb2_grpc.GreeterServicer): + """Implementation of the Greeter service.""" + + def SayHello(self, request, context): + """Handle SayHello RPC (unary-unary).""" + return greeter_pb2.HelloReply(message=f"Hello, {request.name}!") + + def SayHelloWithInfo(self, request, context): + """Handle SayHelloWithInfo RPC (unary-unary).""" + # Use deterministic values for testing (no dynamic UUIDs or timestamps) + return greeter_pb2.HelloReplyWithInfo( + message=f"Hello, {request.name} from {request.city}! You are {request.age} years old.", + greeting_id="test-greeting-id-12345", + timestamp=1234567890000, + ) + + def SayHelloStream(self, request, context): + """Handle SayHelloStream RPC - server streaming (unary-stream).""" + greetings = [ + f"Hello, {request.name}!", + f"Welcome, {request.name}!", + f"Greetings, {request.name}!", + ] + for greeting in greetings: + yield greeter_pb2.HelloReply(message=greeting) + time.sleep(0.1) # Small delay between messages + + def SayHelloToMany(self, request_iterator, context): + """Handle SayHelloToMany RPC - client streaming (stream-unary). + + This endpoint exposes BUG #3: Channel.stream_unary is NOT patched. + """ + names = [] + for request in request_iterator: + names.append(request.name) + + combined_greeting = f"Hello to all: {', '.join(names)}!" + return greeter_pb2.HelloReply(message=combined_greeting) + + def Chat(self, request_iterator, context): + """Handle Chat RPC - bidirectional streaming (stream-stream). + + This endpoint exposes BUG #4: Channel.stream_stream is NOT patched. + """ + for request in request_iterator: + response_message = f"Echo: {request.name}" + yield greeter_pb2.HelloReply(message=response_message) + time.sleep(0.05) # Small delay between responses + + +def serve(port: int = 50051): + """Start the gRPC server.""" + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + greeter_pb2_grpc.add_GreeterServicer_to_server(GreeterServicer(), server) + server.add_insecure_port(f"[::]:{port}") + server.start() + print(f"gRPC server started on port {port}") + return server + + +if __name__ == "__main__": + server = serve() + server.wait_for_termination() diff --git a/drift/instrumentation/grpc/e2e-tests/src/proto/greeter.proto b/drift/instrumentation/grpc/e2e-tests/src/proto/greeter.proto new file mode 100644 index 0000000..bd265b4 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/proto/greeter.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +package greeter; + +// The greeting service definition. +service Greeter { + // Sends a greeting (unary-unary) + rpc SayHello (HelloRequest) returns (HelloReply) {} + + // Sends a greeting with additional info (unary-unary) + rpc SayHelloWithInfo (HelloRequestWithInfo) returns (HelloReplyWithInfo) {} + + // Server streaming - sends multiple greetings (unary-stream) + rpc SayHelloStream (HelloRequest) returns (stream HelloReply) {} + + // Client streaming - receives multiple names (stream-unary) + // BUG #3: Channel.stream_unary is NOT patched + rpc SayHelloToMany (stream HelloRequest) returns (HelloReply) {} + + // Bidirectional streaming (stream-stream) + // BUG #4: Channel.stream_stream is NOT patched + rpc Chat (stream HelloRequest) returns (stream HelloReply) {} +} + +// The request message containing the user's name. +message HelloRequest { + string name = 1; +} + +// The request message with additional info +message HelloRequestWithInfo { + string name = 1; + int32 age = 2; + string city = 3; +} + +// The response message containing the greeting +message HelloReply { + string message = 1; +} + +// The response message with additional info +message HelloReplyWithInfo { + string message = 1; + string greeting_id = 2; + int64 timestamp = 3; +} diff --git a/drift/instrumentation/grpc/e2e-tests/src/test_requests.py b/drift/instrumentation/grpc/e2e-tests/src/test_requests.py new file mode 100644 index 0000000..1195f34 --- /dev/null +++ b/drift/instrumentation/grpc/e2e-tests/src/test_requests.py @@ -0,0 +1,42 @@ +"""Execute test requests against the Flask app to exercise the gRPC instrumentation.""" + +from drift.instrumentation.e2e_common.test_utils import make_request, print_request_summary + +if __name__ == "__main__": + print("Starting test request sequence for gRPC instrumentation...\n") + + # Health check + make_request("GET", "/health") + + # Simple unary gRPC call + make_request("GET", "/api/greet?name=TestUser") + + # Unary gRPC call with different name + make_request("GET", "/api/greet?name=AnotherUser") + + # Unary gRPC call with complex request + make_request( + "POST", + "/api/greet-with-info", + json={"name": "John", "age": 30, "city": "San Francisco"}, + ) + + # Server streaming gRPC call + make_request("GET", "/api/greet-stream?name=StreamUser") + + # Multiple sequential gRPC calls + make_request("GET", "/api/greet-chain") + + # Test with_call method + make_request("GET", "/api/greet-with-call?name=CallUser") + + # Future calls (async unary) + make_request("GET", "/test/future-call?name=FutureTest") + + # Client streaming (stream-unary) + make_request("GET", "/test/stream-unary") + + # Bidirectional streaming (stream-stream) + make_request("GET", "/test/stream-stream") + + print_request_summary() diff --git a/drift/instrumentation/grpc/instrumentation.py b/drift/instrumentation/grpc/instrumentation.py new file mode 100644 index 0000000..090961a --- /dev/null +++ b/drift/instrumentation/grpc/instrumentation.py @@ -0,0 +1,1976 @@ +"""Instrumentation for gRPC client library (grpcio).""" + +from __future__ import annotations + +import json +import logging +from enum import Enum +from typing import Any + +from opentelemetry.trace import Span, Status +from opentelemetry.trace import SpanKind as OTelSpanKind +from opentelemetry.trace import StatusCode as OTelStatusCode + +from ...core.data_normalization import remove_none_values +from ...core.drift_sdk import TuskDrift +from ...core.mode_utils import handle_record_mode, handle_replay_mode +from ...core.tracing import TdSpanAttributes +from ...core.tracing.span_utils import CreateSpanOptions, SpanUtils +from ...core.types import ( + PackageType, + SpanKind, + SpanStatus, + StatusCode, + TuskDriftMode, + calling_library_context, +) +from ..base import InstrumentationBase +from .utils import ( + deserialize_grpc_payload, + parse_grpc_path, + serialize_grpc_metadata, + serialize_grpc_payload, +) + +logger = logging.getLogger(__name__) + +GRPC_MODULE_NAME = "grpc" + + +class ReplayResponseType(Enum): + """How to format the mock response in replay mode.""" + + DIRECT = "direct" # Return mock directly + WITH_CALL = "with_call" # Return (mock, MockGrpcCall()) + ITERATOR = "iterator" # Return iter(mock) + FUTURE = "future" # Return MockGrpcFuture(mock) + + +class GrpcInstrumentation(InstrumentationBase): + """Instrumentation for the grpcio gRPC client library. + + Patches grpc.Channel methods to: + - Intercept gRPC requests in REPLAY mode and return mocked responses + - Capture request/response data as CLIENT spans in RECORD mode + + This instrumentation focuses on client-side gRPC calls (unary and server streaming). + Server-side instrumentation is not yet implemented. + """ + + def __init__(self, enabled: bool = True) -> None: + super().__init__( + name="GrpcInstrumentation", + module_name="grpc", + supported_versions="*", + enabled=enabled, + ) + + def patch(self, module: Any) -> None: + """Patch the grpc module. + + Patches the Channel class to intercept all RPC calls: + - unary_unary: Single request, single response + - unary_stream: Single request, streaming response (server streaming) + - stream_unary: Streaming request, single response (client streaming) + - stream_stream: Streaming request, streaming response (bidirectional) + """ + # CRITICAL: Patch the concrete _channel.Channel class, not just the abstract class. + # The abstract grpc.Channel defines the interface, but grpc._channel.Channel + # provides the implementation. Python's MRO resolves methods from the concrete + # class first, so patching the abstract class has no effect. + try: + from grpc import _channel + + if hasattr(_channel, "Channel"): + self._patch_channel_class(_channel.Channel) + logger.debug("Patched grpc._channel.Channel (concrete implementation)") + except ImportError: + logger.warning("Could not import grpc._channel - falling back to abstract class patching") + # Fallback to abstract class (less reliable but better than nothing) + if hasattr(module, "Channel"): + self._patch_channel_class(module.Channel) + + # Also patch insecure_channel and secure_channel factory functions + # to return instrumented channels + self._patch_channel_factories(module) + + logger.info("grpc module instrumented") + + def _patch_channel_factories(self, module: Any) -> None: + """Patch channel factory functions to instrument created channels.""" + instrumentation_self = self + + # Patch insecure_channel + if hasattr(module, "insecure_channel"): + original_insecure_channel = module.insecure_channel + + def patched_insecure_channel(*args, **kwargs): + channel = original_insecure_channel(*args, **kwargs) + return instrumentation_self._wrap_channel(channel) + + module.insecure_channel = patched_insecure_channel + logger.debug("Patched grpc.insecure_channel") + + # Patch secure_channel + if hasattr(module, "secure_channel"): + original_secure_channel = module.secure_channel + + def patched_secure_channel(*args, **kwargs): + channel = original_secure_channel(*args, **kwargs) + return instrumentation_self._wrap_channel(channel) + + module.secure_channel = patched_secure_channel + logger.debug("Patched grpc.secure_channel") + + def _patch_channel_class(self, channel_class: Any) -> None: + """Patch the Channel class methods for all RPC patterns.""" + instrumentation_self = self + + # Store original methods + original_unary_unary = channel_class.unary_unary + original_unary_stream = channel_class.unary_stream + original_stream_unary = channel_class.stream_unary + original_stream_stream = channel_class.stream_stream + + def patched_unary_unary(channel_self: Any, method: str, *args: Any, **kwargs: Any) -> Any: + """Patched unary_unary that returns instrumented callable.""" + original_callable = original_unary_unary(channel_self, method, *args, **kwargs) + return instrumentation_self._wrap_unary_unary_callable(original_callable, method) + + def patched_unary_stream(channel_self: Any, method: str, *args: Any, **kwargs: Any) -> Any: + """Patched unary_stream that returns instrumented callable.""" + original_callable = original_unary_stream(channel_self, method, *args, **kwargs) + return instrumentation_self._wrap_unary_stream_callable(original_callable, method) + + def patched_stream_unary(channel_self: Any, method: str, *args: Any, **kwargs: Any) -> Any: + """Patched stream_unary that returns instrumented callable.""" + original_callable = original_stream_unary(channel_self, method, *args, **kwargs) + return instrumentation_self._wrap_stream_unary_callable(original_callable, method) + + def patched_stream_stream(channel_self: Any, method: str, *args: Any, **kwargs: Any) -> Any: + """Patched stream_stream that returns instrumented callable.""" + original_callable = original_stream_stream(channel_self, method, *args, **kwargs) + return instrumentation_self._wrap_stream_stream_callable(original_callable, method) + + channel_class.unary_unary = patched_unary_unary + channel_class.unary_stream = patched_unary_stream + channel_class.stream_unary = patched_stream_unary + channel_class.stream_stream = patched_stream_stream + logger.debug("Patched grpc.Channel methods (unary_unary, unary_stream, stream_unary, stream_stream)") + + def _wrap_channel(self, channel: Any) -> Any: + """Wrap an existing channel with instrumented methods. + + This is used for channels created before patching. + """ + # The channel is already using the patched Channel class methods + # due to how we patch at the class level + return channel + + def _wrap_unary_unary_callable(self, original_callable: Any, method: str) -> Any: + """Wrap a unary-unary callable with instrumentation.""" + instrumentation_self = self + + class InstrumentedUnaryUnaryCallable: + """Wrapper for unary-unary RPC callable.""" + + def __init__(self, original: Any, grpc_method: str): + self._original = original + self._method = grpc_method + + def __call__( + self, + request: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make the unary-unary RPC call.""" + return instrumentation_self._handle_unary_unary_call( + self._original, + self._method, + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def with_call( + self, + request: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> tuple[Any, Any]: + """Make the unary-unary RPC call and return (response, call).""" + return instrumentation_self._handle_unary_unary_with_call( + self._original, + self._method, + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def future( + self, + request: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make async unary-unary RPC call returning a future.""" + return instrumentation_self._handle_unary_unary_future( + self._original, + self._method, + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + return InstrumentedUnaryUnaryCallable(original_callable, method) + + def _wrap_unary_stream_callable(self, original_callable: Any, method: str) -> Any: + """Wrap a unary-stream callable with instrumentation.""" + instrumentation_self = self + + class InstrumentedUnaryStreamCallable: + """Wrapper for unary-stream RPC callable.""" + + def __init__(self, original: Any, grpc_method: str): + self._original = original + self._method = grpc_method + + def __call__( + self, + request: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make the unary-stream RPC call.""" + return instrumentation_self._handle_unary_stream_call( + self._original, + self._method, + request, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + return InstrumentedUnaryStreamCallable(original_callable, method) + + def _wrap_stream_unary_callable(self, original_callable: Any, method: str) -> Any: + """Wrap a stream-unary callable with instrumentation.""" + instrumentation_self = self + + class InstrumentedStreamUnaryCallable: + """Wrapper for stream-unary RPC callable (client streaming).""" + + def __init__(self, original: Any, grpc_method: str): + self._original = original + self._method = grpc_method + + def __call__( + self, + request_iterator: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make the stream-unary RPC call.""" + return instrumentation_self._handle_stream_unary_call( + self._original, + self._method, + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def with_call( + self, + request_iterator: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> tuple[Any, Any]: + """Make the stream-unary RPC call and return (response, call).""" + return instrumentation_self._handle_stream_unary_with_call( + self._original, + self._method, + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + def future( + self, + request_iterator: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make async stream-unary RPC call returning a future.""" + return instrumentation_self._handle_stream_unary_future( + self._original, + self._method, + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + return InstrumentedStreamUnaryCallable(original_callable, method) + + def _wrap_stream_stream_callable(self, original_callable: Any, method: str) -> Any: + """Wrap a stream-stream callable with instrumentation.""" + instrumentation_self = self + + class InstrumentedStreamStreamCallable: + """Wrapper for stream-stream RPC callable (bidirectional streaming).""" + + def __init__(self, original: Any, grpc_method: str): + self._original = original + self._method = grpc_method + + def __call__( + self, + request_iterator: Any, + timeout: float | None = None, + metadata: Any = None, + credentials: Any = None, + wait_for_ready: bool | None = None, + compression: Any = None, + ) -> Any: + """Make the stream-stream RPC call.""" + return instrumentation_self._handle_stream_stream_call( + self._original, + self._method, + request_iterator, + timeout=timeout, + metadata=metadata, + credentials=credentials, + wait_for_ready=wait_for_ready, + compression=compression, + ) + + return InstrumentedStreamStreamCallable(original_callable, method) + + def _build_input_value(self, method: str, request: Any, metadata: Any) -> dict[str, Any]: + """Build the input value for a gRPC request.""" + grpc_method, service = parse_grpc_path(method) + readable_body, buffer_map, jsonable_string_map = serialize_grpc_payload(request) + readable_metadata = serialize_grpc_metadata(metadata) + + input_value = { + "method": grpc_method, + "service": service, + "body": readable_body, + "metadata": readable_metadata, + "inputMeta": { + "bufferMap": buffer_map, + "jsonableStringMap": jsonable_string_map, + }, + } + + return input_value + + def _handle_unary_unary_call( + self, + original_callable: Any, + method: str, + request: Any, + **kwargs, + ) -> Any: + """Handle a unary-unary RPC call.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable(request, **kwargs) + + # Set calling_library_context to suppress socket instrumentation warnings + context_token = calling_library_context.set("grpc") + try: + + def original_call(): + return original_callable(request, **kwargs) + + metadata = kwargs.get("metadata") + input_value = self._build_input_value(method, request, metadata) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_unary_unary(sdk, method, input_value, request), + no_op_request_handler=lambda: self._get_default_response(), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_unary_unary( + original_callable, method, request, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_unary_unary_with_call( + self, + original_callable: Any, + method: str, + request: Any, + **kwargs, + ) -> tuple[Any, Any]: + """Handle a unary-unary RPC call with call object returned.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable.with_call(request, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + + def original_call(): + return original_callable.with_call(request, **kwargs) + + metadata = kwargs.get("metadata") + input_value = self._build_input_value(method, request, metadata) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_unary_unary_with_call( + sdk, method, input_value, request + ), + no_op_request_handler=lambda: (self._get_default_response(), None), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_unary_unary_with_call( + original_callable, method, request, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_unary_unary_future( + self, + original_callable: Any, + method: str, + request: Any, + **kwargs, + ) -> Any: + """Handle an async unary-unary RPC call (future). + + Wraps the returned Future to intercept result() calls for recording/replay. + """ + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable.future(request, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + metadata = kwargs.get("metadata") + input_value = self._build_input_value(method, request, metadata) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_unary_unary_future( + sdk, method, input_value, request + ), + no_op_request_handler=lambda: MockGrpcFuture(None), + is_server_request=False, + ) + + def original_call(): + return original_callable.future(request, **kwargs) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_unary_unary_future( + original_callable, method, request, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_unary_stream_call( + self, + original_callable: Any, + method: str, + request: Any, + **kwargs, + ) -> Any: + """Handle a unary-stream RPC call.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable(request, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + + def original_call(): + return original_callable(request, **kwargs) + + metadata = kwargs.get("metadata") + input_value = self._build_input_value(method, request, metadata) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_unary_stream(sdk, method, input_value, request), + no_op_request_handler=lambda: iter([]), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_unary_stream( + original_callable, method, request, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _get_default_response(self) -> None: + """Return default response for background requests in REPLAY mode.""" + logger.debug("[GrpcInstrumentation] Returning default response for background request") + return None + + def _handle_record_unary_unary( + self, + original_callable: Any, + method: str, + request: Any, + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle unary-unary call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.unary", is_pre_app_start) + + if not span_info: + return original_callable(request, **kwargs) + + error = None + response = None + response_metadata = None + trailing_metadata = None + + try: + with SpanUtils.with_span(span_info): + try: + response, call = original_callable.with_call(request, **kwargs) + response_metadata = call.initial_metadata() + trailing_metadata = call.trailing_metadata() + return response + except Exception as e: + error = e + raise + finally: + self._finalize_unary_span( + span_info.span, + input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + finally: + span_info.span.end() + + def _handle_record_unary_unary_with_call( + self, + original_callable: Any, + method: str, + request: Any, + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> tuple[Any, Any]: + """Handle unary-unary with_call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.unary", is_pre_app_start) + + if not span_info: + return original_callable.with_call(request, **kwargs) + + error = None + response = None + call = None + response_metadata = None + trailing_metadata = None + + try: + with SpanUtils.with_span(span_info): + try: + response, call = original_callable.with_call(request, **kwargs) + response_metadata = call.initial_metadata() + trailing_metadata = call.trailing_metadata() + return response, call + except Exception as e: + error = e + raise + finally: + self._finalize_unary_span( + span_info.span, + input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + finally: + span_info.span.end() + + def _handle_record_unary_stream( + self, + original_callable: Any, + method: str, + request: Any, + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle unary-stream call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.server_stream", is_pre_app_start) + + if not span_info: + return original_callable(request, **kwargs) + + # For streaming, we need to wrap the iterator to capture all responses + instrumentation_self = self + + class RecordingStreamIterator: + """Iterator that records streaming responses.""" + + def __init__(self, original_iterator: Any, span_info_ref: Any, input_val: dict): + self._original = original_iterator + self._span_info = span_info_ref + self._input_value = input_val + self._responses: list[dict] = [] + self._error: Exception | None = None + self._finished = False + + def __iter__(self): + return self + + def __next__(self): + try: + response = next(self._original) + # Serialize the response + readable_body, buffer_map, jsonable_string_map = serialize_grpc_payload(response) + self._responses.append( + { + "body": readable_body, + "bufferMap": buffer_map, + "jsonableStringMap": jsonable_string_map, + } + ) + return response + except StopIteration: + self._finish() + raise + except Exception as e: + self._error = e + self._finish() + raise + + def _finish(self): + if self._finished: + return + self._finished = True + + try: + instrumentation_self._finalize_stream_span( + self._span_info.span, + self._input_value, + self._responses, + self._error, + self._original, + ) + finally: + self._span_info.span.end() + + # Get the original iterator and wrap it + try: + with SpanUtils.with_span(span_info): + original_iterator = original_callable(request, **kwargs) + return RecordingStreamIterator(original_iterator, span_info, input_value) + except Exception as e: + # If we fail to even start the stream, finalize and re-raise + self._finalize_stream_span(span_info.span, input_value, [], e, None) + span_info.span.end() + raise + + def _handle_record_unary_unary_future( + self, + original_callable: Any, + method: str, + request: Any, + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle unary-unary future call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.unary", is_pre_app_start) + + if not span_info: + return original_callable.future(request, **kwargs) + + # Create a wrapper future that records the result when accessed + instrumentation_self = self + + class RecordingFuture: + """Future wrapper that records the result when accessed.""" + + def __init__(self, original_future: Any, span_info_ref: Any, input_val: dict): + self._original = original_future + self._span_info = span_info_ref + self._input_value = input_val + self._recorded = False + + def result(self, timeout: float | None = None) -> Any: + """Get the result and record it.""" + error = None + response = None + response_metadata = None + trailing_metadata = None + + try: + response = self._original.result(timeout=timeout) + # Try to get metadata + if hasattr(self._original, "initial_metadata"): + response_metadata = self._original.initial_metadata() + if hasattr(self._original, "trailing_metadata"): + trailing_metadata = self._original.trailing_metadata() + return response + except Exception as e: + error = e + raise + finally: + if not self._recorded: + self._recorded = True + instrumentation_self._finalize_unary_span( + self._span_info.span, + self._input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + self._span_info.span.end() + + def exception(self, timeout: float | None = None) -> Any: + """Get the exception if any.""" + return self._original.exception(timeout=timeout) + + def traceback(self, timeout: float | None = None) -> Any: + """Get the traceback if any.""" + return self._original.traceback(timeout=timeout) + + def add_done_callback(self, fn: Any) -> None: + """Add a callback to be called when the future completes.""" + self._original.add_done_callback(fn) + + def cancelled(self) -> bool: + """Return True if the future was cancelled.""" + return self._original.cancelled() + + def running(self) -> bool: + """Return True if the future is currently running.""" + return self._original.running() + + def done(self) -> bool: + """Return True if the future is done.""" + return self._original.done() + + def cancel(self) -> bool: + """Attempt to cancel the future.""" + return self._original.cancel() + + # gRPC-specific methods + def initial_metadata(self) -> Any: + """Get initial metadata.""" + if hasattr(self._original, "initial_metadata"): + return self._original.initial_metadata() + return [] + + def trailing_metadata(self) -> Any: + """Get trailing metadata.""" + if hasattr(self._original, "trailing_metadata"): + return self._original.trailing_metadata() + return [] + + def code(self) -> Any: + """Get status code.""" + if hasattr(self._original, "code"): + return self._original.code() + return None + + def details(self) -> str: + """Get status details.""" + if hasattr(self._original, "details"): + return self._original.details() + return "" + + # Get the original future and wrap it + try: + with SpanUtils.with_span(span_info): + original_future = original_callable.future(request, **kwargs) + return RecordingFuture(original_future, span_info, input_value) + except Exception as e: + # If we fail to even create the future, finalize and re-raise + self._finalize_unary_span(span_info.span, input_value, None, e, None, None) + span_info.span.end() + raise + + def _handle_stream_unary_call( + self, + original_callable: Any, + method: str, + request_iterator: Any, + **kwargs, + ) -> Any: + """Handle a stream-unary RPC call (client streaming).""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable(request_iterator, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + # For client streaming, we need to consume the iterator to capture all requests + # This changes the behavior slightly - the iterator is consumed upfront + requests_list = list(request_iterator) + metadata = kwargs.get("metadata") + input_value = self._build_stream_input_value(method, requests_list, metadata) + + def original_call(): + return original_callable(iter(requests_list), **kwargs) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_stream_unary(sdk, method, input_value), + no_op_request_handler=lambda: self._get_default_response(), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_stream_unary( + original_callable, method, requests_list, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_stream_unary_with_call( + self, + original_callable: Any, + method: str, + request_iterator: Any, + **kwargs, + ) -> tuple[Any, Any]: + """Handle a stream-unary RPC call with call object returned.""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable.with_call(request_iterator, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + requests_list = list(request_iterator) + metadata = kwargs.get("metadata") + input_value = self._build_stream_input_value(method, requests_list, metadata) + + def original_call(): + return original_callable.with_call(iter(requests_list), **kwargs) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_stream_unary_with_call(sdk, method, input_value), + no_op_request_handler=lambda: (self._get_default_response(), None), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_stream_unary_with_call( + original_callable, method, requests_list, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_stream_unary_future( + self, + original_callable: Any, + method: str, + request_iterator: Any, + **kwargs, + ) -> Any: + """Handle an async stream-unary RPC call (future).""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable.future(request_iterator, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + requests_list = list(request_iterator) + metadata = kwargs.get("metadata") + input_value = self._build_stream_input_value(method, requests_list, metadata) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_stream_unary_future(sdk, method, input_value), + no_op_request_handler=lambda: MockGrpcFuture(None), + is_server_request=False, + ) + + def original_call(): + return original_callable.future(iter(requests_list), **kwargs) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_stream_unary_future( + original_callable, method, requests_list, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _handle_stream_stream_call( + self, + original_callable: Any, + method: str, + request_iterator: Any, + **kwargs, + ) -> Any: + """Handle a stream-stream RPC call (bidirectional streaming).""" + sdk = TuskDrift.get_instance() + + if sdk.mode == TuskDriftMode.DISABLED: + return original_callable(request_iterator, **kwargs) + + context_token = calling_library_context.set("grpc") + try: + # For bidirectional streaming, we need to capture both request and response streams + requests_list = list(request_iterator) + metadata = kwargs.get("metadata") + input_value = self._build_stream_input_value(method, requests_list, metadata) + + def original_call(): + return original_callable(iter(requests_list), **kwargs) + + if sdk.mode == TuskDriftMode.REPLAY: + return handle_replay_mode( + replay_mode_handler=lambda: self._handle_replay_stream_stream(sdk, method, input_value), + no_op_request_handler=lambda: iter([]), + is_server_request=False, + ) + + return handle_record_mode( + original_function_call=original_call, + record_mode_handler=lambda is_pre_app_start: self._handle_record_stream_stream( + original_callable, method, requests_list, input_value, is_pre_app_start, **kwargs + ), + span_kind=OTelSpanKind.CLIENT, + ) + finally: + calling_library_context.reset(context_token) + + def _build_stream_input_value(self, method: str, requests: list[Any], metadata: Any) -> dict[str, Any]: + """Build the input value for a streaming gRPC request.""" + grpc_method, service = parse_grpc_path(method) + readable_metadata = serialize_grpc_metadata(metadata) + + # Serialize all requests in the stream + serialized_requests = [] + combined_buffer_map: dict[str, dict[str, str]] = {} + combined_jsonable_string_map: dict[str, str] = {} + + for i, request in enumerate(requests): + readable_body, buffer_map, jsonable_string_map = serialize_grpc_payload(request) + serialized_requests.append( + { + "body": readable_body, + "bufferMap": buffer_map, + "jsonableStringMap": jsonable_string_map, + } + ) + # Prefix keys to avoid collisions + for key, value in buffer_map.items(): + combined_buffer_map[f"{i}_{key}"] = value + for key, value in jsonable_string_map.items(): + combined_jsonable_string_map[f"{i}_{key}"] = value + + input_value = { + "method": grpc_method, + "service": service, + "body": serialized_requests, # List of request bodies + "metadata": readable_metadata, + "inputMeta": { + "bufferMap": combined_buffer_map, + "jsonableStringMap": combined_jsonable_string_map, + }, + } + + return input_value + + # ========================================================================= + # Helper Methods + # ========================================================================= + + def _create_client_span(self, span_name: str, is_pre_app_start: bool) -> Any: + """Create a gRPC CLIENT span with standard attributes. + + Args: + span_name: Name for the span (e.g., "grpc.client.unary") + is_pre_app_start: Whether this is before app startup + + Returns: + SpanInfo object or None if span creation fails + """ + return SpanUtils.create_span( + CreateSpanOptions( + name=span_name, + kind=OTelSpanKind.CLIENT, + attributes={ + TdSpanAttributes.NAME: span_name, + TdSpanAttributes.PACKAGE_NAME: GRPC_MODULE_NAME, + TdSpanAttributes.INSTRUMENTATION_NAME: "GrpcInstrumentation", + TdSpanAttributes.SUBMODULE_NAME: "client", + TdSpanAttributes.PACKAGE_TYPE: PackageType.GRPC.name, + TdSpanAttributes.IS_PRE_APP_START: is_pre_app_start, + }, + is_pre_app_start=is_pre_app_start, + ) + ) + + def _handle_replay_generic( + self, + sdk: TuskDrift, + method: str, + input_value: dict[str, Any], + span_name: str, + is_stream: bool, + response_type: ReplayResponseType, + ) -> Any: + """Generic replay handler for all gRPC call types. + + Args: + sdk: TuskDrift instance + method: gRPC method path + input_value: Serialized input for mock matching + span_name: Name for the span + is_stream: Whether this is a streaming response + response_type: How to format the mock response + + Returns: + Mock response formatted according to response_type + """ + span_info = self._create_client_span(span_name, not sdk.app_ready) + + if not span_info: + raise RuntimeError(f"Error creating span in replay mode for gRPC {method}") + + try: + with SpanUtils.with_span(span_info): + mock_response = self._try_get_mock( + sdk, + method, + span_info.trace_id, + span_info.span_id, + input_value, + is_stream=is_stream, + ) + + if mock_response is None: + raise RuntimeError(f"No mock found for gRPC {method} in REPLAY mode") + + # Format response based on type + if response_type == ReplayResponseType.DIRECT: + return mock_response + elif response_type == ReplayResponseType.WITH_CALL: + return mock_response, MockGrpcCall() + elif response_type == ReplayResponseType.ITERATOR: + return iter(mock_response) + elif response_type == ReplayResponseType.FUTURE: + return MockGrpcFuture(mock_response) + else: + return mock_response + finally: + span_info.span.end() + + # ========================================================================= + # Record Mode Handlers + # ========================================================================= + + def _handle_record_stream_unary( + self, + original_callable: Any, + method: str, + requests: list[Any], + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle stream-unary call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.client_stream", is_pre_app_start) + + if not span_info: + return original_callable(iter(requests), **kwargs) + + error = None + response = None + response_metadata = None + trailing_metadata = None + + try: + with SpanUtils.with_span(span_info): + try: + response, call = original_callable.with_call(iter(requests), **kwargs) + response_metadata = call.initial_metadata() + trailing_metadata = call.trailing_metadata() + return response + except Exception as e: + error = e + raise + finally: + self._finalize_unary_span( + span_info.span, + input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + finally: + span_info.span.end() + + def _handle_record_stream_unary_with_call( + self, + original_callable: Any, + method: str, + requests: list[Any], + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> tuple[Any, Any]: + """Handle stream-unary with_call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.client_stream", is_pre_app_start) + + if not span_info: + return original_callable.with_call(iter(requests), **kwargs) + + error = None + response = None + call = None + response_metadata = None + trailing_metadata = None + + try: + with SpanUtils.with_span(span_info): + try: + response, call = original_callable.with_call(iter(requests), **kwargs) + response_metadata = call.initial_metadata() + trailing_metadata = call.trailing_metadata() + return response, call + except Exception as e: + error = e + raise + finally: + self._finalize_unary_span( + span_info.span, + input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + finally: + span_info.span.end() + + def _handle_record_stream_unary_future( + self, + original_callable: Any, + method: str, + requests: list[Any], + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle stream-unary future call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.client_stream", is_pre_app_start) + + if not span_info: + return original_callable.future(iter(requests), **kwargs) + + instrumentation_self = self + + class RecordingFuture: + """Future wrapper that records the result when accessed.""" + + def __init__(self, original_future: Any, span_info_ref: Any, input_val: dict): + self._original = original_future + self._span_info = span_info_ref + self._input_value = input_val + self._recorded = False + + def result(self, timeout: float | None = None) -> Any: + error = None + response = None + response_metadata = None + trailing_metadata = None + + try: + response = self._original.result(timeout=timeout) + if hasattr(self._original, "initial_metadata"): + response_metadata = self._original.initial_metadata() + if hasattr(self._original, "trailing_metadata"): + trailing_metadata = self._original.trailing_metadata() + return response + except Exception as e: + error = e + raise + finally: + if not self._recorded: + self._recorded = True + instrumentation_self._finalize_unary_span( + self._span_info.span, + self._input_value, + response, + error, + response_metadata, + trailing_metadata, + ) + self._span_info.span.end() + + def exception(self, timeout: float | None = None) -> Any: + return self._original.exception(timeout=timeout) + + def traceback(self, timeout: float | None = None) -> Any: + return self._original.traceback(timeout=timeout) + + def add_done_callback(self, fn: Any) -> None: + self._original.add_done_callback(fn) + + def cancelled(self) -> bool: + return self._original.cancelled() + + def running(self) -> bool: + return self._original.running() + + def done(self) -> bool: + return self._original.done() + + def cancel(self) -> bool: + return self._original.cancel() + + def initial_metadata(self) -> Any: + if hasattr(self._original, "initial_metadata"): + return self._original.initial_metadata() + return [] + + def trailing_metadata(self) -> Any: + if hasattr(self._original, "trailing_metadata"): + return self._original.trailing_metadata() + return [] + + def code(self) -> Any: + if hasattr(self._original, "code"): + return self._original.code() + return None + + def details(self) -> str: + if hasattr(self._original, "details"): + return self._original.details() + return "" + + try: + with SpanUtils.with_span(span_info): + original_future = original_callable.future(iter(requests), **kwargs) + return RecordingFuture(original_future, span_info, input_value) + except Exception as e: + self._finalize_unary_span(span_info.span, input_value, None, e, None, None) + span_info.span.end() + raise + + def _handle_record_stream_stream( + self, + original_callable: Any, + method: str, + requests: list[Any], + input_value: dict[str, Any], + is_pre_app_start: bool, + **kwargs, + ) -> Any: + """Handle stream-stream call in RECORD mode.""" + span_info = self._create_client_span("grpc.client.bidi_stream", is_pre_app_start) + + if not span_info: + return original_callable(iter(requests), **kwargs) + + instrumentation_self = self + + class RecordingStreamIterator: + """Iterator that records streaming responses for bidirectional streaming.""" + + def __init__(self, original_iterator: Any, span_info_ref: Any, input_val: dict): + self._original = original_iterator + self._span_info = span_info_ref + self._input_value = input_val + self._responses: list[dict] = [] + self._error: Exception | None = None + self._finished = False + + def __iter__(self): + return self + + def __next__(self): + try: + response = next(self._original) + readable_body, buffer_map, jsonable_string_map = serialize_grpc_payload(response) + self._responses.append( + { + "body": readable_body, + "bufferMap": buffer_map, + "jsonableStringMap": jsonable_string_map, + } + ) + return response + except StopIteration: + self._finish() + raise + except Exception as e: + self._error = e + self._finish() + raise + + def _finish(self): + if self._finished: + return + self._finished = True + + try: + instrumentation_self._finalize_stream_span( + self._span_info.span, + self._input_value, + self._responses, + self._error, + self._original, + ) + finally: + self._span_info.span.end() + + try: + with SpanUtils.with_span(span_info): + original_iterator = original_callable(iter(requests), **kwargs) + return RecordingStreamIterator(original_iterator, span_info, input_value) + except Exception as e: + self._finalize_stream_span(span_info.span, input_value, [], e, None) + span_info.span.end() + raise + + # ========================================================================= + # Replay Mode Handlers (using generic handler) + # ========================================================================= + + def _handle_replay_unary_unary(self, sdk: TuskDrift, method: str, input_value: dict[str, Any], request: Any) -> Any: + """Handle unary-unary call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.unary", + is_stream=False, + response_type=ReplayResponseType.DIRECT, + ) + + def _handle_replay_unary_unary_with_call( + self, sdk: TuskDrift, method: str, input_value: dict[str, Any], request: Any + ) -> tuple[Any, Any]: + """Handle unary-unary with_call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.unary", + is_stream=False, + response_type=ReplayResponseType.WITH_CALL, + ) + + def _handle_replay_unary_stream( + self, sdk: TuskDrift, method: str, input_value: dict[str, Any], request: Any + ) -> Any: + """Handle unary-stream call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.server_stream", + is_stream=True, + response_type=ReplayResponseType.ITERATOR, + ) + + def _handle_replay_unary_unary_future( + self, sdk: TuskDrift, method: str, input_value: dict[str, Any], request: Any + ) -> Any: + """Handle unary-unary future call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.unary", + is_stream=False, + response_type=ReplayResponseType.FUTURE, + ) + + def _handle_replay_stream_unary(self, sdk: TuskDrift, method: str, input_value: dict[str, Any]) -> Any: + """Handle stream-unary call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.client_stream", + is_stream=False, + response_type=ReplayResponseType.DIRECT, + ) + + def _handle_replay_stream_unary_with_call( + self, sdk: TuskDrift, method: str, input_value: dict[str, Any] + ) -> tuple[Any, Any]: + """Handle stream-unary with_call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.client_stream", + is_stream=False, + response_type=ReplayResponseType.WITH_CALL, + ) + + def _handle_replay_stream_unary_future(self, sdk: TuskDrift, method: str, input_value: dict[str, Any]) -> Any: + """Handle stream-unary future call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.client_stream", + is_stream=False, + response_type=ReplayResponseType.FUTURE, + ) + + def _handle_replay_stream_stream(self, sdk: TuskDrift, method: str, input_value: dict[str, Any]) -> Any: + """Handle stream-stream call in REPLAY mode.""" + return self._handle_replay_generic( + sdk, + method, + input_value, + "grpc.client.bidi_stream", + is_stream=True, + response_type=ReplayResponseType.ITERATOR, + ) + + def _try_get_mock( + self, + sdk: TuskDrift, + method: str, + trace_id: str, + span_id: str, + input_value: dict[str, Any], + is_stream: bool = False, + ) -> Any: + """Try to get a mocked response from CLI.""" + try: + grpc_method, service = parse_grpc_path(method) + span_name = "grpc.client.server_stream" if is_stream else "grpc.client.unary" + + # Use centralized mock finding utility + from ...core.mock_utils import find_mock_response_sync + + mock_response_output = find_mock_response_sync( + sdk=sdk, + trace_id=trace_id, + span_id=span_id, + name=span_name, + package_name=GRPC_MODULE_NAME, + package_type=PackageType.GRPC, + instrumentation_name="GrpcInstrumentation", + submodule_name="client", + input_value=input_value, + kind=SpanKind.CLIENT, + is_pre_app_start=not sdk.app_ready, + ) + + if not mock_response_output or not mock_response_output.found: + logger.debug(f"No mock found for gRPC {method} (trace_id={trace_id})") + return None + + if mock_response_output.response is None: + logger.debug(f"Mock found but response data is None for gRPC {method}") + return None + + return self._create_mock_response(mock_response_output.response, is_stream) + + except Exception as e: + logger.error(f"Error getting mock for gRPC {method}: {e}") + return None + + def _create_mock_response(self, mock_data: dict[str, Any], is_stream: bool) -> Any: + """Create a mocked gRPC response. + + Args: + mock_data: Mock response data from CLI + is_stream: Whether this is a streaming response + + Returns: + Mocked response object or list of responses for streaming + """ + # Check if it's an error response + if "error" in mock_data: + import grpc + + error_info = mock_data["error"] + status_info = mock_data.get("status", {}) + status_code = status_info.get("code", grpc.StatusCode.UNKNOWN.value[0]) + + # Map numeric code to StatusCode + try: + grpc_status = grpc.StatusCode(status_code) + except ValueError: + grpc_status = grpc.StatusCode.UNKNOWN + + raise grpc.RpcError(grpc_status, error_info.get("message", "Unknown error")) + + # Get the response body + body = mock_data.get("body") + buffer_map = mock_data.get("bufferMap", {}) + jsonable_string_map = mock_data.get("jsonableStringMap", {}) + + if is_stream: + # For streams, body should be a list of responses + if isinstance(body, list): + responses = [] + for item in body: + item_body = item.get("body") if isinstance(item, dict) else item + item_buffer_map = item.get("bufferMap", {}) if isinstance(item, dict) else {} + item_string_map = item.get("jsonableStringMap", {}) if isinstance(item, dict) else {} + restored = deserialize_grpc_payload(item_body, item_buffer_map, item_string_map) + # Convert dict to object with attribute access (like protobuf messages) + responses.append(self._dict_to_object(restored)) + return responses + return [] + + # For unary, restore the body + restored_body = deserialize_grpc_payload(body, buffer_map, jsonable_string_map) + # Convert dict to object with attribute access (like protobuf messages) + return self._dict_to_object(restored_body) + + def _dict_to_object(self, data: Any) -> Any: + """Convert a dict to an object with attribute access. + + This allows mock responses to be accessed like protobuf messages: + response.message instead of response["message"] + + Args: + data: Dictionary or other value to convert + + Returns: + MockProtoMessage object or original value if not a dict + """ + if data is None: + return None + if isinstance(data, dict): + return MockProtoMessage(data) + if isinstance(data, list): + return [self._dict_to_object(item) for item in data] + return data + + def _finalize_unary_span( + self, + span: Span, + input_value: dict[str, Any], + response: Any, + error: Exception | None, + response_metadata: Any, + trailing_metadata: Any, + ) -> None: + """Finalize span with request/response data for unary call.""" + try: + # Build output value + output_value: dict[str, Any] = {} + status = SpanStatus(code=StatusCode.OK, message="") + + if error: + error_output: dict[str, Any] = { + "error": { + "message": str(error), + "name": type(error).__name__, + } + } + + # Try to get gRPC status from error + # Use getattr to safely access gRPC-specific error attributes + code_fn = getattr(error, "code", None) + if code_fn is not None and callable(code_fn): + try: + code = code_fn() + details_fn = getattr(error, "details", None) + trailing_fn = getattr(error, "trailing_metadata", None) + error_output["status"] = { + "code": code.value[0] if hasattr(code, "value") else int(code), + "details": str(details_fn()) if details_fn and callable(details_fn) else str(error), + "metadata": serialize_grpc_metadata( + trailing_fn() if trailing_fn and callable(trailing_fn) else None + ), + } + except Exception: + error_output["status"] = {"code": 2, "details": str(error), "metadata": {}} + else: + error_output["status"] = {"code": 2, "details": str(error), "metadata": {}} + + if response_metadata: + error_output["metadata"] = serialize_grpc_metadata(response_metadata) + + output_value = error_output + status = SpanStatus(code=StatusCode.ERROR, message=str(error)) + elif response is not None: + # Serialize response + readable_body, buffer_map, jsonable_string_map = serialize_grpc_payload(response) + + output_value = { + "body": readable_body, + "metadata": serialize_grpc_metadata(response_metadata), + "status": { + "code": 0, # OK + "details": "", + "metadata": serialize_grpc_metadata(trailing_metadata), + }, + "bufferMap": buffer_map, + "jsonableStringMap": jsonable_string_map, + } + + # Set span attributes + normalized_input = remove_none_values(input_value) + normalized_output = remove_none_values(output_value) + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(normalized_input)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(normalized_output)) + + # Set status + if status.code == StatusCode.ERROR: + span.set_status(Status(OTelStatusCode.ERROR, status.message)) + else: + span.set_status(Status(OTelStatusCode.OK)) + + except Exception as e: + logger.error(f"Error finalizing gRPC span: {e}") + span.set_status(Status(OTelStatusCode.ERROR, str(e))) + + def _finalize_stream_span( + self, + span: Span, + input_value: dict[str, Any], + responses: list[dict], + error: Exception | None, + original_iterator: Any, + ) -> None: + """Finalize span with request/response data for streaming call.""" + try: + # Build output value + output_value: dict[str, Any] = {} + status = SpanStatus(code=StatusCode.OK, message="") + + if error: + error_output: dict[str, Any] = { + "error": { + "message": str(error), + "name": type(error).__name__, + } + } + + # Try to get gRPC status from error + # Use getattr to safely access gRPC-specific error attributes + code_fn = getattr(error, "code", None) + if code_fn is not None and callable(code_fn): + try: + code = code_fn() + details_fn = getattr(error, "details", None) + error_output["status"] = { + "code": code.value[0] if hasattr(code, "value") else int(code), + "details": str(details_fn()) if details_fn and callable(details_fn) else str(error), + "metadata": {}, + } + except Exception: + error_output["status"] = {"code": 2, "details": str(error), "metadata": {}} + else: + error_output["status"] = {"code": 2, "details": str(error), "metadata": {}} + + output_value = error_output + status = SpanStatus(code=StatusCode.ERROR, message=str(error)) + else: + # Get metadata from iterator if available + response_metadata = {} + trailing_metadata = {} + if original_iterator: + try: + if hasattr(original_iterator, "initial_metadata"): + response_metadata = serialize_grpc_metadata(original_iterator.initial_metadata()) + if hasattr(original_iterator, "trailing_metadata"): + trailing_metadata = serialize_grpc_metadata(original_iterator.trailing_metadata()) + except Exception: + pass + + output_value = { + "body": responses, + "metadata": response_metadata, + "status": { + "code": 0, # OK + "details": "", + "metadata": trailing_metadata, + }, + "bufferMap": {}, + "jsonableStringMap": {}, + } + + # Set span attributes + normalized_input = remove_none_values(input_value) + normalized_output = remove_none_values(output_value) + span.set_attribute(TdSpanAttributes.INPUT_VALUE, json.dumps(normalized_input)) + span.set_attribute(TdSpanAttributes.OUTPUT_VALUE, json.dumps(normalized_output)) + + # Set status + if status.code == StatusCode.ERROR: + span.set_status(Status(OTelStatusCode.ERROR, status.message)) + else: + span.set_status(Status(OTelStatusCode.OK)) + + except Exception as e: + logger.error(f"Error finalizing gRPC stream span: {e}") + span.set_status(Status(OTelStatusCode.ERROR, str(e))) + + +class MockGrpcCall: + """Mock gRPC call object for replay mode.""" + + def __init__( + self, + initial_metadata: list[tuple[str, str | bytes]] | None = None, + trailing_metadata: list[tuple[str, str | bytes]] | None = None, + ): + self._initial_metadata = initial_metadata or [] + self._trailing_metadata = trailing_metadata or [] + + def initial_metadata(self) -> list[tuple[str, str | bytes]]: + return self._initial_metadata + + def trailing_metadata(self) -> list[tuple[str, str | bytes]]: + return self._trailing_metadata + + def code(self) -> Any: + import grpc + + return grpc.StatusCode.OK + + def details(self) -> str: + return "" + + +class MockGrpcFuture: + """Mock gRPC future object for replay mode. + + Implements the Future interface to return pre-recorded responses. + """ + + def __init__( + self, + result_value: Any, + initial_metadata: list[tuple[str, str | bytes]] | None = None, + trailing_metadata: list[tuple[str, str | bytes]] | None = None, + ): + self._result_value = result_value + self._initial_metadata = initial_metadata or [] + self._trailing_metadata = trailing_metadata or [] + + def result(self, timeout: float | None = None) -> Any: + """Return the pre-recorded result.""" + return self._result_value + + def exception(self, timeout: float | None = None) -> None: + """Return None (no exception for successful mocks).""" + return None + + def traceback(self, timeout: float | None = None) -> None: + """Return None (no traceback for successful mocks).""" + return None + + def add_done_callback(self, fn: Any) -> None: + """Call the callback immediately (future is already done).""" + fn(self) + + def cancelled(self) -> bool: + """Return False (mock futures are never cancelled).""" + return False + + def running(self) -> bool: + """Return False (mock futures are already done).""" + return False + + def done(self) -> bool: + """Return True (mock futures are already done).""" + return True + + def cancel(self) -> bool: + """Return False (cannot cancel a completed future).""" + return False + + # gRPC-specific methods + def initial_metadata(self) -> list[tuple[str, str | bytes]]: + """Return initial metadata.""" + return self._initial_metadata + + def trailing_metadata(self) -> list[tuple[str, str | bytes]]: + """Return trailing metadata.""" + return self._trailing_metadata + + def code(self) -> Any: + """Return OK status code.""" + import grpc + + return grpc.StatusCode.OK + + def details(self) -> str: + """Return empty details.""" + return "" + + +class MockProtoMessage: + """Mock protobuf message that allows attribute access to dict values. + + This wrapper makes mock responses behave like protobuf messages, + supporting both attribute access (response.message) and dict access + (response["message"]). + + Also handles type coercion for int64 fields that were serialized as strings + by protobuf's MessageToDict (which converts int64 to strings for JS compat). + """ + + def __init__(self, data: dict[str, Any]): + # Store data with a private name to avoid conflicts + object.__setattr__(self, "_data", data) + # Recursively convert nested dicts and handle type coercion + for key, value in data.items(): + if isinstance(value, dict): + data[key] = MockProtoMessage(value) + elif isinstance(value, list): + data[key] = [ + MockProtoMessage(item) if isinstance(item, dict) else self._coerce_type(item) for item in value + ] + else: + data[key] = self._coerce_type(value) + + @staticmethod + def _coerce_type(value: Any) -> Any: + """Convert numeric strings back to numbers (int64 → string → int). + + Protobuf's MessageToDict converts int64 to strings to preserve + precision for JavaScript. We reverse this during replay. + """ + if isinstance(value, str): + # Try to convert to int if it looks like a number + if value.lstrip("-").isdigit(): + try: + return int(value) + except (ValueError, OverflowError): + pass + # Try to convert to float for scientific notation + elif value.replace(".", "", 1).replace("-", "", 1).replace("e", "", 1).replace("+", "", 1).isdigit(): + try: + return float(value) + except ValueError: + pass + return value + + def __getattr__(self, name: str) -> Any: + data = object.__getattribute__(self, "_data") + if name in data: + return data[name] + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: + if name == "_data": + object.__setattr__(self, name, value) + else: + self._data[name] = value + + def __getitem__(self, key: str) -> Any: + return self._data[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._data[key] = value + + def __contains__(self, key: str) -> bool: + return key in self._data + + def __repr__(self) -> str: + return f"MockProtoMessage({self._data})" + + def __str__(self) -> str: + return str(self._data) + + def keys(self): + """Return dict keys.""" + return self._data.keys() + + def values(self): + """Return dict values.""" + return self._data.values() + + def items(self): + """Return dict items.""" + return self._data.items() + + def get(self, key: str, default: Any = None) -> Any: + """Get value with default.""" + return self._data.get(key, default) diff --git a/drift/instrumentation/grpc/notes.md b/drift/instrumentation/grpc/notes.md new file mode 100644 index 0000000..c69156d --- /dev/null +++ b/drift/instrumentation/grpc/notes.md @@ -0,0 +1,230 @@ +# gRPC Instrumentation Notes + +## 1. gRPC Communication Patterns + +gRPC supports 4 types of RPC patterns: + +### Unary RPC ✅ (Implemented) + +```text +Client ──[Request]──> Server +Client <──[Response]── Server +``` + +- Single request, single response +- Like a regular function call +- Example: `SayHello(HelloRequest) → HelloReply` + +### Server Streaming ✅ (Implemented) + +```text +Client ──[Request]──────────> Server +Client <──[Response 1]─────── Server +Client <──[Response 2]─────── Server +Client <──[Response N]─────── Server +``` + +- Single request, stream of responses +- Example: `ListFeatures(Rectangle) → stream Feature` +- Use case: Fetching paginated data, real-time feeds, LLM token streaming + +### Client Streaming ✅ (Implemented) + +```text +Client ──[Request 1]──────> Server +Client ──[Request 2]──────> Server +Client ──[Request N]──────> Server +Client <──[Response]─────── Server +``` + +- Stream of requests, single response +- Example: `RecordRoute(stream Point) → RouteSummary` +- Use case: File uploads, aggregating data from client + +### Bidirectional Streaming ✅ (Implemented) + +```text +Client ──[Request 1]──────> Server +Client <──[Response 1]───── Server +Client ──[Request 2]──────> Server +Client <──[Response 2]───── Server + ... (interleaved) +``` + +- Stream of requests AND stream of responses (simultaneously) +- Example: `RouteChat(stream RouteNote) → stream RouteNote` +- Use case: Chat applications, real-time audio processing, collaborative editing + +## 2. Python gRPC API + +### Channel Methods + +The `grpcio` library's `Channel` class provides methods to create callable objects for each pattern: + +```python +channel.unary_unary(method) # Returns UnaryUnaryMultiCallable +channel.unary_stream(method) # Returns UnaryStreamMultiCallable +channel.stream_unary(method) # Returns StreamUnaryMultiCallable +channel.stream_stream(method) # Returns StreamStreamMultiCallable +``` + +### Invocation Variants + +Each `MultiCallable` supports different invocation styles: + +| Variant | Syntax | Behavior | +|---------|--------|----------| +| **Direct call** | `response = callable(request)` | Blocks until response received | +| **with_call** | `response, call = callable.with_call(request)` | Returns response + Call object with metadata | +| **future** | `future = callable.future(request)` | Returns Future for async handling | + +Example: + +```python +# Direct call - simplest, blocks +response = stub.SayHello(request) + +# with_call - access metadata +response, call = stub.SayHello.with_call(request) +initial_metadata = call.initial_metadata() +trailing_metadata = call.trailing_metadata() + +# future - non-blocking +future = stub.SayHello.future(request) +# ... do other work ... +response = future.result() # blocks when needed +``` + +### Naming in Instrumentation Code + +The instrumentation handlers follow the pattern `_handle_{pattern}_{variant}`: + +| Handler | Pattern | Variant | +|---------|---------|---------| +| `_handle_unary_unary` | unary_unary | Direct call | +| `_handle_unary_unary_with_call` | unary_unary | with_call | +| `_handle_unary_unary_future` | unary_unary | future | +| `_handle_unary_stream` | unary_stream | Direct call (iterator) | +| `_handle_stream_unary` | stream_unary | Direct call | +| `_handle_stream_unary_with_call` | stream_unary | with_call | +| `_handle_stream_unary_future` | stream_unary | future | +| `_handle_stream_stream` | stream_stream | Direct call (iterator) | + +Note: Streaming responses (`unary_stream`, `stream_stream`) don't have `with_call`/`future` variants - they always return iterators. + +## 3. Use Cases in Python Ecosystem + +gRPC is commonly used in Python for ML inference, data engineering, and cloud services. + +### ML Inference Services + +| Service | Primary Pattern | Use Case | +|---------|----------------|----------| +| TensorFlow Serving | Unary | `Predict(PredictRequest) → PredictResponse` | +| Triton Inference Server | Unary | `ModelInfer(ModelInferRequest) → ModelInferResponse` | +| vLLM / Text Generation | Server Streaming | `Generate(Prompt) → stream Token` (token-by-token output) | +| Batch Inference | Unary | Send batch of inputs, get batch of outputs | + +**Verdict**: Mostly unary, with server streaming for LLM token streaming + +### Google Cloud APIs + +| Service | Primary Pattern | Use Case | +|---------|----------------|----------| +| BigQuery | Unary + Server Streaming | Query submission (unary), large result streaming | +| Cloud Storage | Unary | Upload/download objects | +| Pub/Sub | Server Streaming | `StreamingPull()` for message consumption | +| Firestore | Server Streaming | Real-time listeners | +| Speech-to-Text | Bidirectional | `StreamingRecognize()` - send audio chunks, get transcripts | +| Dialogflow | Bidirectional | `StreamingDetectIntent()` for conversations | + +**Verdict**: Mostly unary + server streaming. Bidirectional mainly for real-time audio/conversation. + +### Data Engineering + +| Tool | Primary Pattern | Use Case | +|------|----------------|----------| +| Apache Beam | Unary | Job submission, status checks | +| Apache Kafka (gRPC bridge) | Server Streaming | Consuming message streams | +| Ray Serve | Unary | Remote function calls | +| Dask | Unary | Task submission | + +**Verdict**: Almost entirely unary + +### Coverage Summary + +| Pattern | Estimated Usage | Status | Typical Use Cases | +|---------|-----------------|--------|-------------------| +| Unary | ~80-85% | ✅ Implemented | Predictions, queries, CRUD, job submission | +| Server Streaming | ~10-15% | ✅ Implemented | LLM tokens, large results, real-time feeds | +| Client Streaming | ~2-3% | ✅ Implemented | Audio upload (Speech-to-Text) | +| Bidirectional | ~2-3% | ✅ Implemented | Real-time audio/conversations | + +**All 4 gRPC communication patterns are now implemented for client-side instrumentation.** + +## 4. Implementation Details + +### What's Patched + +The instrumentation patches the concrete `grpc._channel.Channel` class (not just the abstract `grpc.Channel`): + +- `grpc._channel.Channel.unary_unary()` - Returns instrumented callable +- `grpc._channel.Channel.unary_stream()` - Returns instrumented callable +- `grpc._channel.Channel.stream_unary()` - Returns instrumented callable +- `grpc._channel.Channel.stream_stream()` - Returns instrumented callable + +**Important**: We patch `grpc._channel.Channel` (the concrete implementation), not `grpc.Channel` (the abstract base class). Python's MRO resolves methods from the concrete class first, so patching the abstract class has no effect. + +### Methods Instrumented + +| Method | Description | Record | Replay | +|--------|-------------|--------|--------| +| `callable()` | Direct unary call | ✅ | ✅ | +| `callable.with_call()` | Unary with metadata | ✅ | ✅ | +| `callable.future()` | Async unary (future) | ✅ | ✅ | +| `stream_callable()` | Server streaming | ✅ | ✅ | +| `stream_unary_callable()` | Client streaming | ✅ | ✅ | +| `stream_stream_callable()` | Bidirectional streaming | ✅ | ✅ | + +### Data Serialization + +gRPC messages often contain binary data (protobuf `bytes` fields). The instrumentation handles this using: + +1. `serialize_grpc_payload()`: Converts protobuf messages to JSON-serializable dicts, replacing `bytes` with placeholders and storing actual data in a `buffer_map` + +2. `deserialize_grpc_payload()`: Restores `bytes` fields from the `buffer_map` during replay + +3. `serialize_grpc_metadata()`: Converts gRPC metadata (headers/trailers) to a JSON-serializable format + +### Comparison with Node SDK + +This implementation follows the same patterns as the Node SDK's gRPC instrumentation: + +| Feature | Node SDK | Python SDK | +|---------|----------|------------| +| Unary calls | ✅ `makeUnaryRequest` | ✅ `unary_unary` | +| Server streaming | ✅ `makeServerStreamRequest` | ✅ `unary_stream` | +| Client streaming | ❌ Not implemented | ✅ `stream_unary` | +| Bidirectional | ❌ Not implemented | ✅ `stream_stream` | +| Server-side | ❌ Commented out | ❌ Not implemented | +| Buffer handling | ✅ Placeholder-based | ✅ Placeholder-based | + +**Note**: Python SDK has more complete client-side coverage than the Node SDK. + +## 5. Future Considerations + +### Server-Side Instrumentation + +Server-side gRPC instrumentation (inbound requests) is not yet implemented. Similar to the Node SDK, we'll hold off until a customer asks for it. + +If needed, this would involve: + +1. Patching `grpc.server()` or `Server.add_insecure_port()` +2. Wrapping service handlers to create SERVER spans +3. Handling replay of server responses (more complex than client-side) + +### Known Limitations + +1. **Client streaming iterator consumption**: For `stream_unary` and `stream_stream` calls, the request iterator is consumed upfront to capture all requests. This changes behavior slightly if the iterator has side effects. + +2. **Protobuf int64 serialization**: Protobuf's `MessageToDict` converts `int64` fields to strings for JSON compatibility. The `MockProtoMessage` class handles this by converting numeric strings back to integers during replay. diff --git a/drift/instrumentation/grpc/types.py b/drift/instrumentation/grpc/types.py new file mode 100644 index 0000000..2d83a80 --- /dev/null +++ b/drift/instrumentation/grpc/types.py @@ -0,0 +1,90 @@ +"""Type definitions for gRPC instrumentation.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Union + + +@dataclass +class BufferMetadata: + """Metadata for handling binary buffers in gRPC payloads.""" + + buffer_map: dict[str, dict[str, str]] = field(default_factory=dict) + """Map of field paths to buffer info (value + encoding).""" + + jsonable_string_map: dict[str, str] = field(default_factory=dict) + """Map of field paths to JSON-able strings.""" + + +@dataclass +class GrpcClientInputValue: + """Input value structure for gRPC client requests.""" + + method: str + """gRPC method name.""" + + service: str + """gRPC service name (package.ServiceName).""" + + body: Any + """Request body (protobuf message as dict).""" + + metadata: dict[str, list[str | dict[str, str]]] + """gRPC metadata (headers).""" + + input_meta: BufferMetadata | None = None + """Buffer metadata for binary fields.""" + + +@dataclass +class GrpcStatus: + """gRPC response status.""" + + code: int + """gRPC status code.""" + + details: str + """Status details/message.""" + + metadata: dict[str, list[str | dict[str, str]]] = field(default_factory=dict) + """Trailing metadata.""" + + +@dataclass +class GrpcOutputValue: + """Output value structure for successful gRPC responses.""" + + body: Any + """Response body (protobuf message as dict).""" + + metadata: dict[str, list[str | dict[str, str]]] + """Initial response metadata.""" + + status: GrpcStatus + """gRPC status.""" + + buffer_map: dict[str, dict[str, str]] = field(default_factory=dict) + """Buffer metadata for binary fields in response.""" + + jsonable_string_map: dict[str, str] = field(default_factory=dict) + """Map of field paths to JSON-able strings.""" + + +@dataclass +class GrpcErrorOutput: + """Output value structure for gRPC errors.""" + + error: dict[str, str] + """Error info (message, name, stack).""" + + status: GrpcStatus + """gRPC status.""" + + metadata: dict[str, list[str | dict[str, str]]] = field(default_factory=dict) + """Response metadata.""" + + +# Type alias for readable metadata values +ReadableMetadataValue = Union[str, dict[str, str]] +ReadableMetadata = dict[str, list[ReadableMetadataValue]] diff --git a/drift/instrumentation/grpc/utils.py b/drift/instrumentation/grpc/utils.py new file mode 100644 index 0000000..16ca303 --- /dev/null +++ b/drift/instrumentation/grpc/utils.py @@ -0,0 +1,299 @@ +"""Utility functions for gRPC instrumentation.""" + +from __future__ import annotations + +import base64 +import copy +import logging +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + +# Sentinel value for replaced buffers +BUFFER_PLACEHOLDER = "__tusk_drift_buffer_replaced__" + + +def is_utf8(data: bytes) -> bool: + """Check if bytes contain valid UTF-8 text.""" + try: + decoded = data.decode("utf-8") + # Verify round-trip works + return data == decoded.encode("utf-8") + except (UnicodeDecodeError, UnicodeEncodeError): + return False + + +def serialize_grpc_metadata(metadata: Any) -> dict[str, list[str | dict[str, str]]]: + """ + Convert gRPC Metadata object to a plain Python dict. + + Args: + metadata: grpc.Metadata or similar object + + Returns: + Dict mapping keys to lists of values (strings or encoded buffers) + """ + if metadata is None: + return {} + + readable_metadata: dict[str, list[str | dict[str, str]]] = {} + + # Handle different metadata formats + # grpc.Metadata can be iterated as (key, value) tuples + try: + items = list(metadata) if hasattr(metadata, "__iter__") else [] + except TypeError: + return {} + + for key, value in items: + if key not in readable_metadata: + readable_metadata[key] = [] + + if isinstance(value, str): + readable_metadata[key].append(value) + elif isinstance(value, bytes): + # Handle binary values + if is_utf8(value): + readable_metadata[key].append({"value": value.decode("utf-8"), "encoding": "utf8"}) + else: + readable_metadata[key].append({"value": base64.b64encode(value).decode("ascii"), "encoding": "base64"}) + else: + # Convert other types to string + readable_metadata[key].append(str(value)) + + return readable_metadata + + +def deserialize_grpc_metadata( + readable_metadata: dict[str, list[str | dict[str, str]]], +) -> list[tuple[str, str | bytes]]: + """ + Convert a plain Python dict back to gRPC metadata tuples. + + Args: + readable_metadata: Dict from serialize_grpc_metadata + + Returns: + List of (key, value) tuples suitable for grpc.Metadata + """ + result: list[tuple[str, str | bytes]] = [] + + for key, values in readable_metadata.items(): + for value in values: + if isinstance(value, str): + result.append((key, value)) + elif isinstance(value, dict) and "value" in value and "encoding" in value: + # Handle encoded buffer + if value["encoding"] == "utf8": + result.append((key, value["value"].encode("utf-8"))) + else: + result.append((key, base64.b64decode(value["value"]))) + + return result + + +def parse_grpc_path(path: str) -> tuple[str, str]: + """ + Extract service and method name from gRPC path. + + Path format: /package.ServiceName/MethodName + + Args: + path: gRPC method path + + Returns: + Tuple of (method, service) + """ + if not path: + return ("", "") + + # Remove leading slash and split + parts = path.lstrip("/").split("/") + service = parts[0] if len(parts) > 0 else "" + method = parts[1] if len(parts) > 1 else "" + + return (method, service) + + +def serialize_grpc_payload(payload: Any) -> tuple[Any, dict[str, dict[str, str]], dict[str, str]]: + """ + Convert request/response body to a serializable format, handling bytes. + + Protobuf messages often contain bytes fields which need special handling + for JSON serialization. This function replaces bytes with placeholders + and stores the actual data in a separate map. + + Args: + payload: Protobuf message (as dict or object with __dict__) + + Returns: + Tuple of (readable_body, buffer_map, jsonable_string_map) + """ + buffer_map: dict[str, dict[str, str]] = {} + jsonable_string_map: dict[str, str] = {} + + # Convert protobuf message to dict if needed + if hasattr(payload, "DESCRIPTOR"): + # It's a protobuf message - convert to dict + try: + from google.protobuf.json_format import MessageToDict + + readable_body = MessageToDict(payload, preserving_proto_field_name=True) + except ImportError: + # Fallback: try to access fields directly + readable_body = _proto_to_dict(payload) + elif isinstance(payload, dict): + readable_body = copy.deepcopy(payload) + else: + # Try to convert to dict + readable_body = copy.deepcopy(payload) if payload is not None else None + + # Process the body recursively to handle bytes + if readable_body is not None: + _process_payload_for_serialization(readable_body, buffer_map, jsonable_string_map, []) + + return (readable_body, buffer_map, jsonable_string_map) + + +def _proto_to_dict(message: Any) -> dict[str, Any]: + """Convert a protobuf message to dict without protobuf library.""" + result: dict[str, Any] = {} + + if hasattr(message, "DESCRIPTOR"): + for field in message.DESCRIPTOR.fields: + value = getattr(message, field.name) + if field.message_type is not None: + # Nested message + if field.label == field.LABEL_REPEATED: + result[field.name] = [_proto_to_dict(v) for v in value] + else: + result[field.name] = _proto_to_dict(value) + else: + result[field.name] = value + elif hasattr(message, "__dict__"): + result = copy.deepcopy(message.__dict__) + + return result + + +def _process_payload_for_serialization( + payload: Any, + buffer_map: dict[str, dict[str, str]], + jsonable_string_map: dict[str, str], + path: list[str], +) -> None: + """ + Recursively process a payload to convert bytes to placeholders. + + Args: + payload: Object to process (dict or list) + buffer_map: Map to store buffer info + jsonable_string_map: Map for JSON-able strings + path: Current path in the object tree + """ + if payload is None or not isinstance(payload, (dict, list)): + return + + if isinstance(payload, list): + for i, item in enumerate(payload): + current_path = [*path, str(i)] + if isinstance(item, bytes): + path_str = ".".join(current_path) + if is_utf8(item): + buffer_map[path_str] = {"value": item.decode("utf-8"), "encoding": "utf8"} + else: + buffer_map[path_str] = {"value": base64.b64encode(item).decode("ascii"), "encoding": "base64"} + payload[i] = BUFFER_PLACEHOLDER + elif isinstance(item, (dict, list)): + _process_payload_for_serialization(item, buffer_map, jsonable_string_map, current_path) + return + + # Handle dict + for key in list(payload.keys()): + current_path = [*path, key] + path_str = ".".join(current_path) + value = payload[key] + + if isinstance(value, bytes): + if is_utf8(value): + buffer_map[path_str] = {"value": value.decode("utf-8"), "encoding": "utf8"} + else: + buffer_map[path_str] = {"value": base64.b64encode(value).decode("ascii"), "encoding": "base64"} + payload[key] = BUFFER_PLACEHOLDER + elif isinstance(value, (dict, list)): + _process_payload_for_serialization(value, buffer_map, jsonable_string_map, current_path) + + +def deserialize_grpc_payload( + readable_payload: Any, + buffer_map: dict[str, dict[str, str]], + jsonable_string_map: dict[str, str], +) -> Any: + """ + Convert a serialized payload back to its original format with bytes restored. + + Args: + readable_payload: Payload from serialize_grpc_payload + buffer_map: Buffer map from serialize_grpc_payload + jsonable_string_map: String map from serialize_grpc_payload + + Returns: + Payload with bytes fields restored + """ + if readable_payload is None: + return None + + result = copy.deepcopy(readable_payload) + _restore_payload_from_serialization(result, buffer_map, jsonable_string_map, []) + return result + + +def _restore_payload_from_serialization( + payload: Any, + buffer_map: dict[str, dict[str, str]], + jsonable_string_map: dict[str, str], + path: list[str], +) -> None: + """ + Recursively restore bytes in a payload. + + Args: + payload: Object to process + buffer_map: Buffer map with stored values + jsonable_string_map: String map + path: Current path in the object tree + """ + if payload is None or not isinstance(payload, (dict, list)): + return + + if isinstance(payload, list): + for i, item in enumerate(payload): + current_path = [*path, str(i)] + path_str = ".".join(current_path) + if item == BUFFER_PLACEHOLDER and path_str in buffer_map: + buffer_info = buffer_map[path_str] + if buffer_info["encoding"] == "utf8": + payload[i] = buffer_info["value"].encode("utf-8") + else: + payload[i] = base64.b64decode(buffer_info["value"]) + elif isinstance(item, (dict, list)): + _restore_payload_from_serialization(item, buffer_map, jsonable_string_map, current_path) + return + + # Handle dict + for key in list(payload.keys()): + current_path = [*path, key] + path_str = ".".join(current_path) + value = payload[key] + + if value == BUFFER_PLACEHOLDER and path_str in buffer_map: + buffer_info = buffer_map[path_str] + if buffer_info["encoding"] == "utf8": + payload[key] = buffer_info["value"].encode("utf-8") + else: + payload[key] = base64.b64decode(buffer_info["value"]) + elif isinstance(value, (dict, list)): + _restore_payload_from_serialization(value, buffer_map, jsonable_string_map, current_path) diff --git a/pyproject.toml b/pyproject.toml index 5d65a6c..124c539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,11 @@ prerelease = "allow" [tool.ruff] target-version = "py39" line-length = 120 +exclude = [ + # Ignore all rules for generated protobuf and grpc + "**/greeter_pb2.py", + "**/greeter_pb2_grpc.py", +] [tool.ruff.lint] select = [ @@ -127,6 +132,7 @@ exclude = ["**/e2e-tests/**"] # Disable unresolved-import errors for instrumentation files with optional dependencies include = [ "drift/instrumentation/django/**", + "drift/instrumentation/grpc/**", "drift/instrumentation/psycopg/**", "drift/instrumentation/psycopg2/**", "drift/instrumentation/redis/**", diff --git a/scripts/generate_manifest.py b/scripts/generate_manifest.py index 482f09a..69e6bb1 100644 --- a/scripts/generate_manifest.py +++ b/scripts/generate_manifest.py @@ -27,7 +27,7 @@ import ast import json import sys -from datetime import UTC, datetime +from datetime import datetime, timezone from pathlib import Path # Script location @@ -177,7 +177,7 @@ def main() -> int: "sdkVersion": sdk_version, "language": "python", "pythonVersion": ">=3.12", - "generatedAt": datetime.now(UTC).isoformat().replace("+00:00", "Z"), + "generatedAt": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), "instrumentations": instrumentations, }