diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 72a2292fb3..6e664ca7f8 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -21,6 +21,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore + from a2a.server.tasks import TaskStore from a2a.types import AgentCard except ImportError as e: if sys.version_info < (3, 10): @@ -91,6 +92,7 @@ def to_a2a( protocol: str = "http", agent_card: Optional[Union[AgentCard, str]] = None, runner: Optional[Runner] = None, + task_store: Optional[TaskStore] = None, ) -> Starlette: """Convert an ADK agent to a A2A Starlette application. @@ -104,6 +106,9 @@ def to_a2a( agent. runner: Optional pre-built Runner object. If not provided, a default runner will be created using in-memory services. + task_store: Optional task store instance. If not provided, an + InMemoryTaskStore will be created. Must be compatible with + DefaultRequestHandler's task_store parameter. Returns: A Starlette application that can be run with uvicorn @@ -115,6 +120,11 @@ def to_a2a( # Or with custom agent card: app = to_a2a(agent, agent_card=my_custom_agent_card) + + # Or with custom task store: + from a2a.server.tasks import TaskStore + class MyCustomTaskStore(TaskStore): ... # A user-defined TaskStore; abstract methods must be implemented + app = to_a2a(agent, task_store=MyCustomTaskStore()) """ # Set up ADK logging to ensure logs are visible when using uvicorn directly setup_adk_logger(logging.INFO) @@ -132,7 +142,8 @@ async def create_runner() -> Runner: ) # Create A2A components - task_store = InMemoryTaskStore() + if task_store is None: + task_store = InMemoryTaskStore() agent_executor = A2aAgentExecutor( runner=runner or create_runner, diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index ee80b0233b..69040701b3 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -29,6 +29,7 @@ from a2a.server.apps import A2AStarletteApplication from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import InMemoryTaskStore + from a2a.server.tasks import TaskStore from a2a.types import AgentCard from google.adk.a2a.executor.a2a_agent_executor import A2aAgentExecutor from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder @@ -148,6 +149,47 @@ def test_to_a2a_with_custom_runner( "startup", mock_app.add_event_handler.call_args[0][1] ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") + @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") + @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore") + @patch("google.adk.a2a.utils.agent_to_a2a.AgentCardBuilder") + @patch("google.adk.a2a.utils.agent_to_a2a.Starlette") + def test_to_a2a_with_custom_task_store( + self, + mock_starlette_class, + mock_card_builder_class, + mock_task_store_class, + mock_request_handler_class, + mock_agent_executor_class, + ): + """Test to_a2a with a custom task store.""" + # Arrange + mock_app = Mock(spec=Starlette) + mock_starlette_class.return_value = mock_app + custom_task_store = Mock(spec=TaskStore) + mock_agent_executor = Mock(spec=A2aAgentExecutor) + mock_agent_executor_class.return_value = mock_agent_executor + + # Act + result = to_a2a(self.mock_agent, task_store=custom_task_store) + + # Assert + assert result == mock_app + mock_starlette_class.assert_called_once() + # Verify InMemoryTaskStore was NOT created since we provided a custom one + mock_task_store_class.assert_not_called() + mock_agent_executor_class.assert_called_once() + # Verify the custom task store was used + mock_request_handler_class.assert_called_once_with( + agent_executor=mock_agent_executor, task_store=custom_task_store + ) + mock_card_builder_class.assert_called_once_with( + agent=self.mock_agent, rpc_url="http://localhost:8000/" + ) + mock_app.add_event_handler.assert_called_once_with( + "startup", mock_app.add_event_handler.call_args[0][1] + ) + @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @patch("google.adk.a2a.utils.agent_to_a2a.InMemoryTaskStore")