diff --git a/src/google/adk/models/gemini_llm_connection.py b/src/google/adk/models/gemini_llm_connection.py index d4e2900e40..028cdf0d4a 100644 --- a/src/google/adk/models/gemini_llm_connection.py +++ b/src/google/adk/models/gemini_llm_connection.py @@ -165,18 +165,12 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]: yield self.__build_full_text_response(text) text = '' yield llm_response - if ( - message.server_content.input_transcription - and message.server_content.input_transcription.text - ): + if message.server_content.input_transcription: llm_response = LlmResponse( input_transcription=message.server_content.input_transcription, ) yield llm_response - if ( - message.server_content.output_transcription - and message.server_content.output_transcription.text - ): + if message.server_content.output_transcription: llm_response = LlmResponse( output_transcription=message.server_content.output_transcription ) diff --git a/tests/unittests/models/test_gemini_llm_connection.py b/tests/unittests/models/test_gemini_llm_connection.py index 7710cce980..23e8697fdd 100644 --- a/tests/unittests/models/test_gemini_llm_connection.py +++ b/tests/unittests/models/test_gemini_llm_connection.py @@ -112,6 +112,46 @@ async def test_close(gemini_connection, mock_gemini_session): @pytest.mark.asyncio +@pytest.mark.parametrize('tx_direction', ['input', 'output']) +async def test_receive_transcript_finished( + gemini_connection, mock_gemini_session, tx_direction +): + """Test receive_transcript_finished for input and output transcription.""" + + finished_tx = types.Transcription(finished=True) + + msg = mock.Mock() + msg.tool_call = None + msg.usage_metadata = None + msg.session_resumption_update = None + msg.server_content.model_turn = None + msg.server_content.interrupted = False + msg.server_content.turn_complete = False + msg.server_content.input_transcription = ( + finished_tx if tx_direction == 'input' else None + ) + msg.server_content.output_transcription = ( + finished_tx if tx_direction == 'output' else None + ) + + async def gen(): + yield msg + + mock_gemini_session.receive = mock.Mock(return_value=gen()) + + responses = [] + async for r in gemini_connection.receive(): + responses.append(r) + + attr_name = f'{tx_direction}_transcription' + tx_resps = [r for r in responses if getattr(r, attr_name)] + assert tx_resps, f'Expected {tx_direction} transcription response' + + transcription = getattr(tx_resps[0], attr_name) + assert transcription.finished is True + assert not transcription.text + + async def test_receive_usage_metadata_and_server_content( gemini_connection, mock_gemini_session ):