From a2e89a22a516ecabec64a7b5018375016572a380 Mon Sep 17 00:00:00 2001 From: Hangfei Lin Date: Fri, 29 Aug 2025 17:33:00 -0700 Subject: [PATCH] feat: Upgrade ADK stack to use App instead in addition to root_agent The convention: - If some fields(like plugin) are defined both at root_agent and app, then a error will be raised. - app code should be located within agent.py. - an instance named app should be created PiperOrigin-RevId: 801084463 --- .../samples/hello_world_app/__init__.py | 15 ++ contributing/samples/hello_world_app/agent.py | 145 ++++++++++++++++++ contributing/samples/hello_world_app/main.py | 103 +++++++++++++ src/google/adk/apps/__init__.py | 19 +++ src/google/adk/apps/app.py | 52 +++++++ src/google/adk/cli/adk_web_server.py | 22 ++- src/google/adk/cli/cli.py | 25 ++- src/google/adk/cli/utils/agent_loader.py | 36 +++-- src/google/adk/cli/utils/base_agent_loader.py | 4 +- src/google/adk/runners.py | 84 ++++++++-- tests/unittests/cli/utils/test_cli.py | 7 +- 11 files changed, 476 insertions(+), 36 deletions(-) create mode 100755 contributing/samples/hello_world_app/__init__.py create mode 100755 contributing/samples/hello_world_app/agent.py create mode 100755 contributing/samples/hello_world_app/main.py create mode 100644 src/google/adk/apps/__init__.py create mode 100644 src/google/adk/apps/app.py diff --git a/contributing/samples/hello_world_app/__init__.py b/contributing/samples/hello_world_app/__init__.py new file mode 100755 index 0000000000..c48963cdc7 --- /dev/null +++ b/contributing/samples/hello_world_app/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import agent diff --git a/contributing/samples/hello_world_app/agent.py b/contributing/samples/hello_world_app/agent.py new file mode 100755 index 0000000000..95d0e2add8 --- /dev/null +++ b/contributing/samples/hello_world_app/agent.py @@ -0,0 +1,145 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +from google.adk import Agent +from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.callback_context import CallbackContext +from google.adk.apps import App +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools.tool_context import ToolContext +from google.genai import types + + +def roll_die(sides: int, tool_context: ToolContext) -> int: + """Roll a die and return the rolled result. + + Args: + sides: The integer number of sides the die has. + + Returns: + An integer of the result of rolling the die. + """ + result = random.randint(1, sides) + if not 'rolls' in tool_context.state: + tool_context.state['rolls'] = [] + + tool_context.state['rolls'] = tool_context.state['rolls'] + [result] + return result + + +async def check_prime(nums: list[int]) -> str: + """Check if a given list of numbers are prime. + + Args: + nums: The list of numbers to check. + + Returns: + A str indicating which number is prime. + """ + primes = set() + for number in nums: + number = int(number) + if number <= 1: + continue + is_prime = True + for i in range(2, int(number**0.5) + 1): + if number % i == 0: + is_prime = False + break + if is_prime: + primes.add(number) + return ( + 'No prime numbers found.' + if not primes + else f"{', '.join(str(num) for num in primes)} are prime numbers." + ) + + +root_agent = Agent( + model='gemini-2.0-flash', + name='hello_world_agent', + description=( + 'hello world agent that can roll a dice of 8 sides and check prime' + ' numbers.' + ), + instruction=""" + You roll dice and answer questions about the outcome of the dice rolls. + You can roll dice of different sizes. + You can use multiple tools in parallel by calling functions in parallel(in one request and in one round). + It is ok to discuss previous dice roles, and comment on the dice rolls. + When you are asked to roll a die, you must call the roll_die tool with the number of sides. Be sure to pass in an integer. Do not pass in a string. + You should never roll a die on your own. + When checking prime numbers, call the check_prime tool with a list of integers. Be sure to pass in a list of integers. You should never pass in a string. + You should not check prime numbers before calling the tool. + When you are asked to roll a die and check prime numbers, you should always make the following two function calls: + 1. You should first call the roll_die tool to get a roll. Wait for the function response before calling the check_prime tool. + 2. After you get the function response from roll_die tool, you should call the check_prime tool with the roll_die result. + 2.1 If user asks you to check primes based on previous rolls, make sure you include the previous rolls in the list. + 3. When you respond, you must include the roll_die result from step 1. + You should always perform the previous 3 steps when asking for a roll and checking prime numbers. + You should not rely on the previous history on prime results. + """, + tools=[ + roll_die, + check_prime, + ], + # planner=BuiltInPlanner( + # thinking_config=types.ThinkingConfig( + # include_thoughts=True, + # ), + # ), + generate_content_config=types.GenerateContentConfig( + safety_settings=[ + types.SafetySetting( # avoid false alarm about rolling dice. + category=types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, + threshold=types.HarmBlockThreshold.OFF, + ), + ] + ), +) + + +class CountInvocationPlugin(BasePlugin): + """A custom plugin that counts agent and tool invocations.""" + + def __init__(self) -> None: + """Initialize the plugin with counters.""" + super().__init__(name='count_invocation') + self.agent_count: int = 0 + self.tool_count: int = 0 + self.llm_request_count: int = 0 + + async def before_agent_callback( + self, *, agent: BaseAgent, callback_context: CallbackContext + ) -> None: + """Count agent runs.""" + self.agent_count += 1 + print(f'[Plugin] Agent run count: {self.agent_count}') + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> None: + """Count LLM requests.""" + self.llm_request_count += 1 + print(f'[Plugin] LLM request count: {self.llm_request_count}') + + +app = App( + name='hello_world_app', + root_agent=root_agent, + plugins=[CountInvocationPlugin()], +) diff --git a/contributing/samples/hello_world_app/main.py b/contributing/samples/hello_world_app/main.py new file mode 100755 index 0000000000..b9e3035528 --- /dev/null +++ b/contributing/samples/hello_world_app/main.py @@ -0,0 +1,103 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +import agent +from dotenv import load_dotenv +from google.adk.agents.run_config import RunConfig +from google.adk.cli.utils import logs +from google.adk.runners import InMemoryRunner +from google.adk.sessions.session import Session +from google.genai import types + +load_dotenv(override=True) +logs.log_to_tmp_folder() + + +async def main(): + app_name = 'my_app' + user_id_1 = 'user1' + runner = InMemoryRunner( + agent=agent.root_agent, + app_name=app_name, + ) + session_11 = await runner.session_service.create_session( + app_name=app_name, user_id=user_id_1 + ) + + async def run_prompt(session: Session, new_message: str): + content = types.Content( + role='user', parts=[types.Part.from_text(text=new_message)] + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + async def run_prompt_bytes(session: Session, new_message: str): + content = types.Content( + role='user', + parts=[ + types.Part.from_bytes( + data=str.encode(new_message), mime_type='text/plain' + ) + ], + ) + print('** User says:', content.model_dump(exclude_none=True)) + async for event in runner.run_async( + user_id=user_id_1, + session_id=session.id, + new_message=content, + run_config=RunConfig(save_input_blobs_as_artifacts=True), + ): + if event.content.parts and event.content.parts[0].text: + print(f'** {event.author}: {event.content.parts[0].text}') + + async def check_rolls_in_state(rolls_size: int): + session = await runner.session_service.get_session( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + assert len(session.state['rolls']) == rolls_size + for roll in session.state['rolls']: + assert roll > 0 and roll <= 100 + + start_time = time.time() + print('Start time:', start_time) + print('------------------------------------') + await run_prompt(session_11, 'Hi') + await run_prompt(session_11, 'Roll a die with 100 sides') + await check_rolls_in_state(1) + await run_prompt(session_11, 'Roll a die again with 100 sides.') + await check_rolls_in_state(2) + await run_prompt(session_11, 'What numbers did I got?') + await run_prompt_bytes(session_11, 'Hi bytes') + print( + await runner.artifact_service.list_artifact_keys( + app_name=app_name, user_id=user_id_1, session_id=session_11.id + ) + ) + end_time = time.time() + print('------------------------------------') + print('End time:', end_time) + print('Total time:', end_time - start_time) + + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/src/google/adk/apps/__init__.py b/src/google/adk/apps/__init__.py new file mode 100644 index 0000000000..33721570a3 --- /dev/null +++ b/src/google/adk/apps/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .app import App + +__all__ = [ + 'App', +] diff --git a/src/google/adk/apps/app.py b/src/google/adk/apps/app.py new file mode 100644 index 0000000000..67b2f45834 --- /dev/null +++ b/src/google/adk/apps/app.py @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from abc import ABC +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + +from ..agents.base_agent import BaseAgent +from ..plugins.base_plugin import BasePlugin +from ..utils.feature_decorator import experimental + + +@experimental +class App(BaseModel): + """Represents an LLM-backed agentic application. + + An `App` is the top-level container for an agentic system powered by LLMs. + It manages a root agent (`root_agent`), which serves as the root of an agent + tree, enabling coordination and communication across all agents in the + hierarchy. + The `plugins` are application-wide components that provide shared capabilities + and services to the entire system. + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="forbid", + ) + + name: str + """The name of the application.""" + + root_agent: BaseAgent + """The root agent in the application. One app can only have one root agent.""" + + plugins: list[BasePlugin] = Field(default_factory=list) + """The plugins in the application.""" diff --git a/src/google/adk/cli/adk_web_server.py b/src/google/adk/cli/adk_web_server.py index 4237120525..b8b9a6ceab 100644 --- a/src/google/adk/cli/adk_web_server.py +++ b/src/google/adk/cli/adk_web_server.py @@ -50,10 +50,12 @@ from watchdog.observers import Observer from . import agent_graph +from ..agents.base_agent import BaseAgent from ..agents.live_request_queue import LiveRequest from ..agents.live_request_queue import LiveRequestQueue from ..agents.run_config import RunConfig from ..agents.run_config import StreamingMode +from ..apps import App from ..artifacts.base_artifact_service import BaseArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService from ..errors.not_found_error import NotFoundError @@ -305,10 +307,17 @@ async def get_runner_async(self, app_name: str) -> Runner: envs.load_dotenv_for_agent(os.path.basename(app_name), self.agents_dir) if app_name in self.runner_dict: return self.runner_dict[app_name] - root_agent = self.agent_loader.load_agent(app_name) + agent_or_app = self.agent_loader.load_agent(app_name) + agentic_app = None + if isinstance(agent_or_app, BaseAgent): + agentic_app = App( + name=app_name, + root_agent=agent_or_app, + ) + else: + agentic_app = agent_or_app runner = Runner( - app_name=app_name, - agent=root_agent, + app=agentic_app, artifact_service=self.artifact_service, session_service=self.session_service, memory_service=self.memory_service, @@ -597,9 +606,10 @@ async def add_session_to_eval_set( invocations = evals.convert_session_to_eval_invocations(session) # Populate the session with initial session state. - initial_session_state = create_empty_state( - self.agent_loader.load_agent(app_name) - ) + agent_or_app = self.agent_loader.load_agent(app_name) + if isinstance(agent_or_app, App): + agent_or_app = agent_or_app.root_agent + initial_session_state = create_empty_state(agent_or_app) new_eval_case = EvalCase( eval_id=req.eval_id, diff --git a/src/google/adk/cli/cli.py b/src/google/adk/cli/cli.py index 70c58d04c3..c14bd7836e 100644 --- a/src/google/adk/cli/cli.py +++ b/src/google/adk/cli/cli.py @@ -16,12 +16,15 @@ from datetime import datetime from typing import Optional +from typing import Union import click from google.genai import types from pydantic import BaseModel +from ..agents.base_agent import BaseAgent from ..agents.llm_agent import LlmAgent +from ..apps.app import App from ..artifacts.base_artifact_service import BaseArtifactService from ..artifacts.in_memory_artifact_service import InMemoryArtifactService from ..auth.credential_service.base_credential_service import BaseCredentialService @@ -43,15 +46,19 @@ class InputFile(BaseModel): async def run_input_file( app_name: str, user_id: str, - root_agent: LlmAgent, + agent_or_app: Union[LlmAgent, App], artifact_service: BaseArtifactService, session_service: BaseSessionService, credential_service: BaseCredentialService, input_path: str, ) -> Session: + app = ( + agent_or_app + if isinstance(agent_or_app, App) + else App(name=app_name, root_agent=agent_or_app) + ) runner = Runner( - app_name=app_name, - agent=root_agent, + app=app, artifact_service=artifact_service, session_service=session_service, credential_service=credential_service, @@ -79,15 +86,19 @@ async def run_input_file( async def run_interactively( - root_agent: LlmAgent, + root_agent_or_app: Union[LlmAgent, App], artifact_service: BaseArtifactService, session: Session, session_service: BaseSessionService, credential_service: BaseCredentialService, ) -> None: + app = ( + root_agent_or_app + if isinstance(root_agent_or_app, App) + else App(name=session.app_name, root_agent=root_agent_or_app) + ) runner = Runner( - app_name=session.app_name, - agent=root_agent, + app=app, artifact_service=artifact_service, session_service=session_service, credential_service=credential_service, @@ -154,7 +165,7 @@ async def run_cli( session = await run_input_file( app_name=agent_folder_name, user_id=user_id, - root_agent=root_agent, + agent_or_app=root_agent, artifact_service=artifact_service, session_service=session_service, credential_service=credential_service, diff --git a/src/google/adk/cli/utils/agent_loader.py b/src/google/adk/cli/utils/agent_loader.py index c5c83e4d23..042b873a4f 100644 --- a/src/google/adk/cli/utils/agent_loader.py +++ b/src/google/adk/cli/utils/agent_loader.py @@ -20,6 +20,7 @@ from pathlib import Path import sys from typing import Optional +from typing import Union from pydantic import ValidationError from typing_extensions import override @@ -27,6 +28,7 @@ from . import envs from ...agents import config_agent_utils from ...agents.base_agent import BaseAgent +from ...apps.app import App from ...utils.feature_decorator import experimental from .base_agent_loader import BaseAgentLoader @@ -50,19 +52,25 @@ class AgentLoader(BaseAgentLoader): def __init__(self, agents_dir: str): self.agents_dir = agents_dir.rstrip("/") self._original_sys_path = None - self._agent_cache: dict[str, BaseAgent] = {} + self._agent_cache: dict[str, Union[BaseAgent, App]] = {} def _load_from_module_or_package( self, agent_name: str - ) -> Optional[BaseAgent]: + ) -> Optional[Union[BaseAgent, App]]: # Load for case: Import "{agent_name}" (as a package or module) # Covers structures: # a) agents_dir/{agent_name}.py (with root_agent in the module) # b) agents_dir/{agent_name}/__init__.py (with root_agent in the package) try: module_candidate = importlib.import_module(agent_name) + # Check for "app" first, then "root_agent" + if hasattr(module_candidate, "app") and isinstance( + module_candidate.app, App + ): + logger.debug("Found app in %s", agent_name) + return module_candidate.app # Check for "root_agent" directly in "{agent_name}" module/package - if hasattr(module_candidate, "root_agent"): + elif hasattr(module_candidate, "root_agent"): logger.debug("Found root_agent directly in %s", agent_name) if isinstance(module_candidate.root_agent, BaseAgent): return module_candidate.root_agent @@ -96,12 +104,20 @@ def _load_from_module_or_package( return None - def _load_from_submodule(self, agent_name: str) -> Optional[BaseAgent]: + def _load_from_submodule( + self, agent_name: str + ) -> Optional[Union[BaseAgent], App]: # Load for case: Import "{agent_name}.agent" and look for "root_agent" # Covers structure: agents_dir/{agent_name}/agent.py (with root_agent defined in the module) try: module_candidate = importlib.import_module(f"{agent_name}.agent") - if hasattr(module_candidate, "root_agent"): + # Check for "app" first, then "root_agent" + if hasattr(module_candidate, "app") and isinstance( + module_candidate.app, App + ): + logger.debug("Found app in %s.agent", agent_name) + return module_candidate.app + elif hasattr(module_candidate, "root_agent"): logger.info("Found root_agent in %s.agent", agent_name) if isinstance(module_candidate.root_agent, BaseAgent): return module_candidate.root_agent @@ -161,7 +177,7 @@ def _load_from_yaml_config(self, agent_name: str) -> Optional[BaseAgent]: ) + e.args[1:] raise e - def _perform_load(self, agent_name: str) -> BaseAgent: + def _perform_load(self, agent_name: str) -> Union[BaseAgent, App]: """Internal logic to load an agent""" # Add self.agents_dir to sys.path if self.agents_dir not in sys.path: @@ -192,16 +208,16 @@ def _perform_load(self, agent_name: str) -> BaseAgent: ) @override - def load_agent(self, agent_name: str) -> BaseAgent: + def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: """Load an agent module (with caching & .env) and return its root_agent.""" if agent_name in self._agent_cache: logger.debug("Returning cached agent for %s (async)", agent_name) return self._agent_cache[agent_name] logger.debug("Loading agent %s - not in cache.", agent_name) - agent = self._perform_load(agent_name) - self._agent_cache[agent_name] = agent - return agent + agent_or_app = self._perform_load(agent_name) + self._agent_cache[agent_name] = agent_or_app + return agent_or_app @override def list_agents(self) -> list[str]: diff --git a/src/google/adk/cli/utils/base_agent_loader.py b/src/google/adk/cli/utils/base_agent_loader.py index 015d450b35..d62a6b8651 100644 --- a/src/google/adk/cli/utils/base_agent_loader.py +++ b/src/google/adk/cli/utils/base_agent_loader.py @@ -18,15 +18,17 @@ from abc import ABC from abc import abstractmethod +from typing import Union from ...agents.base_agent import BaseAgent +from ...apps.app import App class BaseAgentLoader(ABC): """Abstract base class for agent loaders.""" @abstractmethod - def load_agent(self, agent_name: str) -> BaseAgent: + def load_agent(self, agent_name: str) -> Union[BaseAgent, App]: """Loads an instance of an agent with the given name.""" @abstractmethod diff --git a/src/google/adk/runners.py b/src/google/adk/runners.py index bcf29839d1..c996af862b 100644 --- a/src/google/adk/runners.py +++ b/src/google/adk/runners.py @@ -34,6 +34,7 @@ from .agents.live_request_queue import LiveRequestQueue from .agents.llm_agent import LlmAgent from .agents.run_config import RunConfig +from .apps.app import App from .artifacts.base_artifact_service import BaseArtifactService from .artifacts.in_memory_artifact_service import InMemoryArtifactService from .auth.credential_service.base_credential_service import BaseCredentialService @@ -91,8 +92,9 @@ class Runner: def __init__( self, *, - app_name: str, - agent: BaseAgent, + app: Optional[App] = None, + app_name: Optional[str] = None, + agent: Optional[BaseAgent] = None, plugins: Optional[List[BasePlugin]] = None, artifact_service: Optional[BaseArtifactService] = None, session_service: BaseSessionService, @@ -101,23 +103,85 @@ def __init__( ): """Initializes the Runner. + Developers should provide either an `app` instance or both `app_name` and + `agent`. Providing a mix of `app` and `app_name`/`agent` will result in a + `ValueError`. Providing `app` is the recommended way to create a runner. + Args: - app_name: The application name of the runner. - agent: The root agent to run. - plugins: A list of plugins for the runner. + app_name: The application name of the runner. Required if `app` is not + provided. + agent: The root agent to run. Required if `app` is not provided. + app: An optional `App` instance. If provided, `app_name` and `agent` + should not be specified. + plugins: Deprecated. A list of plugins for the runner. Please use the + `app` argument to provide plugins instead. artifact_service: The artifact service for the runner. session_service: The session service for the runner. memory_service: The memory service for the runner. credential_service: The credential service for the runner. + + Raises: + ValueError: If `app` is provided along with `app_name` or `plugins`, or + if `app` is not provided but either `app_name` or `agent` is missing. """ - self.app_name = app_name - self.agent = agent + self.app_name, self.agent, plugins = self._validate_runner_params( + app, app_name, agent, plugins + ) self.artifact_service = artifact_service self.session_service = session_service self.memory_service = memory_service self.credential_service = credential_service self.plugin_manager = PluginManager(plugins=plugins) + def _validate_runner_params( + self, + app: Optional[App], + app_name: Optional[str], + agent: Optional[BaseAgent], + plugins: Optional[List[BasePlugin]], + ) -> tuple[str, BaseAgent, Optional[List[BasePlugin]]]: + """Validates and extracts runner parameters. + + Args: + app: An optional `App` instance. + app_name: The application name of the runner. + agent: The root agent to run. + plugins: A list of plugins for the runner. + + Returns: + A tuple containing (app_name, agent, plugins). + + Raises: + ValueError: If parameters are invalid. + """ + if app: + if app_name: + raise ValueError( + 'When app is provided, app_name should not be provided.' + ) + if agent: + raise ValueError('When app is provided, agent should not be provided.') + if plugins: + raise ValueError( + 'When app is provided, plugins should not be provided and should be' + ' provided in the app instead.' + ) + app_name = app.name + agent = app.root_agent + plugins = app.plugins + elif not app_name or not agent: + raise ValueError( + 'Either app or both app_name and agent must be provided.' + ) + + if plugins: + warnings.warn( + 'The `plugins` argument is deprecated. Please use the `app` argument' + ' to provide plugins instead.', + DeprecationWarning, + ) + return app_name, agent, plugins + def run( self, *, @@ -656,10 +720,11 @@ class InMemoryRunner(Runner): def __init__( self, - agent: BaseAgent, + agent: Optional[BaseAgent] = None, *, - app_name: str = 'InMemoryRunner', + app_name: Optional[str] = 'InMemoryRunner', plugins: Optional[list[BasePlugin]] = None, + app: Optional[App] = None, ): """Initializes the InMemoryRunner. @@ -674,6 +739,7 @@ def __init__( agent=agent, artifact_service=InMemoryArtifactService(), plugins=plugins, + app=app, session_service=self._in_memory_session_service, memory_service=InMemoryMemoryService(), ) diff --git a/tests/unittests/cli/utils/test_cli.py b/tests/unittests/cli/utils/test_cli.py index 2139a8c204..425b2a326e 100644 --- a/tests/unittests/cli/utils/test_cli.py +++ b/tests/unittests/cli/utils/test_cli.py @@ -26,6 +26,7 @@ from typing import Tuple import click +from google.adk.agents.base_agent import BaseAgent import google.adk.cli.cli as cli import pytest @@ -130,12 +131,12 @@ def _echo(msg: str) -> None: artifact_service = cli.InMemoryArtifactService() session_service = cli.InMemorySessionService() credential_service = cli.InMemoryCredentialService() - dummy_root = types.SimpleNamespace(name="root") + dummy_root = BaseAgent(name="root") session = await cli.run_input_file( app_name="app", user_id="user", - root_agent=dummy_root, + agent_or_app=dummy_root, artifact_service=artifact_service, session_service=session_service, credential_service=credential_service, @@ -205,7 +206,7 @@ async def test_run_interactively_whitespace_and_exit( sess = await session_service.create_session(app_name="dummy", user_id="u") artifact_service = cli.InMemoryArtifactService() credential_service = cli.InMemoryCredentialService() - root_agent = types.SimpleNamespace(name="root") + root_agent = BaseAgent(name="root") # fake user input: blank -> 'hello' -> 'exit' answers = iter([" ", "hello", "exit"])