Skip to content
Open
22 changes: 22 additions & 0 deletions src/google/adk/plugins/base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,25 @@ async def on_tool_error_callback(
allows the original error to be raised.
"""
pass

async def on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> None:
"""Callback executed when an event carries state changes.
This callback is invoked after an event with a non-empty
``state_delta`` is yielded from the runner. It is observational, but
returning a non-`None` value will short-circuit subsequent plugins.
Args:
callback_context: The context for the current invocation.
state_delta: A copy of the state changes carried by the event.
Mutating this dict does not affect the original state.
Returns:
None
"""
pass
7 changes: 0 additions & 7 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2510,13 +2510,6 @@ async def after_tool_callback(
parent_span_id_override=parent_span_id,
)

if tool_context.actions.state_delta:
await self._log_event(
"STATE_DELTA",
tool_context,
state_delta=tool_context.actions.state_delta,
)

async def on_tool_error_callback(
self,
*,
Expand Down
14 changes: 14 additions & 0 deletions src/google/adk/plugins/plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"after_model_callback",
"on_tool_error_callback",
"on_model_error_callback",
"on_state_change_callback",
]

logger = logging.getLogger("google_adk." + __name__)
Expand Down Expand Up @@ -257,6 +258,19 @@ async def run_on_tool_error_callback(
error=error,
)

async def run_on_state_change_callback(
self,
*,
callback_context: CallbackContext,
state_delta: dict[str, Any],
) -> Optional[Any]:
"""Runs the `on_state_change_callback` for all plugins."""
return await self._run_callbacks(
"on_state_change_callback",
callback_context=callback_context,
state_delta=state_delta,
)

async def _run_callbacks(
self, callback_name: PluginCallbackName, **kwargs: Any
) -> Optional[Any]:
Expand Down
23 changes: 20 additions & 3 deletions src/google/adk/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,6 @@ async def _exec_with_plugin(
# transcription event.
buffered_events: list[Event] = []
is_transcribing: bool = False

async with Aclosing(execute_fn(invocation_context)) as agen:
async for event in agen:
_apply_run_config_custom_metadata(
Expand Down Expand Up @@ -839,9 +838,19 @@ async def _exec_with_plugin(
_apply_run_config_custom_metadata(
modified_event, invocation_context.run_config
)
yield modified_event
final_event = modified_event
else:
yield event
final_event = event
yield final_event

# Step 3b: Notify plugins of state changes, if any.
if final_event.actions.state_delta:
from .agents.callback_context import CallbackContext

await plugin_manager.run_on_state_change_callback(
callback_context=CallbackContext(invocation_context),
state_delta=dict(final_event.actions.state_delta),
)

# Step 4: Run the after_run callbacks to perform global cleanup tasks or
# finalizing logs and metrics data.
Expand Down Expand Up @@ -1485,6 +1494,14 @@ async def _handle_new_message(
state_delta=state_delta,
)

if state_delta:
from .agents.callback_context import CallbackContext

await invocation_context.plugin_manager.run_on_state_change_callback(
callback_context=CallbackContext(invocation_context),
state_delta=dict(state_delta),
)

def _collect_toolset(self, agent: BaseAgent) -> set[BaseToolset]:
toolsets = set()
if isinstance(agent, LlmAgent):
Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/plugins/test_base_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ async def after_model_callback(self, **kwargs) -> str:
async def on_model_error_callback(self, **kwargs) -> str:
return "overridden_on_model_error"

async def on_state_change_callback(self, **kwargs) -> str:
return "overridden_on_state_change"


def test_base_plugin_initialization():
"""Tests that a plugin is initialized with the correct name."""
Expand Down Expand Up @@ -172,6 +175,13 @@ async def test_base_plugin_default_callbacks_return_none():
)
is None
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_context,
state_delta={},
)
is None
)


@pytest.mark.asyncio
Expand Down Expand Up @@ -278,3 +288,10 @@ async def test_base_plugin_all_callbacks_can_be_overridden():
)
== "overridden_on_model_error"
)
assert (
await plugin.on_state_change_callback(
callback_context=mock_callback_context,
state_delta={"key": "value"},
)
== "overridden_on_state_change"
)
38 changes: 10 additions & 28 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1560,16 +1560,19 @@ async def test_after_tool_callback_logs_correctly(
assert content_dict["result"] == {"res": "success"}

@pytest.mark.asyncio
async def test_after_tool_callback_state_delta_logging(
async def test_after_tool_callback_no_inline_state_delta(
self, bq_plugin_inst, mock_write_client, tool_context, dummy_arrow_schema
):
"""after_tool_callback does not log STATE_DELTA inline.

STATE_DELTA is logged exclusively via on_state_change_callback.
"""
mock_tool = mock.create_autospec(
base_tool_lib.BaseTool, instance=True, spec_set=True
)
type(mock_tool).name = mock.PropertyMock(return_value="StateTool")
type(mock_tool).description = mock.PropertyMock(return_value="Sets state")

# Simulate a tool modifying the state
tool_context.actions.state_delta["new_key"] = "new_value"

bigquery_agent_analytics_plugin.TraceManager.push_span(tool_context)
Expand All @@ -1581,31 +1584,11 @@ async def test_after_tool_callback_state_delta_logging(
)
await asyncio.sleep(0.01)

# We should have two events appended: TOOL_COMPLETED and STATE_DELTA
assert mock_write_client.append_rows.call_count >= 1

# Retrieve all flushed events
rows = await _get_captured_rows_async(mock_write_client, dummy_arrow_schema)
assert len(rows) == 2

# Sort by event_type to reliably access them
rows.sort(key=lambda x: x["event_type"])

state_delta_event = (
rows[0] if rows[0]["event_type"] == "STATE_DELTA" else rows[1]
)
tool_event = (
rows[1] if rows[1]["event_type"] == "TOOL_COMPLETED" else rows[0]
# Only TOOL_COMPLETED should be logged
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)

assert state_delta_event["event_type"] == "STATE_DELTA"
assert tool_event["event_type"] == "TOOL_COMPLETED"

# Verify STATE_DELTA payload
attributes = json.loads(state_delta_event["attributes"])
assert "state_delta" in attributes
assert attributes["state_delta"] == {"new_key": "new_value"}
assert state_delta_event["content"] is None
assert log_entry["event_type"] == "TOOL_COMPLETED"

@pytest.mark.asyncio
async def test_on_state_change_callback_logs_correctly(
Expand All @@ -1615,6 +1598,7 @@ async def test_on_state_change_callback_logs_correctly(
callback_context,
dummy_arrow_schema,
):
"""STATE_DELTA is logged via on_state_change_callback."""
state_delta = {"key": "value", "new_key": 123}
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
await bq_plugin_inst.on_state_change_callback(
Expand All @@ -1625,10 +1609,8 @@ async def test_on_state_change_callback_logs_correctly(
mock_write_client, dummy_arrow_schema
)
_assert_common_fields(log_entry, "STATE_DELTA")
# content should be None (as raw_content was not passed)
assert log_entry["content"] is None

# state_delta should be in attributes
attributes = json.loads(log_entry["attributes"])
assert attributes["state_delta"] == state_delta

Expand Down
62 changes: 62 additions & 0 deletions tests/unittests/plugins/test_plugin_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ async def after_model_callback(self, **kwargs):
async def on_model_error_callback(self, **kwargs):
return await self._handle_callback("on_model_error_callback")

async def on_state_change_callback(self, **kwargs):
return await self._handle_callback("on_state_change_callback")


@pytest.fixture
def service() -> PluginManager:
Expand Down Expand Up @@ -252,6 +255,10 @@ async def test_all_callbacks_are_supported(
llm_request=mock_context,
error=mock_context,
)
await service.run_on_state_change_callback(
callback_context=mock_context,
state_delta={"key": "value"},
)

# Verify all callbacks were logged
expected_callbacks = [
Expand All @@ -267,6 +274,7 @@ async def test_all_callbacks_are_supported(
"before_model_callback",
"after_model_callback",
"on_model_error_callback",
"on_state_change_callback",
]
assert set(plugin1.call_log) == set(expected_callbacks)

Expand Down Expand Up @@ -317,3 +325,57 @@ async def slow_close():
assert "Failed to close plugins: 'plugin1': TimeoutError" in str(
excinfo.value
)


# --- on_state_change_callback tests ---


@pytest.mark.asyncio
async def test_run_on_state_change_callback(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that run_on_state_change_callback invokes the callback and returns None."""
service.register_plugin(plugin1)
result = await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)
assert result is None
assert "on_state_change_callback" in plugin1.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_calls_all_plugins(
service: PluginManager, plugin1: TestPlugin, plugin2: TestPlugin
):
"""Tests that on_state_change_callback is called on all plugins."""
service.register_plugin(plugin1)
service.register_plugin(plugin2)

