diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 0f7cc2b7d9..9e4302e4b0 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -455,6 +455,15 @@ class LlmAgent(BaseAgent): """ # Callbacks - End + @override + async def _handle_before_agent_callback( + self, ctx: InvocationContext + ) -> Optional[Event]: + event = await super()._handle_before_agent_callback(ctx) + if event is not None: + self.__maybe_save_output_to_state(event) + return event + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/tests/unittests/agents/test_llm_agent_output_save.py b/tests/unittests/agents/test_llm_agent_output_save.py index e904130105..c7b1bf8d95 100644 --- a/tests/unittests/agents/test_llm_agent_output_save.py +++ b/tests/unittests/agents/test_llm_agent_output_save.py @@ -17,6 +17,7 @@ import logging from unittest.mock import patch +from google.adk.agents.callback_context import CallbackContext from google.adk.agents.llm_agent import LlmAgent from google.adk.events.event import Event from google.adk.events.event_actions import EventActions @@ -24,6 +25,8 @@ from pydantic import BaseModel import pytest +from .. import testing_utils + class MockOutputSchema(BaseModel): message: str @@ -276,3 +279,24 @@ def test_maybe_save_output_to_state_handles_empty_final_chunk_with_schema( # ASSERT: Because the method should return early, the state_delta # should remain empty. assert len(event.actions.state_delta) == 0 + + @pytest.mark.asyncio + async def test_output_key_saved_when_before_agent_callback_short_circuits( + self, + ): + """Test that output_key is written to session state when + before_agent_callback short-circuits the agent.""" + + def cache_callback(callback_context: CallbackContext) -> types.Content: + return types.Content(parts=[types.Part.from_text(text="cached answer")]) + + agent = LlmAgent( + name="test_agent", + output_key="result", + before_agent_callback=cache_callback, + ) + + runner = testing_utils.InMemoryRunner(agent) + await runner.run_async("hello") + + assert runner.session.state.get("result") == "cached answer"