Skip to content

Commit 4485379

Browse files
xuanyang15copybara-github
authored andcommitted
ADK changes
PiperOrigin-RevId: 816288113
1 parent 0989d64 commit 4485379

6 files changed

Lines changed: 322 additions & 4 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ dependencies = [
3434
"google-api-python-client>=2.157.0, <3.0.0", # Google API client discovery
3535
"google-cloud-aiplatform[agent_engines]>=1.112.0, <2.0.0",# For VertexAI integrations, e.g. example store.
3636
"google-cloud-bigtable>=2.32.0", # For Bigtable database
37+
"google-cloud-discoveryengine>=0.13.12, <0.14.0", # For Discovery Engine Search Tool
3738
"google-cloud-secret-manager>=2.22.0, <3.0.0", # Fetching secrets in RestAPI Tool
3839
"google-cloud-spanner>=3.56.0, <4.0.0", # For Spanner database
3940
"google-cloud-speech>=2.30.0, <3.0.0", # For Audio Transcription

src/google/adk/agents/llm_agent.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import AsyncGenerator
2222
from typing import Awaitable
2323
from typing import Callable
24+
from typing import cast
2425
from typing import ClassVar
2526
from typing import Dict
2627
from typing import Literal
@@ -118,6 +119,7 @@ async def _convert_tool_union_to_tools(
118119
multiple_tools: bool = False,
119120
) -> list[BaseTool]:
120121
from ..tools.google_search_tool import google_search
122+
from ..tools.vertex_ai_search_tool import VertexAiSearchTool
121123

122124
# Wrap google_search tool with AgentTool if there are multiple tools because
123125
# the built-in tools cannot be used together with other tools.
@@ -128,6 +130,24 @@ async def _convert_tool_union_to_tools(
128130

129131
return [GoogleSearchAgentTool(create_google_search_agent(model))]
130132

133+
# Replace VertexAiSearchTool with DiscoveryEngineSearchTool if there are
134+
# multiple tools because the built-in tools cannot be used together with
135+
# other tools.
136+
# TODO(b/448114567): Remove once the workaround is no longer needed.
137+
if multiple_tools and isinstance(tool_union, VertexAiSearchTool):
138+
from ..tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
139+
140+
vais_tool = cast(VertexAiSearchTool, tool_union)
141+
return [
142+
DiscoveryEngineSearchTool(
143+
data_store_id=vais_tool.data_store_id,
144+
data_store_specs=vais_tool.data_store_specs,
145+
search_engine_id=vais_tool.search_engine_id,
146+
filter=vais_tool.filter,
147+
max_results=vais_tool.max_results,
148+
)
149+
]
150+
131151
if isinstance(tool_union, BaseTool):
132152
return [tool_union]
133153
if callable(tool_union):

src/google/adk/tools/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .agent_tool import AgentTool
1919
from .apihub_tool.apihub_toolset import APIHubToolset
2020
from .base_tool import BaseTool
21+
from .discovery_engine_search_tool import DiscoveryEngineSearchTool
2122
from .enterprise_search_tool import enterprise_web_search_tool as enterprise_web_search
2223
from .example_tool import ExampleTool
2324
from .exit_loop_tool import exit_loop
@@ -39,6 +40,7 @@
3940
'APIHubToolset',
4041
'AuthToolArguments',
4142
'BaseTool',
43+
'DiscoveryEngineSearchTool',
4244
'enterprise_web_search',
4345
'google_maps_grounding',
4446
'google_search',
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import Any
18+
from typing import Optional
19+
20+
from google.api_core.exceptions import GoogleAPICallError
21+
import google.auth
22+
from google.cloud import discoveryengine_v1beta as discoveryengine
23+
from google.genai import types
24+
25+
from .function_tool import FunctionTool
26+
27+
28+
class DiscoveryEngineSearchTool(FunctionTool):
29+
"""Tool for searching the discovery engine."""
30+
31+
def __init__(
32+
self,
33+
data_store_id: Optional[str] = None,
34+
data_store_specs: Optional[
35+
list[types.VertexAISearchDataStoreSpec]
36+
] = None,
37+
search_engine_id: Optional[str] = None,
38+
filter: Optional[str] = None,
39+
max_results: Optional[int] = None,
40+
):
41+
"""Initializes the DiscoveryEngineSearchTool.
42+
43+
Args:
44+
data_store_id: The Vertex AI search data store resource ID in the format
45+
of
46+
"projects/{project}/locations/{location}/collections/{collection}/dataStores/{dataStore}".
47+
data_store_specs: Specifications that define the specific DataStores to be
48+
searched. It should only be set if engine is used.
49+
search_engine_id: The Vertex AI search engine resource ID in the format of
50+
"projects/{project}/locations/{location}/collections/{collection}/engines/{engine}".
51+
filter: The filter to be applied to the search request. Default is None.
52+
max_results: The maximum number of results to return. Default is None.
53+
"""
54+
super().__init__(self.discovery_engine_search)
55+
if (data_store_id is None and search_engine_id is None) or (
56+
data_store_id is not None and search_engine_id is not None
57+
):
58+
raise ValueError(
59+
"Either data_store_id or search_engine_id must be specified."
60+
)
61+
if data_store_specs is not None and search_engine_id is None:
62+
raise ValueError(
63+
"search_engine_id must be specified if data_store_specs is specified."
64+
)
65+
66+
self._serving_config = (
67+
f"{data_store_id or search_engine_id}/servingConfigs/default_config"
68+
)
69+
self._data_store_specs = data_store_specs
70+
self._search_engine_id = search_engine_id
71+
self._filter = filter
72+
self._max_results = max_results
73+
74+
credentials, _ = google.auth.default()
75+
self._discovery_engine_client = discoveryengine.SearchServiceClient(
76+
credentials=credentials
77+
)
78+
79+
def discovery_engine_search(
80+
self,
81+
query: str,
82+
) -> dict[str, Any]:
83+
"""Search the discovery engine.
84+
85+
Args:
86+
query: The search query.
87+
88+
Returns:
89+
A dictionary containing the status of the request and the list of search
90+
results, which contains the title, url and content.
91+
"""
92+
request = discoveryengine.SearchRequest(
93+
serving_config=self._serving_config,
94+
query=query,
95+
content_search_spec=discoveryengine.SearchRequest.ContentSearchSpec(
96+
search_result_mode=discoveryengine.SearchRequest.ContentSearchSpec.SearchResultMode.CHUNKS,
97+
chunk_spec=discoveryengine.SearchRequest.ContentSearchSpec.ChunkSpec(
98+
num_previous_chunks=0,
99+
num_next_chunks=0,
100+
),
101+
),
102+
)
103+
104+
if self._data_store_specs:
105+
request.data_store_specs = self._data_store_specs
106+
if self._filter:
107+
request.filter = self._filter
108+
if self._max_results:
109+
request.page_size = self._max_results
110+
111+
results = []
112+
try:
113+
response = self._discovery_engine_client.search(request)
114+
for item in response.results:
115+
chunk = item.chunk
116+
if not chunk or not chunk.document_metadata:
117+
continue
118+
119+
results.append({
120+
"title": chunk.document_metadata.title,
121+
"url": chunk.document_metadata.uri,
122+
"content": chunk.content,
123+
})
124+
except GoogleAPICallError as e:
125+
return {"status": "error", "error_message": str(e)}
126+
return {"status": "success", "results": results}

tests/unittests/agents/test_llm_agent_fields.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,17 @@
1515
"""Unit tests for canonical_xxx fields in LlmAgent."""
1616

1717
from typing import Any
18-
from typing import cast
1918
from typing import Optional
2019

2120
from google.adk.agents.callback_context import CallbackContext
2221
from google.adk.agents.invocation_context import InvocationContext
2322
from google.adk.agents.llm_agent import LlmAgent
24-
from google.adk.agents.loop_agent import LoopAgent
2523
from google.adk.agents.readonly_context import ReadonlyContext
2624
from google.adk.models.llm_request import LlmRequest
2725
from google.adk.models.registry import LLMRegistry
2826
from google.adk.sessions.in_memory_session_service import InMemorySessionService
2927
from google.adk.tools.google_search_tool import google_search
28+
from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool
3029
from google.genai import types
3130
from pydantic import BaseModel
3231
import pytest
@@ -306,6 +305,7 @@ async def test_handle_google_search_with_other_tools(self):
306305

307306
assert len(tools) == 2
308307
assert tools[0].name == '_my_tool'
308+
assert tools[0].__class__.__name__ == 'FunctionTool'
309309
assert tools[1].name == 'google_search_agent'
310310
assert tools[1].__class__.__name__ == 'GoogleSearchAgentTool'
311311

@@ -325,8 +325,8 @@ async def test_handle_google_search_only(self):
325325
assert tools[0].name == 'google_search'
326326
assert tools[0].__class__.__name__ == 'GoogleSearchTool'
327327

328-
async def test_no_google_search(self):
329-
"""Test other tools are not affected."""
328+
async def test_function_tool_only(self):
329+
"""Test that function tool is not affected."""
330330
agent = LlmAgent(
331331
name='test_agent',
332332
model='gemini-pro',
@@ -340,3 +340,38 @@ async def test_no_google_search(self):
340340
assert len(tools) == 1
341341
assert tools[0].name == '_my_tool'
342342
assert tools[0].__class__.__name__ == 'FunctionTool'
343+
344+
async def test_handle_google_vais_with_other_tools(self):
345+
"""Test that VertexAiSearchTool is wrapped into an agent."""
346+
agent = LlmAgent(
347+
name='test_agent',
348+
model='gemini-pro',
349+
tools=[
350+
self._my_tool,
351+
VertexAiSearchTool(data_store_id='test_data_store_id'),
352+
],
353+
)
354+
ctx = await _create_readonly_context(agent)
355+
tools = await agent.canonical_tools(ctx)
356+
357+
assert len(tools) == 2
358+
assert tools[0].name == '_my_tool'
359+
assert tools[0].__class__.__name__ == 'FunctionTool'
360+
assert tools[1].name == 'discovery_engine_search'
361+
assert tools[1].__class__.__name__ == 'DiscoveryEngineSearchTool'
362+
363+
async def test_handle_vais_only(self):
364+
"""Test that VertexAiSearchTool is not wrapped into an agent."""
365+
agent = LlmAgent(
366+
name='test_agent',
367+
model='gemini-pro',
368+
tools=[
369+
VertexAiSearchTool(data_store_id='test_data_store_id'),
370+
],
371+
)
372+
ctx = await _create_readonly_context(agent)
373+
tools = await agent.canonical_tools(ctx)
374+
375+
assert len(tools) == 1
376+
assert tools[0].name == 'vertex_ai_search'
377+
assert tools[0].__class__.__name__ == 'VertexAiSearchTool'

0 commit comments

Comments
 (0)