diff --git a/sdk/rt/speechmatics/rt/_async_client.py b/sdk/rt/speechmatics/rt/_async_client.py index 5e581e1..892e254 100644 --- a/sdk/rt/speechmatics/rt/_async_client.py +++ b/sdk/rt/speechmatics/rt/_async_client.py @@ -12,6 +12,7 @@ from ._exceptions import TimeoutError from ._exceptions import TranscriptionError from ._logging import get_logger +from ._models import AudioEncoding from ._models import AudioEventsConfig from ._models import AudioFormat from ._models import ClientMessageType @@ -97,6 +98,8 @@ def __init__( self.on(ServerMessageType.WARNING, self._on_warning) self.on(ServerMessageType.AUDIO_ADDED, self._on_audio_added) + self._audio_format = AudioFormat(encoding=AudioEncoding.PCM_S16LE, sample_rate=44100, chunk_size=4096) + self._logger.debug("AsyncClient initialized (request_id=%s)", self._session.request_id) async def start_session( @@ -133,6 +136,9 @@ async def start_session( ... await client.start_session() ... await client.send_audio(frame) """ + if audio_format is not None: + self._audio_format = audio_format + await self._start_recognition_session( transcription_config=transcription_config, audio_format=audio_format, @@ -161,16 +167,24 @@ async def stop_session(self) -> None: await self._session_done_evt.wait() # Wait for end of transcript event to indicate we can stop listening await self.close() - async def force_end_of_utterance(self) -> None: + async def force_end_of_utterance(self, timestamp: Optional[float] = None) -> None: """ This method sends a ForceEndOfUtterance message to the server to signal the end of an utterance. Forcing end of utterance will cause the final transcript to be sent to the client early. + Takes an optional timestamp parameter to specify a marker for the engine + to use for timing of the end of the utterance. If not provided, the timestamp + will be calculated based on the cumulative audio sent to the server. + + Args: + timestamp: Optional timestamp for the request. + Raises: ConnectionError: If the WebSocket connection fails. TranscriptionError: If the server reports an error during teardown. TimeoutError: If the connection or teardown times out. + ValueError: If the audio format does not have an encoding set. Examples: Basic streaming: @@ -179,7 +193,19 @@ async def force_end_of_utterance(self) -> None: ... await client.send_audio(frame) ... await client.force_end_of_utterance() """ - await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE}) + if timestamp is None: + timestamp = self.audio_seconds_sent + + await self.send_message({"message": ClientMessageType.FORCE_END_OF_UTTERANCE, "timestamp": timestamp}) + + @property + def audio_seconds_sent(self) -> float: + """Number of audio seconds sent to the server. + + Raises: + ValueError: If the audio format does not have an encoding set. + """ + return self._audio_bytes_sent / (self._audio_format.sample_rate * self._audio_format.bytes_per_sample) async def transcribe( self, diff --git a/sdk/rt/speechmatics/rt/_base_client.py b/sdk/rt/speechmatics/rt/_base_client.py index 0ac6d08..89167e2 100644 --- a/sdk/rt/speechmatics/rt/_base_client.py +++ b/sdk/rt/speechmatics/rt/_base_client.py @@ -42,6 +42,7 @@ def __init__(self, transport: Transport) -> None: self._recv_task: Optional[asyncio.Task[None]] = None self._closed_evt = asyncio.Event() self._eos_sent = False + self._audio_bytes_sent = 0 self._seq_no = 0 self._logger = get_logger("speechmatics.rt.base_client") @@ -122,11 +123,17 @@ async def send_audio(self, payload: bytes) -> None: try: await self._transport.send_message(payload) + self._audio_bytes_sent += len(payload) self._seq_no += 1 except Exception: self._closed_evt.set() raise + @property + def audio_bytes_sent(self) -> int: + """Number of audio bytes sent to the server.""" + return self._audio_bytes_sent + async def send_message(self, message: dict[str, Any]) -> None: """ Send a message through the WebSocket. diff --git a/sdk/rt/speechmatics/rt/_models.py b/sdk/rt/speechmatics/rt/_models.py index 84e5720..d1f6acb 100644 --- a/sdk/rt/speechmatics/rt/_models.py +++ b/sdk/rt/speechmatics/rt/_models.py @@ -183,6 +183,29 @@ class AudioFormat: sample_rate: int = 44100 chunk_size: int = 4096 + _BYTES_PER_SAMPLE = { + AudioEncoding.PCM_F32LE: 4, + AudioEncoding.PCM_S16LE: 2, + AudioEncoding.MULAW: 1, + } + + @property + def bytes_per_sample(self) -> int: + """Number of bytes per audio sample based on encoding. + + Raises: + ValueError: If encoding is None (file type) or unrecognized. + """ + if self.encoding is None: + raise ValueError( + "Cannot determine bytes per sample for file-type audio format. " + "Set an explicit encoding on AudioFormat." + ) + try: + return self._BYTES_PER_SAMPLE[self.encoding] + except KeyError: + raise ValueError(f"Unknown encoding: {self.encoding}") + def to_dict(self) -> dict[str, Any]: """ Convert audio format to dictionary. diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index c0988dd..bc36f39 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -472,12 +472,15 @@ def _prepare_config( # LIFECYCLE METHODS # ============================================================================ - async def connect(self) -> None: + async def connect(self, ws_headers: Optional[dict] = None) -> None: """Connect to the Speechmatics API. Establishes WebSocket connection and starts the transcription session. This must be called before sending audio. + Args: + ws_headers: Optional headers to pass to the WebSocket connection. + Raises: Exception: If connection fails. @@ -521,6 +524,7 @@ async def connect(self) -> None: await self.start_session( transcription_config=self._transcription_config, audio_format=self._audio_format, + ws_headers=ws_headers, ) self._is_connected = True self._start_metrics_task() @@ -717,14 +721,11 @@ def update_diarization_config(self, config: SpeakerFocusConfig) -> None: # PUBLIC UTTERANCE / TURN MANAGEMENT # ============================================================================ - def finalize(self, end_of_turn: bool = False) -> None: + def finalize(self) -> None: """Finalize segments. This function will emit segments in the buffer without any further checks on the contents of the segments. - - Args: - end_of_turn: Whether to emit an end of turn message. """ # Clear smart turn cutoff diff --git a/tests/voice/test_17_eou_feou.py b/tests/voice/test_17_eou_feou.py index f78c6ab..4e4e50d 100644 --- a/tests/voice/test_17_eou_feou.py +++ b/tests/voice/test_17_eou_feou.py @@ -48,41 +48,41 @@ class TranscriptionTests(BaseModel): SAMPLES: TranscriptionTests = TranscriptionTests.from_dict( { "samples": [ - # { - # "id": "07b", - # "path": "./assets/audio_07b_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, - # {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, - # {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, - # {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, - # {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, - # {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, - # {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, - # {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, - # {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, - # {"text": "When?", "start_time": 25.49, "end_time": 25.96}, - # {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, - # {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, - # {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, - # ], - # }, - # { - # "id": "08", - # "path": "./assets/audio_08_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, - # {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, - # {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, - # {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, - # {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, - # {"text": "After.", "start_time": 9.56, "end_time": 10.05}, - # ], - # }, + { + "id": "07b", + "path": "./assets/audio_07b_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, + {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, + {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, + {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, + {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, + {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, + {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, + {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, + {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, + {"text": "When?", "start_time": 25.49, "end_time": 25.96}, + {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, + {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, + {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, + ], + }, + { + "id": "08", + "path": "./assets/audio_08_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, + {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, + {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, + {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, + {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, + {"text": "After.", "start_time": 9.56, "end_time": 10.05}, + ], + }, { "id": "09", "path": "./assets/audio_09_16kHz.wav", @@ -97,12 +97,12 @@ class TranscriptionTests(BaseModel): ) # VAD delay -VAD_DELAY_S: list[float] = [0.18, 0.22] +VAD_DELAY_S: list[float] = [0.18] # , 0.22] # Endpoints ENDPOINTS: list[str] = [ - # "wss://eu-west-2-research.speechmatics.cloud/v2", - "wss://eu.rt.speechmatics.com/v2", + "wss://eu-west-2-research.speechmatics.cloud/v2", + # "wss://eu.rt.speechmatics.com/v2", # "wss://us.rt.speechmatics.com/v2", ] @@ -177,6 +177,11 @@ async def run_test(endpoint: str, sample: TranscriptionTest, config: VoiceAgentC # Start time start_time = datetime.datetime.now() + # Zero time + def zero_time(message): + global start_time + start_time = datetime.datetime.now() + # Finalized segment def add_segments(message): segments = message["segments"] @@ -213,6 +218,13 @@ def log_message(message): log = json.dumps({"ts": round(ts, 3), "payload": message}) print(log) + # Custom listeners + client.on(AgentServerMessageType.RECOGNITION_STARTED, zero_time) + client.on(AgentServerMessageType.END_OF_TURN, eot_detected) + client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) + client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) + client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) + # Add listeners if SHOW_LOG: message_types = [m for m in AgentServerMessageType if m != AgentServerMessageType.AUDIO_ADDED] @@ -220,12 +232,6 @@ def log_message(message): for message_type in message_types: client.on(message_type, log_message) - # Custom listeners - client.on(AgentServerMessageType.END_OF_TURN, eot_detected) - client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) - client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) - client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) - # HEADER if SHOW_LOG: print() @@ -326,7 +332,9 @@ def log_message(message): # Calculate the CER cer = TextUtils.cer(normalized_expected, normalized_received) - print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") + # Debug metrics + if SHOW_LOG: + print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") # Check CER if cer > CER_THRESHOLD: diff --git a/tests/voice/test_18_feou_timestamp.py b/tests/voice/test_18_feou_timestamp.py new file mode 100644 index 0000000..39d85bf --- /dev/null +++ b/tests/voice/test_18_feou_timestamp.py @@ -0,0 +1,73 @@ +import os + +import pytest +from _utils import get_client +from _utils import send_silence + +from speechmatics.rt import AudioEncoding +from speechmatics.voice import VoiceAgentConfig + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping in CI") +pytestmark = pytest.mark.skipif(API_KEY is None, reason="Skipping when no API key is provided") + +# How much silence to send (seconds) +SILENCE_DURATION = 3.0 + +# Tolerance for the timestamp check +TOLERANCE = 0.00 + +# Audio format configurations to test: (encoding, chunk_size, bytes_per_sample) +AUDIO_FORMATS = [ + pytest.param(AudioEncoding.PCM_S16LE, 160, 2, id="s16-chunk160"), + pytest.param(AudioEncoding.PCM_S16LE, 320, 2, id="s16-chunk320"), + pytest.param(AudioEncoding.PCM_F32LE, 160, 4, id="f32-chunk160"), + pytest.param(AudioEncoding.PCM_F32LE, 320, 4, id="f32-chunk320"), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("encoding,chunk_size,sample_size", AUDIO_FORMATS) +async def test_feou_timestamp(encoding: AudioEncoding, chunk_size: int, sample_size: int): + """Test that audio_seconds_sent correctly computes elapsed audio time. + + Sends 3 seconds of silence (zero bytes) with different audio encodings + and chunk sizes, then verifies that audio_seconds_sent returns the + correct duration. + """ + + # Create and connect client + config = VoiceAgentConfig(audio_encoding=encoding, chunk_size=chunk_size) + client = await get_client( + api_key=API_KEY, + connect=False, + config=config, + ) + + try: + await client.connect() + except Exception: + pytest.skip("Failed to connect to server") + + assert client._is_connected + + # Send 3 seconds of silence + await send_silence( + client, + duration=SILENCE_DURATION, + chunk_size=chunk_size, + sample_size=sample_size, + ) + + # Check the computed audio seconds + actual_seconds = client.audio_seconds_sent + assert ( + abs(actual_seconds - SILENCE_DURATION) <= TOLERANCE + ), f"Expected ~{SILENCE_DURATION}s but got {actual_seconds:.4f}s" + + # Clean up + await client.disconnect() + assert not client._is_connected