Skip to content

Commit 634f270

Browse files
committed
feat: propagate OTel context via WebSocket HTTP upgrade headers
Propagate traceparent, tracestate, and baggage through the WebSocket connection using standard W3C HTTP headers on the upgrade request, matching how any HTTP-based service would propagate OTel context. Client side (v1 + v2): - Use propagate.inject() to capture the current OTel context into a headers dict, then pass it as extra_headers/additional_headers to websockets.connect(). Server side: - In Server.serve(), use propagate.extract() on websocket.request_headers to restore the OTel context, then attach it as the ambient context for the lifetime of the connection.
1 parent bd88e45 commit 634f270

File tree

4 files changed

+280
-31
lines changed

4 files changed

+280
-31
lines changed

src/replit_river/client_transport.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import nanoid
77
import websockets
8+
from opentelemetry import propagate
89
from pydantic import ValidationError
910
from websockets import (
1011
WebSocketCommonProtocol,
@@ -170,7 +171,13 @@ async def _establish_new_connection(
170171

171172
try:
172173
uri_and_metadata = await self._uri_and_metadata_factory()
173-
ws = await websockets.connect(uri_and_metadata["uri"], max_size=None)
174+
otel_headers: dict[str, str] = {}
175+
propagate.inject(otel_headers)
176+
ws = await websockets.connect(
177+
uri_and_metadata["uri"],
178+
max_size=None,
179+
extra_headers=otel_headers,
180+
)
174181
session_id = (
175182
self.generate_nanoid()
176183
if not old_session

src/replit_river/server.py

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Mapping
44

55
import websockets
6+
from opentelemetry import context, propagate
67
from websockets.exceptions import ConnectionClosed
78
from websockets.server import WebSocketServerProtocol
89

@@ -68,34 +69,48 @@ async def serve(self, websocket: WebSocketServerProtocol) -> None:
6869
logger.debug(
6970
"River server started establishing session with ws: %s", websocket.id
7071
)
71-
grace_ms = self._transport_options.handshake_timeout_ms
72+
73+
# Extract OTel context (traceparent, tracestate, baggage) from the
74+
# WebSocket HTTP upgrade request headers and make it the ambient
75+
# context for the lifetime of this connection.
76+
otel_context = propagate.extract(websocket.request_headers)
77+
token = context.attach(otel_context)
78+
7279
try:
73-
session = await asyncio.wait_for(
74-
self._handshake_to_get_session(websocket),
75-
grace_ms / 1000, # wait_for unit is seconds
76-
)
77-
if not session:
80+
grace_ms = self._transport_options.handshake_timeout_ms
81+
try:
82+
session = await asyncio.wait_for(
83+
self._handshake_to_get_session(websocket),
84+
grace_ms / 1000, # wait_for unit is seconds
85+
)
86+
if not session:
87+
return
88+
except asyncio.TimeoutError:
89+
logger.error(
90+
f"Handshake timeout after {grace_ms}ms, closing websocket"
91+
)
92+
await websocket.close()
7893
return
79-
except asyncio.TimeoutError:
80-
logger.error(f"Handshake timeout after {grace_ms}ms, closing websocket")
81-
await websocket.close()
82-
return
83-
except asyncio.CancelledError:
84-
logger.error("Handshake cancelled, closing websocket")
85-
await websocket.close()
86-
return
87-
logger.debug("River server session established, start serving messages")
94+
except asyncio.CancelledError:
95+
logger.error("Handshake cancelled, closing websocket")
96+
await websocket.close()
97+
return
98+
logger.debug("River server session established, start serving messages")
8899

89-
try:
90-
# Session serve will be closed in two cases
91-
# 1. websocket is closed
92-
# 2. exception thrown
93-
# session should be kept in order to be reused by the reconnect within the
94-
# grace period.
95-
await session.serve()
96-
except ConnectionClosed:
97-
logger.debug("ConnectionClosed while serving", exc_info=True)
98-
# We don't have to close the websocket here, it is already closed.
99-
except Exception:
100-
logger.exception("River transport error in server %s", self._server_id)
101-
await websocket.close()
100+
try:
101+
# Session serve will be closed in two cases
102+
# 1. websocket is closed
103+
# 2. exception thrown
104+
# session should be kept in order to be reused by the reconnect within
105+
# the grace period.
106+
await session.serve()
107+
except ConnectionClosed:
108+
logger.debug("ConnectionClosed while serving", exc_info=True)
109+
# We don't have to close the websocket here, it is already closed.
110+
except Exception:
111+
logger.exception(
112+
"River transport error in server %s", self._server_id
113+
)
114+
await websocket.close()
115+
finally:
116+
context.detach(token)

src/replit_river/v2/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import websockets.asyncio.client
2424
from aiochannel import Channel, ChannelEmpty, ChannelFull
2525
from aiochannel.errors import ChannelClosed
26+
from opentelemetry import propagate
2627
from opentelemetry.trace import Span, use_span
2728
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
2829
from pydantic import ValidationError
@@ -1133,9 +1134,12 @@ async def _do_ensure_connected[HandshakeMetadata](
11331134
ws: ClientConnection | None = None
11341135
try:
11351136
uri_and_metadata = await uri_and_metadata_factory()
1137+
otel_headers: dict[str, str] = {}
1138+
propagate.inject(otel_headers)
11361139
ws = await websockets.asyncio.client.connect(
11371140
uri_and_metadata["uri"],
11381141
max_size=None,
1142+
additional_headers=otel_headers,
11391143
)
11401144
transition_connecting(ws)
11411145

tests/v1/test_opentelemetry.py

Lines changed: 225 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,25 @@
11
import contextlib
2+
import logging
23
from datetime import timedelta
3-
from typing import AsyncGenerator, AsyncIterator, Iterator
4+
from typing import AsyncGenerator, AsyncIterator, Iterator, Literal
45

56
import grpc
67
import grpc.aio
78
import pytest
9+
from opentelemetry import baggage, context, propagate, trace
10+
from opentelemetry.baggage.propagation import W3CBaggagePropagator
11+
from opentelemetry.propagators.composite import CompositePropagator
812
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
913
from opentelemetry.trace import StatusCode
14+
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
15+
from websockets.server import serve
1016

1117
from replit_river.client import Client
18+
from replit_river.client_transport import UriAndMetadata
1219
from replit_river.error_schema import RiverError, RiverException
13-
from replit_river.rpc import stream_method_handler
20+
from replit_river.rpc import rpc_method_handler, stream_method_handler
21+
from replit_river.server import Server
22+
from replit_river.transport_options import TransportOptions
1423
from tests.conftest import (
1524
HandlerMapping,
1625
deserialize_error,
@@ -219,3 +228,217 @@ async def stream_data() -> AsyncGenerator[str, None]:
219228
assert len(spans) == 1
220229
assert spans[0].name == "river.client.stream.test_service.stream_method"
221230
assert spans[0].status.status_code == StatusCode.OK
231+
232+
233+
# ===== OTel context propagation via WebSocket HTTP upgrade headers =====
234+
235+
236+
# A handler that reads OTel baggage from the ambient context and returns it.
237+
async def baggage_echo_handler(
238+
request: str, ctx: grpc.aio.ServicerContext
239+
) -> str:
240+
all_baggage = baggage.get_all()
241+
# Return baggage as a comma-separated "key=value" string
242+
return ",".join(f"{k}={v}" for k, v in sorted(all_baggage.items()))
243+
244+
245+
baggage_echo_handlers: HandlerMapping = {
246+
("test_service", "baggage_echo"): (
247+
"rpc",
248+
rpc_method_handler(baggage_echo_handler, deserialize_request, serialize_response),
249+
)
250+
}
251+
252+
253+
@pytest.fixture
254+
def _enable_baggage_propagator():
255+
"""Temporarily install a composite propagator that includes both
256+
W3C TraceContext and W3C Baggage propagation so that
257+
``propagate.inject()`` / ``propagate.extract()`` handle the
258+
``baggage`` HTTP header."""
259+
previous = propagate.get_global_textmap()
260+
propagate.set_global_textmap(
261+
CompositePropagator([
262+
TraceContextTextMapPropagator(),
263+
W3CBaggagePropagator(),
264+
])
265+
)
266+
yield
267+
propagate.set_global_textmap(previous)
268+
269+
270+
@pytest.mark.asyncio
271+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
272+
@pytest.mark.usefixtures("_enable_baggage_propagator")
273+
async def test_baggage_propagated_via_ws_headers(
274+
no_logging_error: NoErrors,
275+
server: Server,
276+
transport_options: TransportOptions,
277+
) -> None:
278+
"""Verify that OTel baggage set on the client side is propagated to the
279+
server via the WebSocket HTTP upgrade request headers."""
280+
281+
# Set baggage in the ambient OTel context *before* the client connects,
282+
# so that ``propagate.inject()`` (called inside ``websockets.connect()``)
283+
# includes the ``baggage`` header.
284+
ctx = baggage.set_baggage("test-key", "test-value")
285+
ctx = baggage.set_baggage("another-key", "another-value", context=ctx)
286+
token = context.attach(ctx)
287+
288+
binding = None
289+
try:
290+
binding = await serve(server.serve, "127.0.0.1")
291+
sockets = list(binding.sockets)
292+
assert len(sockets) == 1
293+
socket = sockets[0]
294+
295+
async def websocket_uri_factory() -> UriAndMetadata[None]:
296+
return {
297+
"uri": "ws://%s:%d" % socket.getsockname(),
298+
"metadata": None,
299+
}
300+
301+
client: Client[Literal[None]] = Client[None](
302+
uri_and_metadata_factory=websocket_uri_factory,
303+
client_id="test_client",
304+
server_id="test_server",
305+
transport_options=transport_options,
306+
)
307+
try:
308+
response = await client.send_rpc(
309+
"test_service",
310+
"baggage_echo",
311+
"ignored",
312+
serialize_request,
313+
deserialize_response,
314+
deserialize_error,
315+
timedelta(seconds=20),
316+
)
317+
# The handler returns sorted "key=value" pairs
318+
assert response == "another-key=another-value,test-key=test-value"
319+
finally:
320+
logging.debug("Start closing test client")
321+
await client.close()
322+
finally:
323+
context.detach(token)
324+
logging.debug("Start closing test server")
325+
if binding:
326+
binding.close()
327+
await server.close()
328+
if binding:
329+
await binding.wait_closed()
330+
331+
332+
@pytest.mark.asyncio
333+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
334+
@pytest.mark.usefixtures("_enable_baggage_propagator")
335+
async def test_no_baggage_when_none_set(
336+
no_logging_error: NoErrors,
337+
server: Server,
338+
transport_options: TransportOptions,
339+
) -> None:
340+
"""Verify that when no baggage is set, the server sees empty baggage."""
341+
342+
binding = None
343+
try:
344+
binding = await serve(server.serve, "127.0.0.1")
345+
sockets = list(binding.sockets)
346+
assert len(sockets) == 1
347+
socket = sockets[0]
348+
349+
async def websocket_uri_factory() -> UriAndMetadata[None]:
350+
return {
351+
"uri": "ws://%s:%d" % socket.getsockname(),
352+
"metadata": None,
353+
}
354+
355+
client: Client[Literal[None]] = Client[None](
356+
uri_and_metadata_factory=websocket_uri_factory,
357+
client_id="test_client",
358+
server_id="test_server",
359+
transport_options=transport_options,
360+
)
361+
try:
362+
response = await client.send_rpc(
363+
"test_service",
364+
"baggage_echo",
365+
"ignored",
366+
serialize_request,
367+
deserialize_response,
368+
deserialize_error,
369+
timedelta(seconds=20),
370+
)
371+
assert response == ""
372+
finally:
373+
logging.debug("Start closing test client")
374+
await client.close()
375+
finally:
376+
logging.debug("Start closing test server")
377+
if binding:
378+
binding.close()
379+
await server.close()
380+
if binding:
381+
await binding.wait_closed()
382+
383+
384+
@pytest.mark.asyncio
385+
@pytest.mark.parametrize("handlers", [{**baggage_echo_handlers}])
386+
@pytest.mark.usefixtures("_enable_baggage_propagator")
387+
async def test_traceparent_propagated_via_ws_headers(
388+
no_logging_error: NoErrors,
389+
server: Server,
390+
transport_options: TransportOptions,
391+
span_exporter: InMemorySpanExporter,
392+
) -> None:
393+
"""Verify that when a span is active on the client, the traceparent
394+
header is sent on the WS upgrade and the server-side context inherits
395+
the trace."""
396+
tracer = trace.get_tracer(__name__)
397+
398+
with tracer.start_as_current_span("client-operation") as client_span:
399+
# Also set some baggage
400+
ctx = baggage.set_baggage("trace-test", "yes")
401+
token = context.attach(ctx)
402+
403+
binding = None
404+
try:
405+
binding = await serve(server.serve, "127.0.0.1")
406+
sockets = list(binding.sockets)
407+
assert len(sockets) == 1
408+
socket = sockets[0]
409+
410+
async def websocket_uri_factory() -> UriAndMetadata[None]:
411+
return {
412+
"uri": "ws://%s:%d" % socket.getsockname(),
413+
"metadata": None,
414+
}
415+
416+
client: Client[Literal[None]] = Client[None](
417+
uri_and_metadata_factory=websocket_uri_factory,
418+
client_id="test_client",
419+
server_id="test_server",
420+
transport_options=transport_options,
421+
)
422+
try:
423+
response = await client.send_rpc(
424+
"test_service",
425+
"baggage_echo",
426+
"ignored",
427+
serialize_request,
428+
deserialize_response,
429+
deserialize_error,
430+
timedelta(seconds=20),
431+
)
432+
# Verify baggage was propagated
433+
assert response == "trace-test=yes"
434+
finally:
435+
logging.debug("Start closing test client")
436+
await client.close()
437+
finally:
438+
context.detach(token)
439+
logging.debug("Start closing test server")
440+
if binding:
441+
binding.close()
442+
await server.close()
443+
if binding:
444+
await binding.wait_closed()

0 commit comments

Comments
 (0)