await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "on_state_change_callback" in plugin1.call_log
assert "on_state_change_callback" in plugin2.call_log


@pytest.mark.asyncio
async def test_run_on_state_change_callback_wraps_exceptions(
service: PluginManager, plugin1: TestPlugin
):
"""Tests that exceptions in on_state_change_callback are wrapped in RuntimeError."""
original_exception = ValueError("state change error")
plugin1.exceptions_to_raise["on_state_change_callback"] = original_exception
service.register_plugin(plugin1)

with pytest.raises(RuntimeError) as excinfo:
await service.run_on_state_change_callback(
callback_context=Mock(),
state_delta={"key": "value"},
)

assert "Error in plugin 'plugin1'" in str(excinfo.value)
assert "on_state_change_callback" in str(excinfo.value)
assert excinfo.value.__cause__ is original_exception
47 changes: 47 additions & 0 deletions tests/unittests/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def __init__(self):
self.enable_user_message_callback = False
self.enable_event_callback = False
self.user_content_seen_in_before_run_callback = None
self.state_change_deltas: list[dict] = []

async def on_user_message_callback(
self,
Expand All @@ -169,6 +170,9 @@ async def before_run_callback(
invocation_context.user_content
)

async def on_state_change_callback(self, *, callback_context, state_delta, **kwargs):
self.state_change_deltas.append(state_delta)

async def on_event_callback(
self, *, invocation_context: InvocationContext, event: Event
) -> Optional[Event]:
Expand Down Expand Up @@ -853,6 +857,49 @@ async def test_runner_passes_plugin_close_timeout(self):
)
assert runner.plugin_manager._close_timeout == 10.0

@pytest.mark.asyncio
async def test_state_delta_in_run_async_triggers_on_state_change_callback(
self,
):
"""Test that caller-supplied state_delta triggers on_state_change_callback."""
await self.session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
state_delta = {"lang": "en", "theme": "dark"}
events = []
async for event in self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(
role="user", parts=[types.Part(text="Hello")]
),
state_delta=state_delta,
):
events.append(event)

assert len(self.plugin.state_change_deltas) >= 1
assert self.plugin.state_change_deltas[0] == state_delta

@pytest.mark.asyncio
async def test_no_state_delta_does_not_trigger_on_state_change_callback(
self,
):
"""Test that on_state_change_callback is not called when no state_delta is provided."""
await self.session_service.create_session(
app_name=TEST_APP_ID, user_id=TEST_USER_ID, session_id=TEST_SESSION_ID
)
events = []
async for event in self.runner.run_async(
user_id=TEST_USER_ID,
session_id=TEST_SESSION_ID,
new_message=types.Content(
role="user", parts=[types.Part(text="Hello")]
),
):
events.append(event)

assert len(self.plugin.state_change_deltas) == 0

@pytest.mark.filterwarnings(
"ignore:The `plugins` argument is deprecated:DeprecationWarning"
)
Expand Down