From 8329b52d84633e94a173b71622c5ede446bb82d9 Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Wed, 3 Dec 2025 17:47:50 +0000 Subject: [PATCH 1/6] Support Vertex Flex API in GeminiModelHandler --- .../ml/inference/gemini_inference.py | 16 +++++++++++- .../ml/inference/gemini_inference_test.py | 26 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index c840efedd8fd..110793392ed3 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -24,6 +24,7 @@ from typing import Union from google import genai +from google.api_core.client_options import ClientOptions as HttpOptions from google.genai import errors from google.genai.types import Part from PIL.Image import Image @@ -108,6 +109,7 @@ def __init__( api_key: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, + use_vertex_flex_api: Optional[bool]=False, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -169,6 +171,8 @@ def __init__( self.location = location self.use_vertex = True + self.use_vertex_flex_api = use_vertex_flex_api + super().__init__( namespace='GeminiModelHandler', retry_filter=_retry_on_appropriate_service_error, @@ -180,7 +184,17 @@ def create_client(self) -> genai.Client: provided when the GeminiModelHandler class is instantiated. """ if self.use_vertex: - return genai.Client( + if self.use_vertex_flex_api: + return genai.Client( + vertexai=True, project=self.project, location=self.location, + http_options=HttpOptions( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + # Set timeout in the unit of millisecond. + timeout = 600000 + )) + else: + return genai.Client( vertexai=True, project=self.project, location=self.location) return genai.Client(api_key=self.api_key) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index cb73c7de13f4..fcfe940e82f3 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -81,5 +81,31 @@ def test_missing_all_params(self): ) +@unittest.mock.patch( + 'apache_beam.ml.inference.gemini_inference.genai.Client') +@unittest.mock.patch( + 'apache_beam.ml.inference.gemini_inference.HttpOptions') +class TestGeminiModelHandler(unittest.TestCase): + def test_create_client_with_flex_api( + self, mock_http_options, mock_genai_client): + handler = GeminiModelHandler( + model_name="gemini-pro", + request_fn=generate_from_string, + project="test-project", + location="us-central1", + use_vertex_flex_api=True) + handler.create_client() + mock_http_options.assert_called_with( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + timeout=600000, + ) + mock_genai_client.assert_called_with( + vertexai=True, + project="test-project", + location="us-central1", + http_options=mock_http_options.return_value) + + if __name__ == '__main__': unittest.main() From 55f35abe8f26f621511ff805c1a98a7153b0ae35 Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Wed, 3 Dec 2025 21:25:29 +0000 Subject: [PATCH 2/6] Fix lint errors. --- .../ml/inference/gemini_inference.py | 20 ++++++++++--------- .../ml/inference/gemini_inference_test.py | 7 +++---- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index 110793392ed3..e736af9edf43 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -109,7 +109,7 @@ def __init__( api_key: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, - use_vertex_flex_api: Optional[bool]=False, + use_vertex_flex_api: Optional[bool] = False, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -139,6 +139,7 @@ def __init__( location: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, project must also be provided and api_key should not be set. + use_vertex_flex_api: if true, use the Vertex Flex API. min_batch_size: optional. the minimum batch size to use when batching inputs. max_batch_size: optional. the maximum batch size to use when batching @@ -186,16 +187,17 @@ def create_client(self) -> genai.Client: if self.use_vertex: if self.use_vertex_flex_api: return genai.Client( - vertexai=True, project=self.project, location=self.location, - http_options=HttpOptions( - api_version="v1", - headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, - # Set timeout in the unit of millisecond. - timeout = 600000 - )) + vertexai=True, + project=self.project, + location=self.location, + http_options=HttpOptions( + api_version="v1", + headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, + # Set timeout in the unit of millisecond. + timeout=600000)) else: return genai.Client( - vertexai=True, project=self.project, location=self.location) + vertexai=True, project=self.project, location=self.location) return genai.Client(api_key=self.api_key) def request( diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index fcfe940e82f3..e4ea5b349f38 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,6 +17,7 @@ # pytype: skip-file import unittest +from unittest import mock try: from google.genai import errors @@ -81,10 +82,8 @@ def test_missing_all_params(self): ) -@unittest.mock.patch( - 'apache_beam.ml.inference.gemini_inference.genai.Client') -@unittest.mock.patch( - 'apache_beam.ml.inference.gemini_inference.HttpOptions') +@unittest.mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client') +@unittest.mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions') class TestGeminiModelHandler(unittest.TestCase): def test_create_client_with_flex_api( self, mock_http_options, mock_genai_client): From e3161fe737d48b040ecafe9b78502a9e7a889dea Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Thu, 4 Dec 2025 20:46:10 +0000 Subject: [PATCH 3/6] Fix lint. --- sdks/python/apache_beam/ml/inference/gemini_inference_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index e4ea5b349f38..a7db2279b70a 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,7 +17,6 @@ # pytype: skip-file import unittest -from unittest import mock try: from google.genai import errors From 2a77e9dfc2d0eb7b0fc5ef64a6fe805927e19ce0 Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Thu, 4 Dec 2025 21:34:53 +0000 Subject: [PATCH 4/6] Import the correct HttpOptions --- sdks/python/apache_beam/ml/inference/gemini_inference.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index e736af9edf43..27f1d1378cfb 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -24,9 +24,8 @@ from typing import Union from google import genai -from google.api_core.client_options import ClientOptions as HttpOptions from google.genai import errors -from google.genai.types import Part +from google.genai.types import HttpOptions, Part from PIL.Image import Image from apache_beam.ml.inference import utils From 8811991fbec0d3aaa7cab51196d556ec1cf8375d Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Thu, 4 Dec 2025 21:37:11 +0000 Subject: [PATCH 5/6] Fix the unit test so it runs successfully. --- .../python/apache_beam/ml/inference/gemini_inference_test.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index a7db2279b70a..012287e98f3e 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,6 +17,7 @@ # pytype: skip-file import unittest +from unittest import mock try: from google.genai import errors @@ -81,8 +82,8 @@ def test_missing_all_params(self): ) -@unittest.mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client') -@unittest.mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions') +@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client') +@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions') class TestGeminiModelHandler(unittest.TestCase): def test_create_client_with_flex_api( self, mock_http_options, mock_genai_client): From 99642165fa66cabed85e1e7e99d7e30741006b03 Mon Sep 17 00:00:00 2001 From: Xiaochu Liu Date: Thu, 4 Dec 2025 22:48:49 +0000 Subject: [PATCH 6/6] Fix lint again. --- sdks/python/apache_beam/ml/inference/gemini_inference.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index 27f1d1378cfb..3b2aef4f88dc 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -25,7 +25,8 @@ from google import genai from google.genai import errors -from google.genai.types import HttpOptions, Part +from google.genai.types import HttpOptions +from google.genai.types import Part from PIL.Image import Image from apache_beam.ml.inference import utils