Skip to content

Commit e9f43a0

Browse files
shawn-yang-googlecopybara-github
authored andcommitted
feat: Support Inline Source Deployment in Agent Engine
PiperOrigin-RevId: 817827304
1 parent 9ef8d05 commit e9f43a0

7 files changed

Lines changed: 751 additions & 64 deletions

File tree

tests/unit/vertexai/genai/replays/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vertexai._genai import (
2323
client as vertexai_genai_client_module,
2424
)
25+
from vertexai._genai import _agent_engines_utils
2526
from google.cloud import storage, bigquery
2627
from google.genai import _replay_api_client
2728
from google.genai import client as google_genai_client_module
@@ -122,6 +123,16 @@ def replays_prefix():
122123
return "test"
123124

124125

126+
@pytest.fixture
127+
def mock_agent_engine_create_base64_encoded_tarball():
128+
"""Mocks the _create_base64_encoded_tarball function."""
129+
with mock.patch.object(
130+
_agent_engines_utils, "_create_base64_encoded_tarball"
131+
) as mock_create_base64_encoded_tarball:
132+
mock_create_base64_encoded_tarball.return_value = "H4sIAAAAAAAAA-3UvWrDMBAHcM9-CpEpGRLkD8VQ6JOUElT7LFxkydEHxG9f2V1CKXSyu_x_i6TjJN2gk6N7HByNZIK_hEfINsCTa82zilcNTyMvRSPKao2vBM8KwZu6vJZXITJepGyRMb5FMT9FH6RjLHsM0mpr1CyN-i1vcsMo3aycjdMede0kV9YqTedW29id5TBpGXrrxjep0pO4kVGDIf-e_3edsI1APtxG20VNl2ne5o6_-r-oRer_Ypk2dd0s_Z82oP_3kLdaes-ensFLzpKOenaP5OajJ92fvoMLRyE6ww7LjrTwkzWeDnm-nmA_PqkN7PX5vOMJnwcAAAAAAAAAAAAAAAAAAADAdr4AI-kzQQAoAAA="
133+
yield mock_create_base64_encoded_tarball
134+
135+
125136
def _get_replay_id(use_vertex: bool, replays_prefix: str) -> str:
126137
test_name_ending = os.environ.get("PYTEST_CURRENT_TEST").split("::")[-1]
127138
test_name = test_name_ending.split(" ")[0].split("[")[0] + "." + "vertex"

tests/unit/vertexai/genai/replays/test_create_agent_engine.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from tests.unit.vertexai.genai.replays import pytest_helper
2020
from vertexai._genai import types
2121

22+
_TEST_CLASS_METHODS = [
23+
{"name": "query", "api_mode": ""},
24+
]
25+
2226

2327
def test_create_config_lightweight(client):
2428
agent_display_name = "test-display-name"
@@ -108,6 +112,36 @@ def test_create_with_context_spec(client):
108112
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
109113

110114

115+
def test_create_with_source_packages(
116+
client,
117+
mock_agent_engine_create_base64_encoded_tarball,
118+
):
119+
"""Tests creating an agent engine with source packages."""
120+
with mock_agent_engine_create_base64_encoded_tarball:
121+
agent_engine = client.agent_engines.create(
122+
config={
123+
"display_name": "test-agent-engine-source-packages",
124+
"source_packages": [
125+
"test_module.py",
126+
"requirements.txt",
127+
],
128+
"entrypoint_module": "test_module",
129+
"entrypoint_object": "test_object",
130+
"class_methods": _TEST_CLASS_METHODS,
131+
"http_options": {
132+
"base_url": "https://us-west1-aiplatform.googleapis.com",
133+
"api_version": "v1beta1",
134+
},
135+
},
136+
)
137+
assert (
138+
agent_engine.api_resource.display_name
139+
== "test-agent-engine-source-packages"
140+
)
141+
# Clean up resources.
142+
client.agent_engines.delete(name=agent_engine.api_resource.name, force=True)
143+
144+
111145
pytestmark = pytest_helper.setup(
112146
file=__file__,
113147
globals_for_file=globals(),

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 168 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414
#
1515
import asyncio
16+
import base64
1617
import importlib
18+
import io
1719
import json
1820
import logging
1921
import os
2022
import sys
23+
import tarfile
2124
import tempfile
2225
from typing import Any, AsyncIterable, Dict, Iterable, List
2326
from unittest import mock
@@ -901,6 +904,48 @@ def test_create_agent_engine_config_full(self, mock_prepare):
901904
== _TEST_AGENT_ENGINE_CUSTOM_SERVICE_ACCOUNT
902905
)
903906

