diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference.py b/sdks/python/apache_beam/ml/inference/gemini_inference.py index 3b2aef4f88dc..c840efedd8fd 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference.py @@ -25,7 +25,6 @@ from google import genai from google.genai import errors -from google.genai.types import HttpOptions from google.genai.types import Part from PIL.Image import Image @@ -109,7 +108,6 @@ def __init__( api_key: Optional[str] = None, project: Optional[str] = None, location: Optional[str] = None, - use_vertex_flex_api: Optional[bool] = False, *, min_batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, @@ -139,7 +137,6 @@ def __init__( location: the GCP project to use for Vertex AI requests. Setting this parameter routes requests to Vertex AI. If this paramter is provided, project must also be provided and api_key should not be set. - use_vertex_flex_api: if true, use the Vertex Flex API. min_batch_size: optional. the minimum batch size to use when batching inputs. max_batch_size: optional. the maximum batch size to use when batching @@ -172,8 +169,6 @@ def __init__( self.location = location self.use_vertex = True - self.use_vertex_flex_api = use_vertex_flex_api - super().__init__( namespace='GeminiModelHandler', retry_filter=_retry_on_appropriate_service_error, @@ -185,19 +180,8 @@ def create_client(self) -> genai.Client: provided when the GeminiModelHandler class is instantiated. """ if self.use_vertex: - if self.use_vertex_flex_api: - return genai.Client( - vertexai=True, - project=self.project, - location=self.location, - http_options=HttpOptions( - api_version="v1", - headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, - # Set timeout in the unit of millisecond. - timeout=600000)) - else: - return genai.Client( - vertexai=True, project=self.project, location=self.location) + return genai.Client( + vertexai=True, project=self.project, location=self.location) return genai.Client(api_key=self.api_key) def request( diff --git a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py index 012287e98f3e..cb73c7de13f4 100644 --- a/sdks/python/apache_beam/ml/inference/gemini_inference_test.py +++ b/sdks/python/apache_beam/ml/inference/gemini_inference_test.py @@ -17,7 +17,6 @@ # pytype: skip-file import unittest -from unittest import mock try: from google.genai import errors @@ -82,29 +81,5 @@ def test_missing_all_params(self): ) -@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client') -@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions') -class TestGeminiModelHandler(unittest.TestCase): - def test_create_client_with_flex_api( - self, mock_http_options, mock_genai_client): - handler = GeminiModelHandler( - model_name="gemini-pro", - request_fn=generate_from_string, - project="test-project", - location="us-central1", - use_vertex_flex_api=True) - handler.create_client() - mock_http_options.assert_called_with( - api_version="v1", - headers={"X-Vertex-AI-LLM-Request-Type": "flex"}, - timeout=600000, - ) - mock_genai_client.assert_called_with( - vertexai=True, - project="test-project", - location="us-central1", - http_options=mock_http_options.return_value) - - if __name__ == '__main__': unittest.main()