diff --git a/decart/realtime/messages.py b/decart/realtime/messages.py index cd6f63d..ed99237 100644 --- a/decart/realtime/messages.py +++ b/decart/realtime/messages.py @@ -171,4 +171,9 @@ def message_to_json(message: OutgoingMessage) -> str: Returns: JSON string """ + # SetAvatarImageMessage uses exclude_unset so explicitly-passed None values + # (e.g. image_data=None, prompt=None for passthrough) are serialized as null, + # while fields that were never set are omitted. + if isinstance(message, SetAvatarImageMessage): + return message.model_dump_json(exclude_unset=True) return message.model_dump_json(exclude_none=True) diff --git a/decart/realtime/webrtc_connection.py b/decart/realtime/webrtc_connection.py index 216a6d9..98b183a 100644 --- a/decart/realtime/webrtc_connection.py +++ b/decart/realtime/webrtc_connection.py @@ -97,7 +97,9 @@ async def connect( ) elif initial_prompt: await self._send_initial_prompt_and_wait(initial_prompt) - + elif local_track is not None: + # No image and no prompt — send passthrough (skip for subscribe mode which has no local stream) + await self._send_passthrough_and_wait() await self._setup_peer_connection(local_track, model_name=model_name) await self._create_and_send_offer() @@ -171,6 +173,31 @@ async def _send_initial_prompt_and_wait(self, prompt: dict, timeout: float = 15. finally: self.unregister_prompt_wait(prompt_text) + async def _send_passthrough_and_wait(self, timeout: float = 30.0) -> None: + """Send passthrough set_image (null image + null prompt) and wait for ack. + + When connecting without an initial prompt or image, the server still + expects an explicit initial state. Sending image_data=null + prompt=null + tells the server to use passthrough mode. + """ + event, result = self.register_image_set_wait() + + try: + message = SetAvatarImageMessage(type="set_image", image_data=None, prompt=None) + await self._send_message(message) + + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + raise WebRTCError("Passthrough acknowledgment timed out") + + if not result["success"]: + raise WebRTCError( + f"Failed to send passthrough: {result.get('error', 'unknown error')}" + ) + finally: + self.unregister_image_set_wait() + async def _setup_peer_connection( self, local_track: Optional[MediaStreamTrack], @@ -344,6 +371,20 @@ def _handle_set_image_ack(self, message: SetImageAckMessage) -> None: def _handle_error(self, message: ErrorMessage) -> None: logger.error(f"Received error from server: {message.error}") error = WebRTCError(message.error) + + # Fail-fast: resolve any pending Phase-2 waits so they surface the + # real server error instead of timing out after 30 s. + if self._pending_image_set: + event, result = self._pending_image_set + result["success"] = False + result["error"] = message.error + event.set() + + for _prompt, (event, result) in list(self._pending_prompts.items()): + result["success"] = False + result["error"] = message.error + event.set() + if self._on_error: self._on_error(error) diff --git a/tests/test_realtime_unit.py b/tests/test_realtime_unit.py index ae6f300..f12c5bb 100644 --- a/tests/test_realtime_unit.py +++ b/tests/test_realtime_unit.py @@ -1094,3 +1094,175 @@ async def test_image_to_base64_file_path_string(tmp_path): mock_session = MagicMock() result = await _image_to_base64(str(img), mock_session) assert result == base64.b64encode(b"PNGDATA").decode("utf-8") + + +# Tests for passthrough mode (no initial prompt/image) + + +@pytest.mark.asyncio +async def test_connect_without_initial_state_sends_passthrough(): + """Connecting without prompt/image sends passthrough set_image (null image + null prompt).""" + client = DecartClient(api_key="test-key") + + with ( + patch("decart.realtime.client.WebRTCManager") as mock_manager_class, + patch("decart.realtime.client.aiohttp.ClientSession") as mock_session_cls, + ): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager.is_connected = MagicMock(return_value=True) + mock_manager_class.return_value = mock_manager + + mock_session = MagicMock() + mock_session.closed = False + mock_session.close = AsyncMock() + mock_session_cls.return_value = mock_session + + mock_track = MagicMock() + + from decart.realtime.types import RealtimeConnectOptions + + realtime_client = await RealtimeClient.connect( + base_url=client.base_url, + api_key=client.api_key, + local_track=mock_track, + options=RealtimeConnectOptions( + model=models.realtime("mirage"), + on_remote_stream=lambda t: None, + # No initial_state — should trigger passthrough + ), + ) + + assert realtime_client is not None + mock_manager.connect.assert_called_once() + call_kwargs = mock_manager.connect.call_args[1] + # initial_image and initial_prompt should both be None + assert call_kwargs.get("initial_image") is None + assert call_kwargs.get("initial_prompt") is None + + +@pytest.mark.asyncio +async def test_passthrough_sends_set_image_with_null_prompt(): + """_send_passthrough_and_wait sends set_image with null image_data and null prompt.""" + from decart.realtime.webrtc_connection import WebRTCConnection + + connection = WebRTCConnection() + + sent_messages: list = [] + + async def capture_send(message): + sent_messages.append(message) + # Simulate set_image_ack arriving immediately (like FakeWebSocket in JS tests) + if connection._pending_image_set: + event, result = connection._pending_image_set + result["success"] = True + event.set() + + connection._send_message = capture_send # type: ignore[assignment] + + await connection._send_passthrough_and_wait() + + assert len(sent_messages) == 1 + msg = sent_messages[0] + assert msg.type == "set_image" + assert msg.image_data is None + assert msg.prompt is None + + # Verify JSON serialization includes null values + from decart.realtime.messages import message_to_json + import json + + json_str = message_to_json(msg) + parsed = json.loads(json_str) + assert parsed == {"type": "set_image", "image_data": None, "prompt": None} + + +@pytest.mark.asyncio +async def test_subscribe_mode_skips_passthrough(): + """Subscribe mode (null local_track) must not send passthrough set_image.""" + client = DecartClient(api_key="test-key") + + with (patch("decart.realtime.client.WebRTCManager") as mock_manager_class,): + mock_manager = AsyncMock() + mock_manager.connect = AsyncMock(return_value=True) + mock_manager_class.return_value = mock_manager + + # subscribe() passes local_track=None internally + from decart.realtime.subscribe import SubscribeClient, encode_subscribe_token + + token = encode_subscribe_token("test-sid", "1.2.3.4", 8080) + + from decart.realtime.subscribe import SubscribeOptions + + sub_client = await RealtimeClient.subscribe( + base_url=client.base_url, + api_key=client.api_key, + options=SubscribeOptions( + token=token, + on_remote_stream=lambda t: None, + ), + ) + + assert sub_client is not None + # Verify connect was called with local_track=None (subscribe mode) + mock_manager.connect.assert_called_once() + call_args = mock_manager.connect.call_args + assert call_args[0][0] is None # first positional arg is local_track=None + + +@pytest.mark.asyncio +async def test_server_error_during_passthrough_fails_fast(): + """Server error during passthrough surfaces real error instead of 30s timeout.""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.realtime.messages import ErrorMessage + from decart.errors import WebRTCError + + connection = WebRTCConnection() + + async def fake_send(message): + # Simulate the server responding with an error instead of set_image_ack + await asyncio.sleep(0) # yield so wait_for is listening + connection._handle_error(ErrorMessage(type="error", error="insufficient_credits")) + + connection._send_message = fake_send # type: ignore[assignment] + + with pytest.raises(WebRTCError, match="insufficient_credits"): + await connection._send_passthrough_and_wait() + + +@pytest.mark.asyncio +async def test_server_error_during_initial_image_fails_fast(): + """Server error during initial image setup surfaces real error (pre-existing fix).""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.realtime.messages import ErrorMessage + from decart.errors import WebRTCError + + connection = WebRTCConnection() + + async def fake_send(message): + await asyncio.sleep(0) + connection._handle_error(ErrorMessage(type="error", error="invalid_image")) + + connection._send_message = fake_send # type: ignore[assignment] + + with pytest.raises(WebRTCError, match="invalid_image"): + await connection._send_initial_image_and_wait("base64data") + + +@pytest.mark.asyncio +async def test_server_error_during_initial_prompt_fails_fast(): + """Server error during initial prompt setup surfaces real error (pre-existing fix).""" + from decart.realtime.webrtc_connection import WebRTCConnection + from decart.realtime.messages import ErrorMessage + from decart.errors import WebRTCError + + connection = WebRTCConnection() + + async def fake_send(message): + await asyncio.sleep(0) + connection._handle_error(ErrorMessage(type="error", error="rate_limited")) + + connection._send_message = fake_send # type: ignore[assignment] + + with pytest.raises(WebRTCError, match="rate_limited"): + await connection._send_initial_prompt_and_wait({"text": "test", "enhance": True})