Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions sdks/python/apache_beam/ml/inference/gemini_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from google import genai
from google.genai import errors
from google.genai.types import HttpOptions
from google.genai.types import Part
from PIL.Image import Image

Expand Down Expand Up @@ -108,6 +109,7 @@ def __init__(
api_key: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
use_vertex_flex_api: Optional[bool] = False,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -137,6 +139,7 @@ def __init__(
location: the GCP project to use for Vertex AI requests. Setting this
parameter routes requests to Vertex AI. If this paramter is provided,
project must also be provided and api_key should not be set.
use_vertex_flex_api: if true, use the Vertex Flex API.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
Expand Down Expand Up @@ -169,6 +172,8 @@ def __init__(
self.location = location
self.use_vertex = True

self.use_vertex_flex_api = use_vertex_flex_api

super().__init__(
namespace='GeminiModelHandler',
retry_filter=_retry_on_appropriate_service_error,
Expand All @@ -180,8 +185,19 @@ def create_client(self) -> genai.Client:
provided when the GeminiModelHandler class is instantiated.
"""
if self.use_vertex:
return genai.Client(
vertexai=True, project=self.project, location=self.location)
if self.use_vertex_flex_api:
return genai.Client(
vertexai=True,
project=self.project,
location=self.location,
http_options=HttpOptions(
api_version="v1",
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
# Set timeout in the unit of millisecond.
timeout=600000))
else:
return genai.Client(
vertexai=True, project=self.project, location=self.location)
return genai.Client(api_key=self.api_key)

def request(
Expand Down
25 changes: 25 additions & 0 deletions sdks/python/apache_beam/ml/inference/gemini_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pytype: skip-file

import unittest
from unittest import mock

try:
from google.genai import errors
Expand Down Expand Up @@ -81,5 +82,29 @@ def test_missing_all_params(self):
)


@mock.patch('apache_beam.ml.inference.gemini_inference.genai.Client')
@mock.patch('apache_beam.ml.inference.gemini_inference.HttpOptions')
class TestGeminiModelHandler(unittest.TestCase):
def test_create_client_with_flex_api(
self, mock_http_options, mock_genai_client):
handler = GeminiModelHandler(
model_name="gemini-pro",
request_fn=generate_from_string,
project="test-project",
location="us-central1",
use_vertex_flex_api=True)
handler.create_client()
mock_http_options.assert_called_with(
api_version="v1",
headers={"X-Vertex-AI-LLM-Request-Type": "flex"},
timeout=600000,
)
mock_genai_client.assert_called_with(
vertexai=True,
project="test-project",
location="us-central1",
http_options=mock_http_options.return_value)


if __name__ == '__main__':
unittest.main()
Loading