From 734a3a817ab5c883c7022583a189900b7f734c1e Mon Sep 17 00:00:00 2001 From: Phillip Hwang Date: Sat, 7 Mar 2026 17:31:06 +0000 Subject: [PATCH 1/2] fix: forward vertexai.init() location/project to genai.Client MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit vertexai.init(location=...) writes to google.cloud.aiplatform.initializer but genai.Client() never reads that state — it falls back to the GOOGLE_CLOUD_LOCATION env var (defaulting to 'global'). This means region-specific models like gemini-live-2.5-flash-native-audio, which are only available in us-central1/us-east1/europe-west4, would fail with WebSocket 1008 even when the user had called vertexai.init(location='us-central1'). Fix: read initializer.global_config.location and .project in both api_client and _live_api_client and pass them explicitly to Client(), bridging the gap between the two configuration systems. --- src/google/adk/models/google_llm.py | 13 +++++++--- tests/unittests/models/test_google_llm.py | 29 +++++++++++++++++++---- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 23c9c27810..b985caf575 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -303,13 +303,16 @@ def api_client(self) -> Client: The api client. """ from google.genai import Client + from google.cloud.aiplatform import initializer return Client( + location=initializer.global_config.location, + project=initializer.global_config.project, http_options=types.HttpOptions( headers=self._tracking_headers(), retry_options=self.retry_options, base_url=self.base_url, - ) + ), ) @cached_property @@ -335,11 +338,15 @@ def _live_api_version(self) -> str: @cached_property def _live_api_client(self) -> Client: from google.genai import Client + from google.cloud.aiplatform import initializer return Client( + location=initializer.global_config.location, + project=initializer.global_config.project, http_options=types.HttpOptions( - headers=self._tracking_headers(), api_version=self._live_api_version - ) + headers=self._tracking_headers(), + api_version=self._live_api_version, + ), ) @contextlib.asynccontextmanager diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 70aa01b69d..ce44416b7c 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -623,15 +623,15 @@ def test_live_api_version_gemini_api(gemini_llm): def test_live_api_client_properties(gemini_llm): - """Test that _live_api_client is properly configured with tracking headers and API version.""" + """Test that _live_api_client uses v1alpha for Gemini API backend.""" with mock.patch.object( - gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API ): client = gemini_llm._live_api_client - # Verify that the client has the correct headers and API version + # Verify that the client has v1alpha for Gemini API http_options = client._api_client._http_options - assert http_options.api_version == "v1beta1" + assert http_options.api_version == "v1alpha" # Check that tracking headers are included tracking_headers = get_tracking_headers() @@ -640,6 +640,27 @@ def test_live_api_client_properties(gemini_llm): assert value in http_options.headers[key] +def test_live_api_client_uses_initializer_location(monkeypatch): + """Test that _live_api_client uses location/project from vertexai.init(). + + vertexai.init(location=...) writes to google.cloud.aiplatform.initializer. + Previously genai.Client() ignored that state and fell back to the + GOOGLE_CLOUD_LOCATION env var (defaulting to 'global'), causing native audio + models (gemini-live-2.5-flash-native-audio) to fail with WebSocket 1008 + when the user had called vertexai.init(location='us-central1'). + """ + monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "1") + mock_config = mock.MagicMock() + mock_config.location = "us-central1" + mock_config.project = "my-project" + gemini = Gemini(model="gemini-live-2.5-flash-native-audio") + with mock.patch( + "google.cloud.aiplatform.initializer.global_config", mock_config + ): + client = gemini._live_api_client + assert client._api_client._location == "us-central1" + + @pytest.mark.asyncio async def test_connect_with_custom_headers(gemini_llm, llm_request): """Test that connect method updates tracking headers and API version when custom headers are provided.""" From 8a2260572a1f03b3a809a7088acd45de05e4d0de Mon Sep 17 00:00:00 2001 From: Phillip Hwang Date: Fri, 13 Mar 2026 18:25:35 +0000 Subject: [PATCH 2/2] fixed the broken unit tests and verified working again --- src/google/adk/models/google_llm.py | 26 +++++++++++------------ tests/unittests/models/test_google_llm.py | 10 ++++++--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index b985caf575..ee2fe65ec3 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,12 +19,7 @@ import copy from functools import cached_property import logging -from typing import Any -from typing import AsyncGenerator -from typing import cast -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union +from typing import Any, AsyncGenerator, cast, Optional, TYPE_CHECKING, Union from google.genai import types from google.genai.errors import ClientError @@ -40,6 +35,7 @@ from .gemini_llm_connection import GeminiLlmConnection from .llm_response import LlmResponse + if TYPE_CHECKING: from google.genai import Client @@ -303,16 +299,13 @@ def api_client(self) -> Client: The api client. """ from google.genai import Client - from google.cloud.aiplatform import initializer return Client( - location=initializer.global_config.location, - project=initializer.global_config.project, http_options=types.HttpOptions( headers=self._tracking_headers(), retry_options=self.retry_options, base_url=self.base_url, - ), + ) ) @cached_property @@ -337,15 +330,22 @@ def _live_api_version(self) -> str: @cached_property def _live_api_client(self) -> Client: + import os + from google.genai import Client - from google.cloud.aiplatform import initializer + kwargs = {} + if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '').lower() in ('1', 'true'): + from google.cloud.aiplatform import initializer + + kwargs['location'] = initializer.global_config.location + kwargs['project'] = initializer.global_config.project return Client( - location=initializer.global_config.location, - project=initializer.global_config.project, + **kwargs, http_options=types.HttpOptions( headers=self._tracking_headers(), api_version=self._live_api_version, + base_url=self.base_url, ), ) diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index ce44416b7c..55a43c5c67 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -22,14 +22,18 @@ from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.models.cache_metadata import CacheMetadata from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.adk.models.google_llm import ( + _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE, +) from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log -from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE from google.adk.models.google_llm import _ResourceExhaustedError from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse -from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME +from google.adk.utils._client_labels_utils import ( + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, +) from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.utils._google_client_headers import get_tracking_headers from google.adk.utils.variant_utils import GoogleLLMVariant @@ -658,7 +662,7 @@ def test_live_api_client_uses_initializer_location(monkeypatch): "google.cloud.aiplatform.initializer.global_config", mock_config ): client = gemini._live_api_client - assert client._api_client._location == "us-central1" + assert client._api_client.location == "us-central1" @pytest.mark.asyncio