Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions src/deepgram/speak/v1/socket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import websockets.sync.connection as websockets_sync_connection
from ...core.events import EventEmitterMixin, EventType
from ...core.pydantic_utilities import parse_obj_as
from websockets import WebSocketClientProtocol

try:
from websockets.legacy.client import WebSocketClientProtocol # type: ignore
Expand All @@ -28,9 +29,9 @@
# Response union type with binary support
V1SocketClientResponse = typing.Union[
SpeakV1AudioChunkEvent, # Binary audio data
SpeakV1MetadataEvent, # JSON metadata
SpeakV1ControlEvent, # JSON control responses (Flushed, Cleared)
SpeakV1WarningEvent, # JSON warnings
SpeakV1MetadataEvent, # JSON metadata
SpeakV1ControlEvent, # JSON control responses (Flushed, Cleared)
SpeakV1WarningEvent, # JSON warnings
bytes, # Raw binary audio chunks
]

Expand All @@ -50,17 +51,14 @@ def _handle_binary_message(self, message: bytes) -> typing.Any:

def _handle_json_message(self, message: str) -> typing.Any:
"""Handle a JSON message by parsing it."""
json_data = json.loads(message)
return parse_obj_as(V1SocketClientResponse, json_data) # type: ignore
return parse_obj_as(V1SocketClientResponse, json.loads(message)) # type: ignore

def _process_message(self, raw_message: typing.Any) -> typing.Tuple[typing.Any, bool]:
"""Process a raw message, detecting if it's binary or JSON."""
if self._is_binary_message(raw_message):
processed = self._handle_binary_message(raw_message)
return processed, True
if isinstance(raw_message, bytes) or isinstance(raw_message, bytearray):
return raw_message, True
else:
processed = self._handle_json_message(raw_message)
return processed, False
return self._handle_json_message(raw_message), False

async def __aiter__(self):
async for message in self._websocket:
Expand Down