From 4588b9d6c3d2355d9ba22b5d889f4974299354e2 Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Thu, 13 Nov 2025 23:05:36 +0100 Subject: [PATCH 1/6] feat: implement enterprise search agent tool and related functionality --- src/google/adk/agents/llm_agent.py | 12 ++ .../adk/flows/llm_flows/base_llm_flow.py | 6 +- .../adk/tools/enterprise_search_agent_tool.py | 141 ++++++++++++++++++ .../adk/tools/enterprise_search_tool.py | 10 +- .../test_enterprise_search_agent_tool.py | 110 ++++++++++++++ 5 files changed, 276 insertions(+), 3 deletions(-) create mode 100644 src/google/adk/tools/enterprise_search_agent_tool.py create mode 100644 tests/unittests/tools/test_enterprise_search_agent_tool.py diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 2f8a969fad..22298cf9ae 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -138,6 +138,7 @@ async def _convert_tool_union_to_tools( model: Union[str, BaseLlm], multiple_tools: bool = False, ) -> list[BaseTool]: + from ..tools.enterprise_search_tool import EnterpriseWebSearchTool from ..tools.google_search_tool import GoogleSearchTool from ..tools.vertex_ai_search_tool import VertexAiSearchTool @@ -171,6 +172,17 @@ async def _convert_tool_union_to_tools( ) ] + # Wrap enterprise_web_search tool with AgentTool if there are multiple tools + # because the built-in tools cannot be used together with other tools. + # TODO(b/448114567): Remove once the workaround is no longer needed. + if multiple_tools and isinstance(tool_union, EnterpriseWebSearchTool): + from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent + from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool + + enterprise_tool = cast(EnterpriseWebSearchTool, tool_union) + if enterprise_tool.bypass_multi_tools_limit: + return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))] + if isinstance(tool_union, BaseTool): return [tool_union] if callable(tool_union): diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index a95d6b8dcc..23ba58aab2 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -846,7 +846,11 @@ async def _maybe_add_grounding_metadata( tools = await agent.canonical_tools(readonly_context) invocation_context.canonical_tools_cache = tools - if not any(tool.name == 'google_search_agent' for tool in tools): + if not any( + tool.name == 'google_search_agent' + or tool.name == 'enterprise_search_agent' + for tool in tools + ): return response ground_metadata = invocation_context.session.state.get( 'temp:_adk_grounding_metadata', None diff --git a/src/google/adk/tools/enterprise_search_agent_tool.py b/src/google/adk/tools/enterprise_search_agent_tool.py new file mode 100644 index 0000000000..ae7ce8b124 --- /dev/null +++ b/src/google/adk/tools/enterprise_search_agent_tool.py @@ -0,0 +1,141 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import Union + +from google.genai import types +from typing_extensions import override + +from ..agents.llm_agent import LlmAgent +from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..models.base_llm import BaseLlm +from ..runners import Runner +from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils.context_utils import Aclosing +from ._forwarding_artifact_service import ForwardingArtifactService +from .agent_tool import AgentTool +from .enterprise_search_tool import enterprise_web_search_tool +from .tool_context import ToolContext + + +def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: + """Create a sub-agent that only uses enterprise_web_search tool.""" + return LlmAgent( + name='enterprise_search_agent', + model=model, + description=( + 'An agent for performing Enterprise search using the' + ' `enterprise_web_search` tool' + ), + instruction=""" + You are a specialized Enterprise search agent. + + When given a search query, use the `enterprise_web_search` tool to find the related information. + """, + tools=[enterprise_web_search_tool], + ) + + +class EnterpriseSearchAgentTool(AgentTool): + """A tool that wraps a sub-agent that only uses enterprise_web_search tool. + + This is a workaround to support using enterprise_web_search tool with other tools. + TODO(b/448114567): Remove once the workaround is no longer needed. + + Attributes: + model: The model to use for the sub-agent. + """ + + def __init__(self, agent: LlmAgent): + self.agent = agent + super().__init__(agent=self.agent) + + @override + async def run_async( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> Any: + from ..agents.llm_agent import LlmAgent + + if isinstance(self.agent, LlmAgent) and self.agent.input_schema: + input_value = self.agent.input_schema.model_validate(args) + content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=input_value.model_dump_json(exclude_none=True) + ) + ], + ) + else: + content = types.Content( + role='user', + parts=[types.Part.from_text(text=args['request'])], + ) + runner = Runner( + app_name=self.agent.name, + agent=self.agent, + artifact_service=ForwardingArtifactService(tool_context), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=tool_context._invocation_context.credential_service, + plugins=list(tool_context._invocation_context.plugin_manager.plugins), + ) + + state_dict = { + k: v + for k, v in tool_context.state.to_dict().items() + if not k.startswith('_adk') # Filter out adk internal states + } + session = await runner.session_service.create_session( + app_name=self.agent.name, + user_id=tool_context._invocation_context.user_id, + state=state_dict, + ) + + last_content = None + last_grounding_metadata = None + async with Aclosing( + runner.run_async( + user_id=session.user_id, session_id=session.id, new_message=content + ) + ) as agen: + async for event in agen: + # Forward state delta to parent session. + if event.actions.state_delta: + tool_context.state.update(event.actions.state_delta) + if event.content: + last_content = event.content + last_grounding_metadata = event.grounding_metadata + + if not last_content: + return '' + merged_text = '\n'.join(p.text for p in last_content.parts if p.text) + if isinstance(self.agent, LlmAgent) and self.agent.output_schema: + tool_result = self.agent.output_schema.model_validate_json( + merged_text + ).model_dump(exclude_none=True) + else: + tool_result = merged_text + + if last_grounding_metadata: + tool_context.state['temp:_adk_grounding_metadata'] = ( + last_grounding_metadata + ) + return tool_result diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 7980f8f028..5c51138448 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -35,12 +35,18 @@ class EnterpriseWebSearchTool(BaseTool): https://cloud.google.com/vertex-ai/generative-ai/docs/grounding/web-grounding-enterprise. """ - def __init__(self): - """Initializes the Vertex AI Search tool.""" + def __init__(self, *, bypass_multi_tools_limit: bool = False): + """Initializes the Google search tool. + + Args: + bypass_multi_tools_limit: Whether to bypass the multi tools limitation, + so that the tool can be used with other tools in the same agent. + """ # Name and description are not used because this is a model built-in tool. super().__init__( name='enterprise_web_search', description='enterprise_web_search' ) + self.bypass_multi_tools_limit = bypass_multi_tools_limit @override async def process_llm_request( diff --git a/tests/unittests/tools/test_enterprise_search_agent_tool.py b/tests/unittests/tools/test_enterprise_search_agent_tool.py new file mode 100644 index 0000000000..1599100522 --- /dev/null +++ b/tests/unittests/tools/test_enterprise_search_agent_tool.py @@ -0,0 +1,110 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from unittest import mock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.agents.run_config import RunConfig +from google.adk.agents.sequential_agent import SequentialAgent +from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService +from google.adk.memory.in_memory_memory_service import InMemoryMemoryService +from google.adk.plugins.plugin_manager import PluginManager +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.enterprise_search_agent_tool import create_enterprise_search_agent +from google.adk.tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool +from google.adk.tools.tool_context import ToolContext +from pytest import mark + + +async def _create_tool_context() -> ToolContext: + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=agent, + session=session, + session_service=session_service, + artifact_service=InMemoryArtifactService(), + memory_service=InMemoryMemoryService(), + plugin_manager=PluginManager(), + run_config=RunConfig(), + ) + return ToolContext(invocation_context=invocation_context) + + +class TestEnterpriseSearchAgentTool: + """Test the EnterpriseSearchAgentTool class.""" + + def test_create_enterprise_search_agent(self): + """Test that create_enterprise_search_agent creates a valid agent.""" + agent = create_enterprise_search_agent('gemini-pro') + assert isinstance(agent, LlmAgent) + assert agent.name == 'enterprise_search_agent' + assert 'enterprise_web_search' in [t.name for t in agent.tools] + + def test_enterprise_search_agent_tool_init(self): + """Test initialization of EnterpriseSearchAgentTool.""" + mock_agent = mock.MagicMock(spec=LlmAgent) + mock_agent.name = 'test_agent' + mock_agent.description = 'test_description' + tool = EnterpriseSearchAgentTool(mock_agent) + assert tool.agent == mock_agent + + @mark.asyncio + @mock.patch('google.adk.tools.enterprise_search_agent_tool.Runner') + async def test_run_async_succeeds(self, mock_runner_class): + """Test that run_async executes the sub-agent and returns the result.""" + # Arrange + mock_agent = mock.MagicMock(spec=LlmAgent) + mock_agent.name = 'enterprise_search_agent' + mock_agent.description = 'test_description' + mock_agent.input_schema = None + mock_agent.output_schema = None + + tool = EnterpriseSearchAgentTool(mock_agent) + tool_context = await _create_tool_context() + + async def mock_run_async_gen(): + yield mock.MagicMock( + actions=mock.MagicMock(state_delta={'key': 'value'}), content=None + ) + yield mock.MagicMock( + actions=mock.MagicMock(state_delta=None), + content=mock.MagicMock(parts=[mock.MagicMock(text='test response')]), + ) + + mock_runner_instance = mock.MagicMock() + mock_runner_instance.run_async.return_value = mock_run_async_gen() + mock_runner_instance.session_service = mock.AsyncMock() + mock_runner_instance.session_service.create_session.return_value = ( + tool_context._invocation_context.session + ) + mock_runner_class.return_value = mock_runner_instance + + # Act + result = await tool.run_async( + args={'request': 'test query'}, tool_context=tool_context + ) + + # Assert + mock_runner_class.assert_called_once() + mock_runner_instance.run_async.assert_called_once() + assert tool_context.state['key'] == 'value' + assert result == 'test response' From 3a20b7dbd69194a59b45f8d3a303b9d415059589 Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Fri, 14 Nov 2025 00:12:25 +0100 Subject: [PATCH 2/6] feat: refactor search agent tools to inherit from a common base class --- .../adk/flows/llm_flows/base_llm_flow.py | 3 +- src/google/adk/tools/_search_agent_tool.py | 113 ++++++++++++++++++ .../adk/tools/enterprise_search_agent_tool.py | 90 +------------- .../adk/tools/enterprise_search_tool.py | 2 +- .../adk/tools/google_search_agent_tool.py | 90 +------------- .../test_enterprise_search_agent_tool.py | 30 ++--- 6 files changed, 126 insertions(+), 202 deletions(-) create mode 100644 src/google/adk/tools/_search_agent_tool.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 23ba58aab2..f2054a8a80 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -847,8 +847,7 @@ async def _maybe_add_grounding_metadata( invocation_context.canonical_tools_cache = tools if not any( - tool.name == 'google_search_agent' - or tool.name == 'enterprise_search_agent' + tool.name in {'google_search_agent', 'enterprise_search_agent'} for tool in tools ): return response diff --git a/src/google/adk/tools/_search_agent_tool.py b/src/google/adk/tools/_search_agent_tool.py new file mode 100644 index 0000000000..37742e6a02 --- /dev/null +++ b/src/google/adk/tools/_search_agent_tool.py @@ -0,0 +1,113 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from google.genai import types +from typing_extensions import override + +from ..agents.llm_agent import LlmAgent +from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..runners import Runner +from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils.context_utils import Aclosing +from ._forwarding_artifact_service import ForwardingArtifactService +from .agent_tool import AgentTool +from .tool_context import ToolContext + + +class _SearchAgentTool(AgentTool): + """A base class for search agent tools.""" + + @override + async def run_async( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> Any: + from ..agents.llm_agent import LlmAgent + + if isinstance(self.agent, LlmAgent) and self.agent.input_schema: + input_value = self.agent.input_schema.model_validate(args) + content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=input_value.model_dump_json(exclude_none=True) + ) + ], + ) + else: + content = types.Content( + role='user', + parts=[types.Part.from_text(text=args['request'])], + ) + runner = Runner( + app_name=self.agent.name, + agent=self.agent, + artifact_service=ForwardingArtifactService(tool_context), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=tool_context._invocation_context.credential_service, + plugins=list(tool_context._invocation_context.plugin_manager.plugins), + ) + try: + state_dict = { + k: v + for k, v in tool_context.state.to_dict().items() + if not k.startswith('_adk') # Filter out adk internal states + } + session = await runner.session_service.create_session( + app_name=self.agent.name, + user_id=tool_context._invocation_context.user_id, + state=state_dict, + ) + + last_content = None + last_grounding_metadata = None + async with Aclosing( + runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=content, + ) + ) as agen: + async for event in agen: + # Forward state delta to parent session. + if event.actions.state_delta: + tool_context.state.update(event.actions.state_delta) + if event.content: + last_content = event.content + last_grounding_metadata = event.grounding_metadata + + if not last_content: + return '' + merged_text = '\n'.join(p.text for p in last_content.parts if p.text) + if isinstance(self.agent, LlmAgent) and self.agent.output_schema: + tool_result = self.agent.output_schema.model_validate_json( + merged_text + ).model_dump(exclude_none=True) + else: + tool_result = merged_text + + if last_grounding_metadata: + tool_context.state['temp:_adk_grounding_metadata'] = ( + last_grounding_metadata + ) + return tool_result + finally: + await runner.close() diff --git a/src/google/adk/tools/enterprise_search_agent_tool.py b/src/google/adk/tools/enterprise_search_agent_tool.py index ae7ce8b124..e064f17ff6 100644 --- a/src/google/adk/tools/enterprise_search_agent_tool.py +++ b/src/google/adk/tools/enterprise_search_agent_tool.py @@ -14,22 +14,12 @@ from __future__ import annotations -from typing import Any from typing import Union -from google.genai import types -from typing_extensions import override - from ..agents.llm_agent import LlmAgent -from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.base_llm import BaseLlm -from ..runners import Runner -from ..sessions.in_memory_session_service import InMemorySessionService -from ..utils.context_utils import Aclosing -from ._forwarding_artifact_service import ForwardingArtifactService -from .agent_tool import AgentTool +from ._search_agent_tool import _SearchAgentTool from .enterprise_search_tool import enterprise_web_search_tool -from .tool_context import ToolContext def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: @@ -50,7 +40,7 @@ def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: ) -class EnterpriseSearchAgentTool(AgentTool): +class EnterpriseSearchAgentTool(_SearchAgentTool): """A tool that wraps a sub-agent that only uses enterprise_web_search tool. This is a workaround to support using enterprise_web_search tool with other tools. @@ -63,79 +53,3 @@ class EnterpriseSearchAgentTool(AgentTool): def __init__(self, agent: LlmAgent): self.agent = agent super().__init__(agent=self.agent) - - @override - async def run_async( - self, - *, - args: dict[str, Any], - tool_context: ToolContext, - ) -> Any: - from ..agents.llm_agent import LlmAgent - - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: - input_value = self.agent.input_schema.model_validate(args) - content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=input_value.model_dump_json(exclude_none=True) - ) - ], - ) - else: - content = types.Content( - role='user', - parts=[types.Part.from_text(text=args['request'])], - ) - runner = Runner( - app_name=self.agent.name, - agent=self.agent, - artifact_service=ForwardingArtifactService(tool_context), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), - credential_service=tool_context._invocation_context.credential_service, - plugins=list(tool_context._invocation_context.plugin_manager.plugins), - ) - - state_dict = { - k: v - for k, v in tool_context.state.to_dict().items() - if not k.startswith('_adk') # Filter out adk internal states - } - session = await runner.session_service.create_session( - app_name=self.agent.name, - user_id=tool_context._invocation_context.user_id, - state=state_dict, - ) - - last_content = None - last_grounding_metadata = None - async with Aclosing( - runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content - ) - ) as agen: - async for event in agen: - # Forward state delta to parent session. - if event.actions.state_delta: - tool_context.state.update(event.actions.state_delta) - if event.content: - last_content = event.content - last_grounding_metadata = event.grounding_metadata - - if not last_content: - return '' - merged_text = '\n'.join(p.text for p in last_content.parts if p.text) - if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - tool_result = self.agent.output_schema.model_validate_json( - merged_text - ).model_dump(exclude_none=True) - else: - tool_result = merged_text - - if last_grounding_metadata: - tool_context.state['temp:_adk_grounding_metadata'] = ( - last_grounding_metadata - ) - return tool_result diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 5c51138448..7325137dc4 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -36,7 +36,7 @@ class EnterpriseWebSearchTool(BaseTool): """ def __init__(self, *, bypass_multi_tools_limit: bool = False): - """Initializes the Google search tool. + """Initializes the Enterprise web search tool. Args: bypass_multi_tools_limit: Whether to bypass the multi tools limitation, diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index 77cb6fedf9..771537dfe8 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -14,20 +14,12 @@ from __future__ import annotations -from typing import Any from typing import Union -from google.genai import types -from typing_extensions import override - from ..agents.llm_agent import LlmAgent -from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.base_llm import BaseLlm -from ..utils.context_utils import Aclosing -from ._forwarding_artifact_service import ForwardingArtifactService -from .agent_tool import AgentTool +from ._search_agent_tool import _SearchAgentTool from .google_search_tool import google_search -from .tool_context import ToolContext def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: @@ -47,7 +39,7 @@ def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: ) -class GoogleSearchAgentTool(AgentTool): +class GoogleSearchAgentTool(_SearchAgentTool): """A tool that wraps a sub-agent that only uses google_search tool. This is a workaround to support using google_search tool with other tools. @@ -60,81 +52,3 @@ class GoogleSearchAgentTool(AgentTool): def __init__(self, agent: LlmAgent): self.agent = agent super().__init__(agent=self.agent) - - @override - async def run_async( - self, - *, - args: dict[str, Any], - tool_context: ToolContext, - ) -> Any: - from ..agents.llm_agent import LlmAgent - from ..runners import Runner - from ..sessions.in_memory_session_service import InMemorySessionService - - if isinstance(self.agent, LlmAgent) and self.agent.input_schema: - input_value = self.agent.input_schema.model_validate(args) - content = types.Content( - role='user', - parts=[ - types.Part.from_text( - text=input_value.model_dump_json(exclude_none=True) - ) - ], - ) - else: - content = types.Content( - role='user', - parts=[types.Part.from_text(text=args['request'])], - ) - runner = Runner( - app_name=self.agent.name, - agent=self.agent, - artifact_service=ForwardingArtifactService(tool_context), - session_service=InMemorySessionService(), - memory_service=InMemoryMemoryService(), - credential_service=tool_context._invocation_context.credential_service, - plugins=list(tool_context._invocation_context.plugin_manager.plugins), - ) - - state_dict = { - k: v - for k, v in tool_context.state.to_dict().items() - if not k.startswith('_adk') # Filter out adk internal states - } - session = await runner.session_service.create_session( - app_name=self.agent.name, - user_id=tool_context._invocation_context.user_id, - state=state_dict, - ) - - last_content = None - last_grounding_metadata = None - async with Aclosing( - runner.run_async( - user_id=session.user_id, session_id=session.id, new_message=content - ) - ) as agen: - async for event in agen: - # Forward state delta to parent session. - if event.actions.state_delta: - tool_context.state.update(event.actions.state_delta) - if event.content: - last_content = event.content - last_grounding_metadata = event.grounding_metadata - - if not last_content: - return '' - merged_text = '\n'.join(p.text for p in last_content.parts if p.text) - if isinstance(self.agent, LlmAgent) and self.agent.output_schema: - tool_result = self.agent.output_schema.model_validate_json( - merged_text - ).model_dump(exclude_none=True) - else: - tool_result = merged_text - - if last_grounding_metadata: - tool_context.state['temp:_adk_grounding_metadata'] = ( - last_grounding_metadata - ) - return tool_result diff --git a/tests/unittests/tools/test_enterprise_search_agent_tool.py b/tests/unittests/tools/test_enterprise_search_agent_tool.py index 1599100522..6b107a50b2 100644 --- a/tests/unittests/tools/test_enterprise_search_agent_tool.py +++ b/tests/unittests/tools/test_enterprise_search_agent_tool.py @@ -68,9 +68,9 @@ def test_enterprise_search_agent_tool_init(self): assert tool.agent == mock_agent @mark.asyncio - @mock.patch('google.adk.tools.enterprise_search_agent_tool.Runner') - async def test_run_async_succeeds(self, mock_runner_class): - """Test that run_async executes the sub-agent and returns the result.""" + @mock.patch('google.adk.tools._search_agent_tool._SearchAgentTool.run_async') + async def test_run_async_succeeds(self, mock_run_async): + """Test that run_async calls the base class method.""" # Arrange mock_agent = mock.MagicMock(spec=LlmAgent) mock_agent.name = 'enterprise_search_agent' @@ -80,23 +80,7 @@ async def test_run_async_succeeds(self, mock_runner_class): tool = EnterpriseSearchAgentTool(mock_agent) tool_context = await _create_tool_context() - - async def mock_run_async_gen(): - yield mock.MagicMock( - actions=mock.MagicMock(state_delta={'key': 'value'}), content=None - ) - yield mock.MagicMock( - actions=mock.MagicMock(state_delta=None), - content=mock.MagicMock(parts=[mock.MagicMock(text='test response')]), - ) - - mock_runner_instance = mock.MagicMock() - mock_runner_instance.run_async.return_value = mock_run_async_gen() - mock_runner_instance.session_service = mock.AsyncMock() - mock_runner_instance.session_service.create_session.return_value = ( - tool_context._invocation_context.session - ) - mock_runner_class.return_value = mock_runner_instance + mock_run_async.return_value = 'test response' # Act result = await tool.run_async( @@ -104,7 +88,7 @@ async def mock_run_async_gen(): ) # Assert - mock_runner_class.assert_called_once() - mock_runner_instance.run_async.assert_called_once() - assert tool_context.state['key'] == 'value' + mock_run_async.assert_called_once_with( + args={'request': 'test query'}, tool_context=tool_context + ) assert result == 'test response' From b19353a1d2224e7c77414f30c39cafc34aabaf19 Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Fri, 14 Nov 2025 01:36:28 +0100 Subject: [PATCH 3/6] feat: streamline tool handling in LLM agent by implementing dedicated workarounds --- src/google/adk/agents/llm_agent.py | 102 ++++++---- src/google/adk/tools/_search_agent_tool.py | 1 - .../adk/tools/enterprise_search_agent_tool.py | 2 +- .../test_enterprise_search_agent_tool.py | 175 +++++++++++------- 4 files changed, 175 insertions(+), 105 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 22298cf9ae..82e5ebe7c1 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -142,46 +142,33 @@ async def _convert_tool_union_to_tools( from ..tools.google_search_tool import GoogleSearchTool from ..tools.vertex_ai_search_tool import VertexAiSearchTool - # Wrap google_search tool with AgentTool if there are multiple tools because - # the built-in tools cannot be used together with other tools. + # Handle built-in tool workarounds when multiple tools are present. + # Built-in tools cannot be used together with other tools, so we wrap or + # replace them with compatible alternatives. # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, GoogleSearchTool): - from ..tools.google_search_agent_tool import create_google_search_agent - from ..tools.google_search_agent_tool import GoogleSearchAgentTool - - search_tool = cast(GoogleSearchTool, tool_union) - if search_tool.bypass_multi_tools_limit: - return [GoogleSearchAgentTool(create_google_search_agent(model))] - - # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are - # multiple tools because the built-in tools cannot be used together with - # other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, VertexAiSearchTool): - from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool - - vais_tool = cast(VertexAiSearchTool, tool_union) - if vais_tool.bypass_multi_tools_limit: - return [ - DiscoveryEngineSearchTool( - data_store_id=vais_tool.data_store_id, - data_store_specs=vais_tool.data_store_specs, - search_engine_id=vais_tool.search_engine_id, - filter=vais_tool.filter, - max_results=vais_tool.max_results, - ) - ] - - # Wrap enterprise_web_search tool with AgentTool if there are multiple tools - # because the built-in tools cannot be used together with other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, EnterpriseWebSearchTool): - from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent - from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool - - enterprise_tool = cast(EnterpriseWebSearchTool, tool_union) - if enterprise_tool.bypass_multi_tools_limit: - return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))] + if multiple_tools: + tool_workarounds = [ + # GoogleSearchTool: wrap with AgentTool + { + 'tool_class': GoogleSearchTool, + 'handler': lambda: _handle_google_search_tool(tool_union, model), + }, + # VertexAiSearchTool: replace with DiscoveryEngineSearchTool + { + 'tool_class': VertexAiSearchTool, + 'handler': lambda: _handle_vertex_ai_search_tool(tool_union), + }, + # EnterpriseWebSearchTool: wrap with AgentTool + { + 'tool_class': EnterpriseWebSearchTool, + 'handler': lambda: _handle_enterprise_search_tool(tool_union, model), + }, + ] + + for workaround in tool_workarounds: + if isinstance(tool_union, workaround['tool_class']): + if tool_union.bypass_multi_tools_limit: + return workaround['handler']() if isinstance(tool_union, BaseTool): return [tool_union] @@ -192,6 +179,43 @@ async def _convert_tool_union_to_tools( return await tool_union.get_tools_with_prefix(ctx) +def _handle_google_search_tool( + tool_union: ToolUnion, model: Union[str, BaseLlm] +) -> list[BaseTool]: + """Handle GoogleSearchTool workaround by wrapping with AgentTool.""" + from ..tools.google_search_agent_tool import create_google_search_agent + from ..tools.google_search_agent_tool import GoogleSearchAgentTool + + return [GoogleSearchAgentTool(create_google_search_agent(model))] + + +def _handle_vertex_ai_search_tool(tool_union: ToolUnion) -> list[BaseTool]: + """Handle VertexAiSearchTool workaround by replacing with DiscoveryEngineSearchTool.""" + from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool + from ..tools.vertex_ai_search_tool import VertexAiSearchTool + + vais_tool = cast(VertexAiSearchTool, tool_union) + return [ + DiscoveryEngineSearchTool( + data_store_id=vais_tool.data_store_id, + data_store_specs=vais_tool.data_store_specs, + search_engine_id=vais_tool.search_engine_id, + filter=vais_tool.filter, + max_results=vais_tool.max_results, + ) + ] + + +def _handle_enterprise_search_tool( + tool_union: ToolUnion, model: Union[str, BaseLlm] +) -> list[BaseTool]: + """Handle EnterpriseWebSearchTool workaround by wrapping with AgentTool.""" + from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent + from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool + + return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))] + + class LlmAgent(BaseAgent): """LLM-based Agent.""" diff --git a/src/google/adk/tools/_search_agent_tool.py b/src/google/adk/tools/_search_agent_tool.py index 37742e6a02..5c37ca6e91 100644 --- a/src/google/adk/tools/_search_agent_tool.py +++ b/src/google/adk/tools/_search_agent_tool.py @@ -39,7 +39,6 @@ async def run_async( args: dict[str, Any], tool_context: ToolContext, ) -> Any: - from ..agents.llm_agent import LlmAgent if isinstance(self.agent, LlmAgent) and self.agent.input_schema: input_value = self.agent.input_schema.model_validate(args) diff --git a/src/google/adk/tools/enterprise_search_agent_tool.py b/src/google/adk/tools/enterprise_search_agent_tool.py index e064f17ff6..1ba40eb424 100644 --- a/src/google/adk/tools/enterprise_search_agent_tool.py +++ b/src/google/adk/tools/enterprise_search_agent_tool.py @@ -47,7 +47,7 @@ class EnterpriseSearchAgentTool(_SearchAgentTool): TODO(b/448114567): Remove once the workaround is no longer needed. Attributes: - model: The model to use for the sub-agent. + agent: The sub-agent that this tool wraps. """ def __init__(self, agent: LlmAgent): diff --git a/tests/unittests/tools/test_enterprise_search_agent_tool.py b/tests/unittests/tools/test_enterprise_search_agent_tool.py index 6b107a50b2..75f190d99b 100644 --- a/tests/unittests/tools/test_enterprise_search_agent_tool.py +++ b/tests/unittests/tools/test_enterprise_search_agent_tool.py @@ -12,83 +12,130 @@ # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - -from unittest import mock - from google.adk.agents.invocation_context import InvocationContext -from google.adk.agents.llm_agent import LlmAgent -from google.adk.agents.run_config import RunConfig -from google.adk.agents.sequential_agent import SequentialAgent -from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService -from google.adk.memory.in_memory_memory_service import InMemoryMemoryService -from google.adk.plugins.plugin_manager import PluginManager +from google.adk.agents.llm_agent import Agent +from google.adk.models.llm_response import LlmResponse from google.adk.sessions.in_memory_session_service import InMemorySessionService -from google.adk.tools.enterprise_search_agent_tool import create_enterprise_search_agent from google.adk.tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool from google.adk.tools.tool_context import ToolContext +from google.genai import types +from google.genai.types import Part from pytest import mark +from .. import testing_utils + +function_call_no_schema = Part.from_function_call( + name='tool_agent', args={'request': 'test1'} +) + + +grounding_metadata = types.GroundingMetadata(web_search_queries=['test query']) + + +# TODO(b/448114567): Remove test_grounding_metadata_ tests once the workaround +# is no longer needed. + + +@mark.asyncio +async def test_grounding_metadata_is_stored_in_state_during_invocation(): + """Verify grounding_metadata is stored in the state during invocation.""" + + # Mock model for the tool_agent that returns grounding_metadata + tool_agent_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content( + parts=[Part.from_text(text='response from tool')] + ), + grounding_metadata=grounding_metadata, + ) + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=tool_agent_model, + ) + + agent_tool = EnterpriseSearchAgentTool(agent=tool_agent) -async def _create_tool_context() -> ToolContext: session_service = InMemorySessionService() session = await session_service.create_session( app_name='test_app', user_id='test_user' ) - agent = SequentialAgent(name='test_agent') + invocation_context = InvocationContext( invocation_id='invocation_id', - agent=agent, + agent=tool_agent, session=session, session_service=session_service, - artifact_service=InMemoryArtifactService(), - memory_service=InMemoryMemoryService(), - plugin_manager=PluginManager(), - run_config=RunConfig(), ) - return ToolContext(invocation_context=invocation_context) - - -class TestEnterpriseSearchAgentTool: - """Test the EnterpriseSearchAgentTool class.""" - - def test_create_enterprise_search_agent(self): - """Test that create_enterprise_search_agent creates a valid agent.""" - agent = create_enterprise_search_agent('gemini-pro') - assert isinstance(agent, LlmAgent) - assert agent.name == 'enterprise_search_agent' - assert 'enterprise_web_search' in [t.name for t in agent.tools] - - def test_enterprise_search_agent_tool_init(self): - """Test initialization of EnterpriseSearchAgentTool.""" - mock_agent = mock.MagicMock(spec=LlmAgent) - mock_agent.name = 'test_agent' - mock_agent.description = 'test_description' - tool = EnterpriseSearchAgentTool(mock_agent) - assert tool.agent == mock_agent - - @mark.asyncio - @mock.patch('google.adk.tools._search_agent_tool._SearchAgentTool.run_async') - async def test_run_async_succeeds(self, mock_run_async): - """Test that run_async calls the base class method.""" - # Arrange - mock_agent = mock.MagicMock(spec=LlmAgent) - mock_agent.name = 'enterprise_search_agent' - mock_agent.description = 'test_description' - mock_agent.input_schema = None - mock_agent.output_schema = None - - tool = EnterpriseSearchAgentTool(mock_agent) - tool_context = await _create_tool_context() - mock_run_async.return_value = 'test response' - - # Act - result = await tool.run_async( - args={'request': 'test query'}, tool_context=tool_context - ) - - # Assert - mock_run_async.assert_called_once_with( - args={'request': 'test query'}, tool_context=tool_context - ) - assert result == 'test response' + tool_context = ToolContext(invocation_context=invocation_context) + tool_result = await agent_tool.run_async( + args=function_call_no_schema.function_call.args, tool_context=tool_context + ) + + # Verify the tool result + assert tool_result == 'response from tool' + + # Verify grounding_metadata is stored in the state + assert tool_context.state['temp:_adk_grounding_metadata'] == ( + grounding_metadata + ) + + +@mark.asyncio +async def test_grounding_metadata_is_not_stored_in_state_after_invocation(): + """Verify grounding_metadata is not stored in the state after invocation.""" + + # Mock model for the tool_agent that returns grounding_metadata + tool_agent_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content( + parts=[Part.from_text(text='response from tool')] + ), + grounding_metadata=grounding_metadata, + ) + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=tool_agent_model, + ) + + # Mock model for the root_agent + root_agent_model = testing_utils.MockModel.create( + responses=[ + function_call_no_schema, # Call the tool_agent + 'Final response from root', + ] + ) + + root_agent = Agent( + name='root_agent', + model=root_agent_model, + tools=[EnterpriseSearchAgentTool(agent=tool_agent)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + events = runner.run('test input') + + # Find the function response event + function_response_event = None + for event in events: + if event.get_function_responses(): + function_response_event = event + break + + # Verify the function response + assert function_response_event is not None + function_responses = function_response_event.get_function_responses() + assert len(function_responses) == 1 + tool_output = function_responses[0].response + assert tool_output == {'result': 'response from tool'} + + # Verify grounding_metadata is not stored in the root_agent's state + assert 'temp:_adk_grounding_metadata' not in runner.session.state + From e41ea88821ab6986c6b64645fc47dca2dcd7612b Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Fri, 14 Nov 2025 01:50:54 +0100 Subject: [PATCH 4/6] fix: enhance state filtering in _SearchAgentTool to exclude temporary states --- src/google/adk/tools/_search_agent_tool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/tools/_search_agent_tool.py b/src/google/adk/tools/_search_agent_tool.py index 5c37ca6e91..f3015b5232 100644 --- a/src/google/adk/tools/_search_agent_tool.py +++ b/src/google/adk/tools/_search_agent_tool.py @@ -68,7 +68,7 @@ async def run_async( state_dict = { k: v for k, v in tool_context.state.to_dict().items() - if not k.startswith('_adk') # Filter out adk internal states + if not k.startswith('_adk') and not k.startswith('temp:') } session = await runner.session_service.create_session( app_name=self.agent.name, From 1856e83aef99da6ba4946300de88adaa501af687 Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Fri, 14 Nov 2025 20:56:54 +0100 Subject: [PATCH 5/6] style: format lambda handler for better readability in llm_agent.py refactor: remove unnecessary blank line in test_enterprise_search_agent_tool.py --- src/google/adk/agents/llm_agent.py | 4 +++- tests/unittests/tools/test_enterprise_search_agent_tool.py | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 82e5ebe7c1..16b3ff270a 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -161,7 +161,9 @@ async def _convert_tool_union_to_tools( # EnterpriseWebSearchTool: wrap with AgentTool { 'tool_class': EnterpriseWebSearchTool, - 'handler': lambda: _handle_enterprise_search_tool(tool_union, model), + 'handler': lambda: _handle_enterprise_search_tool( + tool_union, model + ), }, ] diff --git a/tests/unittests/tools/test_enterprise_search_agent_tool.py b/tests/unittests/tools/test_enterprise_search_agent_tool.py index 75f190d99b..4aec16ec7e 100644 --- a/tests/unittests/tools/test_enterprise_search_agent_tool.py +++ b/tests/unittests/tools/test_enterprise_search_agent_tool.py @@ -138,4 +138,3 @@ async def test_grounding_metadata_is_not_stored_in_state_after_invocation(): # Verify grounding_metadata is not stored in the root_agent's state assert 'temp:_adk_grounding_metadata' not in runner.session.state - From 8436e1f62a5e0f3f75bc27dcfd24b9561d5b19b0 Mon Sep 17 00:00:00 2001 From: Michael Baden Date: Mon, 26 Jan 2026 23:56:32 +0100 Subject: [PATCH 6/6] fix: resolve import issues in google_search_agent_tool.py after merge --- src/google/adk/tools/google_search_agent_tool.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index a385cc1a22..33b2edf23e 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -17,16 +17,14 @@ from typing import Any from typing import Union -from async_generator import aclosing as Aclosing from google.genai import types from typing_extensions import override from ..agents.llm_agent import LlmAgent from ..memory.in_memory_memory_service import InMemoryMemoryService from ..models.base_llm import BaseLlm -from ..services.artifact.forwarding_artifact_service import ( - ForwardingArtifactService, -) +from ..utils.context_utils import Aclosing +from ._forwarding_artifact_service import ForwardingArtifactService from ._search_agent_tool import _SearchAgentTool from .google_search_tool import google_search from .tool_context import ToolContext