Skip to content
94 changes: 66 additions & 28 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,38 +138,39 @@ async def _convert_tool_union_to_tools(
model: Union[str, BaseLlm],
multiple_tools: bool = False,
) -> list[BaseTool]:
from ..tools.enterprise_search_tool import EnterpriseWebSearchTool
from ..tools.google_search_tool import GoogleSearchTool
from ..tools.vertex_ai_search_tool import VertexAiSearchTool

# Wrap google_search tool with AgentTool if there are multiple tools because
# the built-in tools cannot be used together with other tools.
# Handle built-in tool workarounds when multiple tools are present.
# Built-in tools cannot be used together with other tools, so we wrap or
# replace them with compatible alternatives.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if multiple_tools and isinstance(tool_union, GoogleSearchTool):
from ..tools.google_search_agent_tool import create_google_search_agent
from ..tools.google_search_agent_tool import GoogleSearchAgentTool

search_tool = cast(GoogleSearchTool, tool_union)
if search_tool.bypass_multi_tools_limit:
return [GoogleSearchAgentTool(create_google_search_agent(model))]

# Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are
# multiple tools because the built-in tools cannot be used together with
# other tools.
# TODO(b/448114567): Remove once the workaround is no longer needed.
if multiple_tools and isinstance(tool_union, VertexAiSearchTool):
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool

vais_tool = cast(VertexAiSearchTool, tool_union)
if vais_tool.bypass_multi_tools_limit:
return [
DiscoveryEngineSearchTool(
data_store_id=vais_tool.data_store_id,
data_store_specs=vais_tool.data_store_specs,
search_engine_id=vais_tool.search_engine_id,
filter=vais_tool.filter,
max_results=vais_tool.max_results,
)
]
if multiple_tools:
tool_workarounds = [
# GoogleSearchTool: wrap with AgentTool
{
'tool_class': GoogleSearchTool,
'handler': lambda: _handle_google_search_tool(tool_union, model),
},
# VertexAiSearchTool: replace with DiscoveryEngineSearchTool
{
'tool_class': VertexAiSearchTool,
'handler': lambda: _handle_vertex_ai_search_tool(tool_union),
},
# EnterpriseWebSearchTool: wrap with AgentTool
{
'tool_class': EnterpriseWebSearchTool,
'handler': lambda: _handle_enterprise_search_tool(
tool_union, model
),
},
]

for workaround in tool_workarounds:
if isinstance(tool_union, workaround['tool_class']):
if tool_union.bypass_multi_tools_limit:
return workaround['handler']()

