Skip to content

Commit 8ee000c

Browse files
MarkDaoustcopybara-github
authored andcommitted
chore: refactor live conection parameters into separate vertex and mldef functions.
chore: move the credentials refresh into asyncio.to_thread (function is async, this should be too) The diff looks messy, but this really just moves the contents of the giant if/else to _prepare_connection_vertex and _prepare_connection_mldev. No behavior changes expected. Existing tests pass, additional tests added to cover additional cases. PiperOrigin-RevId: 835355573
1 parent 99058b6 commit 8ee000c

File tree

2 files changed

+244
-112
lines changed

2 files changed

+244
-112
lines changed

google/genai/live.py

Lines changed: 170 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -929,76 +929,172 @@ async def connect(
929929
base_url = self._api_client._websocket_base_url()
930930
if isinstance(base_url, bytes):
931931
base_url = base_url.decode('utf-8')
932-
transformed_model = t.t_model(self._api_client, model) # type: ignore
933932

934933
parameter_model = await _t_live_connect_config(self._api_client, config)
935934

936-
if self._api_client.api_key and not self._api_client.vertexai:
937-
version = self._api_client._http_options.api_version
938-
api_key = self._api_client.api_key
939-
method = 'BidiGenerateContent'
940-
original_headers = self._api_client._http_options.headers
941-
headers = original_headers.copy() if original_headers is not None else {}
942-
if api_key.startswith('auth_tokens/'):
935+
if self._api_client.vertexai:
936+
uri, headers, request = await self._prepare_connection_vertex(
937+
base_url=base_url, model=model, parameter_model=parameter_model
938+
)
939+
else:
940+
uri, headers, request = await self._prepare_connection_mldev(
941+
base_url=base_url, model=model, parameter_model=parameter_model
942+
)
943+
944+
if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
945+
parameter_model.tools
946+
):
947+
if headers is None:
948+
headers = {}
949+
_mcp_utils.set_mcp_usage_header(headers)
950+
951+
async with ws_connect(
952+
uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
953+
) as ws:
954+
await ws.send(request)
955+
try:
956+
# websockets 14.0+
957+
raw_response = await ws.recv(decode=False)
958+
except TypeError:
959+
raw_response = await ws.recv() # type: ignore[assignment]
960+
if raw_response:
961+
try:
962+
response = json.loads(raw_response)
963+
except json.decoder.JSONDecodeError as e:
964+
raise ValueError(f'Failed to parse response: {raw_response!r}') from e
965+
else:
966+
response = {}
967+
968+
if self._api_client.vertexai:
969+
response_dict = live_converters._LiveServerMessage_from_vertex(response)
970+
else:
971+
response_dict = response
972+
973+
setup_response = types.LiveServerMessage._from_response(
974+
response=response_dict, kwargs=parameter_model.model_dump()
975+
)
976+
if setup_response.setup_complete:
977+
session_id = setup_response.setup_complete.session_id
978+
else:
979+
session_id = None
980+
yield AsyncSession(
981+
api_client=self._api_client,
982+
websocket=ws,
983+
session_id=session_id,
984+
)
985+
986+
async def _prepare_connection_mldev(
987+
self, *,
988+
base_url: str,
989+
model: str,
990+
parameter_model: types.LiveConnectConfig,
991+
) -> tuple[str, _common.StringDict, str]:
992+
"""Prepares live connection parameters for the MLDev API.
993+
994+
Constructs the WebSocket URI, headers, and request body necessary
995+
to establish a connection with the MLDev backend.
996+
997+
Args:
998+
base_url: The base URL for the WebSocket connection.
999+
model: The name of the model to use.
1000+
parameter_model: Configuration parameters for the connection.
1001+
1002+
Returns:
1003+
A tuple containing:
1004+
- uri: The WebSocket connection URI.
1005+
- headers: A dictionary of headers for the connection.
1006+
- request: The JSON-serialized request body.
1007+
1008+
Raises:
1009+
ValueError: If an API key is not provided.
1010+
"""
1011+
transformed_model = t.t_model(self._api_client, model) # type: ignore
1012+
version = self._api_client._http_options.api_version
1013+
method = 'BidiGenerateContent'
1014+
original_headers = self._api_client._http_options.headers
1015+
headers = original_headers.copy() if original_headers is not None else {}
1016+
api_key = self._api_client.api_key
1017+
1018+
if not api_key:
1019+
# this shouldn't happen
1020+
raise ValueError('Genai live connection requires an API key.')
1021+
1022+
if api_key.startswith('auth_tokens/'):
1023+
method = 'BidiGenerateContentConstrained'
1024+
headers['Authorization'] = f'Token {api_key}'
1025+
warnings.warn(
1026+
message=(
1027+
"The SDK's ephemeral token support is experimental, and may"
1028+
' change in future versions.'
1029+
),
1030+
category=errors.ExperimentalWarning,
1031+
)
1032+
if version != 'v1alpha':
9431033
warnings.warn(
9441034
message=(
945-
"The SDK's ephemeral token support is experimental, and may"
946-
' change in future versions.'
1035+
"The SDK's ephemeral token support is in v1alpha only."
1036+
'Please use client = genai.Client(api_key=token.name, '
1037+
'http_options=types.HttpOptions(api_version="v1alpha"))'
1038+
' before session connection.'
9471039
),
9481040
category=errors.ExperimentalWarning,
9491041
)
950-
method = 'BidiGenerateContentConstrained'
951-
headers['Authorization'] = f'Token {api_key}'
952-
if version != 'v1alpha':
953-
warnings.warn(
954-
message=(
955-
"The SDK's ephemeral token support is in v1alpha only."
956-
'Please use client = genai.Client(api_key=token.name, '
957-
'http_options=types.HttpOptions(api_version="v1alpha"))'
958-
' before session connection.'
959-
),
960-
category=errors.ExperimentalWarning,
961-
)
962-
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'
963-
964-
request_dict = _common.convert_to_dict(
965-
live_converters._LiveConnectParameters_to_mldev(
966-
api_client=self._api_client,
967-
from_object=types.LiveConnectParameters(
968-
model=transformed_model,
969-
config=parameter_model,
970-
).model_dump(exclude_none=True),
971-
)
972-
)
973-
del request_dict['config']
9741042

