Skip to content

Commit 7823369

Browse files
committed
feat: propagate OTel tracing context from client to server through websocket
Add server-side trace context extraction and span creation so that distributed traces flow end-to-end through River websocket connections. Changes: - Add TransportMessageTracingGetter to extract traceparent/tracestate from incoming TransportMessages (counterpart to existing Setter) - Extract trace context in ServerSession._open_stream_and_call_handler and create a SERVER span that is a child of the client's CLIENT span - Run handler within the extracted context so downstream code inherits the trace - Update and expand tests to verify server spans, trace propagation, span attributes, and independent traces for concurrent RPCs
1 parent bd88e45 commit 7823369

File tree

3 files changed

+369
-18
lines changed

3 files changed

+369
-18
lines changed

src/replit_river/rpc.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import grpc
2222
from aiochannel import Channel, ChannelClosed
23-
from opentelemetry.propagators.textmap import Setter
23+
from opentelemetry.propagators.textmap import Getter, Setter
2424
from pydantic import BaseModel, ConfigDict, Field
2525

2626
from replit_river.error_schema import (
@@ -126,6 +126,36 @@ def set(self, carrier: TransportMessage, key: str, value: str) -> None:
126126
logger.warning("unknown trace propagation key", extra={"key": key})
127127

128128

129+
class TransportMessageTracingGetter(Getter[TransportMessage]):
130+
"""
131+
Handles extracting tracing context from an incoming transport message.
132+
"""
133+
134+
def get(self, carrier: TransportMessage, key: str) -> list[str] | None:
135+
if not carrier.tracing:
136+
return None
137+
match key:
138+
case "traceparent":
139+
value = carrier.tracing.traceparent
140+
case "tracestate":
141+
value = carrier.tracing.tracestate
142+
case _:
143+
return None
144+
if not value:
145+
return None
146+
return [value]
147+
148+
def keys(self, carrier: TransportMessage) -> list[str]:
149+
if not carrier.tracing:
150+
return []
151+
keys: list[str] = []
152+
if carrier.tracing.traceparent:
153+
keys.append("traceparent")
154+
if carrier.tracing.tracestate:
155+
keys.append("tracestate")
156+
return keys
157+
158+
129159
class GrpcContext(grpc.aio.ServicerContext, Generic[RequestType, ResponseType]):
130160
"""Represents a gRPC-compatible ServicerContext for River interop."""
131161

src/replit_river/server_session.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import websockets
66
from aiochannel import Channel, ChannelClosed
7+
from opentelemetry import context, trace
8+
from opentelemetry.trace import SpanKind
79
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
810
from websockets.exceptions import ConnectionClosed
911

@@ -24,6 +26,7 @@
2426
STREAM_OPEN_BIT,
2527
GenericRpcHandlerBuilder,
2628
TransportMessage,
29+
TransportMessageTracingGetter,
2730
TransportMessageTracingSetter,
2831
)
2932

@@ -32,9 +35,11 @@
3235

3336
logger = logging.getLogger(__name__)
3437

38+
tracer = trace.get_tracer(__name__)
3539

3640
trace_propagator = TraceContextTextMapPropagator()
3741
trace_setter = TransportMessageTracingSetter()
42+
trace_getter = TransportMessageTracingGetter()
3843

3944

4045
class ServerSession(Session):
@@ -216,6 +221,23 @@ async def _open_stream_and_call_handler(
216221
"upload-stream", # subscription
217222
"stream",
218223
)
224+
225+
# Extract trace context from the incoming message and create a server span.
226+
extracted_context = trace_propagator.extract(
227+
carrier=msg, getter=trace_getter
228+
)
229+
span = tracer.start_span(
230+
f"river.server.{method_type}.{msg.serviceName}.{msg.procedureName}",
231+
context=extracted_context,
232+
kind=SpanKind.SERVER,
233+
)
234+
span.set_attribute("river.service_name", msg.serviceName)
235+
span.set_attribute("river.procedure_name", msg.procedureName)
236+
span.set_attribute("river.method_type", method_type)
237+
span.set_attribute("river.stream_id", msg.streamId)
238+
span.set_attribute("river.client_id", msg.from_)
239+
handler_ctx = trace.set_span_in_context(span, extracted_context)
240+
219241
# New channel pair.
220242
input_stream: Channel[Any] = Channel(
221243
MAX_MESSAGE_BUFFER_SIZE if is_streaming_input else 1
@@ -231,9 +253,13 @@ async def _open_stream_and_call_handler(
231253
await input_stream.put(msg.payload)
232254
except (RuntimeError, ChannelClosed) as e:
233255
raise InvalidMessageException(e) from e
234-
# Start the handler.
256+
# Start the handler with the extracted trace context.
235257
self._task_manager.create_task(
236-
handler_func(msg.from_, input_stream, output_stream), tg
258+
self._run_handler_with_tracing(
259+
handler_func, msg.from_, input_stream, output_stream,
260+
span, handler_ctx,
261+
),
262+
tg,
237263
)
238264
self._task_manager.create_task(
239265
self._send_responses_from_output_stream(
@@ -243,6 +269,29 @@ async def _open_stream_and_call_handler(
243269
)
244270
return input_stream
245271

272+
async def _run_handler_with_tracing(
273+
self,
274+
handler_func: GenericRpcHandlerBuilder,
275+
peer: str,
276+
input_stream: Channel[Any],
277+
output_stream: Channel[Any],
278+
span: trace.Span,
279+
handler_ctx: context.Context,
280+
) -> None:
281+
"""Run an RPC handler within the extracted trace context, ending the span
282+
when the handler completes."""
283+
token = context.attach(handler_ctx)
284+
try:
285+
await handler_func(peer, input_stream, output_stream)
286+
span.set_status(trace.StatusCode.OK)
287+
except Exception as e:
288+
span.set_status(trace.StatusCode.ERROR, str(e))
289+
span.record_exception(e)
290+
raise
291+
finally:
292+
span.end()
293+
context.detach(token)
294+
246295
async def _send_responses_from_output_stream(
247296
self,
248297
stream_id: str,

0 commit comments

Comments
 (0)