Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions src/openai/_send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,32 @@
class SendQueue:
"""Bounded byte-size queue for outgoing WebSocket messages.

Messages are stored as pre-serialized strings. The queue enforces a
maximum byte budget so that unbounded buffering cannot occur during
reconnection windows.
Messages are stored as either ``str`` (text frames) or ``bytes`` (binary
frames), preserving the original frame type so that binary payloads are
not corrupted on replay. The queue enforces a maximum byte budget so that
unbounded buffering cannot occur during reconnection windows.
"""

def __init__(self, max_bytes: int = 1_048_576) -> None:
self._queue: list[tuple[str, int]] = [] # (data, byte_length)
self._queue: list[tuple[bytes | str, int]] = [] # (data, byte_length)
self._bytes: int = 0
self._max_bytes = max_bytes
self._lock = threading.Lock()

def enqueue(self, data: str) -> None:
def enqueue(self, data: bytes | str) -> None:
"""Append *data* to the queue.

Raises :class:`WebSocketQueueFullError` if the message would
exceed the byte-size limit.
"""
byte_length = len(data.encode("utf-8"))
byte_length = len(data) if isinstance(data, bytes) else len(data.encode("utf-8"))
with self._lock:
if self._bytes + byte_length > self._max_bytes:
raise WebSocketQueueFullError("send queue is full, message discarded")
self._queue.append((data, byte_length))
self._bytes += byte_length

