diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index c840efedd8fd..3b2aef4f88dc 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -25,6 +25,7 @@ from google import genai from google.genai import errors +from google.genai.types import HttpOptions from google.genai.types import Part from PIL.Image import Image @@ -108,6 +109,7 @@ def __init__( api_key: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, + use_vertex_flex_api: Optional[bool] = False, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -137,6 +139,7 @@ def __init__( location: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, project must also be provided and api_key should not be set. + use_vertex_flex_api: if true, use the Vertex Flex API. min_batch_size: optional. the minimum batch size to use when batching inputs. max_batch_size: optional. the maximum batch size to use when batching @@ -169,6 +172,8 @@ def __init__( self.location = location self.use_vertex = True + self.use_vertex_flex_api = use_vertex_flex_api + super().__init__( namespace='GeminiModelHandler', retry_filter=_retry_on_appropriate_service_error, @@ -180,8 +185,19 @@ def create_client(self) -> genai.Client: provided when the GeminiModelHandler class is instantiated. """ if self.use_vertex: - return genai.Client( - vertexai=True, project=self.project, location=self.location) + if self.use_vertex_flex_api: + return genai.Client( + vertexai=True, + project=self.project, + location=self.location, + http_options=HttpOptions( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + # Set timeout in the unit of millisecond. + timeout=600000)) + else: + return genai.Client( + vertexai=True, project=self.project, location=self.location) return genai.Client(api_key=self.api_key) def request( diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index cb73c7de13f4..012287e98f3e 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,6 +17,7 @@ # pytype: skip-file import unittest +from unittest import mock try: from google.genai import errors @@ -81,5 +82,29 @@ def test_missing_all_params(self): ) +@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client') +@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions') +class TestGeminiModelHandler(unittest.TestCase): + def test_create_client_with_flex_api( + self, mock_http_options, mock_genai_client): + handler = GeminiModelHandler( + model_name="gemini-pro", + request_fn=generate_from_string, + project="test-project", + location="us-central1", + use_vertex_flex_api=True) + handler.create_client() + mock_http_options.assert_called_with( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + timeout=600000, + ) + mock_genai_client.assert_called_with( + vertexai=True, + project="test-project", + location="us-central1", + http_options=mock_http_options.return_value) + + if __name__ == '__main__': unittest.main()