@@ -929,76 +929,172 @@ async def connect(
929929 base_url = self ._api_client ._websocket_base_url ()
930930 if isinstance (base_url , bytes ):
931931 base_url = base_url .decode ('utf-8' )
932- transformed_model = t .t_model (self ._api_client , model ) # type: ignore
933932
934933 parameter_model = await _t_live_connect_config (self ._api_client , config )
935934
936- if self ._api_client .api_key and not self ._api_client .vertexai :
937- version = self ._api_client ._http_options .api_version
938- api_key = self ._api_client .api_key
939- method = 'BidiGenerateContent'
940- original_headers = self ._api_client ._http_options .headers
941- headers = original_headers .copy () if original_headers is not None else {}
942- if api_key .startswith ('auth_tokens/' ):
935+ if self ._api_client .vertexai :
936+ uri , headers , request = await self ._prepare_connection_vertex (
937+ base_url = base_url , model = model , parameter_model = parameter_model
938+ )
939+ else :
940+ uri , headers , request = await self ._prepare_connection_mldev (
941+ base_url = base_url , model = model , parameter_model = parameter_model
942+ )
943+
944+ if parameter_model .tools and _mcp_utils .has_mcp_tool_usage (
945+ parameter_model .tools
946+ ):
947+ if headers is None :
948+ headers = {}
949+ _mcp_utils .set_mcp_usage_header (headers )
950+
951+ async with ws_connect (
952+ uri , additional_headers = headers , ** self ._api_client ._websocket_ssl_ctx
953+ ) as ws :
954+ await ws .send (request )
955+ try :
956+ # websockets 14.0+
957+ raw_response = await ws .recv (decode = False )
958+ except TypeError :
959+ raw_response = await ws .recv () # type: ignore[assignment]
960+ if raw_response :
961+ try :
962+ response = json .loads (raw_response )
963+ except json .decoder .JSONDecodeError as e :
964+ raise ValueError (f'Failed to parse response: { raw_response !r} ' ) from e
965+ else :
966+ response = {}
967+
968+ if self ._api_client .vertexai :
969+ response_dict = live_converters ._LiveServerMessage_from_vertex (response )
970+ else :
971+ response_dict = response
972+
973+ setup_response = types .LiveServerMessage ._from_response (
974+ response = response_dict , kwargs = parameter_model .model_dump ()
975+ )
976+ if setup_response .setup_complete :
977+ session_id = setup_response .setup_complete .session_id
978+ else :
979+ session_id = None
980+ yield AsyncSession (
981+ api_client = self ._api_client ,
982+ websocket = ws ,
983+ session_id = session_id ,
984+ )
985+
986+ async def _prepare_connection_mldev (
987+ self , * ,
988+ base_url : str ,
989+ model : str ,
990+ parameter_model : types .LiveConnectConfig ,
991+ ) -> tuple [str , _common .StringDict , str ]:
992+ """Prepares live connection parameters for the MLDev API.
993+
994+ Constructs the WebSocket URI, headers, and request body necessary
995+ to establish a connection with the MLDev backend.
996+
997+ Args:
998+ base_url: The base URL for the WebSocket connection.
999+ model: The name of the model to use.
1000+ parameter_model: Configuration parameters for the connection.
1001+
1002+ Returns:
1003+ A tuple containing:
1004+ - uri: The WebSocket connection URI.
1005+ - headers: A dictionary of headers for the connection.
1006+ - request: The JSON-serialized request body.
1007+
1008+ Raises:
1009+ ValueError: If an API key is not provided.
1010+ """
1011+ transformed_model = t .t_model (self ._api_client , model ) # type: ignore
1012+ version = self ._api_client ._http_options .api_version
1013+ method = 'BidiGenerateContent'
1014+ original_headers = self ._api_client ._http_options .headers
1015+ headers = original_headers .copy () if original_headers is not None else {}
1016+ api_key = self ._api_client .api_key
1017+
1018+ if not api_key :
1019+ # this shouldn't happen
1020+ raise ValueError ('Genai live connection requires an API key.' )
1021+
1022+ if api_key .startswith ('auth_tokens/' ):
1023+ method = 'BidiGenerateContentConstrained'
1024+ headers ['Authorization' ] = f'Token { api_key } '
1025+ warnings .warn (
1026+ message = (
1027+ "The SDK's ephemeral token support is experimental, and may"
1028+ ' change in future versions.'
1029+ ),
1030+ category = errors .ExperimentalWarning ,
1031+ )
1032+ if version != 'v1alpha' :
9431033 warnings .warn (
9441034 message = (
945- "The SDK's ephemeral token support is experimental, and may"
946- ' change in future versions.'
1035+ "The SDK's ephemeral token support is in v1alpha only."
1036+ 'Please use client = genai.Client(api_key=token.name, '
1037+ 'http_options=types.HttpOptions(api_version="v1alpha"))'
1038+ ' before session connection.'
9471039 ),
9481040 category = errors .ExperimentalWarning ,
9491041 )
950- method = 'BidiGenerateContentConstrained'
951- headers ['Authorization' ] = f'Token { api_key } '
952- if version != 'v1alpha' :
953- warnings .warn (
954- message = (
955- "The SDK's ephemeral token support is in v1alpha only."
956- 'Please use client = genai.Client(api_key=token.name, '
957- 'http_options=types.HttpOptions(api_version="v1alpha"))'
958- ' before session connection.'
959- ),
960- category = errors .ExperimentalWarning ,
961- )
962- uri = f'{ base_url } /ws/google.ai.generativelanguage.{ version } .GenerativeService.{ method } '
963-
964- request_dict = _common .convert_to_dict (
965- live_converters ._LiveConnectParameters_to_mldev (
966- api_client = self ._api_client ,
967- from_object = types .LiveConnectParameters (
968- model = transformed_model ,
969- config = parameter_model ,
970- ).model_dump (exclude_none = True ),
971- )
972- )
973- del request_dict ['config' ]
9741042
975- setv ( request_dict , [ 'setup' , 'model' ], transformed_model )
1043+ uri = f' { base_url } /ws/google.ai.generativelanguage. { version } .GenerativeService. { method } '
9761044
977- request = json .dumps (request_dict )
978- elif self ._api_client .api_key and self ._api_client .vertexai :
979- # Headers already contains api key for express mode.
980- api_key = self ._api_client .api_key
981- version = self ._api_client ._http_options .api_version
982- uri = f'{ base_url } /ws/google.cloud.aiplatform.{ version } .LlmBidiService/BidiGenerateContent'
983- original_headers = self ._api_client ._http_options .headers
984- headers = original_headers .copy () if original_headers is not None else {}
985-
986- request_dict = _common .convert_to_dict (
987- live_converters ._LiveConnectParameters_to_vertex (
988- api_client = self ._api_client ,
989- from_object = types .LiveConnectParameters (
990- model = transformed_model ,
991- config = parameter_model ,
992- ).model_dump (exclude_none = True ),
993- )
994- )
995- del request_dict ['config' ]
1045+ request_dict = _common .convert_to_dict (
1046+ live_converters ._LiveConnectParameters_to_mldev (
1047+ api_client = self ._api_client ,
1048+ from_object = types .LiveConnectParameters (
1049+ model = transformed_model ,
1050+ config = parameter_model ,
1051+ ).model_dump (exclude_none = True ),
1052+ )
1053+ )
1054+ del request_dict ['config' ]
1055+
1056+ setv (request_dict , ['setup' , 'model' ], transformed_model )
1057+
1058+ return uri , headers , json .dumps (request_dict )
1059+
1060+
1061+ async def _prepare_connection_vertex (
1062+ self , * ,
1063+ base_url : str ,
1064+ model : str ,
1065+ parameter_model : types .LiveConnectConfig ,
1066+ ) -> tuple [str , _common .StringDict , str ]:
1067+ """Prepares live connection parameters for the Vertex AI API.
9961068
997- setv (request_dict , ['setup' , 'model' ], transformed_model )
1069+ Constructs the WebSocket URI, headers, and request body necessary
1070+ to establish a connection with the Vertex AI backend. Handles
1071+ authentication using either an API key or default credentials.
9981072
999- request = json .dumps (request_dict )
1073+ Args:
1074+ base_url: The base URL for the WebSocket connection.
1075+ model: The name of the model to use.
1076+ parameter_model: Configuration parameters for the connection.
1077+
1078+ Returns:
1079+ A tuple containing:
1080+ - uri: The WebSocket connection URI.
1081+ - headers: A dictionary of headers for the connection.
1082+ - request: The JSON-serialized request body.
1083+
1084+ Raises:
1085+ ValueError: If project and location are not provided when
1086+ default credentials are used.
1087+ """
1088+ transformed_model = t .t_model (self ._api_client , model ) # type: ignore
1089+ version = self ._api_client ._http_options .api_version
1090+ original_headers = self ._api_client ._http_options .headers
1091+ headers = (
1092+ original_headers .copy () if original_headers is not None else {}
1093+ )
1094+ if api_key := self ._api_client .api_key :
1095+ # Headers already contains api key
1096+ uri = f'{ base_url } /ws/google.cloud.aiplatform.{ version } .LlmBidiService/BidiGenerateContent'
10001097 else :
1001- version = self ._api_client ._http_options .api_version
10021098 has_sufficient_auth = (
10031099 self ._api_client .project and self ._api_client .location
10041100 )
@@ -1028,13 +1124,9 @@ async def connect(
10281124 # Need to refresh credentials to populate those
10291125 if not (creds .token and creds .valid ):
10301126 auth_req = google .auth .transport .requests .Request () # type: ignore
1031- creds .refresh ( auth_req )
1127+ await asyncio . to_thread ( creds .refresh , auth_req )
10321128 bearer_token = creds .token
10331129
1034- original_headers = self ._api_client ._http_options .headers
1035- headers = (
1036- original_headers .copy () if original_headers is not None else {}
1037- )
10381130 if not headers .get ('Authorization' ):
10391131 headers ['Authorization' ] = f'Bearer { bearer_token } '
10401132
@@ -1044,17 +1136,22 @@ async def connect(
10441136 transformed_model = (
10451137 f'projects/{ project } /locations/{ location } /' + transformed_model
10461138 )
1047- request_dict = _common .convert_to_dict (
1048- live_converters ._LiveConnectParameters_to_vertex (
1049- api_client = self ._api_client ,
1050- from_object = types .LiveConnectParameters (
1051- model = transformed_model ,
1052- config = parameter_model ,
1053- ).model_dump (exclude_none = True ),
1054- )
1055- )
1056- del request_dict ['config' ]
10571139
1140+ request_dict = _common .convert_to_dict (
1141+ live_converters ._LiveConnectParameters_to_vertex (
1142+ api_client = self ._api_client ,
1143+ from_object = types .LiveConnectParameters (
1144+ model = transformed_model ,
1145+ config = parameter_model ,
1146+ ).model_dump (exclude_none = True ),
1147+ )
1148+ )
1149+ del request_dict ['config' ]
1150+
1151+ if api_key is None :
1152+ # Refactor note: I'm surprised the two paths are different, you'd have
1153+ # to test every model to be sure. The goal of this refactor is to not
1154+ # change any behavior so leaving it as is.
10581155 if (
10591156 getv (
10601157 request_dict , ['setup' , 'generationConfig' , 'responseModalities' ]
@@ -1067,49 +1164,10 @@ async def connect(
10671164 ['AUDIO' ],
10681165 )
10691166
1070- request = json .dumps (request_dict )
1167+ return uri , headers , json .dumps (request_dict )
10711168
1072- if parameter_model .tools and _mcp_utils .has_mcp_tool_usage (
1073- parameter_model .tools
1074- ):
1075- if headers is None :
1076- headers = {}
1077- _mcp_utils .set_mcp_usage_header (headers )
10781169
1079- async with ws_connect (
1080- uri , additional_headers = headers , ** self ._api_client ._websocket_ssl_ctx
1081- ) as ws :
1082- await ws .send (request )
1083- try :
1084- # websockets 14.0+
1085- raw_response = await ws .recv (decode = False )
1086- except TypeError :
1087- raw_response = await ws .recv () # type: ignore[assignment]
1088- if raw_response :
1089- try :
1090- response = json .loads (raw_response )
1091- except json .decoder .JSONDecodeError :
1092- raise ValueError (f'Failed to parse response: { raw_response !r} ' )
1093- else :
1094- response = {}
10951170
1096- if self ._api_client .vertexai :
1097- response_dict = live_converters ._LiveServerMessage_from_vertex (response )
1098- else :
1099- response_dict = response
1100-
1101- setup_response = types .LiveServerMessage ._from_response (
1102- response = response_dict , kwargs = parameter_model .model_dump ()
1103- )
1104- if setup_response .setup_complete :
1105- session_id = setup_response .setup_complete .session_id
1106- else :
1107- session_id = None
1108- yield AsyncSession (
1109- api_client = self ._api_client ,
1110- websocket = ws ,
1111- session_id = session_id ,
1112- )
11131171
11141172
11151173async def _t_live_connect_config (
0 commit comments