975-
setv(request_dict, ['setup', 'model'], transformed_model)
1043+
uri = f'{base_url}/ws/google.ai.generativelanguage.{version}.GenerativeService.{method}'
9761044

977-
request = json.dumps(request_dict)
978-
elif self._api_client.api_key and self._api_client.vertexai:
979-
# Headers already contains api key for express mode.
980-
api_key = self._api_client.api_key
981-
version = self._api_client._http_options.api_version
982-
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
983-
original_headers = self._api_client._http_options.headers
984-
headers = original_headers.copy() if original_headers is not None else {}
985-
986-
request_dict = _common.convert_to_dict(
987-
live_converters._LiveConnectParameters_to_vertex(
988-
api_client=self._api_client,
989-
from_object=types.LiveConnectParameters(
990-
model=transformed_model,
991-
config=parameter_model,
992-
).model_dump(exclude_none=True),
993-
)
994-
)
995-
del request_dict['config']
1045+
request_dict = _common.convert_to_dict(
1046+
live_converters._LiveConnectParameters_to_mldev(
1047+
api_client=self._api_client,
1048+
from_object=types.LiveConnectParameters(
1049+
model=transformed_model,
1050+
config=parameter_model,
1051+
).model_dump(exclude_none=True),
1052+
)
1053+
)
1054+
del request_dict['config']
1055+
1056+
setv(request_dict, ['setup', 'model'], transformed_model)
1057+
1058+
return uri, headers, json.dumps(request_dict)
1059+
1060+
1061+
async def _prepare_connection_vertex(
1062+
self, *,
1063+
base_url: str,
1064+
model: str,
1065+
parameter_model: types.LiveConnectConfig,
1066+
) -> tuple[str, _common.StringDict, str]:
1067+
"""Prepares live connection parameters for the Vertex AI API.
9961068
997-
setv(request_dict, ['setup', 'model'], transformed_model)
1069+
Constructs the WebSocket URI, headers, and request body necessary
1070+
to establish a connection with the Vertex AI backend. Handles
1071+
authentication using either an API key or default credentials.
9981072
999-
request = json.dumps(request_dict)
1073+
Args:
1074+
base_url: The base URL for the WebSocket connection.
1075+
model: The name of the model to use.
1076+
parameter_model: Configuration parameters for the connection.
1077+
1078+
Returns:
1079+
A tuple containing:
1080+
- uri: The WebSocket connection URI.
1081+
- headers: A dictionary of headers for the connection.
1082+
- request: The JSON-serialized request body.
1083+
1084+
Raises:
1085+
ValueError: If project and location are not provided when
1086+
default credentials are used.
1087+
"""
1088+
transformed_model = t.t_model(self._api_client, model) # type: ignore
1089+
version = self._api_client._http_options.api_version
1090+
original_headers = self._api_client._http_options.headers
1091+
headers = (
1092+
original_headers.copy() if original_headers is not None else {}
1093+
)
1094+
if api_key := self._api_client.api_key:
1095+
# Headers already contains api key
1096+
uri = f'{base_url}/ws/google.cloud.aiplatform.{version}.LlmBidiService/BidiGenerateContent'
10001097
else:
1001-
version = self._api_client._http_options.api_version
10021098
has_sufficient_auth = (
10031099
self._api_client.project and self._api_client.location
10041100
)
@@ -1028,13 +1124,9 @@ async def connect(
10281124
# Need to refresh credentials to populate those
10291125
if not (creds.token and creds.valid):
10301126
auth_req = google.auth.transport.requests.Request() # type: ignore
1031-
creds.refresh(auth_req)
1127+
await asyncio.to_thread(creds.refresh, auth_req)
10321128
bearer_token = creds.token
10331129

1034-
original_headers = self._api_client._http_options.headers
1035-
headers = (
1036-
original_headers.copy() if original_headers is not None else {}
1037-
)
10381130
if not headers.get('Authorization'):
10391131
headers['Authorization'] = f'Bearer {bearer_token}'
10401132

@@ -1044,17 +1136,22 @@ async def connect(
10441136
transformed_model = (
10451137
f'projects/{project}/locations/{location}/' + transformed_model
10461138
)
1047-
request_dict = _common.convert_to_dict(
1048-
live_converters._LiveConnectParameters_to_vertex(
1049-
api_client=self._api_client,
1050-
from_object=types.LiveConnectParameters(
1051-
model=transformed_model,
1052-
config=parameter_model,
1053-
).model_dump(exclude_none=True),
1054-
)
1055-
)
1056-
del request_dict['config']
10571139

1140+
request_dict = _common.convert_to_dict(
1141+
live_converters._LiveConnectParameters_to_vertex(
1142+
api_client=self._api_client,
1143+
from_object=types.LiveConnectParameters(
1144+
model=transformed_model,
1145+
config=parameter_model,
1146+
).model_dump(exclude_none=True),
1147+
)
1148+
)
1149+
del request_dict['config']
1150+
1151+
if api_key is None:
1152+
# Refactor note: I'm surprised the two paths are different, you'd have
1153+
# to test every model to be sure. The goal of this refactor is to not
1154+
# change any behavior so leaving it as is.
10581155
if (
10591156
getv(
10601157
request_dict, ['setup', 'generationConfig', 'responseModalities']
@@ -1067,49 +1164,10 @@ async def connect(
10671164
['AUDIO'],
10681165
)
10691166

1070-
request = json.dumps(request_dict)
1167+
return uri, headers, json.dumps(request_dict)
10711168

1072-
if parameter_model.tools and _mcp_utils.has_mcp_tool_usage(
1073-
parameter_model.tools
1074-
):
1075-
if headers is None:
1076-
headers = {}
1077-
_mcp_utils.set_mcp_usage_header(headers)
10781169

1079-
async with ws_connect(
1080-
uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx
1081-
) as ws:
1082-
await ws.send(request)
1083-
try:
1084-
# websockets 14.0+
1085-
raw_response = await ws.recv(decode=False)
1086-
except TypeError:
1087-
raw_response = await ws.recv() # type: ignore[assignment]
1088-
if raw_response:
1089-
try:
1090-
response = json.loads(raw_response)
1091-
except json.decoder.JSONDecodeError:
1092-
raise ValueError(f'Failed to parse response: {raw_response!r}')
1093-
else:
1094-
response = {}
10951170

1096-
if self._api_client.vertexai:
1097-
response_dict = live_converters._LiveServerMessage_from_vertex(response)
1098-
else:
1099-
response_dict = response
1100-
1101-
setup_response = types.LiveServerMessage._from_response(
1102-
response=response_dict, kwargs=parameter_model.model_dump()
1103-
)
1104-
if setup_response.setup_complete:
1105-
session_id = setup_response.setup_complete.session_id
1106-
else:
1107-
session_id = None
1108-
yield AsyncSession(
1109-
api_client=self._api_client,
1110-
websocket=ws,
1111-
session_id=session_id,
1112-
)
11131171

11141172

11151173
async def _t_live_connect_config(

0 commit comments

Comments
 (0)