907+
@mock.patch.object(
908+
_agent_engines_utils,
909+
"_create_base64_encoded_tarball",
910+
return_value="test_tarball",
911+
)
912+
def test_create_agent_engine_config_with_source_packages(
913+
self, mock_create_base64_encoded_tarball
914+
):
915+
with tempfile.TemporaryDirectory() as tmpdir:
916+
test_file_path = os.path.join(tmpdir, "test_file.txt")
917+
with open(test_file_path, "w") as f:
918+
f.write("test content")
919+
requirements_file_path = os.path.join(tmpdir, "requirements.txt")
920+
with open(requirements_file_path, "w") as f:
921+
f.write("requests==2.0.0")
922+
923+
config = self.client.agent_engines._create_config(
924+
mode="create",
925+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
926+
description=_TEST_AGENT_ENGINE_DESCRIPTION,
927+
source_packages=[test_file_path],
928+
entrypoint_module="main",
929+
entrypoint_object="app",
930+
requirements_file=requirements_file_path,
931+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
932+
)
933+
assert config["display_name"] == _TEST_AGENT_ENGINE_DISPLAY_NAME
934+
assert config["description"] == _TEST_AGENT_ENGINE_DESCRIPTION
935+
assert config["spec"]["source_code_spec"] == {
936+
"inline_source": {"source_archive": "test_tarball"},
937+
"python_spec": {
938+
"version": _TEST_PYTHON_VERSION,
939+
"entrypoint_module": "main",
940+
"entrypoint_object": "app",
941+
"requirements_file": requirements_file_path,
942+
},
943+
}
944+
assert config["spec"]["class_methods"] == _TEST_AGENT_ENGINE_CLASS_METHODS
945+
mock_create_base64_encoded_tarball.assert_called_once_with(
946+
source_packages=[test_file_path]
947+
)
948+
904949
@mock.patch.object(_agent_engines_utils, "_prepare")
905950
def test_update_agent_engine_config_full(self, mock_prepare):
906951
config = self.client.agent_engines._create_config(
@@ -951,10 +996,10 @@ def test_update_agent_engine_config_full(self, mock_prepare):
951996
"spec.package_spec.pickle_object_gcs_uri",
952997
"spec.package_spec.dependency_files_gcs_uri",
953998
"spec.package_spec.requirements_gcs_uri",
999+
"spec.class_methods",
9541000
"spec.deployment_spec.env",
9551001
"spec.deployment_spec.secret_env",
9561002
"spec.service_account",
957-
"spec.class_methods",
9581003
"spec.agent_framework",
9591004
]
9601005
)
@@ -1170,6 +1215,45 @@ def test_to_parsed_json(self, obj, expected):
11701215
for got, want in zip(_agent_engines_utils._yield_parsed_json(obj), expected):
11711216
assert got == want
11721217

