diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index eaca7da214..1e0c321376 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -1020,6 +1020,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: user_id=req.user_id, session_id=req.session_id, new_message=req.new_message, + state_delta=req.state_delta, ) ) as agen: events = [event async for event in agen] diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index 423581dfd9..b05bf3c4ab 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -34,6 +34,7 @@ from google.adk.evaluation.eval_set import EvalSet from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager from google.adk.events.event import Event +from google.adk.events.event_actions import EventActions from google.adk.runners import Runner from google.adk.sessions.base_session_service import ListSessionsResponse from google.genai import types @@ -94,6 +95,14 @@ def _event_3(): ) +def _event_state_delta(state_delta: dict[str, Any]): + return Event( + author="dummy agent", + invocation_id="invocation_id", + actions=EventActions(state_delta=state_delta), + ) + + # Define mocked async generator functions for the Runner async def dummy_run_live(self, session, live_request_queue): yield _event_1() @@ -110,6 +119,7 @@ async def dummy_run_async( user_id, session_id, new_message, + state_delta=None, run_config: RunConfig = RunConfig(), ): yield _event_1() @@ -119,6 +129,10 @@ async def dummy_run_async( await asyncio.sleep(0) yield _event_3() + await asyncio.sleep(0) + + if state_delta is not None: + yield _event_state_delta(state_delta) # Define a local mock for EvalCaseResult specific to fast_api tests @@ -744,6 +758,29 @@ def test_agent_run(test_app, create_test_session): logger.info("Agent run test completed successfully") +def test_agent_run_passes_state_delta(test_app, create_test_session): + """Test /run forwards state_delta and surfaces it in events.""" + info = create_test_session + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "new_message": {"role": "user", "parts": [{"text": "Hello"}]}, + "streaming": False, + "state_delta": {"k": "v", "count": 1}, + } + + # Verify the response + response = test_app.post("/run", json=payload) + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + assert len(data) == 4 + + # Verify we got the expected event + assert data[3]["actions"]["stateDelta"] == payload["state_delta"] + + def test_list_artifact_names(test_app, create_test_session): """Test listing artifact names for a session.""" info = create_test_session