def flush_sync(self, send: typing.Callable[[str], object]) -> None:
def flush_sync(self, send: typing.Callable[[bytes | str], object]) -> None:
"""Send every queued message via *send*.

If *send* raises, the failing message and all subsequent messages
Expand All @@ -56,7 +57,7 @@ def flush_sync(self, send: typing.Callable[[str], object]) -> None:
self._bytes = sum(bl for _, bl in self._queue)
raise

async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object]]) -> None:
async def flush_async(self, send: typing.Callable[[bytes | str], typing.Awaitable[object]]) -> None:
"""Async variant of :meth:`flush_sync`."""
with self._lock:
pending = list(self._queue)
Expand All @@ -73,7 +74,7 @@ async def flush_async(self, send: typing.Callable[[str], typing.Awaitable[object
self._bytes = sum(bl for _, bl in self._queue)
raise

def drain(self) -> list[str]:
def drain(self) -> list[bytes | str]:
"""Remove and return all queued messages."""
with self._lock:
items = [data for data, _ in self._queue]
Expand Down
18 changes: 12 additions & 6 deletions src/openai/resources/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,13 @@ async def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> N

async def send_raw(self, data: bytes | str) -> None:
if self._is_reconnecting:
raw = data if isinstance(data, str) else data.decode("utf-8")
self._send_queue.enqueue(raw)
self._send_queue.enqueue(data)
return
await self._connection.send(data)
try:
await self._connection.send(data)
except Exception:
self._send_queue.enqueue(data)
raise

async def close(self, *, code: int = 1000, reason: str = "") -> None:
self._intentionally_closed = True
Expand Down Expand Up @@ -839,10 +842,13 @@ def send(self, event: RealtimeClientEvent | RealtimeClientEventParam) -> None:

def send_raw(self, data: bytes | str) -> None:
if self._is_reconnecting:
raw = data if isinstance(data, str) else data.decode("utf-8")
self._send_queue.enqueue(raw)
self._send_queue.enqueue(data)
return
self._connection.send(data)
try:
self._connection.send(data)
except Exception:
self._send_queue.enqueue(data)
raise

def close(self, *, code: int = 1000, reason: str = "") -> None:
self._intentionally_closed = True
Expand Down
18 changes: 12 additions & 6 deletions src/openai/resources/responses/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3852,10 +3852,13 @@ async def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) ->

async def send_raw(self, data: bytes | str) -> None:
if self._is_reconnecting:
raw = data if isinstance(data, str) else data.decode("utf-8")
self._send_queue.enqueue(raw)
self._send_queue.enqueue(data)
return
await self._connection.send(data)
try:
await self._connection.send(data)
except Exception:
self._send_queue.enqueue(data)
raise

async def close(self, *, code: int = 1000, reason: str = "") -> None:
self._intentionally_closed = True
Expand Down Expand Up @@ -4309,10 +4312,13 @@ def send(self, event: ResponsesClientEvent | ResponsesClientEventParam) -> None:

def send_raw(self, data: bytes | str) -> None:
if self._is_reconnecting:
raw = data if isinstance(data, str) else data.decode("utf-8")
self._send_queue.enqueue(raw)
self._send_queue.enqueue(data)
return
self._connection.send(data)
try:
self._connection.send(data)
except Exception:
self._send_queue.enqueue(data)
raise

def close(self, *, code: int = 1000, reason: str = "") -> None:
self._intentionally_closed = True
Expand Down
89 changes: 89 additions & 0 deletions tests/test_realtime_reconnect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

from typing import Any

import pytest

from openai.resources.realtime.realtime import AsyncRealtimeConnection


def _connection_closed_error(code: int = 1011) -> Exception:
from websockets.frames import Close
from websockets.exceptions import ConnectionClosedError

return ConnectionClosedError(Close(code=code, reason=""), None)


class _DeadConnection:
"""A connection whose send() always fails, simulating a dropped socket."""

async def send(self, _data: bytes | str) -> None:
raise _connection_closed_error()

async def close(self, *, code: int = 1000, reason: str = "") -> None:
pass


class _RecordingConnection:
"""The connection returned after a successful reconnect."""

def __init__(self) -> None:
self.sent: list[bytes | str] = []

async def send(self, data: bytes | str) -> None:
self.sent.append(data)


def _make_connection(new_conn: _RecordingConnection) -> AsyncRealtimeConnection:
async def make_ws(_extra_query: Any, _extra_headers: Any) -> Any:
return new_conn

return AsyncRealtimeConnection(
_DeadConnection(), # type: ignore[arg-type]
make_ws=make_ws,
on_reconnecting=lambda _event: None,
max_retries=1,
initial_delay=0.0,
max_delay=0.0,
)


@pytest.mark.asyncio
async def test_reconnect_resends_binary_payload_unchanged() -> None:
"""End-to-end: a binary send_raw() that fails mid-send is queued and
replayed byte-for-byte after reconnect, without UTF-8 corruption."""
from websockets.exceptions import ConnectionClosedError

new_conn = _RecordingConnection()
conn = _make_connection(new_conn)

binary = b"\xff\xfe\x00audio" # not valid UTF-8 (would crash on decode)

# send fails on the dead socket -> the original connection error must
# surface (NOT a UnicodeDecodeError from decoding the binary payload),
# and the payload must be queued for replay.
with pytest.raises(ConnectionClosedError):
await conn.send_raw(binary)

# Drive the real reconnect path, which flushes the queue to the new socket.
reconnected = await conn._reconnect(_connection_closed_error())
assert reconnected is True

assert new_conn.sent == [binary]
assert isinstance(new_conn.sent[0], bytes)


@pytest.mark.asyncio
async def test_reconnect_resends_text_payload() -> None:
"""A str send_raw() is replayed as text after reconnect."""
from websockets.exceptions import ConnectionClosedError

new_conn = _RecordingConnection()
conn = _make_connection(new_conn)

with pytest.raises(ConnectionClosedError):
await conn.send_raw('{"type": "input_audio_buffer.append"}')

assert await conn._reconnect(_connection_closed_error()) is True
assert new_conn.sent == ['{"type": "input_audio_buffer.append"}']
assert isinstance(new_conn.sent[0], str)
33 changes: 33 additions & 0 deletions tests/test_send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ def test_enqueue_and_drain(self) -> None:
assert items == ['{"type": "session.update"}', '{"type": "response.create"}']
assert len(q) == 0

def test_enqueue_preserves_binary_frames(self) -> None:
"""Binary payloads must be stored as-is, not decoded to text.

Decoding to UTF-8 would corrupt binary frames and raise
``UnicodeDecodeError`` for arbitrary bytes (e.g. audio chunks).
"""
q = SendQueue()
binary = b"\xff\xfe\x00audio" # not valid UTF-8
q.enqueue(binary)
q.enqueue("text")

items = q.drain()
assert items == [binary, "text"]
assert isinstance(items[0], bytes)
assert isinstance(items[1], str)

def test_enqueue_counts_binary_byte_length(self) -> None:
q = SendQueue(max_bytes=4)
q.enqueue(b"\xff\xfe\xfd\xfc") # 4 bytes, fits exactly
with pytest.raises(WebSocketQueueFullError):
q.enqueue(b"\x00") # would exceed
assert len(q) == 1

def test_flush_sync_preserves_binary(self) -> None:
q = SendQueue()
binary = b"\xff\xfe"
q.enqueue(binary)
q.enqueue("text")

sent: list[bytes | str] = []
q.flush_sync(sent.append)
assert sent == [binary, "text"]

def test_enqueue_respects_byte_limit(self) -> None:
q = SendQueue(max_bytes=10)
q.enqueue("12345") # 5 bytes, fits
Expand Down