diff --git a/src/google/adk/models/google_llm.py b/src/google/adk/models/google_llm.py index 23c9c27810..ee2fe65ec3 100644 --- a/src/google/adk/models/google_llm.py +++ b/src/google/adk/models/google_llm.py @@ -19,12 +19,7 @@ import copy from functools import cached_property import logging -from typing import Any -from typing import AsyncGenerator -from typing import cast -from typing import Optional -from typing import TYPE_CHECKING -from typing import Union +from typing import Any, AsyncGenerator, cast, Optional, TYPE_CHECKING, Union from google.genai import types from google.genai.errors import ClientError @@ -40,6 +35,7 @@ from .gemini_llm_connection import GeminiLlmConnection from .llm_response import LlmResponse + if TYPE_CHECKING: from google.genai import Client @@ -334,12 +330,23 @@ def _live_api_version(self) -> str: @cached_property def _live_api_client(self) -> Client: + import os + from google.genai import Client + kwargs = {} + if os.environ.get('GOOGLE_GENAI_USE_VERTEXAI', '').lower() in ('1', 'true'): + from google.cloud.aiplatform import initializer + + kwargs['location'] = initializer.global_config.location + kwargs['project'] = initializer.global_config.project return Client( + **kwargs, http_options=types.HttpOptions( - headers=self._tracking_headers(), api_version=self._live_api_version - ) + headers=self._tracking_headers(), + api_version=self._live_api_version, + base_url=self.base_url, + ), ) @contextlib.asynccontextmanager diff --git a/tests/unittests/models/test_google_llm.py b/tests/unittests/models/test_google_llm.py index 70aa01b69d..55a43c5c67 100644 --- a/tests/unittests/models/test_google_llm.py +++ b/tests/unittests/models/test_google_llm.py @@ -22,14 +22,18 @@ from google.adk.agents.context_cache_config import ContextCacheConfig from google.adk.models.cache_metadata import CacheMetadata from google.adk.models.gemini_llm_connection import GeminiLlmConnection +from google.adk.models.google_llm import ( + _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE, +) from google.adk.models.google_llm import _build_function_declaration_log from google.adk.models.google_llm import _build_request_log -from google.adk.models.google_llm import _RESOURCE_EXHAUSTED_POSSIBLE_FIX_MESSAGE from google.adk.models.google_llm import _ResourceExhaustedError from google.adk.models.google_llm import Gemini from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse -from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME +from google.adk.utils._client_labels_utils import ( + _AGENT_ENGINE_TELEMETRY_ENV_VARIABLE_NAME, +) from google.adk.utils._client_labels_utils import _AGENT_ENGINE_TELEMETRY_TAG from google.adk.utils._google_client_headers import get_tracking_headers from google.adk.utils.variant_utils import GoogleLLMVariant @@ -623,15 +627,15 @@ def test_live_api_version_gemini_api(gemini_llm): def test_live_api_client_properties(gemini_llm): - """Test that _live_api_client is properly configured with tracking headers and API version.""" + """Test that _live_api_client uses v1alpha for Gemini API backend.""" with mock.patch.object( - gemini_llm, "_api_backend", GoogleLLMVariant.VERTEX_AI + gemini_llm, "_api_backend", GoogleLLMVariant.GEMINI_API ): client = gemini_llm._live_api_client - # Verify that the client has the correct headers and API version + # Verify that the client has v1alpha for Gemini API http_options = client._api_client._http_options - assert http_options.api_version == "v1beta1" + assert http_options.api_version == "v1alpha" # Check that tracking headers are included tracking_headers = get_tracking_headers() @@ -640,6 +644,27 @@ def test_live_api_client_properties(gemini_llm): assert value in http_options.headers[key] +def test_live_api_client_uses_initializer_location(monkeypatch): + """Test that _live_api_client uses location/project from vertexai.init(). + + vertexai.init(location=...) writes to google.cloud.aiplatform.initializer. + Previously genai.Client() ignored that state and fell back to the + GOOGLE_CLOUD_LOCATION env var (defaulting to 'global'), causing native audio + models (gemini-live-2.5-flash-native-audio) to fail with WebSocket 1008 + when the user had called vertexai.init(location='us-central1'). + """ + monkeypatch.setenv("GOOGLE_GENAI_USE_VERTEXAI", "1") + mock_config = mock.MagicMock() + mock_config.location = "us-central1" + mock_config.project = "my-project" + gemini = Gemini(model="gemini-live-2.5-flash-native-audio") + with mock.patch( + "google.cloud.aiplatform.initializer.global_config", mock_config + ): + client = gemini._live_api_client + assert client._api_client.location == "us-central1" + + @pytest.mark.asyncio async def test_connect_with_custom_headers(gemini_llm, llm_request): """Test that connect method updates tracking headers and API version when custom headers are provided."""