Skip to content

Commit 226e873

Browse files
wukathcopybara-github
authored andcommitted
fix: Ensure consistent ADC quota project override in ADK
Fix discovery engine search tool, bigquery agent analytics plugin, and application integration tool to correctly handle the ADC quota project override -- the x-goog-user-project should be set based on the ADC quota project, per gcloud auth team's requirements. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 853841124
1 parent 8afb99a commit 226e873

6 files changed

Lines changed: 108 additions & 22 deletions

File tree

src/google/adk/plugins/bigquery_agent_analytics_plugin.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import uuid
3737
import weakref
3838

39+
from google.api_core import client_options
3940
from google.api_core.exceptions import InternalServerError
4041
from google.api_core.exceptions import ServiceUnavailable
4142
from google.api_core.exceptions import TooManyRequests
@@ -1352,19 +1353,31 @@ async def _lazy_setup(self, **kwargs) -> None:
13521353
if _GLOBAL_WRITE_CLIENT is None:
13531354

13541355
def get_credentials():
1355-
creds, _ = google.auth.default(
1356+
creds, project_id = google.auth.default(
13561357
scopes=["https://www.googleapis.com/auth/cloud-platform"]
13571358
)
1358-
return creds
1359+
return creds, project_id
13591360

1360-
creds = await loop.run_in_executor(self._executor, get_credentials)
1361+
creds, project_id = await loop.run_in_executor(
1362+
self._executor, get_credentials
1363+
)
1364+
quota_project_id = (
1365+
getattr(creds, "quota_project_id", None) or project_id
1366+
)
1367+
options = (
1368+
client_options.ClientOptions(quota_project_id=quota_project_id)
1369+
if quota_project_id
1370+
else None
1371+
)
13611372
client_info = gapic_client_info.ClientInfo(
13621373
user_agent=f"google-adk-bq-logger/{__version__}"
13631374
)
13641375
# Initialize the async client in the current event loop, not in the
13651376
# executor.
13661377
_GLOBAL_WRITE_CLIENT = BigQueryWriteAsyncClient(
1367-
credentials=creds, client_info=client_info
1378+
credentials=creds,
1379+
client_info=client_info,
1380+
client_options=options,
13681381
)
13691382
self.write_client = _GLOBAL_WRITE_CLIENT
13701383

src/google/adk/tools/application_integration_tool/clients/integration_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
self.actions = actions if actions is not None else []
7676
self.service_account_json = service_account_json
7777
self.credential_cache = None
78+
self._quota_project_id = None
7879

7980
def get_openapi_spec_for_integration(self):
8081
"""Gets the OpenAPI spec for the integration.
@@ -92,6 +93,8 @@ def get_openapi_spec_for_integration(self):
9293
"Content-Type": "application/json",
9394
"Authorization": f"Bearer {self._get_access_token()}",
9495
}
96+
if not self.service_account_json:
97+
headers["x-goog-user-project"] = self._quota_project_id or self.project
9598
data = {
9699
"apiTriggerResources": [
97100
{
@@ -247,11 +250,14 @@ def _get_access_token(self) -> str:
247250
)
248251
else:
249252
try:
250-
credentials, _ = default_service_credential(
253+
credentials, project_id = default_service_credential(
251254
scopes=["https://www.googleapis.com/auth/cloud-platform"]
252255
)
253256
except:
254257
credentials = None
258+
if credentials:
259+
quota_project_id = getattr(credentials, "quota_project_id", None)
260+
self._quota_project_id = quota_project_id or project_id
255261

256262
if not credentials:
257263
raise ValueError(

src/google/adk/tools/discovery_engine_search_tool.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any
1818
from typing import Optional
1919

20+
from google.api_core import client_options
2021
from google.api_core.exceptions import GoogleAPICallError
2122
import google.auth
2223
from google.cloud import discoveryengine_v1beta as discoveryengine
@@ -72,8 +73,14 @@ def __init__(
7273
self._max_results = max_results
7374

7475
credentials, _ = google.auth.default()
76+
quota_project_id = getattr(credentials, "quota_project_id", None)
77+
options = (
78+
client_options.ClientOptions(quota_project_id=quota_project_id)
79+
if quota_project_id
80+
else None
81+
)
7582
self._discovery_engine_client = discoveryengine.SearchServiceClient(
76-
credentials=credentials
83+
credentials=credentials, client_options=options
7784
)
7885

7986
def discovery_engine_search(

tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,41 @@ async def test_global_client_reuse(
15681568
await plugin2.shutdown()
15691569
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
15701570

1571+
@pytest.mark.asyncio
1572+
async def test_quota_project_id_used_in_client(
1573+
self,
1574+
mock_bq_client,
1575+
mock_to_arrow_schema,
1576+
mock_asyncio_to_thread,
1577+
):
1578+
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
1579+
mock_creds = mock.create_autospec(
1580+
google.auth.credentials.Credentials, instance=True, spec_set=True
1581+
)
1582+
mock_creds.quota_project_id = "quota-project"
1583+
with mock.patch.object(
1584+
google.auth,
1585+
"default",
1586+
autospec=True,
1587+
return_value=(mock_creds, PROJECT_ID),
1588+
) as mock_auth_default:
1589+
with mock.patch.object(
1590+
bigquery_agent_analytics_plugin,
1591+
"BigQueryWriteAsyncClient",
1592+
autospec=True,
1593+
) as mock_bq_write_cls:
1594+
plugin = bigquery_agent_analytics_plugin.BigQueryAgentAnalyticsPlugin(
1595+
project_id=PROJECT_ID,
1596+
dataset_id=DATASET_ID,
1597+
table_id=TABLE_ID,
1598+
)
1599+
await plugin._ensure_started()
1600+
mock_auth_default.assert_called_once()
1601+
mock_bq_write_cls.assert_called_once()
1602+
_, kwargs = mock_bq_write_cls.call_args
1603+
assert kwargs["client_options"].quota_project_id == "quota-project"
1604+
bigquery_agent_analytics_plugin._GLOBAL_WRITE_CLIENT = None
1605+
15711606
@pytest.mark.asyncio
15721607
async def test_pickle_safety(self, mock_auth_default, mock_bq_client):
15731608
"""Test that the plugin can be pickled safely."""

tests/unittests/tools/application_integration_tool/clients/test_integration_client.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import re
1717
from unittest import mock
1818

19+
from google.adk.tools.application_integration_tool.clients import integration_client
1920
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
2021
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
2122
import google.auth
@@ -110,18 +111,21 @@ def test_get_openapi_spec_for_integration_success(
110111
mock_credentials,
111112
mock_connections_client,
112113
):
114+
mock_credentials.quota_project_id = "quota-project"
115+
mock_credentials.expired = False
113116
expected_spec = {"openapi": "3.0.0", "info": {"title": "Test Integration"}}
114117
mock_response = mock.MagicMock()
115118
mock_response.status_code = 200
116119
mock_response.json.return_value = {"openApiSpec": json.dumps(expected_spec)}
117120

118121
with (
119122
mock.patch.object(
120-
IntegrationClient,
121-
"_get_access_token",
122-
return_value=mock_credentials.token,
123+
integration_client,
124+
"default_service_credential",
125+
return_value=(mock_credentials, project),
123126
),
124-
mock.patch("requests.post", return_value=mock_response),
127+
mock.patch.object(mock_credentials, "refresh", return_value=None),
128+
mock.patch.object(requests, "post", return_value=mock_response),
125129
):
126130
client = IntegrationClient(
127131
project=project,
@@ -140,6 +144,7 @@ def test_get_openapi_spec_for_integration_success(
140144
headers={
141145
"Content-Type": "application/json",
142146
"Authorization": f"Bearer {mock_credentials.token}",
147+
"x-goog-user-project": "quota-project",
143148
},
144149
json={
145150
"apiTriggerResources": [{

tests/unittests/tools/test_discovery_engine_search_tool.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414

1515
from unittest import mock
1616

17+
from google.adk.tools import discovery_engine_search_tool
1718
from google.adk.tools.discovery_engine_search_tool import DiscoveryEngineSearchTool
1819
from google.api_core import exceptions
1920
from google.cloud import discoveryengine_v1beta as discoveryengine
2021
import pytest
2122

23+
from google import auth
24+
2225

2326
@mock.patch(
2427
"google.auth.default",
@@ -76,10 +79,14 @@ def test_init_with_data_store_specs_without_search_engine_id_raises_error(
7679
data_store_id="test_data_store", data_store_specs=[{"id": "123"}]
7780
)
7881

79-
@mock.patch(
80-
"google.cloud.discoveryengine_v1beta.SearchServiceClient",
82+
@mock.patch.object(discovery_engine_search_tool, "client_options")
83+
@mock.patch.object(
84+
discoveryengine,
85+
"SearchServiceClient",
8186
)
82-
def test_discovery_engine_search_success(self, mock_search_client):
87+
def test_discovery_engine_search_success(
88+
self, mock_search_client, mock_client_options
89+
):
8390
"""Test successful discovery engine search."""
8491
mock_response = discoveryengine.SearchResponse()
8592
mock_response.results = [
@@ -98,15 +105,28 @@ def test_discovery_engine_search_success(self, mock_search_client):
98105
)
99106
]
100107
mock_search_client.return_value.search.return_value = mock_response
101-
102-
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
103-
result = tool.discovery_engine_search("test query")
104-
105-
assert result["status"] == "success"
106-
assert len(result["results"]) == 1
107-
assert result["results"][0]["title"] == "Test Title"
108-
assert result["results"][0]["url"] == "http://example.com"
109-
assert result["results"][0]["content"] == "Test Content"
108+
mock_credentials = mock.MagicMock()
109+
mock_credentials.quota_project_id = "test-quota-project"
110+
111+
with mock.patch.object(
112+
auth, "default", return_value=(mock_credentials, "project")
113+
) as mock_auth:
114+
tool = DiscoveryEngineSearchTool(data_store_id="test_data_store")
115+
result = tool.discovery_engine_search("test query")
116+
117+
assert result["status"] == "success"
118+
assert len(result["results"]) == 1
119+
assert result["results"][0]["title"] == "Test Title"
120+
assert result["results"][0]["url"] == "http://example.com"
121+
assert result["results"][0]["content"] == "Test Content"
122+
mock_auth.assert_called_once()
123+
mock_client_options.ClientOptions.assert_called_once_with(
124+
quota_project_id="test-quota-project"
125+
)
126+
mock_search_client.assert_called_once_with(
127+
credentials=mock_credentials,
128+
client_options=mock_client_options.ClientOptions.return_value,
129+
)
110130

111131
@mock.patch(
112132
"google.cloud.discoveryengine_v1beta.SearchServiceClient",

0 commit comments

Comments
 (0)