if isinstance(tool_union, BaseTool):
return [tool_union]
Expand All @@ -180,6 +181,43 @@ async def _convert_tool_union_to_tools(
return await tool_union.get_tools_with_prefix(ctx)


def _handle_google_search_tool(
tool_union: ToolUnion, model: Union[str, BaseLlm]
) -> list[BaseTool]:
"""Handle GoogleSearchTool workaround by wrapping with AgentTool."""
from ..tools.google_search_agent_tool import create_google_search_agent
from ..tools.google_search_agent_tool import GoogleSearchAgentTool

return [GoogleSearchAgentTool(create_google_search_agent(model))]


def _handle_vertex_ai_search_tool(tool_union: ToolUnion) -> list[BaseTool]:
"""Handle VertexAiSearchTool workaround by replacing with DiscoveryEngineSearchTool."""
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
from ..tools.vertex_ai_search_tool import VertexAiSearchTool

vais_tool = cast(VertexAiSearchTool, tool_union)
return [
DiscoveryEngineSearchTool(
data_store_id=vais_tool.data_store_id,
data_store_specs=vais_tool.data_store_specs,
search_engine_id=vais_tool.search_engine_id,
filter=vais_tool.filter,
max_results=vais_tool.max_results,
)
]


def _handle_enterprise_search_tool(
tool_union: ToolUnion, model: Union[str, BaseLlm]
) -> list[BaseTool]:
"""Handle EnterpriseWebSearchTool workaround by wrapping with AgentTool."""
from ..tools.enterprise_search_agent_tool import create_enterprise_search_agent
from ..tools.enterprise_search_agent_tool import EnterpriseSearchAgentTool

return [EnterpriseSearchAgentTool(create_enterprise_search_agent(model))]


class LlmAgent(BaseAgent):
"""LLM-based Agent."""

Expand Down
5 changes: 4 additions & 1 deletion src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,10 @@ async def _maybe_add_grounding_metadata(
tools = await agent.canonical_tools(readonly_context)
invocation_context.canonical_tools_cache = tools

if not any(tool.name == 'google_search_agent' for tool in tools):
if not any(
tool.name in {'google_search_agent', 'enterprise_search_agent'}
for tool in tools
):
return response
ground_metadata = invocation_context.session.state.get(
'temp:_adk_grounding_metadata', None
Expand Down
112 changes: 112 additions & 0 deletions src/google/adk/tools/_search_agent_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Any

from google.genai import types
from typing_extensions import override

from ..agents.llm_agent import LlmAgent
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from .agent_tool import AgentTool
from .tool_context import ToolContext


class _SearchAgentTool(AgentTool):
"""A base class for search agent tools."""

@override
async def run_async(
self,
*,
args: dict[str, Any],
tool_context: ToolContext,
) -> Any:

if isinstance(self.agent, LlmAgent) and self.agent.input_schema:
input_value = self.agent.input_schema.model_validate(args)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
else:
content = types.Content(
role='user',
parts=[types.Part.from_text(text=args['request'])],
)
runner = Runner(
app_name=self.agent.name,
agent=self.agent,
artifact_service=ForwardingArtifactService(tool_context),
session_service=InMemorySessionService(),
memory_service=InMemoryMemoryService(),
credential_service=tool_context._invocation_context.credential_service,
plugins=list(tool_context._invocation_context.plugin_manager.plugins),
)
try:
state_dict = {
k: v
for k, v in tool_context.state.to_dict().items()
if not k.startswith('_adk') and not k.startswith('temp:')
}
session = await runner.session_service.create_session(
app_name=self.agent.name,
user_id=tool_context._invocation_context.user_id,
state=state_dict,
)

last_content = None
last_grounding_metadata = None
async with Aclosing(
runner.run_async(
user_id=session.user_id,
session_id=session.id,
new_message=content,
)
) as agen:
async for event in agen:
# Forward state delta to parent session.
if event.actions.state_delta:
tool_context.state.update(event.actions.state_delta)
if event.content:
last_content = event.content
last_grounding_metadata = event.grounding_metadata

if not last_content:
return ''
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
tool_result = self.agent.output_schema.model_validate_json(
merged_text
).model_dump(exclude_none=True)
else:
tool_result = merged_text

if last_grounding_metadata:
tool_context.state['temp:_adk_grounding_metadata'] = (
last_grounding_metadata
)
return tool_result
finally:
await runner.close()
55 changes: 55 additions & 0 deletions src/google/adk/tools/enterprise_search_agent_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Union

from ..agents.llm_agent import LlmAgent
from ..models.base_llm import BaseLlm
from ._search_agent_tool import _SearchAgentTool
from .enterprise_search_tool import enterprise_web_search_tool


def create_enterprise_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
"""Create a sub-agent that only uses enterprise_web_search tool."""
return LlmAgent(
name='enterprise_search_agent',
model=model,
description=(
'An agent for performing Enterprise search using the'
' `enterprise_web_search` tool'
),
instruction="""
You are a specialized Enterprise search agent.

When given a search query, use the `enterprise_web_search` tool to find the related information.
""",
tools=[enterprise_web_search_tool],
)


class EnterpriseSearchAgentTool(_SearchAgentTool):
"""A tool that wraps a sub-agent that only uses enterprise_web_search tool.

This is a workaround to support using enterprise_web_search tool with other tools.
TODO(b/448114567): Remove once the workaround is no longer needed.

Attributes:
agent: The sub-agent that this tool wraps.
"""

def __init__(self, agent: LlmAgent):
self.agent = agent
super().__init__(agent=self.agent)
10 changes: 8 additions & 2 deletions src/google/adk/tools/enterprise_search_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,18 @@ class EnterpriseWebSearchTool(BaseTool):

"""

def __init__(self):
"""Initializes the Enterprise Web Search tool."""
def __init__(self, *, bypass_multi_tools_limit: bool = False):
"""Initializes the Enterprise web search tool.

Args:
bypass_multi_tools_limit: Whether to bypass the multi tools limitation,
so that the tool can be used with other tools in the same agent.
"""
# Name and description are not used because this is a model built-in tool.
super().__init__(
name='enterprise_web_search', description='enterprise_web_search'
)
self.bypass_multi_tools_limit = bypass_multi_tools_limit

@override
async def process_llm_request(
Expand Down
4 changes: 2 additions & 2 deletions src/google/adk/tools/google_search_agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..models.base_llm import BaseLlm
from ..utils.context_utils import Aclosing
from ._forwarding_artifact_service import ForwardingArtifactService
from .agent_tool import AgentTool
from ._search_agent_tool import _SearchAgentTool
from .google_search_tool import google_search
from .tool_context import ToolContext

Expand All @@ -47,7 +47,7 @@ def create_google_search_agent(model: Union[str, BaseLlm]) -> LlmAgent:
)


class GoogleSearchAgentTool(AgentTool):
class GoogleSearchAgentTool(_SearchAgentTool):
"""A tool that wraps a sub-agent that only uses google_search tool.

This is a workaround to support using google_search tool with other tools.
Expand Down
Loading
Loading