diff --git a/src/google/adk/plugins/base_plugin.py b/src/google/adk/plugins/base_plugin.py index 3639f61aa2..4d5a016173 100644 --- a/src/google/adk/plugins/base_plugin.py +++ b/src/google/adk/plugins/base_plugin.py @@ -370,3 +370,25 @@ async def on_tool_error_callback( allows the original error to be raised. """ pass + + async def on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> None: + """Callback executed when an event carries state changes. + + This callback is invoked after an event with a non-empty + ``state_delta`` is yielded from the runner. It is observational, but + returning a non-`None` value will short-circuit subsequent plugins. + + Args: + callback_context: The context for the current invocation. + state_delta: A copy of the state changes carried by the event. + Mutating this dict does not affect the original state. + + Returns: + None + """ + pass diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 7cbf931ca9..721047bde8 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -2510,13 +2510,6 @@ async def after_tool_callback( parent_span_id_override=parent_span_id, ) - if tool_context.actions.state_delta: - await self._log_event( - "STATE_DELTA", - tool_context, - state_delta=tool_context.actions.state_delta, - ) - async def on_tool_error_callback( self, *, diff --git a/src/google/adk/plugins/plugin_manager.py b/src/google/adk/plugins/plugin_manager.py index c781e8fa4e..46954d7706 100644 --- a/src/google/adk/plugins/plugin_manager.py +++ b/src/google/adk/plugins/plugin_manager.py @@ -52,6 +52,7 @@ "after_model_callback", "on_tool_error_callback", "on_model_error_callback", + "on_state_change_callback", ] logger = logging.getLogger("google_adk." + __name__) @@ -257,6 +258,19 @@ async def run_on_tool_error_callback( error=error, ) + async def run_on_state_change_callback( + self, + *, + callback_context: CallbackContext, + state_delta: dict[str, Any], + ) -> Optional[Any]: + """Runs the `on_state_change_callback` for all plugins.""" + return await self._run_callbacks( + "on_state_change_callback", + callback_context=callback_context, + state_delta=state_delta, + ) + async def _run_callbacks( self, callback_name: PluginCallbackName, **kwargs: Any ) -> Optional[Any]: diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 545a0e83e6..4d7319f9f5 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -774,7 +774,6 @@ async def _exec_with_plugin( # transcription event. buffered_events: list[Event] = [] is_transcribing: bool = False - async with Aclosing(execute_fn(invocation_context)) as agen: async for event in agen: _apply_run_config_custom_metadata( @@ -839,9 +838,19 @@ async def _exec_with_plugin( _apply_run_config_custom_metadata( modified_event, invocation_context.run_config ) - yield modified_event + final_event = modified_event else: - yield event + final_event = event + yield final_event + + # Step 3b: Notify plugins of state changes, if any. + if final_event.actions.state_delta: + from .agents.callback_context import CallbackContext + + await plugin_manager.run_on_state_change_callback( + callback_context=CallbackContext(invocation_context), + state_delta=dict(final_event.actions.state_delta), + ) # Step 4: Run the after_run callbacks to perform global cleanup tasks or # finalizing logs and metrics data. @@ -1485,6 +1494,14 @@ async def _handle_new_message( state_delta=state_delta, ) + if state_delta: + from .agents.callback_context import CallbackContext + + await invocation_context.plugin_manager.run_on_state_change_callback( + callback_context=CallbackContext(invocation_context), + state_delta=dict(state_delta), + ) + def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]: toolsets = set() if isinstance(agent, LlmAgent): diff --git a/tests/unittests/plugins/test_base_plugin.py b/tests/unittests/plugins/test_base_plugin.py index aa7c17fb01..fbe98b71df 100644 --- a/tests/unittests/plugins/test_base_plugin.py +++ b/tests/unittests/plugins/test_base_plugin.py @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str: async def on_model_error_callback(self, **kwargs) -> str: return "overridden_on_model_error" + async def on_state_change_callback(self, **kwargs) -> str: + return "overridden_on_state_change" + def test_base_plugin_initialization(): """Tests that a plugin is initialized with the correct name.""" @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none(): ) is None ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_context, + state_delta={}, + ) + is None + ) @pytest.mark.asyncio @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden(): ) == "overridden_on_model_error" ) + assert ( + await plugin.on_state_change_callback( + callback_context=mock_callback_context, + state_delta={"key": "value"}, + ) + == "overridden_on_state_change" + ) diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index b11d5659dc..b5bba8e31e 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -1560,16 +1560,19 @@ async def test_after_tool_callback_logs_correctly( assert content_dict["result"] == {"res": "success"} @pytest.mark.asyncio - async def test_after_tool_callback_state_delta_logging( + async def test_after_tool_callback_no_inline_state_delta( self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema ): + """after_tool_callback does not log STATE_DELTA inline. + + STATE_DELTA is logged exclusively via on_state_change_callback. + """ mock_tool = mock.create_autospec( base_tool_lib.BaseTool, instance=True, spec_set=True ) type(mock_tool).name = mock.PropertyMock(return_value="StateTool") type(mock_tool).description = mock.PropertyMock(return_value="Sets state") - # Simulate a tool modifying the state tool_context.actions.state_delta["new_key"] = "new_value" bigquery_agent_analytics_plugin.TraceManager.push_span(tool_context) @@ -1581,31 +1584,11 @@ async def test_after_tool_callback_state_delta_logging( ) await asyncio.sleep(0.01) - # We should have two events appended: TOOL_COMPLETED and STATE_DELTA - assert mock_write_client.append_rows.call_count >= 1 - - # Retrieve all flushed events - rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema) - assert len(rows) == 2 - - # Sort by event_type to reliably access them - rows.sort(key=lambda x: x["event_type"]) - - state_delta_event = ( - rows[0] if rows[0]["event_type"] == "STATE_DELTA" else rows[1] - ) - tool_event = ( - rows[1] if rows[1]["event_type"] == "TOOL_COMPLETED" else rows[0] + # Only TOOL_COMPLETED should be logged + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema ) - - assert state_delta_event["event_type"] == "STATE_DELTA" - assert tool_event["event_type"] == "TOOL_COMPLETED" - - # Verify STATE_DELTA payload - attributes = json.loads(state_delta_event["attributes"]) - assert "state_delta" in attributes - assert attributes["state_delta"] == {"new_key": "new_value"} - assert state_delta_event["content"] is None + assert log_entry["event_type"] == "TOOL_COMPLETED" @pytest.mark.asyncio async def test_on_state_change_callback_logs_correctly( @@ -1615,6 +1598,7 @@ async def test_on_state_change_callback_logs_correctly( callback_context, dummy_arrow_schema, ): + """STATE_DELTA is logged via on_state_change_callback.""" state_delta = {"key": "value", "new_key": 123} bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) await bq_plugin_inst.on_state_change_callback( @@ -1625,10 +1609,8 @@ async def test_on_state_change_callback_logs_correctly( mock_write_client, dummy_arrow_schema ) _assert_common_fields(log_entry, "STATE_DELTA") - # content should be None (as raw_content was not passed) assert log_entry["content"] is None - # state_delta should be in attributes attributes = json.loads(log_entry["attributes"]) assert attributes["state_delta"] == state_delta diff --git a/tests/unittests/plugins/test_plugin_manager.py b/tests/unittests/plugins/test_plugin_manager.py index ba070ea8f3..fe47ee47a1 100644 --- a/tests/unittests/plugins/test_plugin_manager.py +++ b/tests/unittests/plugins/test_plugin_manager.py @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs): async def on_model_error_callback(self, **kwargs): return await self._handle_callback("on_model_error_callback") + async def on_state_change_callback(self, **kwargs): + return await self._handle_callback("on_state_change_callback") + @pytest.fixture def service() -> PluginManager: @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported( llm_request=mock_context, error=mock_context, ) + await service.run_on_state_change_callback( + callback_context=mock_context, + state_delta={"key": "value"}, + ) # Verify all callbacks were logged expected_callbacks = [ @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported( "before_model_callback", "after_model_callback", "on_model_error_callback", + "on_state_change_callback", ] assert set(plugin1.call_log) == set(expected_callbacks) @@ -317,3 +325,57 @@ async def slow_close(): assert "Failed to close plugins: 'plugin1': TimeoutError" in str( excinfo.value ) + + +# --- on_state_change_callback tests --- + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback( + service: PluginManager, plugin1: TestPlugin +): + """Tests that run_on_state_change_callback invokes the callback and returns None.""" + service.register_plugin(plugin1) + result = await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + assert result is None + assert "on_state_change_callback" in plugin1.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_calls_all_plugins( + service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin +): + """Tests that on_state_change_callback is called on all plugins.""" + service.register_plugin(plugin1) + service.register_plugin(plugin2) + + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "on_state_change_callback" in plugin1.call_log + assert "on_state_change_callback" in plugin2.call_log + + +@pytest.mark.asyncio +async def test_run_on_state_change_callback_wraps_exceptions( + service: PluginManager, plugin1: TestPlugin +): + """Tests that exceptions in on_state_change_callback are wrapped in RuntimeError.""" + original_exception = ValueError("state change error") + plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception + service.register_plugin(plugin1) + + with pytest.raises(RuntimeError) as excinfo: + await service.run_on_state_change_callback( + callback_context=Mock(), + state_delta={"key": "value"}, + ) + + assert "Error in plugin 'plugin1'" in str(excinfo.value) + assert "on_state_change_callback" in str(excinfo.value) + assert excinfo.value.__cause__ is original_exception diff --git a/tests/unittests/test_runners.py b/tests/unittests/test_runners.py index 62b8d7334b..c28afd966a 100644 --- a/tests/unittests/test_runners.py +++ b/tests/unittests/test_runners.py @@ -146,6 +146,7 @@ def __init__(self): self.enable_user_message_callback = False self.enable_event_callback = False self.user_content_seen_in_before_run_callback = None + self.state_change_deltas: list[dict] = [] async def on_user_message_callback( self, @@ -169,6 +170,9 @@ async def before_run_callback( invocation_context.user_content ) + async def on_state_change_callback(self, *, callback_context, state_delta, **kwargs): + self.state_change_deltas.append(state_delta) + async def on_event_callback( self, *, invocation_context: InvocationContext, event: Event ) -> Optional[Event]: @@ -853,6 +857,49 @@ async def test_runner_passes_plugin_close_timeout(self): ) assert runner.plugin_manager._close_timeout == 10.0 + @pytest.mark.asyncio + async def test_state_delta_in_run_async_triggers_on_state_change_callback( + self, + ): + """Test that caller-supplied state_delta triggers on_state_change_callback.""" + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + state_delta = {"lang": "en", "theme": "dark"} + events = [] + async for event in self.runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + state_delta=state_delta, + ): + events.append(event) + + assert len(self.plugin.state_change_deltas) >= 1 + assert self.plugin.state_change_deltas[0] == state_delta + + @pytest.mark.asyncio + async def test_no_state_delta_does_not_trigger_on_state_change_callback( + self, + ): + """Test that on_state_change_callback is not called when no state_delta is provided.""" + await self.session_service.create_session( + app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID + ) + events = [] + async for event in self.runner.run_async( + user_id=TEST_USER_ID, + session_id=TEST_SESSION_ID, + new_message=types.Content( + role="user", parts=[types.Part(text="Hello")] + ), + ): + events.append(event) + + assert len(self.plugin.state_change_deltas) == 0 + @pytest.mark.filterwarnings( "ignore:The `plugins` argument is deprecated:DeprecationWarning" )