diff --git a/src/openai/_send_queue.py b/src/openai/_send_queue.py index b35d0fbcba..998b14850a 100644 --- a/src/openai/_send_queue.py +++ b/src/openai/_send_queue.py @@ -11,31 +11,32 @@ class SendQueue: """Bounded byte-size queue for outgoing WebSocket messages. - Messages are stored as pre-serialized strings. The queue enforces a - maximum byte budget so that unbounded buffering cannot occur during - reconnection windows. + Messages are stored as either ``str`` (text frames) or ``bytes`` (binary + frames), preserving the original frame type so that binary payloads are + not corrupted on replay. The queue enforces a maximum byte budget so that + unbounded buffering cannot occur during reconnection windows. """ def __init__(self, max_bytes: int = 1_048_576) -> None: - self._queue: list[tuple[str, int]] = [] # (data, byte_length) + self._queue: list[tuple[bytes | str, int]] = [] # (data, byte_length) self._bytes: int = 0 self._max_bytes = max_bytes self._lock = threading.Lock() - def enqueue(self, data: str) -> None: + def enqueue(self, data: bytes | str) -> None: """Append *data* to the queue. Raises :class:`WebSocketQueueFullError` if the message would exceed the byte-size limit. """ - byte_length = len(data.encode("utf-8")) + byte_length = len(data) if isinstance(data, bytes) else len(data.encode("utf-8")) with self._lock: if self._bytes + byte_length > self._max_bytes: raise WebSocketQueueFullError("send queue is full, message discarded") self._queue.append((data, byte_length)) self._bytes += byte_length - def flush_sync(self, send: typing.Callable[[str], object]) -> None: + def flush_sync(self, send: typing.Callable[[bytes | str], object]) -> None: """Send every queued message via *send*. If *send* raises, the failing message and all subsequent messages @@ -56,7 +57,7 @@ def flush_sync(self, send: typing.Callable[[str], object]) -> None: self._bytes = sum(bl for _, bl in self._queue) raise - async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object]]) -> None: + async def flush_async(self, send: typing.Callable[[bytes | str], typing.Awaitable[object]]) -> None: """Async variant of :meth:`flush_sync`.""" with self._lock: pending = list(self._queue) @@ -73,7 +74,7 @@ async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object self._bytes = sum(bl for _, bl in self._queue) raise - def drain(self) -> list[str]: + def drain(self) -> list[bytes | str]: """Remove and return all queued messages.""" with self._lock: items = [data for data, _ in self._queue] diff --git a/src/openai/resources/realtime/realtime.py b/src/openai/resources/realtime/realtime.py index e4c5bd8163..a29248e3c8 100644 --- a/src/openai/resources/realtime/realtime.py +++ b/src/openai/resources/realtime/realtime.py @@ -359,10 +359,13 @@ async def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> N async def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - await self._connection.send(data) + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True @@ -839,10 +842,13 @@ def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None: def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - self._connection.send(data) + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True diff --git a/src/openai/resources/responses/responses.py b/src/openai/resources/responses/responses.py index 5019d7e831..06f3d8e28a 100644 --- a/src/openai/resources/responses/responses.py +++ b/src/openai/resources/responses/responses.py @@ -3852,10 +3852,13 @@ async def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> async def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - await self._connection.send(data) + try: + await self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise async def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True @@ -4309,10 +4312,13 @@ def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None: def send_raw(self, data: bytes | str) -> None: if self._is_reconnecting: - raw = data if isinstance(data, str) else data.decode("utf-8") - self._send_queue.enqueue(raw) + self._send_queue.enqueue(data) return - self._connection.send(data) + try: + self._connection.send(data) + except Exception: + self._send_queue.enqueue(data) + raise def close(self, *, code: int = 1000, reason: str = "") -> None: self._intentionally_closed = True diff --git a/tests/test_realtime_reconnect.py b/tests/test_realtime_reconnect.py new file mode 100644 index 0000000000..27b14744bc --- /dev/null +++ b/tests/test_realtime_reconnect.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from openai.resources.realtime.realtime import AsyncRealtimeConnection + + +def _connection_closed_error(code: int = 1011) -> Exception: + from websockets.frames import Close + from websockets.exceptions import ConnectionClosedError + + return ConnectionClosedError(Close(code=code, reason=""), None) + + +class _DeadConnection: + """A connection whose send() always fails, simulating a dropped socket.""" + + async def send(self, _data: bytes | str) -> None: + raise _connection_closed_error() + + async def close(self, *, code: int = 1000, reason: str = "") -> None: + pass + + +class _RecordingConnection: + """The connection returned after a successful reconnect.""" + + def __init__(self) -> None: + self.sent: list[bytes | str] = [] + + async def send(self, data: bytes | str) -> None: + self.sent.append(data) + + +def _make_connection(new_conn: _RecordingConnection) -> AsyncRealtimeConnection: + async def make_ws(_extra_query: Any, _extra_headers: Any) -> Any: + return new_conn + + return AsyncRealtimeConnection( + _DeadConnection(), # type: ignore[arg-type] + make_ws=make_ws, + on_reconnecting=lambda _event: None, + max_retries=1, + initial_delay=0.0, + max_delay=0.0, + ) + + +@pytest.mark.asyncio +async def test_reconnect_resends_binary_payload_unchanged() -> None: + """End-to-end: a binary send_raw() that fails mid-send is queued and + replayed byte-for-byte after reconnect, without UTF-8 corruption.""" + from websockets.exceptions import ConnectionClosedError + + new_conn = _RecordingConnection() + conn = _make_connection(new_conn) + + binary = b"\xff\xfe\x00audio" # not valid UTF-8 (would crash on decode) + + # send fails on the dead socket -> the original connection error must + # surface (NOT a UnicodeDecodeError from decoding the binary payload), + # and the payload must be queued for replay. + with pytest.raises(ConnectionClosedError): + await conn.send_raw(binary) + + # Drive the real reconnect path, which flushes the queue to the new socket. + reconnected = await conn._reconnect(_connection_closed_error()) + assert reconnected is True + + assert new_conn.sent == [binary] + assert isinstance(new_conn.sent[0], bytes) + + +@pytest.mark.asyncio +async def test_reconnect_resends_text_payload() -> None: + """A str send_raw() is replayed as text after reconnect.""" + from websockets.exceptions import ConnectionClosedError + + new_conn = _RecordingConnection() + conn = _make_connection(new_conn) + + with pytest.raises(ConnectionClosedError): + await conn.send_raw('{"type": "input_audio_buffer.append"}') + + assert await conn._reconnect(_connection_closed_error()) is True + assert new_conn.sent == ['{"type": "input_audio_buffer.append"}'] + assert isinstance(new_conn.sent[0], str) diff --git a/tests/test_send_queue.py b/tests/test_send_queue.py index 61db916bc4..6d0676337f 100644 --- a/tests/test_send_queue.py +++ b/tests/test_send_queue.py @@ -19,6 +19,39 @@ def test_enqueue_and_drain(self) -> None: assert items == ['{"type": "session.update"}', '{"type": "response.create"}'] assert len(q) == 0 + def test_enqueue_preserves_binary_frames(self) -> None: + """Binary payloads must be stored as-is, not decoded to text. + + Decoding to UTF-8 would corrupt binary frames and raise + ``UnicodeDecodeError`` for arbitrary bytes (e.g. audio chunks). + """ + q = SendQueue() + binary = b"\xff\xfe\x00audio" # not valid UTF-8 + q.enqueue(binary) + q.enqueue("text") + + items = q.drain() + assert items == [binary, "text"] + assert isinstance(items[0], bytes) + assert isinstance(items[1], str) + + def test_enqueue_counts_binary_byte_length(self) -> None: + q = SendQueue(max_bytes=4) + q.enqueue(b"\xff\xfe\xfd\xfc") # 4 bytes, fits exactly + with pytest.raises(WebSocketQueueFullError): + q.enqueue(b"\x00") # would exceed + assert len(q) == 1 + + def test_flush_sync_preserves_binary(self) -> None: + q = SendQueue() + binary = b"\xff\xfe" + q.enqueue(binary) + q.enqueue("text") + + sent: list[bytes | str] = [] + q.flush_sync(sent.append) + assert sent == [binary, "text"] + def test_enqueue_respects_byte_limit(self) -> None: q = SendQueue(max_bytes=10) q.enqueue("12345") # 5 bytes, fits