1218+
def test_create_base64_encoded_tarball(self):
1219+
with tempfile.TemporaryDirectory() as tmpdir:
1220+
test_file_path = os.path.join(tmpdir, "test_file.txt")
1221+
with open(test_file_path, "w") as f:
1222+
f.write("test content")
1223+
1224+
origin_dir = os.getcwd()
1225+
try:
1226+
os.chdir(tmpdir)
1227+
encoded_tarball = _agent_engines_utils._create_base64_encoded_tarball(
1228+
source_packages=["test_file.txt"]
1229+
)
1230+
finally:
1231+
os.chdir(origin_dir)
1232+
1233+
decoded_tarball = base64.b64decode(encoded_tarball)
1234+
with tarfile.open(fileobj=io.BytesIO(decoded_tarball), mode="r:gz") as tar:
1235+
names = tar.getnames()
1236+
assert "test_file.txt" in names
1237+
1238+
def test_create_base64_encoded_tarball_outside_project_dir_raises(self):
1239+
with tempfile.TemporaryDirectory() as tmpdir:
1240+
project_dir = os.path.join(tmpdir, "project")
1241+
os.makedirs(project_dir)
1242+
sibling_path = os.path.join(tmpdir, "sibling.txt")
1243+
with open(sibling_path, "w") as f:
1244+
f.write("test content")
1245+
1246+
origin_dir = os.getcwd()
1247+
try:
1248+
os.chdir(project_dir)
1249+
with pytest.raises(ValueError) as excinfo:
1250+
_agent_engines_utils._create_base64_encoded_tarball(
1251+
source_packages=["../sibling.txt"]
1252+
)
1253+
assert "is outside the project directory" in str(excinfo.value)
1254+
finally:
1255+
os.chdir(origin_dir)
1256+
11731257

