diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index e57e0c8fa9..e4358b07ce 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -202,7 +202,7 @@ class RunAgentRequest(common.BaseModel): app_name: str user_id: str session_id: str - new_message: types.Content + new_message: Optional[types.Content] = None streaming: bool = False state_delta: Optional[dict[str, Any]] = None # for resume long-running functions @@ -369,7 +369,7 @@ def _otel_env_vars_enabled() -> bool: def _setup_gcp_telemetry( - internal_exporters: list[SpanProcessor] = None, + internal_exporters: list[SpanProcessor] | None = None, ): if typing.TYPE_CHECKING: from ..telemetry.setup import OTelHooks @@ -411,7 +411,7 @@ def _setup_gcp_telemetry( def _setup_telemetry_from_env( - internal_exporters: list[SpanProcessor] = None, + internal_exporters: list[SpanProcessor] | None = None, ): from ..telemetry.setup import maybe_set_otel_providers @@ -507,7 +507,7 @@ def __init__( # Internal properties we want to allow being modified from callbacks. self.runners_to_clean: set[str] = set() self.current_app_name_ref: SharedValue[str] = SharedValue(value="") - self.runner_dict = {} + self.runner_dict: dict[str, Runner] = {} self.url_prefix = url_prefix async def get_runner_async(self, app_name: str) -> Runner: @@ -707,8 +707,8 @@ def get_fast_api_app( A FastAPI app instance. """ # Properties we don't need to modify from callbacks - trace_dict = {} - session_trace_dict = {} + trace_dict: dict[str, Any] = {} + session_trace_dict: dict[str, list[int]] = {} # Set up a file system watcher to detect changes in the agents directory. observer = Observer() setup_observer(observer, self) @@ -1413,6 +1413,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]: session_id=req.session_id, new_message=req.new_message, state_delta=req.state_delta, + invocation_id=req.invocation_id, ) ) as agen: events = [event async for event in agen] diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index 545a0e83e6..5a6b42522b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -23,6 +23,7 @@ from typing import Any from typing import AsyncGenerator from typing import Callable +from typing import cast from typing import Generator from typing import List from typing import Optional @@ -414,7 +415,7 @@ def run( The events generated by the agent. """ run_config = run_config or RunConfig() - event_queue = queue.Queue() + event_queue: queue.Queue[Optional[Event]] = queue.Queue() async def _invoke_run_async(): try: @@ -481,8 +482,8 @@ async def run_async( The events generated by the agent. Raises: - ValueError: If the session is not found; If both invocation_id and - new_message are None. + ValueError: If the session is not found and `auto_create_session` is False, + or if both `invocation_id` and `new_message` are `None`. """ run_config = run_config or RunConfig() @@ -497,6 +498,7 @@ async def _run_with_trace( session = await self._get_or_create_session( user_id=user_id, session_id=session_id ) + if not invocation_id and not new_message: raise ValueError( 'Running an agent requires either a new_message or an ' @@ -1002,7 +1004,7 @@ async def run_live( ) if not session: session = await self._get_or_create_session( - user_id=user_id, session_id=session_id + user_id=cast(str, user_id), session_id=cast(str, session_id) ) invocation_context = self._new_invocation_context_for_live( session, @@ -1321,7 +1323,7 @@ async def _setup_context_for_resumed_invocation( # Step 1: Maybe retrieve a previous user message for the invocation. user_message = new_message or self._find_user_message_for_invocation( - session.events, invocation_id + session.events, cast(str, invocation_id) ) if not user_message: raise ValueError( @@ -1537,12 +1539,7 @@ async def close(self): logger.info('Runner closed.') - if sys.version_info < (3, 11): - Self = 'Runner' # pylint: disable=invalid-name - else: - from typing import Self # pylint: disable=g-import-not-at-top - - async def __aenter__(self) -> Self: + async def __aenter__(self) -> 'Runner': """Async context manager entry.""" return self diff --git a/tests/unittests/cli/test_fast_api.py b/tests/unittests/cli/test_fast_api.py index fa89021ec5..8b018ef47b 100755 --- a/tests/unittests/cli/test_fast_api.py +++ b/tests/unittests/cli/test_fast_api.py @@ -48,6 +48,7 @@ from google.adk.sessions.state import State from google.genai import types from pydantic import BaseModel +from pydantic import Field import pytest # Configure logging to help diagnose server startup issues @@ -132,6 +133,7 @@ async def dummy_run_async( run_config: Optional[RunConfig] = None, invocation_id: Optional[str] = None, ): + run_config = run_config or RunConfig() yield _event_1() await asyncio.sleep(0) @@ -154,9 +156,9 @@ class _MockEvalCaseResult(BaseModel): user_id: str session_id: str eval_set_file: str - eval_metric_results: list = {} - overall_eval_metric_results: list = ({},) - eval_metric_result_per_invocation: list = {} + eval_metric_results: list = Field(default_factory=list) + overall_eval_metric_results: list = Field(default_factory=list) + eval_metric_result_per_invocation: list = Field(default_factory=list) ################################################# @@ -1336,5 +1338,31 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path): assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists() +@pytest.mark.parametrize( + "extra_payload", + [ + {}, + {"state_delta": {"some_key": "some_value"}}, + ], + ids=["no_state_delta", "with_state_delta"], +) +def test_agent_run_resume_without_message_success( + test_app, create_test_session, extra_payload +): + """Test that /run allows resuming a session with only an invocation_id.""" + info = create_test_session + url = "/run" + payload = { + "app_name": info["app_name"], + "user_id": info["user_id"], + "session_id": info["session_id"], + "invocation_id": "test_invocation_id", + "streaming": False, + **extra_payload, + } + response = test_app.post(url, json=payload) + assert response.status_code == 200 + + if __name__ == "__main__": pytest.main(["-xvs", __file__])