Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#

import json
import logging
from collections.abc import Iterable
from collections.abc import Mapping
Expand Down Expand Up @@ -63,6 +64,7 @@ def __init__(
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
invoke_route: Optional[str] = None,
*,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -95,6 +97,12 @@ def __init__(
private: optional. if the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
invoke_route: optional. the custom route path to use when invoking
endpoints with arbitrary prediction routes. When specified, uses
`Endpoint.invoke()` instead of `Endpoint.predict()`. The route
should start with a forward slash, e.g., "/predict/v1".
See https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
for more information.
min_batch_size: optional. the minimum batch size to use when batching
inputs.
max_batch_size: optional. the maximum batch size to use when batching
Expand All @@ -104,6 +112,7 @@ def __init__(
"""
self._batching_kwargs = {}
self._env_vars = kwargs.get('env_vars', {})
self._invoke_route = invoke_route
if min_batch_size is not None:
self._batching_kwargs["min_batch_size"] = min_batch_size
if max_batch_size is not None:
Expand Down Expand Up @@ -203,9 +212,64 @@ def request(
Returns:
An iterable of Predictions.
"""
prediction = model.predict(instances=list(batch), parameters=inference_args)
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)
if self._invoke_route:
# Use invoke() for endpoints with custom prediction routes
request_body: dict[str, Any] = {"instances": list(batch)}
if inference_args:
request_body["parameters"] = inference_args
response = model.invoke(
request_path=self._invoke_route,
body=json.dumps(request_body).encode("utf-8"),
headers={"Content-Type": "application/json"})
return self._parse_invoke_response(batch, bytes(response))
else:
prediction = model.predict(
instances=list(batch), parameters=inference_args)
return utils._convert_to_result(
batch, prediction.predictions, prediction.deployed_model_id)

def _parse_invoke_response(self, batch: Sequence[Any],
response: bytes) -> Iterable[PredictionResult]:
"""Parses the response from Endpoint.invoke() into PredictionResults.

Args:
batch: the original batch of inputs.
response: the raw bytes response from invoke().

Returns:
An iterable of PredictionResults.
"""
try:
response_json = json.loads(response.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError) as e:
LOGGER.warning(
"Failed to decode invoke response as JSON, returning raw bytes: %s",
e)
# Return raw response for each batch item
return [
PredictionResult(example=example, inference=response)
for example in batch
]

# Handle standard Vertex AI response format with "predictions" key
if isinstance(response_json, dict) and "predictions" in response_json:
predictions = response_json["predictions"]
model_id = response_json.get("deployedModelId")
return utils._convert_to_result(batch, predictions, model_id)

# Handle response as a list of predictions (one per input)
if isinstance(response_json, list) and len(response_json) == len(batch):
return utils._convert_to_result(batch, response_json, None)

# Handle single prediction response
if len(batch) == 1:
return [PredictionResult(example=batch[0], inference=response_json)]

# Fallback: return the full response for each batch item
return [
PredictionResult(example=example, inference=response_json)
for example in batch
]

def batch_elements_kwargs(self) -> Mapping[str, Any]:
return self._batching_kwargs
66 changes: 66 additions & 0 deletions sdks/python/apache_beam/ml/inference/vertex_ai_inference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,71 @@ def test_exception_on_private_without_network(self):
private=True)


class ParseInvokeResponseTest(unittest.TestCase):
"""Tests for _parse_invoke_response method."""
def _create_handler_with_invoke_route(self, invoke_route="/test"):
"""Creates a mock handler with invoke_route for testing."""
import unittest.mock as mock
with mock.patch.object(VertexAIModelHandlerJSON,
'_retrieve_endpoint',
return_value=None):
handler = VertexAIModelHandlerJSON(
endpoint_id="1",
project="testproject",
location="us-central1",
invoke_route=invoke_route)
return handler

def test_parse_invoke_response_with_predictions_key(self):
"""Test parsing response with standard 'predictions' key."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}, {"input": "test2"}]
response = (
b'{"predictions": ["result1", "result2"], '
b'"deployedModelId": "model123"}')

results = list(handler._parse_invoke_response(batch, response))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].example, {"input": "test1"})
self.assertEqual(results[0].inference, "result1")
self.assertEqual(results[1].example, {"input": "test2"})
self.assertEqual(results[1].inference, "result2")

def test_parse_invoke_response_list_format(self):
"""Test parsing response as a list of predictions."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}, {"input": "test2"}]
response = b'["result1", "result2"]'

results = list(handler._parse_invoke_response(batch, response))

self.assertEqual(len(results), 2)
self.assertEqual(results[0].inference, "result1")
self.assertEqual(results[1].inference, "result2")

def test_parse_invoke_response_single_prediction(self):
"""Test parsing response with a single prediction."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}]
response = b'{"output": "single result"}'

results = list(handler._parse_invoke_response(batch, response))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].inference, {"output": "single result"})

def test_parse_invoke_response_non_json(self):
"""Test handling non-JSON response."""
handler = self._create_handler_with_invoke_route()
batch = [{"input": "test1"}]
response = b'not valid json'

results = list(handler._parse_invoke_response(batch, response))

self.assertEqual(len(results), 1)
self.assertEqual(results[0].inference, response)


if __name__ == '__main__':
unittest.main()
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/yaml/yaml_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def __init__(
experiment: Optional[str] = None,
network: Optional[str] = None,
private: bool = False,
invoke_route: Optional[str] = None,
min_batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
max_batch_duration_secs: Optional[int] = None):
Expand Down Expand Up @@ -236,6 +237,13 @@ def __init__(
private: If the deployed Vertex AI endpoint is
private, set to true. Requires a network to be provided
as well.
invoke_route: The custom route path to use when invoking
endpoints with arbitrary prediction routes. When specified, uses
`Endpoint.invoke()` instead of `Endpoint.predict()`. The route
should start with a forward slash, e.g., "/predict/v1".
See
https://cloud.google.com/vertex-ai/docs/predictions/use-arbitrary-custom-routes
for more information.
min_batch_size: The minimum batch size to use when batching
inputs.
max_batch_size: The maximum batch size to use when batching
Expand All @@ -258,6 +266,7 @@ def __init__(
experiment=experiment,
network=network,
private=private,
invoke_route=invoke_route,
min_batch_size=min_batch_size,
max_batch_size=max_batch_size,
max_batch_duration_secs=max_batch_duration_secs)
Expand Down
Loading