11741258
@pytest.mark.usefixtures("google_auth_mock")
11751259
class TestAgentEngine:
@@ -1365,6 +1449,10 @@ def test_create_agent_engine_with_env_vars_dict(
13651449
agent_server_mode=None,
13661450
labels=None,
13671451
class_methods=None,
1452+
source_packages=None,
1453+
entrypoint_module=None,
1454+
entrypoint_object=None,
1455+
requirements_file=None,
13681456
)
13691457
request_mock.assert_called_with(
13701458
"post",
@@ -1447,6 +1535,10 @@ def test_create_agent_engine_with_custom_service_account(
14471535
labels=None,
14481536
agent_server_mode=None,
14491537
class_methods=None,
1538+
source_packages=None,
1539+
entrypoint_module=None,
1540+
entrypoint_object=None,
1541+
requirements_file=None,
14501542
)
14511543
request_mock.assert_called_with(
14521544
"post",
@@ -1531,6 +1623,10 @@ def test_create_agent_engine_with_experimental_mode(
15311623
labels=None,
15321624
agent_server_mode=_genai_types.AgentServerMode.EXPERIMENTAL,
15331625
class_methods=None,
1626+
source_packages=None,
1627+
entrypoint_module=None,
1628+
entrypoint_object=None,
1629+
requirements_file=None,
15341630
)
15351631
request_mock.assert_called_with(
15361632
"post",
@@ -1553,6 +1649,72 @@ def test_create_agent_engine_with_experimental_mode(
15531649
None,
15541650
)
15551651

1652+
@mock.patch.object(
1653+
_agent_engines_utils,
1654+
"_create_base64_encoded_tarball",
1655+
return_value="test_tarball",
1656+
)
1657+
@mock.patch.object(_agent_engines_utils, "_await_operation")
1658+
def test_create_agent_engine_with_source_packages(
1659+
self,
1660+
mock_await_operation,
1661+
mock_create_base64_encoded_tarball,
1662+
):
1663+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
1664+
response=_genai_types.ReasoningEngine(
1665+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
1666+
spec=_TEST_AGENT_ENGINE_SPEC,
1667+
)
1668+
)
1669+
with tempfile.TemporaryDirectory() as tmpdir:
1670+
test_file_path = os.path.join(tmpdir, "test_file.txt")
1671+
with open(test_file_path, "w") as f:
1672+
f.write("test content")
1673+
requirements_file_path = os.path.join(tmpdir, "requirements.txt")
1674+
with open(requirements_file_path, "w") as f:
1675+
f.write("requests==2.0.0")
1676+
1677+
with mock.patch.object(
1678+
self.client.agent_engines._api_client, "request"
1679+
) as request_mock:
1680+
request_mock.return_value = genai_types.HttpResponse(body="")
1681+
self.client.agent_engines.create(
1682+
config=_genai_types.AgentEngineConfig(
1683+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
1684+
description=_TEST_AGENT_ENGINE_DESCRIPTION,
1685+
source_packages=[test_file_path],
1686+
entrypoint_module="main",
1687+
entrypoint_object="app",
1688+
requirements_file=requirements_file_path,
1689+
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1690+
),
1691+
)
1692+
request_mock.assert_called_with(
1693+
"post",
1694+
"reasoningEngines",
1695+
{
1696+
"displayName": _TEST_AGENT_ENGINE_DISPLAY_NAME,
1697+
"description": _TEST_AGENT_ENGINE_DESCRIPTION,
1698+
"spec": {
1699+
"agent_framework": "custom",
1700+
"source_code_spec": {
1701+
"inline_source": {"source_archive": "test_tarball"},
1702+
"python_spec": {
1703+
"version": _TEST_PYTHON_VERSION,
1704+
"entrypoint_module": "main",
1705+
"entrypoint_object": "app",
1706+
"requirements_file": requirements_file_path,
1707+
},
1708+
},
1709+
"class_methods": _TEST_AGENT_ENGINE_CLASS_METHODS,
1710+
},
1711+
},
1712+
None,
1713+
)
1714+
mock_create_base64_encoded_tarball.assert_called_once_with(
1715+
source_packages=[test_file_path]
1716+
)
1717+
15561718
@mock.patch.object(agent_engines.AgentEngines, "_create_config")
15571719
@mock.patch.object(_agent_engines_utils, "_await_operation")
15581720
def test_create_agent_engine_with_class_methods(
@@ -1613,6 +1775,10 @@ def test_create_agent_engine_with_class_methods(
16131775
labels=None,
16141776
agent_server_mode=None,
16151777
class_methods=_TEST_AGENT_ENGINE_CLASS_METHODS,
1778+
source_packages=None,
1779+
entrypoint_module=None,
1780+
entrypoint_object=None,
1781+
requirements_file=None,
16161782
)
16171783
request_mock.assert_called_with(
16181784
"post",
@@ -1772,9 +1938,9 @@ def test_update_agent_engine_env_vars(
17721938
[
17731939
"spec.package_spec.pickle_object_gcs_uri",
17741940
"spec.package_spec.requirements_gcs_uri",
1941+
"spec.class_methods",
17751942
"spec.deployment_spec.env",
17761943
"spec.deployment_spec.secret_env",
1777-
"spec.class_methods",
17781944
"spec.agent_framework",
17791945
]
17801946
)

vertexai/_genai/_agent_engines_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import abc
1818
import asyncio
19+
import base64
1920
from importlib import metadata as importlib_metadata
2021
import inspect
2122
import io
@@ -1161,6 +1162,30 @@ def _upload_extra_packages(
11611162
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")
11621163

11631164

1165+
def _create_base64_encoded_tarball(
1166+
*,
1167+
source_packages: Sequence[str],
1168+
) -> str:
1169+
"""Creates a base64 encoded tarball from the source packages."""
1170+
logger.info("Creating in-memory tarfile of source_packages")
1171+
tar_fileobj = io.BytesIO()
1172+
project_dir = os.path.realpath(os.getcwd())
1173+
with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar:
1174+
for file in source_packages:
1175+
real_file_path = os.path.realpath(file)
1176+
if real_file_path != project_dir and not real_file_path.startswith(
1177+
project_dir + os.sep
1178+
):
1179+
raise ValueError(
1180+
f"File path '{file}' is outside the project directory "
1181+
f"'{project_dir}'."
1182+
)
1183+
tar.add(file)
1184+
tar_fileobj.seek(0)
1185+
tarball_bytes = tar_fileobj.read()
1186+
return base64.b64encode(tarball_bytes).decode("utf-8")
1187+
1188+
11641189
def _validate_extra_packages_or_raise(
11651190
*,
11661191
extra_packages: Sequence[str],

0 commit comments

Comments
 (0)