diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 9133092c3f..f63162d6d0 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -138,38 +138,39 @@ async def _convert_tool_union_to_tools( model: Union[str, BaseLlm], multiple_tools: bool = False, ) -> list[BaseTool]: + from ..tools.enterprise_search_tool import EnterpriseWebSearchTool from ..tools.google_search_tool import GoogleSearchTool from ..tools.vertex_ai_search_tool import VertexAiSearchTool - # Wrap google_search tool with AgentTool if there are multiple tools because - # the built-in tools cannot be used together with other tools. + # Handle built-in tool workarounds when multiple tools are present. + # Built-in tools cannot be used together with other tools, so we wrap or + # replace them with compatible alternatives. # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, GoogleSearchTool): - from ..tools.google_search_agent_tool import create_google_search_agent - from ..tools.google_search_agent_tool import GoogleSearchAgentTool - - search_tool = cast(GoogleSearchTool, tool_union) - if search_tool.bypass_multi_tools_limit: - return [GoogleSearchAgentTool(create_google_search_agent(model))] - - # Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are - # multiple tools because the built-in tools cannot be used together with - # other tools. - # TODO(b/448114567): Remove once the workaround is no longer needed. - if multiple_tools and isinstance(tool_union, VertexAiSearchTool): - from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool - - vais_tool = cast(VertexAiSearchTool, tool_union) - if vais_tool.bypass_multi_tools_limit: - return [ - DiscoveryEngineSearchTool( - data_store_id=vais_tool.data_store_id, - data_store_specs=vais_tool.data_store_specs, - search_engine_id=vais_tool.search_engine_id, - filter=vais_tool.filter, - max_results=vais_tool.max_results, - ) - ] + if multiple_tools: + tool_workarounds = [ + # GoogleSearchTool: wrap with AgentTool + { + 'tool_class': GoogleSearchTool, + 'handler': lambda: _handle_google_search_tool(tool_union, model), + }, + # VertexAiSearchTool: replace with DiscoveryEngineSearchTool + { + 'tool_class': VertexAiSearchTool, + 'handler': lambda: _handle_vertex_ai_search_tool(tool_union), + }, + # EnterpriseWebSearchTool: wrap with AgentTool + { + 'tool_class': EnterpriseWebSearchTool, + 'handler': lambda: _handle_enterprise_search_tool( + tool_union, model + ), + }, + ] + + for workaround in tool_workarounds: + if isinstance(tool_union, workaround['tool_class']): + if tool_union.bypass_multi_tools_limit: + return workaround['handler']() if isinstance(tool_union, BaseTool): return [tool_union] @@ -180,6 +181,43 @@ async def _convert_tool_union_to_tools( return await tool_union.get_tools_with_prefix(ctx) +def _handle_google_search_tool( + tool_union: ToolUnion, model: Union[str, BaseLlm] +) -> list[BaseTool]: + """Handle GoogleSearchTool workaround by wrapping with AgentTool.""" + from ..tools.google_search_agent_tool import create_google_search_agent + from ..tools.google_search_agent_tool import GoogleSearchAgentTool + + return [GoogleSearchAgentTool(create_google_search_agent(model))] + + +def _handle_vertex_ai_search_tool(tool_union: ToolUnion) -> list[BaseTool]: + """Handle VertexAiSearchTool workaround by replacing with DiscoveryEngineSearchTool.""" + from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool + from ..tools.vertex_ai_search_tool import VertexAiSearchTool + + vais_tool = cast(VertexAiSearchTool, tool_union) + return [ + DiscoveryEngineSearchTool( + data_store_id=vais_tool.data_store_id, + data_store_specs=vais_tool.data_store_specs, + search_engine_id=vais_tool.search_engine_id, + filter=vais_tool.filter, + max_results=vais_tool.max_results, + ) + ] + + +def _handle_enterprise_search_tool( + tool_union: ToolUnion, model: Union[str, BaseLlm] +) -> list[BaseTool]: + """Handle EnterpriseWebSearchTool workaround by wrapping with AgentTool.""" + from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent + from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool + + return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))] + + class LlmAgent(BaseAgent): """LLM-based Agent.""" diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index f9f80e6cd0..e89348bb9b 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -930,7 +930,10 @@ async def _maybe_add_grounding_metadata( tools = await agent.canonical_tools(readonly_context) invocation_context.canonical_tools_cache = tools - if not any(tool.name == 'google_search_agent' for tool in tools): + if not any( + tool.name in {'google_search_agent', 'enterprise_search_agent'} + for tool in tools + ): return response ground_metadata = invocation_context.session.state.get( 'temp:_adk_grounding_metadata', None diff --git a/src/google/adk/tools/_search_agent_tool.py b/src/google/adk/tools/_search_agent_tool.py new file mode 100644 index 0000000000..f3015b5232 --- /dev/null +++ b/src/google/adk/tools/_search_agent_tool.py @@ -0,0 +1,112 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any + +from google.genai import types +from typing_extensions import override + +from ..agents.llm_agent import LlmAgent +from ..memory.in_memory_memory_service import InMemoryMemoryService +from ..runners import Runner +from ..sessions.in_memory_session_service import InMemorySessionService +from ..utils.context_utils import Aclosing +from ._forwarding_artifact_service import ForwardingArtifactService +from .agent_tool import AgentTool +from .tool_context import ToolContext + + +class _SearchAgentTool(AgentTool): + """A base class for search agent tools.""" + + @override + async def run_async( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> Any: + + if isinstance(self.agent, LlmAgent) and self.agent.input_schema: + input_value = self.agent.input_schema.model_validate(args) + content = types.Content( + role='user', + parts=[ + types.Part.from_text( + text=input_value.model_dump_json(exclude_none=True) + ) + ], + ) + else: + content = types.Content( + role='user', + parts=[types.Part.from_text(text=args['request'])], + ) + runner = Runner( + app_name=self.agent.name, + agent=self.agent, + artifact_service=ForwardingArtifactService(tool_context), + session_service=InMemorySessionService(), + memory_service=InMemoryMemoryService(), + credential_service=tool_context._invocation_context.credential_service, + plugins=list(tool_context._invocation_context.plugin_manager.plugins), + ) + try: + state_dict = { + k: v + for k, v in tool_context.state.to_dict().items() + if not k.startswith('_adk') and not k.startswith('temp:') + } + session = await runner.session_service.create_session( + app_name=self.agent.name, + user_id=tool_context._invocation_context.user_id, + state=state_dict, + ) + + last_content = None + last_grounding_metadata = None + async with Aclosing( + runner.run_async( + user_id=session.user_id, + session_id=session.id, + new_message=content, + ) + ) as agen: + async for event in agen: + # Forward state delta to parent session. + if event.actions.state_delta: + tool_context.state.update(event.actions.state_delta) + if event.content: + last_content = event.content + last_grounding_metadata = event.grounding_metadata + + if not last_content: + return '' + merged_text = '\n'.join(p.text for p in last_content.parts if p.text) + if isinstance(self.agent, LlmAgent) and self.agent.output_schema: + tool_result = self.agent.output_schema.model_validate_json( + merged_text + ).model_dump(exclude_none=True) + else: + tool_result = merged_text + + if last_grounding_metadata: + tool_context.state['temp:_adk_grounding_metadata'] = ( + last_grounding_metadata + ) + return tool_result + finally: + await runner.close() diff --git a/src/google/adk/tools/enterprise_search_agent_tool.py b/src/google/adk/tools/enterprise_search_agent_tool.py new file mode 100644 index 0000000000..1ba40eb424 --- /dev/null +++ b/src/google/adk/tools/enterprise_search_agent_tool.py @@ -0,0 +1,55 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Union + +from ..agents.llm_agent import LlmAgent +from ..models.base_llm import BaseLlm +from ._search_agent_tool import _SearchAgentTool +from .enterprise_search_tool import enterprise_web_search_tool + + +def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: + """Create a sub-agent that only uses enterprise_web_search tool.""" + return LlmAgent( + name='enterprise_search_agent', + model=model, + description=( + 'An agent for performing Enterprise search using the' + ' `enterprise_web_search` tool' + ), + instruction=""" + You are a specialized Enterprise search agent. + + When given a search query, use the `enterprise_web_search` tool to find the related information. + """, + tools=[enterprise_web_search_tool], + ) + + +class EnterpriseSearchAgentTool(_SearchAgentTool): + """A tool that wraps a sub-agent that only uses enterprise_web_search tool. + + This is a workaround to support using enterprise_web_search tool with other tools. + TODO(b/448114567): Remove once the workaround is no longer needed. + + Attributes: + agent: The sub-agent that this tool wraps. + """ + + def __init__(self, agent: LlmAgent): + self.agent = agent + super().__init__(agent=self.agent) diff --git a/src/google/adk/tools/enterprise_search_tool.py b/src/google/adk/tools/enterprise_search_tool.py index 4f7a0d7f35..c084539090 100644 --- a/src/google/adk/tools/enterprise_search_tool.py +++ b/src/google/adk/tools/enterprise_search_tool.py @@ -40,12 +40,18 @@ class EnterpriseWebSearchTool(BaseTool): """ - def __init__(self): - """Initializes the Enterprise Web Search tool.""" + def __init__(self, *, bypass_multi_tools_limit: bool = False): + """Initializes the Enterprise web search tool. + + Args: + bypass_multi_tools_limit: Whether to bypass the multi tools limitation, + so that the tool can be used with other tools in the same agent. + """ # Name and description are not used because this is a model built-in tool. super().__init__( name='enterprise_web_search', description='enterprise_web_search' ) + self.bypass_multi_tools_limit = bypass_multi_tools_limit @override async def process_llm_request( diff --git a/src/google/adk/tools/google_search_agent_tool.py b/src/google/adk/tools/google_search_agent_tool.py index 56da204e5f..33b2edf23e 100644 --- a/src/google/adk/tools/google_search_agent_tool.py +++ b/src/google/adk/tools/google_search_agent_tool.py @@ -25,7 +25,7 @@ from ..models.base_llm import BaseLlm from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService -from .agent_tool import AgentTool +from ._search_agent_tool import _SearchAgentTool from .google_search_tool import google_search from .tool_context import ToolContext @@ -47,7 +47,7 @@ def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent: ) -class GoogleSearchAgentTool(AgentTool): +class GoogleSearchAgentTool(_SearchAgentTool): """A tool that wraps a sub-agent that only uses google_search tool. This is a workaround to support using google_search tool with other tools. diff --git a/tests/unittests/tools/test_enterprise_search_agent_tool.py b/tests/unittests/tools/test_enterprise_search_agent_tool.py new file mode 100644 index 0000000000..4aec16ec7e --- /dev/null +++ b/tests/unittests/tools/test_enterprise_search_agent_tool.py @@ -0,0 +1,140 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import Agent +from google.adk.models.llm_response import LlmResponse +from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +from google.genai.types import Part +from pytest import mark + +from .. import testing_utils + +function_call_no_schema = Part.from_function_call( + name='tool_agent', args={'request': 'test1'} +) + + +grounding_metadata = types.GroundingMetadata(web_search_queries=['test query']) + + +# TODO(b/448114567): Remove test_grounding_metadata_ tests once the workaround +# is no longer needed. + + +@mark.asyncio +async def test_grounding_metadata_is_stored_in_state_during_invocation(): + """Verify grounding_metadata is stored in the state during invocation.""" + + # Mock model for the tool_agent that returns grounding_metadata + tool_agent_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content( + parts=[Part.from_text(text='response from tool')] + ), + grounding_metadata=grounding_metadata, + ) + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=tool_agent_model, + ) + + agent_tool = EnterpriseSearchAgentTool(agent=tool_agent) + + session_service = InMemorySessionService() + session = await session_service.create_session( + app_name='test_app', user_id='test_user' + ) + + invocation_context = InvocationContext( + invocation_id='invocation_id', + agent=tool_agent, + session=session, + session_service=session_service, + ) + tool_context = ToolContext(invocation_context=invocation_context) + tool_result = await agent_tool.run_async( + args=function_call_no_schema.function_call.args, tool_context=tool_context + ) + + # Verify the tool result + assert tool_result == 'response from tool' + + # Verify grounding_metadata is stored in the state + assert tool_context.state['temp:_adk_grounding_metadata'] == ( + grounding_metadata + ) + + +@mark.asyncio +async def test_grounding_metadata_is_not_stored_in_state_after_invocation(): + """Verify grounding_metadata is not stored in the state after invocation.""" + + # Mock model for the tool_agent that returns grounding_metadata + tool_agent_model = testing_utils.MockModel.create( + responses=[ + LlmResponse( + content=types.Content( + parts=[Part.from_text(text='response from tool')] + ), + grounding_metadata=grounding_metadata, + ) + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=tool_agent_model, + ) + + # Mock model for the root_agent + root_agent_model = testing_utils.MockModel.create( + responses=[ + function_call_no_schema, # Call the tool_agent + 'Final response from root', + ] + ) + + root_agent = Agent( + name='root_agent', + model=root_agent_model, + tools=[EnterpriseSearchAgentTool(agent=tool_agent)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + events = runner.run('test input') + + # Find the function response event + function_response_event = None + for event in events: + if event.get_function_responses(): + function_response_event = event + break + + # Verify the function response + assert function_response_event is not None + function_responses = function_response_event.get_function_responses() + assert len(function_responses) == 1 + tool_output = function_responses[0].response + assert tool_output == {'result': 'response from tool'} + + # Verify grounding_metadata is not stored in the root_agent's state + assert 'temp:_adk_grounding_metadata' not in runner.session.state