diff --git a/agentplatform/__init__.py b/agentplatform/__init__.py index 9756768002..fa669c8760 100644 --- a/agentplatform/__init__.py +++ b/agentplatform/__init__.py @@ -15,11 +15,16 @@ """The agentplatform module.""" import importlib +import sys + from google.cloud.aiplatform import init from google.cloud.aiplatform import version as aiplatform_version __version__ = aiplatform_version.__version__ +_genai_client = None +_genai_types = None + def __getattr__(name): # type: ignore[no-untyped-def] if name == "preview": @@ -30,10 +35,26 @@ def __getattr__(name): # type: ignore[no-untyped-def] # `import google.cloud.aiplatform.agentplatform.preview as` # `agentplatform_preview` return importlib.import_module(".preview", __name__) + if name == "Client": + global _genai_client + if _genai_client is None: + _genai_client = importlib.import_module("._genai.client", __name__) + return getattr(_genai_client, name) + + if name == "types": + global _genai_types + if _genai_types is None: + _genai_types = importlib.import_module("._genai.types", __name__) + if "vertexai.types" not in sys.modules: + sys.modules["vertexai.types"] = _genai_types + return _genai_types + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") __all__ = [ "init", "preview", + "Client", + "types", ] diff --git a/agentplatform/_genai/__init__.py b/agentplatform/_genai/__init__.py new file mode 100644 index 0000000000..e7a13c9de3 --- /dev/null +++ b/agentplatform/_genai/__init__.py @@ -0,0 +1,43 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The agentplatform module.""" + +import importlib + +from .client import Client + +_evals = None + + +def __getattr__(name): # type: ignore[no-untyped-def] + if name == "evals": + global _evals + if _evals is None: + try: + _evals = importlib.import_module(".evals", __package__) + except ImportError as e: + raise ImportError( + "The 'evals' module requires additional dependencies. " + "Please install them using pip install " + "google-cloud-aiplatform[evaluation]" + ) from e + return _evals + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +__all__ = [ + "Client", + "evals", +] diff --git a/agentplatform/_genai/_agent_engines_utils.py b/agentplatform/_genai/_agent_engines_utils.py new file mode 100644 index 0000000000..b269ba0833 --- /dev/null +++ b/agentplatform/_genai/_agent_engines_utils.py @@ -0,0 +1,1996 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for agent engines.""" + +import abc +import asyncio +import base64 +from importlib import metadata as importlib_metadata +import inspect +import io +import json +import logging +import os +import re +import sys +import tarfile +import time +import types +import typing +from typing import ( + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + Iterator, + List, + Mapping, + Optional, + Protocol, + Sequence, + Set, + TypedDict, + Union, +) + +import httpx + +import proto + +from google.api_core import exceptions +from google.genai import types as google_genai_types +from google.protobuf import struct_pb2 +from google.protobuf import json_format + +from . import types as genai_types + + +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + + +try: + _BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names +except AttributeError: + _BUILTIN_MODULE_NAMES: Sequence[str] = [] # type: ignore[no-redef] + +try: + _PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = ( + importlib_metadata.packages_distributions() + ) +except AttributeError: + _PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = {} # type: ignore[no-redef] + +try: + # sys.stdlib_module_names is available from Python 3.10 onwards. + _STDLIB_MODULE_NAMES: frozenset[str] = sys.stdlib_module_names +except AttributeError: + _STDLIB_MODULE_NAMES: frozenset[str] = frozenset() # type: ignore[no-redef] + + +if typing.TYPE_CHECKING: + from google.cloud import storage # type: ignore[attr-defined] + + _StorageBucket: TypeAlias = storage.Bucket +else: + try: + from google.cloud import storage # type: ignore[attr-defined] + + _StorageBucket: type[Any] = storage.Bucket + except (ImportError, AttributeError): + _StorageBucket: type[Any] = Any # type: ignore[no-redef] + + +if typing.TYPE_CHECKING: + import packaging + + _SpecifierSet = packaging.specifiers.SpecifierSet +else: + try: + import packaging + + _SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet + except (ImportError, AttributeError): + _SpecifierSet: type[Any] = Any # type: ignore[no-redef] + + +try: + from a2a.types import ( + AgentCard, + TransportProtocol, + Message, + TaskIdParams, + TaskQueryParams, + ) + from a2a.client import ClientConfig, ClientFactory + + AgentCard = AgentCard + TransportProtocol = TransportProtocol + Message = Message + ClientConfig = ClientConfig + ClientFactory = ClientFactory + TaskIdParams = TaskIdParams + TaskQueryParams = TaskQueryParams +except (ImportError, AttributeError): + AgentCard = None + TransportProtocol = None + Message = None + ClientConfig = None + ClientFactory = None + TaskIdParams = None + TaskQueryParams = None + +_ACTIONS_KEY = "actions" +_ACTION_APPEND = "append" +_AGENT_FRAMEWORK_ATTR = "agent_framework" +_ASYNC_API_MODE = "async" +_ASYNC_STREAM_API_MODE = "async_stream" +_BIDI_STREAM_API_MODE = "bidi_stream" +_BASE_MODULES = set(_BUILTIN_MODULE_NAMES).union(_STDLIB_MODULE_NAMES) +_BLOB_FILENAME = "agent_engine.pkl" +_DEFAULT_AGENT_FRAMEWORK = "custom" +_SUPPORTED_AGENT_FRAMEWORKS = frozenset( + [ + "google-adk", + "langchain", + "langgraph", + "ag2", + "llama-index", + "custom", + "a2a", + ] +) +_DEFAULT_ASYNC_METHOD_NAME = "async_query" +_DEFAULT_ASYNC_METHOD_RETURN_TYPE = "Coroutine[Any]" +_DEFAULT_ASYNC_STREAM_METHOD_NAME = "async_stream_query" +_DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE = "AsyncIterable[Any]" +_DEFAULT_GCS_DIR_NAME = "agent_engine" +_DEFAULT_METHOD_DOCSTRING_TEMPLATE = """ + Runs the Agent Engine to serve the user request. + This will be based on the `.{method_name}(...)` of the python object that + was passed in when creating the Agent Engine. The method will invoke the + `{default_method_name}` API client of the python object. + Args: + **kwargs: + Optional. The arguments of the `.{method_name}(...)` method. + Returns: + {return_type}: The response from serving the user request. +""" +_DEFAULT_METHOD_NAME = "query" +_DEFAULT_METHOD_RETURN_TYPE = "dict[str, Any]" +_DEFAULT_STREAM_METHOD_RETURN_TYPE = "Iterable[Any]" +_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"]) +_DEFAULT_STREAM_METHOD_NAME = "stream_query" +_DEFAULT_BIDI_STREAM_METHOD_NAME = "bidi_stream_query" +_EXTRA_PACKAGES_FILE = "dependencies.tar.gz" +_FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE = ( + "Failed to register API methods. Please follow the guide to " + "register the API methods: " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#custom-methods. " + "Error: {%s}" +) +_INSTALLATION_SUBDIR = "installation_scripts" +_METHOD_NAME_KEY_IN_SCHEMA = "name" +_MODE_KEY_IN_SCHEMA = "api_mode" +_REQUIREMENTS_FILE = "requirements.txt" +_STANDARD_API_MODE = "" +_STREAM_API_MODE = "stream" +_A2A_EXTENSION_MODE = "a2a_extension" +_A2A_AGENT_CARD = "a2a_agent_card" +_WARNINGS_KEY = "warnings" +_WARNING_MISSING = "missing" +_WARNING_INCOMPATIBLE = "incompatible" + +_DEFAULT_METHOD_NAME_MAP = { + _STANDARD_API_MODE: _DEFAULT_METHOD_NAME, + _ASYNC_API_MODE: _DEFAULT_ASYNC_METHOD_NAME, + _STREAM_API_MODE: _DEFAULT_STREAM_METHOD_NAME, + _ASYNC_STREAM_API_MODE: _DEFAULT_ASYNC_STREAM_METHOD_NAME, +} +_DEFAULT_METHOD_RETURN_TYPE_MAP = { + _STANDARD_API_MODE: _DEFAULT_METHOD_RETURN_TYPE, + _ASYNC_API_MODE: _DEFAULT_ASYNC_METHOD_RETURN_TYPE, + _STREAM_API_MODE: _DEFAULT_STREAM_METHOD_RETURN_TYPE, + _ASYNC_STREAM_API_MODE: _DEFAULT_ASYNC_STREAM_METHOD_RETURN_TYPE, +} + + +logger = logging.getLogger("agentplatform_genai.agentengines") + + +@typing.runtime_checkable +class Queryable(Protocol): + """Protocol for Agent Engines that can be queried.""" + + @abc.abstractmethod + def query(self, **kwargs): # type: ignore[no-untyped-def] + """Runs the Agent Engine to serve the user query.""" + + +@typing.runtime_checkable +class AsyncQueryable(Protocol): + """Protocol for Agent Engines that can be queried asynchronously.""" + + @abc.abstractmethod + def async_query(self, **kwargs): # type: ignore[no-untyped-def] + """Runs the Agent Engine to serve the user query asynchronously.""" + + +@typing.runtime_checkable +class AsyncStreamQueryable(Protocol): + """Protocol for Agent Engines that can stream responses asynchronously.""" + + @abc.abstractmethod + async def async_stream_query(self, **kwargs) -> AsyncIterator[Any]: # type: ignore[no-untyped-def] + """Asynchronously stream responses to serve the user query.""" + + +@typing.runtime_checkable +class StreamQueryable(Protocol): + """Protocol for Agent Engines that can stream responses.""" + + @abc.abstractmethod + def stream_query(self, **kwargs) -> Iterator[Any]: # type: ignore[no-untyped-def] + """Stream responses to serve the user query.""" + + +@typing.runtime_checkable +class BidiStreamQueryable(Protocol): + """Protocol for Agent Engines that can stream requests and responses.""" + + @abc.abstractmethod + async def bidi_stream_query( + self, input_queue: asyncio.Queue[Any] + ) -> AsyncIterator[Any]: + """Stream requests and responses to serve the user queries.""" + + +@typing.runtime_checkable +class Cloneable(Protocol): + """Protocol for Agent Engines that can be cloned.""" + + @abc.abstractmethod + def clone(self) -> Any: + """Return a clone of the object.""" + + +@typing.runtime_checkable +class OperationRegistrable(Protocol): + """Protocol for agents that have registered operations.""" + + @abc.abstractmethod + def register_operations(self, **kwargs: Any) -> dict[str, list[str]]: + """Register the user provided operations (modes and methods).""" + pass + + +if typing.TYPE_CHECKING: + from google.adk.agents import BaseAgent + + ADKAgent: TypeAlias = BaseAgent +else: + try: + from google.adk.agents import BaseAgent + + ADKAgent: Optional[TypeAlias] = BaseAgent + except (ImportError, AttributeError): + ADKAgent = None # type: ignore[no-redef] + +_AgentEngineInterface = Union[ + ADKAgent, + AsyncQueryable, + AsyncStreamQueryable, + OperationRegistrable, + Queryable, + StreamQueryable, + BidiStreamQueryable, +] + + +class _ModuleAgentAttributes(TypedDict, total=False): + module_name: str + agent_name: str + register_operations: Dict[str, list[str]] + sys_paths: Optional[Sequence[str]] + agent: _AgentEngineInterface + + +class ModuleAgent(Cloneable, OperationRegistrable): + """Agent that is defined by a module and an agent name. + + This agent is instantiated by importing a module and instantiating an agent + from that module. It also allows to register operations that are defined in + the agent. + """ + + def __init__( + self, + *, + module_name: str, + agent_name: str, + register_operations: Dict[str, list[str]], + sys_paths: Optional[Sequence[str]] = None, + ): + """Initializes a module-based agent. + + Args: + module_name (str): + Required. The name of the module to import. + agent_name (str): + Required. The name of the agent in the module to instantiate. + register_operations (Dict[str, list[str]]): + Required. A dictionary of API modes to a list of method names. + sys_paths (Sequence[str]): + Optional. The system paths to search for the module. It should + be relative to the directory where the code will be running. + I.e. it should correspond to the directory being passed to + `extra_packages=...` in the create method. It will be appended + to the system path in the sequence being specified here, and + only be appended if it is not already in the system path. + """ + self._tmpl_attrs: _ModuleAgentAttributes = { + "module_name": module_name, + "agent_name": agent_name, + "register_operations": register_operations, + "sys_paths": sys_paths, + } + + def clone(self) -> "ModuleAgent": + """Return a clone of the agent.""" + return ModuleAgent( + module_name=self._tmpl_attrs.get("module_name"), + agent_name=self._tmpl_attrs.get("agent_name"), + register_operations=self._tmpl_attrs.get("register_operations"), + sys_paths=self._tmpl_attrs.get("sys_paths"), + ) + + def register_operations(self, **kwargs: Any) -> dict[str, list[str]]: + reg_operations = self._tmpl_attrs.get("register_operations") + if reg_operations is None: + raise ValueError("Register operations is not set.") + return reg_operations + + def set_up(self) -> None: + """Sets up the agent for execution of queries at runtime. + + It runs the code to import the agent from the module, and registers the + operations of the agent. + """ + sys_paths = self._tmpl_attrs.get("sys_paths") + if isinstance(sys_paths, Sequence): + import sys + + for sys_path in sys_paths: + abs_path = os.path.abspath(sys_path) + if abs_path not in sys.path: + sys.path.append(abs_path) + + import importlib + + module = importlib.import_module(self._tmpl_attrs.get("module_name")) + try: + importlib.reload(module) + except Exception as e: + logger.warning( + f"Failed to reload module {self._tmpl_attrs.get('module_name')}: {e}" + ) + agent_name = self._tmpl_attrs.get("agent_name") + try: + agent = getattr(module, agent_name) + except AttributeError as e: + raise AttributeError( + f"Agent {agent_name} not found in module " + f"{self._tmpl_attrs.get('module_name')}" + ) from e + self._tmpl_attrs["agent"] = agent + if hasattr(agent, "set_up"): + agent.set_up() + for operations in self.register_operations().values(): + for operation in operations: + op = _wrap_agent_operation(agent=agent, operation=operation) + setattr(self, operation, types.MethodType(op, self)) + + +class _RequirementsValidationActions(TypedDict): + append: Set[str] + + +class _RequirementsValidationWarnings(TypedDict): + missing: Set[str] + incompatible: Set[str] + + +class _RequirementsValidationResult(TypedDict): + warnings: _RequirementsValidationWarnings + actions: _RequirementsValidationActions + + +AgentEngineOperationUnion = Union[ + genai_types.AgentEngineOperation, + genai_types.AgentEngineMemoryOperation, + genai_types.AgentEngineGenerateMemoriesOperation, +] + + +class GetOperationFunction(Protocol): + def __call__( + self, *, operation_name: str, **kwargs: Any + ) -> AgentEngineOperationUnion: + pass + + +class GetAsyncOperationFunction(Protocol): + async def __call__( + self, *, operation_name: str, **kwargs: Any + ) -> AgentEngineOperationUnion: + pass + + +def _get_reasoning_engine_id(operation_name: str = "", resource_name: str = "") -> str: + """Returns reasoning engine ID from operation name or resource name.""" + if not resource_name and not operation_name: + raise ValueError("Resource name or operation name cannot be empty.") + + if resource_name: + match = re.match( + r"^projects/[^/]+/locations/[^/]+/reasoningEngines/([^/]+)$", + resource_name, + ) + if match: + return match.group(1) + else: + raise ValueError( + "Failed to parse reasoning engine ID from resource name: " + f"`{resource_name}`" + ) + + if not operation_name: + raise ValueError("Operation name cannot be empty.") + + match = re.match( + r"^projects/[^/]+/locations/[^/]+/reasoningEngines/([^/]+)/operations/[^/]+$", + operation_name, + ) + if match: + return match.group(1) + raise ValueError( + "Failed to parse reasoning engine ID from operation name: " + f"`{operation_name}`" + ) + + +async def _await_async_operation( + *, + operation_name: str, + get_operation_fn: GetAsyncOperationFunction, + poll_interval_seconds: float = 10, +) -> Any: + """Waits for the operation for creating an agent engine to complete. + + Args: + operation_name (str): + Required. The name of the operation for creating the Agent Engine. + poll_interval_seconds (float): + The number of seconds to wait between each poll. + get_operation_fn (Callable[[str], Awaitable[Any]]): + Optional. The async function to use for getting the operation. If not + provided, `self._get_agent_operation` will be used. + + Returns: + The operation that has completed (i.e. `operation.done==True`). + """ + operation = await get_operation_fn(operation_name=operation_name) + while not operation.done: + await asyncio.sleep(poll_interval_seconds) + operation = await get_operation_fn(operation_name=operation.name) + + return operation + + +def _await_operation( + *, + operation_name: str, + get_operation_fn: GetOperationFunction, + poll_interval_seconds: float = 10, +) -> Any: + """Waits for the operation for creating an agent engine to complete. + + Args: + operation_name (str): + Required. The name of the operation for creating the Agent Engine. + poll_interval_seconds (float): + The number of seconds to wait between each poll. + get_operation_fn (Callable[[str], Any]): + Optional. The function to use for getting the operation. If not + provided, `self._get_agent_operation` will be used. + + Returns: + The operation that has completed (i.e. `operation.done==True`). + """ + operation = get_operation_fn(operation_name=operation_name) + while not operation.done: + time.sleep(poll_interval_seconds) + operation = get_operation_fn(operation_name=operation.name) + + return operation + + +def _compare_requirements( + *, + requirements: Mapping[str, str], + constraints: Union[Sequence[str], Mapping[str, Optional["_SpecifierSet"]]], + required_packages: Optional[Iterator[str]] = None, +) -> _RequirementsValidationResult: + """Compares the requirements with the constraints. + + Args: + requirements (Mapping[str, str]): + Required. The packages (and their versions) to compare with the constraints. + This is assumed to be the result of `scan_requirements`. + constraints (Union[Sequence[str], Mapping[str, SpecifierSet]]): + Required. The package constraints to compare against. This is assumed + to be the result of `parse_constraints`. + required_packages (Iterator[str]): + Optional. The set of packages that are required to be in the + constraints. It defaults to the set of packages that are required + for deployment on Agent Engine. + + Returns: + dict[str, dict[str, Any]]: The comparison result as a dictionary containing: + * warnings: + * missing: The set of packages that are not in the constraints. + * incompatible: The set of packages that are in the constraints + but have versions that are not in the constraint specifier. + * actions: + * append: The set of packages that are not in the constraints + but should be appended to the constraints. + """ + packaging_version = _import_packaging_version_or_raise() + if required_packages is None: + required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment] + result = _RequirementsValidationResult( + warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()), + actions=_RequirementsValidationActions(append=set()), + ) + if isinstance(constraints, list): + constraints = _parse_constraints(constraints=constraints) + for package, package_version in requirements.items(): + if package not in constraints: + result[_WARNINGS_KEY][_WARNING_MISSING].add(package) # type: ignore[literal-required] + if package in required_packages: # type: ignore[operator] + result[_ACTIONS_KEY][_ACTION_APPEND].add( # type: ignore[literal-required] + f"{package}=={package_version}" + ) + continue + if package_version: + package_specifier = constraints[package] # type: ignore[call-overload] + if not package_specifier: + continue + if packaging_version.Version(package_version) not in package_specifier: + result[_WARNINGS_KEY][_WARNING_INCOMPATIBLE].add( # type: ignore[literal-required] + f"{package}=={package_version} (required: {str(package_specifier)})" + ) + return result + + +def _generate_class_methods_spec_or_raise( + *, + agent: _AgentEngineInterface, + operations: Dict[str, List[str]], +) -> List[proto.Message]: + """Generates a ReasoningEngineSpec based on the registered operations. + + Args: + agent: The AgentEngine instance. + operations: A dictionary of API modes and method names. + + Returns: + A list of ReasoningEngineSpec.ClassMethod messages. + + Raises: + ValueError: If a method defined in `register_operations` is not found on + the AgentEngine. + """ + if isinstance(agent, ModuleAgent): + # We do a dry-run of setting up the agent engine to have the operations + # needed for registration. + agent: ModuleAgent = agent.clone() # type: ignore[no-redef] + try: + agent.set_up() + except Exception as e: + raise ValueError(f"Failed to set up agent {agent}: {e}") from e + class_methods_spec = [] + for mode, method_names in operations.items(): + for method_name in method_names: + if not hasattr(agent, method_name): + raise ValueError( + f"Method `{method_name}` defined in `register_operations`" + " not found on agent." + ) + + method = getattr(agent, method_name) + try: + schema_dict = _generate_schema(method, schema_name=method_name) + except Exception as e: + logger.warning(f"failed to generate schema for {method_name}: {e}") + continue + + class_method = _to_proto(schema_dict) + class_method[_MODE_KEY_IN_SCHEMA] = mode + if hasattr(agent, "agent_card"): + class_method[_A2A_AGENT_CARD] = json_format.MessageToJson( + getattr(agent, "agent_card") + ) + class_methods_spec.append(class_method) + + return class_methods_spec + + +def _class_methods_to_class_methods_spec( + class_methods: List[dict[str, Any]], +) -> List[proto.Message]: + """Converts a list of class methods to a list of ReasoningEngineSpec.ClassMethod messages.""" + return [_to_proto(class_method) for class_method in class_methods] + + +def _is_pydantic_serializable(param: inspect.Parameter) -> bool: + """Checks if the parameter is pydantic serializable.""" + + if param.annotation == inspect.Parameter.empty: + return True + + if "ForwardRef" in repr(param.annotation): + return True + + if isinstance(param.annotation, str): + return False + + pydantic = _import_pydantic_or_raise() + try: + pydantic.TypeAdapter(param.annotation) + return True + except Exception: + return False + + +def _generate_schema( + f: Callable[..., Any], + *, + schema_name: Optional[str] = None, + descriptions: Mapping[str, str] = {}, + required: Sequence[str] = [], +) -> Dict[str, Any]: + """Generates the OpenAPI Schema for a callable object. + + Only positional and keyword arguments of the function `f` will be supported + in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will + not be present in the OpenAPI schema returned from this function. For those + cases, you can either include it in the docstring for `f`, or modify the + OpenAPI schema returned from this function to include additional arguments. + + Args: + f (Callable): + Required. The function to generate an OpenAPI Schema for. + schema_name (str): + Optional. The name for the OpenAPI schema. If unspecified, the name + of the Callable will be used. + descriptions (Mapping[str, str]): + Optional. A `{name: description}` mapping for annotating input + arguments of the function with user-provided descriptions. It + defaults to an empty dictionary (i.e. there will not be any + description for any of the inputs). + required (Sequence[str]): + Optional. For the user to specify the set of required arguments in + function calls to `f`. If specified, it will be automatically + inferred from `f`. + + Returns: + dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format. + """ + pydantic = _import_pydantic_or_raise() + defaults = dict(inspect.signature(f).parameters) + fields_dict = { + name: ( + # 1. We infer the argument type here: use Any rather than None so + # it will not try to auto-infer the type based on the default value. + ( + param.annotation + if param.annotation != inspect.Parameter.empty + and "ForwardRef" not in repr(param.annotation) + else Any + ), + pydantic.Field( + # 2. We do not support default values for now. + # default=( + # param.default if param.default != inspect.Parameter.empty + # else None + # ), + # 3. We support user-provided descriptions. + description=descriptions.get(name, None), + ), + ) + for name, param in defaults.items() + # We do not support *args or **kwargs + if param.kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + # For a bidi endpoint, it requires an asyncio.Queue as the input, but + # it is not JSON serializable. We hence exclude it from the schema. + and param.annotation != asyncio.Queue and _is_pydantic_serializable(param) + } + parameters = pydantic.create_model(f.__name__, **fields_dict).schema() + # Postprocessing + # 4. Suppress unnecessary title generation: + # * https://github.com/pydantic/pydantic/issues/1051 + # * http://cl/586221780 + parameters.pop("title", "") + for name, function_arg in parameters.get("properties", {}).items(): + function_arg.pop("title", "") + annotation = defaults[name].annotation + # 5. Nullable fields: + # * https://github.com/pydantic/pydantic/issues/1270 + # * https://stackoverflow.com/a/58841311 + # * https://github.com/pydantic/pydantic/discussions/4872 + if typing.get_origin(annotation) is Union and type(None) in typing.get_args( + annotation + ): + # for "typing.Optional" arguments, function_arg might be a + # dictionary like + # + # {'anyOf': [{'type': 'integer'}, {'type': 'null'}] + for schema in function_arg.pop("anyOf", []): + schema_type = schema.get("type") + if schema_type and schema_type != "null": + function_arg["type"] = schema_type + break + function_arg["nullable"] = True + # 6. Annotate required fields. + if required: + # We use the user-provided "required" fields if specified. + parameters["required"] = required + else: + # Otherwise we infer it from the function signature. + parameters["required"] = [ + k + for k in defaults + if ( + defaults[k].default == inspect.Parameter.empty + and defaults[k].kind + in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_ONLY, + ) + ) + ] + schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters) + if schema_name: + schema["name"] = schema_name + return schema + + +def _get_agent_framework( + *, + agent_framework: Optional[str], + agent: _AgentEngineInterface, +) -> Union[str, Any]: + """Gets the agent framework to use. + + The agent framework is determined in the following order of priority: + 1. The `agent_framework` passed to this function. + 2. The `agent_framework` attribute on the `agent` object. + 3. The default framework, "custom". + + Args: + agent_framework (str): + The agent framework provided by the user. + agent (_AgentEngineInterface): + The agent engine instance. + + Returns: + str: The name of the agent framework to use. + """ + if agent_framework is not None and agent_framework in _SUPPORTED_AGENT_FRAMEWORKS: + logger.info(f"Using agent framework: {agent_framework}") + return agent_framework + if hasattr(agent, _AGENT_FRAMEWORK_ATTR): + agent_framework_attr = getattr(agent, _AGENT_FRAMEWORK_ATTR) + if ( + agent_framework_attr is not None + and isinstance(agent_framework_attr, str) + and agent_framework_attr in _SUPPORTED_AGENT_FRAMEWORKS + ): + logger.info(f"Using agent framework: {agent_framework_attr}") + return agent_framework_attr + logger.info( + f"The provided agent framework {agent_framework} is not supported." + f" Defaulting to {_DEFAULT_AGENT_FRAMEWORK}." + ) + return _DEFAULT_AGENT_FRAMEWORK + + +def _get_gcs_bucket( + *, + project: str, + location: str, + staging_bucket: str, + credentials: Optional[Any] = None, +) -> _StorageBucket: + """Gets or creates the GCS bucket.""" + storage = _import_cloud_storage_or_raise() + storage_client = storage.Client(project=project, credentials=credentials) + staging_bucket = staging_bucket.replace("gs://", "") + try: + gcs_bucket = storage_client.get_bucket(staging_bucket) + logger.info(f"Using bucket {staging_bucket}") + except exceptions.NotFound: + new_bucket = storage_client.bucket(staging_bucket) + gcs_bucket = storage_client.create_bucket(new_bucket, location=location) + logger.info(f"Creating bucket {staging_bucket} in {location=}") + return gcs_bucket + + +def _get_registered_operations( + *, + agent: _AgentEngineInterface, +) -> dict[str, list[str]]: + """Retrieves registered operations for a AgentEngine.""" + if isinstance(agent, OperationRegistrable): + return agent.register_operations() + + operations = {} + if isinstance(agent, Queryable): + operations[_STANDARD_API_MODE] = [_DEFAULT_METHOD_NAME] + if isinstance(agent, AsyncQueryable): + operations[_ASYNC_API_MODE] = [_DEFAULT_ASYNC_METHOD_NAME] + if isinstance(agent, StreamQueryable): + operations[_STREAM_API_MODE] = [_DEFAULT_STREAM_METHOD_NAME] + if isinstance(agent, AsyncStreamQueryable): + operations[_ASYNC_STREAM_API_MODE] = [_DEFAULT_ASYNC_STREAM_METHOD_NAME] + if isinstance(agent, BidiStreamQueryable): + operations[_BIDI_STREAM_API_MODE] = [_DEFAULT_BIDI_STREAM_METHOD_NAME] + return operations + + +def _import_cloudpickle_or_raise() -> types.ModuleType: + """Tries to import the cloudpickle module.""" + try: + import cloudpickle # noqa:F401 + except ImportError as e: + raise ImportError( + "cloudpickle is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) from e + return cloudpickle # type: ignore[no-any-return] + + +def _import_cloud_storage_or_raise() -> types.ModuleType: + """Tries to import the Cloud Storage module.""" + try: + from google.cloud import storage # type: ignore[attr-defined] + except ImportError as e: + raise ImportError( + "Cloud Storage is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) from e + return storage # type: ignore[no-any-return] + + +def _import_packaging_requirements_or_raise() -> types.ModuleType: + """Tries to import the packaging.requirements module.""" + try: + from packaging import requirements + except ImportError as e: + raise ImportError( + "packaging.requirements is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) from e + return requirements + + +def _import_packaging_version_or_raise() -> types.ModuleType: + """Tries to import the packaging.requirements module.""" + try: + from packaging import version + except ImportError as e: + raise ImportError( + "packaging.version is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) from e + return version + + +def _import_pydantic_or_raise() -> types.ModuleType: + """Tries to import the pydantic module.""" + try: + import pydantic + + _ = pydantic.Field + except AttributeError: + from pydantic import v1 as pydantic # type: ignore[no-redef] + except ImportError as e: + raise ImportError( + "pydantic is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) from e + return pydantic + + +def _parse_constraints( + *, + constraints: Sequence[str], +) -> Mapping[str, Optional["_SpecifierSet"]]: + """Parses a list of constraints into a dict of requirements. + + Args: + constraints (list[str]): + Required. The list of package requirements to parse. This is assumed + to come from the `requirements.txt` file. + + Returns: + dict[str, SpecifierSet]: The specifiers for each package. + """ + requirements = _import_packaging_requirements_or_raise() + result: Dict[str, Optional[_SpecifierSet]] = {} + for constraint in constraints: + try: + if constraint.endswith(".whl"): + constraint = os.path.basename(constraint) + requirement = requirements.Requirement(constraint) + except Exception as e: + logger.warning(f"Failed to parse constraint: {constraint}. Exception: {e}") + continue + result[requirement.name] = requirement.specifier or None + return result + + +def _prepare( + *, + agent: Optional[_AgentEngineInterface], + requirements: Optional[Sequence[str]], + extra_packages: Optional[Sequence[str]], + project: str, + location: str, + staging_bucket: str, + gcs_dir_name: str, + credentials: Optional[Any] = None, +) -> None: + """Prepares the agent engine for creation or updates in Vertex AI. + + This involves packaging and uploading artifacts to Cloud Storage. Note that + 1. This does not actually update the Agent Engine in Vertex AI. + 2. This will only generate and upload a pickled object if specified. + 3. This will only generate and upload the dependencies.tar.gz file if + extra_packages is non-empty. + + Args: + agent: The agent engine to be prepared. + requirements (Sequence[str]): The set of PyPI dependencies needed. + extra_packages (Sequence[str]): The set of extra user-provided packages. + project (str): The project for the staging bucket. + location (str): The location for the staging bucket. + staging_bucket (str): The staging bucket name in the form "gs://...". + gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to use + for staging the artifacts needed. + credentials: The credentials to use for the storage client. + """ + if agent is None: + return + gcs_bucket = _get_gcs_bucket( + project=project, + location=location, + staging_bucket=staging_bucket, + credentials=credentials, + ) + _upload_agent_engine( + agent=agent, + gcs_bucket=gcs_bucket, + gcs_dir_name=gcs_dir_name, + ) + if requirements is not None: + _upload_requirements( + requirements=requirements, + gcs_bucket=gcs_bucket, + gcs_dir_name=gcs_dir_name, + ) + if extra_packages is not None: + _upload_extra_packages( + extra_packages=extra_packages, + gcs_bucket=gcs_bucket, + gcs_dir_name=gcs_dir_name, + ) + + +def _register_api_methods_or_raise( + *, + agent_engine: genai_types.AgentEngine | genai_types.AgentEngineRuntimeRevision, + wrap_operation_fn: Optional[ + dict[str, Callable[[str, str], Callable[..., Any]]] + ] = None, +) -> None: + """Registers Agent Engine API methods based on operation schemas. + + This function iterates through operation schemas provided by the + `agent_engine`. Each schema defines an API mode and method name. + It dynamically creates and registers methods on the `agent_engine` + to handle API calls based on the specified API mode. + Currently, only standard API mode `` is supported. + + Args: + agent_engine: The AgentEngine to augment with API methods. + wrap_operation_fn: A dictionary of API modes and method wrapping + functions. + + Raises: + ValueError: If the API mode is not supported or if the operation schema + is missing any required fields (e.g. `api_mode` or `name`). + """ + operation_schemas = agent_engine.operation_schemas() + if not operation_schemas: + return + for operation_schema in operation_schemas: + if _MODE_KEY_IN_SCHEMA not in operation_schema: + raise ValueError( + f"Operation schema {operation_schema} does not" + f" contain an `{_MODE_KEY_IN_SCHEMA}` field." + ) + api_mode = operation_schema.get(_MODE_KEY_IN_SCHEMA) + # For bidi stream api mode, we don't need to wrap the operation. + if api_mode == _BIDI_STREAM_API_MODE: + continue + + if _METHOD_NAME_KEY_IN_SCHEMA not in operation_schema: + raise ValueError( + f"Operation schema {operation_schema} does not" + f" contain a `{_METHOD_NAME_KEY_IN_SCHEMA}` field." + ) + method_name = operation_schema.get(_METHOD_NAME_KEY_IN_SCHEMA) + if not isinstance(method_name, str): + raise ValueError( + "Operation schema has a non-string value for" + f" `{_METHOD_NAME_KEY_IN_SCHEMA}`: {method_name}" + ) + method_description = operation_schema.get( + "description", + _DEFAULT_METHOD_DOCSTRING_TEMPLATE.format( + method_name=method_name, + default_method_name=_DEFAULT_METHOD_NAME_MAP.get( + api_mode, _DEFAULT_METHOD_NAME + ), + return_type=_DEFAULT_METHOD_RETURN_TYPE_MAP.get( + api_mode, + _DEFAULT_METHOD_RETURN_TYPE, + ), + ), + ) + _wrap_operation_map = { + _STANDARD_API_MODE: _wrap_query_operation, + _ASYNC_API_MODE: _wrap_async_query_operation, + _STREAM_API_MODE: _wrap_stream_query_operation, + _ASYNC_STREAM_API_MODE: _wrap_async_stream_query_operation, + _A2A_EXTENSION_MODE: _wrap_a2a_operation, + } + if isinstance(wrap_operation_fn, dict) and api_mode in wrap_operation_fn: + # Override the default function with user-specified function if it exists. + _wrap_operation = wrap_operation_fn[api_mode] + elif api_mode in _wrap_operation_map: + _wrap_operation = _wrap_operation_map[api_mode] # type: ignore[assignment] + else: + supported_api_modes = ", ".join( + f"`{mode}`" for mode in sorted(_wrap_operation_map.keys()) + ) + raise ValueError( + f"Unsupported api mode: `{api_mode}`," + f" Supported modes are: {supported_api_modes}." + ) + + # Bind the method to the object. + if api_mode == _A2A_EXTENSION_MODE: + agent_card = operation_schema.get(_A2A_AGENT_CARD) + method = _wrap_operation( + method_name=method_name, agent_card=agent_card + ) # type: ignore[call-arg] + else: + method = _wrap_operation(method_name=method_name) # type: ignore[call-arg] + method.__name__ = method_name + if method_description and isinstance(method_description, str): + method.__doc__ = method_description + setattr(agent_engine, method_name, types.MethodType(method, agent_engine)) + + +def _scan_requirements( + *, + obj: Any, + ignore_modules: Optional[Sequence[str]] = None, + package_distributions: Optional[Mapping[str, Sequence[str]]] = None, + inspect_getmembers_kwargs: Optional[Mapping[str, Any]] = None, +) -> Mapping[str, str]: + """Scans the object for modules and returns the requirements discovered. + + This is not a comprehensive scan of the object, and only detects for common + cases based on the members of the object returned by `dir(obj)`. + + Args: + obj (Any): + Required. The object to scan for package requirements. + ignore_modules (Sequence[str]): + Optional. The set of modules to ignore. It defaults to the set of + built-in and stdlib modules. + package_distributions (Mapping[str, Sequence[str]]): + Optional. The mapping of module names to the set of packages that + contain them. It defaults to the set of packages from + `importlib_metadata.packages_distributions()`. + inspect_getmembers_kwargs (Mapping[str, Any]): + Optional. The keyword arguments to pass to `inspect.getmembers`. It + defaults to an empty dictionary. + + Returns: + Sequence[str]: The list of requirements that were discovered. + """ + if ignore_modules is None: + ignore_modules = _BASE_MODULES # type: ignore[assignment] + if package_distributions is None: + package_distributions = _PACKAGE_DISTRIBUTIONS + modules_found = set(_DEFAULT_REQUIRED_PACKAGES) + inspect_getmembers_kwargs = inspect_getmembers_kwargs or {} + for _, attr in inspect.getmembers(obj, **inspect_getmembers_kwargs): + if not attr or inspect.isbuiltin(attr) or not hasattr(attr, "__module__"): + continue + module_name = (attr.__module__ or "").split(".")[0] + if module_name and module_name not in ignore_modules: # type: ignore[operator] + for module in package_distributions.get(module_name, []): + modules_found.add(module) + return {module: importlib_metadata.version(module) for module in modules_found} + + +def _to_dict(message: proto.Message) -> Dict[str, Any]: + """Converts the contents of the protobuf message to JSON format. + + Args: + message (proto.Message): + Required. The proto message to be converted to a JSON dictionary. + + Returns: + dict[str, Any]: A dictionary containing the contents of the proto. + """ + try: + # Best effort attempt to convert the message into a JSON dictionary. + result: Dict[str, Any] = json.loads( + json_format.MessageToJson( + message._pb, + preserving_proto_field_name=True, + ) + ) + except AttributeError: + result: Dict[str, Any] = json.loads( # type: ignore[no-redef] + json_format.MessageToJson( + message, + preserving_proto_field_name=True, + ) + ) + return result + + +def _to_proto( + obj: Union[Dict[str, Any], proto.Message], + message: Optional[proto.Message] = None, +) -> proto.Message: + """Parses a JSON-like object into a message. + + If the object is already a message, this will return the object as-is. If + the object is a JSON Dict, this will parse and merge the object into the + message. + + Args: + obj (Union[dict[str, Any], proto.Message]): + Required. The object to convert to a proto message. + message (proto.Message): + Optional. A protocol buffer message to merge the obj into. It + defaults to Struct() if unspecified. + + Returns: + proto.Message: The same message passed as argument. + """ + if message is None: + message = struct_pb2.Struct() + if isinstance(obj, (proto.Message, struct_pb2.Struct)): + return obj + try: + json_format.ParseDict(obj, message._pb) + except AttributeError: + json_format.ParseDict(obj, message) + return message + + +def _upload_agent_engine( + *, + agent: _AgentEngineInterface, + gcs_bucket: _StorageBucket, + gcs_dir_name: str, +) -> None: + """Uploads the agent engine to GCS.""" + cloudpickle = _import_cloudpickle_or_raise() + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") + with blob.open("wb") as f: + try: + cloudpickle.dump(agent, f) + except Exception as e: + url = "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/custom#deployment-considerations" + error_msg = f"Failed to serialize agent engine. Visit {url} for details." + if "google._upb._message" in str(e) or "Descriptor" in str(e): + error_msg += ( + " This is often caused by protobuf objects (like Part, AgentCard) " + "being imported at the global module level. Please move these " + "imports inside the functions or methods where they are used. " + "Alternatively, you can import the entire module: " + "`from a2a import types`." + ) + raise TypeError(error_msg) from e + with blob.open("rb") as f: + try: + _ = cloudpickle.load(f) + except Exception as e: + raise TypeError("Agent engine serialized to an invalid format") from e + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" + logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}") + + +def _upload_requirements( + *, + requirements: Sequence[str], + gcs_bucket: _StorageBucket, + gcs_dir_name: str, +) -> None: + """Uploads the requirements file to GCS.""" + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}") + blob.upload_from_string("\n".join(requirements)) + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" + logger.info(f"Writing to {dir_name}/{_REQUIREMENTS_FILE}") + + +def _upload_extra_packages( + *, + extra_packages: Sequence[str], + gcs_bucket: _StorageBucket, + gcs_dir_name: str, +) -> None: + """Uploads extra packages to GCS.""" + logger.info("Creating in-memory tarfile of extra_packages") + tar_fileobj = io.BytesIO() + with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar: + for file in extra_packages: + tar.add(file) + tar_fileobj.seek(0) + blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}") + blob.upload_from_string(tar_fileobj.read()) + dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" + logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}") + + +def _create_base64_encoded_tarball( + *, + source_packages: Sequence[str], +) -> str: + """Creates a base64 encoded tarball from the source packages.""" + logger.info("Creating in-memory tarfile of source_packages") + tar_fileobj = io.BytesIO() + project_dir = os.path.realpath(os.getcwd()) + with tarfile.open(fileobj=tar_fileobj, mode="w|gz") as tar: + for file in source_packages: + real_file_path = os.path.realpath(file) + if real_file_path != project_dir and not real_file_path.startswith( + project_dir + os.sep + ): + raise ValueError( + f"File path '{file}' is outside the project directory " + f"'{project_dir}'." + ) + tar.add(file) + tar_fileobj.seek(0) + tarball_bytes = tar_fileobj.read() + return base64.b64encode(tarball_bytes).decode("utf-8") + + +def _validate_packages_or_raise( + *, + packages: Sequence[str], + build_options: Optional[Dict[str, Sequence[str]]] = None, +) -> Sequence[str]: + """Tries to validates the packages.""" + packages = packages or [] + if build_options and _INSTALLATION_SUBDIR in build_options: + _validate_installation_scripts_or_raise( + script_paths=build_options[_INSTALLATION_SUBDIR], + packages=packages, + ) + for package in packages: + if not os.path.exists(package): + raise FileNotFoundError(f"Package specified but not found: {package=}") + return packages + + +def _validate_installation_scripts_or_raise( + *, + script_paths: Sequence[str], + packages: Sequence[str], +) -> None: + """Validates the installation scripts' path explicitly provided by the user. + + Args: + script_paths (Sequence[str]): + Required. The paths to the installation scripts. + packages (Sequence[str]): + Required. The user-provided packages. + + Raises: + ValueError: If a user-defined script is not under the expected + subdirectory, or not in `packages`, or if a package is + in the installation scripts subdirectory, but is not specified as an + installation script. + """ + for script_path in script_paths: + if not script_path.startswith(_INSTALLATION_SUBDIR): + logger.warning( + f"User-defined installation script '{script_path}' is not in " + f"the expected '{_INSTALLATION_SUBDIR}' subdirectory. " + f"Ensure it is placed in '{_INSTALLATION_SUBDIR}' within your " + f"'extra_packages' or 'source_packages'." + ) + raise ValueError( + f"Required installation script '{script_path}' " + f"is not under '{_INSTALLATION_SUBDIR}'" + ) + + if script_path not in packages: + logger.warning( + f"User-defined installation script '{script_path}' is not in " + f"'extra_packages' or 'source_packages'. Ensure it is added to " + f"'extra_packages' or 'source_packages'." + ) + raise ValueError( + f"User-defined installation script '{script_path}' " + f"does not exist in 'extra_packages' or 'source_packages'." + ) + + for package in packages: + if package.startswith(_INSTALLATION_SUBDIR) and package not in script_paths: + logger.warning( + f"Package '{package}' is in the installation " + "scripts subdirectory, but is not specified as an installation " + "script in `build_options`. " + "Ensure it is added to installation_scripts for " + "automatic execution." + ) + raise ValueError( + f"Package '{package}' is in the installation " + "scripts subdirectory, but is not specified as an installation " + "script in `build_options`." + ) + return + + +def _validate_staging_bucket_or_raise(*, staging_bucket: str) -> str: + """Tries to validate the staging bucket.""" + if not staging_bucket: + raise ValueError( + "Please provide a `staging_bucket` in `client.agent_engines.create(...)`." + ) + if not staging_bucket.startswith("gs://"): + raise ValueError(f"{staging_bucket=} must start with `gs://`") + return staging_bucket + + +def _validate_requirements_or_warn( + *, + obj: Any, + requirements: List[str], +) -> List[str]: + """Compiles the requirements into a list of requirements.""" + requirements = requirements.copy() + try: + current_requirements = _scan_requirements(obj=obj) + logger.info(f"Identified the following requirements: {current_requirements}") + constraints = _parse_constraints(constraints=requirements) + missing_requirements = _compare_requirements( + requirements=current_requirements, + constraints=constraints, + ) + for warning_type, warnings in missing_requirements["warnings"].items(): + if warnings: + logger.warning( + f"The following requirements are {warning_type}: {warnings}" + ) + for action_type, actions in missing_requirements["actions"].items(): + if actions and action_type == _ACTION_APPEND: + for action in actions: # type: ignore[attr-defined] + requirements.append(action) + logger.info(f"The following requirements are appended: {actions}") + except Exception as e: + logger.warning(f"Failed to compile requirements: {e}") + return requirements + + +def _validate_requirements_or_raise( + *, + agent: Any, + requirements: Optional[Sequence[str]] = None, +) -> Sequence[str]: + """Tries to validate the requirements.""" + if requirements is None: + requirements = [] + elif isinstance(requirements, str): + try: + logger.info(f"Reading requirements from {requirements=}") + with open(requirements) as f: + requirements = f.read().splitlines() + logger.info(f"Read the following lines: {requirements}") + except IOError as err: + raise IOError(f"Failed to read requirements from {requirements=}") from err + requirements = _validate_requirements_or_warn( + obj=agent, + requirements=requirements, + ) + logger.info(f"The final list of requirements: {requirements}") + return requirements + + +def _validate_agent_or_raise( + *, + agent: _AgentEngineInterface, +) -> _AgentEngineInterface: + """Tries to validate the agent engine. + + The agent engine must have one of the following: + * a callable method named `query` + * a callable method named `stream_query` + * a callable method named `async_stream_query` + * a callable method named `bidi_stream_query` + * a callable method named `register_operations` + + Args: + agent: The agent to be validated. + + Returns: + The validated agent engine. + + Raises: + TypeError: If `agent_engine` has no callable method named `query`, + `stream_query` or `register_operations`. + ValueError: If `agent_engine` has an invalid `query`, `stream_query` or + `register_operations` signature. + """ + try: + from google.adk.agents import BaseAgent + + if isinstance(agent, BaseAgent): + logger.info("Deploying google.adk.agents.Agent as an application.") + from agentplatform import agent_engines + + agent = agent_engines.AdkApp(agent=agent) + except Exception: + pass + is_queryable = isinstance(agent, Queryable) and callable(agent.query) + is_async_queryable = isinstance(agent, AsyncQueryable) and callable( + agent.async_query + ) + is_stream_queryable = isinstance(agent, StreamQueryable) and callable( + agent.stream_query + ) + is_async_stream_queryable = isinstance(agent, AsyncStreamQueryable) and callable( + agent.async_stream_query + ) + is_bidi_stream_queryable = isinstance(agent, BidiStreamQueryable) and callable( + agent.bidi_stream_query + ) + is_operation_registrable = isinstance(agent, OperationRegistrable) and callable( + agent.register_operations + ) + + if not ( + is_queryable + or is_async_queryable + or is_stream_queryable + or is_operation_registrable + or is_bidi_stream_queryable + or is_async_stream_queryable + ): + raise TypeError( + "agent_engine has none of the following callable methods: " + "`query`, `async_query`, `stream_query`, `async_stream_query`, " + "`bidi_stream_query`, or `register_operations`." + ) + + if is_queryable: + try: + inspect.signature(getattr(agent, "query")) + except ValueError as err: + raise ValueError( + "Invalid query signature. This might be due to a missing " + "`self` argument in the agent.query method." + ) from err + + if is_async_queryable: + try: + inspect.signature(getattr(agent, "async_query")) + except ValueError as err: + raise ValueError( + "Invalid async_query signature. This might be due to a missing " + "`self` argument in the agent.async_query method." + ) from err + + if is_stream_queryable: + try: + inspect.signature(getattr(agent, "stream_query")) + except ValueError as err: + raise ValueError( + "Invalid stream_query signature. This might be due to a missing" + " `self` argument in the agent.stream_query method." + ) from err + + if is_async_stream_queryable: + try: + inspect.signature(getattr(agent, "async_stream_query")) + except ValueError as err: + raise ValueError( + "Invalid async_stream_query signature. This might be due to a " + " missing `self` argument in the agent.async_stream_query method." + ) from err + + if is_bidi_stream_queryable: + try: + inspect.signature(getattr(agent, "bidi_stream_query")) + except ValueError as err: + raise ValueError( + "Invalid bidi_stream_query signature. This might be due to a " + " missing `self` argument in the agent.bidi_stream_query method." + ) from err + + if is_operation_registrable: + try: + inspect.signature(getattr(agent, "register_operations")) + except ValueError as err: + raise ValueError( + "Invalid register_operations signature. This might be due to a " + "missing `self` argument in the agent.register_operations method." + ) from err + + if isinstance(agent, Cloneable): + # Avoid undeployable states. + agent = agent.clone() + return agent + + +def _wrap_agent_operation(*, agent: Any, operation: str) -> Callable[..., Any]: + """Wraps an agent operation into a method (works for all API modes).""" + + def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + if not self._tmpl_attrs.get("agent"): + self.set_up() + return getattr(self._tmpl_attrs["agent"], operation)(**kwargs) + + _method.__name__ = operation + _method.__doc__ = getattr(agent, operation).__doc__ + return _method + + +def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]: + """Wraps an Agent Engine method, creating a callable for `query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `query` API. It handles the creation of + the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `query` API. + """ + + def _method(self: genai_types.AgentEngine, **kwargs) -> Any: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + response = self.api_client._query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ) + return response.output + + return _method + + +def _wrap_async_query_operation( + *, method_name: str +) -> Callable[..., Coroutine[Any, Any, Any]]: + """Wraps an Agent Engine method, creating an async callable for `query` API. + + This function creates a callable object that executes the specified + Agent Engine method asynchronously using the `query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `query` API. + """ + + async def _method( + self: genai_types.AgentEngine, **kwargs: Any + ) -> Union[Coroutine[Any, Any, Any], Any]: + if not self.api_async_client: + raise ValueError("api_async_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + response = await self.api_async_client._query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ) + return response.output + + return _method + + +def _wrap_stream_query_operation(*, method_name: str) -> Callable[..., Iterator[Any]]: + """Wraps an Agent Engine method, creating a callable for `stream_query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `stream_query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `stream_query` API. + """ + + def _method(self: genai_types.AgentEngine, **kwargs) -> Iterator[Any]: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + for http_response in self.api_client._stream_query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ): + for line in _yield_parsed_json(http_response=http_response): + if line is not None: + yield line + + return _method + + +def _wrap_async_stream_query_operation( + *, method_name: str +) -> Callable[..., AsyncIterator[Any]]: + """Wraps an Agent Engine method, creating an async callable for `stream_query` API. + + This function creates a callable object that executes the specified + Agent Engine method using the `stream_query` API. It handles the + creation of the API request and the processing of the API response. + + Args: + method_name: The name of the Agent Engine method to call. + doc: Documentation string for the method. + + Returns: + A callable object that executes the method on the Agent Engine via + the `stream_query` API. + """ + + async def _method(self: genai_types.AgentEngine, **kwargs) -> AsyncIterator[Any]: # type: ignore[no-untyped-def] + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + async for http_response in self.api_client._async_stream_query( + name=self.api_resource.name, + config={ + "class_method": method_name, + "input": kwargs, + "include_all_fields": True, + }, + ): + for line in _yield_parsed_json(http_response=http_response): + if line is not None: + yield line + + return _method + + +def _wrap_a2a_operation(method_name: str, agent_card: str) -> Callable[..., list[Any]]: + """Wraps an Agent Engine method, creating a callable for A2A API. + + Args: + method_name: The name of the Agent Engine method to call. + agent_card: The agent card to use for the A2A API call. + Example: + {'additionalInterfaces': None, + 'capabilities': {'extensions': None, + 'pushNotifications': None, + 'stateTransitionHistory': None, + 'streaming': False}, + 'defaultInputModes': ['text'], + 'defaultOutputModes': ['text'], + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'documentationUrl': None, + 'iconUrl': None, + 'name': 'Q&A Agent', + 'preferredTransport': 'JSONRPC', + 'protocolVersion': '0.3.0', + 'provider': None, + 'security': None, + 'securitySchemes': None, + 'signatures': None, + 'skills': [{ + 'description': ( + 'A helpful assistant agent that can answer questions.' + ), + 'examples': ['Who is leading 2025 F1 Standings?', + 'Where can i find an active volcano?'], + 'id': 'question_answer', + 'inputModes': None, + 'name': 'Q&A Agent', + 'outputModes': None, + 'security': None, + 'tags': ['Question-Answer']}], + 'supportsAuthenticatedExtendedCard': True, + 'url': 'http://localhost:8080/', + 'version': '1.0.0'} + Returns: + A callable object that executes the method on the Agent Engine via + the A2A API. + """ + + async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + """Wraps an Agent Engine method, creating a callable for A2A API.""" + if not self.api_client: + raise ValueError("api_client is not initialized.") + if not self.api_resource: + raise ValueError("api_resource is not initialized.") + a2a_agent_card = AgentCard(**json.loads(agent_card)) + # A2A + AE integration currently only supports Rest API. + if ( + a2a_agent_card.preferred_transport + and a2a_agent_card.preferred_transport != TransportProtocol.http_json + ): + raise ValueError( + "Only HTTP+JSON is supported for preferred transport on agent card " + ) + + # Set preferred transport to HTTP+JSON if not set. + if not hasattr(a2a_agent_card, "preferred_transport"): + a2a_agent_card.preferred_transport = TransportProtocol.http_json + + if not hasattr(a2a_agent_card.capabilities, "streaming"): + a2a_agent_card.capabilities.streaming = False + + # agent_card is set on the class_methods before set_up is invoked. + # Ensure that the agent_card url is set correctly before the client is created. + base_url = self.api_client._api_client._http_options.base_url.rstrip("/") + api_version = self.api_client._api_client._http_options.api_version + a2a_agent_card.url = f"{base_url}/{api_version}/{self.api_resource.name}/a2a" + + # Using a2a client, inject the auth token from the global config. + config = ClientConfig( + supported_transports=[ + TransportProtocol.http_json, + ], + use_client_preference=True, + httpx_client=httpx.AsyncClient( + headers={ + "Authorization": ( + f"Bearer {self.api_client._api_client._credentials.token}" + ) + }, + timeout=( + self.api_client._api_client._http_options.timeout / 1000.0 + if self.api_client._api_client._http_options.timeout + else None + ), + ), + ) + factory = ClientFactory(config) + client = factory.create(a2a_agent_card) + + if method_name == "on_message_send": + response = client.send_message(Message(**kwargs)) + chunks = [] + async for chunk in response: + chunks.append(chunk) + return chunks + elif method_name == "on_get_task": + response = await client.get_task(TaskQueryParams(**kwargs)) + elif method_name == "on_cancel_task": + response = await client.cancel_task(TaskIdParams(**kwargs)) + elif method_name == "handle_authenticated_agent_card": + response = await client.get_card() + else: + raise ValueError(f"Unknown method name: {method_name}") + + return response + + return _method # type: ignore[return-value] + + +def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]: + """Converts the body of the HTTP Response message to JSON format. + + Args: + http_response (google.genai.types.HttpResponse): + Required. The httpbody body to be converted to JSON object(s). + + Yields: + Any: A JSON object or line of the original body or None. + """ + if not http_response.body: + yield None + return + + # Handle the case of multiple dictionaries delimited by newlines. + for line in http_response.body.split("\n"): + if line: + try: + line = json.loads(line) + except Exception as e: + logger.warning(f"failed to parse json: {line}. Exception: {e}") + yield line + + +def _validate_resource_limits_or_raise(resource_limits: dict[str, str]) -> None: + """Validates the resource limits. + + Checks that the resource limits are a dict with 'cpu' and 'memory' keys. + Checks that the 'cpu' value is one of 1, 2, 4, 6, 8. + Checks that the 'memory' value is a string ending with 'Gi'. + Checks that the memory size is smaller than 32Gi. + Checks that the memory size requires at least the specified number of CPUs. + + Args: + resource_limits: The resource limits to be validated. + + Raises: + TypeError: If the resource limits are not a dict. + KeyError: If the resource limits do not contain 'cpu' and 'memory' keys. + ValueError: If the 'cpu' value is not one of 1, 2, 4, 6, 8. + ValueError: If the 'memory' value is not a string ending with 'Gi'. + ValueError: If the memory size is too large. + ValueError: If the memory size requires more CPUs than the specified + 'cpu' value. + """ + if not isinstance(resource_limits, dict): + raise TypeError(f"resource_limits must be a dict. Got {type(resource_limits)}") + if "cpu" not in resource_limits or "memory" not in resource_limits: + raise KeyError("resource_limits must contain 'cpu' and 'memory' keys.") + + cpu = int(resource_limits["cpu"]) + memory_str = resource_limits["memory"] + + if cpu not in [1, 2, 4, 6, 8]: + raise ValueError( + "resource_limits['cpu'] must be one of 1, 2, 4, 6, 8. Got" f" {cpu}" + ) + + if not isinstance(memory_str, str) or not memory_str.endswith("Gi"): + raise ValueError( + "resource_limits['memory'] must be a string ending with 'Gi'." + f" Got {memory_str}" + ) + + try: + memory_gb = int(memory_str[:-2]) + except ValueError: + raise ValueError( + f"Invalid memory value: {memory_str}. Must be an integer" + " followed by 'Gi'." + ) + + # https://cloud.google.com/run/docs/configuring/memory-limits + if memory_gb > 32: + raise ValueError( + f"Memory size of {memory_str} is too large. Must be smaller than 32Gi." + ) + if memory_gb > 24: + min_cpu = 8 + elif memory_gb > 16: + min_cpu = 6 + elif memory_gb > 8: + min_cpu = 4 + elif memory_gb > 4: + min_cpu = 2 + else: + min_cpu = 1 + + if cpu < min_cpu: + raise ValueError( + f"Memory size of {memory_str} requires at least {min_cpu} CPUs." + f" Got {cpu}" + ) + + +def _is_adk_agent(agent_engine: _AgentEngineInterface) -> bool: + """Checks if the agent engine is an ADK agent. + + Args: + agent_engine: The agent engine to check. + + Returns: + True if the agent engine is an ADK agent, False otherwise. + """ + + from agentplatform.agent_engines.templates import adk + + return isinstance(agent_engine, adk.AdkApp) + + +def _add_telemetry_enablement_env( + env_vars: Optional[Dict[str, Union[str, Any]]] +) -> Optional[Dict[str, Union[str, Any]]]: + """Adds telemetry enablement env var to the env vars. + + This is in order to achieve default-on telemetry. + If the telemetry enablement env var is already set, we do not override it. + + Args: + env_vars: The env vars to add the telemetry enablement env var to. + + Returns: + The env vars with the telemetry enablement env var added. + """ + + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" + ) + env_to_add = {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"} + + if env_vars is None: + return env_to_add + + if not isinstance(env_vars, dict): + raise TypeError(f"env_vars must be a dict, but got {type(env_vars)}.") + + if GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY in env_vars: + return env_vars + + return env_vars | env_to_add diff --git a/agentplatform/_genai/_bigquery_utils.py b/agentplatform/_genai/_bigquery_utils.py new file mode 100644 index 0000000000..94813df432 --- /dev/null +++ b/agentplatform/_genai/_bigquery_utils.py @@ -0,0 +1,49 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import logging + +from google.cloud import bigquery +from google.genai._api_client import BaseApiClient +import pandas as pd + + +logger = logging.getLogger(__name__) + + +class BigQueryUtils: + """Handles BigQuery operations.""" + + def __init__(self, api_client: BaseApiClient): + self.api_client = api_client + self.bigquery_client = bigquery.Client( + project=self.api_client.project, + credentials=self.api_client._credentials, + ) + + def load_bigquery_to_dataframe(self, table_uri: str) -> "pd.DataFrame": + """Loads data from a BigQuery table into a DataFrame.""" + table = self.bigquery_client.get_table(table_uri) + return self.bigquery_client.list_rows(table).to_dataframe() + + def upload_dataframe_to_bigquery( + self, df: "pd.DataFrame", bq_table_uri: str + ) -> None: + """Uploads a Pandas DataFrame to a BigQuery table.""" + job = self.bigquery_client.load_table_from_dataframe(df, bq_table_uri) + job.result() + logger.info( + f"DataFrame successfully uploaded to BigQuery table: {bq_table_uri}" + ) diff --git a/agentplatform/_genai/_datasets_utils.py b/agentplatform/_genai/_datasets_utils.py new file mode 100644 index 0000000000..7b0774eeaa --- /dev/null +++ b/agentplatform/_genai/_datasets_utils.py @@ -0,0 +1,280 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for multimodal dataset.""" + +import asyncio +import datetime +from typing import Any, Type, TypeVar +import uuid + +import google.auth.credentials +from agentplatform._genai.types import common +from google.genai import _common + + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +_BQ_MULTIREGIONS = {"us", "eu"} +_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets" +_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset" + +T = TypeVar("T", bound=_common.BaseModel) + + +def create_from_response( + model_type: Type[T], + response: dict[str, Any], + config: Any | None = None, +) -> T: + """Creates a model from a response.""" + kwargs = ( + { + "config": { + "response_schema": getattr(config, "response_schema", None), + "response_json_schema": getattr(config, "response_json_schema", None), + "include_all_fields": getattr(config, "include_all_fields", None), + } + } + if config + else {} + ) + return model_type._from_response(response=response, kwargs=kwargs) + + +def validate_multimodal_dataset_bigquery_uri( + multimodal_dataset: common.MultimodalDataset, +) -> None: + """Validates that a multimodal dataset has a bigquery uri or raises ValueError.""" + if ( + not hasattr(multimodal_dataset, "metadata") + or multimodal_dataset.metadata is None + ): + raise ValueError("Multimodal dataset metadata is required.") + if ( + not hasattr(multimodal_dataset.metadata, "input_config") + or multimodal_dataset.metadata.input_config is None + ): + raise ValueError("Multimodal dataset input config is required.") + if ( + not hasattr(multimodal_dataset.metadata.input_config, "bigquery_source") + or multimodal_dataset.metadata.input_config.bigquery_source is None + ): + raise ValueError("Multimodal dataset input config bigquery source is required.") + if ( + not hasattr(multimodal_dataset.metadata.input_config.bigquery_source, "uri") + or multimodal_dataset.metadata.input_config.bigquery_source.uri is None + ): + raise ValueError( + "Multimodal dataset input config bigquery source uri is required." + ) + if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith( + "bq://" + ): + raise ValueError( + "Multimodal dataset bigquery source uri must start with 'bq://'." + ) + + +def _try_import_bigframes() -> Any: + """Tries to import `bigframes`.""" + try: + import bigframes + import bigframes.pandas + import bigframes.bigquery + + return bigframes + except ImportError as exc: + raise ImportError( + "`bigframes` is not installed. Please call 'pip install bigframes'." + ) from exc + + +def _try_import_bigquery() -> Any: + """Tries to import `bigquery`.""" + try: + from google.cloud import bigquery + + return bigquery + except ImportError as exc: + raise ImportError( + "`bigquery` is not installed. Please call 'pip install" + " google-cloud-bigquery'." + ) from exc + + +def _bq_dataset_location_allowed( + vertex_location: str, bq_dataset_location: str +) -> bool: + if bq_dataset_location == vertex_location: + return True + if bq_dataset_location in _BQ_MULTIREGIONS: + return vertex_location.startswith(bq_dataset_location) + return False + + +def _normalize_and_validate_table_id( + *, + table_id: str, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + bigquery = _try_import_bigquery() + + table_ref = bigquery.TableReference.from_string(table_id, default_project=project) + if table_ref.project != project: + raise ValueError( + "The BigQuery table " + f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`" + " must be in the same project as the multimodal dataset." + f" The multimodal dataset is in `{project}`, but the BigQuery table" + f" is in `{table_ref.project}`." + ) + + dataset_ref = bigquery.DatasetReference( + project=table_ref.project, dataset_id=table_ref.dataset_id + ) + client = bigquery.Client(project=project, credentials=credentials) + bq_dataset = client.get_dataset(dataset_ref=dataset_ref) + if not _bq_dataset_location_allowed(location, bq_dataset.location): + raise ValueError( + "The BigQuery dataset" + f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the" + " same location as the multimodal dataset. The multimodal dataset" + f" is in `{location}`, but the BigQuery dataset is in" + f" `{bq_dataset.location}`." + ) + return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + + +async def _normalize_and_validate_table_id_async( + *, + table_id: str, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + bigquery = _try_import_bigquery() + + table_ref = bigquery.TableReference.from_string(table_id, default_project=project) + if table_ref.project != project: + raise ValueError( + "The BigQuery table " + f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`" + " must be in the same project as the multimodal dataset." + f" The multimodal dataset is in `{project}`, but the BigQuery table" + f" is in `{table_ref.project}`." + ) + + dataset_ref = bigquery.DatasetReference( + project=table_ref.project, dataset_id=table_ref.dataset_id + ) + client = bigquery.Client(project=project, credentials=credentials) + bq_dataset = await asyncio.to_thread(client.get_dataset, dataset_ref=dataset_ref) + if not _bq_dataset_location_allowed(location, bq_dataset.location): + raise ValueError( + "The BigQuery dataset" + f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the" + " same location as the multimodal dataset. The multimodal dataset" + f" is in `{location}`, but the BigQuery dataset is in" + f" `{bq_dataset.location}`." + ) + return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" + + +def _create_default_bigquery_dataset_if_not_exists( + *, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + bigquery = _try_import_bigquery() + + bigquery_client = bigquery.Client(project=project, credentials=credentials) + location_str = location.lower().replace("-", "_") + dataset_id = bigquery.DatasetReference( + project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}" + ) + dataset = bigquery.Dataset(dataset_ref=dataset_id) + dataset.location = location + bigquery_client.create_dataset(dataset, exists_ok=True) + return f"{dataset_id.project}.{dataset_id.dataset_id}" + + +async def _create_default_bigquery_dataset_if_not_exists_async( + *, + project: str, + location: str, + credentials: google.auth.credentials.Credentials, +) -> str: + bigquery = _try_import_bigquery() + + bigquery_client = bigquery.Client(project=project, credentials=credentials) + location_str = location.lower().replace("-", "_") + dataset_id = bigquery.DatasetReference( + project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}" + ) + dataset = bigquery.Dataset(dataset_ref=dataset_id) + dataset.location = location + await asyncio.to_thread(bigquery_client.create_dataset, dataset, exists_ok=True) + return f"{dataset_id.project}.{dataset_id.dataset_id}" + + +def _generate_target_table_id(dataset_id: str) -> str: + return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}" + + +def generate_multimodal_dataset_display_name() -> str: + """Generates a display name with a timestamp.""" + return f"MultimodalDataset {datetime.datetime.now().isoformat(sep=' ')}" + + +def save_dataframe_to_bigquery( + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + target_table_id: str, + bq_client: "bigquery.Client", # type: ignore # noqa: F821 +) -> None: + # `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround. + temp_table_id = dataframe.to_gbq() + copy_job = bq_client.copy_table( + sources=temp_table_id, + destination=target_table_id, + ) + copy_job.result() + bq_client.delete_table(temp_table_id) + + +async def save_dataframe_to_bigquery_async( + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + target_table_id: str, + bq_client: "bigquery.Client", # type: ignore # noqa: F821 +) -> None: + # `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround. + temp_table_id = await asyncio.to_thread(dataframe.to_gbq) + copy_job = await asyncio.to_thread( + bq_client.copy_table, + sources=temp_table_id, + destination=target_table_id, + ) + await asyncio.to_thread(copy_job.result) + await asyncio.to_thread(bq_client.delete_table, temp_table_id) + + +def resolve_dataset_name(resource_name_or_id: str, project: str, location: str) -> str: + """Resolves a dataset name or ID to a full resource name.""" + if "/" not in resource_name_or_id: + return f"projects/{project}/locations/{location}/datasets/{resource_name_or_id}" + return resource_name_or_id diff --git a/agentplatform/_genai/_evals_common.py b/agentplatform/_genai/_evals_common.py new file mode 100644 index 0000000000..cfe31fe351 --- /dev/null +++ b/agentplatform/_genai/_evals_common.py @@ -0,0 +1,2975 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Common utilities for evals.""" + +import asyncio +import base64 +import collections +import concurrent.futures +import contextlib +import datetime +import json +import logging +import os +import threading +import time +from typing import Any, Callable, Literal, Optional, Union, cast +import uuid + +from google.api_core import exceptions as api_exceptions +import agentplatform +from google.genai import types as genai_types +from google.genai._api_client import BaseApiClient +from google.genai.models import Models +import pandas as pd +from tqdm import tqdm +from pydantic import ValidationError + +from . import _evals_constant +from . import _evals_data_converters +from . import _evals_metric_handlers +from . import _evals_metric_loaders +from . import _evals_utils +from . import _gcs_utils +from . import evals +from . import types +from . import _transformers as t + +logger = logging.getLogger(__name__) + +try: + import litellm +except ImportError: + litellm = None + + +_thread_local_data = threading.local() + +MAX_WORKERS = 100 +AGENT_MAX_WORKERS = 20 +CONTENT = _evals_constant.CONTENT +PARTS = _evals_constant.PARTS +USER_AUTHOR = _evals_constant.USER_AUTHOR +AGENT_DATA = _evals_constant.AGENT_DATA + + +@contextlib.contextmanager +def _temp_logger_level(logger_name: str, level: int) -> None: # type: ignore[misc] + """Temporarily sets the level of a logger.""" + logger_instance = logging.getLogger(logger_name) + original_level = logger_instance.getEffectiveLevel() + logger_instance.setLevel(level) + try: + yield + finally: + logger_instance.setLevel(original_level) + + +def _get_api_client_with_location( + api_client: BaseApiClient, location: Optional[str] +) -> BaseApiClient: + """Returns a new API client with the specified location.""" + if not location or location == api_client.location: + return api_client + + logger.info( + "Model endpoint location set to %s, overriding client location %s for" + " this API call.", + location, + api_client.location, + ) + return agentplatform.Client( # type: ignore[no-any-return] + project=api_client.project, + location=location, + credentials=api_client._credentials, + http_options=api_client._http_options, + )._api_client + + +def _get_agent_engine_instance( + agent_name: str, api_client: BaseApiClient +) -> Union[types.AgentEngine, Any]: + """Gets or creates an agent engine instance for the current thread.""" + if not hasattr(_thread_local_data, "agent_engine_instances"): + _thread_local_data.agent_engine_instances = {} + if agent_name not in _thread_local_data.agent_engine_instances: + client = agentplatform.Client( + project=api_client.project, + location=api_client.location, + ) + _thread_local_data.agent_engine_instances[agent_name] = ( + client.agent_engines.get(name=agent_name) + ) + return _thread_local_data.agent_engine_instances[agent_name] + + +def _generate_content_with_retry( + api_client: BaseApiClient, + model: str, + contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], + config: Optional[genai_types.GenerateContentConfig] = None, + max_retries: int = 3, +) -> Union[genai_types.GenerateContentResponse, dict[str, Any]]: + """Generates content using the model's generate_content with retries.""" + models_module = Models(api_client_=api_client) + + for attempt in range(max_retries): + try: + response = models_module.generate_content( + model=model, + contents=contents, + config=config, + ) + if not response.candidates: + logger.warning( + "Prompt blocked. Attempt %d/%d. Feedback: %s. Prompt: %s.", + attempt + 1, + max_retries, + response.prompt_feedback, + contents, + ) + if attempt == max_retries - 1: + feedback_dict = {} + if response.prompt_feedback: + feedback_dict = response.prompt_feedback.model_dump( + mode="json", exclude_none=True + ) + return { + "error": "Prompt blocked after retries", + "prompt_feedback": feedback_dict, + } + else: + candidate = response.candidates[0] + if candidate.finish_reason not in ( + genai_types.FinishReason.STOP, + genai_types.FinishReason.MAX_TOKENS, + genai_types.FinishReason.FINISH_REASON_UNSPECIFIED, + ): + logger.warning( + "Generate content did not finish successfully." + "Finish reason: %s. Finish message: %s." + "Retry attempt: %d/%d", + candidate.finish_reason, + candidate.finish_message, + attempt + 1, + max_retries, + ) + if attempt == max_retries - 1: + return { + "error": ( + "Generate content unsuccessful after retries:" + f" {candidate.finish_reason}" + ), + "finish_reason": str(candidate.finish_reason), + "finish_message": candidate.finish_message or "", + } + else: + return response + except api_exceptions.ResourceExhausted as e: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" + " seconds...", + attempt + 1, + max_retries, + e, + 2**attempt, + ) + if attempt == max_retries - 1: + return {"error": f"Resource exhausted after retries: {e}"} + time.sleep(2**attempt) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Unexpected error during generate_content on attempt %d/%d: %s", + attempt + 1, + max_retries, + e, + ) + + if attempt == max_retries - 1: + return {"error": f"Failed after retries: {e}"} + time.sleep(1) + return {"error": f"Failed to generate content after {max_retries} retries"} + + +def _build_generate_content_config( + request_dict: Union[dict[str, Any], str], + global_config: Optional[genai_types.GenerateContentConfig] = None, +) -> genai_types.GenerateContentConfig: + """Builds a GenerateContentConfig from the request dictionary or provided config.""" + if global_config: + # If a global config is provided, apply it as a base config. Parts of + # the global config can be overridden by providing configs in the + # request. + merged_config_dict = global_config.model_dump(exclude_none=True) + else: + merged_config_dict = {} + + if not isinstance(request_dict, dict): + return genai_types.GenerateContentConfig(**merged_config_dict) + + for key in [ + "system_instruction", + "tools", + "tools_config", + "safety_settings", + "labels", + ]: + if key in request_dict: + merged_config_dict[key] = request_dict[key] + if "generation_config" in request_dict and isinstance( + request_dict["generation_config"], dict + ): + merged_config_dict.update(request_dict["generation_config"]) + if "labels" in request_dict: + merged_config_dict["labels"] = request_dict["labels"] + + return genai_types.GenerateContentConfig(**merged_config_dict) + + +def _extract_contents_for_inference( + request_dict_or_raw_text: Any, +) -> Any: + """Extracts contents from a request dictionary or returns the raw text.""" + if not request_dict_or_raw_text: + raise ValueError("Prompt cannot be empty.") + if isinstance(request_dict_or_raw_text, dict): + contents_for_fn = request_dict_or_raw_text.get("contents", None) + if not contents_for_fn: + raise ValueError("Contents in the request cannot be empty.") + return contents_for_fn + else: + return request_dict_or_raw_text + + +def _eval_cases_to_dataframe( + eval_cases: list[types.EvalCase], +) -> pd.DataFrame: + """Converts a list of EvalCase objects to a pandas DataFrame. + + Each EvalCase is converted to a row in the DataFrame. Structured fields + like ``agent_data`` are preserved as-is (not flattened) so that downstream + agent execution paths can consume them directly. + + Args: + eval_cases: The list of EvalCase objects to convert. + + Returns: + A DataFrame with one row per EvalCase. + """ + rows = [] + for case in eval_cases: + row: dict[str, Any] = {} + if case.prompt: + row[_evals_constant.PROMPT] = _evals_data_converters._get_content_text( + case.prompt + ) + + if case.responses and len(case.responses) > 0 and case.responses[0].response: + row[_evals_constant.RESPONSE] = _evals_data_converters._get_content_text( + case.responses[0].response + ) + + if case.reference and case.reference.response: + row[_evals_constant.REFERENCE] = _evals_data_converters._get_content_text( + case.reference.response + ) + + if case.agent_data: + row[AGENT_DATA] = case.agent_data + + if case.intermediate_events: + row[_evals_constant.INTERMEDIATE_EVENTS] = [ + {CONTENT: event.content} + for event in case.intermediate_events + if event.content + ] + + if case.conversation_history: + history_parts = [] + for msg in case.conversation_history: + if msg.content: + role = msg.content.role or "user" + text = _evals_data_converters._get_content_text(msg.content) + history_parts.append(f"{role}: {text}") + if history_parts: + row[_evals_constant.CONVERSATION_HISTORY] = "\n".join(history_parts) + + if case.user_scenario: + if case.user_scenario.starting_prompt: + row[_evals_constant.STARTING_PROMPT] = ( + case.user_scenario.starting_prompt + ) + if case.user_scenario.conversation_plan: + row[_evals_constant.CONVERSATION_PLAN] = ( + case.user_scenario.conversation_plan + ) + + rows.append(row) + return pd.DataFrame(rows) + + +def _extract_prompt_from_agent_data( + agent_data: types.evals.AgentData, +) -> tuple[genai_types.Content, list[types.evals.AgentEvent]]: + """Extracts the last user message and prior events from agent_data. + + The last event across all turns must be authored by ``"user"``; it is + treated as the current prompt that the agent should respond to. + Everything before it is returned as conversation history. + + Args: + agent_data: The AgentData containing conversation turns. + + Returns: + A tuple of ``(last_user_content, history_events)`` where + ``last_user_content`` is the ``Content`` of the final user event + and ``history_events`` is the ordered list of all prior + ``AgentEvent`` objects. + + Raises: + ValueError: If ``agent_data`` has no turns, no events, or the last + event is not a user event. + """ + if not agent_data.turns: + raise ValueError("agent_data must have at least one turn.") + + all_events: list[types.evals.AgentEvent] = [] + for turn in agent_data.turns: + if turn.events: + all_events.extend(turn.events) + + if not all_events: + raise ValueError("agent_data turns contain no events.") + + last_event = all_events[-1] + if last_event.author != USER_AUTHOR: + raise ValueError( + "agent_data must end with a user event, but the last event has" + f" author='{last_event.author}'." + ) + + if not last_event.content: + raise ValueError("The last user event in agent_data has no content.") + + return last_event.content, all_events[:-1] + + +def _is_n_plus_1_inference( + agent_data: Union[types.evals.AgentData, dict[str, Any]], +) -> bool: + """Returns True if agent_data represents an N+1 inference case. + + An N+1 case means the trace is incomplete: N prior conversation turns + exist plus 1 final user query that the agent should respond to. This + is detected by checking whether the very last event across all turns + is authored by ``"user"``. + + Returns ``False`` for completed traces (last event from the agent), + empty traces, or invalid data. + """ + if isinstance(agent_data, dict): + try: + agent_data = types.evals.AgentData.model_validate(agent_data) + except Exception: # pylint: disable=broad-exception-caught + return False + if not isinstance(agent_data, types.evals.AgentData): + return False + if not agent_data.turns: + return False + all_events: list[types.evals.AgentEvent] = [] + for turn in agent_data.turns or []: + if turn.events: + all_events.extend(turn.events) + if not all_events: + return False + return all_events[-1].author == USER_AUTHOR + + +def _extract_response_from_completed_trace( + agent_data: types.evals.AgentData, +) -> list[dict[str, Any]]: + """Extracts all events from a completed agent trace as event dicts. + + For BYOD (bring-your-own-data) use cases where the agent trace is + already complete, this returns all events formatted as a list of + dicts compatible with ``_process_single_turn_agent_response``. The + last element is the final agent response; preceding elements become + intermediate events. + """ + event_dicts: list[dict[str, Any]] = [] + for turn in agent_data.turns or []: + if not turn.events: + continue + for event in turn.events: + d: dict[str, Any] = {"author": event.author or "agent"} + if event.content: + d[CONTENT] = event.content.model_dump(exclude_none=True) + event_dicts.append(d) + return event_dicts + + +def _resolve_dataset( + api_client: BaseApiClient, + dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset], + dest: str, + parsed_agent_info: Optional[types.evals.AgentInfo] = None, +) -> types.EvaluationRunDataSource: + """Resolves dataset for the evaluation run.""" + if isinstance(dataset, types.EvaluationDataset): + candidate_name = _get_candidate_name(dataset, parsed_agent_info) + eval_df = dataset.eval_dataset_df + if eval_df is None and dataset.eval_cases: + eval_df = _eval_cases_to_dataframe(dataset.eval_cases) + + eval_set = _create_evaluation_set_from_dataframe( + api_client, + dest, + eval_df, + candidate_name, + ) + dataset = types.EvaluationRunDataSource(evaluation_set=eval_set.name) + return dataset + + +def _get_default_prompt_template( + api_client: BaseApiClient, + inference_config: types.EvaluationRunInferenceConfigOrDict, + dataset: types.EvaluationRunDataSource, +) -> Any: + """Resolves prompt template data for the evaluation run.""" + if isinstance(inference_config, dict): + if inference_config.get("prompt_template"): + return inference_config["prompt_template"] + elif inference_config.prompt_template: + return inference_config.prompt_template + + try: + evals_module = evals.Evals(api_client_=api_client) + eval_set = evals_module.get_evaluation_set(name=dataset.evaluation_set) + if eval_set and eval_set.evaluation_items: + eval_item = evals_module.get_evaluation_item( + name=eval_set.evaluation_items[0] + ) + if ( + eval_item + and eval_item.evaluation_request + and eval_item.evaluation_request.prompt + and eval_item.evaluation_request.prompt.prompt_template_data + ): + if ( + "prompt" + in eval_item.evaluation_request.prompt.prompt_template_data.values + ): + return "{prompt}" + except Exception as e: + logger.warning("Failed to get prompt template from evaluation set: %s", e) + return None + + +def _resolve_inference_configs( + api_client: BaseApiClient, + dataset: types.EvaluationRunDataSource, + inference_configs: Optional[ + dict[str, types.EvaluationRunInferenceConfigOrDict] + ] = None, + parsed_agent_info: Optional[types.evals.AgentInfo] = None, +) -> Optional[dict[str, types.EvaluationRunInferenceConfigOrDict]]: + """Resolves inference configs for the evaluation run.""" + # Resolve agent config + if parsed_agent_info and parsed_agent_info.name: + if inference_configs is None: + inference_configs = {} + + # We might have used "candidate-1" as a placeholder key in the caller, + # let's migrate it to the agent name, or if it doesn't exist, just create it. + if "candidate-1" in inference_configs: + inference_configs[parsed_agent_info.name] = inference_configs.pop( + "candidate-1" + ) + + if parsed_agent_info.name not in inference_configs: + inference_configs[parsed_agent_info.name] = ( + types.EvaluationRunInferenceConfig( + agent_configs=parsed_agent_info.agents + ) + ) + else: + config = inference_configs[parsed_agent_info.name] + if isinstance(config, dict): + config["agent_configs"] = parsed_agent_info.agents + else: + config.agent_configs = parsed_agent_info.agents + + # Resolve prompt template data + if inference_configs: + for inference_config in inference_configs.values(): + prompt_template_val = ( + inference_config.get("prompt_template") + if isinstance(inference_config, dict) + else inference_config.prompt_template + ) + if not prompt_template_val: + default_prompt_template = _get_default_prompt_template( + api_client, inference_config, dataset + ) + if default_prompt_template: + prompt_template_to_set = default_prompt_template + if not isinstance( + default_prompt_template, types.EvaluationRunPromptTemplate + ): + prompt_template_to_set = types.EvaluationRunPromptTemplate( + prompt_template=default_prompt_template + ) + if isinstance(inference_config, dict): + inference_config["prompt_template"] = ( + prompt_template_to_set.model_dump(exclude_none=True) + ) + else: + inference_config.prompt_template = ( + prompt_template_to_set.model_dump(exclude_none=True) + ) + return inference_configs + + +def _add_evaluation_run_labels( + labels: Optional[dict[str, str]] = None, + agent: Optional[str] = None, +) -> Optional[dict[str, str]]: + """Adds labels to the evaluation run.""" + if agent: + labels = labels or {} + labels["vertex-ai-evaluation-agent-engine-id"] = agent.split( + "reasoningEngines/" + )[-1] + return labels + + +def _get_candidate_name( + dataset: types.EvaluationDataset, + parsed_agent_info: Optional[types.evals.AgentInfo] = None, +) -> Optional[str]: + """Internal helper to get candidate name.""" + if parsed_agent_info is not None and ( + dataset.candidate_name + and parsed_agent_info + and parsed_agent_info.name + and dataset.candidate_name != parsed_agent_info.name + ): + logger.warning( + "Evaluation dataset candidate_name and agent_info.name are different. Please make sure this is intended." + ) + elif dataset.candidate_name is None and parsed_agent_info: + return parsed_agent_info.name + return dataset.candidate_name or None + + +def _execute_inference_concurrently( + api_client: BaseApiClient, + prompt_dataset: pd.DataFrame, + progress_desc: str, + model_or_fn: Optional[Union[str, Callable[[Any], Any]]] = None, + gemini_config: Optional[genai_types.GenerateContentConfig] = None, + inference_fn: Optional[Callable[..., Any]] = None, + agent_engine: Optional[Union[str, types.AgentEngine]] = None, + agent: Optional["LlmAgent"] = None, # type: ignore # noqa: F821 + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, +) -> list[ + Union[ + genai_types.GenerateContentResponse, + dict[str, Any], + list[dict[str, Any]], + ] +]: + """Internal helper to run inference with concurrency.""" + logger.info( + "Generating responses for %d prompts using model or function: %s", + len(prompt_dataset), + model_or_fn, + ) + responses: list[ + Union[ + genai_types.GenerateContentResponse, + dict[str, Any], + list[dict[str, Any]], + None, + ] + ] = [None] * len(prompt_dataset) + tasks = [] + + # When running with an agent and agent_data is present, we extract the + # prompt from the structured agent_data rather than requiring a flat + # prompt/request column. + has_agent_data = ( + agent is not None or agent_engine is not None + ) and AGENT_DATA in prompt_dataset.columns + + primary_prompt_column: Optional[str] = None + if "request" in prompt_dataset.columns: + primary_prompt_column = "request" + elif "prompt" in prompt_dataset.columns: + primary_prompt_column = "prompt" + elif "starting_prompt" in prompt_dataset.columns: + primary_prompt_column = "starting_prompt" + elif not has_agent_data: + raise ValueError( + "Dataset must contain either 'prompt', 'request', or" + " 'starting_prompt'." + f" Found: {prompt_dataset.columns.tolist()}" + ) + + max_workers = AGENT_MAX_WORKERS if agent_engine or agent else MAX_WORKERS + with tqdm(total=len(prompt_dataset), desc=progress_desc) as pbar: + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + for index, row in prompt_dataset.iterrows(): + try: + if ( + has_agent_data + and AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + ): + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate( + agent_data_obj + ) + if _is_n_plus_1_inference(agent_data_obj): + last_user_content, _ = _extract_prompt_from_agent_data( + agent_data_obj + ) + contents = _evals_data_converters._get_content_text( + last_user_content + ) + else: + logger.info( + "Row %s has a completed agent trace" + " (last event is not from user)." + " Skipping inference and using existing" + " agent response.", + index, + ) + responses[index] = _extract_response_from_completed_trace( + agent_data_obj + ) + pbar.update(1) + continue + else: + if primary_prompt_column is None: + raise ValueError( + "Row has no agent_data and dataset has no" + " 'prompt', 'request', or 'starting_prompt'" + " column." + ) + request_dict_or_raw_text = row[primary_prompt_column] + contents = _extract_contents_for_inference( + request_dict_or_raw_text + ) + except ValueError as e: + error_message = ( + f"Failed to extract contents for prompt at index" + f" {index}: {e}. Skipping prompt." + ) + logger.error(error_message) + responses[index] = {"error": error_message} + pbar.update(1) + continue + + if agent_engine or agent: + + def agent_run_wrapper( # type: ignore[no-untyped-def] + row_arg, + contents_arg, + agent_engine_arg, + agent_arg, + inference_fn_arg, + api_client_arg, + user_simulator_config_arg, + ) -> Any: + if agent_engine_arg: + if isinstance(agent_engine_arg, str): + agent_engine_instance = _get_agent_engine_instance( + agent_engine_arg, api_client_arg + ) + else: + agent_engine_instance = agent_engine_arg + + return inference_fn_arg( + row=row_arg, + contents=contents_arg, + agent_engine=agent_engine_instance, + ) + elif agent_arg: + return inference_fn_arg( + row=row_arg, + contents=contents_arg, + user_simulator_config=user_simulator_config_arg, + agent=agent_arg, + ) + + future = executor.submit( + agent_run_wrapper, + row, + contents, + agent_engine, + agent, + inference_fn, + api_client, + user_simulator_config, + ) + elif isinstance(model_or_fn, str): + generation_content_config = _build_generate_content_config( + request_dict_or_raw_text, + gemini_config, + ) + future = executor.submit( + inference_fn, + api_client=api_client, + model=model_or_fn, + contents=contents, + config=generation_content_config, + ) + else: + future = executor.submit(model_or_fn, contents) + future.add_done_callback(lambda _: pbar.update(1)) + tasks.append((future, index)) + + for future, index in tasks: + try: + result = future.result() + responses[index] = result + except Exception as e: + logger.error( + "Error processing prompt at index %d: %s", + index, + e, + ) + responses[index] = {"error": f"Inference task failed: {e}"} + return responses # type: ignore[return-value] + + +def _run_gemini_inference( + api_client: BaseApiClient, + model: str, + prompt_dataset: pd.DataFrame, + config: Optional[genai_types.GenerateContentConfig] = None, +) -> list[ + Union[ + genai_types.GenerateContentResponse, + dict[str, Any], + list[dict[str, Any]], + ] +]: + """Internal helper to run inference using Gemini model with concurrency.""" + return _execute_inference_concurrently( + api_client=api_client, + model_or_fn=model, + prompt_dataset=prompt_dataset, + progress_desc="Gemini Inference", + gemini_config=config, + inference_fn=_generate_content_with_retry, + ) + + +def _run_custom_inference( + model_fn: Callable[[Any], Any], + prompt_dataset: pd.DataFrame, +) -> list[Any]: + """Internal helper to run inference using a custom function with concurrency.""" + return _execute_inference_concurrently( + api_client=None, + model_or_fn=model_fn, + prompt_dataset=prompt_dataset, + progress_desc="Custom Inference", + ) + + +def _convert_prompt_row_to_litellm_messages( + row: pd.Series, +) -> list[dict[str, Any]]: + """Converts a DataFrame row into LiteLLM's messages format by detecting the input schema.""" + messages: list[dict[str, Any]] = [] + row_dict = row.to_dict() + + # Case 1: The row is an OpenAI request body itself. + if "messages" in row_dict and isinstance(row_dict.get("messages"), list): + return row_dict["messages"] # type: ignore[no-any-return] + + # Case 2: The row contains a 'request' key with an OpenAI request body. + elif "request" in row_dict and isinstance(row_dict.get("request"), dict): + request_body = row_dict["request"] + if "messages" in request_body and isinstance( + request_body.get("messages"), list + ): + return request_body["messages"] # type: ignore[no-any-return] + + # Case 3: The 'request' key is in Gemini 'contents' format. + elif "contents" in request_body and isinstance( + request_body.get("contents"), list + ): + for content in request_body["contents"]: + role = content.get("role", USER_AUTHOR) + text_parts = [part.get("text", "") for part in content.get("parts", [])] + messages.append({"role": role, "content": " ".join(text_parts)}) + return messages + + # Case 4: Fallback to a simple 'prompt' key with a raw string. + elif "prompt" in row_dict and isinstance(row_dict.get("prompt"), str): + return [{"role": USER_AUTHOR, "content": row_dict["prompt"]}] + + raise ValueError( + "Could not determine prompt/messages format from input row. Expected" + " OpenAI request body with a 'messages' key, or a 'request' key with" + " OpenAI request body, or Gemini request body with a 'contents' key, or" + f" a 'prompt' key with a raw string. Found keys: {list(row_dict.keys())}" + ) + + +def _call_litellm_completion( + model: str, messages: list[dict[str, Any]] +) -> dict[str, Any]: + """Wrapper for a single litellm.completion call.""" + try: + response = litellm.completion(model=model, messages=messages) + return response.model_dump() # type: ignore[no-any-return] + except Exception as e: + logger.error("LiteLLM completion failed for model %s: %s", model, e) + return {"error": str(e)} + + +def _run_litellm_inference( + model: str, prompt_dataset: pd.DataFrame +) -> list[Optional[dict[str, Any]]]: + """Runs inference using LiteLLM with concurrency.""" + logger.info( + "Generating responses for %d prompts using LiteLLM for third party model: %s", + len(prompt_dataset), + model, + ) + responses: list[Optional[dict[str, Any]]] = [None] * len(prompt_dataset) + tasks = [] + + with tqdm(total=len(prompt_dataset), desc=f"LiteLLM Inference ({model})") as pbar: + with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor: + for index, row in prompt_dataset.iterrows(): + messages = _convert_prompt_row_to_litellm_messages(row) + future = executor.submit( + _call_litellm_completion, model=model, messages=messages + ) + future.add_done_callback(lambda _: pbar.update(1)) + tasks.append((future, index)) + + for future, index in tasks: + try: + result = future.result() + responses[index] = result + except Exception as e: + logger.error("Error processing prompt at index %d: %s", index, e) + responses[index] = {"error": f"LiteLLM task failed: {e}"} + + return responses + + +def _is_litellm_vertex_maas_model(model: str) -> bool: + """Checks if the model is a Vertex MAAS model to be handled by LiteLLM.""" + return any( + model.startswith(prefix) + for prefix in _evals_constant.SUPPORTED_VERTEX_MAAS_MODEL_PREFIXES + ) + + +def _is_litellm_model(model: str) -> bool: + """Checks if the model name corresponds to a valid LiteLLM model name.""" + if litellm is None: + return False + + try: + litellm.get_llm_provider(model) + return True + except ValueError: + return False + + +def _is_gemini_model(model: str) -> bool: + """Checks if the model name corresponds to a Gemini/Vertex AI model.""" + return ( + model.startswith("gemini-") + or model.startswith("projects/") + or model.startswith("models/") + or model.startswith("publishers/") + or model.startswith("tunedModels/") + ) + + +def _run_inference_internal( + api_client: BaseApiClient, + model: Union[Callable[[Any], Any], str], + prompt_dataset: pd.DataFrame, + config: Optional[genai_types.GenerateContentConfig] = None, +) -> pd.DataFrame: + """Runs inference on a given dataset using the specified model or function.""" + + if isinstance(model, str) and _is_gemini_model(model): + if ( + "prompt" not in prompt_dataset.columns + and "request" not in prompt_dataset.columns + and "starting_prompt" not in prompt_dataset.columns + ): + raise ValueError( + "Prompt dataset for Gemini model must contain either 'prompt'," + " 'request' or 'starting_prompt' column for inference. " + f"Found columns: {prompt_dataset.columns.tolist()}" + ) + + logger.info("Running inference with Gemini model name: %s", model) + raw_responses = _run_gemini_inference( + api_client=api_client, + model=model, + prompt_dataset=prompt_dataset, + config=config, + ) + processed_responses = [] + for resp_item in raw_responses: + if isinstance(resp_item, genai_types.GenerateContentResponse): + text_response = resp_item.text + processed_responses.append( + text_response + if text_response is not None + else json.dumps({"error": "Empty response text"}) + ) + elif isinstance(resp_item, dict) and "error" in resp_item: + processed_responses.append(json.dumps(resp_item)) + else: + error_payload = { + "error": "Unexpected response type from Gemini inference", + "response_type": str(type(resp_item)), + "details": str(resp_item), + } + processed_responses.append(json.dumps(error_payload)) + responses = processed_responses + elif callable(model): + logger.info("Running inference with custom callable function.") + custom_responses_raw = _run_custom_inference( + model_fn=model, prompt_dataset=prompt_dataset + ) + processed_custom_responses = [] + for resp_item in custom_responses_raw: + if isinstance(resp_item, str): + processed_custom_responses.append(resp_item) + elif isinstance(resp_item, dict) and "error" in resp_item: + processed_custom_responses.append(json.dumps(resp_item)) + else: + try: + processed_custom_responses.append(json.dumps(resp_item)) + except TypeError: + processed_custom_responses.append(str(resp_item)) + responses = processed_custom_responses + elif isinstance(model, str): + if litellm is None: + raise ImportError( + "The 'litellm' library is required to use this model." + " Please install it using 'pip install" + " google-cloud-aiplatform[evaluation]'." + ) + + processed_model_id = model + if model.startswith("vertex_ai/"): + # Already correctly prefixed for LiteLLM's Vertex AI provider + pass + elif _is_litellm_vertex_maas_model(model): + processed_model_id = f"vertex_ai/{model}" + logger.info( + "Detected Vertex AI Model Garden managed MaaS model. " + "Using LiteLLM ID: %s", + processed_model_id, + ) + elif _is_litellm_model(model): + # Other LiteLLM supported model + logger.info("Running inference with LiteLLM for model: %s", model) + else: + # Unsupported model string + raise TypeError( + f"Unsupported string model name: {model}. Expecting a Gemini model" + " name (e.g., 'gemini-2.5-pro', 'projects/.../models/...') or a" + " LiteLLM supported model name (e.g., 'openai/gpt-4o')." + " If using a third-party model via LiteLLM, ensure the" + " necessary environment variables are set (e.g., for OpenAI:" + " `os.environ['OPENAI_API_KEY'] = 'Your API Key'`). See" + " LiteLLM documentation for details:" + " https://docs.litellm.ai/docs/set_keys#environment-variables" + ) + + logger.info("Running inference via LiteLLM for model: %s", processed_model_id) + raw_responses = _run_litellm_inference( # type: ignore[assignment] + model=processed_model_id, prompt_dataset=prompt_dataset + ) + processed_llm_responses = [] + for response_dict in raw_responses: + if not isinstance(response_dict, dict): + processed_llm_responses.append( + json.dumps( + { + "error": "Invalid LiteLLM response format", + "details": str(response_dict), + } + ) + ) + continue + + if "error" in response_dict: + processed_llm_responses.append(json.dumps(response_dict)) + continue + + if ( + "choices" in response_dict + and isinstance(response_dict["choices"], list) + and len(response_dict["choices"]) > 0 + ): + first_choice = response_dict["choices"][0] + if "message" in first_choice and isinstance( + first_choice["message"], dict + ): + message = first_choice["message"] + if "content" in message and isinstance(message["content"], str): + processed_llm_responses.append(message["content"]) + else: + processed_llm_responses.append( + json.dumps( + { + "error": "LiteLLM response missing 'content' in message", + "details": response_dict, + } + ) + ) + else: + processed_llm_responses.append( + json.dumps( + { + "error": "LiteLLM response missing 'message' in first choice", + "details": response_dict, + } + ) + ) + else: + processed_llm_responses.append( + json.dumps( + { + "error": "LiteLLM response missing 'choices'", + "details": response_dict, + } + ) + ) + responses = processed_llm_responses + else: + raise TypeError( + f"Unsupported model type: {type(model)}. Expecting string (model" + " name) or Callable." + ) + + if len(responses) != len(prompt_dataset): + raise RuntimeError( + "Critical prompt/response count mismatch: %d prompts vs %d" + " responses. This indicates an issue in response collection." + % (len(prompt_dataset), len(responses)) + ) + + results_df_responses_only = pd.DataFrame( + { + _evals_constant.RESPONSE: responses, + } + ) + + prompt_dataset_indexed = prompt_dataset.reset_index(drop=True) + + # Drop existing 'response' column to prevent duplicate column names when + # re-running inference on a dataset that already has responses. + if _evals_constant.RESPONSE in prompt_dataset_indexed.columns: + logger.warning( + "A column named '%s' already exists in the prompt dataset. " + "The existing column will be dropped and replaced with the new " + "inference results.", + _evals_constant.RESPONSE, + ) + prompt_dataset_indexed = prompt_dataset_indexed.drop( + columns=[_evals_constant.RESPONSE] + ) + + results_df_responses_only_indexed = results_df_responses_only.reset_index(drop=True) + + results_df = pd.concat( + [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 + ) + + return results_df + + +async def _run_adk_user_simulation( + row: pd.Series, + agent: "LlmAgent", # type: ignore # noqa: F821 + config: Optional[types.evals.UserSimulatorConfig] = None, +) -> list[dict[str, Any]]: + """Runs a multi-turn user simulation using ADK's EvaluationGenerator.""" + # Lazy-import ADK dependencies to avoid top-level import failures when + # google-adk is not installed. + from google.adk.evaluation.conversation_scenarios import ConversationScenario + from google.adk.evaluation.eval_case import SessionInput as ADK_SessionInput + from google.adk.evaluation.evaluation_generator import EvaluationGenerator + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulator, + ) + from google.adk.evaluation.simulation.llm_backed_user_simulator import ( + LlmBackedUserSimulatorConfig, + ) + + starting_prompt = row.get("starting_prompt") + conversation_plan = row.get("conversation_plan") + user_persona = "EVALUATOR" + + if not starting_prompt or not conversation_plan: + raise ValueError( + "User simulation requires 'starting_prompt' and 'conversation_plan'" + " columns." + ) + + scenario = ConversationScenario( + starting_prompt=starting_prompt, + conversation_plan=conversation_plan, + user_persona=user_persona, + ) + + user_simulator_kwargs: dict[str, Any] = {} + if config: + if config.model_name: + user_simulator_kwargs["model"] = config.model_name + if config.model_configuration is not None: + user_simulator_kwargs["model_configuration"] = ( + config.model_configuration.model_dump(exclude_none=True) + ) + if config.max_turn is not None: + user_simulator_kwargs["max_allowed_invocations"] = config.max_turn + + user_simulator_config = LlmBackedUserSimulatorConfig(**user_simulator_kwargs) + user_simulator = LlmBackedUserSimulator( + conversation_scenario=scenario, config=user_simulator_config + ) + + try: + initial_session = _get_session_inputs(row) + app_name = initial_session.app_name or "user_simulation_app" + user_id = initial_session.user_id or "user_simulation_default_user" + state = initial_session.state or {} + except (KeyError, TypeError, ValueError): + app_name = "user_simulation_app" + user_id = "user_simulation_default_user" + state = {} + + invocations = await EvaluationGenerator._generate_inferences_from_root_agent( # pylint: disable=protected-access + root_agent=agent, + user_simulator=user_simulator, + reset_func=getattr(agent, "reset_data", None), + initial_session=ADK_SessionInput( + app_name=app_name, + user_id=user_id, + state=state, + ), + ) + + turns = [] + for i, invocation in enumerate(invocations): + events = [] + if invocation.user_content: + events.append( + { + "author": "user", + "content": invocation.user_content.model_dump( + mode="json", exclude_none=True + ), + "event_time": datetime.datetime.fromtimestamp( + invocation.creation_timestamp, tz=datetime.timezone.utc + ), + } + ) + if invocation.intermediate_data: + if ( + hasattr(invocation.intermediate_data, "invocation_events") + and invocation.intermediate_data.invocation_events + ): + for ie in invocation.intermediate_data.invocation_events: + events.append( + { + "author": ie.author, + "content": ( + ie.content.model_dump(mode="json", exclude_none=True) + if ie.content + else None + ), + "event_time": datetime.datetime.fromtimestamp( + invocation.creation_timestamp, tz=datetime.timezone.utc + ), + } + ) + elif hasattr(invocation.intermediate_data, "tool_uses"): + for tool_call in invocation.intermediate_data.tool_uses: + events.append( + { + "author": "tool_call", + "content": tool_call.model_dump( + mode="json", exclude_none=True + ), + "event_time": datetime.datetime.fromtimestamp( + invocation.creation_timestamp, tz=datetime.timezone.utc + ), + } + ) + + if invocation.final_response: + events.append( + { + "author": "agent", + "content": invocation.final_response.model_dump( + mode="json", exclude_none=True + ), + "event_time": datetime.datetime.fromtimestamp( + invocation.creation_timestamp, tz=datetime.timezone.utc + ), + } + ) + + turns.append( + { + "turn_index": i, + "turn_id": invocation.invocation_id or str(uuid.uuid4()), + "events": events, + } + ) + + return turns + + +def _apply_prompt_template( + df: pd.DataFrame, prompt_template: types.PromptTemplate +) -> None: + """Applies a prompt template to a DataFrame. + + The DataFrame is expected to have columns corresponding to the variables + in the prompt_template_str. The result will be in a new 'request' column. + + Args: + df: The input DataFrame to modify. + prompt_template: The prompt template to apply. + + Returns: + None. The DataFrame is modified in place. + """ + missing_vars = [var for var in prompt_template.variables if var not in df.columns] + if missing_vars: + raise ValueError( + "Missing columns in DataFrame for prompt template variables:" + f" {', '.join(missing_vars)}. Available columns:" + f" {', '.join(df.columns.tolist())}" + ) + + if "prompt" in df.columns: + logger.info( + "Templated prompts stored in 'request' and will be used for" + " inference.Original 'prompt' column is kept but not used for" + " inference." + ) + elif "prompt" not in df.columns and "request" in df.columns: + logger.info("The 'request' column will be replaced with templated prompts.") + + templated_prompts = [] + for _, row in df.iterrows(): + templated_prompts.append(prompt_template.assemble(**row.to_dict())) + + df["request"] = templated_prompts + + +def _load_dataframe( + api_client: BaseApiClient, src: Union[str, pd.DataFrame] +) -> pd.DataFrame: + """Loads and prepares the prompt dataset for inference.""" + logger.info("Loading prompt dataset from: %s", src) + try: + loader = _evals_utils.EvalDatasetLoader(api_client=api_client) + dataset_list_of_dicts = loader.load(src) + if not dataset_list_of_dicts: + raise ValueError("Prompt dataset 'prompt_dataset' must not be empty.") + return pd.DataFrame(dataset_list_of_dicts) + except Exception as e: + logger.error("Failed to load prompt dataset from source: %s. Error: %s", src, e) + raise e + + +def _execute_inference( + *, + api_client: BaseApiClient, + src: Union[str, pd.DataFrame], + model: Optional[Union[Callable[[Any], Any], str]] = None, + agent_engine: Optional[Union[str, types.AgentEngine]] = None, + agent: Optional["LlmAgent"] = None, # type: ignore # noqa: F821 + dest: Optional[str] = None, + config: Optional[genai_types.GenerateContentConfig] = None, + prompt_template: Optional[Union[str, types.PromptTemplateOrDict]] = None, + location: Optional[str] = None, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, + allow_cross_region_model: bool = False, +) -> pd.DataFrame: + """Executes inference on a given dataset using the specified model. + + Args: + api_client: The API client. + src: The source of the dataset. Can be a string (path to a local file, a + GCS path, or a BigQuery table) or a Pandas DataFrame. + model: The model to use for inference. Can be a callable function or a + string representing a model. + agent_engine: The agent engine to use for inference. Can be a resource + name string or an `AgentEngine` instance. + agent: The local agent to use for inference. Can be an ADK agent instance. + dest: The destination to save the inference results. Can be a string + representing a file path or a GCS URI. + config: The generation configuration for the model. + prompt_template: The prompt template to use for inference. + location: The location to use for the inference. If not specified, the + location configured in the client will be used. + user_simulator_config: The configuration for the user simulator in + multi-turn agent scraping. + + Returns: + A pandas DataFrame containing the inference results. + """ + if not api_client: + raise ValueError("'api_client' instance must be provided.") + if location: + api_client = _get_api_client_with_location(api_client, location) + + if sum(x is not None for x in [model, agent_engine, agent]) != 1: + raise ValueError( + "Exactly one of model, agent_engine, or agent must be provided." + ) + + prompt_dataset = _load_dataframe(api_client, src) + if prompt_template: + logger.info("Applying prompt template...") + if isinstance(prompt_template, str): + prompt_template = types.PromptTemplate(text=prompt_template) + elif isinstance(prompt_template, dict): + prompt_template = types.PromptTemplate.model_validate(prompt_template) + + _apply_prompt_template(prompt_dataset, prompt_template) + + if model: + start_time = time.time() + logger.debug("Starting inference process ...") + results_df = _run_inference_internal( + api_client=api_client, + model=model, + prompt_dataset=prompt_dataset, + config=config, + ) + end_time = time.time() + logger.info("Inference completed in %.2f seconds.", end_time - start_time) + + candidate_name = None + if isinstance(model, str): + candidate_name = model + elif callable(model): + candidate_name = getattr(model, "__name__", None) + + results_df = _drop_empty_columns(results_df) + evaluation_dataset = types.EvaluationDataset( + eval_dataset_df=results_df, + candidate_name=candidate_name, + ) + elif agent_engine or agent: + candidate_name = None + if agent_engine: + candidate_name = "agent_engine_0" + elif agent: + agent_config = types.evals.AgentConfig.from_agent(agent) + candidate_name = agent_config.agent_id or "agent_0" + + if ( + agent_engine + and not isinstance(agent_engine, str) + and not ( + hasattr(agent_engine, "api_client") + and type(agent_engine).__name__ == "AgentEngine" + ) + ): + raise TypeError( + f"Unsupported agent_engine type: {type(agent_engine)}. Expecting a" + " string (agent engine resource name in" + " 'projects/{project_id}/locations/{location_id}/reasoningEngines/{reasoning_engine_id}'" + " format) or a types.AgentEngine instance." + ) + if ( + _evals_constant.INTERMEDIATE_EVENTS in prompt_dataset.columns + or _evals_constant.RESPONSE in prompt_dataset.columns + ): + raise ValueError( + "The eval dataset provided for agent run should not contain" + f" '{_evals_constant.INTERMEDIATE_EVENTS}' or" + f" '{_evals_constant.RESPONSE}' columns, as these columns will be" + " generated by the agent run." + ) + start_time = time.time() + logger.debug("Starting Agent Run process ...") + results_df = _run_agent_internal( + api_client=api_client, + agent_engine=agent_engine, + agent=agent, + prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, + allow_cross_region_model=allow_cross_region_model, + ) + end_time = time.time() + logger.info("Agent Run completed in %.2f seconds.", end_time - start_time) + + results_df = _drop_empty_columns(results_df) + evaluation_dataset = types.EvaluationDataset( + eval_dataset_df=results_df, + candidate_name=candidate_name, + ) + else: + raise ValueError("Either model, agent_engine or agent must be provided.") + + if dest: + file_name = "inference_results.jsonl" if model else "agent_run_results.jsonl" + is_gcs_path = dest.startswith(_gcs_utils.GCS_PREFIX) + + if is_gcs_path: + full_dest_path = os.path.join(dest, file_name) + else: + os.makedirs(dest, exist_ok=True) + full_dest_path = os.path.join(dest, file_name) + + logger.info("Saving inference / agent run results to: %s", full_dest_path) + try: + if is_gcs_path: + _gcs_utils.GcsUtils(api_client=api_client).upload_dataframe( + df=results_df, + gcs_destination_blob_path=full_dest_path, + file_type="jsonl", + ) + logger.info("Results saved to GCS: %s", full_dest_path) + evaluation_dataset.gcs_source = genai_types.GcsSource( + uris=[full_dest_path] + ) + else: + results_df.to_json(full_dest_path, orient="records", lines=True) + logger.info("Results saved locally to: %s", full_dest_path) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Failed to save results to %s. Error: %s", full_dest_path, e) + + return evaluation_dataset + + +def _get_dataset_source( + ds_item: types.EvaluationDataset, +) -> Union[str, pd.DataFrame]: + """Returns the source of the dataset, either a DataFrame, GCS URI, or BigQuery URI.""" + if ds_item.eval_dataset_df is not None: + return ds_item.eval_dataset_df + elif ds_item.gcs_source is not None and ds_item.gcs_source.uris: + if len(ds_item.gcs_source.uris) > 1: + logger.warning( + "Multiple GCS URIs in GcsSource. Using the first one: %s", + ds_item.gcs_source.uris[0], + ) + return ds_item.gcs_source.uris[0] + elif ds_item.bigquery_source is not None and ds_item.bigquery_source.input_uri: + return ds_item.bigquery_source.input_uri + else: + raise ValueError( + "EvaluationDataset item has no valid source" + " (eval_dataset_df, gcs_source with uris, or bigquery_source with" + " input_uri)." + ) + + +def _resolve_dataset_inputs( + dataset: list[types.EvaluationDataset], + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]], + loader: "_evals_utils.EvalDatasetLoader", + agent_info: Optional[types.evals.AgentInfo] = None, +) -> tuple[types.EvaluationDataset, int]: + """Loads and processes single or multiple datasets for evaluation. + + Args: + dataset: The dataset(s) to process. Can be a single EvaluationDataset or a + list of them. + dataset_schema: The schema to use for the dataset(s). If None, it will be + auto-detected. + loader: An instance of EvalDatasetLoader to load data. + agent_info: The agent info of the agent under evaluation. + + Returns: + A tuple containing: + - processed_eval_dataset: The processed EvaluationDataset containing + evaluation cases. + - num_response_candidates: The number of response candidates. + """ + if not dataset: + raise ValueError("Input dataset list cannot be empty.") + + num_response_candidates = len(dataset) + datasets_to_process = dataset + logger.info("Processing %s dataset(s).", num_response_candidates) + + if len(datasets_to_process) == 1 and datasets_to_process[0].eval_cases: + return datasets_to_process[0], 1 + + parsed_evaluation_datasets: list[types.EvaluationDataset] = [] + + for i, ds_item in enumerate(datasets_to_process): + if not isinstance(ds_item, types.EvaluationDataset): + logger.error( + "Unexpected item type in dataset list at index %d: %s. Expected" + " types.EvaluationDataset.", + i, + type(ds_item), + ) + raise TypeError( + f"Item at index {i} is not an EvaluationDataset: {type(ds_item)}" + ) + + if ds_item.eval_cases: + logger.info("Dataset %d already contains eval_cases.", i) + parsed_evaluation_datasets.append(ds_item) + continue + + ds_source_for_loader = _get_dataset_source(ds_item) + current_loaded_data = loader.load(ds_source_for_loader) + + if dataset_schema: + current_schema = _evals_data_converters.EvalDatasetSchema(dataset_schema) + else: + current_schema = _evals_data_converters.auto_detect_dataset_schema( # type: ignore[assignment] + current_loaded_data + ) + + logger.info( + "Dataset %d: Schema: %s. Using %s converter.", + i, + current_schema, + _evals_data_converters.get_dataset_converter( + current_schema + ).__class__.__name__, + ) + converter = _evals_data_converters.get_dataset_converter(current_schema) + parsed_evaluation_datasets.append(converter.convert(current_loaded_data)) + + processed_eval_dataset = _evals_data_converters.merge_evaluation_datasets( + datasets=parsed_evaluation_datasets, + agent_info=agent_info, + ) + + if not processed_eval_dataset.eval_cases: + raise ValueError("No evaluation cases found in the dataset.") + return processed_eval_dataset, num_response_candidates + + +def _resolve_evaluation_run_metrics( + metrics: Union[list[types.EvaluationRunMetric], list[types.Metric]], api_client: Any +) -> list[types.EvaluationRunMetric]: + """Resolves a list of evaluation run metric instances, loading RubricMetric if necessary.""" + if not metrics: + return [] + resolved_metrics_list = [] + for metric_instance in metrics: + if isinstance(metric_instance, types.EvaluationRunMetric): + resolved_metrics_list.append(metric_instance) + elif isinstance( + metric_instance, _evals_metric_loaders.LazyLoadedPrebuiltMetric + ): + try: + resolved_metric = metric_instance.resolve(api_client=api_client) + if resolved_metric.name: + resolved_metrics_list.append( + types.EvaluationRunMetric( + metric=resolved_metric.name, + metric_config=types.UnifiedMetric( + predefined_metric_spec=genai_types.PredefinedMetricSpec( + metric_spec_name=resolved_metric.name, + ) + ), + ) + ) + except Exception as e: + logger.error( + "Failed to resolve RubricMetric %s@%s: %s", + metric_instance.name, + metric_instance.version, + e, + ) + raise + elif isinstance(metric_instance, types.Metric): + config_dict = t.t_metrics([metric_instance])[0] + res_name = getattr(metric_instance, "metric_resource_name", None) + resolved_metrics_list.append( + types.EvaluationRunMetric( + metric=metric_instance.name, + metric_config=config_dict if config_dict else None, + metric_resource_name=res_name, + ) + ) + else: + try: + metric_name_str = str(metric_instance) + lazy_metric_instance = getattr( + _evals_metric_loaders.RubricMetric, metric_name_str.upper() + ) + if isinstance( + lazy_metric_instance, _evals_metric_loaders.LazyLoadedPrebuiltMetric + ): + resolved_metric = lazy_metric_instance.resolve( + api_client=api_client + ) + if resolved_metric.name: + resolved_metrics_list.append( + types.EvaluationRunMetric( + metric=resolved_metric.name, + metric_config=types.UnifiedMetric( + predefined_metric_spec=genai_types.PredefinedMetricSpec( + metric_spec_name=resolved_metric.name, + ) + ), + ) + ) + else: + raise TypeError( + f"RubricMetric.{metric_name_str.upper()} cannot be resolved." + ) + except AttributeError as exc: + raise TypeError( + "Unsupported metric type or invalid RubricMetric name:" + f" {metric_instance}" + ) from exc + return resolved_metrics_list + + +def _resolve_metrics( + metrics: list[types.Metric], api_client: Any +) -> list[types.Metric]: + """Resolves a list of metric instances, loading RubricMetric if necessary.""" + resolved_metrics_list = [] + for metric_instance in metrics: + if isinstance(metric_instance, _evals_metric_loaders.LazyLoadedPrebuiltMetric): + try: + resolved_metrics_list.append( + metric_instance.resolve(api_client=api_client) + ) + except Exception as e: + logger.error( + "Failed to resolve RubricMetric %s@%s: %s", + metric_instance.name, + metric_instance.version, + e, + ) + raise + elif isinstance(metric_instance, types.Metric): + resolved_metrics_list.append(metric_instance) + else: + try: + metric_name_str = str(metric_instance) + lazy_metric_instance = getattr( + _evals_metric_loaders.RubricMetric, metric_name_str.upper() + ) + if isinstance( + lazy_metric_instance, _evals_metric_loaders.LazyLoadedPrebuiltMetric + ): + resolved_metrics_list.append( + lazy_metric_instance.resolve(api_client=api_client) + ) + else: + raise TypeError( + f"RubricMetric.{metric_name_str.upper()} cannot be resolved." + ) + except AttributeError as exc: + raise TypeError( + "Unsupported metric type or invalid RubricMetric name:" + f" {metric_instance}" + ) from exc + return resolved_metrics_list + + +def _execute_evaluation( # type: ignore[no-untyped-def] + *, + api_client: Any, + dataset: Union[types.EvaluationDataset, list[types.EvaluationDataset]], + metrics: list[types.Metric], + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = None, + dest: Optional[str] = None, + location: Optional[str] = None, + evaluation_service_qps: Optional[float] = None, + **kwargs, +) -> types.EvaluationResult: + """Evaluates a dataset using the provided metrics. + + Args: + api_client: The API client. + dataset: The dataset to evaluate. + metrics: The metrics to evaluate the dataset against. + dataset_schema: The schema of the dataset. + dest: The destination to save the evaluation results. + location: The location to use for the evaluation. If not specified, the + location configured in the client will be used. + evaluation_service_qps: The rate limit (queries per second) for calls + to the evaluation service. Defaults to 10. Increase this value if + your project has a higher EvaluateInstances API quota. + **kwargs: Extra arguments to pass to evaluation, such as `agent_info`. + + Returns: + The evaluation result. + """ + + if location: + api_client = _get_api_client_with_location(api_client, location) + + logger.info("Preparing dataset(s) and metrics...") + if isinstance(dataset, types.EvaluationDataset): + dataset_list = [dataset] + elif isinstance(dataset, list): + for item in dataset: + if not isinstance(item, types.EvaluationDataset): + raise TypeError( + f"Unsupported dataset type: {type(item)}. " + "Must be EvaluationDataset." + ) + dataset_list = dataset + else: + raise TypeError( + f"Unsupported dataset type: {type(dataset)}. Must be an" + " EvaluationDataset or a list of EvaluationDataset." + ) + original_candidate_names = [ + ds.candidate_name or f"candidate_{i + 1}" for i, ds in enumerate(dataset_list) + ] + name_counts = collections.Counter(original_candidate_names) + deduped_candidate_names = [] + current_name_counts: collections.defaultdict[Any, int] = collections.defaultdict( + int + ) + + for name in original_candidate_names: + if name_counts[name] > 1: + current_name_counts[name] += 1 + deduped_candidate_names.append(f"{name} #{current_name_counts[name]}") + else: + deduped_candidate_names.append(name) + + loader = _evals_utils.EvalDatasetLoader(api_client=api_client) + + agent_info = kwargs.get("agent_info", None) + validated_agent_info = None + if agent_info: + if isinstance(agent_info, dict): + validated_agent_info = types.evals.AgentInfo.model_validate(agent_info) + elif isinstance(agent_info, types.evals.AgentInfo): + validated_agent_info = agent_info + else: + raise TypeError( + "agent_info values must be of type types.evals.AgentInfo or dict," + f" but got {type(agent_info)}'" + ) + + processed_eval_dataset, num_response_candidates = _resolve_dataset_inputs( + dataset=dataset_list, + dataset_schema=dataset_schema, + loader=loader, + agent_info=validated_agent_info, + ) + + resolved_metrics = _resolve_metrics(metrics, api_client) + + evaluation_run_config = _evals_metric_handlers.EvaluationRunConfig( + evals_module=evals.Evals(api_client_=api_client), + dataset=processed_eval_dataset, + metrics=resolved_metrics, + num_response_candidates=num_response_candidates, + ) + + logger.info("Running Metric Computation...") + t1 = time.perf_counter() + evaluation_result = _evals_metric_handlers.compute_metrics_and_aggregate( + evaluation_run_config, + evaluation_service_qps=evaluation_service_qps, + ) + t2 = time.perf_counter() + logger.info("Evaluation took: %f seconds", t2 - t1) + + evaluation_result.evaluation_dataset = dataset_list + evaluation_result.agent_info = validated_agent_info + + if not evaluation_result.metadata: + evaluation_result.metadata = types.EvaluationRunMetadata() + + evaluation_result.metadata.creation_timestamp = datetime.datetime.now( + datetime.timezone.utc + ) + + if deduped_candidate_names: + evaluation_result.metadata.candidate_names = deduped_candidate_names + + logger.info("Evaluation run completed.") + + if dest: + uploaded_path = _gcs_utils.GcsUtils( + api_client=api_client + ).upload_json_to_prefix( + data=evaluation_result.model_dump( + mode="json", + exclude_none=True, + exclude={"evaluation_dataset"}, + ), + gcs_dest_prefix=dest, + filename_prefix="evaluation_result", + ) + logger.info( + "Evaluation results uploaded successfully to GCS: %s", uploaded_path + ) + return evaluation_result + + +def _get_session_inputs(row: pd.Series) -> types.evals.SessionInput: + """Parses session inputs from a row.""" + if isinstance(row["session_inputs"], str): + return types.evals.SessionInput.model_validate( + json.loads(row["session_inputs"]) + ) + elif isinstance(row["session_inputs"], dict): + return types.evals.SessionInput.model_validate(row["session_inputs"]) + elif isinstance(row["session_inputs"], types.evals.SessionInput): + return row["session_inputs"] + else: + raise TypeError( + f"Unsupported session_inputs type: {type(row['session_inputs'])}. " + "Expecting string or dict in types.evals.SessionInput format." + ) + + +def _is_multi_turn_agent_simulation( + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, + prompt_dataset: pd.DataFrame = None, +) -> bool: + """Checks if the agent run is a multi-turn user simulation.""" + return ( + user_simulator_config is not None + or "conversation_plan" in prompt_dataset.columns + ) + + +def _process_multi_turn_agent_response( + resp_item: Any, + agent_data_agents: Optional[dict[str, Any]], +) -> Optional[Union[str, dict[str, Any]]]: + """Processes a multi-turn agent response.""" + if isinstance(resp_item, dict) and "error" in resp_item: + return json.dumps(resp_item) + return types.evals.AgentData( + turns=resp_item, + agents=agent_data_agents, + ).model_dump(exclude_unset=True) + + +def _process_single_turn_agent_response( + resp_item: Any, + agent_data_agents: Optional[dict[str, Any]], +) -> tuple[ + Optional[Union[str, dict[str, Any]]], + list[dict[str, Any]], + Optional[Union[str, dict[str, Any]]], +]: + """Processes a single-turn agent response.""" + intermediate_events_row: list[dict[str, Any]] = [] + response_row: Optional[Union[str, dict[str, Any]]] = None + agent_data_row: Optional[Union[str, dict[str, Any]]] = None + + if isinstance(resp_item, list): + try: + response_row = resp_item[-1]["content"]["parts"][0]["text"] + for intermediate_event in resp_item[:-1]: + intermediate_events_row.append( + { + "event_id": intermediate_event.get("id"), + "content": intermediate_event.get("content"), + "creation_timestamp": intermediate_event.get("timestamp"), + "author": intermediate_event.get("author"), + } + ) + # Construct AgentData natively for single-turn runs + agent_events = [] + for event_dict in resp_item: + content_dict = event_dict.get("content") + content_obj = None + if content_dict: + content_obj = genai_types.Content.model_validate(content_dict) + + agent_events.append( + types.evals.AgentEvent( + author=event_dict.get("author", "model"), + content=content_obj, + ) + ) + + turn = types.evals.ConversationTurn( + turn_index=0, + turn_id="turn_0", + events=agent_events, + ) + agent_data_row = types.evals.AgentData( + turns=[turn], + agents=agent_data_agents, + ).model_dump(exclude_unset=True) + except Exception as e: # pylint: disable=broad-exception-caught + error_payload = { + "error": ( + f"Failed to parse agent run response {str(resp_item)} to " + f"agent data: {e}" + ), + } + response_row = json.dumps(error_payload) + agent_data_row = json.dumps(error_payload) + elif isinstance(resp_item, dict) and "error" in resp_item: + response_row = json.dumps(resp_item) + else: + error_payload = { + "error": "Unexpected response type from agent run", + "response_type": str(type(resp_item)), + "details": str(resp_item), + } + response_row = json.dumps(error_payload) + + return response_row, intermediate_events_row, agent_data_row + + +def _create_agent_results_dataframe( + prompt_dataset: pd.DataFrame, + processed_responses: list[Any], + processed_intermediate_events: list[Any], + processed_agent_data: list[Any], + is_user_simulation: bool, +) -> pd.DataFrame: + """Creates a DataFrame from the processed agent responses.""" + df_dict: dict[str, Any] = {} + if is_user_simulation: + df_dict[AGENT_DATA] = processed_agent_data + if len(processed_agent_data) != len(prompt_dataset): + raise RuntimeError( + "Critical prompt/agent_data count mismatch: %d" + " prompts vs %d agent_data. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_agent_data), + ) + ) + else: + df_dict[_evals_constant.INTERMEDIATE_EVENTS] = processed_intermediate_events + df_dict[_evals_constant.RESPONSE] = processed_responses + df_dict[AGENT_DATA] = processed_agent_data + if len(processed_responses) != len(prompt_dataset) or len( + processed_responses + ) != len(processed_intermediate_events): + raise RuntimeError( + "Critical prompt/response/intermediate_events count mismatch: %d" + " prompts vs %d vs %d responses. This indicates an issue in response" + " collection." + % ( + len(prompt_dataset), + len(processed_responses), + len(processed_intermediate_events), + ) + ) + + results_df_raw = pd.DataFrame(df_dict) + + prompt_dataset_indexed = prompt_dataset.reset_index(drop=True) + results_df_responses_only_indexed = results_df_raw.reset_index(drop=True) + + # Drop columns from input that will be overwritten by results to avoid + # duplicate columns after concatenation (e.g. agent_data). + overlap = prompt_dataset_indexed.columns.intersection( + results_df_responses_only_indexed.columns + ) + if not overlap.empty: + prompt_dataset_indexed = prompt_dataset_indexed.drop(columns=overlap) + + results_df = pd.concat( + [prompt_dataset_indexed, results_df_responses_only_indexed], axis=1 + ) + return results_df + + +def _run_agent_internal( + api_client: BaseApiClient, + agent_engine: Optional[Union[str, types.AgentEngine]], + agent: Optional["LlmAgent"], # type: ignore # noqa: F821 + prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, + allow_cross_region_model: bool = False, +) -> pd.DataFrame: + """Runs an agent.""" + raw_responses = _run_agent( + api_client=api_client, + agent_engine=agent_engine, + agent=agent, + prompt_dataset=prompt_dataset, + user_simulator_config=user_simulator_config, + allow_cross_region_model=allow_cross_region_model, + ) + processed_intermediate_events = [] + processed_responses = [] + processed_agent_data = [] + agent_data_agents = None + if agent: + agent_data_agents = types.evals.AgentData.get_agents_map(agent) + + is_user_simulation = _is_multi_turn_agent_simulation( + user_simulator_config, prompt_dataset + ) + + for resp_item in raw_responses: + if is_user_simulation: + agent_data_row = _process_multi_turn_agent_response( + resp_item, agent_data_agents + ) + processed_agent_data.append(agent_data_row) + else: + response_row, intermediate_events_row, agent_data_row = ( + _process_single_turn_agent_response(resp_item, agent_data_agents) + ) + processed_responses.append(response_row) + processed_intermediate_events.append(intermediate_events_row) + processed_agent_data.append(agent_data_row) + + return _create_agent_results_dataframe( + prompt_dataset, + processed_responses, + processed_intermediate_events, + processed_agent_data, + is_user_simulation, + ) + + +def _run_agent( + api_client: BaseApiClient, + agent_engine: Optional[Union[str, types.AgentEngine]], + agent: Optional["LlmAgent"], # type: ignore # noqa: F821 + prompt_dataset: pd.DataFrame, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, + allow_cross_region_model: bool = False, +) -> list[ + Union[ + list[dict[str, Any]], + dict[str, Any], + genai_types.GenerateContentResponse, + ] +]: + """Internal helper to run inference using Gemini model with concurrency.""" + original_location = os.environ.get("GOOGLE_CLOUD_LOCATION") + location_overridden = False + + if user_simulator_config and user_simulator_config.model_name: + model_name = user_simulator_config.model_name + if model_name.startswith("gemini-3") and "/" not in model_name: + current_location = original_location or api_client.location or "us-central1" + if current_location != "global" and not allow_cross_region_model: + raise ValueError( + f"The model '{model_name}' is currently only available in the" + " 'global' region. Because this request originated in" + f" '{current_location}', you must explicitly set" + " allow_cross_region_model=True to allow your data to be routed" + " outside of your request's region." + ) + + logger.warning( + "Model %s is only available in the global region. Routing to global.", + model_name, + ) + user_simulator_config.model_name = f"projects/{api_client.project}/locations/global/publishers/google/models/{model_name}" + if original_location != "global": + os.environ["GOOGLE_CLOUD_LOCATION"] = "global" + location_overridden = True + + try: + if agent_engine: + return _execute_inference_concurrently( + api_client=api_client, + agent_engine=agent_engine, + prompt_dataset=prompt_dataset, + progress_desc="Agent Run", + gemini_config=None, + user_simulator_config=None, + inference_fn=_execute_agent_run_with_retry, + ) + elif agent: + return _execute_inference_concurrently( + api_client=api_client, + agent=agent, + prompt_dataset=prompt_dataset, + progress_desc="Local Agent Run", + gemini_config=None, + user_simulator_config=user_simulator_config, + inference_fn=_execute_local_agent_run_with_retry, + ) + else: + raise ValueError("Neither agent_engine nor agent is provided.") + finally: + if location_overridden: + if original_location is None: + del os.environ["GOOGLE_CLOUD_LOCATION"] + else: + os.environ["GOOGLE_CLOUD_LOCATION"] = original_location + + +def _create_agent_engine_session( + *, + agent_engine: types.AgentEngine, + user_id: str, + session_state: Optional[dict[str, Any]] = None, +) -> Any: + """Creates a session for an agent engine and returns the session ID. + + First attempts to use the agent engine's own `create_session` operation + (available for agents deployed via AdkApp). If the agent engine does not + have `create_session` registered, falls back to the managed Vertex AI + Sessions API. + + Args: + agent_engine: The AgentEngine instance. + user_id: The user ID for the session. + session_state: Optional initial state for the session. + + Returns: + The session ID string. + + Raises: + RuntimeError: If the session could not be created via either path. + """ + try: + session = agent_engine.create_session( # type: ignore[attr-defined] + user_id=user_id, + state=session_state, + ) + return session["id"] + except AttributeError as exc: + # Agent engine does not have create_session registered (e.g. deployed + # via Console, gcloud, or source code deployment without AdkApp). + # Fall back to the managed Vertex AI Sessions API. + logger.info( + "Agent engine does not have 'create_session' operation registered." + " Falling back to managed Sessions API." + ) + if agent_engine.api_resource is None: + raise RuntimeError( + "Failed to create session: agent_engine.api_resource is None." + ) from exc + if agent_engine.api_client is None: + raise RuntimeError( + "Failed to create session: agent_engine.api_client is None." + ) from exc + operation = agent_engine.api_client.sessions.create( + name=agent_engine.api_resource.name, + user_id=user_id, + config=types.CreateAgentEngineSessionConfig( + session_state=session_state, + ), + ) + if operation.response and operation.response.name: + # Session name format: + # projects/{p}/locations/{l}/reasoningEngines/{re}/sessions/{id} + return operation.response.name.split("/")[-1] + elif operation.error: + raise RuntimeError( + f"Failed to create session via managed API: {operation.error}" + ) from exc + else: + raise RuntimeError( + "Failed to create session via managed API: " + "operation returned no response." + ) from exc + + +def _execute_agent_run_with_retry( + row: pd.Series, + contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], + agent_engine: types.AgentEngine, + max_retries: int = 3, +) -> Union[list[dict[str, Any]], dict[str, Any]]: + """Executes agent run over agent engine for a single prompt.""" + try: + if "session_inputs" in row.index and row.get("session_inputs") is not None: + session_inputs = _get_session_inputs(row) + user_id = session_inputs.user_id or str(uuid.uuid4()) + session_state = session_inputs.state + else: + user_id = str(uuid.uuid4()) + session_state = None + except KeyError as e: + return {"error": f"Failed to get all required agent engine inputs: {e}"} + + try: + session_id = _create_agent_engine_session( + agent_engine=agent_engine, + user_id=user_id, + session_state=session_state, + ) + except Exception as e: # pylint: disable=broad-exception-caught + return {"error": f"Failed to create a new session: {e}"} + + # Pre-populate remote session with agent_data history (N+1 case only). + if ( + AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + and _is_n_plus_1_inference(row[AGENT_DATA]) + ): + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) + _, history_events = _extract_prompt_from_agent_data(agent_data_obj) + + if agent_engine.api_resource is None: + return {"error": "agent_engine.api_resource is None."} + if agent_engine.api_client is None: + return {"error": "agent_engine.api_client is None."} + session_name = f"{agent_engine.api_resource.name}/sessions/{session_id}" + base_ts = datetime.datetime(2000, 1, 1, tzinfo=datetime.timezone.utc) + for i, ag_event in enumerate(history_events): + agent_engine.api_client.sessions.events.append( + name=session_name, + author=ag_event.author or "user", + invocation_id="history", + timestamp=base_ts + datetime.timedelta(seconds=i), + config=types.AppendAgentEngineSessionEventConfig( + content=ag_event.content, + ), + ) + + # stream_query retry loop (shared for both agent_data and prompt paths). + for attempt in range(max_retries): + try: + responses = [] + for event in agent_engine.stream_query( # type: ignore[attr-defined] + user_id=user_id, + session_id=session_id, + message=contents, + ): + if event and CONTENT in event and PARTS in event[CONTENT]: + responses.append(event) + return responses + except api_exceptions.ResourceExhausted as e: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s. Retrying in %s" + " seconds...", + attempt + 1, + max_retries, + e, + 2**attempt, + ) + if attempt == max_retries - 1: + return {"error": f"Resource exhausted after retries: {e}"} + time.sleep(2**attempt) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Unexpected error during agent engine run on attempt %d/%d: %s", + attempt + 1, + max_retries, + e, + ) + if attempt == max_retries - 1: + return {"error": f"Failed after retries: {e}"} + time.sleep(1) + return {"error": f"Failed to get agent run results after {max_retries} retries"} + + +def _execute_local_agent_run_with_retry( + row: pd.Series, + contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], + agent: "LlmAgent", # type: ignore # noqa: F821 + max_retries: int = 3, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, +) -> Union[list[dict[str, Any]], dict[str, Any]]: + """Executes agent run locally for a single prompt synchronously.""" + return asyncio.run( + _execute_local_agent_run_with_retry_async( + row, contents, agent, max_retries, user_simulator_config + ) + ) + + +async def _execute_local_agent_run_with_retry_async( + row: pd.Series, + contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict], + agent: "LlmAgent", # type: ignore # noqa: F821 + max_retries: int = 3, + user_simulator_config: Optional[types.evals.UserSimulatorConfig] = None, +) -> Union[list[dict[str, Any]], dict[str, Any]]: + """Executes agent run locally for a single prompt asynchronously.""" + # Lazy-import ADK dependencies to avoid top-level import failures when + # google-adk is not installed. + from google.adk.runners import Runner + from google.adk.sessions import InMemorySessionService + + # Multi-turn agent scraping with user simulation. + if user_simulator_config or "conversation_plan" in row: + try: + return await _run_adk_user_simulation(row, agent, user_simulator_config) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Multi-turn agent run with user simulation failed: %s", e) + return {"error": f"Multi-turn agent run with user simulation failed: {e}"} + + if "session_inputs" in row.index and row.get("session_inputs") is not None: + session_inputs = _get_session_inputs(row) + user_id = session_inputs.user_id or str(uuid.uuid4()) + app_name = session_inputs.app_name or "local agent run" + else: + user_id = str(uuid.uuid4()) + app_name = "local agent run" + session_id = str(uuid.uuid4()) + + session_service = InMemorySessionService() + await session_service.create_session( + app_name=app_name, user_id=user_id, session_id=session_id + ) + + # Pre-populate session with agent_data history (N+1 case only). + if ( + AGENT_DATA in row.index + and row.get(AGENT_DATA) is not None + and _is_n_plus_1_inference(row[AGENT_DATA]) + ): + from google.adk.events.event import Event as AdkEvent + + agent_data_obj = row[AGENT_DATA] + if isinstance(agent_data_obj, dict): + agent_data_obj = types.evals.AgentData.model_validate(agent_data_obj) + _, history_events = _extract_prompt_from_agent_data(agent_data_obj) + internal_session = session_service.sessions[app_name][user_id][session_id] + for ag_event in history_events: + adk_event = AdkEvent( + author=ag_event.author or "user", + content=ag_event.content, + invocation_id="history", + ) + internal_session.events.append(adk_event) + + agent_runner = Runner( + agent=agent, app_name=app_name, session_service=session_service + ) + new_message_content = genai_types.Content( + role=USER_AUTHOR, + parts=[genai_types.Part(text=contents)], + ) + # Avoid printing out warning from agent_runner.run() + # WARNING:google_genai.types:Warning: there are non-text parts in the + # response: ['function_call'], returning concatenated text result from + # text parts. Check the full candidates.content.parts accessor to get + # the full model response. + # TODO: Update retry mechanism + with _temp_logger_level("google_genai.types", logging.ERROR): + for attempt in range(max_retries): + try: + events = [] + async for event in agent_runner.run_async( + user_id=user_id, + session_id=session_id, + new_message=new_message_content, + ): + if event: + event = event.model_dump(exclude_none=True) + if event and CONTENT in event and PARTS in event[CONTENT]: + events.append(event) + return events + except api_exceptions.ResourceExhausted as e: + logger.warning( + "Resource Exhausted error on attempt %d/%d: %s. Retrying" + " in %s seconds...", + attempt + 1, + max_retries, + e, + 2**attempt, + ) + if attempt == max_retries - 1: + return {"error": f"Resource exhausted after retries: {e}"} + await asyncio.sleep(2**attempt) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Unexpected error during agent run on attempt %d/%d: %s", + attempt + 1, + max_retries, + e, + ) + if attempt == max_retries - 1: + return {"error": f"Failed after retries: {e}"} + await asyncio.sleep(1) + return {"error": f"Failed to get agent run results after {max_retries} retries"} + + +def _convert_gcs_to_evaluation_item_result( + api_client: BaseApiClient, + gcs_uri: str, +) -> types.EvaluationItemResult: + """Converts a json file to an EvaluationItemResult.""" + logger.info("Loading evaluation item result from GCS: %s", gcs_uri) + gcs_utils = _gcs_utils.GcsUtils(api_client=api_client) + try: + eval_item_data = json.loads(gcs_utils.read_file_contents(gcs_uri)) + return types.EvaluationItemResult(**eval_item_data) + except Exception as e: + logger.error( + "Failed to load evaluation result from GCS: %s. Error: %s", gcs_uri, e + ) + return types.EvaluationItemResult() + + +def _convert_gcs_to_evaluation_item_request( + api_client: BaseApiClient, + gcs_uri: str, +) -> types.EvaluationItemRequest: + """Converts a json file to an EvaluationItemRequest.""" + logger.info("Loading evaluation item request from GCS: %s", gcs_uri) + gcs_utils = _gcs_utils.GcsUtils(api_client=api_client) + try: + eval_item_data = json.loads(gcs_utils.read_file_contents(gcs_uri)) + return types.EvaluationItemRequest(**eval_item_data) + except Exception as e: + logger.error( + "Failed to load evaluation request from GCS: %s. Error: %s", gcs_uri, e + ) + return types.EvaluationItemRequest() + + +def _get_aggregated_metrics( + results: types.EvaluationRunResults, +) -> list[types.AggregatedMetricResult]: + """Retrieves an EvaluationResult from the resource name.""" + if ( + not results + or not results.summary_metrics + or not results.summary_metrics.metrics + ): + return [] + + aggregated_metrics_dict: dict[str, dict[str, Any]] = {} + for name, value in results.summary_metrics.metrics.items(): + result = name.rsplit("/", 1) + full_metric_name = result[0] + aggregated_metric_name = result[1] + if full_metric_name not in aggregated_metrics_dict: + aggregated_metrics_dict[full_metric_name] = {} + aggregated_metrics_dict[full_metric_name]["sub_metric_name"] = ( + full_metric_name.split("/")[-1] + ) + aggregated_metrics_dict[full_metric_name][aggregated_metric_name] = value + + items_sorted = sorted( + aggregated_metrics_dict.items(), + key=lambda item: (item[1]["sub_metric_name"], item[0]), + ) + + return [ + types.AggregatedMetricResult( + metric_name=name.split("/")[-1], + mean_score=values.get("AVERAGE"), + stdev_score=values.get("STANDARD_DEVIATION"), + ) + for name, values in items_sorted + ] + + +def _get_eval_case_result_from_eval_item( + index: int, + eval_item: types.EvaluationItem, +) -> types.EvalCaseResult: + """Transforms EvaluationItem to EvalCaseResult.""" + metric_results = {} + if ( + eval_item.evaluation_response + and eval_item.evaluation_response.candidate_results + ): + for candidate_result in eval_item.evaluation_response.candidate_results: + metric_results[candidate_result.metric] = types.EvalCaseMetricResult( + metric_name=candidate_result.metric, + score=candidate_result.score, + explanation=candidate_result.explanation, + rubric_verdicts=candidate_result.rubric_verdicts, + error_message=(eval_item.error.message if eval_item.error else None), + ) + return types.EvalCaseResult( + eval_case_index=index, + response_candidate_results=[ + types.ResponseCandidateResult( + response_index=0, + metric_results=metric_results, + ) + ], + ) + + +def _convert_request_to_dataset_row( + request: types.EvaluationItemRequest, +) -> dict[str, Any]: + """Converts an EvaluationItemRequest to a dictionary.""" + dict_row: dict[str, Any] = {} + dict_row[_evals_constant.PROMPT] = ( + request.prompt.text if request.prompt and request.prompt.text else None + ) + dict_row[_evals_constant.REFERENCE] = request.golden_response + + if request.prompt and request.prompt.user_scenario: + dict_row[_evals_constant.STARTING_PROMPT] = ( + request.prompt.user_scenario.starting_prompt + ) + dict_row[_evals_constant.CONVERSATION_PLAN] = ( + request.prompt.user_scenario.conversation_plan + ) + + intermediate_events = [] + agent_data = None + if request.candidate_responses: + for candidate in request.candidate_responses: + if candidate.candidate is not None: + dict_row[candidate.candidate] = ( + candidate.text if candidate.text else None + ) + if candidate.events: + for event in candidate.events: + content_dict = {"parts": event.parts, "role": event.role} + int_events_dict = { + "event_id": candidate.candidate, + "content": content_dict, + } + intermediate_events.append(int_events_dict) + agent_data = request.candidate_responses[0].agent_data + + dict_row[_evals_constant.INTERMEDIATE_EVENTS] = intermediate_events + dict_row[_evals_constant.AGENT_DATA] = ( + agent_data.model_dump() if agent_data else None + ) + return dict_row + + +def _drop_empty_columns(df: "pd.DataFrame") -> "pd.DataFrame": + """Drops columns that are all None or all empty lists/dicts.""" + if df is None or df.empty or pd is None: + return df + + def is_empty(x: Any) -> bool: + if isinstance(x, (list, dict)): + return not x + return pd.isna(x) # type: ignore[no-any-return] + + cols_to_drop = [col for col in df.columns if df[col].apply(is_empty).all()] + return df.drop(columns=cols_to_drop) + + +def _transform_dataframe( + rows: list[dict[str, Any]], +) -> list[types.EvaluationDataset]: + """Transforms rows to a list of EvaluationDatasets. + + Args: + rows: A list of rows, each row is a dictionary of candidate name to response + text. + + Returns: + A list of EvaluationDatasets, one for each candidate. + """ + df = pd.DataFrame(rows) + candidates = [ + col for col in df.columns if col not in _evals_constant.COMMON_DATASET_COLUMNS + ] + + eval_dfs = [] + for candidate in candidates: + temp_df = df.rename(columns={candidate: _evals_constant.RESPONSE}) + temp_df = _drop_empty_columns(temp_df) + eval_dfs.append( + types.EvaluationDataset( + candidate_name=candidate, + eval_dataset_df=temp_df, + ) + ) + return eval_dfs + + +def _get_eval_cases_eval_dfs_from_eval_items( + eval_items: list[types.EvaluationItem], +) -> tuple[list[types.EvalCaseResult], list[types.EvaluationDataset]]: + """Converts an EvaluationSet to a list of EvaluationCaseResults and EvaluationDatasets. + + Args: + api_client: The API client. + evaluation_set_name: The name of the evaluation set. + + Returns: + A tuple of two lists: + - eval_case_results: A list of EvalCaseResults, one for each evaluation + item. + - eval_dfs: A list of EvaluationDatasets, one for each candidate. + """ + dataset_rows = [] + eval_case_results = [] + for index, eval_item in enumerate(eval_items): + if ( + eval_item + and eval_item.evaluation_response + and eval_item.evaluation_response.request + ): + eval_case_results.append( + _get_eval_case_result_from_eval_item(index, eval_item) + ) + dataset_rows.append( + _convert_request_to_dataset_row(eval_item.evaluation_response.request) + ) + eval_dfs = _transform_dataframe(dataset_rows) + return eval_case_results, eval_dfs + + +def _get_agent_info_from_inference_configs( + candidate_names: list[str], + inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None, +) -> Optional[types.evals.AgentInfo]: + """Retrieves an AgentInfo from the inference configs.""" + # TODO(lakeyk): Support multiple agents. + if not ( + inference_configs + and candidate_names + and candidate_names[0] in inference_configs + and inference_configs[candidate_names[0]].agent_config + ): + return None + if len(inference_configs.keys()) > 1: + logger.warning( + "Multiple agents are not supported yet. Displaying the first agent." + ) + agent_config = inference_configs[candidate_names[0]].agent_config + di = ( + agent_config.developer_instruction + if agent_config and agent_config.developer_instruction + else None + ) + instruction = di.parts[0].text if di and di.parts and di.parts[0].text else None + tools = agent_config.tools if agent_config and agent_config.tools else None + + return types.evals.AgentInfo( + name=candidate_names[0], + agents={ + "agent_0": types.evals.AgentConfig( + instruction=instruction, + tools=tools, + ) + }, + root_agent_id="agent_0", + ) + + +def _get_eval_result_from_eval_items( + results: types.EvaluationRunResults, + eval_items: list[types.EvaluationItem], + inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None, +) -> types.EvaluationResult: + """Retrieves an EvaluationResult from the EvaluationRunResults. + + This function is used to convert an EvaluationRunResults object used by the + Evaluation Management API to an EvaluationResult object. It is used to display + the evaluation results in the UI. + + Args: + results: The EvaluationRunResults object. + eval_items: The list of EvaluationItems. + + Returns: + An EvaluationResult object. + """ + aggregated_metrics = _get_aggregated_metrics(results) + eval_case_results, eval_dfs = _get_eval_cases_eval_dfs_from_eval_items(eval_items) + candidate_names = [eval_df.candidate_name for eval_df in eval_dfs] + eval_result = types.EvaluationResult( + summary_metrics=aggregated_metrics, + eval_case_results=eval_case_results, + evaluation_dataset=eval_dfs, + metadata=types.EvaluationRunMetadata( + candidate_names=candidate_names, + ), + agent_info=_get_agent_info_from_inference_configs( + candidate_names, inference_configs + ), + ) + return eval_result + + +def _build_eval_item_map( + eval_items: list[types.EvaluationItem], +) -> dict[str, dict[str, Any]]: + """Builds a mapping from EvaluationItem resource name to serialized data. + + This is used by the loss analysis visualization to enrich examples with + scenario and rubric data from the original evaluation items. + + Args: + eval_items: The list of EvaluationItem objects. + + Returns: + A dict mapping evaluation item resource name to the serialized + evaluation_response dict (which the JS visualization reads as + ``evaluation_result``). + """ + item_map: dict[str, dict[str, Any]] = {} + for item in eval_items: + if item.name and item.evaluation_response: + try: + item_map[item.name] = item.evaluation_response.model_dump( + mode="json", exclude_none=True + ) + except Exception: + pass + return item_map + + +def _convert_evaluation_run_results( + api_client: BaseApiClient, + evaluation_run_results: types.EvaluationRunResults, + inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None, +) -> tuple[Optional[types.EvaluationResult], dict[str, dict[str, Any]]]: + """Retrieves an EvaluationResult and item map from EvaluationRunResults. + + Returns: + A tuple of (EvaluationResult, eval_item_map). The eval_item_map maps + evaluation item resource names to their serialized evaluation response + data, used for enriching loss analysis visualization. + """ + if not evaluation_run_results or not evaluation_run_results.evaluation_set: + return None, {} + + evals_module = evals.Evals(api_client_=api_client) + eval_set = evals_module.get_evaluation_set( + name=evaluation_run_results.evaluation_set + ) + + eval_items = [] + if eval_set and eval_set.evaluation_items: + eval_items = [ + evals_module.get_evaluation_item(name=item_name) + for item_name in eval_set.evaluation_items + ] + eval_result = _get_eval_result_from_eval_items( + evaluation_run_results, eval_items, inference_configs + ) + eval_item_map = _build_eval_item_map(eval_items) + return eval_result, eval_item_map + + +async def _convert_evaluation_run_results_async( + api_client: BaseApiClient, + evaluation_run_results: types.EvaluationRunResults, + inference_configs: Optional[dict[str, types.EvaluationRunInferenceConfig]] = None, +) -> tuple[Optional[types.EvaluationResult], dict[str, dict[str, Any]]]: + """Retrieves an EvaluationResult and item map from EvaluationRunResults.""" + if not evaluation_run_results or not evaluation_run_results.evaluation_set: + return None, {} + + evals_module = evals.AsyncEvals(api_client_=api_client) + eval_set = await evals_module.get_evaluation_set( + name=evaluation_run_results.evaluation_set + ) + + eval_items = [] + if eval_set and eval_set.evaluation_items: + tasks = [ + evals_module.get_evaluation_item(name=eval_item) + for eval_item in eval_set.evaluation_items + ] + eval_items = await asyncio.gather(*tasks) + eval_result = _get_eval_result_from_eval_items( + evaluation_run_results, eval_items, inference_configs + ) + eval_item_map = _build_eval_item_map(eval_items) + return eval_result, eval_item_map + + +def _object_to_dict(obj: Any) -> Union[dict[str, Any], Any]: + """Converts an object to a dictionary.""" + if obj is None: + return obj + if isinstance(obj, (int, float, str, bool)): + return obj + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + if isinstance(obj, (list, tuple)): + return [_object_to_dict(item) for item in obj] + if isinstance(obj, dict): + return {k: _object_to_dict(v) for k, v in obj.items()} + + if not hasattr(obj, "__dict__"): + return obj # Not an object with attributes, return as is (e.g., set) + + result: dict[str, Any] = {} + for key, value in obj.__dict__.items(): + if value is None: + continue + result[key] = _object_to_dict(value) + return result + + +def _get_content(row: dict[str, Any], column: str) -> Optional[genai_types.Content]: + if isinstance(row[column], str): + return genai_types.Content( + parts=[genai_types.Part(text=row[column])], + role=_evals_constant.USER_AUTHOR, + ) + elif isinstance(row[column], genai_types.Content): + return cast(genai_types.Content, row[column]) + else: + raise ValueError( + f"{column} must be a string or a Content object. Got {type(row[column])}." + ) + + +def _create_evaluation_set_from_dataframe( + api_client: BaseApiClient, + gcs_dest_prefix: str, + eval_df: pd.DataFrame, + candidate_name: Optional[str] = None, +) -> Union[types.EvaluationSet, Any]: + """Converts a dataframe to an EvaluationSet.""" + eval_item_requests = [] + for _, row in eval_df.iterrows(): + intermediate_events = [] + if ( + _evals_constant.INTERMEDIATE_EVENTS in row + and isinstance(row[_evals_constant.INTERMEDIATE_EVENTS], list) + and len(row[_evals_constant.INTERMEDIATE_EVENTS]) > 0 + ): + for event in row[_evals_constant.INTERMEDIATE_EVENTS]: + if CONTENT in event: + intermediate_events.append(event[CONTENT]) + + agent_data_obj = None + if _evals_constant.AGENT_DATA in row: + agent_data_val = row[AGENT_DATA] + if isinstance(agent_data_val, str): + try: + agent_data_val = json.loads(agent_data_val) + except json.JSONDecodeError: + pass + if isinstance(agent_data_val, dict): + try: + agent_data_obj = types.evals.AgentData.model_validate( + agent_data_val + ) + except ValidationError: + pass + elif isinstance(agent_data_val, types.evals.AgentData): + agent_data_obj = agent_data_val + + candidate_responses = [] + if _evals_constant.RESPONSE in row or agent_data_obj or intermediate_events: + # Resolve the oneof conflict: prioritize agent_data over flat text + response_text = row.get(_evals_constant.RESPONSE) or None + + if agent_data_obj and response_text: + logger.info( + "Both 'response' and 'agent_data' columns found in the evaluation dataset. " + "Prioritizing 'agent_data' and omitting 'response' text to satisfy " + "CandidateResponse protobuf oneof constraints." + ) + response_text = None + + candidate_responses.append( + types.CandidateResponse( + candidate=candidate_name or "Candidate 1", + text=response_text, + events=intermediate_events or None, + agent_data=agent_data_obj, + ) + ) + + prompt = None + # Determine which history column name is present, preferring + # "conversation_history" over "history" if both exist. + history_col = None + if _evals_constant.CONVERSATION_HISTORY in row: + history_col = _evals_constant.CONVERSATION_HISTORY + elif _evals_constant.HISTORY in row: + history_col = _evals_constant.HISTORY + + if ( + _evals_constant.STARTING_PROMPT in row + and _evals_constant.CONVERSATION_PLAN in row + ): + prompt = types.EvaluationPrompt( + user_scenario=types.evals.UserScenario( + starting_prompt=row[_evals_constant.STARTING_PROMPT], + conversation_plan=row[_evals_constant.CONVERSATION_PLAN], + ) + ) + elif _evals_constant.CONTEXT in row or history_col: + values = {} + if _evals_constant.CONTEXT in row: + values[_evals_constant.CONTEXT] = _get_content( + row, _evals_constant.CONTEXT + ) + if history_col: + values[_evals_constant.CONVERSATION_HISTORY] = _get_content( + row, history_col + ) + if _evals_constant.PROMPT in row: + values[_evals_constant.PROMPT] = _get_content( + row, _evals_constant.PROMPT + ) + prompt = types.EvaluationPrompt( + prompt_template_data=types.PromptTemplateData(values=values) + ) + elif _evals_constant.PROMPT in row: + prompt = types.EvaluationPrompt(text=row[_evals_constant.PROMPT]) + + eval_item_requests.append( + types.EvaluationItemRequest( + prompt=prompt or None, + golden_response=( + types.CandidateResponse(text=row[_evals_constant.REFERENCE]) + if _evals_constant.REFERENCE in row + else None + ), + candidate_responses=( + candidate_responses if candidate_responses else None + ), + ) + ) + logger.info("Writing evaluation item requests to GCS.") + gcs_utils = _gcs_utils.GcsUtils(api_client=api_client) + evals_module = evals.Evals(api_client_=api_client) + eval_items = [] + for eval_item_request in eval_item_requests: + gcs_uri = gcs_utils.upload_json_to_prefix( + data=_object_to_dict(eval_item_request), + gcs_dest_prefix=gcs_dest_prefix, + filename_prefix="request", + ) + eval_item = evals_module.create_evaluation_item( + evaluation_item_type=types.EvaluationItemType.REQUEST, + gcs_uri=gcs_uri, + display_name="sdk-generated-eval-item", + ) + eval_items.append(eval_item.name) + logger.info("Creating evaluation set from GCS URIs") + evaluation_set = evals_module.create_evaluation_set( + evaluation_items=eval_items, + ) + + return evaluation_set diff --git a/agentplatform/_genai/_evals_constant.py b/agentplatform/_genai/_evals_constant.py new file mode 100644 index 0000000000..822f8e685a --- /dev/null +++ b/agentplatform/_genai/_evals_constant.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Constants for evals module.""" + +SUPPORTED_PREDEFINED_METRICS = frozenset( + { + "general_quality_v1", + "text_quality_v1", + "instruction_following_v1", + "grounding_v1", + "safety_v1", + "multi_turn_general_quality_v1", + "multi_turn_text_quality_v1", + "multi_turn_tool_use_quality_v1", + "multi_turn_trajectory_quality_v1", + "multi_turn_task_success_v1", + "final_response_match_v2", + "final_response_reference_free_v1", + "final_response_quality_v1", + "hallucination_v1", + "tool_use_quality_v1", + "gecko_text2image_v1", + "gecko_text2video_v1", + } +) + +SUPPORTED_VERTEX_MAAS_MODEL_PREFIXES = frozenset( + { + "meta/", # Meta/Llama + "deepseek-ai/", # DeepSeek AI + "qwen/", # Qwen + "openai/", # OpenAI (GPT-OSS) + "claude-", # Anthropic (Claude) + "mistral-", # Mistral AI + "jamba-", # AI21 (Jamba) + } +) +INTERMEDIATE_EVENTS = "intermediate_events" +RESPONSE = "response" +PROMPT = "prompt" +REFERENCE = "reference" +SESSION_INPUT = "session_inputs" +CONTEXT = "context" +CONTENT = "content" +PARTS = "parts" +USER_AUTHOR = "user" +AGENT_DATA = "agent_data" +STARTING_PROMPT = "starting_prompt" +CONVERSATION_PLAN = "conversation_plan" +HISTORY = "history" +CONVERSATION_HISTORY = "conversation_history" + +COMMON_DATASET_COLUMNS = frozenset( + { + INTERMEDIATE_EVENTS, + PROMPT, + REFERENCE, + SESSION_INPUT, + CONTEXT, + HISTORY, + CONVERSATION_HISTORY, + STARTING_PROMPT, + CONVERSATION_PLAN, + AGENT_DATA, + } +) diff --git a/agentplatform/_genai/_evals_data_converters.py b/agentplatform/_genai/_evals_data_converters.py new file mode 100644 index 0000000000..21564ff0a8 --- /dev/null +++ b/agentplatform/_genai/_evals_data_converters.py @@ -0,0 +1,918 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Dataset converters for evals.""" + +import copy +import json +import logging +from typing import Any, Optional, Union + +from google.genai import _common +from google.genai import types as genai_types +from pydantic import ValidationError +from typing_extensions import override + +from . import _evals_utils +from . import _observability_data_converter +from . import types + + +logger = logging.getLogger("agentplatform_genai._evals_data_converters") + + +class EvalDatasetSchema(_common.CaseInSensitiveEnum): + """Represents the schema of an evaluation dataset.""" + + GEMINI = "gemini" + FLATTEN = "flatten" + OPENAI = "openai" + OBSERVABILITY = "observability" + UNKNOWN = "unknown" + + +_PLACEHOLDER_RESPONSE_TEXT = "Error: Missing response for this candidate" + + +def _create_placeholder_response_candidate( + text: str = _PLACEHOLDER_RESPONSE_TEXT, +) -> types.ResponseCandidate: + """Creates a ResponseCandidate with placeholder text.""" + return types.ResponseCandidate( + response=genai_types.Content(parts=[genai_types.Part(text=text)]) + ) + + +class _GeminiEvalDataConverter(_evals_utils.EvalDataConverter): + """Converter for dataset in the Gemini format.""" + + def _parse_request(self, request_data: dict[str, Any]) -> tuple[ + genai_types.Content, + genai_types.Content, + list[types.evals.Message], + types.ResponseCandidate, + ]: + """Parses a request from a Gemini dataset.""" + system_instruction = genai_types.Content() + prompt = genai_types.Content() + reference = types.ResponseCandidate() + conversation_history = [] + + if "system_instruction" in request_data: + system_instruction = genai_types.Content.model_validate( + request_data["system_instruction"] + ) + for turn_id, content_dict in enumerate(request_data.get("contents", [])): + if not isinstance(content_dict, dict): + raise TypeError( + "Expected a dictionary for content at turn %s, but got %s: %s" + % (turn_id, type(content_dict).__name__, content_dict) + ) + if "parts" not in content_dict: + raise ValueError( + "Missing 'parts' key in content structure at turn %s: %s" + % (turn_id, content_dict) + ) + conversation_history.append( + types.evals.Message( + turn_id=str(turn_id), + content=genai_types.Content.model_validate(content_dict), + ) + ) + if conversation_history: + last_message = conversation_history.pop() + last_message_role = ( + last_message.content.role if last_message.content else "user" + ) + if last_message_role in ["user", None]: + prompt = ( + last_message.content + if last_message.content + else genai_types.Content() + ) + elif last_message_role == "model": + reference = types.ResponseCandidate(response=last_message.content) + if conversation_history: + second_to_last_message = conversation_history.pop() + prompt = ( + second_to_last_message.content + if second_to_last_message.content + else genai_types.Content() + ) + else: + prompt = genai_types.Content() + + return prompt, system_instruction, conversation_history, reference + + @override + def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: + """Converts a list of raw data into an EvaluationDataset.""" + eval_cases = [] + + for i, item in enumerate(raw_data): + eval_case_id = "gemini_eval_case_%s" % i + request_data = item.get("request", {}) + response_data = item.get("response", {}) + + ( + prompt, + system_instruction, + conversation_history, + reference, + ) = self._parse_request(request_data) + + responses = [] + if isinstance(response_data, str): + responses.append( + types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=response_data)] + ) + ) + ) + elif isinstance(response_data, dict): + try: + generate_content_response = ( + genai_types.GenerateContentResponse.model_validate( + response_data + ) + ) + if generate_content_response.candidates: + candidate = generate_content_response.candidates[0] + if candidate.content: + responses.append( + types.ResponseCandidate( + response=genai_types.Content.model_validate( + candidate.content + ) + ) + ) + else: + responses.append(_create_placeholder_response_candidate()) + except Exception: + responses.append(_create_placeholder_response_candidate()) + else: + responses.append(_create_placeholder_response_candidate()) + + eval_case = types.EvalCase( + eval_case_id=eval_case_id, + prompt=prompt, + responses=responses, + reference=reference, + system_instruction=system_instruction, + conversation_history=conversation_history, + ) + eval_cases.append(eval_case) + + return types.EvaluationDataset(eval_cases=eval_cases) + + +class _FlattenEvalDataConverter(_evals_utils.EvalDataConverter): + """Converter for datasets in a structured table format.""" + + def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: + """Converts a list of raw data into an EvaluationDataset.""" + eval_cases = [] + for i, item_dict in enumerate(raw_data): + if not isinstance(item_dict, dict): + raise TypeError( + "Expected a dictionary for item at index %s, but got %s: %s" + % (i, type(item_dict).__name__, item_dict) + ) + item = copy.deepcopy(item_dict) + eval_case_id = "eval_case_%s" % i + prompt_data = item.pop("prompt", None) + if not prompt_data: + prompt_data = item.pop("source", None) + + conversation_history_data = item.pop("conversation_history", None) + if conversation_history_data is None: + conversation_history_data = item.pop("history", None) + response_data = item.pop("response", None) + reference_data = item.pop("reference", None) + system_instruction_data = item.pop("instruction", None) + rubric_groups_data = item.pop("rubric_groups", None) + intermediate_events_data = item.pop("intermediate_events", None) + agent_data_raw = item.pop("agent_data", None) + + if not response_data and not agent_data_raw: + raise ValueError( + "Response is required but missing for %s." % eval_case_id + ) + if not prompt_data and not agent_data_raw: + raise ValueError( + "Prompt is required but missing for %s." % eval_case_id + ) + + prompt: Optional[genai_types.Content] = None + if isinstance(prompt_data, str): + prompt = genai_types.Content(parts=[genai_types.Part(text=prompt_data)]) + elif isinstance(prompt_data, dict): + prompt = genai_types.Content.model_validate(prompt_data) + elif isinstance(prompt_data, genai_types.Content): + prompt = prompt_data + elif not agent_data_raw: + raise ValueError( + "Invalid prompt type for case %s: %s" % (i, type(prompt_data)) + ) + + conversation_history: Optional[list[types.evals.Message]] = None + if isinstance(conversation_history_data, list): + conversation_history = [] + for turn_id, content in enumerate(conversation_history_data): + if isinstance(content, genai_types.Content): + conversation_history.append( + types.evals.Message( + turn_id=str(turn_id), + content=content, + ) + ) + elif isinstance(content, dict): + try: + validated_content = genai_types.Content.model_validate( + content + ) + conversation_history.append( + types.evals.Message( + turn_id=str(turn_id), + content=validated_content, + ) + ) + except ValidationError as e: + logger.warning( + "Item at index %s in 'history' column for case " + " %s is a dict but could not be validated as" + " genai_types.Content: %s", + turn_id, + eval_case_id, + e, + ) + else: + logger.warning( + "Invalid type in 'history' column for case %s at index %s. " + "Expected genai_types.Content or dict, but got %s. " + "Skipping this history item.", + eval_case_id, + turn_id, + type(content), + ) + + responses: Optional[list[types.ResponseCandidate]] = None + if isinstance(response_data, dict): + responses = [ + types.ResponseCandidate( + response=genai_types.Content.model_validate(response_data) + ) + ] + elif isinstance(response_data, str): + responses = [ + types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=response_data)] + ) + ) + ] + elif isinstance(response_data, genai_types.Content): + responses = [types.ResponseCandidate(response=response_data)] + elif not agent_data_raw: + raise ValueError( + "Invalid response type for case %s: %s" % (i, type(response_data)) + ) + + reference: Optional[types.ResponseCandidate] = None + if reference_data: + if isinstance(reference_data, dict): + reference = types.ResponseCandidate( + response=genai_types.Content.model_validate(reference_data) + ) + elif isinstance(reference_data, str): + reference = types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=reference_data)] + ) + ) + elif isinstance(reference_data, genai_types.Content): + reference = types.ResponseCandidate(response=reference_data) + + system_instruction: Optional[genai_types.Content] = None + if system_instruction_data: + if isinstance(system_instruction_data, dict): + system_instruction = genai_types.Content.model_validate( + system_instruction_data + ) + elif isinstance(system_instruction_data, str): + system_instruction = genai_types.Content( + parts=[genai_types.Part(text=system_instruction_data)] + ) + elif isinstance(system_instruction_data, genai_types.Content): + system_instruction = system_instruction_data + + rubric_groups: Optional[dict[str, types.RubricGroup]] = None + if rubric_groups_data: + if isinstance(rubric_groups_data, dict): + rubric_groups = {} + for key, value in rubric_groups_data.items(): + if isinstance(value, list): + try: + validated_rubrics = [ + ( + types.evals.Rubric.model_validate(r) + if isinstance(r, dict) + else r + ) + for r in value + ] + if all( + isinstance(r, types.evals.Rubric) + for r in validated_rubrics + ): + rubric_groups[key] = types.RubricGroup( + rubrics=validated_rubrics + ) + else: + logger.warning( + "Invalid item type in rubric list for group '%s' in case %s.", + key, + i, + ) + except Exception as e: + logger.warning( + "Failed to validate rubrics for group '%s' in case %s: %s", + key, + i, + e, + ) + elif isinstance(value, types.RubricGroup): + rubric_groups[key] = value + elif isinstance(value, dict): + try: + rubric_groups[key] = types.RubricGroup.model_validate( + value + ) + except Exception as e: + logger.warning( + "Failed to validate RubricGroup dict for group '%s' in case %s: %s", + key, + i, + e, + ) + else: + logger.warning( + "Invalid type for rubric group '%s' in case %s." + " Expected list of rubrics, dict, or RubricGroup.", + key, + i, + ) + else: + logger.warning( + "Invalid type for rubric_groups in case %s. Expected dict.", + i, + ) + + intermediate_events: Optional[list[types.evals.Event]] = None + if intermediate_events_data: + if isinstance(intermediate_events_data, list): + intermediate_events = [] + for event in intermediate_events_data: + if isinstance(event, dict): + try: + validated_event = types.evals.Event.model_validate( + event + ) + intermediate_events.append(validated_event) + except Exception as e: + logger.warning( + "Failed to validate intermediate event dict for" + " case %s: %s", + i, + e, + ) + elif isinstance(event, types.evals.Event): + intermediate_events.append(event) + else: + logger.warning( + "Invalid type for intermediate_event in case" + " %s. Expected list of dicts or list of" + " types.evals.Event objects.", + i, + ) + else: + logger.warning( + "Invalid type for intermediate_events in case %s. Expected" + " list of types.evals.Event objects.", + i, + ) + + agent_data: Optional[types.evals.AgentData] = None + if agent_data_raw: + if isinstance(agent_data_raw, str): + try: + agent_data_dict = json.loads(agent_data_raw) + agent_data = types.evals.AgentData.model_validate( + agent_data_dict + ) + except json.JSONDecodeError: + logger.warning( + "Could not decode agent_data JSON string for case %s.", i + ) + except ValidationError as e: + logger.warning( + "Failed to validate agent_data for case %s: %s", i, e + ) + elif isinstance(agent_data_raw, dict): + try: + agent_data = types.evals.AgentData.model_validate( + agent_data_raw + ) + except ValidationError as e: + logger.warning( + "Failed to validate agent_data for case %s: %s", i, e + ) + elif isinstance(agent_data_raw, types.evals.AgentData): + agent_data = agent_data_raw + else: + logger.warning( + "Invalid type for agent_data in case %s. Expected str, dict" + " or types.evals.AgentData object. Got %s", + i, + type(agent_data_raw), + ) + + eval_case = types.EvalCase( + eval_case_id=eval_case_id, + prompt=prompt, + responses=responses, + reference=reference, + conversation_history=conversation_history, + system_instruction=system_instruction, + rubric_groups=rubric_groups, + intermediate_events=intermediate_events, + agent_data=agent_data, + **item, # Pass remaining columns as extra fields to EvalCase. + # They can be used for custom metric prompt templates. + ) + eval_cases.append(eval_case) + + return types.EvaluationDataset(eval_cases=eval_cases) + + +class _OpenAIDataConverter(_evals_utils.EvalDataConverter): + """Converter for dataset in OpenAI's Chat Completion format.""" + + def _parse_messages(self, messages: list[dict[str, Any]]) -> tuple[ + Optional[genai_types.Content], + list[types.evals.Message], + Optional[genai_types.Content], + Optional[types.ResponseCandidate], + ]: + """Parses a list of messages into instruction, history, prompt, and reference.""" + system_instruction = None + prompt = None + reference = None + conversation_history = [] + + if messages and messages[0].get("role") in ["system", "developer"]: + system_instruction = genai_types.Content( + parts=[genai_types.Part(text=messages[0].get("content"))] + ) + messages = messages[1:] + + for turn_id, msg in enumerate(messages): + role = msg.get("role", "user") + content = msg.get("content", "") + conversation_history.append( + types.evals.Message( + turn_id=str(turn_id), + content=genai_types.Content( + parts=[genai_types.Part(text=content)], role=role + ), + author=role, + ) + ) + + if conversation_history: + last_message = conversation_history.pop() + if last_message.content and last_message.content.role == "user": + prompt = last_message.content + elif last_message.content and last_message.content.role == "assistant": + reference = types.ResponseCandidate(response=last_message.content) + if conversation_history: + second_to_last_message = conversation_history.pop() + prompt = second_to_last_message.content + + return system_instruction, conversation_history, prompt, reference + + @override + def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: + """Converts a list of OpenAI ChatCompletion data into an EvaluationDataset.""" + eval_cases = [] + for i, item in enumerate(raw_data): + eval_case_id = "openai_eval_case_%s" % i + + if "request" not in item or "response" not in item: + logger.warning( + "Skipping case %s due to missing 'request' or 'response' key.", i + ) + continue + + request_data = item.get("request", {}) + response_data_raw = item.get("response", {}) + + response_data = {} + if isinstance(response_data_raw, str): + try: + loaded_json = json.loads(response_data_raw) + if isinstance(loaded_json, dict): + response_data = loaded_json + else: + logger.warning( + "Decoded response JSON is not a dictionary for case" + " %s. Type: %s", + i, + type(loaded_json), + ) + except json.JSONDecodeError: + logger.warning( + "Could not decode response JSON string for case %s." + " Treating as empty response.", + i, + ) + elif isinstance(response_data_raw, dict): + response_data = response_data_raw + + messages = request_data.get("messages", []) + choices = response_data.get("choices", []) + + ( + system_instruction, + conversation_history, + prompt, + reference, + ) = self._parse_messages(messages) + + if prompt is None and reference is None: + logger.warning( + "Could not determine a user prompt or reference for case %s." + " Skipping.", + i, + ) + continue + + responses = [] + if ( + choices + and isinstance(choices, list) + and isinstance(choices[0], dict) + and choices[0].get("message") + ): + response_content = choices[0]["message"].get("content", "") + responses.append( + types.ResponseCandidate( + response=genai_types.Content( + parts=[genai_types.Part(text=response_content)] + ) + ) + ) + else: + responses.append(_create_placeholder_response_candidate()) + + other_fields = { + k: v for k, v in item.items() if k not in ["request", "response"] + } + + eval_case = types.EvalCase( + eval_case_id=eval_case_id, + prompt=prompt, + responses=responses, + reference=reference, + system_instruction=system_instruction, + conversation_history=conversation_history, + **other_fields, + ) + eval_cases.append(eval_case) + + return types.EvaluationDataset(eval_cases=eval_cases) + + +def auto_detect_dataset_schema( + raw_dataset: list[dict[str, Any]], +) -> Union[EvalDatasetSchema, str]: + """Detects the schema of a raw dataset.""" + if not raw_dataset: + return EvalDatasetSchema.UNKNOWN + + first_item = raw_dataset[0] + keys = set(first_item.keys()) + + if "format" in keys: + format_content = first_item.get("format", "") + if isinstance(format_content, str) and format_content == "observability": + return EvalDatasetSchema.OBSERVABILITY + + if "request" in keys and "response" in keys: + request_content = first_item.get("request", {}) + if isinstance(request_content, dict) and "contents" in request_content: + contents_list = request_content.get("contents") + if ( + contents_list + and isinstance(contents_list, list) + and isinstance(contents_list[0], dict) + ): + if "parts" in contents_list[0]: + return EvalDatasetSchema.GEMINI + + if "request" in keys and "response" in keys: + request_content = first_item.get("request", {}) + if isinstance(request_content, dict) and "messages" in request_content: + messages_list = request_content.get("messages") + if ( + messages_list + and isinstance(messages_list, list) + and isinstance(messages_list[0], dict) + ): + if "role" in messages_list[0] and "content" in messages_list[0]: + return EvalDatasetSchema.OPENAI + + if "agent_data" in keys: + return EvalDatasetSchema.FLATTEN + + if {"prompt", "response"}.issubset(keys) or { + "response", + "reference", + }.issubset(keys): + return EvalDatasetSchema.FLATTEN + else: + return EvalDatasetSchema.UNKNOWN + + +_CONVERTER_REGISTRY = { + EvalDatasetSchema.GEMINI: _GeminiEvalDataConverter, + EvalDatasetSchema.FLATTEN: _FlattenEvalDataConverter, + EvalDatasetSchema.OPENAI: _OpenAIDataConverter, + EvalDatasetSchema.OBSERVABILITY: _observability_data_converter.ObservabilityDataConverter, +} + + +def get_dataset_converter( + dataset_schema: EvalDatasetSchema, +) -> _evals_utils.EvalDataConverter: + """Returns the appropriate dataset converter for the given schema.""" + if dataset_schema in _CONVERTER_REGISTRY: + return _CONVERTER_REGISTRY[dataset_schema]() # type: ignore[abstract] + else: + raise ValueError("Unsupported dataset schema: %s" % dataset_schema) + + +def _get_content_text(content: genai_types.Content) -> str: + """Safely extracts text from all parts of a content. + + If the content has multiple parts, text from all parts is concatenated. + If a part is not text, it is ignored. If no text parts are found, + an empty string is returned. + """ + text_parts = [] + if ( + content + and hasattr(content, "parts") + and isinstance(content.parts, list) + and content.parts + ): + for part in content.parts: + if hasattr(part, "text") and part.text is not None: + text_parts.append(str(part.text)) + return "".join(text_parts) + + +def _get_text_from_reference( + reference: Optional[types.ResponseCandidate], +) -> Optional[str]: + """Safely extracts text from a reference field.""" + if reference and hasattr(reference, "response") and reference.response: + return _get_content_text(reference.response) + return None + + +def _validate_case_consistency( + base_case: types.EvalCase, + current_case: types.EvalCase, + case_idx: int, + dataset_idx: int, +) -> None: + """Logs warnings if prompt or reference mismatches occur.""" + if base_case.prompt != current_case.prompt: + base_prompt_text_preview = _get_content_text(base_case.prompt)[:50] + current_prompt_text_preview = _get_content_text(current_case.prompt)[:50] + logger.warning( + "Prompt mismatch for case index %d between base dataset (0)" + " and dataset %d. Using prompt from base. Base prompt" + " preview: '%s...', Dataset" + " %d prompt preview: '%s...'", + case_idx, + dataset_idx, + base_prompt_text_preview, + dataset_idx, + current_prompt_text_preview, + ) + + base_ref_text = _get_text_from_reference(base_case.reference) + current_ref_text = _get_text_from_reference(current_case.reference) + + if bool(base_case.reference) != bool(current_case.reference): + logger.warning( + "Reference presence mismatch for case index %d between base" + " dataset (0) and dataset %d. Using reference (or lack" + " thereof) from base.", + case_idx, + dataset_idx, + ) + elif base_ref_text != current_ref_text: + logger.warning( + "Reference text mismatch for case index %d between base" + " dataset (0) and dataset %d. Using reference from base. " + " Base ref: '%s...', Current ref:" + " '%s...'", + case_idx, + dataset_idx, + str(base_ref_text)[:50], + str(current_ref_text)[:50], + ) + + +def merge_evaluation_datasets( + datasets: list[types.EvaluationDataset], + agent_info: Optional[types.evals.AgentInfo] = None, +) -> types.EvaluationDataset: + """Merges multiple EvaluationDatasets into a single EvaluationDataset. + + Assumes that each dataset has responses corresponding to the same set of + prompts, in the same order. The prompt, reference, system_instruction, and + conversation_history are taken from the first dataset. + """ + if not datasets: + raise ValueError("Input 'datasets' cannot be empty.") + + num_expected_cases = 0 + if datasets[0].eval_cases: + num_expected_cases = len(datasets[0].eval_cases) + + if num_expected_cases == 0: + logger.warning( + "The first dataset has no evaluation cases. Result will be empty." + ) + return types.EvaluationDataset(eval_cases=[]) + + for i, ds in enumerate(datasets): + current_len = len(ds.eval_cases) if ds.eval_cases else 0 + if current_len != num_expected_cases: + raise ValueError( + "All datasets must have the same number of evaluation cases. " + "Base dataset (0) has %s, but dataset %s has %s." + % (num_expected_cases, i, current_len) + ) + + merged_eval_cases: list[types.EvalCase] = [] + base_parsed_dataset = datasets[0] + + for case_idx in range(num_expected_cases): + base_eval_case: types.EvalCase = ( + base_parsed_dataset.eval_cases[case_idx] + if base_parsed_dataset.eval_cases + else types.EvalCase() + ) + candidate_responses: list[types.ResponseCandidate] = [] + + if base_eval_case.responses: + candidate_responses.append(base_eval_case.responses[0]) + elif base_eval_case.agent_data: + candidate_responses.append(_create_placeholder_response_candidate("")) + else: + logger.warning( + "No response or agent data found for base dataset (index 0) in case %s. " + "Adding placeholder.", + case_idx, + ) + candidate_responses.append( + _create_placeholder_response_candidate( + "Missing response from base dataset (0) for case %s" % case_idx + ) + ) + + eval_case_custom_columns = base_eval_case.model_dump( + exclude={ + "eval_case_id", + "prompt", + "responses", + "reference", + "system_instruction", + "conversation_history", + "intermediate_events", + "agent_data", + "agent_info", + }, + exclude_none=True, + ) + for dataset_idx_offset, current_parsed_ds in enumerate(datasets[1:], start=1): + current_ds_eval_case: types.EvalCase = ( + current_parsed_ds.eval_cases[case_idx] + if current_parsed_ds.eval_cases + else types.EvalCase() + ) + + _validate_case_consistency( + base_eval_case, current_ds_eval_case, case_idx, dataset_idx_offset + ) + + current_ds_extra_attrs = current_ds_eval_case.model_dump( + exclude={ + "eval_case_id", + "prompt", + "responses", + "reference", + "system_instruction", + "conversation_history", + "intermediate_events", + "agent_data", + "agent_info", + }, + exclude_none=True, + ) + eval_case_custom_columns.update(current_ds_extra_attrs) + + if current_ds_eval_case.responses: + candidate_responses.append(current_ds_eval_case.responses[0]) + elif current_ds_eval_case.agent_data: + candidate_responses.append(_create_placeholder_response_candidate("")) + else: + logger.warning( + "No response or agent data found for dataset %s in case %s. Adding" + " placeholder.", + dataset_idx_offset, + case_idx, + ) + candidate_responses.append( + _create_placeholder_response_candidate( + "Missing response from dataset %s for case %s" + % (dataset_idx_offset, case_idx) + ) + ) + + merged_case = types.EvalCase( + eval_case_id=base_eval_case.eval_case_id + or "merged_eval_case_%s" % case_idx, + prompt=base_eval_case.prompt, + responses=candidate_responses if candidate_responses else None, + reference=base_eval_case.reference, + system_instruction=base_eval_case.system_instruction, + conversation_history=base_eval_case.conversation_history, + agent_info=agent_info or base_eval_case.agent_info, + agent_data=base_eval_case.agent_data, + intermediate_events=base_eval_case.intermediate_events, + **eval_case_custom_columns, + ) + merged_eval_cases.append(merged_case) + + return types.EvaluationDataset(eval_cases=merged_eval_cases) + + +def merge_response_datasets_into_canonical_format( + raw_datasets: list[list[dict[str, Any]]], + schemas: list[str], + agent_info: Optional[types.evals.AgentInfo] = None, +) -> types.EvaluationDataset: + """Merges multiple raw response datasets into a single EvaluationDataset. + + Assumes that each dataset in raw_datasets has responses corresponding + to the same set of prompts, in the same order. The prompt, reference, + system_instruction, and conversation_history are taken from the first dataset. + """ + if not isinstance(raw_datasets, list): + raise TypeError( + "Input 'raw_datasets' must be a list, got %s." % type(raw_datasets) + ) + if not raw_datasets or not all(isinstance(ds, list) for ds in raw_datasets): + raise ValueError( + "Input 'raw_datasets' cannot be empty and must be a list of lists." + ) + if not schemas or len(schemas) != len(raw_datasets): + raise ValueError( + "A list of schemas must be provided, one for each raw dataset. " + "Got %s schemas for %s datasets." % (len(schemas), len(raw_datasets)) + ) + + parsed_evaluation_datasets: list[types.EvaluationDataset] = [] + for i, (raw_ds_entry, schema) in enumerate(zip(raw_datasets, schemas)): + converter = get_dataset_converter(schema) + parsed_evaluation_datasets.append(converter.convert(raw_ds_entry)) + + return merge_evaluation_datasets(parsed_evaluation_datasets, agent_info) diff --git a/agentplatform/_genai/_evals_metric_handlers.py b/agentplatform/_genai/_evals_metric_handlers.py new file mode 100644 index 0000000000..4571802dbc --- /dev/null +++ b/agentplatform/_genai/_evals_metric_handlers.py @@ -0,0 +1,1755 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Handlers for computing evaluation metrics.""" + +import abc +import collections +from concurrent import futures +import json +import logging +import random +import statistics +import time +from typing import Any, Callable, Generic, Optional, TypeVar, Union + +from google.genai import errors as genai_errors +from google.genai import _common +from google.genai import types as genai_types +from tqdm import tqdm +from typing_extensions import override + +from . import _evals_common +from . import _evals_constant +from . import _evals_utils +from . import evals +from . import types + + +logger = logging.getLogger(__name__) +_MAX_RETRIES = 5 +# HTTP status codes that are safe to retry with backoff. +_RETRYABLE_STATUS_CODES = frozenset( + { + 408, # RequestTimeout (DEADLINE_EXCEEDED) + 409, # Conflict / Aborted (ABORTED) + 429, # TooManyRequests / ResourceExhausted (RESOURCE_EXHAUSTED) + 499, # Client Closed Request (CANCELLED) + 500, # InternalServerError (INTERNAL) + 502, # BadGateway + 503, # ServiceUnavailable (UNAVAILABLE) + 504, # GatewayTimeout (DEADLINE_EXCEEDED) + } +) + +R = TypeVar("R") +T = TypeVar("T", types.Metric, types.MetricSource, types.LLMMetric) + + +def _call_with_retry( + fn: Callable[[], R], + metric_name: str, +) -> R: + """Calls ``fn()`` with exponential backoff + jitter on retryable errors. + + Retries up to ``_MAX_RETRIES`` times on errors whose HTTP status code is + in ``_RETRYABLE_STATUS_CODES`` (Aborted, DeadlineExceeded, + ResourceExhausted, ServiceUnavailable, Cancelled). Non-retryable errors + are re-raised immediately. If all retries are exhausted the last + exception is re-raised so the caller can decide how to handle it. + + Args: + fn: A zero-argument callable that performs the API call. + metric_name: Name of the metric, used for log messages. + + Returns: + The return value of ``fn()``. + + Raises: + genai_errors.APIError: If all retries are exhausted or the error is + not retryable. + """ + for attempt in range(_MAX_RETRIES): + try: + return fn() + except genai_errors.APIError as e: + if e.code in _RETRYABLE_STATUS_CODES: + backoff = 2**attempt + random.uniform(0, 1) + logger.warning( + "Retryable error (code=%s) on attempt %d/%d for metric" + " '%s': %s. Retrying in %.1f seconds...", + e.code, + attempt + 1, + _MAX_RETRIES, + metric_name, + e, + backoff, + ) + if attempt == _MAX_RETRIES - 1: + raise + time.sleep(backoff) + else: + raise + raise genai_errors.APIError( + code=504, response_json={"message": "Retries exhausted"} + ) + + +def _has_tool_call(events: Optional[list[Any]]) -> bool: + """Checks if any event in events has a function call.""" + if not events: + return False + for event in events: + if getattr(event, "content", None) and getattr(event.content, "parts", None): + for part in event.content.parts: + if hasattr(part, "function_call") and part.function_call: + return True + return False + + +def _extract_text_from_content( + content: Optional[genai_types.Content], warn_property: str = "text" +) -> Optional[str]: + """Extracts and concatenates all text parts from a Content object.""" + if not content or not content.parts: + return None + + text_accumulator = "" + any_text_part_found = False + non_text_part_names = [] + + for part_obj in content.parts: + part_dump = part_obj.model_dump(exclude={"text", "thought"}) + for field_name, field_value in part_dump.items(): + if field_value is not None: + if field_name not in non_text_part_names: + non_text_part_names.append(field_name) + + if isinstance(part_obj.text, str): + if ( + hasattr(part_obj, "thought") + and isinstance(part_obj.thought, bool) + and part_obj.thought + ): + continue + any_text_part_found = True + text_accumulator += part_obj.text + + if non_text_part_names and any_text_part_found: + logger.warning( + "Warning: content contains non-text parts: %s. Returning" + " concatenated %s result from text parts. Inspect individual parts" + " for full content.", + non_text_part_names, + warn_property, + ) + return text_accumulator if any_text_part_found else None + + +def _get_prompt_from_eval_case( + eval_case: types.EvalCase, +) -> Optional[genai_types.Content]: + """Extracts prompt content from eval_case.prompt or starting_prompt.""" + if eval_case.prompt: + return eval_case.prompt + + user_scenario = getattr(eval_case, "user_scenario", None) + if user_scenario and user_scenario.starting_prompt: + return genai_types.Content( + parts=[genai_types.Part(text=user_scenario.starting_prompt)] + ) + + return None + + +def _get_response_from_eval_case( + eval_case: types.EvalCase, response_index: int, metric_name: str +) -> Optional[genai_types.Content]: + """Extracts response content from eval_case.responses.""" + response_content = None + if eval_case.responses and response_index < len(eval_case.responses): + response_content = eval_case.responses[response_index].response + + return response_content + + +def _value_to_content_list(value: Any) -> list[genai_types.Content]: + """Converts a value to a list of Content objects.""" + if isinstance(value, genai_types.Content): + return [value] + if isinstance(value, types.ResponseCandidate): + return [value.response] if value.response else [] + if isinstance(value, list) and value: + if isinstance(value[0], genai_types.Content): + return value + if isinstance(value[0], types.evals.Message): + history_texts = [] + for msg_obj in value: + msg_text = _extract_text_from_content(msg_obj.content) + if msg_text: + role = msg_obj.content.role or msg_obj.author or "user" + history_texts.append(f"{role}: {msg_text}") + return [ + genai_types.Content( + parts=[genai_types.Part(text="\n".join(history_texts))] + ) + ] + return [genai_types.Content(parts=[genai_types.Part(text=json.dumps(value))])] + if isinstance(value, dict): + return [genai_types.Content(parts=[genai_types.Part(text=json.dumps(value))])] + return [genai_types.Content(parts=[genai_types.Part(text=str(value))])] + + +def _get_autorater_config(metric: types.Metric) -> dict[str, Any]: + """Extracts autorater config settings from a metric.""" + autorater_config: dict[str, Any] = {} + if metric.judge_model: + autorater_config["autorater_model"] = metric.judge_model + if metric.judge_model_generation_config: + autorater_config["generation_config"] = metric.judge_model_generation_config + if metric.judge_model_sampling_count: + autorater_config["sampling_count"] = metric.judge_model_sampling_count + return autorater_config + + +def _default_aggregate_scores( + metric_name: str, + eval_case_metric_results: list[types.EvalCaseMetricResult], + calculate_pass_rate: bool = False, +) -> types.AggregatedMetricResult: + """Default aggregation logic using mean and standard deviation.""" + scores = [] + num_error = 0 + num_valid = 0 + num_passing = 0 + + for result in eval_case_metric_results: + if result.error_message is None and result.score is not None: + try: + score = float(result.score) + scores.append(score) + num_valid += 1 + if calculate_pass_rate and score == 1.0: + num_passing += 1 + except (ValueError, TypeError): + logger.warning( + "Could not convert score '%s' to float for metric '%s' during" + " default aggregation. Counting as error.", + result.score, + metric_name, + ) + num_error += 1 + else: + num_error += 1 + + mean_score = None + stdev_score = None + pass_rate = None + + if num_valid > 0: + try: + mean_score = statistics.mean(scores) + except statistics.StatisticsError as e: + logger.warning("Could not calculate mean for %s: %s", metric_name, e) + if calculate_pass_rate: + pass_rate = num_passing / num_valid + + if num_valid > 1: + try: + stdev_score = statistics.stdev(scores) + except statistics.StatisticsError as e: + logger.warning("Could not calculate stdev for %s: %s", metric_name, e) + + return types.AggregatedMetricResult( + metric_name=metric_name, + num_cases_total=len(eval_case_metric_results), + num_cases_valid=num_valid, + num_cases_error=num_error, + mean_score=mean_score, + stdev_score=stdev_score, + pass_rate=pass_rate if calculate_pass_rate else None, + ) + + +class MetricHandler(abc.ABC, Generic[T]): + """Abstract base class for metric handlers.""" + + def __init__(self, module: "evals.Evals", metric: T): + self.module = module + self.metric: T = metric + + @property + @abc.abstractmethod + def metric_name(self) -> str: + """Returns the name of the metric polymorphically.""" + raise NotImplementedError() + + @abc.abstractmethod + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a specific metric.""" + raise NotImplementedError() + + @abc.abstractmethod + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a specific metric.""" + raise NotImplementedError() + + +class ComputationMetricHandler(MetricHandler[types.Metric]): + """Metric handler for computation metrics.""" + + SUPPORTED_COMPUTATION_METRICS = frozenset( + { + "exact_match", + "bleu", + "rouge_1", + "rouge_l_sum", + "tool_call_valid", + "tool_name_match", + "tool_parameter_key_match", + "tool_parameter_kv_match", + # TODO b/423934249 - Add trajectory metrics once they are supported. + } + ) + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.Metric): + super().__init__(module=module, metric=metric) + if self.metric.name not in self.SUPPORTED_COMPUTATION_METRICS: + raise ValueError( + f"Metric '{self.metric.name}' is not supported for computation." + ) + + def _build_request_payload( + self, eval_case: types.EvalCase, response_index: int + ) -> dict[str, Any]: + """Builds the request parameters for evaluate instances.""" + request_payload = {} + + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric.name + ) + prediction_text = _extract_text_from_content(response_content) + + if prediction_text is None: + raise ValueError( + f"Response text missing for candidate {response_index} in eval_case" + f" {eval_case.eval_case_id or 'Unknown ID'}." + ) + + if ( + eval_case.reference is None + or _extract_text_from_content(eval_case.reference.response) is None + ): + raise ValueError( + "Reference text missing for eval_case" + f" {eval_case.eval_case_id or 'Unknown ID'}." + ) + logger.debug("eval_case: %s", eval_case) + + if self.metric.name and self.metric.name.startswith("rouge"): + request_payload["rouge_input"] = { + "metric_spec": { + "rouge_type": ( + "rougeLsum" if self.metric.name == "rouge_l_sum" else "rouge1" + ), + }, + "instances": [ + { + "prediction": prediction_text, + "reference": _extract_text_from_content( + eval_case.reference.response + ), + } + ], + } + else: + request_payload[f"{self.metric.name}_input"] = { + "metric_spec": {}, + "instances": [ + { + "prediction": prediction_text, + "reference": _extract_text_from_content( + eval_case.reference.response + ), + } + ], + } + logger.debug("request_payload: %s", request_payload) + return request_payload + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a specific computation metric.""" + + metric_name = self.metric.name + logger.debug( + "ComputationMetricHandler: Processing '%s' for case: %s", + metric_name, + eval_case.model_dump(exclude_none=True), + ) + response = _call_with_retry( + lambda: self.module.evaluate_instances( + metric_config=self._build_request_payload(eval_case, response_index) + ).model_dump(exclude_none=True), + metric_name, + ) + logger.debug("response: %s", response) + score = None + for _, result_value in response.items(): + if isinstance(result_value, dict) and result_value: + for _, metric_value in result_value.items(): + if isinstance(metric_value, list) and metric_value: + score = metric_value[0]["score"] + break + logger.debug("Metric result: %s", score) + return types.EvalCaseMetricResult( + metric_name=metric_name, + score=score, + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a computation metric.""" + logger.debug("Aggregating results for computation metric: %s", self.metric.name) + return _default_aggregate_scores(self.metric.name, eval_case_metric_results) + + +class TranslationMetricHandler(MetricHandler[types.Metric]): + """Metric handler for translation metrics.""" + + SUPPORTED_TRANSLATION_METRICS = frozenset({"comet", "metricx"}) + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.Metric): + super().__init__(module=module, metric=metric) + + if self.metric.name not in self.SUPPORTED_TRANSLATION_METRICS: + raise ValueError( + f"Metric '{self.metric.name}' is not supported for translation." + ) + + def _build_request_payload( + self, eval_case: types.EvalCase, response_index: int + ) -> dict[str, Any]: + """Builds the request parameters for evaluate instances.""" + request_payload = {} + metric_input_name = f"{self.metric.name}_input" + version = None + if hasattr(self.metric, "version"): + version = self.metric.version + elif self.metric.name == "comet": + version = "COMET_22_SRC_REF" + elif self.metric.name == "metricx": + version = "METRICX_24_SRC_REF" + + source_language = None + target_language = None + if hasattr(self.metric, "source_language"): + source_language = self.metric.source_language + if hasattr(self.metric, "target_language"): + target_language = self.metric.target_language + + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric.name + ) + prediction_text = _extract_text_from_content(response_content) + prompt_text = _extract_text_from_content(_get_prompt_from_eval_case(eval_case)) + + if prediction_text is None: + raise ValueError( + f"Response text missing for candidate {response_index} in eval_case" + f" {eval_case.eval_case_id or 'Unknown ID'}." + ) + + if ( + eval_case.reference is None + or _extract_text_from_content(eval_case.reference.response) is None + ): + raise ValueError( + "Reference text missing for eval_case" + f" {eval_case.eval_case_id or 'Unknown ID'}." + ) + if prompt_text is None: + raise ValueError( + "Prompt text (source for translation) missing for eval_case" + f" {eval_case.eval_case_id or 'Unknown ID'}." + ) + + request_payload[metric_input_name] = { + "metric_spec": { + "version": version, + "source_language": source_language, + "target_language": target_language, + }, + "instance": { + "prediction": prediction_text, + "reference": _extract_text_from_content(eval_case.reference.response), + "source": prompt_text, + }, + } + return request_payload + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a specific translation metric.""" + metric_name = self.metric.name + logger.debug( + "TranslationMetricHandler: Processing '%s' for case: %s", + metric_name, + eval_case, + ) + api_response = _call_with_retry( + lambda: self.module.evaluate_instances( + metric_config=self._build_request_payload(eval_case, response_index) + ), + metric_name, + ) + logger.debug("API Response: %s", api_response) + + score = None + error_message = None + + try: + if metric_name == "comet": + if api_response and api_response.comet_result: + score = api_response.comet_result.score + else: + logger.warning( + "Comet result missing in API response for metric '%s'." + " API response: %s", + metric_name, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + ) + elif metric_name == "metricx": + if api_response and api_response.metricx_result: + score = api_response.metricx_result.score + else: + logger.warning( + "MetricX result missing in API response for metric '%s'." + " API response: %s", + metric_name, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + ) + if score is None and not error_message: + logger.warning( + "Score could not be extracted for translation metric '%s'." + " API response: %s", + metric_name, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error processing/extracting score for translation metric '%s': %s." + " API response: %s", + metric_name, + e, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + exc_info=True, + ) + error_message = f"Error extracting score: {e}" + + return types.EvalCaseMetricResult( + metric_name=metric_name, + score=score, + error_message=error_message, + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a translation metric.""" + logger.debug("Aggregating results for translation metric: %s", self.metric.name) + return _default_aggregate_scores(self.metric.name, eval_case_metric_results) + + +def _content_to_instance_data( + content: Optional[genai_types.Content], +) -> Optional[types.evals.InstanceData]: + """Converts a genai_types.Content object to a types.InstanceData object.""" + if not content: + return None + return types.evals.InstanceData( + contents=types.evals.InstanceDataContents(contents=[content]) + ) + + +def _eval_case_to_agent_data( + eval_case: types.EvalCase, + prompt_content: Optional[genai_types.Content] = None, + response_content: Optional[genai_types.Content] = None, +) -> Optional[types.evals.AgentData]: + """Converts an EvalCase object to a single turn AgentData object. + + If `eval_case.agent_data` is provided, it is returned directly, and + `prompt_content` and `response_content` are ignored. + """ + if getattr(eval_case, "agent_data", None): + return eval_case.agent_data + + if ( + not eval_case.agent_info + and not eval_case.intermediate_events + and not prompt_content + and not response_content + ): + return None + + agents_map = eval_case.agent_info.agents if eval_case.agent_info else None + events = [] + if prompt_content: + events.append(types.evals.AgentEvent(author="user", content=prompt_content)) + + if eval_case.intermediate_events: + for event in eval_case.intermediate_events: + events.append( + types.evals.AgentEvent( + author=event.author, + content=event.content, + event_time=event.creation_timestamp, + ) + ) + + if response_content: + events.append(types.evals.AgentEvent(author="model", content=response_content)) + + turns = ( + [types.evals.ConversationTurn(turn_index=0, turn_id="turn_0", events=events)] + if events + else None + ) + return types.evals.AgentData(agents=agents_map, turns=turns) + + +def _build_evaluation_instance( + eval_case: types.EvalCase, + response_content: Optional[genai_types.Content], + prompt_instance_data: Optional[types.evals.InstanceData] = None, + prompt_template: Optional[str] = None, +) -> types.EvaluationInstance: + """Builds a unified EvaluationInstance. Multi-turn logic is handled by the caller.""" + extracted_prompt = _get_prompt_from_eval_case(eval_case) + + # 1. Use caller-provided prompt data (multi-turn) or default to simple content + if prompt_instance_data is None: + prompt_instance_data = _content_to_instance_data(extracted_prompt) + + # 2. Collect placeholders for other_data + other_data_map: dict[str, Any] = {} + if hasattr(eval_case, "context") and eval_case.context: + if isinstance(eval_case.context, str): + other_data_map["context"] = types.evals.InstanceData(text=eval_case.context) + elif isinstance(eval_case.context, genai_types.Content): + other_data_map["context"] = _content_to_instance_data(eval_case.context) + + # 3. Extract custom variables from LLMMetric templates + if prompt_template: + template_vars = types.PromptTemplate(text=prompt_template).variables + standard_fields = {"prompt", "response", "reference", "context", "agent_data"} + for full_path in template_vars: + # Extract the root variable (e.g. 'metadata' from 'metadata.user_id') + root_var = full_path.split(".")[0].split("[")[0] + + if root_var not in standard_fields and hasattr(eval_case, root_var): + val = getattr(eval_case, root_var) + # Add the root object to other_data so the backend can traverse it + other_data_map[root_var] = types.evals.InstanceData( + contents=types.evals.InstanceDataContents( + contents=_value_to_content_list(val) + ) + ) + + return types.EvaluationInstance( + prompt=prompt_instance_data, + response=_content_to_instance_data(response_content), + reference=( + _content_to_instance_data(eval_case.reference.response) + if eval_case.reference + else None + ), + rubric_groups=eval_case.rubric_groups, + other_data=( + types.MapInstance(map_instance=other_data_map) if other_data_map else None + ), + agent_data=_eval_case_to_agent_data( + eval_case, extracted_prompt, response_content + ), + ) + + +class LLMMetricHandler(MetricHandler[types.LLMMetric]): + """Metric handler for LLM metrics.""" + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.LLMMetric): + super().__init__(module=module, metric=metric) + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case using the unified backend interface.""" + try: + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric_name + ) + if not response_content: + raise ValueError( + f"Response content missing for candidate {response_index}." + ) + + instance = _build_evaluation_instance( + eval_case, response_content, prompt_template=self.metric.prompt_template + ) + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=instance, + ), + self.metric_name, + ) + + if api_response and api_response.metric_results: + result = api_response.metric_results[0] + error_msg = None + if result.error and getattr(result.error, "code"): + error_msg = f"Error in metric result: {result.error}" + + return types.EvalCaseMetricResult( + metric_name=self.metric_name, + score=result.score, + explanation=result.explanation, + rubric_verdicts=result.rubric_verdicts, + error_message=error_msg, + ) + else: + return types.EvalCaseMetricResult( + metric_name=self.metric_name, + error_message="Metric results missing in API response.", + ) + + except Exception as e: + logger.error( + "Error processing metric %s for case %s.", + self.metric_name, + eval_case.eval_case_id, + exc_info=True, + ) + return types.EvalCaseMetricResult( + metric_name=self.metric_name, error_message=str(e) + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a LLM metric.""" + if self.metric.aggregate_summary_fn and callable( + self.metric.aggregate_summary_fn + ): + logger.info( + "Using custom aggregate_summary_fn for metric '%s'", self.metric.name + ) + try: + custom_summary_dict = self.metric.aggregate_summary_fn( + eval_case_metric_results + ) + if not isinstance(custom_summary_dict, dict): + raise TypeError("aggregate_summary_fn must return a dictionary.") + + num_cases_total = len(eval_case_metric_results) + num_cases_error = len( + [ + result + for result in eval_case_metric_results + if result.error_message is not None + ] + ) + num_cases_valid = num_cases_total - num_cases_error + required_fields = { + "num_cases_total": num_cases_total, + "num_cases_error": num_cases_error, + "num_cases_valid": num_cases_valid, + } + final_summary_dict = {**required_fields, **custom_summary_dict} + + return types.AggregatedMetricResult( + metric_name=self.metric.name, + **final_summary_dict, + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error executing custom aggregate_summary_fn for metric '%s': %s." + " Falling back to default aggregation.", + self.metric.name, + e, + exc_info=True, + ) + return _default_aggregate_scores( + self.metric.name, eval_case_metric_results + ) + else: + logger.debug( + "Using default aggregation for LLM metric '%s'", self.metric.name + ) + return _default_aggregate_scores(self.metric.name, eval_case_metric_results) + + +class CustomMetricHandler(MetricHandler[types.Metric]): + """Metric handler for custom metrics.""" + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.Metric): + super().__init__(module=module, metric=metric) + + if not self.metric.custom_function: + raise ValueError( + f"CustomMetricHandler for '{self.metric.name}' needs " + " Metric.custom_function to be set." + ) + if not isinstance(self.metric.custom_function, Callable): + raise ValueError( + f"CustomMetricHandler for '{self.metric.name}' needs " + " Metric.custom_function to be a callable function." + ) + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a custom metric.""" + metric_name = self.metric.name + logger.debug( + "CustomMetricHandler: Processing '%s' for case: %s", + metric_name, + eval_case.model_dump(exclude_none=True), + ) + + try: + response_content = _get_response_from_eval_case( + eval_case, response_index, metric_name + ) + except ValueError as e: + return types.EvalCaseMetricResult( + metric_name=metric_name, + error_message=str(e), + ) + + if not response_content: + return types.EvalCaseMetricResult( + metric_name=metric_name, + error_message=( + f"No response found for candidate {response_index} in EvalCase" + f" {eval_case.eval_case_id}." + ), + ) + + instance_for_custom_fn = eval_case.model_dump( + exclude={"responses"}, mode="json", exclude_none=True + ) + instance_for_custom_fn["response"] = response_content.model_dump( + mode="json", exclude_none=True + ) + extracted_prompt = _get_prompt_from_eval_case(eval_case) + if extracted_prompt: + instance_for_custom_fn["prompt"] = extracted_prompt.model_dump( + mode="json", exclude_none=True + ) + + error_msg = None + score = None + explanation = None + try: + if self.metric.custom_function and callable(self.metric.custom_function): + custom_function_result = self.metric.custom_function( + instance_for_custom_fn + ) + + if isinstance(custom_function_result, types.EvalCaseMetricResult): + return custom_function_result + elif ( + isinstance(custom_function_result, dict) + and "score" in custom_function_result + ): + score = custom_function_result["score"] + explanation = custom_function_result.get("explanation", None) + elif isinstance(custom_function_result, (float, int)): + score = custom_function_result + explanation = None + else: + error_msg = ( + f"CustomFunctionError({self.metric.custom_function}): Returned" + f" unexpected type {type(custom_function_result)}" + ) + + except Exception as e: # pylint: disable=broad-exception-caught + if self.metric.custom_function and hasattr( + self.metric.custom_function, "__name__" + ): + custom_function_name = self.metric.custom_function.__name__ + else: + custom_function_name = "unknown_custom_function" + error_msg = f"CustomFunctionError({custom_function_name}): {e}" + score = None + explanation = None + + return types.EvalCaseMetricResult( + metric_name=self.metric.name, + score=score, + explanation=explanation, + error_message=error_msg, + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a custom metric.""" + logger.debug("Aggregating results for custom metric: %s", self.metric.name) + return _default_aggregate_scores(self.metric.name, eval_case_metric_results) + + +class PredefinedMetricHandler(MetricHandler[types.Metric]): + """Metric handler for predefined metrics.""" + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.Metric): + super().__init__(module=module, metric=metric) + if self.metric.name not in _evals_constant.SUPPORTED_PREDEFINED_METRICS: + raise ValueError( + f"Metric '{self.metric.name}' is not a supported predefined metric." + ) + + def _build_request_payload( + self, eval_case: types.EvalCase, response_index: int + ) -> dict[str, Any]: + """Builds the request parameters for evaluate instances request.""" + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric.name + ) + + if not response_content and not getattr(eval_case, "agent_data", None): + raise ValueError( + f"Response content missing for candidate {response_index}." + ) + + if self.metric.name == "tool_use_quality_v1": + has_tool_call = _has_tool_call(eval_case.intermediate_events) + + # Check agent_data for tool calls if intermediate_events is empty + agent_data = getattr(eval_case, "agent_data", None) + if not has_tool_call and agent_data: + for turn in agent_data.turns or []: + if _has_tool_call(turn.events): + has_tool_call = True + break + + if not has_tool_call: + logger.warning( + "Metric 'tool_use_quality_v1' requires tool usage in " + "'intermediate_events' or 'agent_data', but no tool usage was found for case %s.", + eval_case.eval_case_id, + ) + + extracted_prompt = _get_prompt_from_eval_case(eval_case) + prompt_instance_data = None + if self.metric.name and self.metric.name.startswith("multi_turn"): + prompt_contents = [ + msg.content for msg in (eval_case.conversation_history or []) + ] + if extracted_prompt: + prompt_contents.append(extracted_prompt) + prompt_instance_data = types.evals.InstanceData( + contents=types.evals.InstanceDataContents(contents=prompt_contents) + ) + + instance_payload = _build_evaluation_instance( + eval_case=eval_case, + response_content=response_content, + prompt_instance_data=prompt_instance_data, + ) + + request_payload: dict[str, Any] = { + "instance": instance_payload, + } + + autorater_config = _get_autorater_config(self.metric) + if autorater_config: + request_payload["autorater_config"] = genai_types.AutoraterConfig( + **autorater_config + ) + return request_payload + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a specific predefined metric.""" + metric_name = self.metric.name + try: + payload = self._build_request_payload(eval_case, response_index) + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=payload.get("instance"), + autorater_config=payload.get("autorater_config"), + ), + metric_name, + ) + + if ( + api_response + and hasattr(api_response, "metric_results") + and api_response.metric_results + ): + result_data = api_response.metric_results[0] + + error_message = None + if result_data.error and getattr(result_data.error, "code"): + error_message = f"Error in metric result: {result_data.error}" + return types.EvalCaseMetricResult( + metric_name=metric_name, + score=result_data.score, + explanation=result_data.explanation, + rubric_verdicts=result_data.rubric_verdicts, + error_message=error_message, + ) + else: + logger.error( + "Metric results missing in API response for predefined metric '%s'." + " API response: %s", + metric_name, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + ) + return types.EvalCaseMetricResult( + metric_name=metric_name, + error_message="Metric results missing in API response.", + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error processing metric %s for case %s: %s", + metric_name, + eval_case.eval_case_id, + e, + exc_info=True, + ) + return types.EvalCaseMetricResult( + metric_name=metric_name, error_message=str(e) + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a predefined metric.""" + logger.debug("Aggregating results for predefined metric: %s", self.metric.name) + return _default_aggregate_scores( + self.metric.name, eval_case_metric_results, calculate_pass_rate=True + ) + + +class CustomCodeExecutionMetricHandler(MetricHandler[types.Metric]): + """Metric handler for custom code execution metrics.""" + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + def __init__(self, module: "evals.Evals", metric: types.Metric): + super().__init__(module=module, metric=metric) + + if not self.metric.remote_custom_function and not self.metric.custom_function: + raise ValueError( + f"CustomCodeExecutionMetricHandler for '{self.metric.name}' needs " + " custom function to be set." + ) + + def _build_request_payload( + self, eval_case: types.EvalCase, response_index: int + ) -> dict[str, Any]: + """Builds the request parameters for evaluate instances request.""" + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric.name + ) + + if not response_content and not getattr(eval_case, "agent_data", None): + raise ValueError( + f"Response content missing for candidate {response_index}." + ) + + reference_instance_data = None + if eval_case.reference: + reference_instance_data = _content_to_instance_data( + eval_case.reference.response + ) + + extracted_prompt = _get_prompt_from_eval_case(eval_case) + prompt_instance_data = _content_to_instance_data(extracted_prompt) + + instance_payload = types.EvaluationInstance( + prompt=prompt_instance_data, + response=_content_to_instance_data(response_content), + reference=reference_instance_data, + agent_data=_eval_case_to_agent_data(eval_case), + ) + + return { + "instance": instance_payload, + } + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case for a specific custom code execution metric.""" + metric_name = self.metric.name + try: + payload = self._build_request_payload(eval_case, response_index) + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metrics=[self.metric], + instance=payload.get("instance"), + ), + metric_name, + ) + + if ( + api_response + and hasattr(api_response, "metric_results") + and api_response.metric_results + ): + result_data = api_response.metric_results[0] + + error_message = None + if result_data.error and getattr(result_data.error, "code"): + error_message = f"Error in metric result: {result_data.error}" + return types.EvalCaseMetricResult( + metric_name=metric_name, + score=result_data.score, + explanation=result_data.explanation, + error_message=error_message, + ) + else: + logger.error( + "Metric results missing in API response for metric '%s'." + " API response: %s", + metric_name, + ( + api_response.model_dump_json(exclude_none=True) + if api_response + else "None" + ), + ) + return types.EvalCaseMetricResult( + metric_name=metric_name, + error_message="Metric results missing in API response.", + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error processing metric %s for case %s", + metric_name, + eval_case.eval_case_id, + exc_info=True, + ) + return types.EvalCaseMetricResult( + metric_name=metric_name, error_message=str(e) + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a custom code execution metric.""" + logger.debug( + "Aggregating results for custom code execution metric: %s", self.metric.name + ) + return _default_aggregate_scores( + self.metric.name, eval_case_metric_results, calculate_pass_rate=True + ) + + +class RegisteredMetricHandler(MetricHandler[types.Metric]): + """Metric handler for registered metrics.""" + + def __init__( + self, + module: "evals.Evals", + metric: types.Metric, + ): + if isinstance(metric, dict): + metric = types.MetricSource(**metric) + super().__init__(module=module, metric=metric) + + def _build_request_payload( + self, eval_case: types.EvalCase, response_index: int + ) -> dict[str, Any]: + """Builds request payload for registered metric by assembling EvaluationInstance.""" + response_content = _get_response_from_eval_case( + eval_case, response_index, self.metric_name + ) + + if not response_content and not getattr(eval_case, "agent_data", None): + raise ValueError( + f"Response content missing for candidate {response_index}." + ) + + reference_instance_data = None + if eval_case.reference: + reference_instance_data = _content_to_instance_data( + eval_case.reference.response + ) + + extracted_prompt = _get_prompt_from_eval_case(eval_case) + prompt_instance_data = _content_to_instance_data(extracted_prompt) + + instance_payload = types.EvaluationInstance( + prompt=prompt_instance_data, + response=_content_to_instance_data(response_content), + reference=reference_instance_data, + rubric_groups=eval_case.rubric_groups, + agent_data=_eval_case_to_agent_data(eval_case), + ) + + request_payload = { + "instance": instance_payload, + } + return request_payload + + @property + def metric_name(self) -> str: + return self.metric.name or "unknown_metric" + + @override + def get_metric_result( + self, eval_case: types.EvalCase, response_index: int + ) -> types.EvalCaseMetricResult: + """Processes a single evaluation case using a MetricSource reference.""" + metric_name = self.metric_name + metric_source = types.MetricSource( + metric_resource_name=self.metric.metric_resource_name + ) + + try: + payload = self._build_request_payload(eval_case, response_index) + api_response = _call_with_retry( + lambda: self.module._evaluate_instances( + metric_sources=[metric_source], + instance=payload.get("instance"), + autorater_config=payload.get("autorater_config"), + ), + metric_name, + ) + + if api_response and api_response.metric_results: + result_data = api_response.metric_results[0] + error_message = None + if result_data.error and getattr(result_data.error, "code"): + error_message = f"Error in metric result: {result_data.error}" + return types.EvalCaseMetricResult( + metric_name=metric_name, + score=result_data.score, + explanation=result_data.explanation, + rubric_verdicts=result_data.rubric_verdicts, + error_message=error_message, + ) + else: + return types.EvalCaseMetricResult( + metric_name=metric_name, + error_message="Metric results missing in API response.", + ) + except Exception as e: + return types.EvalCaseMetricResult( + metric_name=metric_name, error_message=str(e) + ) + + @override + def aggregate( + self, eval_case_metric_results: list[types.EvalCaseMetricResult] + ) -> types.AggregatedMetricResult: + """Aggregates the metric results for a registered metric.""" + return _default_aggregate_scores( + self.metric_name, eval_case_metric_results, calculate_pass_rate=True + ) + + +_METRIC_HANDLER_MAPPING = [ + ( + lambda m: ( + # Recognize the user-facing class + isinstance(m, types.CodeExecutionMetric) + and (hasattr(m, "custom_function") and m.custom_function) + ) + or (hasattr(m, "remote_custom_function") and m.remote_custom_function) + # Recognize base Metric objects that have been coerced by Pydantic + or ( + isinstance(m, types.Metric) + and isinstance(getattr(m, "custom_function", None), str) + ), + CustomCodeExecutionMetricHandler, + ), + ( + lambda m: m.custom_function and isinstance(m.custom_function, Callable), + CustomMetricHandler, + ), + ( + lambda m: getattr(m, "metric_resource_name", None) is not None, + RegisteredMetricHandler, + ), + ( + lambda m: m.name in ComputationMetricHandler.SUPPORTED_COMPUTATION_METRICS, + ComputationMetricHandler, + ), + ( + lambda m: m.name in TranslationMetricHandler.SUPPORTED_TRANSLATION_METRICS, + TranslationMetricHandler, + ), + ( + lambda m: m.name in _evals_constant.SUPPORTED_PREDEFINED_METRICS, + PredefinedMetricHandler, + ), + (lambda m: isinstance(m, types.LLMMetric), LLMMetricHandler), +] + +MetricHandlerType = TypeVar( + "MetricHandlerType", + ComputationMetricHandler, + TranslationMetricHandler, + LLMMetricHandler, + CustomMetricHandler, + CustomCodeExecutionMetricHandler, + PredefinedMetricHandler, +) + + +def get_handler_for_metric( + module: "evals.Evals", metric: types.Metric +) -> Union[MetricHandlerType, Any]: + """Returns a metric handler for the given metric.""" + for condition, handler_class in _METRIC_HANDLER_MAPPING: + if condition(metric): # type: ignore[no-untyped-call] + return handler_class(module=module, metric=metric) + raise ValueError(f"Unsupported metric: {metric.name}") + + +def calculate_win_rates(eval_result: types.EvaluationResult) -> dict[str, Any]: + """Calculates win/tie rates for comparison results.""" + if not eval_result.eval_case_results: + return {} + max_models = max( + ( + len(case.response_candidate_results) + for case in eval_result.eval_case_results + if case.response_candidate_results + ), + default=0, + ) + if max_models == 0: + return {} + stats: collections.defaultdict[str, dict[str, Any]] = collections.defaultdict( + lambda: {"wins": [0] * max_models, "ties": 0, "valid_comparisons": 0} + ) + for case in eval_result.eval_case_results: + if not case.response_candidate_results: + continue + scores_by_metric = collections.defaultdict(list) + for idx, candidate in enumerate(case.response_candidate_results): + for name, res in ( + candidate.metric_results.items() if candidate.metric_results else {} + ): + if res.score is not None: + scores_by_metric[name].append({"score": res.score, "cand_idx": idx}) + for name, scores in scores_by_metric.items(): + if not scores: + continue + stats[name]["valid_comparisons"] += 1 + max_score = max(s["score"] for s in scores) + winners = [s["cand_idx"] for s in scores if s["score"] == max_score] + if len(winners) == 1: + stats[name]["wins"][winners[0]] += 1 + else: + stats[name]["ties"] += 1 + win_rates = {} + for name, metric_stats in stats.items(): + if metric_stats["valid_comparisons"] > 0: + win_rates[name] = { + "win_rates": [ + w / metric_stats["valid_comparisons"] for w in metric_stats["wins"] + ], + "tie_rate": metric_stats["ties"] / metric_stats["valid_comparisons"], + } + return win_rates + + +def _aggregate_metric_results( + metric_handlers: list[MetricHandler[Any]], + eval_case_results: list[types.EvalCaseResult], +) -> list[types.AggregatedMetricResult]: + """Aggregates results by calling the aggregate method of each handler.""" + aggregated_metric_results = [] + logger.info("Aggregating results per metric...") + for handler in metric_handlers: + metric_name = handler.metric_name + results_for_this_metric: list[types.EvalCaseMetricResult] = [] + for case_result in eval_case_results: + if case_result.response_candidate_results: + for response_candidate_res in case_result.response_candidate_results: + if ( + response_candidate_res.metric_results + and metric_name in response_candidate_res.metric_results + and isinstance(metric_name, str) + ): + results_for_this_metric.append( + response_candidate_res.metric_results[metric_name] + ) + if not results_for_this_metric: + logger.warning( + "No results found for metric '%s' to aggregate.", metric_name + ) + continue + + try: + summary = handler.aggregate(results_for_this_metric) + aggregated_metric_results.append(summary) + except NotImplementedError: + logger.warning( + "Aggregation not implemented for metric handler: %s (metric: '%s')." + " Skipping summary.", + type(handler).__name__, + metric_name, + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error during aggregation for metric '%s' using handler %s: %s", + metric_name, + type(handler).__name__, + e, + exc_info=True, + ) + aggregated_metric_results.append( + types.AggregatedMetricResult( + metric_name=metric_name, + num_cases_total=len(results_for_this_metric), + num_cases_valid=0, + num_cases_error=len(results_for_this_metric), + mean_score=None, + stdev_score=None, + ) + ) + logger.debug("Finished aggregation, returning: %s", aggregated_metric_results) + return aggregated_metric_results + + +class EvaluationRunConfig(_common.BaseModel): + """Configuration for an evaluation run.""" + + evals_module: Any + """The module to be used for the evaluation run.""" + dataset: types.EvaluationDataset + """The dataset to be used for the evaluation run.""" + metrics: list[types.Metric] + """The list of metrics to be used for the evaluation run.""" + num_response_candidates: int + """The number of response candidates for the evaluation run.""" + + +def _rate_limited_get_metric_result( + rate_limiter: _evals_utils.RateLimiter, + handler: MetricHandler[Any], + eval_case: types.EvalCase, + response_index: int, +) -> types.EvalCaseMetricResult: + """Wraps a handler's get_metric_result with rate limiting.""" + rate_limiter.sleep_and_advance() + return handler.get_metric_result(eval_case, response_index) + + +def compute_metrics_and_aggregate( + evaluation_run_config: EvaluationRunConfig, + evaluation_service_qps: Optional[float] = None, +) -> types.EvaluationResult: + """Computes metrics and aggregates them for a given evaluation run config. + + Args: + evaluation_run_config: The configuration for the evaluation run. + evaluation_service_qps: Optional QPS limit for the evaluation service. + Defaults to _DEFAULT_EVAL_SERVICE_QPS (10). Users with higher + quotas can increase this value. + """ + metric_handlers = [] + all_futures = [] + results_by_case_response_metric: collections.defaultdict[ + Any, collections.defaultdict[Any, dict[Any, Any]] + ] = collections.defaultdict(lambda: collections.defaultdict(dict)) + submission_errors = [] + execution_errors = [] + case_indices_with_errors = set() + + if evaluation_service_qps is not None and evaluation_service_qps <= 0: + raise ValueError("evaluation_service_qps must be a positive number.") + qps = evaluation_service_qps or _evals_utils._DEFAULT_EVAL_SERVICE_QPS + rate_limiter = _evals_utils.RateLimiter(rate=qps) + logger.info("Rate limiting evaluation service requests to %.1f QPS.", qps) + + for eval_metric in evaluation_run_config.metrics: + metric_handlers.append( + get_handler_for_metric(evaluation_run_config.evals_module, eval_metric) + ) + + eval_case_count = len(evaluation_run_config.dataset.eval_cases) + logger.info("Total number of evaluation cases: %d", eval_case_count) + logger.info( + "Number of response candidates: %d", + evaluation_run_config.num_response_candidates, + ) + total_metric_computations = ( + eval_case_count + * len(metric_handlers) + * evaluation_run_config.num_response_candidates + ) + logger.info("Total number of metric computations: %d", total_metric_computations) + + with tqdm( + total=total_metric_computations, + desc="Computing Metrics for Evaluation Dataset", + ) as pbar: + with futures.ThreadPoolExecutor( + max_workers=_evals_common.MAX_WORKERS + ) as executor: + for metric_handler_instance in metric_handlers: + for eval_case_index, eval_case in enumerate( + evaluation_run_config.dataset.eval_cases + ): + num_responses = ( + len(eval_case.responses) if eval_case.responses else 0 + ) + if num_responses == 0 and getattr(eval_case, "agent_data", None): + num_responses = 1 + + actual_num_candidates_for_case = min( + evaluation_run_config.num_response_candidates, + num_responses, + ) + for response_index in range(actual_num_candidates_for_case): + try: + future = executor.submit( + _rate_limited_get_metric_result, + rate_limiter, + metric_handler_instance, + eval_case, + response_index, + ) + future.add_done_callback(lambda _: pbar.update(1)) + logger.debug( + "Submitting metric computation for case %d, " + "response %d for metric %s.", + eval_case_index, + response_index, + metric_handler_instance.metric_name, + ) + all_futures.append( + ( + future, + metric_handler_instance.metric_name, + eval_case_index, + response_index, + ) + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error submitting metric computation for case %d, " + "response %d for metric %s: %s", + eval_case_index, + response_index, + metric_handler_instance.metric_name, + e, + exc_info=True, + ) + submission_errors.append( + ( + metric_handler_instance.metric_name, + eval_case_index, + response_index, + f"Error: {e}", + ) + ) + error_result = types.EvalCaseMetricResult( + metric_name=metric_handler_instance.metric_name, + error_message=f"Submission Error: {e}", + ) + results_by_case_response_metric[eval_case_index][ + response_index + ][metric_handler_instance.metric_name] = error_result + case_indices_with_errors.add(eval_case_index) + pbar.update(1) + + for future, metric_name, eval_case_index, response_index in all_futures: + try: + eval_case_metric_result = future.result() + logger.debug( + "Successfully obtained result for metric '%s', case %d, response" + " %d: %s.", + metric_name, + eval_case_index, + response_index, + eval_case_metric_result, + ) + results_by_case_response_metric[eval_case_index][response_index][ + metric_name + ] = eval_case_metric_result + logger.debug( + "Stored result for metric '%s', case %d, response %d.", + metric_name, + eval_case_index, + response_index, + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error( + "Error executing metric '%s' for case %s, response %s: %s", + metric_name, + eval_case_index, + response_index, + e, + exc_info=True, + ) + error_msg = ( + f"Error executing metric '{metric_name}' for case" + f" {eval_case_index}, response {response_index}: {e}" + ) + execution_errors.append( + ( + metric_name, + eval_case_index, + response_index, + error_msg, + ) + ) + case_indices_with_errors.add(eval_case_index) + error_result = types.EvalCaseMetricResult( + metric_name=metric_name, + error_message=error_msg, + ) + results_by_case_response_metric[eval_case_index][response_index][ + metric_name + ] = error_result + + final_eval_case_results = [] + sorted_eval_case_indices = sorted(results_by_case_response_metric.keys()) + for eval_case_index in sorted_eval_case_indices: + per_response_results_for_this_case = results_by_case_response_metric[ + eval_case_index + ] + + current_response_candidate_results_list = [] + sorted_response_indices = sorted(per_response_results_for_this_case.keys()) + + for response_index in sorted_response_indices: + metric_results_for_this_response = per_response_results_for_this_case[ + response_index + ] + + response_candidate_result_obj = types.ResponseCandidateResult( + response_index=response_index, + metric_results=metric_results_for_this_response, + ) + current_response_candidate_results_list.append( + response_candidate_result_obj + ) + + if current_response_candidate_results_list: + eval_case_result = types.EvalCaseResult( + eval_case_index=eval_case_index, + response_candidate_results=current_response_candidate_results_list, + ) + final_eval_case_results.append(eval_case_result) + elif eval_case_index in case_indices_with_errors or any( + err_case_idx == eval_case_index + for _, err_case_idx, _, _ in submission_errors + ): + logger.warning( + "EvalCase %d had errors but no metric results were" + " processed into the structure.", + eval_case_index, + ) + eval_case_result = types.EvalCaseResult( + eval_case_index=eval_case_index, + response_candidate_results=[], + ) + final_eval_case_results.append(eval_case_result) + + if submission_errors: + logger.warning("Encountered %d submission errors.", len(submission_errors)) + logger.warning("Submission errors: %s", submission_errors) + if execution_errors: + logger.warning("Encountered %d execution errors.", len(execution_errors)) + logger.warning("Execution errors: %s", execution_errors) + + aggregated_metric_results = _aggregate_metric_results( + metric_handlers, final_eval_case_results + ) + eval_result = types.EvaluationResult( + eval_case_results=final_eval_case_results, + summary_metrics=aggregated_metric_results, + ) + if evaluation_run_config.num_response_candidates > 1: + try: + eval_result.win_rates = calculate_win_rates(eval_result) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Error calculating win rates: %s", e, exc_info=True) + return eval_result diff --git a/agentplatform/_genai/_evals_metric_loaders.py b/agentplatform/_genai/_evals_metric_loaders.py new file mode 100644 index 0000000000..a9f62cb92a --- /dev/null +++ b/agentplatform/_genai/_evals_metric_loaders.py @@ -0,0 +1,381 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for evals.""" + +import json +import logging +import os +import re +from typing import Any, Optional, Union, TYPE_CHECKING + +import yaml + +from . import _evals_constant +from . import _gcs_utils + +if TYPE_CHECKING: + from . import types + + +logger = logging.getLogger(__name__) + + +class LazyLoadedPrebuiltMetric: + """A proxy object representing a prebuilt metric to be loaded on demand. + + This can resolve to either an API Predefined Metric or an LLM Metric + loaded from GCS. + """ + + _cache: dict[str, "types.Metric"] = {} + _base_gcs_path = ( + "gs://vertex-ai-generative-ai-eval-sdk-resources/metrics/{metric_name}/" + ) + + def __init__(self, name: str, version: Optional[str] = None, **kwargs: Any): + self.name = name.upper() + self.version = version + self.metric_kwargs = kwargs + self._resolved_metric: Optional["types.Metric"] = None + + def _get_api_metric_spec_name(self) -> Optional[str]: + """Constructs the metric_spec_name for API Predefined Metrics.""" + base_name = self.name.lower() + if self.version: + # Explicit version provided + version = self.version.lower() + potential_name = f"{base_name}_{version}" + return ( + potential_name + if potential_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS + else None + ) + else: + # Default versioning: Try _v1, then base name + v1_name = f"{base_name}_v1" + if v1_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS: + return v1_name + if base_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS: + return base_name + return None + + def _resolve_api_predefined(self) -> Optional["types.Metric"]: + """Attempts to resolve as an API Predefined Metric.""" + from . import types + + metric_spec_name = self._get_api_metric_spec_name() + if metric_spec_name: + logger.info( + "Resolving '%s' as API Predefined Metric with spec name: %s", + self.name, + metric_spec_name, + ) + return types.Metric(name=metric_spec_name, **self.metric_kwargs) + return None + + def _get_latest_version_uri(self, api_client: Any, metric_gcs_dir: str) -> str: + """Lists files in GCS directory and determines the latest version URI.""" + gcs_utils = _gcs_utils.GcsUtils(api_client) + bucket_name, prefix = gcs_utils.parse_gcs_path(metric_gcs_dir) + + blobs = gcs_utils.storage_client.list_blobs(bucket_name, prefix=prefix) + + version_files: list[dict[str, Union[list[int], str]]] = ( + [] + ) # {'version_parts': [1,0,0], 'filename': 'v1.0.0.yaml'} + + version_pattern = re.compile( + r"v(\d+)(?:\.(\d+))?(?:\.(\d+))?\.(yaml|yml|json)$", re.IGNORECASE + ) + + for blob in blobs: + match = version_pattern.match(os.path.basename(blob.name)) + if match: + major = int(match.group(1)) + minor = int(match.group(2)) if match.group(2) else 0 + patch = int(match.group(3)) if match.group(3) else 0 + version_files.append( + { + "version_parts": [major, minor, patch], + "filename": os.path.basename(blob.name), + } + ) + + if not version_files: + raise IOError(f"No versioned metric files found in {metric_gcs_dir}") + + version_files.sort(key=lambda x: x["version_parts"], reverse=True) + + latest_filename = version_files[0]["filename"] + return os.path.join(metric_gcs_dir, latest_filename) + + def _fetch_and_parse(self, api_client: Any) -> "types.LLMMetric": + """Fetches and parses the metric definition from GCS.""" + + from . import types + + metric_gcs_dir = self._base_gcs_path.format(metric_name=self.name.lower()) + uri: str + if self.version == "latest" or self.version is None: + uri = self._get_latest_version_uri(api_client, metric_gcs_dir) + resolved_version_match = re.match( + r"(v\d+(?:\.\d+)*)\.(?:yaml|yml|json)", + os.path.basename(uri), + re.IGNORECASE, + ) + if resolved_version_match: + self.version = resolved_version_match.group(1) + else: + # Fallback if regex fails + self.version = os.path.splitext(os.path.basename(uri))[0] + else: + yaml_uri = os.path.join(metric_gcs_dir, f"{self.version}.yaml") + json_uri = os.path.join(metric_gcs_dir, f"{self.version}.json") + + gcs_utils = _gcs_utils.GcsUtils(api_client) + try: + bucket_name, blob_path = gcs_utils.parse_gcs_path(yaml_uri) + if ( + gcs_utils.storage_client.bucket(bucket_name) + .blob(blob_path) + .exists() + ): + uri = yaml_uri + else: + bucket_name_json, blob_path_json = gcs_utils.parse_gcs_path( + json_uri + ) + if ( + gcs_utils.storage_client.bucket(bucket_name_json) + .blob(blob_path_json) + .exists() + ): + uri = json_uri + else: + raise IOError( + f"Metric file for version '{self.version}' " + f"not found as .yaml or .json in {metric_gcs_dir}" + ) + except Exception as e: + raise IOError( + f"Error checking for metric file version '{self.version}' in" + f" {metric_gcs_dir}: {e}" + ) from e + + logger.info( + "Fetching predefined metric '%s@%s' from %s...", + self.name, + self.version, + uri, + ) + + gcs_utils = _gcs_utils.GcsUtils(api_client) + content_str = gcs_utils.read_file_contents(uri) + + file_extension = os.path.splitext(uri)[1].lower() + data: dict[str, Any] + if file_extension == ".yaml" or file_extension == ".yml": + if yaml is None: + raise ImportError( + "YAML parsing requires the pyyaml library. Please install it" + " with `pip install google-cloud-aiplatform[evaluation]`." + ) + data = yaml.safe_load(content_str) + elif file_extension == ".json": + data = json.loads(content_str) + else: + raise ValueError(f"Unsupported file extension: {file_extension}") + + if not isinstance(data, dict): + raise ValueError("Metric config content did not parse into a dictionary.") + + metric_obj = types.LLMMetric.model_validate({**data, **self.metric_kwargs}) + metric_obj._is_predefined = True + metric_obj._config_source = uri + metric_obj._version = self.version + return metric_obj + + def resolve(self, api_client: Any) -> "types.Metric": + """Resolves the metric by checking API Predefined, then GCS, caching results.""" + if self._resolved_metric: + return self._resolved_metric + + cache_key = f"{self.name}@{self.version or 'default'}" + if cache_key in LazyLoadedPrebuiltMetric._cache: + self._resolved_metric = LazyLoadedPrebuiltMetric._cache[cache_key] + logger.debug("Metric '%s' found in cache.", cache_key) + return self._resolved_metric + + # Try resolving as API Predefined Metric first + api_metric = self._resolve_api_predefined() + if api_metric: + self._resolved_metric = api_metric + LazyLoadedPrebuiltMetric._cache[cache_key] = self._resolved_metric + return self._resolved_metric + + # Fallback to GCS loading for custom LLM-based Prebuilt Metrics + logger.debug( + "Metric '%s' not an API Predefined Metric, trying GCS...", self.name + ) + try: + gcs_metric = self._fetch_and_parse(api_client) + final_cache_key = f"{self.name}@{self.version}" + LazyLoadedPrebuiltMetric._cache[final_cache_key] = gcs_metric + self._resolved_metric = gcs_metric + return self._resolved_metric + except Exception as e: + logger.error( + "Error loading metric %s (requested version: %s) from GCS: %s", + self.name, + self.version, + e, + ) + raise ValueError( + f"Metric '{self.name}' could not be resolved as an API " + "Predefined Metric or loaded from GCS." + ) from e + + def __call__( + self, version: Optional[str] = None, **kwargs: Any + ) -> "LazyLoadedPrebuiltMetric": + """Allows setting a specific version and other metric attributes.""" + updated_kwargs = self.metric_kwargs.copy() + updated_kwargs.update(kwargs) + return LazyLoadedPrebuiltMetric( + name=self.name, version=version or self.version, **updated_kwargs + ) + + +class PrebuiltMetricLoader: + """Provides access to predefined evaluation metrics via attributes. + + This class provides a set of predefined LLM-based metrics (Autorater recipes) + for evaluation. These metrics are lazily loaded from a GCS repository + when they are first accessed. + + Example: + from agentplatform import types + text_quality_metric = types.RubricMetric.TEXT_QUALITY + """ + + def __getattr__( + self, name: str, version: Optional[str] = None, **kwargs: Any + ) -> LazyLoadedPrebuiltMetric: + return LazyLoadedPrebuiltMetric(name=name, version=version, **kwargs) + + @property + def GENERAL_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("GENERAL_QUALITY") + + @property + def TEXT_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("TEXT_QUALITY") + + @property + def INSTRUCTION_FOLLOWING(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("INSTRUCTION_FOLLOWING") + + @property + def SAFETY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("SAFETY") + + @property + def MULTI_TURN_GENERAL_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_GENERAL_QUALITY") + + @property + def MULTI_TURN_TEXT_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_TEXT_QUALITY") + + @property + def MULTI_TURN_TOOL_USE_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_TOOL_USE_QUALITY", version="v1") + + @property + def MULTI_TURN_TRAJECTORY_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_TRAJECTORY_QUALITY", version="v1") + + @property + def MULTI_TURN_TASK_SUCCESS(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_TASK_SUCCESS", version="v1") + + @property + def FINAL_RESPONSE_MATCH(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("FINAL_RESPONSE_MATCH", version="v2") + + @property + def FINAL_RESPONSE_REFERENCE_FREE(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("FINAL_RESPONSE_REFERENCE_FREE") + + @property + def COHERENCE(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("COHERENCE") + + @property + def FLUENCY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("FLUENCY") + + @property + def VERBOSITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("VERBOSITY") + + @property + def SUMMARIZATION_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("SUMMARIZATION_QUALITY") + + @property + def QUESTION_ANSWERING_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("QUESTION_ANSWERING_QUALITY") + + @property + def MULTI_TURN_CHAT_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_CHAT_QUALITY") + + @property + def MULTI_TURN_SAFETY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("MULTI_TURN_SAFETY") + + @property + def FINAL_RESPONSE_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("FINAL_RESPONSE_QUALITY") + + @property + def HALLUCINATION(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("HALLUCINATION") + + @property + def TOOL_USE_QUALITY(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("TOOL_USE_QUALITY") + + @property + def GECKO_TEXT2IMAGE(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("GECKO_TEXT2IMAGE") + + @property + def GECKO_TEXT2VIDEO(self) -> LazyLoadedPrebuiltMetric: + return self.__getattr__("GECKO_TEXT2VIDEO") + + +PrebuiltMetric = PrebuiltMetricLoader() +RubricMetric = PrebuiltMetric + + +def CodeExecutionMetric( + name: str, custom_function: str, **kwargs: Any +) -> "types.Metric": + """Instantiates a code execution metric.""" + from . import types + + return types.Metric(name=name, remote_custom_function=custom_function, **kwargs) diff --git a/agentplatform/_genai/_evals_utils.py b/agentplatform/_genai/_evals_utils.py new file mode 100644 index 0000000000..feb24bfbf1 --- /dev/null +++ b/agentplatform/_genai/_evals_utils.py @@ -0,0 +1,1029 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for evals.""" + +import abc +import asyncio +import json +import logging +import os +import threading +import time +from typing import Any, Optional, Union + +from google.genai._api_client import BaseApiClient +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +import pandas as pd + +from . import _bigquery_utils +from . import _gcs_utils +from . import _transformers +from . import types + + +logger = logging.getLogger(__name__) + + +GCS_PREFIX = "gs://" +BQ_PREFIX = "bq://" +_DEFAULT_EVAL_SERVICE_QPS = 10 + + +class RateLimiter: + """Helper class for rate-limiting requests to Vertex AI to improve QoS. + + Implements a token bucket algorithm to limit the rate at which API calls + can occur. Designed for cases where the batch size is always 1 for traffic + shaping and rate limiting. + + Attributes: + seconds_per_event: The time interval (in seconds) between events to + maintain the desired rate. + last: The timestamp of the last event. + _lock: A lock to ensure thread safety. + """ + + def __init__(self, rate: float) -> None: + """Initializes the rate limiter. + + Args: + rate: The number of queries allowed per second. + + Raises: + ValueError: If the rate is not positive. + """ + if not rate or rate <= 0: + raise ValueError("Rate must be a positive number") + self.seconds_per_event = 1.0 / rate + self._next_allowed = time.monotonic() + self._lock = threading.Lock() + + def sleep_and_advance(self) -> None: + """Blocks the current thread until the next event can be admitted. + + The lock is held only long enough to reserve a time slot. The + actual sleep happens outside the lock so that multiple threads + can be sleeping concurrently with staggered wake-up times. + """ + with self._lock: + now = time.monotonic() + wait_until = max(now, self._next_allowed) + delay = wait_until - now + self._next_allowed = wait_until + self.seconds_per_event + + if delay > 0: + time.sleep(delay) + + +class EvalDatasetLoader: + """A loader for datasets from various sources, using a shared client.""" + + def __init__(self, api_client: BaseApiClient) -> None: + self.api_client = api_client + self.gcs_utils = _gcs_utils.GcsUtils(self.api_client) + self.bigquery_utils = _bigquery_utils.BigQueryUtils(self.api_client) + + def _load_file( + self, filepath: str, file_type: str + ) -> Union[list[dict[str, Any]], Any]: + """Loads data from a file into a list of dictionaries.""" + if filepath.startswith(GCS_PREFIX): + df = self.gcs_utils.read_gcs_file_to_dataframe(filepath, file_type) + return df.to_dict(orient="records") + else: + if file_type == "jsonl": + df = pd.read_json(filepath, lines=True) + return df.to_dict(orient="records") + elif file_type == "csv": + df = pd.read_csv(filepath, encoding="utf-8") + return df.to_dict(orient="records") + else: + raise ValueError( + f"Unsupported file type: '{file_type}'. Please provide 'jsonl' or" + " 'csv'." + ) + + def load( + self, source: Union[str, "pd.DataFrame"] + ) -> Union[list[dict[str, Any]], Any]: + """Loads dataset from various sources into a list of dictionaries.""" + if isinstance(source, pd.DataFrame): + return source.to_dict(orient="records") + elif isinstance(source, str): + if source.startswith(BQ_PREFIX): + df = self.bigquery_utils.load_bigquery_to_dataframe( + source[len(BQ_PREFIX) :] + ) + return df.to_dict(orient="records") + + _, extension = os.path.splitext(source) + file_type = extension.lower()[1:] + + if file_type == "jsonl": + return self._load_file(source, "jsonl") + elif file_type == "csv": + return self._load_file(source, "csv") + else: + raise TypeError( + f"Unsupported file type: {file_type} from {source}. Please" + " provide a valid GCS path with `jsonl` or `csv` suffix, " + "a local file path, or a valid BigQuery table URI." + ) + else: + raise TypeError( + "Unsupported dataset type. Must be a `pd.DataFrame`, Python" + " a valid GCS path with `jsonl` or `csv` suffix, a local" + " file path, or a valid BigQuery table URI." + ) + + +class BatchEvaluateRequestPreparer: + """Prepares data for requests.""" + + @staticmethod + def _EvaluationDataset_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["gcs_source"]) is not None: + setv( + to_object, + ["gcs_source"], + getv(from_object, ["gcs_source"]), + ) + + if getv(from_object, ["bigquery_source"]) is not None: + setv( + to_object, + ["bigquery_source"], + getv(from_object, ["bigquery_source"]), + ) + + return to_object + + @staticmethod + def _Metric_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["prompt_template"]) is not None: + setv( + to_object, + ["pointwise_metric_spec", "prompt_template"], + getv(from_object, ["prompt_template"]), + ) + + if getv(from_object, ["judge_model"]) is not None: + setv( + parent_object, + ["autorater_config", "autorater_model"], + getv(from_object, ["judge_model"]), + ) + + if getv(from_object, ["judge_model_sampling_count"]) is not None: + setv( + parent_object, + ["autorater_config", "sampling_count"], + getv(from_object, ["judge_model_sampling_count"]), + ) + + if getv(from_object, ["judge_model_system_instruction"]) is not None: + setv( + to_object, + ["pointwise_metric_spec", "system_instruction"], + getv(from_object, ["judge_model_system_instruction"]), + ) + + if getv(from_object, ["return_raw_output"]) is not None: + setv( + to_object, + [ + "pointwise_metric_spec", + "custom_output_format_config", + "return_raw_output", + ], + getv(from_object, ["return_raw_output"]), + ) + + return to_object + + @staticmethod + def _OutputConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["gcs_destination"]) is not None: + setv( + to_object, + ["gcsDestination"], + getv(from_object, ["gcs_destination"]), + ) + + return to_object + + @staticmethod + def _EvaluationDataset_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["dataset", "gcs_source"]) is not None: + setv( + to_object, + ["gcs_source"], + getv(from_object, ["dataset", "gcs_source"]), + ) + + if getv(from_object, ["dataset", "bigquery_source"]) is not None: + setv( + to_object, + ["bigquery_source"], + getv(from_object, ["dataset", "bigquery_source"]), + ) + + return to_object + + @staticmethod + def _AutoraterConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["sampling_count"]) is not None: + setv(to_object, ["samplingCount"], getv(from_object, ["sampling_count"])) + + if getv(from_object, ["flip_enabled"]) is not None: + setv(to_object, ["flipEnabled"], getv(from_object, ["flip_enabled"])) + + if getv(from_object, ["autorater_model"]) is not None: + setv( + to_object, + ["autoraterModel"], + getv(from_object, ["autorater_model"]), + ) + + return to_object + + @staticmethod + def EvaluateDatasetOperation_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["done"]) is not None: + setv(to_object, ["done"], getv(from_object, ["done"])) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["response"]) is not None: + setv( + to_object, + ["response"], + BatchEvaluateRequestPreparer._EvaluationDataset_from_vertex( + getv(from_object, ["response"]), to_object + ), + ) + + return to_object + + @staticmethod + def EvaluateDatasetRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset"]) is not None: + setv( + to_object, + ["dataset"], + BatchEvaluateRequestPreparer._EvaluationDataset_to_vertex( + getv(from_object, ["dataset"]), to_object + ), + ) + + if getv(from_object, ["metrics"]) is not None: + setv( + to_object, + ["metrics"], + [ + BatchEvaluateRequestPreparer._Metric_to_vertex(item, to_object) + for item in getv(from_object, ["metrics"]) + ], + ) + + if getv(from_object, ["output_config"]) is not None: + setv( + to_object, + ["outputConfig"], + BatchEvaluateRequestPreparer._OutputConfig_to_vertex( + getv(from_object, ["output_config"]), to_object + ), + ) + + if getv(from_object, ["autorater_config"]) is not None: + setv( + to_object, + ["autoraterConfig"], + BatchEvaluateRequestPreparer._AutoraterConfig_to_vertex( + getv(from_object, ["autorater_config"]), to_object + ), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + @staticmethod + def prepare_metric_payload( + request_dict: dict[str, Any], resolved_metrics: list["types.MetricSubclass"] + ) -> dict[str, Any]: + """Prepares the metric payload for the evaluation request. + + Args: + request_dict: The dictionary containing the request details. + resolved_metrics: A list of resolved metric objects. + + Returns: + The updated request dictionary with the prepared metric payload. + """ + request_dict["metrics"] = _transformers.t_metrics( + resolved_metrics, set_default_aggregation_metrics=True + ) + return request_dict + + +class EvalDataConverter(abc.ABC): + """Abstract base class for dataset converters.""" + + @abc.abstractmethod + def convert(self, raw_data: Any) -> "types.EvaluationDataset": + """Converts a loaded raw dataset into an EvaluationDataset.""" + raise NotImplementedError() + + +def _postprocess_user_scenarios_response( + response: types.GenerateUserScenariosResponse, +) -> types.EvaluationDataset: + """Postprocesses the response from generating user scenarios.""" + eval_cases = [] + data_for_df = [] + if hasattr(response, "user_scenarios") and response.user_scenarios: + for scenario in response.user_scenarios: + eval_case = types.EvalCase( + user_scenario=scenario, + ) + eval_cases.append(eval_case) + data_for_df.append( + { + "starting_prompt": scenario.starting_prompt, + "conversation_plan": scenario.conversation_plan, + } + ) + eval_dataset_df = None + if pd is not None: + eval_dataset_df = pd.DataFrame(data_for_df) + else: + logger.warning("Pandas is not installed. eval_dataset_df will be None.") + return types.EvaluationDataset( + eval_cases=eval_cases, eval_dataset_df=eval_dataset_df + ) + + +def _display_loss_analysis_result( + result: types.LossAnalysisResult, +) -> None: + """Displays a LossAnalysisResult as a formatted pandas DataFrame.""" + metric = result.config.metric if result.config else None + candidate = result.config.candidate if result.config else None + rows: list[dict[str, Any]] = [] + for cluster in result.clusters or []: + entry = cluster.taxonomy_entry + row = { + "metric": metric, + "candidate": candidate, + "cluster_id": cluster.cluster_id, + "l1_category": entry.l1_category if entry else None, + "l2_category": entry.l2_category if entry else None, + "description": entry.description if entry else None, + "item_count": cluster.item_count, + } + rows.append(row) + + if not rows: + logger.info("No loss clusters found.") + return + + df = pd.DataFrame(rows) + try: + from IPython.display import display # pylint: disable=g-import-not-at-top + + display(df) + except ImportError: + print(df.to_string()) # pylint: disable=print-function + + +def _resolve_metric_name( + metric: Optional[Any], +) -> Optional[str]: + """Extracts a metric name string from a metric argument. + + Accepts a string, a Metric object, or a LazyLoadedPrebuiltMetric + (RubricMetric) and returns the metric name as a string. + + For LazyLoadedPrebuiltMetric (e.g., RubricMetric.MULTI_TURN_TASK_SUCCESS), + this resolves to the API metric spec name (e.g., + "multi_turn_task_success_v1") so it matches the keys in eval results. + + Args: + metric: A metric name string, Metric object, RubricMetric enum value, or + None. + + Returns: + The metric name as a string, or None if metric is None. + """ + if metric is None: + return None + if isinstance(metric, str): + return metric + # LazyLoadedPrebuiltMetric: resolve to versioned API spec name. + if hasattr(metric, "_get_api_metric_spec_name"): + spec_name: Optional[str] = metric._get_api_metric_spec_name() + if spec_name: + return spec_name + # Metric objects and other types with a .name attribute. + if hasattr(metric, "name"): + return str(metric.name) + return str(metric) + + +def _resolve_eval_run_loss_configs( + loss_analysis_metrics: Optional[list[Any]] = None, + loss_analysis_configs: Optional[list[Any]] = None, + inference_configs: Optional[dict[str, Any]] = None, +) -> Optional[list[types.LossAnalysisConfig]]: + """Resolves loss analysis configs for create_evaluation_run. + + Supports two modes: + 1. ``loss_analysis_metrics``: A simplified list of metrics. The candidate + is auto-inferred from ``inference_configs`` when there is exactly one + candidate. Each metric is resolved via ``_resolve_metric_name()``. + 2. ``loss_analysis_configs``: Explicit ``LossAnalysisConfig`` objects or + dicts for full control. + + Args: + loss_analysis_metrics: Optional list of metric references (strings, + Metric objects, or RubricMetric enums). + loss_analysis_configs: Optional list of LossAnalysisConfig or dicts. + inference_configs: The resolved inference_configs dict (candidate name + -> config). Used to auto-infer candidate for the metrics path. + + Returns: + A list of resolved LossAnalysisConfig objects, or None if neither + loss_analysis_metrics nor loss_analysis_configs is provided. + + Raises: + ValueError: If candidate cannot be inferred for loss_analysis_metrics. + """ + if not loss_analysis_metrics and not loss_analysis_configs: + return None + + if loss_analysis_configs: + return [ + types.LossAnalysisConfig.model_validate(c) if isinstance(c, dict) else c + for c in loss_analysis_configs + ] + + # loss_analysis_metrics path: auto-infer candidate from inference_configs + candidate = None + if inference_configs and len(inference_configs) == 1: + candidate = next(iter(inference_configs)) + elif inference_configs and len(inference_configs) > 1: + raise ValueError( + "Cannot infer candidate for loss analysis: multiple candidates" + f" found in inference_configs: {list(inference_configs.keys())}." + " Please use loss_analysis_configs with explicit candidate values" + " instead." + ) + + configs = [] + for m in loss_analysis_metrics or []: + metric_name = _resolve_metric_name(m) + configs.append( + types.LossAnalysisConfig(metric=metric_name, candidate=candidate) + ) + return configs + + +def _resolve_loss_analysis_config( + eval_result: types.EvaluationResult, + config: Optional[types.LossAnalysisConfig] = None, + metric: Optional[str] = None, + candidate: Optional[str] = None, +) -> types.LossAnalysisConfig: + """Resolves and validates the LossAnalysisConfig for generate_loss_clusters. + + Auto-infers `metric` and `candidate` from the EvaluationResult when not + explicitly provided. Validates that provided values exist in the eval result. + + Args: + eval_result: The EvaluationResult from client.evals.evaluate(). + config: Optional explicit LossAnalysisConfig. If provided, metric and + candidate from config take precedence over the separate arguments. + metric: Optional metric name override. + candidate: Optional candidate name override. + + Returns: + A resolved LossAnalysisConfig with metric and candidate populated. + + Raises: + ValueError: If metric/candidate cannot be inferred or are invalid. + """ + # Start from config if provided, otherwise create a new one. + if config is not None: + resolved_metric = metric or config.metric + resolved_candidate = candidate or config.candidate + resolved_config = config.model_copy( + update={"metric": resolved_metric, "candidate": resolved_candidate} + ) + else: + resolved_config = types.LossAnalysisConfig(metric=metric, candidate=candidate) + + # Collect available metric names from the eval result. + available_metrics: set[str] = set() + if eval_result.eval_case_results: + for case_result in eval_result.eval_case_results: + for resp_cand in case_result.response_candidate_results or []: + for m_name in (resp_cand.metric_results or {}).keys(): + available_metrics.add(m_name) + + # Collect available candidate names from metadata. + available_candidates: list[str] = [] + if eval_result.metadata and eval_result.metadata.candidate_names: + available_candidates = list(eval_result.metadata.candidate_names) + + # Auto-infer metric if not provided. + if not resolved_config.metric: + if len(available_metrics) == 1: + resolved_config = resolved_config.model_copy( + update={"metric": next(iter(available_metrics))} + ) + elif len(available_metrics) == 0: + raise ValueError( + "Cannot infer metric: no metric results found in eval_result." + " Please provide metric explicitly via" + " config=types.LossAnalysisConfig(metric='...')." + ) + else: + raise ValueError( + "Cannot infer metric: multiple metrics found in eval_result:" + f" {sorted(available_metrics)}. Please provide metric" + " explicitly via config=types.LossAnalysisConfig(metric='...')." + ) + + # Validate metric if provided explicitly. + if available_metrics and resolved_config.metric not in available_metrics: + raise ValueError( + f"Metric '{resolved_config.metric}' not found in eval_result." + f" Available metrics: {sorted(available_metrics)}." + ) + + # Auto-infer candidate if not provided. + if not resolved_config.candidate: + if len(available_candidates) == 1: + resolved_config = resolved_config.model_copy( + update={"candidate": available_candidates[0]} + ) + elif len(available_candidates) == 0: + # Fallback: use default candidate naming convention from SDK. + resolved_config = resolved_config.model_copy( + update={"candidate": "candidate_1"} + ) + logger.warning( + "No candidate names found in eval_result.metadata." + " Defaulting to 'candidate_1'. If this is incorrect, provide" + " candidate explicitly via" + " config=types.LossAnalysisConfig(candidate='...')." + ) + else: + raise ValueError( + "Cannot infer candidate: multiple candidates found in" + f" eval_result: {available_candidates}. Please provide" + " candidate explicitly via" + " config=types.LossAnalysisConfig(candidate='...')." + ) + + # Validate candidate if provided explicitly and candidates are known. + if available_candidates and resolved_config.candidate not in available_candidates: + raise ValueError( + f"Candidate '{resolved_config.candidate}' not found in" + f" eval_result. Available candidates: {available_candidates}." + ) + + return resolved_config + + +def _build_rubric_description_map( + eval_result: types.EvaluationResult, +) -> dict[str, str]: + """Builds a rubric_id -> description map from the EvaluationResult.""" + rubric_map: dict[str, str] = {} + for case_result in eval_result.eval_case_results or []: + for resp_cand in case_result.response_candidate_results or []: + for metric_res in (resp_cand.metric_results or {}).values(): + for verdict in metric_res.rubric_verdicts or []: + rubric = verdict.evaluated_rubric + if rubric and rubric.rubric_id and rubric.content: + if ( + rubric.content.property + and rubric.content.property.description + ): + rubric_map[rubric.rubric_id] = ( + rubric.content.property.description + ) + return rubric_map + + +def _extract_scenario_preview_from_dict( + eval_result_dict: dict[str, Any], +) -> Optional[str]: + """Extracts the first user message from an evaluation_result dict. + + Handles both snake_case (SDK-side) and camelCase (API echo-back) keys. + """ + request = eval_result_dict.get("request") + if not request: + return None + prompt = request.get("prompt") + if not prompt: + return None + # Try agent_data (snake_case or camelCase) + agent_data = prompt.get("agent_data") or prompt.get("agentData") + if agent_data and isinstance(agent_data, dict): + turns = agent_data.get("turns", []) + for turn in turns: + events = turn.get("events", []) + for event in events: + author = event.get("author", "") + content = event.get("content") + if author.lower() == "user" and content and isinstance(content, dict): + parts = content.get("parts", []) + for part in parts: + text = str(part.get("text", "")).strip() + if text: + if len(text) > 150: + return text[:150] + "..." + return text + # Try simple prompt path + parts = prompt.get("parts", []) + for part in parts: + text = str(part.get("text", "")).strip() + if text: + if len(text) > 150: + return text[:150] + "..." + return text + return None + + +def _extract_scenario_from_agent_data(agent_data: Any) -> Optional[str]: + """Extracts the first user message from an AgentData object or dict.""" + if agent_data is None: + return None + if hasattr(agent_data, "model_dump"): + agent_data = agent_data.model_dump() + if isinstance(agent_data, str): + try: + agent_data = json.loads(agent_data) + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(agent_data, dict): + return None + turns = agent_data.get("turns", []) + if not isinstance(turns, list): + return None + for turn in turns: + if not isinstance(turn, dict): + continue + events = turn.get("events", []) + if not isinstance(events, list): + continue + for event in events: + if not isinstance(event, dict): + continue + author = event.get("author", "") + if not isinstance(author, str) or author.lower() != "user": + continue + content = event.get("content") + if not content or not isinstance(content, dict): + continue + parts = content.get("parts", []) + if not isinstance(parts, list): + continue + for part in parts: + if not isinstance(part, dict): + continue + text = str(part.get("text", "")).strip() + if text: + if len(text) > 150: + return text[:150] + "..." + return text + return None + + +def _truncate_scenario(text: str, max_len: int = 150) -> str: + """Truncates a scenario preview to max_len characters.""" + text = text.strip() + if len(text) > max_len: + return text[:max_len] + "..." + return text + + +def _build_scenario_preview_list( + eval_result: types.EvaluationResult, +) -> list[Optional[str]]: + """Builds an ordered list of scenario previews from the EvaluationResult. + + Returns one scenario preview per eval_case_result, in the same order as + eval_case_results. This extracts the first user message from the original + SDK EvaluationResult (via eval_cases or DataFrame), rather than relying + on the API echo-back which may not preserve the request data. + + Extraction priority per eval case: + 1. eval_case.agent_data → first user message in turns + 2. eval_case.user_scenario.starting_prompt + 3. eval_case.prompt → text content + 4. DataFrame agent_data column → first user message + 5. DataFrame starting_prompt column + """ + eval_dataset = eval_result.evaluation_dataset + eval_cases: list[Any] = [] + if isinstance(eval_dataset, list) and eval_dataset: + eval_cases = getv(eval_dataset[0], ["eval_cases"]) or [] + + eval_case_results = eval_result.eval_case_results or [] + scenarios: list[Optional[str]] = [] + + for case_result in eval_case_results: + case_idx = case_result.eval_case_index or 0 + scenario: Optional[str] = None + + eval_case = None + if 0 <= case_idx < len(eval_cases): + eval_case = eval_cases[case_idx] + + if eval_case: + # 1. Try agent_data (populated after run_inference) + agent_data = getv(eval_case, ["agent_data"]) + if agent_data: + scenario = _extract_scenario_from_agent_data(agent_data) + + # 2. Try user_scenario.starting_prompt (from + # generate_conversation_scenarios) + if scenario is None: + user_scenario = getv(eval_case, ["user_scenario"]) + if user_scenario: + starting_prompt = getv(user_scenario, ["starting_prompt"]) + if starting_prompt and isinstance(starting_prompt, str): + scenario = _truncate_scenario(starting_prompt) + + # 3. Try prompt text + if scenario is None: + prompt = getv(eval_case, ["prompt"]) + if prompt: + from . import _evals_data_converters + + text = _evals_data_converters._get_content_text(prompt) + if text: + scenario = _truncate_scenario(str(text)) + + # 4. Fallback: extract agent_data from DataFrame + if scenario is None and eval_dataset: + df_agent_data = _transformers._extract_agent_data_from_df( + eval_dataset, case_idx + ) + if df_agent_data is not None: + scenario = _extract_scenario_from_agent_data(df_agent_data) + + # 5. Fallback: extract starting_prompt from DataFrame + if scenario is None and eval_dataset: + ds = eval_dataset[0] if isinstance(eval_dataset, list) else eval_dataset + df = getv(ds, ["eval_dataset_df"]) + if df is not None and hasattr(df, "iloc"): + if 0 <= case_idx < len(df): + row = df.iloc[case_idx] + sp = row.get("starting_prompt") + if sp and isinstance(sp, str) and sp.strip(): + scenario = _truncate_scenario(sp) + + scenarios.append(scenario) + + return scenarios + + +def _enrich_loss_response_with_rubric_descriptions( + response: types.GenerateLossClustersResponse, + eval_result: types.EvaluationResult, +) -> None: + """Enriches loss response with rubric descriptions and scenario previews. + + Rubric descriptions and scenario previews are extracted from the original + SDK EvaluationResult object, because the API echo-back in + LossExample.evaluation_result may not preserve all request data (e.g., + agent_data turns with user messages). + """ + rubric_map = _build_rubric_description_map(eval_result) + scenario_list = _build_scenario_preview_list(eval_result) + logger.debug( + "Enriching loss response: %d scenarios extracted, %d rubric" " descriptions", + sum(1 for s in scenario_list if s), + len(rubric_map), + ) + for result in response.results or []: + for cluster in result.clusters or []: + for example in cluster.examples or []: + if example.evaluation_result is None: + example.evaluation_result = {} + if rubric_map: + example.evaluation_result["rubric_descriptions"] = rubric_map + # Try extracting scenario from the API echo-back first + if "scenario_preview" not in example.evaluation_result: + scenario = _extract_scenario_preview_from_dict( + example.evaluation_result + ) + if scenario: + example.evaluation_result["scenario_preview"] = scenario + # Fallback: match against scenarios from original eval_result + if "scenario_preview" not in example.evaluation_result: + if scenario_list: + for s in scenario_list: + if s: + example.evaluation_result["scenario_preview"] = s + break + + +def _poll_operation( + api_client: BaseApiClient, + operation: types.GenerateLossClustersOperation, + poll_interval_seconds: float = 5.0, +) -> types.GenerateLossClustersOperation: + """Polls a long-running operation until completion. + + Args: + api_client: The API client to use for polling. + operation: The initial operation returned from the API call. + poll_interval_seconds: Time between polls. + + Returns: + The completed operation. + """ + if operation.done: + return operation + start_time = time.time() + while True: + response = api_client.request("get", operation.name, {}, None) + response_dict = {} if not response.body else json.loads(response.body) + polled = types.GenerateLossClustersOperation._from_response( + response=response_dict, kwargs={} + ) + if polled.done: + return polled + elapsed = int(time.time() - start_time) + logger.info( + "Loss analysis operation still running... Elapsed time: %d seconds", + elapsed, + ) + time.sleep(poll_interval_seconds) + + +async def _poll_operation_async( + api_client: BaseApiClient, + operation: types.GenerateLossClustersOperation, + poll_interval_seconds: float = 5.0, +) -> types.GenerateLossClustersOperation: + """Polls a long-running operation until completion (async). + + Args: + api_client: The API client to use for polling. + operation: The initial operation returned from the API call. + poll_interval_seconds: Time between polls. + + Returns: + The completed operation. + """ + if operation.done: + return operation + start_time = time.time() + while True: + response = await api_client.async_request("get", operation.name, {}, None) + response_dict = {} if not response.body else json.loads(response.body) + polled = types.GenerateLossClustersOperation._from_response( + response=response_dict, kwargs={} + ) + if polled.done: + return polled + elapsed = int(time.time() - start_time) + logger.info( + "Loss analysis operation still running... Elapsed time: %d seconds", + elapsed, + ) + await asyncio.sleep(poll_interval_seconds) + + +def _validate_dataset_agent_data( + dataset: types.EvaluationDataset, + inference_configs: Optional[dict[str, Any]] = None, +) -> None: + """Validates agent_data in the EvaluationDataset. + + Checks that agent_data matches the expected AgentData type and that + 'agents' are not defined in both the dataset's agent_data and inference_configs. + """ + has_inference_agent_configs = False + if inference_configs: + for cand_config in inference_configs.values(): + if isinstance(cand_config, dict) and cand_config.get("agent_configs"): + has_inference_agent_configs = True + elif hasattr(cand_config, "agent_configs") and cand_config.agent_configs: + has_inference_agent_configs = True + + def _validate_single_agent_data(agent_data_val: Any, identifier: str) -> None: + + if not agent_data_val: + return + + agent_data_obj = None + if isinstance(agent_data_val, str): + try: + agent_data_val = json.loads(agent_data_val) + if "error" in agent_data_val: + return + agent_data_obj = types.evals.AgentData.model_validate(agent_data_val) + except json.JSONDecodeError as e: + raise ValueError( + f"{identifier}: 'agent_data' is not valid JSON: {e}" + ) from e + elif isinstance(agent_data_val, dict) and "error" in agent_data_val: + return + elif isinstance(agent_data_val, dict): + try: + agent_data_obj = types.evals.AgentData.model_validate(agent_data_val) + except Exception as e: + raise ValueError( + f"{identifier}: 'agent_data' " + f"is inconsistent with AgentData type: {e}" + ) from e + elif isinstance(agent_data_val, types.evals.AgentData): + agent_data_obj = agent_data_val + else: + raise ValueError( + f"{identifier}: 'agent_data' is inconsistent with AgentData type. " + f"Got {type(agent_data_val)}" + ) + + if agent_data_obj and agent_data_obj.agents and has_inference_agent_configs: + raise ValueError( + f"{identifier}: Cannot provide 'agents' in the dataset's 'agent_data' " + "and 'agent_configs' in inference_configs at the same time." + ) + + if ( + dataset.eval_dataset_df is not None + and "agent_data" in dataset.eval_dataset_df.columns + ): + for idx, row in dataset.eval_dataset_df.iterrows(): + _validate_single_agent_data(row.get("agent_data"), f"Row {idx}") + + if dataset.eval_cases: + for idx, eval_case in enumerate(dataset.eval_cases): + agent_data = None + if isinstance(eval_case, dict): + agent_data = eval_case.get("agent_data", None) + elif hasattr(eval_case, "agent_data"): + agent_data = eval_case.agent_data + _validate_single_agent_data(agent_data, f"EvalCase {idx}") diff --git a/agentplatform/_genai/_evals_visualization.py b/agentplatform/_genai/_evals_visualization.py new file mode 100644 index 0000000000..ed5e82a93a --- /dev/null +++ b/agentplatform/_genai/_evals_visualization.py @@ -0,0 +1,2017 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Visualization utilities for GenAI Evaluation SDK.""" + +import base64 +import datetime +import html +import json +import logging +import textwrap +from typing import Any, Optional + +import pandas as pd +from pydantic import errors + +from . import types + + +logger = logging.getLogger(__name__) + + +def _is_ipython_env() -> bool: + """Checks if the code is running in an IPython environment.""" + try: + from IPython import get_ipython + + return get_ipython() is not None + except ImportError: + return False + + +def _pydantic_serializer(obj: Any) -> Any: + """Custom serializer for Pydantic models.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + if isinstance(obj, datetime.datetime): + return obj.isoformat() + if isinstance(obj, bytes): + return base64.b64encode(obj).decode("utf-8") + raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable") + + +def _preprocess_df_for_json(df: Optional[pd.DataFrame]) -> Optional[pd.DataFrame]: + """Prepares a DataFrame for JSON serialization by converting complex objects to strings.""" + if df is None: + return None + df_copy = df.copy() + + for col in df_copy.columns: + if ( + df_copy[col].dtype == "object" + or df_copy[col].apply(lambda x: isinstance(x, (dict, list))).any() + ): + + def stringify_cell(cell: Any) -> Optional[str]: + if isinstance(cell, (dict, list)): + try: + return json.dumps( + cell, ensure_ascii=False, default=_pydantic_serializer + ) + except TypeError: + return str(cell) + elif pd.isna(cell): + return None + elif not isinstance(cell, (str, int, float, bool)): + if hasattr(cell, "model_dump"): + return json.dumps( + cell.model_dump(mode="json"), ensure_ascii=False + ) + return str(cell) + return str(cell) + + df_copy[col] = df_copy[col].apply(stringify_cell) + return df_copy + + +def _encode_to_base64(data: str) -> str: + """Encodes a string to a web-safe Base64 string.""" + return base64.b64encode(data.encode("utf-8")).decode("utf-8") + + +def _extract_text_and_raw_json(content: Any) -> dict[str, str]: + """Extracts display text and raw JSON from a content object.""" + if hasattr(content, "model_dump"): + content = content.model_dump(mode="json", exclude_none=True) + + if not isinstance(content, (str, dict)): + return {"display_text": str(content or ""), "raw_json": ""} + + try: + data = json.loads(content) if isinstance(content, str) else content + + if not isinstance(data, dict): + return {"display_text": str(content), "raw_json": ""} + + pretty_json = json.dumps(data, indent=2, ensure_ascii=False) + + # Gemini format check (API Wrapper format). + if ( + "contents" in data + and isinstance(data.get("contents"), list) + and data["contents"] + ): + first_part = data["contents"][0].get("parts", [{}])[0] + display_text = first_part.get("text", str(data)) + return {"display_text": display_text, "raw_json": pretty_json} + + # Direct Gemini Content Object Check + elif "parts" in data and isinstance(data.get("parts"), list) and data["parts"]: + text_parts = [p.get("text", "") for p in data["parts"] if "text" in p] + display_text = "\n".join(text_parts) if text_parts else str(data) + return {"display_text": display_text, "raw_json": pretty_json} + + # OpenAI response format check. + elif ( + "choices" in data + and isinstance(data.get("choices"), list) + and data["choices"] + ): + message = data["choices"][0].get("message", {}) + display_text = message.get("content", str(data)) + return {"display_text": display_text, "raw_json": pretty_json} + + # OpenAI request format check. + elif ( + "messages" in data + and isinstance(data.get("messages"), list) + and data["messages"] + ): + user_messages = [ + message.get("content", "") + for message in data["messages"] + if message.get("role") == "user" + ] + display_text = user_messages[-1] if user_messages else str(data) + return {"display_text": display_text, "raw_json": pretty_json} + else: + # Not a recognized format. + return {"display_text": str(content), "raw_json": pretty_json} + + except (json.JSONDecodeError, TypeError, IndexError): + return {"display_text": str(content), "raw_json": ""} + + +def _extract_dataset_rows(dataset: types.EvaluationDataset) -> list[dict[str, Any]]: + """Helper to consistently extract rows from either a dataframe or raw eval_cases list.""" + processed_rows = [] + + # Process from DataFrame if available + if getattr(dataset, "eval_dataset_df", None) is not None: + processed_df = _preprocess_df_for_json(dataset.eval_dataset_df) + if processed_df is not None: + for _, row in processed_df.iterrows(): + prompt_key = "request" if "request" in row else "prompt" + prompt_info = _extract_text_and_raw_json(row.get(prompt_key)) + response_info = _extract_text_and_raw_json(row.get("response")) + ref_info = _extract_text_and_raw_json(row.get("reference")) + processed_row = { + "prompt_display_text": prompt_info["display_text"], + "prompt_raw_json": prompt_info["raw_json"], + "reference": ref_info["display_text"], + "reference_raw_json": ref_info["raw_json"], + "response_display_text": response_info["display_text"], + "response_raw_json": response_info["raw_json"], + "intermediate_events": row.get("intermediate_events", None), + "agent_data": row.get("agent_data", None), + } + processed_rows.append(processed_row) + + # Fallback to pure eval_cases extraction + elif dataset.eval_cases: + for case in dataset.eval_cases: + prompt_info = ( + _extract_text_and_raw_json(case.prompt) + if case.prompt + else {"display_text": "", "raw_json": ""} + ) + + response_info = {"display_text": "", "raw_json": ""} + if case.responses and case.responses[0].response: + response_info = _extract_text_and_raw_json(case.responses[0].response) + + reference_text = "" + reference_raw_json = "" + if case.reference and case.reference.response: + ref_info = _extract_text_and_raw_json(case.reference.response) + reference_text = ref_info["display_text"] + reference_raw_json = ref_info["raw_json"] + + agent_data_json = None + if case.agent_data: + agent_data_json = json.dumps( + case.agent_data.model_dump(mode="json", exclude_none=True), + ensure_ascii=False, + ) + + intermediate_events_json = None + if case.intermediate_events: + intermediate_events_json = json.dumps( + [ + e.model_dump(mode="json", exclude_none=True) + for e in case.intermediate_events + ], + ensure_ascii=False, + ) + + processed_row = { + "prompt_display_text": prompt_info["display_text"], + "prompt_raw_json": prompt_info["raw_json"], + "reference": reference_text, + "reference_raw_json": reference_raw_json, + "response_display_text": response_info["display_text"], + "response_raw_json": response_info["raw_json"], + "intermediate_events": intermediate_events_json, + "agent_data": agent_data_json, + } + processed_rows.append(processed_row) + + return processed_rows + + +def _get_evaluation_html(eval_result_json: str) -> str: + """Returns a self-contained HTML for single evaluation visualization.""" + payload_b64 = _encode_to_base64(eval_result_json) + return textwrap.dedent( + f""" + + + + + Evaluation Report + + + + + +
+

Evaluation Report

+
+
+
+ + + +""" + ) + + +def _get_comparison_html(eval_result_json: str) -> str: + """Returns a self-contained HTML for a side-by-side eval comparison.""" + payload_b64 = _encode_to_base64(eval_result_json) + return textwrap.dedent( + f""" + + + + + Eval Comparison Report + + + + + +
+

Eval Comparison Report

+
+
+
+ + + +""" + ) + + +def _get_inference_html(dataframe_json: str) -> str: + """Returns a self-contained HTML for displaying inference results.""" + payload_b64 = _encode_to_base64(dataframe_json) + return textwrap.dedent( + f""" + + + + + Evaluation Dataset + + + + + +
+

Evaluation Dataset

+
+
+ + + +""" + ) + + +def display_evaluation_result( + eval_result_obj: types.EvaluationResult, + candidate_names: Optional[list[str]] = None, +) -> None: + """Displays evaluation result in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + try: + result_dump = eval_result_obj.model_dump( + mode="json", exclude_none=True, exclude={"evaluation_dataset"} + ) + except errors.PydanticSerializationError as e: + logger.error( + "Serialization Error: %s\nCould not display the evaluation " + "result due to a data serialization issue. Please check the " + "content of the EvaluationResult object.", + e, + ) + return + except Exception as e: + logger.error("Failed to serialize EvaluationResult: %s", e, exc_info=True) + raise + + input_dataset_list = eval_result_obj.evaluation_dataset + is_comparison = input_dataset_list and len(input_dataset_list) > 1 + + metadata_payload = result_dump.get("metadata", {}) + metadata_payload["candidate_names"] = candidate_names or metadata_payload.get( + "candidate_names" + ) + + if is_comparison and input_dataset_list: + if input_dataset_list[0]: + metadata_payload["dataset"] = _extract_dataset_rows(input_dataset_list[0]) + + if "eval_case_results" in result_dump: + for case_res in result_dump["eval_case_results"]: + for resp_idx, cand_res in enumerate( + case_res.get("response_candidate_results", []) + ): + if ( + input_dataset_list is not None + and resp_idx < len(input_dataset_list) + and input_dataset_list[resp_idx] + ): + rows = _extract_dataset_rows(input_dataset_list[resp_idx]) + case_idx = case_res.get("eval_case_index") + if case_idx is not None and case_idx < len(rows): + original_case = rows[case_idx] + cand_res["display_text"] = original_case[ + "response_display_text" + ] + cand_res["raw_json"] = original_case["response_raw_json"] + + win_rates = eval_result_obj.win_rates if eval_result_obj.win_rates else {} + if "summary_metrics" in result_dump: + for summary in result_dump["summary_metrics"]: + if summary.get("metric_name") in win_rates: + summary.update(win_rates[summary["metric_name"]]) + + result_dump["metadata"] = metadata_payload + html_content = _get_comparison_html(json.dumps(result_dump)) + else: + single_dataset = input_dataset_list[0] if input_dataset_list else None + processed_rows = [] + if single_dataset is not None: + processed_rows = _extract_dataset_rows(single_dataset) + metadata_payload["dataset"] = processed_rows + + if "eval_case_results" in result_dump and processed_rows: + for case_res in result_dump["eval_case_results"]: + case_idx = case_res.get("eval_case_index") + if ( + case_idx is not None + and case_idx < len(processed_rows) + and case_res.get("response_candidate_results") + ): + original_case = processed_rows[case_idx] + cand_res = case_res["response_candidate_results"][0] + cand_res["display_text"] = original_case[ + "response_display_text" + ] + cand_res["raw_json"] = original_case["response_raw_json"] + + result_dump["metadata"] = metadata_payload + html_content = _get_evaluation_html(json.dumps(result_dump)) + + display.display(display.HTML(html_content)) + + +def display_evaluation_dataset(eval_dataset_obj: types.EvaluationDataset) -> None: + """Displays an evaluation dataset in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + if ( + eval_dataset_obj.eval_dataset_df is None + or eval_dataset_obj.eval_dataset_df.empty + ): + logger.warning("No inference data to display.") + return + + processed_rows = [] + df = eval_dataset_obj.eval_dataset_df + + for _, row in df.iterrows(): + processed_row = {} + for col_name, cell_value in row.items(): + if col_name in ["prompt", "request", "response"]: + processed_row[col_name] = _extract_text_and_raw_json(cell_value) + elif col_name == "rubric_groups": + # Special handling for rubric_groups to keep it as a dict + if isinstance(cell_value, dict): + processed_row[col_name] = { + k: [ # type: ignore[misc] + ( + v_item.model_dump(mode="json") + if hasattr(v_item, "model_dump") + else v_item + ) + for v_item in v + ] + for k, v in cell_value.items() + } + else: + processed_row[col_name] = cell_value + else: + if isinstance(cell_value, (dict, list)): + processed_row[col_name] = json.dumps( # type: ignore[assignment] + cell_value, ensure_ascii=False, default=_pydantic_serializer + ) + else: + processed_row[col_name] = cell_value + processed_rows.append(processed_row) + + dataframe_json_string = json.dumps(processed_rows, ensure_ascii=False, default=str) + html_content = _get_inference_html(dataframe_json_string) + display.display(display.HTML(html_content)) + + +def _get_loss_analysis_html(loss_analysis_json: str) -> str: + """Returns self-contained HTML for loss pattern analysis visualization.""" + payload_b64 = _encode_to_base64(loss_analysis_json) + return textwrap.dedent( + f""" + + + + + Loss Pattern Analysis + + + + +
+
+
+ + + +""" + ) + + +def display_loss_clusters_response( + response_obj: "types.GenerateLossClustersResponse", +) -> None: + """Displays a GenerateLossClustersResponse in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + try: + result_dump = response_obj.model_dump(mode="json", exclude_none=True) + except Exception as e: + logger.error( + "Failed to serialize GenerateLossClustersResponse: %s", + e, + exc_info=True, + ) + raise + + html_content = _get_loss_analysis_html( + json.dumps(result_dump, ensure_ascii=False, default=_pydantic_serializer) + ) + display.display(display.HTML(html_content)) + + +def display_loss_analysis_result( + result_obj: "types.LossAnalysisResult", +) -> None: + """Displays a single LossAnalysisResult in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + try: + # Wrap in a response-like structure for the shared HTML generator + wrapped = {"results": [result_obj.model_dump(mode="json", exclude_none=True)]} + except Exception as e: + logger.error( + "Failed to serialize LossAnalysisResult: %s", + e, + exc_info=True, + ) + raise + + html_content = _get_loss_analysis_html( + json.dumps(wrapped, ensure_ascii=False, default=_pydantic_serializer) + ) + display.display(display.HTML(html_content)) + + +def _get_status_html(status: str, error_message: Optional[str] = None) -> str: + """Returns a simple HTML string for displaying a status and optional error.""" + error_html = "" + if error_message: + error_html = f""" +

+ Error: +

{html.escape(error_message)}
+

+ """ + + return textwrap.dedent( + f""" +
+

Status: {html.escape(status)}

+ {error_html} +
+ """ + ) + + +def _enrich_loss_examples_with_eval_items( + results: list["types.LossAnalysisResult"], + eval_item_map: Optional[dict[str, dict[str, Any]]], +) -> list[dict[str, Any]]: + """Enriches loss analysis examples with eval item data for visualization. + + For the eval run path, loss examples only have ``evaluation_item`` + (a resource name) but no ``evaluation_result``. The JS visualization + needs ``evaluation_result`` to extract scenario previews and rubric + descriptions. This function joins the loss examples with the eval + item map so the visualization works identically to the LRO path. + + Args: + results: Loss analysis results from the eval run. + eval_item_map: Optional mapping from evaluation item resource name + to serialized evaluation response data (built by + ``_evals_common._build_eval_item_map``). + + Returns: + A list of dicts ready for JSON serialization, with ``evaluation_result`` + populated on each example where a match is found. + """ + result_dicts = [] + for r in results: + r_dump = r.model_dump(mode="json", exclude_none=True) + if eval_item_map: + clusters = r_dump.get("clusters", []) + for cluster in clusters: + examples = cluster.get("examples", []) + for ex in examples: + # Skip if evaluation_result is already populated (LRO path) + if ex.get("evaluation_result"): + continue + # Match by evaluation_item resource name + eval_item_ref = ex.get("evaluation_item") + if eval_item_ref and eval_item_ref in eval_item_map: + ex["evaluation_result"] = eval_item_map[eval_item_ref] + result_dicts.append(r_dump) + return result_dicts + + +def display_loss_analysis_results( + results: list["types.LossAnalysisResult"], + eval_item_map: Optional[dict[str, dict[str, Any]]] = None, +) -> None: + """Displays loss analysis results from an EvaluationRun. + + Wraps the list of LossAnalysisResult objects into the same JSON + structure used by GenerateLossClustersResponse and renders using + the shared _get_loss_analysis_html() function. + + When ``eval_item_map`` is provided (from + ``get_evaluation_run(include_evaluation_items=True)``), the examples + are enriched with scenario and rubric data for the visualization. + + Args: + results: A list of LossAnalysisResult objects from + EvaluationRunResults.loss_analysis_results. + eval_item_map: Optional mapping from evaluation item resource name + to serialized evaluation response data for enrichment. + """ + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + try: + result_dicts = _enrich_loss_examples_with_eval_items(results, eval_item_map) + wrapped = {"results": result_dicts} + except Exception as e: + logger.error( + "Failed to serialize loss analysis results: %s", + e, + exc_info=True, + ) + raise + + html_content = _get_loss_analysis_html( + json.dumps(wrapped, ensure_ascii=False, default=_pydantic_serializer) + ) + display.display(display.HTML(html_content)) + + +def display_evaluation_run_status(eval_run_obj: "types.EvaluationRun") -> None: + """Displays the status of an evaluation run in an IPython environment.""" + if not _is_ipython_env(): + logger.warning("Skipping display: not in an IPython environment.") + return + else: + from IPython import display + + status = eval_run_obj.state.name if eval_run_obj.state else "UNKNOWN" + error_message = str(eval_run_obj.error) if eval_run_obj.error else None + html_content = _get_status_html(status, error_message) + display.display(display.HTML(html_content)) diff --git a/agentplatform/_genai/_gcs_utils.py b/agentplatform/_genai/_gcs_utils.py new file mode 100644 index 0000000000..176f2109e4 --- /dev/null +++ b/agentplatform/_genai/_gcs_utils.py @@ -0,0 +1,192 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import io +import json +import logging +from typing import Any, Union + +from google.cloud import storage # type: ignore[attr-defined] +from google.cloud.aiplatform.utils.gcs_utils import blob_from_uri +from google.genai._api_client import BaseApiClient +import pandas as pd +import uuid + + +logger = logging.getLogger(__name__) + + +GCS_PREFIX = "gs://" + + +class GcsUtils: + """Handles File I/O operations with Google Cloud Storage (GCS)""" + + def __init__(self, api_client: BaseApiClient): + self.api_client = api_client + self.storage_client = storage.Client( + project=self.api_client.project, + credentials=self.api_client._credentials, + ) + + def parse_gcs_path(self, gcs_path: str) -> tuple[str, str]: + """Helper to parse gs://bucket/path into (bucket_name, blob_path).""" + if not gcs_path.startswith(GCS_PREFIX): + raise ValueError( + f"Invalid GCS path: '{gcs_path}'. It must start with '{GCS_PREFIX}'." + ) + path_without_prefix = gcs_path[len(GCS_PREFIX) :] + if "/" not in path_without_prefix: + return path_without_prefix, "" + bucket_name, blob_path = path_without_prefix.split("/", 1) + return bucket_name, blob_path + + def upload_file_to_gcs(self, upload_gcs_path: str, filename: str) -> None: + """Uploads the provided file to a Google Cloud Storage location.""" + + blob_from_uri( + uri=upload_gcs_path, client=self.storage_client + ).upload_from_filename(filename) + + def upload_dataframe( + self, + df: "pd.DataFrame", + gcs_destination_blob_path: str, + file_type: str = "jsonl", + ) -> None: + """Uploads a Pandas DataFrame to a Google Cloud Storage location. + + Args: + df: The Pandas DataFrame to upload. + gcs_destination_blob_path: The full GCS path for the destination blob + (e.g., 'gs://bucket/data/my_dataframe.jsonl'). + file_type: The format to save the DataFrame ('jsonl' or 'csv'). Defaults + to 'jsonl'. + """ + bucket_name, blob_name = self.parse_gcs_path(gcs_destination_blob_path) + if not blob_name: + raise ValueError( + f"Invalid GCS path for blob: '{gcs_destination_blob_path}'. " + "It must include the object name (e.g., gs://bucket/file.csv)." + ) + bucket = self.storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + buffer = io.StringIO() + if file_type == "csv": + df.to_csv(buffer, index=False) + content_type = "text/csv" + elif file_type == "jsonl": + df.to_json(buffer, orient="records", lines=True) + content_type = "application/jsonl" + else: + raise ValueError( + f"Unsupported file type: '{file_type}'. " + "Please provide 'jsonl' or 'csv'." + ) + blob.upload_from_string(buffer.getvalue(), content_type=content_type) + + logger.info( + f"DataFrame successfully uploaded to: gs://{bucket.name}/{blob.name}" + ) + + def upload_json(self, data: dict[str, Any], gcs_destination_blob_path: str) -> None: + """Uploads a dictionary as a JSON file to Google Cloud Storage.""" + bucket_name, blob_name = self.parse_gcs_path(gcs_destination_blob_path) + if not blob_name: + raise ValueError( + f"Invalid GCS path for blob: '{gcs_destination_blob_path}'. " + "It must include the object name (e.g., gs://bucket/file.json)." + ) + bucket = self.storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + json_data = json.dumps(data, indent=2) + blob.upload_from_string(json_data, content_type="application/json") + + logger.info( + f"JSON data successfully uploaded to: gs://{bucket_name}/{blob_name}" + ) + + def upload_json_to_prefix( + self, + data: dict[str, Any], + gcs_dest_prefix: str, + filename_prefix: str = "data", + ) -> str: + """Uploads a dictionary to a GCS prefix with a UUID JSON filename. + + Args: + data: The dictionary to upload. + gcs_dest_prefix: The GCS prefix (e.g., 'gs://bucket/path/prefix/'). + filename_prefix: Prefix for the generated filename. Defaults to 'data'. + + Returns: + The full GCS path where the file was uploaded. + + Raises: + ValueError: If the gcs_dest_prefix is not a valid GCS path. + """ + if not gcs_dest_prefix.startswith(GCS_PREFIX): + raise ValueError( + f"Invalid GCS destination prefix: '{gcs_dest_prefix}'. Must start" + f" with '{GCS_PREFIX}'." + ) + + gcs_path_without_scheme = gcs_dest_prefix[len(GCS_PREFIX) :] + bucket_name, *path_parts = gcs_path_without_scheme.split("/") + + user_prefix_path = "/".join(path_parts) + if user_prefix_path and not user_prefix_path.endswith("/"): + user_prefix_path += "/" + + filename = f"{filename_prefix}_{uuid.uuid4()}.json" + + blob_name = f"{user_prefix_path}{filename}" + + full_gcs_path = f"{GCS_PREFIX}{bucket_name}/{blob_name}" + + self.upload_json(data, full_gcs_path) + return full_gcs_path + + def read_file_contents(self, gcs_filepath: str) -> Union[str, Any]: + """Reads the contents of a file from Google Cloud Storage.""" + + bucket_name, blob_path = self.parse_gcs_path(gcs_filepath) + if not blob_path: + raise ValueError( + f"Invalid GCS file path: '{gcs_filepath}'. Path must point to a file," + " not just a bucket." + ) + bucket = self.storage_client.bucket(bucket_name) + blob = bucket.blob(blob_path) + content = blob.download_as_bytes().decode("utf-8") + logger.info(f"Successfully read content from '{gcs_filepath}'") + return content + + def read_gcs_file_to_dataframe( + self, gcs_filepath: str, file_type: str + ) -> "pd.DataFrame": + """Reads a file from Google Cloud Storage into a Pandas DataFrame.""" + file_contents = self.read_file_contents(gcs_filepath) + if file_type == "csv": + return pd.read_csv(io.StringIO(file_contents), encoding="utf-8") + elif file_type == "jsonl": + return pd.read_json(io.StringIO(file_contents), lines=True) + else: + raise ValueError( + f"Unsupported file type: '{file_type}'. Please provide 'jsonl' or" + " 'csv'." + ) diff --git a/agentplatform/_genai/_logging_utils.py b/agentplatform/_genai/_logging_utils.py new file mode 100644 index 0000000000..cee6b23df4 --- /dev/null +++ b/agentplatform/_genai/_logging_utils.py @@ -0,0 +1,47 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +from typing import Any, Callable +from google.genai import _common +import warnings + + +def show_deprecation_warning_once( + message: str, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """Decorator to show a deprecation warning once for a function.""" + + def decorator(func: Any) -> Any: + warning_done = False + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal warning_done + if not warning_done: + warning_done = True + warnings.warn(message, DeprecationWarning, stacklevel=2) + + # Suppress ExperimentalWarning while executing the deprecated wrapper + with warnings.catch_warnings(): + # We ignore ExperimentalWarning because the user will see it + # when they migrate to the new prompts module + warnings.simplefilter("ignore", category=_common.ExperimentalWarning) + return func(*args, **kwargs) + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/agentplatform/_genai/_observability_data_converter.py b/agentplatform/_genai/_observability_data_converter.py new file mode 100644 index 0000000000..d5bef28d9c --- /dev/null +++ b/agentplatform/_genai/_observability_data_converter.py @@ -0,0 +1,186 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Dataset converter for Google Observability GenAI data.""" + +import json +import logging +from typing import Any, Optional + +from google.genai import types as genai_types +from typing_extensions import override + +from . import _evals_utils +from . import types + + +logger = logging.getLogger("agentplatform_genai._observability_data_converters") + + +def _load_jsonl(data: Any, case_id: str) -> list[dict[Any, Any]]: + """Parses the raw JSONL data into a list of dict possible.""" + if isinstance(data, str): + json_list = [] + for line in data.splitlines(): + loaded_json = json.loads(line) + if not isinstance(loaded_json, dict): + raise TypeError( + f"Decoded JSON payload is not a dict for case " + f"{case_id}. Type found: {type(loaded_json).__name__}" + ) + json_list.append(loaded_json) + return json_list + else: + raise TypeError( + f"Payload is not a JSONL string for case {case_id}. Type " + f"found: {type(data).__name__}" + ) + + +class ObservabilityDataConverter(_evals_utils.EvalDataConverter): + """Converter for dataset in GCP Observability GenAI format.""" + + def _message_to_content(self, message: dict[str, Any]) -> genai_types.Content: + """Converts Observability GenAI Message format to Content.""" + parts = [] + message_parts = message.get("parts", []) + if isinstance(message_parts, list): + for message_part in message_parts: + part = None + part_type = message_part.get("type", "") + if part_type == "text": + part = genai_types.Part(text=message_part.get("content", "")) + elif part_type == "blob": + part = genai_types.Part( + inline_data=genai_types.Blob( + data=message_part.get("data", ""), + mime_type=message_part.get("mime_type", ""), + ) + ) + elif part_type == "file_data": + part = genai_types.Part( + file_data=genai_types.FileData( + file_uri=message_part.get("file_uri", ""), + mime_type=message_part.get("mime_type", ""), + ) + ) + elif part_type == "tool_call": + # O11y format requires use of id in place of name + part = genai_types.Part( + function_call=genai_types.FunctionCall( + id=message_part.get("id", ""), + name=message_part.get("id", ""), + args=message_part.get("arguments", {}), + ) + ) + elif part_type == "tool_call_response": + # O11y format requires use of id in place of name + part = genai_types.Part( + function_response=genai_types.FunctionResponse( + id=message_part.get("id", ""), + name=message_part.get("id", ""), + response=message_part.get("result", {}), + ) + ) + else: + logger.warning( + "Skipping message part due to unrecognized message " + "part type of '%s'", + part_type, + ) + + if part is not None: + parts.append(part) + + return genai_types.Content(parts=parts, role=message.get("role", "")) + + def _parse_messages( + self, + eval_case_id: str, + request_msgs: list[Any], + response_msgs: list[Any], + system_instruction_msg: Optional[dict[str, Any]] = None, + ) -> types.EvalCase: + """Parses a set of Observability messages into an EvalCase.""" + # System instruction message + system_instruction = None + if system_instruction_msg is not None: + system_instruction = self._message_to_content(system_instruction_msg) + + # Request messages + prompt = None + conversation_history = [] + if request_msgs: + # Extract latest message as prompt + prompt = self._message_to_content(request_msgs[-1]) + + # All previous messages are conversation history + if len(request_msgs) > 1: + for i, msg in enumerate(request_msgs[:-1]): + conversation_history.append( + types.evals.Message( + turn_id=str(i), + content=self._message_to_content(msg), + author=msg.get("role", ""), + ) + ) + + # Output messages + responses = [] + for msg in response_msgs: + response = types.ResponseCandidate(response=self._message_to_content(msg)) + responses.append(response) + + return types.EvalCase( + eval_case_id=eval_case_id, + prompt=prompt, + responses=responses, + system_instruction=system_instruction, + conversation_history=conversation_history, + reference=None, + ) + + @override + def convert(self, raw_data: list[dict[str, Any]]) -> types.EvaluationDataset: + """Converts a list of GCP Observability GenAI cases into an EvaluationDataset.""" + eval_cases = [] + + for i, case in enumerate(raw_data): + eval_case_id = f"observability_eval_case_{i}" + + if "request" not in case or "response" not in case: + logger.warning( + "Skipping case %s due to missing 'request' or 'response' key.", + eval_case_id, + ) + continue + + request_data = case.get("request", []) + request_list = _load_jsonl(request_data, eval_case_id) + + response_data = case.get("response", []) + response_list = _load_jsonl(response_data, eval_case_id) + + system_dict = None + if "system_instruction" in case: + system_data = case.get("system_instruction", {}) + system_list = _load_jsonl(system_data, eval_case_id) + system_dict = system_list[0] if system_list else {} + + eval_case = self._parse_messages( + eval_case_id, request_list, response_list, system_dict + ) + eval_cases.append(eval_case) + + return types.EvaluationDataset(eval_cases=eval_cases) diff --git a/agentplatform/_genai/_prompt_management_utils.py b/agentplatform/_genai/_prompt_management_utils.py new file mode 100644 index 0000000000..e8de400c0d --- /dev/null +++ b/agentplatform/_genai/_prompt_management_utils.py @@ -0,0 +1,147 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for prompt management.""" + +from typing import Optional + +from google.genai import types as genai_types + +from . import types + + +DEFAULT_API_SCHEMA_VERSION = "1.0.0" +PROMPT_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/text_prompt_1.0.0.yaml" +) +PROMPT_TYPE = "multimodal_freeform" + + +def _create_dataset_metadata_from_prompt( + prompt: types.Prompt, + variables: Optional[list[dict[str, genai_types.Part]]] = None, +) -> types.SchemaTextPromptDatasetMetadata: + """Convert a types.Prompt into types.SchemaTextPromptDatasetMetadata.""" + + prompt_metadata = types.SchemaTextPromptDatasetMetadata() + + prompt_api_schema = types.SchemaPromptApiSchema() + prompt_api_schema.multimodal_prompt = types.SchemaPromptSpecMultimodalPrompt( + prompt_message=prompt.prompt_data + ) + + prompt_api_schema.api_schema_version = DEFAULT_API_SCHEMA_VERSION + + prompt_metadata.has_prompt_variable = bool(variables) + + if variables: + prompt_execution_list = [] + for prompt_var in variables: + prompt_instance_execution = types.SchemaPromptInstancePromptExecution() + prompt_instance_execution.arguments = {} + for key, val in prompt_var.items(): + prompt_instance_execution.arguments[key] = ( + types.SchemaPromptInstanceVariableValue( + part_list=types.SchemaPromptSpecPartList(parts=[val]) + ) + ) + prompt_execution_list.append(prompt_instance_execution) + prompt_api_schema.executions = prompt_execution_list + + # Need to exclude variables from the prompt message as it is a client side + # only field + if prompt_api_schema.multimodal_prompt.prompt_message: + prompt_message_dict = ( + prompt_api_schema.multimodal_prompt.prompt_message.model_dump( + exclude=["variables"], exclude_none=True + ) + ) + prompt_api_schema.multimodal_prompt.prompt_message = ( + types.SchemaPromptSpecPromptMessage(**prompt_message_dict) + ) + prompt_metadata.prompt_api_schema = prompt_api_schema + + prompt_metadata.prompt_type = PROMPT_TYPE + + return prompt_metadata + + +def _create_prompt_from_dataset_metadata( + dataset: types.Dataset, +) -> types.Prompt: + """Constructs a types.Prompt from a types.Dataset resource returned from the API. + + Args: + dataset: The types.Dataset object containing the prompt metadata. + + Returns: + A types.Prompt object reconstructed from the dataset metadata. + """ + if ( + not hasattr(dataset, "metadata") + or dataset.metadata is None + or not isinstance(dataset.metadata, types.SchemaTextPromptDatasetMetadata) + ): + raise ValueError( + "Error retrieving prompt: prompt dataset resource is missing 'metadata'." + ) + api_schema = dataset.metadata.prompt_api_schema + prompt = types.Prompt() + + if api_schema is None: + return prompt + + if api_schema.multimodal_prompt: + + prompt_message = api_schema.multimodal_prompt.prompt_message + prompt.prompt_data = prompt_message + + if api_schema.executions: + executions = api_schema.executions + if executions and prompt.prompt_data is not None: + prompt.prompt_data.variables = [] + for execution in executions: + if execution.arguments: + args = execution.arguments + var_map = {} + for key, val in args.items(): + if ( + val.part_list is not None + and val.part_list.parts is not None + ): + part_list = val.part_list.parts + if part_list and part_list[0].text: + var_map[key] = part_list[0] + if var_map and prompt.prompt_data.variables is not None: + prompt.prompt_data.variables.append(var_map) + + return prompt + + +def _raise_for_invalid_prompt( + prompt: types.Prompt, +) -> None: + + if not prompt.prompt_data: + raise ValueError("Prompt data must be provided.") + if not prompt.prompt_data.contents: + raise ValueError("Prompt contents must be provided.") + if not prompt.prompt_data.model: + raise ValueError("Model name must be provided.") + if ( + prompt.prompt_data + and prompt.prompt_data.contents + and len(prompt.prompt_data.contents) > 1 + ): + raise ValueError("Multi-turn prompts are not currently supported.") diff --git a/agentplatform/_genai/_prompt_optimizer_utils.py b/agentplatform/_genai/_prompt_optimizer_utils.py new file mode 100644 index 0000000000..2a218432ee --- /dev/null +++ b/agentplatform/_genai/_prompt_optimizer_utils.py @@ -0,0 +1,215 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Utility functions for prompt optimizer.""" + +import json +from typing import Any, Optional, Union +from typing_extensions import TypeAlias + +from pydantic import ValidationError + +from . import types + +try: + import pandas as pd # pylint: disable=g-import-not-at-top + + PandasDataFrame: TypeAlias = pd.DataFrame +except ImportError: + pd = None + PandasDataFrame = Any # type: ignore[misc] + + +def _construct_input_prompt( + example_df: PandasDataFrame, + *, + prompt_col_name: str, + model_response_col_name: str, + rubrics_col_name: str, + rubrics_evaluations_col_name: str, + target_response_col_name: str, + system_instruction: Optional[str] = None, +) -> str: + """Construct the input prompt for the few shot prompt optimizer.""" + + all_prompts = [] + for row in example_df.to_dict(orient="records"): + example_data = { + "prompt": row[prompt_col_name], + "model_response": row[model_response_col_name], + } + if rubrics_col_name: + example_data["rubrics"] = row[rubrics_col_name] + if rubrics_evaluations_col_name: + example_data["rubrics_evaluations"] = row[rubrics_evaluations_col_name] + if target_response_col_name: + example_data["target_response"] = row[target_response_col_name] + + json_str = json.dumps(example_data, indent=2) + all_prompts.append(f"```JSON\n{json_str}\n```") + + all_prompts_str = "\n\n".join(all_prompts) + + if system_instruction is None: + system_instruction = "" + + return "\n".join( + [ + "Original System Instructions:\n", + system_instruction, + "Examples:\n", + all_prompts_str, + "\nNew Output:\n", + ] + ) + + +def _get_few_shot_prompt( + system_instruction: str, + config: types.OptimizeConfig, +) -> str: + """Builds the few shot prompt.""" + + if config.examples_dataframe is None: + raise ValueError("The 'examples_dataframe' is required in the config.") + + if "prompt" not in config.examples_dataframe.columns: + raise ValueError("'prompt' is required in the examples_dataframe.") + + if "prompt" not in config.examples_dataframe.columns: + raise ValueError("'prompt' is required in the examples_dataframe.") + prompt_col_name = "prompt" + + if "model_response" not in config.examples_dataframe.columns: + raise ValueError("'model_response' is required in the example_df.") + model_response_col_name = "model_response" + + target_response_col_name = "" + rubrics_col_name = "" + rubrics_evaluations_col_name = "" + + if ( + config.optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + ): + if "target_response" not in config.examples_dataframe.columns: + raise ValueError("'target_response' is required in the examples_dataframe.") + target_response_col_name = "target_response" + if "rubrics" in config.examples_dataframe.columns: + raise ValueError( + "Only 'target_response' should be provided " + "for OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE " + "but 'rubrics' was provided." + ) + + elif ( + config.optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS + ): + if not {"rubrics", "rubrics_evaluations"}.issubset( + config.examples_dataframe.columns + ): + raise ValueError( + "rubrics and rubrics_evaluations is required in the" + "examples_dataframe when rubrics is set." + ) + + rubrics_col_name = "rubrics" + rubrics_evaluations_col_name = "rubrics_evaluations" + if "target_response" in config.examples_dataframe.columns: + raise ValueError( + "Only 'rubrics' and 'rubrics_evaluations' should be provided " + "for OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS " + "but target_response was provided." + ) + else: + raise ValueError("One of 'target_response' or 'rubrics' must be provided.") + + return _construct_input_prompt( + config.examples_dataframe, + prompt_col_name=prompt_col_name, + model_response_col_name=model_response_col_name, + rubrics_col_name=rubrics_col_name, + rubrics_evaluations_col_name=rubrics_evaluations_col_name, + target_response_col_name=target_response_col_name, + system_instruction=system_instruction, + ) + + +def _get_service_account( + config: types.PromptOptimizerConfigOrDict, +) -> str: + """Get the service account from the config for the custom job.""" + if isinstance(config, dict): + config = types.PromptOptimizerConfig.model_validate(config) + + if ( + config.service_account and config.service_account_project_number + ): # pytype: disable=attribute-error + raise ValueError( + "Only one of service_account or " + "service_account_project_number can be provided." + ) + elif config.service_account: # pytype: disable=attribute-error + return config.service_account # pytype: disable=attribute-error + elif config.service_account_project_number: # pytype: disable=attribute-error + return f"{config.service_account_project_number}-compute@developer.gserviceaccount.com" # pytype: disable=attribute-error + else: + raise ValueError( + "Either service_account or service_account_project_number " "is required." + ) + + +def _clean_and_parse_optimized_prompt(output_str: str) -> Optional[Any]: + """Cleans a string response returned from the prompt optimizer endpoint. + + Args: + output_str: The optimized prompt string containing the JSON data, + potentially with markdown formatting like ```json ... ```. + + Returns: + The parsed JSON data, or None if parsing fails. + """ + lines = output_str.strip().split("\n") + # Remove markdown delimiters + if lines and lines[0].strip().startswith("```"): + cleaned_string = "\n".join(lines[1:-1]) + else: + cleaned_string = output_str + + # remove any 'json' labels if they exist on the first line. + if cleaned_string.strip().startswith("json"): + cleaned_string = cleaned_string.strip()[4:].strip() + + try: + return json.loads(cleaned_string) + except json.JSONDecodeError as e: + # TODO(b/437144880): raise errors.ClientError here instead + raise ValueError( + f"Failed to parse the response from prompt optimizer endpoint. {e}" + ) from e + + +def _parse( + output_str: str, +) -> Union[ + types.prompts.ParsedResponse, + types.prompts.ParsedResponseFewShot, +]: + """Parses the output string from the prompt optimizer endpoint.""" + parsed_out = _clean_and_parse_optimized_prompt(output_str) + try: + return types.prompts.ParsedResponse(**parsed_out) + except ValidationError: + return types.prompts.ParsedResponseFewShot(**parsed_out) diff --git a/agentplatform/_genai/_transformers.py b/agentplatform/_genai/_transformers.py new file mode 100644 index 0000000000..fb6f477cf8 --- /dev/null +++ b/agentplatform/_genai/_transformers.py @@ -0,0 +1,538 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Transformers module for Vertex addons.""" +import json +import re +from typing import Any + +from google.genai._common import get_value_by_path as getv + +from . import _evals_constant +from . import _evals_data_converters +from . import types + +_METRIC_RES_NAME_RE = r"^projects/[^/]+/locations/[^/]+/evaluationMetrics/[^/]+$" + + +def t_metrics( + metrics: "list[types.MetricSubclass]", + set_default_aggregation_metrics: bool = False, +) -> list[dict[str, Any]]: + """Prepares the metric payload for the evaluation request. + + Args: + metrics: A list of metrics used for evaluation. + set_default_aggregation_metrics: Whether to set default aggregation metrics. + Returns: + A list of resolved metric payloads for the evaluation request. + """ + metrics_payload = [] + + for metric in metrics: + metric_payload_item: dict[str, Any] = {} + + metric_id = getv(metric, ["metric"]) or getv(metric, ["name"]) + metric_name = metric_id.lower() if metric_id else None + + if set_default_aggregation_metrics: + metric_payload_item["aggregation_metrics"] = [ + "AVERAGE", + "STANDARD_DEVIATION", + ] + + if metric_name == "exact_match": + metric_payload_item["exact_match_spec"] = {} + elif metric_name == "bleu": + metric_payload_item["bleu_spec"] = {} + elif metric_name and metric_name.startswith("rouge"): + rouge_type = metric_name.replace("_", "") + metric_payload_item["rouge_spec"] = {"rouge_type": rouge_type} + # API Pre-defined metrics + elif ( + metric_name and metric_name in _evals_constant.SUPPORTED_PREDEFINED_METRICS + ): + metric_payload_item["predefined_metric_spec"] = { + "metric_spec_name": metric_name, + "metric_spec_parameters": metric.metric_spec_parameters, + } + # Custom Code Execution Metric + elif ( + hasattr(metric, "remote_custom_function") and metric.remote_custom_function + ): + metric_payload_item["custom_code_execution_spec"] = { + "evaluation_function": metric.remote_custom_function + } + elif ( + isinstance(metric, types.CodeExecutionMetric) + or ( + isinstance(metric, types.Metric) + and isinstance(getattr(metric, "custom_function", None), str) + ) + ) and getattr(metric, "custom_function", None): + metric_payload_item["custom_code_execution_spec"] = { + "evaluation_function": metric.custom_function + } + # LLM-based metrics + elif hasattr(metric, "prompt_template") and metric.prompt_template: + llm_based_spec: dict[str, Any] = { + "metric_prompt_template": metric.prompt_template + } + system_instruction = getv(metric, ["judge_model_system_instruction"]) + if system_instruction: + llm_based_spec["system_instruction"] = system_instruction + rubric_group_name = getv(metric, ["rubric_group_name"]) + if rubric_group_name: + llm_based_spec["rubric_group_key"] = rubric_group_name + return_raw_output = getv(metric, ["return_raw_output"]) + if return_raw_output: + llm_based_spec["custom_output_format_config"] = { + "return_raw_output": return_raw_output + } + + autorater_config: dict[str, Any] = {} + if hasattr(metric, "judge_model") and metric.judge_model: + autorater_config["autorater_model"] = metric.judge_model + if ( + hasattr(metric, "judge_model_generation_config") + and metric.judge_model_generation_config + ): + autorater_config["generation_config"] = ( + metric.judge_model_generation_config + ) + if ( + hasattr(metric, "judge_model_sampling_count") + and metric.judge_model_sampling_count + ): + autorater_config["sampling_count"] = metric.judge_model_sampling_count + + if autorater_config: + llm_based_spec["judge_autorater_config"] = autorater_config + + result_parsing_function = getv(metric, ["result_parsing_function"]) + if result_parsing_function: + llm_based_spec["result_parser_config"] = { + "custom_code_parser_config": { + "parsing_function": result_parsing_function + } + } + + metric_payload_item["llm_based_metric_spec"] = llm_based_spec + elif getattr(metric, "metric_resource_name", None) is not None: + # Safe pass + pass + else: + raise ValueError( + f"Unsupported metric type or invalid metric name: {metric_name}" + ) + metrics_payload.append(metric_payload_item) + return metrics_payload + + +def t_metric_sources(metrics: list[Any]) -> list[dict[str, Any]]: + """Prepares the MetricSource payload.""" + sources_payload = [] + for metric in metrics: + resource_name = getattr(metric, "metric_resource_name", None) + if ( + not resource_name + and isinstance(metric, str) + and re.match(_METRIC_RES_NAME_RE, metric) + ): + resource_name = metric + + if resource_name: + sources_payload.append({"metric_resource_name": resource_name}) + else: + if hasattr(metric, "metric") and not isinstance(metric, str): + metric = metric.metric + + if not hasattr(metric, "name"): + metric = types.Metric(name=str(metric)) + + metric_payload = t_metrics([metric])[0] + sources_payload.append({"metric": metric_payload}) + return sources_payload + + +def t_user_scenario_generation_config( + config: "types.evals.UserScenarioGenerationConfigOrDict", +) -> dict[str, Any]: + """Transforms UserScenarioGenerationConfig to Vertex AI format.""" + payload: dict[str, Any] = {} + config_dict = config if isinstance(config, dict) else config.model_dump() + + if getv(config_dict, ["count"]) is not None: + payload["user_scenario_count"] = getv(config_dict, ["count"]) + if getv(config_dict, ["generation_instruction"]) is not None: + payload["simulation_instruction"] = getv( + config_dict, ["generation_instruction"] + ) + if getv(config_dict, ["environment_context"]) is not None: + payload["environment_data"] = getv(config_dict, ["environment_context"]) + if getv(config_dict, ["model_name"]) is not None: + payload["model_name"] = getv(config_dict, ["model_name"]) + + return payload + + +def t_metric_for_registry( + metric: "types.Metric", +) -> dict[str, Any]: + """Prepares the metric payload specifically for EvaluationMetric registration.""" + metric_payload_item: dict[str, Any] = {} + metric_name = getattr(metric, "name", None) + if metric_name: + metric_name = metric_name.lower() + + # Custom Code Execution Metric + if hasattr(metric, "remote_custom_function") and metric.remote_custom_function: + metric_payload_item["custom_code_execution_spec"] = { + "evaluation_function": metric.remote_custom_function + } + elif ( + isinstance(metric, types.CodeExecutionMetric) + or ( + isinstance(metric, types.Metric) + and isinstance(getattr(metric, "custom_function", None), str) + ) + ) and getattr(metric, "custom_function", None): + metric_payload_item["custom_code_execution_spec"] = { + "evaluation_function": metric.custom_function + } + + # LLM-based metric + elif (hasattr(metric, "prompt_template") and metric.prompt_template) or ( + hasattr(metric, "rubric_group_name") and metric.rubric_group_name + ): + llm_based_spec: dict[str, Any] = {} + + if hasattr(metric, "prompt_template") and metric.prompt_template: + llm_based_spec["metric_prompt_template"] = metric.prompt_template + system_instruction = getv(metric, ["judge_model_system_instruction"]) + if system_instruction: + llm_based_spec["system_instruction"] = system_instruction + rubric_group_name = getv(metric, ["rubric_group_name"]) + if rubric_group_name: + llm_based_spec["rubric_group_key"] = rubric_group_name + + autorater_config: dict[str, Any] = {} + if hasattr(metric, "judge_model") and metric.judge_model: + autorater_config["autorater_model"] = metric.judge_model + if ( + hasattr(metric, "judge_model_generation_config") + and metric.judge_model_generation_config + ): + autorater_config["generation_config"] = metric.judge_model_generation_config + if ( + hasattr(metric, "judge_model_sampling_count") + and metric.judge_model_sampling_count + ): + autorater_config["sampling_count"] = metric.judge_model_sampling_count + + if autorater_config: + llm_based_spec["judge_autorater_config"] = autorater_config + + result_parsing_function = getv(metric, ["result_parsing_function"]) + if result_parsing_function: + llm_based_spec["result_parser_config"] = { + "custom_code_parser_config": { + "parsing_function": result_parsing_function + } + } + + metric_payload_item["llm_based_metric_spec"] = llm_based_spec + + else: + raise ValueError(f"Unsupported metric type: {metric_name}") + + return metric_payload_item + + +_ALLOWED_PART_FIELDS = frozenset( + { + "text", + "inline_data", + "file_data", + "function_call", + "function_response", + "video_metadata", + "thought", + "thought_signature", + "code_execution_result", + "executable_code", + "media_resolution", + } +) + + +def _sanitize_agent_data(agent_data: dict[str, Any]) -> dict[str, Any]: + """Strips SDK-only fields from agent_data so the API accepts the payload. + + The SDK's AgentData model may contain fields like 'tool_call', + 'tool_response', 'part_metadata', and 'will_continue' that don't exist + in the API's AgentData / Content proto. This function recursively removes + them from content parts and keeps only API-recognized top-level fields. + """ + if not isinstance(agent_data, dict): + return agent_data + + sanitized: dict[str, Any] = {} + for key, value in agent_data.items(): + if key == "turns" and isinstance(value, list): + sanitized["turns"] = [ + _sanitize_turn(t) for t in value if isinstance(t, dict) + ] + elif key == "agents" and isinstance(value, dict): + sanitized["agents"] = { + k: _sanitize_agent_config(v) if isinstance(v, dict) else v + for k, v in value.items() + } + # Skip unknown top-level fields (e.g. "error" from failed agent runs). + return sanitized + + +def _sanitize_agent_config(config: dict[str, Any]) -> dict[str, Any]: + """Sanitizes an AgentConfig dict, keeping only API-known fields.""" + allowed = { + "agent_id", + "agent_type", + "description", + "instruction", + "tools", + "sub_agents", + } + return {k: v for k, v in config.items() if k in allowed} + + +def _sanitize_turn(turn: dict[str, Any]) -> dict[str, Any]: + """Sanitizes a ConversationTurn dict.""" + sanitized: dict[str, Any] = {} + for key, value in turn.items(): + if key == "events" and isinstance(value, list): + sanitized["events"] = [ + _sanitize_event(e) for e in value if isinstance(e, dict) + ] + else: + sanitized[key] = value + return sanitized + + +def _sanitize_event(event: dict[str, Any]) -> dict[str, Any]: + """Sanitizes an AgentEvent dict.""" + sanitized: dict[str, Any] = {} + for key, value in event.items(): + if key == "content" and isinstance(value, dict): + sanitized["content"] = _sanitize_content(value) + elif key in ("author", "event_time", "state_delta", "active_tools"): + sanitized[key] = value + # Skip unknown event-level fields. + return sanitized + + +def _sanitize_content(content: dict[str, Any]) -> dict[str, Any]: + """Sanitizes a Content dict, stripping unknown fields from parts.""" + sanitized: dict[str, Any] = {} + for key, value in content.items(): + if key == "parts" and isinstance(value, list): + sanitized["parts"] = [ + _sanitize_part(p) for p in value if isinstance(p, dict) + ] + elif key == "role": + sanitized["role"] = value + return sanitized + + +def _sanitize_part(part: dict[str, Any]) -> dict[str, Any]: + """Keeps only API-recognized fields in a Part dict.""" + sanitized: dict[str, Any] = {} + for key, value in part.items(): + if key in _ALLOWED_PART_FIELDS: + if key == "function_response" and isinstance(value, dict): + # Strip unknown sub-fields like 'will_continue'. + sanitized[key] = { + k: v for k, v in value.items() if k in ("name", "id", "response") + } + else: + sanitized[key] = value + return sanitized + + +def _extract_agent_data_from_df( + eval_dataset: Any, + case_idx: int, +) -> Any: + """Extracts agent_data from a DataFrame-based EvaluationDataset by row index.""" + if not eval_dataset: + return None + ds = eval_dataset[0] if isinstance(eval_dataset, list) else eval_dataset + df = getv(ds, ["eval_dataset_df"]) + if df is None or not hasattr(df, "iloc"): + return None + if case_idx < 0 or case_idx >= len(df): + return None + row = df.iloc[case_idx] + if "agent_data" not in row or row["agent_data"] is None: + return None + return row["agent_data"] + + +def t_inline_results( + eval_results: list[Any], +) -> list[dict[str, Any]]: + """Transforms a list of SDK EvaluationResults into API EvaluationResults.""" + api_results: list[dict[str, Any]] = [] + + for eval_result in eval_results: + metadata = getv(eval_result, ["metadata"]) + candidate_names = getv(metadata, ["candidate_names"]) if metadata else [] + candidate_names = candidate_names or [] + + eval_dataset = getv(eval_result, ["evaluation_dataset"]) + eval_cases: list[Any] = [] + if isinstance(eval_dataset, list) and eval_dataset: + eval_cases = getv(eval_dataset[0], ["eval_cases"]) or [] + + eval_case_results = getv(eval_result, ["eval_case_results"]) or [] + + for case_result in eval_case_results: + case_idx = getv(case_result, ["eval_case_index"]) or 0 + + eval_case = None + if 0 <= case_idx < len(eval_cases): + eval_case = eval_cases[case_idx] + + prompt_payload: dict[str, Any] = {} + if eval_case: + agent_data = getv(eval_case, ["agent_data"]) + prompt = getv(eval_case, ["prompt"]) + + if agent_data: + if hasattr(agent_data, "model_dump"): + prompt_payload["agent_data"] = _sanitize_agent_data( + agent_data.model_dump(exclude_none=True) + ) + elif isinstance(agent_data, dict): + prompt_payload["agent_data"] = _sanitize_agent_data(agent_data) + else: + prompt_payload["agent_data"] = agent_data + elif prompt: + text = _evals_data_converters._get_content_text( + prompt + ) # pylint: disable=protected-access + if text: + prompt_payload["text"] = str(text) + + # Fallback: extract agent_data from the DataFrame when eval_cases + # are not available (e.g., run_inference -> evaluate flow). + if not prompt_payload: + df_agent_data = _extract_agent_data_from_df(eval_dataset, case_idx) + if df_agent_data is not None: + if hasattr(df_agent_data, "model_dump"): + prompt_payload["agent_data"] = _sanitize_agent_data( + df_agent_data.model_dump(exclude_none=True) + ) + elif isinstance(df_agent_data, str): + try: + parsed = json.loads(df_agent_data) + if isinstance(parsed, dict) and "error" in parsed: + pass # Skip error payloads from failed agent runs. + else: + prompt_payload["agent_data"] = _sanitize_agent_data( + parsed + ) + except (json.JSONDecodeError, ValueError): + pass + elif isinstance(df_agent_data, dict): + if "error" not in df_agent_data: + prompt_payload["agent_data"] = _sanitize_agent_data( + df_agent_data + ) + + cand_results = getv(case_result, ["response_candidate_results"]) or [] + for resp_cand_result in cand_results: + resp_idx = getv(resp_cand_result, ["response_index"]) or 0 + cand_name = f"candidate-{resp_idx}" + if 0 <= resp_idx < len(candidate_names): + cand_name = candidate_names[resp_idx] + + metric_results = getv(resp_cand_result, ["metric_results"]) or {} + + for metric_name, metric_res in metric_results.items(): + api_rubric_verdicts: list[dict[str, Any]] = [] + rubric_verdicts = getv(metric_res, ["rubric_verdicts"]) or [] + + for verdict in rubric_verdicts: + verdict_dict: dict[str, Any] = {} + eval_rubric = getv(verdict, ["evaluated_rubric"]) + + if eval_rubric: + rubric_dict: dict[str, Any] = {} + rubric_id = getv(eval_rubric, ["rubric_id"]) + if rubric_id: + rubric_dict["rubric_id"] = str(rubric_id) + + rubric_content = getv(eval_rubric, ["content"]) + if rubric_content: + text = getv(rubric_content, ["text"]) + prop = getv(rubric_content, ["property"]) + + content_dict: dict[str, Any] = {} + if text: + content_dict["text"] = str(text) + if prop: + desc = getv(prop, ["description"]) + if desc: + content_dict["property"] = { + "description": str(desc) + } + rubric_dict["content"] = content_dict + verdict_dict["evaluated_rubric"] = rubric_dict + + verdict_bool = getv(verdict, ["verdict"]) + if verdict_bool is not None: + verdict_dict["verdict"] = bool(verdict_bool) + + reasoning = getv(verdict, ["reasoning"]) + if reasoning: + verdict_dict["reasoning"] = str(reasoning) + + if verdict_dict: + api_rubric_verdicts.append(verdict_dict) + + score = getv(metric_res, ["score"]) + explanation = getv(metric_res, ["explanation"]) + + candidate_result_payload: dict[str, Any] = { + "candidate": str(cand_name), + "metric": str(metric_name), + } + if score is not None: + candidate_result_payload["score"] = float(score) + if explanation: + candidate_result_payload["explanation"] = str(explanation) + if api_rubric_verdicts: + candidate_result_payload["rubric_verdicts"] = ( + api_rubric_verdicts + ) + + api_eval_result = { + "request": {"prompt": prompt_payload}, + "metric": str(metric_name), + "candidate_results": [candidate_result_payload], + } + api_results.append(api_eval_result) + + return api_results diff --git a/agentplatform/_genai/a2a_task_events.py b/agentplatform/_genai/a2a_task_events.py new file mode 100644 index 0000000000..f987dbc487 --- /dev/null +++ b/agentplatform/_genai/a2a_task_events.py @@ -0,0 +1,508 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import types + +logger = logging.getLogger("agentplatform_genai.a2ataskevents") + +logger.setLevel(logging.INFO) + + +def _AppendAgentEngineTaskEventRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["task_events"]) is not None: + setv( + to_object, + ["taskEvents"], + [item for item in getv(from_object, ["task_events"])], + ) + + return to_object + + +def _AppendAgentEngineTaskEventResponse_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + return to_object + + +def _ListAgentEngineTaskEventsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["order_by"]) is not None: + setv(parent_object, ["_query", "orderBy"], getv(from_object, ["order_by"])) + + return to_object + + +def _ListAgentEngineTaskEventsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineTaskEventsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class A2aTaskEvents(_api_module.BaseModule): + + def append( + self, + *, + name: str, + task_events: list[types.TaskEventOrDict], + config: Optional[types.AppendAgentEngineTaskEventConfigOrDict] = None, + ) -> types.AppendAgentEngineTaskEventResponse: + """ + Adds events to an Agent Engine task. + + Args: + name (str): Required. The name of the Agent Engine task to append the events to. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{a2a_task_id}`. + task_events (list[TaskEvent]): + Required. The events to append to the task. + + Returns: + AppendAgentEngineTaskEventResponse: The response for appending the task events. + + """ + + parameter_model = types._AppendAgentEngineTaskEventRequestParameters( + name=name, + task_events=task_events, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AppendAgentEngineTaskEventRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:appendEvents".format_map(request_url_dict) + else: + path = "{name}:appendEvents" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AppendAgentEngineTaskEventResponse_from_vertex( + response_dict + ) + + return_value = types.AppendAgentEngineTaskEventResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTaskEventsConfigOrDict] = None, + ) -> types.ListAgentEngineTaskEventsResponse: + """ + Lists Agent Engine task events. + + Args: + name (str): Required. The name of the Agent Engine task to list events for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{a2a_task_id}`. + config (ListAgentEngineTaskEventsConfig): + Optional. Additional configurations for listing the Agent Engine tasks. + + Returns: + ListAgentEngineTaskEventsResponse: The requested Agent Engine tasks. + + """ + + parameter_model = types._ListAgentEngineTaskEventsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineTaskEventsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/events".format_map(request_url_dict) + else: + path = "{name}/events" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineTaskEventsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTaskEventsConfigOrDict] = None, + ) -> Iterator[types.TaskEvent]: + """Lists the A2A tasks of an Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to list tasks for. + config (List): + Optional. The configuration for the tasks to list. + + Returns: + Iterable[TaskEvent]: An iterable of Task events. + """ + + return Pager( + "taskEvents", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + +class AsyncA2aTaskEvents(_api_module.BaseModule): + + async def append( + self, + *, + name: str, + task_events: list[types.TaskEventOrDict], + config: Optional[types.AppendAgentEngineTaskEventConfigOrDict] = None, + ) -> types.AppendAgentEngineTaskEventResponse: + """ + Adds events to an Agent Engine task. + + Args: + name (str): Required. The name of the Agent Engine task to append the events to. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{a2a_task_id}`. + task_events (list[TaskEvent]): + Required. The events to append to the task. + + Returns: + AppendAgentEngineTaskEventResponse: The response for appending the task events. + + """ + + parameter_model = types._AppendAgentEngineTaskEventRequestParameters( + name=name, + task_events=task_events, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AppendAgentEngineTaskEventRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:appendEvents".format_map(request_url_dict) + else: + path = "{name}:appendEvents" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AppendAgentEngineTaskEventResponse_from_vertex( + response_dict + ) + + return_value = types.AppendAgentEngineTaskEventResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTaskEventsConfigOrDict] = None, + ) -> types.ListAgentEngineTaskEventsResponse: + """ + Lists Agent Engine task events. + + Args: + name (str): Required. The name of the Agent Engine task to list events for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{a2a_task_id}`. + config (ListAgentEngineTaskEventsConfig): + Optional. Additional configurations for listing the Agent Engine tasks. + + Returns: + ListAgentEngineTaskEventsResponse: The requested Agent Engine tasks. + + """ + + parameter_model = types._ListAgentEngineTaskEventsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineTaskEventsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/events".format_map(request_url_dict) + else: + path = "{name}/events" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineTaskEventsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTaskEventsConfigOrDict] = None, + ) -> AsyncPager[types.TaskEvent]: + """Lists the A2A tasks of an Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to list tasks for. + config (List): + Optional. The configuration for the tasks to list. + + Returns: + AsyncPager[TaskEvent]: An async pager of Task events. + """ + + return AsyncPager( + "taskEvents", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) diff --git a/agentplatform/_genai/a2a_tasks.py b/agentplatform/_genai/a2a_tasks.py new file mode 100644 index 0000000000..563e4a7284 --- /dev/null +++ b/agentplatform/_genai/a2a_tasks.py @@ -0,0 +1,861 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import importlib +import json +import logging +import typing +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import types + +if typing.TYPE_CHECKING: + from . import a2a_task_events as a2a_task_events_module + + _ = a2a_task_events_module + + +logger = logging.getLogger("agentplatform_genai.a2atasks") + +logger.setLevel(logging.INFO) + + +def _CreateAgentEngineTaskConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["context_id"]) is not None: + setv(parent_object, ["contextId"], getv(from_object, ["context_id"])) + + if getv(from_object, ["metadata"]) is not None: + setv(parent_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["status_details"]) is not None: + setv(parent_object, ["statusDetails"], getv(from_object, ["status_details"])) + + if getv(from_object, ["output"]) is not None: + setv(parent_object, ["output"], getv(from_object, ["output"])) + + return to_object + + +def _CreateAgentEngineTaskRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["a2a_task_id"]) is not None: + setv(to_object, ["_query", "a2a_task_id"], getv(from_object, ["a2a_task_id"])) + + if getv(from_object, ["config"]) is not None: + _CreateAgentEngineTaskConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _DeleteAgentEngineTaskRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetAgentEngineTaskRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListAgentEngineTasksConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["order_by"]) is not None: + setv(parent_object, ["_query", "orderBy"], getv(from_object, ["order_by"])) + + return to_object + + +def _ListAgentEngineTasksRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineTasksConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +class A2aTasks(_api_module.BaseModule): + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineTaskConfigOrDict] = None, + ) -> None: + """ + Deletes an agent engine task. + + Args: + name (str): Required. The name of the Agent Engine task to delete. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{task_id}`. + config (DeleteAgentEngineTaskConfig): + Optional. Additional configurations for deleting the Agent Engine task. + + Returns: + None + + """ + + parameter_model = types._DeleteAgentEngineTaskRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + self._api_client.request("delete", path, request_dict, http_options) + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineTaskConfigOrDict] = None, + ) -> types.A2aTask: + """ + Gets an agent engine task. + + Args: + name (str): Required. The name of the Agent Engine task to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{task_id}`. + config (GetAgentEngineTaskConfig): + Optional. Additional configurations for getting the Agent Engine task. + + Returns: + AgentEngineTask: The requested Agent Engine task. + + """ + + parameter_model = types._GetAgentEngineTaskRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.A2aTask._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTasksConfigOrDict] = None, + ) -> types.ListAgentEngineTasksResponse: + """ + Lists Agent Engine tasks. + + Args: + name (str): Required. The name of the Agent Engine to list tasks for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineTasksConfig): + Optional. Additional configurations for listing the Agent Engine tasks. + + Returns: + ListAgentEngineTasksResponse: The requested Agent Engine tasks. + + """ + + parameter_model = types._ListAgentEngineTasksRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineTasksRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/a2aTasks".format_map(request_url_dict) + else: + path = "{name}/a2aTasks" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineTasksResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + name: str, + a2a_task_id: str, + config: Optional[types.CreateAgentEngineTaskConfigOrDict] = None, + ) -> types.A2aTask: + """ + Creates a new task in the Agent Engine. + + Args: + name (str): Required. The name of the Agent Engine to create the task under. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + a2a_task_id (str): Required. The user ID of the task. + context_id (str): Required. The ID of the context to use for the task. + config (CreateAgentEngineTaskConfig): + Optional. Additional configurations for creating the Agent Engine task. + + Returns: + A2aTask: The created Agent Engine task. + + """ + + parameter_model = types._CreateAgentEngineTaskRequestParameters( + name=name, + a2a_task_id=a2a_task_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/a2aTasks".format_map(request_url_dict) + else: + path = "{name}/a2aTasks" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.A2aTask._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _events = None + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTasksConfigOrDict] = None, + ) -> Iterator[types.A2aTask]: + """Lists the A2A tasks of an Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to list tasks for. + config (List): + Optional. The configuration for the tasks to list. + + Returns: + Iterable[A2aTask]: An iterable of A2A tasks. + """ + + return Pager( + "a2aTasks", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + @property + def events(self) -> "a2a_task_events_module.A2aTaskEvents": + if self._events is None: + try: + # We need to lazy load the events module to handle the + # possibility of ImportError when dependencies are not installed. + self._events = importlib.import_module(".a2a_task_events", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.a2a_tasks.events' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._events.A2aTaskEvents(self._api_client) # type: ignore[no-any-return] + + +class AsyncA2aTasks(_api_module.BaseModule): + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineTaskConfigOrDict] = None, + ) -> None: + """ + Deletes an agent engine task. + + Args: + name (str): Required. The name of the Agent Engine task to delete. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{task_id}`. + config (DeleteAgentEngineTaskConfig): + Optional. Additional configurations for deleting the Agent Engine task. + + Returns: + None + + """ + + parameter_model = types._DeleteAgentEngineTaskRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + await self._api_client.async_request("delete", path, request_dict, http_options) + + async def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineTaskConfigOrDict] = None, + ) -> types.A2aTask: + """ + Gets an agent engine task. + + Args: + name (str): Required. The name of the Agent Engine task to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/a2aTasks/{task_id}`. + config (GetAgentEngineTaskConfig): + Optional. Additional configurations for getting the Agent Engine task. + + Returns: + AgentEngineTask: The requested Agent Engine task. + + """ + + parameter_model = types._GetAgentEngineTaskRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.A2aTask._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTasksConfigOrDict] = None, + ) -> types.ListAgentEngineTasksResponse: + """ + Lists Agent Engine tasks. + + Args: + name (str): Required. The name of the Agent Engine to list tasks for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineTasksConfig): + Optional. Additional configurations for listing the Agent Engine tasks. + + Returns: + ListAgentEngineTasksResponse: The requested Agent Engine tasks. + + """ + + parameter_model = types._ListAgentEngineTasksRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineTasksRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/a2aTasks".format_map(request_url_dict) + else: + path = "{name}/a2aTasks" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineTasksResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def create( + self, + *, + name: str, + a2a_task_id: str, + config: Optional[types.CreateAgentEngineTaskConfigOrDict] = None, + ) -> types.A2aTask: + """ + Creates a new task in the Agent Engine. + + Args: + name (str): Required. The name of the Agent Engine to create the task under. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + a2a_task_id (str): Required. The user ID of the task. + context_id (str): Required. The ID of the context to use for the task. + config (CreateAgentEngineTaskConfig): + Optional. Additional configurations for creating the Agent Engine task. + + Returns: + A2aTask: The created Agent Engine task. + + """ + + parameter_model = types._CreateAgentEngineTaskRequestParameters( + name=name, + a2a_task_id=a2a_task_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineTaskRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/a2aTasks".format_map(request_url_dict) + else: + path = "{name}/a2aTasks" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.A2aTask._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _events = None + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineTasksConfigOrDict] = None, + ) -> AsyncPager[types.A2aTask]: + """Lists the A2A tasks of an Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to list tasks for. + config (List): + Optional. The configuration for the tasks to list. + + Returns: + AsyncPager[A2aTask]: An async pager of A2A tasks. + """ + + return AsyncPager( + "a2aTasks", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) + + @property + def events(self) -> "a2a_task_events_module.AsyncA2aTaskEvents": + if self._events is None: + try: + # We need to lazy load the events module to handle the + # possibility of ImportError when dependencies are not installed. + self._events = importlib.import_module(".a2a_task_events", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.a2a_tasks.events' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._events.AsyncA2aTaskEvents(self._api_client) # type: ignore[no-any-return] diff --git a/agentplatform/_genai/agent_engines.py b/agentplatform/_genai/agent_engines.py new file mode 100644 index 0000000000..12fe9f5ee8 --- /dev/null +++ b/agentplatform/_genai/agent_engines.py @@ -0,0 +1,4114 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +import importlib +import json +import logging +import typing +from typing import Any, AsyncIterator, Iterator, Optional, Sequence, Tuple, Union +from urllib.parse import urlencode +import warnings + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import Pager + +from . import _agent_engines_utils +from . import types + +if typing.TYPE_CHECKING: + from . import sessions as sessions_module + from . import memories as memories_module + from . import a2a_tasks as a2a_tasks_module + from . import runtimes as runtimes_module + + _ = sessions_module + __ = memories_module + ___ = a2a_tasks_module + ____ = runtimes_module + + +logger = logging.getLogger("agentplatform_genai.agentengines") + +logger.setLevel(logging.INFO) + + +def _AgentEngineOperation_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["done"]) is not None: + setv(to_object, ["done"], getv(from_object, ["done"])) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["response"]) is not None: + setv( + to_object, + ["response"], + _ReasoningEngine_from_vertex(getv(from_object, ["response"]), to_object), + ) + + return to_object + + +def _CancelQueryJobAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["operation_name"]) is not None: + setv(parent_object, ["operationName"], getv(from_object, ["operation_name"])) + + return to_object + + +def _CancelQueryJobAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CancelQueryJobAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _CheckQueryJobAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["retrieve_result"]) is not None: + setv(parent_object, ["retrieveResult"], getv(from_object, ["retrieve_result"])) + + return to_object + + +def _CheckQueryJobAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CheckQueryJobAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _CheckQueryJobResult_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(parent_object, ["operationName"]) is not None: + setv(to_object, ["operation_name"], getv(parent_object, ["operationName"])) + + if getv(parent_object, ["outputGcsUri"]) is not None: + setv(to_object, ["output_gcs_uri"], getv(parent_object, ["outputGcsUri"])) + + if getv(parent_object, ["status"]) is not None: + setv(to_object, ["status"], getv(parent_object, ["status"])) + + if getv(parent_object, ["result"]) is not None: + setv(to_object, ["result"], getv(parent_object, ["result"])) + + return to_object + + +def _CreateAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["spec"]) is not None: + setv(parent_object, ["spec"], getv(from_object, ["spec"])) + + if getv(from_object, ["context_spec"]) is not None: + setv( + parent_object, + ["contextSpec"], + _ReasoningEngineContextSpec_to_vertex( + getv(from_object, ["context_spec"]), to_object + ), + ) + + if getv(from_object, ["psc_interface_config"]) is not None: + setv( + parent_object, + ["pscInterfaceConfig"], + getv(from_object, ["psc_interface_config"]), + ) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(parent_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["source_packages"]) is not None: + setv(parent_object, ["sourcePackages"], getv(from_object, ["source_packages"])) + + if getv(from_object, ["entrypoint_module"]) is not None: + setv( + parent_object, + ["entrypointModule"], + getv(from_object, ["entrypoint_module"]), + ) + + if getv(from_object, ["entrypoint_object"]) is not None: + setv( + parent_object, + ["entrypointObject"], + getv(from_object, ["entrypoint_object"]), + ) + + if getv(from_object, ["requirements_file"]) is not None: + setv( + parent_object, + ["requirementsFile"], + getv(from_object, ["requirements_file"]), + ) + + if getv(from_object, ["agent_framework"]) is not None: + setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"])) + + if getv(from_object, ["python_version"]) is not None: + setv(parent_object, ["pythonVersion"], getv(from_object, ["python_version"])) + + if getv(from_object, ["agent_gateway_config"]) is not None: + setv( + parent_object, + ["agentGatewayConfig"], + getv(from_object, ["agent_gateway_config"]), + ) + + return to_object + + +def _CreateAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + _CreateAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _DeleteAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["force"]) is not None: + setv(to_object, ["force"], getv(from_object, ["force"])) + + return to_object + + +def _GetAgentEngineOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + _ListAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _ListReasoningEnginesResponse_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["sdkHttpResponse"]) is not None: + setv(to_object, ["sdk_http_response"], getv(from_object, ["sdkHttpResponse"])) + + if getv(from_object, ["nextPageToken"]) is not None: + setv(to_object, ["next_page_token"], getv(from_object, ["nextPageToken"])) + + if getv(from_object, ["reasoningEngines"]) is not None: + setv( + to_object, + ["reasoning_engines"], + [ + _ReasoningEngine_from_vertex(item, to_object) + for item in getv(from_object, ["reasoningEngines"]) + ], + ) + + return to_object + + +def _QueryAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["class_method"]) is not None: + setv(parent_object, ["classMethod"], getv(from_object, ["class_method"])) + + if getv(from_object, ["input"]) is not None: + setv(parent_object, ["input"], getv(from_object, ["input"])) + + if getv(from_object, ["include_all_fields"]) is not None: + setv(to_object, ["includeAllFields"], getv(from_object, ["include_all_fields"])) + + return to_object + + +def _QueryAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _QueryAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _ReasoningEngineContextSpecMemoryBankConfig_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["customizationConfigs"]) is not None: + setv( + to_object, + ["customization_configs"], + [item for item in getv(from_object, ["customizationConfigs"])], + ) + + if getv(from_object, ["disableMemoryRevisions"]) is not None: + setv( + to_object, + ["disable_memory_revisions"], + getv(from_object, ["disableMemoryRevisions"]), + ) + + if getv(from_object, ["generationConfig"]) is not None: + setv(to_object, ["generation_config"], getv(from_object, ["generationConfig"])) + + if getv(from_object, ["similaritySearchConfig"]) is not None: + setv( + to_object, + ["similarity_search_config"], + getv(from_object, ["similaritySearchConfig"]), + ) + + if getv(from_object, ["ttlConfig"]) is not None: + setv(to_object, ["ttl_config"], getv(from_object, ["ttlConfig"])) + + if getv(from_object, ["structuredMemoryConfigs"]) is not None: + setv( + to_object, + ["structured_memory_configs"], + [ + _StructuredMemoryConfig_from_vertex(item, to_object) + for item in getv(from_object, ["structuredMemoryConfigs"]) + ], + ) + + return to_object + + +def _ReasoningEngineContextSpecMemoryBankConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["customization_configs"]) is not None: + setv( + to_object, + ["customizationConfigs"], + [item for item in getv(from_object, ["customization_configs"])], + ) + + if getv(from_object, ["disable_memory_revisions"]) is not None: + setv( + to_object, + ["disableMemoryRevisions"], + getv(from_object, ["disable_memory_revisions"]), + ) + + if getv(from_object, ["generation_config"]) is not None: + setv(to_object, ["generationConfig"], getv(from_object, ["generation_config"])) + + if getv(from_object, ["similarity_search_config"]) is not None: + setv( + to_object, + ["similaritySearchConfig"], + getv(from_object, ["similarity_search_config"]), + ) + + if getv(from_object, ["ttl_config"]) is not None: + setv(to_object, ["ttlConfig"], getv(from_object, ["ttl_config"])) + + if getv(from_object, ["structured_memory_configs"]) is not None: + setv( + to_object, + ["structuredMemoryConfigs"], + [ + _StructuredMemoryConfig_to_vertex(item, to_object) + for item in getv(from_object, ["structured_memory_configs"]) + ], + ) + + return to_object + + +def _ReasoningEngineContextSpec_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["memoryBankConfig"]) is not None: + setv( + to_object, + ["memory_bank_config"], + _ReasoningEngineContextSpecMemoryBankConfig_from_vertex( + getv(from_object, ["memoryBankConfig"]), to_object + ), + ) + + return to_object + + +def _ReasoningEngineContextSpec_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["memory_bank_config"]) is not None: + setv( + to_object, + ["memoryBankConfig"], + _ReasoningEngineContextSpecMemoryBankConfig_to_vertex( + getv(from_object, ["memory_bank_config"]), to_object + ), + ) + + return to_object + + +def _ReasoningEngine_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["encryptionSpec"]) is not None: + setv(to_object, ["encryption_spec"], getv(from_object, ["encryptionSpec"])) + + if getv(from_object, ["contextSpec"]) is not None: + setv( + to_object, + ["context_spec"], + _ReasoningEngineContextSpec_from_vertex( + getv(from_object, ["contextSpec"]), to_object + ), + ) + + if getv(from_object, ["createTime"]) is not None: + setv(to_object, ["create_time"], getv(from_object, ["createTime"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(from_object, ["displayName"])) + + if getv(from_object, ["etag"]) is not None: + setv(to_object, ["etag"], getv(from_object, ["etag"])) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["spec"]) is not None: + setv(to_object, ["spec"], getv(from_object, ["spec"])) + + if getv(from_object, ["updateTime"]) is not None: + setv(to_object, ["update_time"], getv(from_object, ["updateTime"])) + + if getv(from_object, ["trafficConfig"]) is not None: + setv(to_object, ["traffic_config"], getv(from_object, ["trafficConfig"])) + + return to_object + + +def _RunQueryJobAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["input_gcs_uri"]) is not None: + setv(parent_object, ["inputGcsUri"], getv(from_object, ["input_gcs_uri"])) + + if getv(from_object, ["output_gcs_uri"]) is not None: + setv(parent_object, ["outputGcsUri"], getv(from_object, ["output_gcs_uri"])) + + return to_object + + +def _RunQueryJobAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _RunQueryJobAgentEngineConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _StructuredMemoryConfig_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["schemaConfigs"]) is not None: + setv( + to_object, + ["schema_configs"], + [ + _StructuredMemorySchemaConfig_from_vertex(item, to_object) + for item in getv(from_object, ["schemaConfigs"]) + ], + ) + + if getv(from_object, ["scopeKeys"]) is not None: + setv(to_object, ["scope_keys"], getv(from_object, ["scopeKeys"])) + + return to_object + + +def _StructuredMemoryConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["schema_configs"]) is not None: + setv( + to_object, + ["schemaConfigs"], + [ + _StructuredMemorySchemaConfig_to_vertex(item, to_object) + for item in getv(from_object, ["schema_configs"]) + ], + ) + + if getv(from_object, ["scope_keys"]) is not None: + setv(to_object, ["scopeKeys"], getv(from_object, ["scope_keys"])) + + return to_object + + +def _StructuredMemorySchemaConfig_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["schema"]) is not None: + setv(to_object, ["memory_schema"], getv(from_object, ["schema"])) + + if getv(from_object, ["id"]) is not None: + setv(to_object, ["id"], getv(from_object, ["id"])) + + if getv(from_object, ["memoryType"]) is not None: + setv(to_object, ["memory_type"], getv(from_object, ["memoryType"])) + + return to_object + + +def _StructuredMemorySchemaConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["memory_schema"]) is not None: + setv(to_object, ["schema"], getv(from_object, ["memory_schema"])) + + if getv(from_object, ["id"]) is not None: + setv(to_object, ["id"], getv(from_object, ["id"])) + + if getv(from_object, ["memory_type"]) is not None: + setv(to_object, ["memoryType"], getv(from_object, ["memory_type"])) + + return to_object + + +def _UpdateAgentEngineConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["spec"]) is not None: + setv(parent_object, ["spec"], getv(from_object, ["spec"])) + + if getv(from_object, ["context_spec"]) is not None: + setv( + parent_object, + ["contextSpec"], + _ReasoningEngineContextSpec_to_vertex( + getv(from_object, ["context_spec"]), to_object + ), + ) + + if getv(from_object, ["psc_interface_config"]) is not None: + setv( + parent_object, + ["pscInterfaceConfig"], + getv(from_object, ["psc_interface_config"]), + ) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(parent_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["source_packages"]) is not None: + setv(parent_object, ["sourcePackages"], getv(from_object, ["source_packages"])) + + if getv(from_object, ["entrypoint_module"]) is not None: + setv( + parent_object, + ["entrypointModule"], + getv(from_object, ["entrypoint_module"]), + ) + + if getv(from_object, ["entrypoint_object"]) is not None: + setv( + parent_object, + ["entrypointObject"], + getv(from_object, ["entrypoint_object"]), + ) + + if getv(from_object, ["requirements_file"]) is not None: + setv( + parent_object, + ["requirementsFile"], + getv(from_object, ["requirements_file"]), + ) + + if getv(from_object, ["agent_framework"]) is not None: + setv(parent_object, ["agentFramework"], getv(from_object, ["agent_framework"])) + + if getv(from_object, ["python_version"]) is not None: + setv(parent_object, ["pythonVersion"], getv(from_object, ["python_version"])) + + if getv(from_object, ["agent_gateway_config"]) is not None: + setv( + parent_object, + ["agentGatewayConfig"], + getv(from_object, ["agent_gateway_config"]), + ) + + if getv(from_object, ["update_mask"]) is not None: + setv( + parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) + ) + + if getv(from_object, ["traffic_config"]) is not None: + setv(parent_object, ["trafficConfig"], getv(from_object, ["traffic_config"])) + + return to_object + + +def _UpdateAgentEngineRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _UpdateAgentEngineConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +class AgentEngines(_api_module.BaseModule): + + def cancel_query_job( + self, + *, + name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + """ + Cancels a long-running query job on an Agent Engine. + + Args: + name (str): + Required. The reasoning engine resource name. + config (CancelQueryJobAgentEngineConfigOrDict): + Optional. The configuration for the cancel_query_job. + + """ + + parameter_model = types._CancelQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CancelQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:cancelAsyncQuery".format_map(request_url_dict) + else: + path = "{name}:cancelAsyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.CancelQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _check_query_job( + self, + *, + name: str, + config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CheckQueryJobResult: + """ + Query an Agent Engine asynchronously. + """ + + parameter_model = types._CheckQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CheckQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:checkQueryJob".format_map(request_url_dict) + else: + path = "{name}:checkQueryJob" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CheckQueryJobResult_from_vertex(response_dict) + + return_value = types.CheckQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _run_query_job( + self, + *, + name: str, + config: Optional[types._RunQueryJobAgentEngineConfigOrDict] = None, + ) -> types.AgentEngineOperation: + """ + Run a query job on an agent engine. + """ + + parameter_model = types._RunQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RunQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:asyncQuery".format_map(request_url_dict) + else: + path = "{name}:asyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create( + self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """ + Creates a new Agent Engine. + """ + + parameter_model = types._CreateAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """ + Delete an Agent Engine resource. + + Args: + name (str): + Required. The name of the Agent Engine to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): + Optional. If set to True, child resources will also be deleted. + Otherwise, the request will fail with FAILED_PRECONDITION error when + the Agent Engine has undeleted child resources. Defaults to False. + config (DeleteAgentEngineConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + + parameter_model = types._DeleteAgentEngineRequestParameters( + name=name, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get( + self, *, name: str, config: Optional[types.GetAgentEngineConfigOrDict] = None + ) -> types.ReasoningEngine: + """ + Get an Agent Engine instance. + """ + + parameter_model = types._GetAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ReasoningEngine_from_vertex(response_dict) + + return_value = types.ReasoningEngine._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> types.ListReasoningEnginesResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListReasoningEnginesResponse_from_vertex(response_dict) + + return_value = types.ListReasoningEnginesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_agent_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineOperation: + parameter_model = types._GetAgentEngineOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _query( + self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None + ) -> types.QueryReasoningEngineResponse: + """ + Query an Agent Engine. + """ + + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _update( + self, *, name: str, config: Optional[types.UpdateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """ + Updates an Agent Engine. + """ + + parameter_model = types._UpdateAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _a2a_tasks = None + _memories = None + _sandboxes = None + _sessions = None + _runtimes = None + + @property + def runtimes(self) -> "runtimes_module.Runtimes": + if self._runtimes is None: + try: + # We need to lazy load the runtimes module to handle the + # possibility of ImportError when dependencies are not installed. + self._runtimes = importlib.import_module(".runtimes", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.runtimes' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._runtimes.Runtimes(self._api_client) # type: ignore[no-any-return] + + @property + def a2a_tasks(self) -> "a2a_tasks_module.A2aTasks": + if self._a2a_tasks is None: + try: + # We need to lazy load the a2a_tasks module to handle the + # possibility of ImportError when dependencies are not installed. + self._a2a_tasks = importlib.import_module(".a2a_tasks", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.a2a_tasks' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._a2a_tasks.A2aTasks(self._api_client) # type: ignore[no-any-return] + + @property + def memories(self) -> "memories_module.Memories": + if self._memories is None: + try: + # We need to lazy load the memories module to handle the + # possibility of ImportError when dependencies are not installed. + self._memories = importlib.import_module(".memories", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.memories' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._memories.Memories(self._api_client) # type: ignore[no-any-return] + + @property + def sandboxes(self) -> Any: + if self._sandboxes is None: + try: + # We need to lazy load the sandboxes module to handle the + # possibility of ImportError when dependencies are not installed. + self._sandboxes = importlib.import_module(".sandboxes", __package__) + except ImportError as e: + raise ImportError( + "The agent_engines.sandboxes module requires additional packages. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._sandboxes.Sandboxes(self._api_client) + + @property + def sessions(self) -> "sessions_module.Sessions": + if self._sessions is None: + try: + # We need to lazy load the sessions module to handle the + # possibility of ImportError when dependencies are not installed. + self._sessions = importlib.import_module(".sessions", __package__) + except ImportError as e: + raise ImportError( + "The agent_engines.sessions module requires additional packages. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._sessions.Sessions(self._api_client) # type: ignore[no-any-return] + + def _list_pager( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> Pager[types.ReasoningEngine]: + return Pager( + "reasoning_engines", + self._list, + self._list(config=config), + config, + ) + + def check_query_job( + self, + *, + name: str, + config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CheckQueryJobResult: + """Checks a query job on an agent engine and optionally returns the results. + + Args: + name (str): + Required. A fully-qualified resource name or ID. + config (CheckQueryJobAgentEngineConfigOrDict): + Optional. The configuration for the check_query_job. If not provided, + the default configuration will be used. This can be used to specify + the following fields: + - retrieve_result: Whether to retrieve the results of the query job. + """ + from google.cloud import storage # type: ignore[attr-defined] + import json + + if config is None: + config = types.CheckQueryJobAgentEngineConfig() + elif isinstance(config, dict): + config = types.CheckQueryJobAgentEngineConfig(**config) + + raw_response = self._api_client.request("get", name, {}) + if hasattr(raw_response, "body"): + operation = ( + json.loads(raw_response.body) + if isinstance(raw_response.body, str) + else raw_response.body + ) + else: + operation = raw_response + + status = "RUNNING" + if isinstance(operation, dict): + if operation.get("done"): + status = "FAILED" if operation.get("error") else "SUCCESS" + + response_dict = operation.get("response", {}) + output_gcs_uri = response_dict.get("outputGcsUri") or response_dict.get( + "output_gcs_uri" + ) + error = operation.get("error") + else: + if getattr(operation, "done", False): + status = "FAILED" if getattr(operation, "error", None) else "SUCCESS" + + response_obj = getattr(operation, "response", None) + if isinstance(response_obj, dict): + output_gcs_uri = response_obj.get("outputGcsUri") or response_obj.get( + "output_gcs_uri" + ) + else: + output_gcs_uri = ( + getattr( + response_obj, + "output_gcs_uri", + getattr(response_obj, "outputGcsUri", None), + ) + if response_obj + else None + ) + error = getattr(operation, "error", None) + + result_str = None + if status == "SUCCESS" and config.retrieve_result and output_gcs_uri: + storage_client = storage.Client( + project=self._api_client.project, + credentials=self._api_client._credentials, + ) + bucket_name = output_gcs_uri.replace("gs://", "").split("/")[0] + blob_name = output_gcs_uri.replace(f"gs://{bucket_name}/", "") + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + if blob.exists(): + result_str = blob.download_as_string().decode("utf-8") + else: + raise ValueError( + f"Failed to retrieve blob results for {output_gcs_uri}" + ) + + elif status == "FAILED" and error: + result_str = str(error) + + return types.CheckQueryJobResult( + operation_name=name, + output_gcs_uri=output_gcs_uri, + status=status, + result=result_str, + ) + + def _is_lightweight_creation( + self, agent: Any, config: types.AgentEngineConfig + ) -> bool: + if ( + agent + or config.source_packages + or config.developer_connect_source + or config.agent_config_source + or config.container_spec + ): + return False + return True + + def run_query_job( + self, + *, + name: str, + config: Optional[types.RunQueryJobAgentEngineConfigOrDict] = None, + ) -> types.RunQueryJobResult: + """Launches a long-running query job on an Agent Engine + + Args: + name (str): + Required. A fully-qualified resource name or ID. + config (RunQueryJobAgentEngineConfigOrDict): + Optional. The configuration for the async query. If not provided, + the default configuration will be used. This can be used to specify + the following fields: + - query: The query to send to the agent engine. + - output_gcs_uri: The GCS URI to use for the output. + """ + from google.cloud import storage # type: ignore[attr-defined] + from google.api_core import exceptions + import uuid + + if config is None: + config = types.RunQueryJobAgentEngineConfig() + elif isinstance(config, dict): + config = types.RunQueryJobAgentEngineConfig(**config) + + if not config.query: + raise ValueError("`query` is required in the config object.") + if not config.output_gcs_uri: + raise ValueError("`output_gcs_uri` is required in the config object.") + + output_gcs_uri = config.output_gcs_uri + is_file = False + last_part = "" + if not output_gcs_uri.endswith("/"): + last_part = output_gcs_uri.split("/")[-1] + if "." in last_part: + is_file = True + + if is_file: + path_parts = output_gcs_uri.split("/") + file_name = path_parts[-1] + base_uri = "/".join(path_parts[:-1]) + name_parts = file_name.rsplit(".", 1) + if len(name_parts) == 2: + name_part, ext = name_parts[0], "." + name_parts[1] + else: + name_part = name_parts[0] + ext = "" + input_gcs_uri = f"{base_uri}/{name_part}_input{ext}" + else: + job_uuid = uuid.uuid4().hex + gcs_path = output_gcs_uri.rstrip("/") + input_gcs_uri = f"{gcs_path}/{job_uuid}_input.json" + output_gcs_uri = f"{gcs_path}/{job_uuid}_output.json" + + storage_client = storage.Client( + project=self._api_client.project, credentials=self._api_client._credentials + ) + + # Handle creating the bucket if it does not exist + bucket_name = config.output_gcs_uri.replace("gs://", "").split("/")[0] + bucket = storage_client.bucket(bucket_name) + + try: + bucket_exists = bucket.exists() + except exceptions.Forbidden as e: + raise ValueError( + f"Permission denied to check existence of bucket '{bucket_name}'. " + "The service account may lack 'storage.buckets.get' permission." + ) from e + + if not bucket_exists: + try: + bucket.create() + except exceptions.Forbidden as e: + raise ValueError( + f"Permission denied to create bucket '{bucket_name}'. " + "The service account may lack 'storage.buckets.create' permission." + ) from e + + input_blob_name = input_gcs_uri.replace(f"gs://{bucket_name}/", "") + blob = bucket.blob(input_blob_name) + blob.upload_from_string(config.query) + + new_config = types._RunQueryJobAgentEngineConfig( + input_gcs_uri=input_gcs_uri, + output_gcs_uri=output_gcs_uri, + ) + + # Proceed with sending the async query via the auto-generated method + operation = self._run_query_job(name=name, config=new_config) + + return types.RunQueryJobResult( + job_name=operation.name, + input_gcs_uri=input_gcs_uri, + output_gcs_uri=output_gcs_uri, + ) + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineConfigOrDict] = None, + ) -> types.AgentEngine: + """Gets an agent engine. + + Args: + name (str): + Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456" or + a shortened name such as "reasoningEngines/456". + """ + api_resource = self._get(name=name, config=config) + agent_engine = types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=api_resource, + ) + if api_resource.spec: + self._register_api_methods(agent_engine=agent_engine) + return agent_engine + + def delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """ + Delete an Agent Engine resource. + + Args: + name (str): + Required. The name of the Agent Engine to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): + Optional. If set to True, child resources will also be deleted. + Otherwise, the request will fail with FAILED_PRECONDITION error when + the Agent Engine has undeleted child resources. Defaults to False. + config (DeleteAgentEngineConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + logger.info(f"Deleting AgentEngine resource: {name}") + operation = self._delete(name=name, force=force, config=config) + logger.info(f"Started AgentEngine delete operation: {operation.name}") + return operation + + def create( + self, + *, + agent_engine: Any = None, + agent: Any = None, + config: Optional[types.AgentEngineConfigOrDict] = None, + ) -> types.AgentEngine: + """Creates an agent engine. + + The Agent Engine will be an instance of the `agent_engine` that + was passed in, running remotely on Vertex AI. + + Sample ``src_dir`` contents (e.g. ``./user_src_dir``): + + .. code-block:: python + + user_src_dir/ + |-- main.py + |-- requirements.txt + |-- user_code/ + | |-- utils.py + | |-- ... + |-- ... + + To build an Agent Engine with the above files, run: + + .. code-block:: python + + client = vertexai.Client( + project="your-project", + location="us-central1", + ) + remote_agent = client.agent_engines.create( + agent=local_agent, + config=dict( + requirements=[ + # I.e. the PyPI dependencies listed in requirements.txt + "google-cloud-aiplatform[agent_engines,adk]", + ... + ], + extra_packages=[ + "./user_src_dir/main.py", # a single file + "./user_src_dir/user_code", # a directory + ... + ], + ), + ) + + Args: + agent (Any): + Optional. The Agent to be created. If not specified, this will + correspond to a lightweight instance that cannot be queried + (but can be updated to future instances that can be queried). + agent_engine (Any): + Optional. This is deprecated. Please use `agent` instead. + config (AgentEngineConfig): + Optional. The configurations to use for creating the Agent Engine. + + Returns: + AgentEngine: The created Agent Engine instance. + + Raises: + ValueError: If the `project` was not set using `client.Client`. + ValueError: If the `location` was not set using `client.Client`. + ValueError: If `config.staging_bucket` was not set when `agent` + is specified. + ValueError: If `config.staging_bucket` does not start with "gs://". + ValueError: If `config.extra_packages` is specified but `agent` + is None. + ValueError: If `config.requirements` is specified but `agent` is None. + ValueError: If `config.env_vars` has a dictionary entry that does not + correspond to an environment variable value or a SecretRef. + TypeError: If `config.env_vars` is not a dictionary. + FileNotFoundError: If `config.extra_packages` includes a file or + directory that does not exist. + IOError: If ``config.requirements` is a string that corresponds to a + nonexistent file. + """ + if config is None: + config = {} + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( + f"config must be a dict or AgentEngineConfig, but got {type(config)}." + ) + context_spec = config.context_spec + if context_spec is not None: + # Conversion to a dict for _create_config + context_spec = json.loads(context_spec.model_dump_json()) + developer_connect_source = config.developer_connect_source + if developer_connect_source is not None: + developer_connect_source = json.loads( + developer_connect_source.model_dump_json() + ) + agent_config_source = config.agent_config_source + if agent_config_source is not None: + agent_config_source = json.loads(agent_config_source.model_dump_json()) + keep_alive_probe = config.keep_alive_probe + if keep_alive_probe is not None: + keep_alive_probe = json.loads( + keep_alive_probe.model_dump_json(exclude_none=True) + ) + if agent and agent_engine: + raise ValueError("Please specify only one of `agent` or `agent_engine`.") + elif agent_engine: + raise DeprecationWarning( + "The `agent_engine` argument is deprecated. Please use `agent` instead." + ) + agent = agent or agent_engine + api_config = self._create_config( + mode="create", + agent=agent, + identity_type=config.identity_type, + staging_bucket=config.staging_bucket, + requirements=config.requirements, + display_name=config.display_name, + description=config.description, + gcs_dir_name=config.gcs_dir_name, + extra_packages=config.extra_packages, + env_vars=config.env_vars, + service_account=config.service_account, + context_spec=context_spec, + psc_interface_config=config.psc_interface_config, + agent_gateway_config=config.agent_gateway_config, + min_instances=config.min_instances, + max_instances=config.max_instances, + resource_limits=config.resource_limits, + container_concurrency=config.container_concurrency, + encryption_spec=config.encryption_spec, + agent_server_mode=config.agent_server_mode, + labels=config.labels, + class_methods=config.class_methods, + source_packages=config.source_packages, + developer_connect_source=developer_connect_source, + entrypoint_module=config.entrypoint_module, + entrypoint_object=config.entrypoint_object, + requirements_file=config.requirements_file, + agent_framework=config.agent_framework, + python_version=config.python_version, + build_options=config.build_options, + image_spec=config.image_spec, + agent_config_source=agent_config_source, + container_spec=config.container_spec, + keep_alive_probe=keep_alive_probe, + ) + operation = self._create(config=api_config) + reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( + operation_name=operation.name + ) + logger.info( + "View progress and logs at https://console.cloud.google.com/logs/query?" + f"project={self._api_client.project}" + "&query=resource.type%3D%22aiplatform.googleapis.com%2FReasoningEngine%22%0A" + f"resource.labels.reasoning_engine_id%3D%22{reasoning_engine_id}%22." + ) + if not self._is_lightweight_creation(agent, config): + poll_interval_seconds = 10 + else: + poll_interval_seconds = 1 # Lightweight agent engine resource creation. + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_agent_operation, + poll_interval_seconds=poll_interval_seconds, + ) + + agent_engine = types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=operation.response, + ) + if agent_engine.api_resource: + logger.info("Agent Engine created. To use it in another session:") + logger.info( + f"agent_engine=client.agent_engines.get(name='{agent_engine.api_resource.name}')" + ) + elif operation.error: + raise RuntimeError(f"Failed to create Agent Engine: {operation.error}") + else: + logger.warning("The operation returned an empty response.") + if not self._is_lightweight_creation(agent, config): + # If the user did not provide an agent_engine (e.g. lightweight + # provisioning), it will not have any API methods registered. + agent_engine = self._register_api_methods(agent_engine=agent_engine) + return agent_engine # type: ignore[no-any-return] + + def _set_source_code_spec( + self, + *, + spec: types.ReasoningEngineSpecDict, + update_masks: list[str], + source_packages: Optional[Sequence[str]] = None, + developer_connect_source: Optional[ + types.ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict + ] = None, + class_methods: Optional[Sequence[dict[str, Any]]] = None, + entrypoint_module: Optional[str] = None, + entrypoint_object: Optional[str] = None, + requirements_file: Optional[str] = None, + sys_version: str, + build_options: Optional[dict[str, list[str]]] = None, + image_spec: Optional[ + types.ReasoningEngineSpecSourceCodeSpecImageSpecDict + ] = None, + agent_config_source: Optional[ + types.ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict + ] = None, + ) -> None: + """Sets source_code_spec for agent engine inside the `spec`.""" + source_code_spec = types.ReasoningEngineSpecSourceCodeSpecDict() + if source_packages and not agent_config_source: + source_packages = _agent_engines_utils._validate_packages_or_raise( + packages=source_packages, + build_options=build_options, + ) + update_masks.append("spec.source_code_spec.inline_source.source_archive") + source_code_spec["inline_source"] = { # type: ignore[typeddict-item] + "source_archive": _agent_engines_utils._create_base64_encoded_tarball( + source_packages=source_packages + ) + } + elif developer_connect_source: + update_masks.append("spec.source_code_spec.developer_connect_source") + source_code_spec["developer_connect_source"] = { + "config": developer_connect_source + } + elif not agent_config_source: + raise ValueError( + "Please specify one of `source_packages`, `developer_connect_source`, " + "or `agent_config_source`." + ) + if class_methods is not None: + update_masks.append("spec.class_methods") + class_methods_spec_list = ( + _agent_engines_utils._class_methods_to_class_methods_spec( + class_methods=class_methods + ) + ) + spec["class_methods"] = [ + _agent_engines_utils._to_dict(class_method_spec) + for class_method_spec in class_methods_spec_list + ] + elif image_spec is None: + raise ValueError( + "`class_methods` must be specified if `source_packages`, " + "`developer_connect_source`, or `agent_config_source` is " + "specified without a Dockerfile or `image_spec`." + ) + if image_spec is not None: + if entrypoint_module or entrypoint_object or requirements_file: + raise ValueError( + "`image_spec` cannot be specified alongside `entrypoint_module`, " + "`entrypoint_object`, or `requirements_file`, as they are " + "mutually exclusive." + ) + if agent_config_source: + raise ValueError( + "`image_spec` cannot be specified alongside `agent_config_source`, " + "as they are mutually exclusive." + ) + update_masks.append("spec.source_code_spec.image_spec") + source_code_spec["image_spec"] = image_spec + spec["source_code_spec"] = source_code_spec + return + + update_masks.append("spec.source_code_spec.python_spec.version") + python_spec: types.ReasoningEngineSpecSourceCodeSpecPythonSpecDict = { + "version": sys_version, + } + if agent_config_source is not None: + if entrypoint_module or entrypoint_object: + logger.warning( + "`entrypoint_module` and `entrypoint_object` are ignored when " + "`agent_config_source` is specified, as they are pre-defined." + ) + if source_packages: + source_packages = _agent_engines_utils._validate_packages_or_raise( + packages=source_packages, + build_options=build_options, + ) + update_masks.append( + "spec.source_code_spec.agent_config_source.inline_source.source_archive" + ) + agent_config_source["inline_source"] = { # type: ignore[typeddict-item] + "source_archive": _agent_engines_utils._create_base64_encoded_tarball( + source_packages=source_packages + ) + } + update_masks.append("spec.source_code_spec.agent_config_source") + source_code_spec["agent_config_source"] = agent_config_source + + if requirements_file is not None: + update_masks.append( + "spec.source_code_spec.python_spec.requirements_file" + ) + python_spec["requirements_file"] = requirements_file + source_code_spec["python_spec"] = python_spec + + spec["source_code_spec"] = source_code_spec + return + + if not entrypoint_module: + raise ValueError( + "`entrypoint_module` must be specified if `source_packages` or `developer_connect_source` is specified." + ) + update_masks.append("spec.source_code_spec.python_spec.entrypoint_module") + python_spec["entrypoint_module"] = entrypoint_module + if not entrypoint_object: + raise ValueError( + "`entrypoint_object` must be specified if `source_packages` or `developer_connect_source` is specified." + ) + update_masks.append("spec.source_code_spec.python_spec.entrypoint_object") + python_spec["entrypoint_object"] = entrypoint_object + if requirements_file is not None: + update_masks.append("spec.source_code_spec.python_spec.requirements_file") + python_spec["requirements_file"] = requirements_file + source_code_spec["python_spec"] = python_spec + spec["source_code_spec"] = source_code_spec + + def _set_package_spec( + self, + *, + spec: types.ReasoningEngineSpecDict, + update_masks: list[str], + agent: Any, + staging_bucket: Optional[str] = None, + requirements: Optional[Union[str, Sequence[str]]] = None, + gcs_dir_name: Optional[str] = None, + extra_packages: Optional[Sequence[str]] = None, + class_methods: Optional[Sequence[dict[str, Any]]] = None, + sys_version: str, + build_options: Optional[dict[str, list[str]]] = None, + ) -> None: + """Sets package spec for agent engine.""" + project = self._api_client.project + if project is None: + raise ValueError("project must be set using `vertexai.Client`.") + location = self._api_client.location + if location is None: + raise ValueError("location must be set using `vertexai.Client`.") + gcs_dir_name = gcs_dir_name or _agent_engines_utils._DEFAULT_GCS_DIR_NAME + staging_bucket = _agent_engines_utils._validate_staging_bucket_or_raise( + staging_bucket=staging_bucket, + ) + requirements = _agent_engines_utils._validate_requirements_or_raise( + agent=agent, + requirements=requirements, + ) + extra_packages = _agent_engines_utils._validate_packages_or_raise( + packages=extra_packages, + build_options=build_options, + ) + # Prepares the Agent Engine for creation/update in Vertex AI. This + # involves packaging and uploading the artifacts for agent_engine, + # requirements and extra_packages to `staging_bucket/gcs_dir_name`. + _agent_engines_utils._prepare( + agent=agent, + requirements=requirements, + project=project, + location=location, + staging_bucket=staging_bucket, + gcs_dir_name=gcs_dir_name, + extra_packages=extra_packages, + credentials=self._api_client._credentials, + ) + # Update the package spec. + update_masks.append("spec.package_spec.pickle_object_gcs_uri") + package_spec: types.ReasoningEngineSpecPackageSpecDict = { + "python_version": sys_version, + "pickle_object_gcs_uri": "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines_utils._BLOB_FILENAME, + ), + } + if extra_packages: + update_masks.append("spec.package_spec.dependency_files_gcs_uri") + package_spec["dependency_files_gcs_uri"] = "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines_utils._EXTRA_PACKAGES_FILE, + ) + if requirements: + update_masks.append("spec.package_spec.requirements_gcs_uri") + package_spec["requirements_gcs_uri"] = "{}/{}/{}".format( + staging_bucket, + gcs_dir_name, + _agent_engines_utils._REQUIREMENTS_FILE, + ) + spec["package_spec"] = package_spec + + update_masks.append("spec.class_methods") + if class_methods is not None: + class_methods_spec_list = ( + _agent_engines_utils._class_methods_to_class_methods_spec( + class_methods=class_methods + ) + ) + else: + class_methods_spec_list = ( + _agent_engines_utils._generate_class_methods_spec_or_raise( + agent=agent, + operations=_agent_engines_utils._get_registered_operations( + agent=agent + ), + ) + ) + spec["class_methods"] = [ + _agent_engines_utils._to_dict(class_method_spec) + for class_method_spec in class_methods_spec_list + ] + + def _create_config( + self, + *, + mode: str, + agent: Any = None, + identity_type: Optional[types.IdentityType] = None, + staging_bucket: Optional[str] = None, + requirements: Optional[Union[str, Sequence[str]]] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + gcs_dir_name: Optional[str] = None, + extra_packages: Optional[Sequence[str]] = None, + env_vars: Optional[dict[str, Union[str, Any]]] = None, + service_account: Optional[str] = None, + context_spec: Optional[types.ReasoningEngineContextSpecDict] = None, + psc_interface_config: Optional[types.PscInterfaceConfigDict] = None, + agent_gateway_config: Optional[ + types.ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] = None, + min_instances: Optional[int] = None, + max_instances: Optional[int] = None, + resource_limits: Optional[dict[str, str]] = None, + container_concurrency: Optional[int] = None, + encryption_spec: Optional[genai_types.EncryptionSpecDict] = None, + labels: Optional[dict[str, str]] = None, + agent_server_mode: Optional[types.AgentServerMode] = None, + class_methods: Optional[Sequence[dict[str, Any]]] = None, + source_packages: Optional[Sequence[str]] = None, + developer_connect_source: Optional[ + types.ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict + ] = None, + entrypoint_module: Optional[str] = None, + entrypoint_object: Optional[str] = None, + requirements_file: Optional[str] = None, + agent_framework: Optional[str] = None, + python_version: Optional[str] = None, + build_options: Optional[dict[str, list[str]]] = None, + image_spec: Optional[ + types.ReasoningEngineSpecSourceCodeSpecImageSpecDict + ] = None, + agent_config_source: Optional[ + types.ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict + ] = None, + container_spec: Optional[types.ReasoningEngineSpecContainerSpecDict] = None, + keep_alive_probe: Optional[dict[str, Any]] = None, + traffic_config: Optional[types.ReasoningEngineTrafficConfigDict] = None, + ) -> types.UpdateAgentEngineConfigDict: + import sys + + config: types.UpdateAgentEngineConfigDict = {} + update_masks = [] + if mode not in ["create", "update"]: + raise ValueError(f"Unsupported mode: {mode}") + if agent is None: + if requirements is not None: + raise ValueError("requirements must be None if agent is None.") + if extra_packages is not None: + raise ValueError("extra_packages must be None if agent is None.") + if display_name is not None: + update_masks.append("display_name") + config["display_name"] = display_name + if description is not None: + update_masks.append("description") + config["description"] = description + if context_spec is not None: + update_masks.append("context_spec") + config["context_spec"] = context_spec + if encryption_spec is not None: + update_masks.append("encryption_spec") + config["encryption_spec"] = encryption_spec + if labels is not None: + update_masks.append("labels") + config["labels"] = labels + if traffic_config is not None: + update_masks.append("traffic_config") + config["traffic_config"] = traffic_config + + if agent_framework == "google-adk": + env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) + + if python_version: + sys_version = python_version + else: + sys_version = f"{sys.version_info.major}.{sys.version_info.minor}" + + if agent: + if source_packages: + raise ValueError( + "If you have provided `source_packages` in `config`, please " + "do not specify `agent` in `agent_engines.create()` or " + "`agent_engines.update()`." + ) + if developer_connect_source: + raise ValueError( + "If you have provided `developer_connect_source` in `config`, please " + "do not specify `agent` in `agent_engines.create()` or " + "`agent_engines.update()`." + ) + elif source_packages and developer_connect_source: + raise ValueError( + "Please specify only one of `source_packages` or `developer_connect_source` in `config`." + ) + + if container_spec: + if agent: + raise ValueError( + "If you have provided `container_spec` in `config`, please " + "do not specify `agent` in `agent_engines.create()` or " + "`agent_engines.update()`." + ) + if source_packages or developer_connect_source: + raise ValueError( + "If you have provided `container_spec` in `config`, please " + "do not specify `source_packages` or `developer_connect_source` in `config`." + ) + + agent_engine_spec: Any = None + if agent: + agent_engine_spec = {} + agent = _agent_engines_utils._validate_agent_or_raise(agent=agent) + if _agent_engines_utils._is_adk_agent(agent): + env_vars = _agent_engines_utils._add_telemetry_enablement_env(env_vars) + self._set_package_spec( + spec=agent_engine_spec, + update_masks=update_masks, + agent=agent, + staging_bucket=staging_bucket, + requirements=requirements, + gcs_dir_name=gcs_dir_name, + extra_packages=extra_packages, + class_methods=class_methods, + sys_version=sys_version, + build_options=build_options, + ) + elif ( + source_packages + or developer_connect_source + or image_spec + or agent_config_source + ): + agent_engine_spec = {} + self._set_source_code_spec( + spec=agent_engine_spec, + update_masks=update_masks, + source_packages=source_packages, + developer_connect_source=developer_connect_source, + class_methods=class_methods, + entrypoint_module=entrypoint_module, + entrypoint_object=entrypoint_object, + requirements_file=requirements_file, + sys_version=sys_version, + build_options=build_options, + image_spec=image_spec, + agent_config_source=agent_config_source, + ) + elif container_spec: + agent_engine_spec = {} + if class_methods is not None: + update_masks.append("spec.class_methods") + class_methods_spec_list = ( + _agent_engines_utils._class_methods_to_class_methods_spec( + class_methods=class_methods + ) + ) + agent_engine_spec["class_methods"] = [ + _agent_engines_utils._to_dict(class_method_spec) + for class_method_spec in class_methods_spec_list + ] + update_masks.append("spec.container_spec") + agent_engine_spec["container_spec"] = container_spec + + is_deployment_spec_updated = ( + env_vars is not None + or psc_interface_config is not None + or agent_gateway_config is not None + or min_instances is not None + or max_instances is not None + or resource_limits is not None + or container_concurrency is not None + or keep_alive_probe is not None + ) + if agent_engine_spec is None and is_deployment_spec_updated: + raise ValueError( + "To update `env_vars`, `psc_interface_config`, `min_instances`, " + "`max_instances`, `resource_limits`, `container_concurrency`, or " + "`keep_alive_probe`, you must also provide the `agent` variable or " + "the source code options (`source_packages`, " + "`developer_connect_source` or `agent_config_source`)." + ) + + if agent_engine_spec is not None: + if is_deployment_spec_updated: + ( + deployment_spec, + deployment_update_masks, + ) = self._generate_deployment_spec_or_raise( + env_vars=env_vars, + psc_interface_config=psc_interface_config, + agent_gateway_config=agent_gateway_config, + min_instances=min_instances, + max_instances=max_instances, + resource_limits=resource_limits, + container_concurrency=container_concurrency, + keep_alive_probe=keep_alive_probe, + ) + update_masks.extend(deployment_update_masks) + agent_engine_spec["deployment_spec"] = deployment_spec + + if agent_server_mode: + if not agent_engine_spec.get("deployment_spec"): + agent_engine_spec["deployment_spec"] = ( + types.ReasoningEngineSpecDeploymentSpecDict() + ) + agent_engine_spec["deployment_spec"][ + "agent_server_mode" + ] = agent_server_mode + + agent_engine_spec["agent_framework"] = ( + _agent_engines_utils._get_agent_framework( + agent_framework=agent_framework, + agent=agent, + ) + ) + + if hasattr(agent, "agent_card"): + agent_card = getattr(agent, "agent_card") + if agent_card: + try: + from google.protobuf import json_format + + agent_engine_spec["agent_card"] = json_format.MessageToDict( + agent_card + ) + except Exception as e: + raise ValueError( + f"Failed to convert agent card to dict (serialization error): {e}" + ) from e + update_masks.append("spec.agent_framework") + + if identity_type is not None or service_account is not None: + if agent_engine_spec is None: + agent_engine_spec = {} + + if identity_type is not None: + agent_engine_spec["identity_type"] = identity_type + update_masks.append("spec.identity_type") + if service_account is not None: + # Clear the field in case of empty service_account. + if service_account: + agent_engine_spec["service_account"] = service_account + update_masks.append("spec.service_account") + + if agent_engine_spec is not None: + config["spec"] = agent_engine_spec + + if update_masks and mode == "update": + config["update_mask"] = ",".join(update_masks) + return config + + def _generate_deployment_spec_or_raise( + self, + *, + env_vars: Optional[dict[str, Union[str, Any]]] = None, + psc_interface_config: Optional[types.PscInterfaceConfigDict] = None, + agent_gateway_config: Optional[ + types.ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] = None, + min_instances: Optional[int] = None, + max_instances: Optional[int] = None, + resource_limits: Optional[dict[str, str]] = None, + container_concurrency: Optional[int] = None, + keep_alive_probe: Optional[dict[str, Any]] = None, + ) -> Tuple[dict[str, Any], Sequence[str]]: + deployment_spec: dict[str, Any] = {} + update_masks = [] + if env_vars: + deployment_spec["env"] = [] + deployment_spec["secret_env"] = [] + if isinstance(env_vars, dict): + self._update_deployment_spec_with_env_vars_dict_or_raise( + deployment_spec=deployment_spec, + env_vars=env_vars, + ) + else: + raise TypeError(f"env_vars must be a dict, but got {type(env_vars)}.") + if deployment_spec.get("env"): + update_masks.append("spec.deployment_spec.env") + if deployment_spec.get("secret_env"): + update_masks.append("spec.deployment_spec.secret_env") + if psc_interface_config: + deployment_spec["psc_interface_config"] = psc_interface_config + update_masks.append("spec.deployment_spec.psc_interface_config") + if agent_gateway_config: + deployment_spec["agent_gateway_config"] = agent_gateway_config + update_masks.append("spec.deployment_spec.agent_gateway_config") + if min_instances is not None: + if not 0 <= min_instances <= 10: + raise ValueError( + f"min_instances must be between 0 and 10. Got {min_instances}" + ) + deployment_spec["min_instances"] = min_instances + update_masks.append("spec.deployment_spec.min_instances") + if max_instances is not None: + if psc_interface_config and not 1 <= max_instances <= 100: + raise ValueError( + f"max_instances must be between 1 and 100 when PSC-I is enabled. Got {max_instances}" + ) + elif not psc_interface_config and not 1 <= max_instances <= 1000: + raise ValueError( + f"max_instances must be between 1 and 1000. Got {max_instances}" + ) + deployment_spec["max_instances"] = max_instances + update_masks.append("spec.deployment_spec.max_instances") + if resource_limits: + _agent_engines_utils._validate_resource_limits_or_raise( + resource_limits=resource_limits + ) + deployment_spec["resource_limits"] = resource_limits + update_masks.append("spec.deployment_spec.resource_limits") + if container_concurrency: + deployment_spec["container_concurrency"] = container_concurrency + update_masks.append("spec.deployment_spec.container_concurrency") + if keep_alive_probe is not None: + deployment_spec["keep_alive_probe"] = keep_alive_probe + update_masks.append("spec.deployment_spec.keep_alive_probe") + return deployment_spec, update_masks + + def _update_deployment_spec_with_env_vars_dict_or_raise( + self, + *, + deployment_spec: dict[str, Any], + env_vars: dict[str, Any], + ) -> None: + for key, value in env_vars.items(): + if isinstance(value, dict): + if "secret_env" not in deployment_spec: + deployment_spec["secret_env"] = [] + deployment_spec["secret_env"].append({"name": key, "secret_ref": value}) + elif isinstance(value, str): + if "env" not in deployment_spec: + deployment_spec["env"] = [] + deployment_spec["env"].append({"name": key, "value": value}) + else: + raise TypeError( + f"Unknown value type in env_vars for {key}. " + f"Must be a str or SecretRef: {value}" + ) + + def _register_api_methods( + self, + *, + agent_engine: types.AgentEngine, + ) -> types.AgentEngine: + """Registers the API methods for the agent engine.""" + try: + _agent_engines_utils._register_api_methods_or_raise( + agent_engine=agent_engine, + wrap_operation_fn={ + "": _agent_engines_utils._wrap_query_operation, # type: ignore[dict-item] + "async": _agent_engines_utils._wrap_async_query_operation, # type: ignore[dict-item] + "stream": _agent_engines_utils._wrap_stream_query_operation, # type: ignore[dict-item] + "async_stream": _agent_engines_utils._wrap_async_stream_query_operation, # type: ignore[dict-item] + "a2a_extension": _agent_engines_utils._wrap_a2a_operation, + }, + ) + except Exception as e: + logger.warning( + _agent_engines_utils._FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE, e + ) + return agent_engine + + def list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> Iterator[types.AgentEngine]: + """List all instances of Agent Engine matching the filter. + + Example Usage: + + .. code-block:: python + import vertexai + + client = vertexai.Client(project="my_project", location="us-central1") + for agent in client.agent_engines.list( + config={"filter": "'display_name="My Custom Agent"'}, + ): + print(agent.api_resource.name) + + Args: + config (ListAgentEngineConfig): + Optional. The config (e.g. filter) for the agents to be listed. + + Returns: + Iterable[AgentEngine]: An iterable of Agent Engines matching the filter. + """ + + for reasoning_engine in self._list_pager(config=config): + yield types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=reasoning_engine, + ) + + def update( + self, + *, + name: str, + agent: Any = None, + agent_engine: Any = None, + config: types.AgentEngineConfigOrDict, + ) -> types.AgentEngine: + """Updates an existing Agent Engine. + + This method updates the configuration of an existing Agent Engine running + remotely, which is identified by its name. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456" or a + shortened name such as "reasoningEngines/456". + agent (Any): + Optional. The instance to be used as the updated Agent Engine. + If it is not specified, the existing instance will be used. + agent_engine (Any): + Optional. This is deprecated. Please use `agent` instead. + config (AgentEngineConfig): + Optional. The configurations to use for updating the Agent Engine. + + Returns: + AgentEngine: The updated Agent Engine. + + Raises: + ValueError: If the `project` was not set using `client.Client`. + ValueError: If the `location` was not set using `client.Client`. + ValueError: If `config.staging_bucket` was not set when `agent_engine` + is specified. + ValueError: If `config.staging_bucket` does not start with "gs://". + ValueError: If `config.extra_packages` is specified but `agent_engine` + is None. + ValueError: If `config.requirements` is specified but `agent_engine` is + None. + ValueError: If `config.env_vars` has a dictionary entry that does not + correspond to an environment variable value or a SecretRef. + TypeError: If `config.env_vars` is not a dictionary. + FileNotFoundError: If `config.extra_packages` includes a file or + directory that does not exist. + IOError: If `config.requirements` is a string that corresponds to a + nonexistent file. + """ + if isinstance(config, dict): + config = types.AgentEngineConfig.model_validate(config) + elif not isinstance(config, types.AgentEngineConfig): + raise TypeError( + f"config must be a dict or AgentEngineConfig, but got {type(config)}." + ) + context_spec = config.context_spec + if context_spec is not None: + # Conversion to a dict for _create_config + context_spec = json.loads(context_spec.model_dump_json()) + developer_connect_source = config.developer_connect_source + if developer_connect_source is not None: + developer_connect_source = json.loads( + developer_connect_source.model_dump_json() + ) + agent_config_source = config.agent_config_source + if agent_config_source is not None: + agent_config_source = json.loads(agent_config_source.model_dump_json()) + keep_alive_probe = config.keep_alive_probe + if keep_alive_probe is not None: + keep_alive_probe = json.loads( + keep_alive_probe.model_dump_json(exclude_none=True) + ) + traffic_config = config.traffic_config + if traffic_config is not None: + traffic_config = json.loads(traffic_config.model_dump_json()) + if agent and agent_engine: + raise ValueError("Please specify only one of `agent` or `agent_engine`.") + elif agent_engine: + raise DeprecationWarning( + "The `agent_engine` argument is deprecated. Please use `agent` instead." + ) + image_spec = config.image_spec + if image_spec is not None: + # Conversion to a dict for _create_config + image_spec = json.loads(image_spec.model_dump_json()) + container_spec = config.container_spec + if container_spec is not None: + # Conversion to a dict for _create_config + container_spec = json.loads(container_spec.model_dump_json()) + agent = agent or agent_engine + api_config = self._create_config( + mode="update", + agent=agent, + identity_type=config.identity_type, + staging_bucket=config.staging_bucket, + requirements=config.requirements, + display_name=config.display_name, + description=config.description, + gcs_dir_name=config.gcs_dir_name, + extra_packages=config.extra_packages, + env_vars=config.env_vars, + service_account=config.service_account, + context_spec=context_spec, + psc_interface_config=config.psc_interface_config, + agent_gateway_config=config.agent_gateway_config, + min_instances=config.min_instances, + max_instances=config.max_instances, + resource_limits=config.resource_limits, + container_concurrency=config.container_concurrency, + labels=config.labels, + class_methods=config.class_methods, + source_packages=config.source_packages, + developer_connect_source=developer_connect_source, + entrypoint_module=config.entrypoint_module, + entrypoint_object=config.entrypoint_object, + requirements_file=config.requirements_file, + agent_framework=config.agent_framework, + python_version=config.python_version, + build_options=config.build_options, + image_spec=image_spec, + agent_config_source=agent_config_source, + container_spec=container_spec, + keep_alive_probe=keep_alive_probe, + traffic_config=traffic_config, + ) + operation = self._update(name=name, config=api_config) + reasoning_engine_id = _agent_engines_utils._get_reasoning_engine_id( + resource_name=name + ) + logger.info( + "View progress and logs at https://console.cloud.google.com/logs/query?" + f"project={self._api_client.project}" + "&query=resource.type%3D%22aiplatform.googleapis.com%2FReasoningEngine%22%0A" + f"resource.labels.reasoning_engine_id%3D%22{reasoning_engine_id}%22." + ) + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_agent_operation, + ) + agent_engine = types.AgentEngine( + api_client=self, + api_async_client=AsyncAgentEngines(api_client_=self._api_client), + api_resource=operation.response, + ) + if agent_engine.api_resource: + logger.info("Agent Engine updated. To use it in another session:") + logger.info( + f"agent_engine=client.agent_engines.get(name='{agent_engine.api_resource.name}')" + ) + elif operation.error: + raise RuntimeError(f"Failed to update Agent Engine: {operation.error}") + if agent_engine.api_resource.spec: + self._register_api_methods(agent_engine=agent_engine) + return agent_engine # type: ignore[no-any-return] + + def _stream_query( + self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None + ) -> Iterator[Any]: + """Streams the response of the agent engine.""" + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + for response in self._api_client.request_streamed( + "post", path, request_dict, http_options + ): + yield response + + # TODO: b/436704146 - Replace with generated methods + # TODO: b/437129724 - Add replay test for async stream query + async def _async_stream_query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineConfigOrDict] = None, + ) -> AsyncIterator[Any]: + """Streams the response of the agent engine asynchronously.""" + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + async_iterator = await self._api_client.async_request_streamed( + "post", path, request_dict, http_options + ) + async for response in async_iterator: + yield response + + def create_memory( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """Deprecated. Use agent_engines.memories.create instead.""" + warnings.warn( + ( + "agent_engines.create_memory is deprecated. " + "Use agent_engines.memories.create instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.create( + name=name, + fact=fact, + scope=scope, + config=config, + ) + + def delete_memory( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineMemoryConfigOrDict] = None, + ) -> types.DeleteAgentEngineMemoryOperation: + """Deprecated. Use agent_engines.memories.delete instead.""" + warnings.warn( + ( + "agent_engines.delete_memory is deprecated. " + "Use agent_engines.memories.delete instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.delete(name=name, config=config) + + def generate_memories( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """Deprecated. Use agent_engines.memories.generate instead.""" + warnings.warn( + ( + "agent_engines.generate_memories is deprecated. " + "Use agent_engines.memories.generate instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.generate( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + + def get_memory( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryConfigOrDict] = None, + ) -> types.Memory: + """Deprecated. Use agent_engines.memories.get instead.""" + warnings.warn( + ( + "agent_engines.get_memory is deprecated. " + "Use agent_engines.memories.get instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.get(name=name, config=config) + + def list_memories( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> Iterator[types.Memory]: + """Deprecated. Use agent_engines.memories.list instead.""" + warnings.warn( + ( + "agent_engines.list_memories is deprecated. " + "Use agent_engines.memories.list instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.list(name=name, config=config) + + def retrieve_memories( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> Iterator[types.RetrieveMemoriesResponseRetrievedMemory]: + """Deprecated. Use agent_engines.memories.retrieve instead.""" + warnings.warn( + ( + "agent_engines.retrieve_memories is deprecated. " + "Use agent_engines.memories.retrieve instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.memories.retrieve( + name=name, + scope=scope, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + config=config, + ) + + def create_session( + self, + *, + name: str, + user_id: str, + config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """Deprecated. Use agent_engines.sessions.create instead.""" + warnings.warn( + ( + "agent_engines.create_session is deprecated. " + "Use agent_engines.sessions.create instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.create(name=name, user_id=user_id, config=config) + + def delete_session( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSessionConfigOrDict] = None, + ) -> types.DeleteAgentEngineSessionOperation: + """Deprecated. Use agent_engines.sessions.delete instead.""" + warnings.warn( + ( + "agent_engines.delete_session is deprecated. " + "Use agent_engines.sessions.delete instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.delete(name=name, config=config) + + def get_session( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSessionConfigOrDict] = None, + ) -> types.Session: + """Deprecated. Use agent_engines.sessions.get instead.""" + warnings.warn( + ( + "agent_engines.get_session is deprecated. " + "Use agent_engines.sessions.get instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.get(name=name, config=config) + + def list_sessions( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, + ) -> Iterator[types.Session]: + """Deprecated. Use agent_engines.sessions.list instead.""" + warnings.warn( + ( + "agent_engines.list_sessions is deprecated. " + "Use agent_engines.sessions.list instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.list(name=name, config=config) + + def append_session_event( + self, + *, + name: str, + author: str, + invocation_id: str, + timestamp: datetime.datetime, + config: Optional[types.AppendAgentEngineSessionEventConfigOrDict] = None, + ) -> types.AppendAgentEngineSessionEventResponse: + """Deprecated. Use agent_engines.sessions.events.append instead.""" + warnings.warn( + ( + "agent_engines.append_session_event is deprecated. " + "Use agent_engines.sessions.events.append instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.events.append( + name=name, + author=author, + invocation_id=invocation_id, + timestamp=timestamp, + config=config, + ) + + def list_session_events( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, + ) -> Iterator[types.SessionEvent]: + """Deprecated. Use agent_engines.sessions.events.list instead.""" + warnings.warn( + ( + "agent_engines.list_session_events is deprecated. " + "Use agent_engines.sessions.events.list instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return self.sessions.events.list(name=name, config=config) + + +class AsyncAgentEngines(_api_module.BaseModule): + + async def cancel_query_job( + self, + *, + name: str, + config: Optional[types.CancelQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CancelQueryJobResult: + """ + Cancels a long-running query job on an Agent Engine. + + Args: + name (str): + Required. The reasoning engine resource name. + config (CancelQueryJobAgentEngineConfigOrDict): + Optional. The configuration for the cancel_query_job. + + """ + + parameter_model = types._CancelQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CancelQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:cancelAsyncQuery".format_map(request_url_dict) + else: + path = "{name}:cancelAsyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.CancelQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _check_query_job( + self, + *, + name: str, + config: Optional[types.CheckQueryJobAgentEngineConfigOrDict] = None, + ) -> types.CheckQueryJobResult: + """ + Query an Agent Engine asynchronously. + """ + + parameter_model = types._CheckQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CheckQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:checkQueryJob".format_map(request_url_dict) + else: + path = "{name}:checkQueryJob" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CheckQueryJobResult_from_vertex(response_dict) + + return_value = types.CheckQueryJobResult._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _run_query_job( + self, + *, + name: str, + config: Optional[types._RunQueryJobAgentEngineConfigOrDict] = None, + ) -> types.AgentEngineOperation: + """ + Run a query job on an agent engine. + """ + + parameter_model = types._RunQueryJobAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RunQueryJobAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:asyncQuery".format_map(request_url_dict) + else: + path = "{name}:asyncQuery" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create( + self, *, config: Optional[types.CreateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """ + Creates a new Agent Engine. + """ + + parameter_model = types._CreateAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """ + Delete an Agent Engine resource. + + Args: + name (str): + Required. The name of the Agent Engine to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): + Optional. If set to True, child resources will also be deleted. + Otherwise, the request will fail with FAILED_PRECONDITION error when + the Agent Engine has undeleted child resources. Defaults to False. + config (DeleteAgentEngineConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + + parameter_model = types._DeleteAgentEngineRequestParameters( + name=name, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get( + self, *, name: str, config: Optional[types.GetAgentEngineConfigOrDict] = None + ) -> types.ReasoningEngine: + """ + Get an Agent Engine instance. + """ + + parameter_model = types._GetAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ReasoningEngine_from_vertex(response_dict) + + return_value = types.ReasoningEngine._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, *, config: Optional[types.ListAgentEngineConfigOrDict] = None + ) -> types.ListReasoningEnginesResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListAgentEngineRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "reasoningEngines".format_map(request_url_dict) + else: + path = "reasoningEngines" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListReasoningEnginesResponse_from_vertex(response_dict) + + return_value = types.ListReasoningEnginesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_agent_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineOperation: + parameter_model = types._GetAgentEngineOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _query( + self, *, name: str, config: Optional[types.QueryAgentEngineConfigOrDict] = None + ) -> types.QueryReasoningEngineResponse: + """ + Query an Agent Engine. + """ + + parameter_model = types._QueryAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _QueryAgentEngineRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update( + self, *, name: str, config: Optional[types.UpdateAgentEngineConfigOrDict] = None + ) -> types.AgentEngineOperation: + """ + Updates an Agent Engine. + """ + + parameter_model = types._UpdateAgentEngineRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _AgentEngineOperation_from_vertex(response_dict) + + return_value = types.AgentEngineOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _a2a_tasks = None + _memories = None + _sessions = None + _runtimes = None + + async def delete( + self, + *, + name: str, + force: Optional[bool] = None, + config: Optional[types.DeleteAgentEngineConfigOrDict] = None, + ) -> types.DeleteAgentEngineOperation: + """ + Delete an Agent Engine resource. + + Args: + name (str): + Required. The name of the Agent Engine to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}` + or `reasoningEngines/{resource_id}`. + force (bool): + Optional. If set to True, child resources will also be deleted. + Otherwise, the request will fail with FAILED_PRECONDITION error when + the Agent Engine has undeleted child resources. Defaults to False. + config (DeleteAgentEngineConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + logger.info(f"Deleting AgentEngine resource: {name}") + operation = await self._delete(name=name, force=force, config=config) + logger.info(f"Started AgentEngine delete operation: {operation.name}") + return operation + + @property + def runtimes(self) -> "runtimes_module.AsyncRuntimes": + if self._runtimes is None: + try: + # We need to lazy load the runtimes module to handle the + # possibility of ImportError when dependencies are not installed. + self._runtimes = importlib.import_module(".runtimes", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.runtimes' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._runtimes.AsyncRuntimes(self._api_client) # type: ignore[no-any-return] + + @property + def a2a_tasks(self) -> "a2a_tasks_module.AsyncA2aTasks": + if self._a2a_tasks is None: + try: + # We need to lazy load the a2a_tasks module to handle the + # possibility of ImportError when dependencies are not installed. + self._a2a_tasks = importlib.import_module(".a2a_tasks", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.a2a_tasks' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._a2a_tasks.AsyncA2aTasks(self._api_client) # type: ignore[no-any-return] + + @property + def memories(self) -> "memories_module.AsyncMemories": + if self._memories is None: + try: + # We need to lazy load the memories module to handle the + # possibility of ImportError when dependencies are not installed. + self._memories = importlib.import_module(".memories", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.memories' module requires additional " + "packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._memories.AsyncMemories(self._api_client) # type: ignore[no-any-return] + + @property + def sessions(self) -> "sessions_module.AsyncSessions": + if self._sessions is None: + try: + # We need to lazy load the sessions module to handle the + # possibility of ImportError when dependencies are not installed. + self._sessions = importlib.import_module(".sessions", __package__) + except ImportError as e: + raise ImportError( + "The agent_engines.sessions module requires additional packages. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._sessions.AsyncSessions(self._api_client) # type: ignore[no-any-return] + + async def append_session_event( + self, + *, + name: str, + author: str, + invocation_id: str, + timestamp: datetime.datetime, + config: Optional[types.AppendAgentEngineSessionEventConfigOrDict] = None, + ) -> types.AppendAgentEngineSessionEventResponse: + """Deprecated. Use agent_engines.sessions.events.append instead.""" + warnings.warn( + ( + "agent_engines.append_session_event is deprecated. " + "Use agent_engines.sessions.events.append instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return await self.sessions.events.append( + name=name, + author=author, + invocation_id=invocation_id, + timestamp=timestamp, + config=config, + ) + + async def delete_memory( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineMemoryConfigOrDict] = None, + ) -> types.DeleteAgentEngineMemoryOperation: + """Deprecated. Use agent_engines.memories.delete instead.""" + warnings.warn( + ( + "agent_engines.delete_memory is deprecated. " + "Use agent_engines.memories.delete instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return await self.memories.delete(name=name, config=config) + + async def delete_session( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSessionConfigOrDict] = None, + ) -> types.DeleteAgentEngineSessionOperation: + """Deprecated. Use agent_engines.sessions.delete instead.""" + warnings.warn( + ( + "agent_engines.delete_session is deprecated. " + "Use agent_engines.sessions.delete instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return await self.sessions.delete(name=name, config=config) + + async def get_memory( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryConfigOrDict] = None, + ) -> types.Memory: + """Deprecated. Use agent_engines.memories.get instead.""" + warnings.warn( + ( + "agent_engines.get_memory is deprecated. " + "Use agent_engines.memories.get instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return await self.memories.get(name=name, config=config) + + async def get_session( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSessionConfigOrDict] = None, + ) -> types.Session: + """Deprecated. Use agent_engines.sessions.get instead.""" + warnings.warn( + ( + "agent_engines.get_session is deprecated. " + "Use agent_engines.sessions.get instead." + ), + DeprecationWarning, + stacklevel=2, + ) + return await self.sessions.get(name=name, config=config) diff --git a/agentplatform/_genai/client.py b/agentplatform/_genai/client.py new file mode 100644 index 0000000000..1289b4fde1 --- /dev/null +++ b/agentplatform/_genai/client.py @@ -0,0 +1,358 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib +from typing import Optional, Union, TYPE_CHECKING +from types import TracebackType, ModuleType + +import google.auth +from google.cloud.aiplatform import version as aip_version +from google.genai import _common +from google.genai import client as genai_client +from google.genai import types +from . import live + +if TYPE_CHECKING: + from agentplatform._genai import ( + agent_engines as agent_engines_module, + ) + from agentplatform._genai import datasets as datasets_module + from agentplatform._genai import evals as evals_module + from agentplatform._genai import ( + prompt_optimizer as prompt_optimizer_module, + ) + from agentplatform._genai import prompts as prompts_module + from agentplatform._genai import skills as skills_module + from agentplatform._genai import live as live_module + + +_GENAI_MODULES_TELEMETRY_HEADER = "vertex-genai-modules" + + +class AsyncClient: + """Async Gen AI Client for the Vertex SDK.""" + + def __init__(self, api_client: genai_client.BaseApiClient): # type: ignore[name-defined] + self._api_client = api_client + self._live = live.AsyncLive(self._api_client) + self._evals: Optional[ModuleType] = None + self._agent_engines: Optional[ModuleType] = None + self._prompt_optimizer: Optional[ModuleType] = None + self._prompts: Optional[ModuleType] = None + self._datasets: Optional[ModuleType] = None + self._skills: Optional[ModuleType] = None + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI live module is experimental, and may change in future " + "versions." + ) + def live(self) -> "live_module.AsyncLive": + return self._live + + @property + def evals(self) -> "evals_module.AsyncEvals": + if self._evals is None: + try: + # We need to lazy load the evals module to avoid ImportError when + # pandas/tqdm are not installed. + self._evals = importlib.import_module(".evals", __package__) + except ImportError as e: + raise ImportError( + "The 'evals' module requires 'pandas' and 'tqdm'. " + "Please install them using pip install " + "google-cloud-aiplatform[evaluation]" + ) from e + return self._evals.AsyncEvals(self._api_client) # type: ignore[no-any-return] + + @property + def prompt_optimizer(self) -> "prompt_optimizer_module.AsyncPromptOptimizer": + if self._prompt_optimizer is None: + self._prompt_optimizer = importlib.import_module( + ".prompt_optimizer", __package__ + ) + return self._prompt_optimizer.AsyncPromptOptimizer(self._api_client) # type: ignore[no-any-return] + + @property + def agent_engines(self) -> "agent_engines_module.AsyncAgentEngines": + if self._agent_engines is None: + try: + # We need to lazy load the agent_engines module to handle the + # possibility of ImportError when dependencies are not installed. + self._agent_engines = importlib.import_module( + ".agent_engines", + __package__, + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines' module requires 'additional packages'. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._agent_engines.AsyncAgentEngines(self._api_client) # type: ignore[no-any-return] + + @property + def prompts(self) -> "prompts_module.AsyncPrompts": + if self._prompts is None: + self._prompts = importlib.import_module( + ".prompts", + __package__, + ) + return self._prompts.AsyncPrompts(self._api_client) # type: ignore[no-any-return] + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI async datasets module is experimental, " + "and may change in future versions." + ) + def datasets(self) -> "datasets_module.AsyncDatasets": + if self._datasets is None: + self._datasets = importlib.import_module( + ".datasets", + __package__, + ) + return self._datasets.AsyncDatasets(self._api_client) # type: ignore[no-any-return] + + @property + def skills(self) -> "skills_module.AsyncSkills": + if self._skills is None: + self._skills = importlib.import_module( + ".skills", + __package__, + ) + return self._skills.AsyncSkills(self._api_client) # type: ignore[no-any-return] + + async def aclose(self) -> None: + """Closes the async client explicitly. + + Example usage: + + from agentplatform import Client + + async_client = agentplatform.Client( + project='my-project-id', location='us-central1' + ).aio + prompt_1 = await async_client.prompts.create(...) + prompt_2 = await async_client.prompts.create(...) + # Close the client to release resources. + await async_client.aclose() + """ + await self._api_client.aclose() + + async def __aenter__(self) -> "AsyncClient": + return self + + async def __aexit__( + self, + exc_type: Optional[Exception], + exc_value: Optional[Exception], + traceback: Optional[TracebackType], + ) -> None: + await self.aclose() + + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.aclose()) + except Exception: + pass + + +class Client: + """Gen AI Client for the Vertex SDK. + + Use this client to interact with Vertex-specific Gemini features. + """ + + def __init__( + self, + *, + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[genai_client.DebugConfig] = None, + http_options: Optional[Union[types.HttpOptions, types.HttpOptionsDict]] = None, + ): + """Initializes the client. + + Args: + api_key (str): The `API key + `_ + to use for authentication. Applies to Vertex AI in express mode only. + credentials (google.auth.credentials.Credentials): The credentials to use + for authentication when calling the Vertex AI APIs. Credentials can be + obtained from environment variables and default credentials. For more + information, see `Set up Application Default Credentials + `_. + project (str): The `Google Cloud project ID + `_ to + use for quota. Can be obtained from environment variables (for example, + ``GOOGLE_CLOUD_PROJECT``). + location (str): The `location + `_ + to send API requests to (for example, ``us-central1``). Can be obtained + from environment variables. + debug_config (DebugConfig): Config settings that control network behavior + of the client. This is typically used when running test code. + http_options (Union[HttpOptions, HttpOptionsDict]): Http options to use + for the client. + """ + + self._debug_config = debug_config or genai_client.DebugConfig() + if isinstance(http_options, dict): + http_options = types.HttpOptions(**http_options) + if http_options is None: + http_options = types.HttpOptions() + if http_options.headers is None: + http_options.headers = {} + + tracking_label = f"{_GENAI_MODULES_TELEMETRY_HEADER}/{aip_version.__version__}" + + if "user-agent" in http_options.headers: + http_options.headers["user-agent"] = ( + f"{http_options.headers['user-agent']} {tracking_label}" + ) + else: + http_options.headers["user-agent"] = tracking_label + + if "x-goog-api-client" in http_options.headers: + http_options.headers["x-goog-api-client"] = ( + f"{http_options.headers['x-goog-api-client']} {tracking_label}" + ) + else: + http_options.headers["x-goog-api-client"] = tracking_label + + self._api_client = genai_client.Client._get_api_client( + vertexai=True, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + debug_config=self._debug_config, + http_options=http_options, + ) + self._aio = AsyncClient(self._api_client) + self._evals: Optional[ModuleType] = None + self._prompt_optimizer: Optional[ModuleType] = None + self._agent_engines: Optional[ModuleType] = None + self._prompts: Optional[ModuleType] = None + self._datasets: Optional[ModuleType] = None + self._skills: Optional[ModuleType] = None + + @property + def evals(self) -> "evals_module.Evals": + if self._evals is None: + try: + # We need to lazy load the evals module to avoid ImportError when + # pandas/tqdm are not installed. + self._evals = importlib.import_module(".evals", __package__) + except ImportError as e: + raise ImportError( + "The 'evals' module requires additional dependencies. " + "Please install them using pip install " + "google-cloud-aiplatform[evaluation]" + ) from e + return self._evals.Evals(self._api_client) # type: ignore[no-any-return] + + @property + def prompt_optimizer(self) -> "prompt_optimizer_module.PromptOptimizer": + if self._prompt_optimizer is None: + self._prompt_optimizer = importlib.import_module( + ".prompt_optimizer", __package__ + ) + return self._prompt_optimizer.PromptOptimizer(self._api_client) # type: ignore[no-any-return] + + @property + def aio(self) -> "AsyncClient": + return self._aio + + # This is only used for replay tests + @staticmethod + def _get_api_client( + api_key: Optional[str] = None, + credentials: Optional[google.auth.credentials.Credentials] = None, + project: Optional[str] = None, + location: Optional[str] = None, + debug_config: Optional[genai_client.DebugConfig] = None, + http_options: Optional[types.HttpOptions] = None, + ) -> Optional[genai_client.BaseApiClient]: # type: ignore[name-defined] + if debug_config and debug_config.client_mode in [ + "record", + "replay", + "auto", + ]: + return genai_client.ReplayApiClient( # type: ignore[attr-defined] + mode=debug_config.client_mode, + replay_id=debug_config.replay_id, + replays_directory=debug_config.replays_directory, + vertexai=True, + api_key=api_key, + credentials=credentials, + project=project, + location=location, + http_options=http_options, + ) + return None + + @property + def agent_engines(self) -> "agent_engines_module.AgentEngines": + if self._agent_engines is None: + try: + # We need to lazy load the agent_engines module to handle the + # possibility of ImportError when dependencies are not installed. + self._agent_engines = importlib.import_module( + ".agent_engines", + __package__, + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines' module requires 'additional packages'. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._agent_engines.AgentEngines(self._api_client) # type: ignore[no-any-return] + + @property + def prompts(self) -> "prompts_module.Prompts": + if self._prompts is None: + # Lazy loading the prompts module + self._prompts = importlib.import_module( + ".prompts", + __package__, + ) + return self._prompts.Prompts(self._api_client) # type: ignore[no-any-return] + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI datasets module is experimental, " + "and may change in future versions." + ) + def datasets(self) -> "datasets_module.Datasets": + if self._datasets is None: + self._datasets = importlib.import_module( + ".datasets", + __package__, + ) + return self._datasets.Datasets(self._api_client) # type: ignore[no-any-return] + + @property + def skills(self) -> "skills_module.Skills": + if self._skills is None: + self._skills = importlib.import_module( + ".skills", + __package__, + ) + return self._skills.Skills(self._api_client) # type: ignore[no-any-return] diff --git a/agentplatform/_genai/datasets.py b/agentplatform/_genai/datasets.py new file mode 100644 index 0000000000..ec0a8addd5 --- /dev/null +++ b/agentplatform/_genai/datasets.py @@ -0,0 +1,2799 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import asyncio +import json +import logging +import time +from typing import Any, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +import pandas as pd + +from . import _datasets_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.datasets") + + +def _AssembleDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["gemini_request_read_config"]) is not None: + setv( + to_object, + ["geminiRequestReadConfig"], + getv(from_object, ["gemini_request_read_config"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _AssessDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["gemini_request_read_config"]) is not None: + setv( + to_object, + ["geminiRequestReadConfig"], + getv(from_object, ["gemini_request_read_config"]), + ) + + if getv(from_object, ["tuning_resource_usage_assessment_config"]) is not None: + setv( + to_object, + ["tuningResourceUsageAssessmentConfig"], + getv(from_object, ["tuning_resource_usage_assessment_config"]), + ) + + if getv(from_object, ["tuning_validation_assessment_config"]) is not None: + setv( + to_object, + ["tuningValidationAssessmentConfig"], + getv(from_object, ["tuning_validation_assessment_config"]), + ) + + if ( + getv(from_object, ["batch_prediction_resource_usage_assessment_config"]) + is not None + ): + setv( + to_object, + ["batchPredictionResourceUsageAssessmentConfig"], + getv(from_object, ["batch_prediction_resource_usage_assessment_config"]), + ) + + if getv(from_object, ["batch_prediction_validation_assessment_config"]) is not None: + setv( + to_object, + ["batchPredictionValidationAssessmentConfig"], + getv(from_object, ["batch_prediction_validation_assessment_config"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CreateMultimodalDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["metadata_schema_uri"]) is not None: + setv( + to_object, ["metadataSchemaUri"], getv(from_object, ["metadata_schema_uri"]) + ) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(to_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _DeleteMultimodalDatasetRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetMultimodalDatasetOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["operation_id"]) is not None: + setv(to_object, ["_url", "operation_id"], getv(from_object, ["operation_id"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetMultimodalDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _ListMultimodalDatasetsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListMultimodalDatasetsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListMultimodalDatasetsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _UpdateMultimodalDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(to_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +class Datasets(_api_module.BaseModule): + + def _assemble_multimodal_dataset( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assembles a multimodal dataset resource. + """ + + parameter_model = types._AssembleDatasetParameters( + name=name, + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AssembleDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:assemble".format_map(request_url_dict) + else: + path = "{name}:assemble" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _assess_multimodal_dataset( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + tuning_resource_usage_assessment_config: Optional[ + types.TuningResourceUsageAssessmentConfigOrDict + ] = None, + tuning_validation_assessment_config: Optional[ + types.TuningValidationAssessmentConfigOrDict + ] = None, + batch_prediction_resource_usage_assessment_config: Optional[ + types.BatchPredictionResourceUsageAssessmentConfigOrDict + ] = None, + batch_prediction_validation_assessment_config: Optional[ + types.BatchPredictionValidationAssessmentConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assesses a multimodal dataset resource. + """ + + parameter_model = types._AssessDatasetParameters( + name=name, + gemini_request_read_config=gemini_request_read_config, + tuning_resource_usage_assessment_config=tuning_resource_usage_assessment_config, + tuning_validation_assessment_config=tuning_validation_assessment_config, + batch_prediction_resource_usage_assessment_config=batch_prediction_resource_usage_assessment_config, + batch_prediction_validation_assessment_config=batch_prediction_validation_assessment_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AssessDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:assess".format_map(request_url_dict) + else: + path = "{name}:assess" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_multimodal_dataset( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata_schema_uri: Optional[str] = None, + metadata: Optional[types.SchemaTablesDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Creates a dataset resource to store multimodal datasets. + """ + + parameter_model = types._CreateMultimodalDatasetParameters( + name=name, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete_multimodal_dataset( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.MultimodalDatasetOperation: + """ + Deletes a multimodal dataset resource. + """ + + parameter_model = types._DeleteMultimodalDatasetRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteMultimodalDatasetRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_multimodal_dataset( + self, + *, + name: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """ + Gets a multimodal dataset resource. + """ + + parameter_model = types._GetMultimodalDatasetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_multimodal_dataset_operation( + self, + *, + dataset_id: Optional[str] = None, + operation_id: Optional[str] = None, + config: Optional[types.GetMultimodalDatasetOperationConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Gets the operation from creating a multimodal dataset. + """ + + parameter_model = types._GetMultimodalDatasetOperationParameters( + dataset_id=dataset_id, + operation_id=operation_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetMultimodalDatasetOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/operations/{operation_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/operations/{operation_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list_multimodal_datasets( + self, *, config: Optional[types.ListMultimodalDatasetsConfigOrDict] = None + ) -> types.ListMultimodalDatasetsResponse: + """ + Lists multimodal datasets. + """ + + parameter_model = types._ListMultimodalDatasetsRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListMultimodalDatasetsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListMultimodalDatasetsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _update_multimodal_dataset( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[types.SchemaTablesDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """ + Updates a multimodal dataset resource. + """ + + parameter_model = types._UpdateMultimodalDatasetParameters( + name=name, + display_name=display_name, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _wait_for_operation( + self, + operation: types.MultimodalDatasetOperation, + timeout_seconds: int, + ) -> dict[str, Any]: + """Waits for a multimodal or assemble dataset operation to complete. + + Args: + operation: The multimodal or assemble dataset operation to wait for. + timeout_seconds: The maximum time in seconds to wait for the operation + to complete. + + Returns: + A dict containing the operation response. + + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + response_operation_name = operation.name + if response_operation_name is None: + raise ValueError("Dataset operation name is empty.") + dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] + operation_id = response_operation_name.split("/")[-1] + + start_time = time.time() + sleep_duration_seconds = 5 + wait_multiplier = 2 + max_wait_time_seconds = 60 + + while (time.time() - start_time) < timeout_seconds: + operation = self._get_multimodal_dataset_operation( + dataset_id=dataset_id, + operation_id=operation_id, + ) + if operation.done: + break + time.sleep(sleep_duration_seconds) + sleep_duration_seconds = min( + sleep_duration_seconds * wait_multiplier, max_wait_time_seconds + ) + else: + raise TimeoutError( + "The operation did not complete within the" + f" specified timeout of {timeout_seconds} seconds." + ) + if not operation or operation.response is None: + logger.error(f"Error running the operation {operation.response}.") + raise ValueError(f"Error running the operation {operation.response}.") + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error running the operation {operation.error}") + return operation.response + + def create_from_bigquery( + self, + *, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a BigQuery table. + + Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". If both `bigquery_uri` and + `multimodal_dataset` are provided, and `multimodal_dataset` also + contains a BigQuery URI, the `bigquery_uri` parameter takes precedence. + multimodal_dataset: + Optional. A representation of a multimodal dataset. If `bigquery_uri` + is set, `multimodal_dataset` can still be used to set other metadata + fields. If both `bigquery_uri` and `multimodal_dataset` are provided, + and `multimodal_dataset` also contains a BigQuery URI, the + `bigquery_uri` parameter takes precedence. + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing a multimodal dataset. + """ + if not bigquery_uri and not multimodal_dataset: + raise ValueError( + "At least one of `bigquery_uri` or `multimodal_dataset` must be" + " provided." + ) + + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) + + if isinstance(config, dict): + config = types.CreateMultimodalDatasetConfig(**config) + elif not config: + config = types.CreateMultimodalDatasetConfig() + + display_name = ( + multimodal_dataset.display_name + if multimodal_dataset.display_name is not None + else _datasets_utils.generate_multimodal_dataset_display_name() + ) + multimodal_dataset_operation = self._create_multimodal_dataset( + config=config, + display_name=display_name, + metadata_schema_uri=_datasets_utils.METADATA_SCHEMA_URI, + metadata=multimodal_dataset.metadata, + ) + response = self._wait_for_operation( + operation=multimodal_dataset_operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) + + def create_from_pandas( + self, + *, + dataframe: pd.DataFrame, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a pandas dataframe. + + Args: + dataframe (pandas.DataFrame): + The pandas dataframe to be used for the created dataset. + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigframes = _datasets_utils._try_import_bigframes() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + ) + with bigframes.connect(session_options) as session: + return self.create_from_bigframes( + dataframe=session.read_pandas(dataframe), + multimodal_dataset=multimodal_dataset, + target_table_id=target_table_id, + config=config, + ) + + def create_from_bigframes( + self, + *, + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a bigframes dataframe. + + Args: + dataframe (bigframes.pandas.DataFrame): + The BigFrames dataframe that will be used for the created + dataset. + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = _datasets_utils._normalize_and_validate_table_id( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + else: + dataset_id = _datasets_utils._create_default_bigquery_dataset_if_not_exists( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + client = bigquery.Client(project=project, credentials=credentials) + _datasets_utils.save_dataframe_to_bigquery( + dataframe, + target_table_id, + client, + ) + + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") + return self.create_from_bigquery( + multimodal_dataset=multimodal_dataset, config=config + ) + + def update_multimodal_dataset( + self, + *, + multimodal_dataset: types.MultimodalDatasetOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Updates a multimodal dataset. + + Updatable fields include: + - display_name + - description + + Args: + multimodal_dataset: + Required. A representation of a multimodal dataset. + config: + Optional. A configuration for updating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing the retrieved multimodal + dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) + + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + return self._update_multimodal_dataset( + config=config, + name=multimodal_dataset.name, + display_name=multimodal_dataset.display_name, + description=multimodal_dataset.description, + metadata=multimodal_dataset.metadata, + ) + + def get_multimodal_dataset( + self, + *, + name: str, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Gets a multimodal dataset. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + config: + Optional. A configuration for getting the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing the retrieved multimodal + dataset. + """ + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + return self._get_multimodal_dataset(config=config, name=name) + + def delete_multimodal_dataset( + self, + *, + name: str, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """Deletes a multimodal dataset. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + config: + Optional. A configuration for deleting the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDatasetOperation object representing the delete + multimodal dataset operation. + """ + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + return self._delete_multimodal_dataset(config=config, name=name) + + def assemble( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> str: + """Assemble the dataset into a BigQuery table. + + Waits for the assemble operation to complete before returning. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + gemini_request_read_config: + Optional. The read config to use to assemble the dataset. If + not provided, the read config attached to the dataset will be + used. + config: + Optional. A configuration for assembling the dataset. If not + provided, the default configuration will be used. + + Returns: + The URI of the bigquery table of the assembled dataset. + """ + if isinstance(config, dict): + config = types.AssembleDatasetConfig(**config) + elif not config: + config = types.AssembleDatasetConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + operation = self._assemble_multimodal_dataset( + name=name, + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return response["bigqueryDestination"] # type: ignore[no-any-return] + + def assess_tuning_resources( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningResourceUsageAssessmentResult: + """Assess the tuning resources required for a given model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the tuning resources + for. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the tuning resources. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning resources. If not + provided, the default configuration will be used. + + Returns: + A types.TuningResourceUsageAssessmentResult object representing the + tuning resource usage assessment result. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = self._assess_multimodal_dataset( + name=dataset_name, + tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( + model_name=model_name + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningResourceUsageAssessmentResult, + response["tuningResourceUsageAssessmentResult"], + config, + ) + + def assess_tuning_validity( + self, + *, + dataset_name: str, + model_name: str, + dataset_usage: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of tuning a given + model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the tuning validity + for. + dataset_usage: + Required. The dataset usage to assess the tuning validity for. + Must be one of the following: SFT_TRAINING, SFT_VALIDATION. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the tuning validity. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning validity. If not + provided, the default configuration will be used. + + Returns: + A dict containing the tuning validity assessment result. The dict + contains the following keys: + - errors: A list of errors that occurred during the tuning validity + assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = self._assess_multimodal_dataset( + name=dataset_name, + tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( + model_name=model_name, + dataset_usage=dataset_usage, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningValidationAssessmentResult, + response["tuningValidationAssessmentResult"], + config, + ) + + def assess_batch_prediction_resources( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionResourceUsageAssessmentResult: + """Assess the batch prediction resources required for a given model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the batch prediction + resources. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the batch prediction resources. If not provided, + the read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction + resources. If not provided, the default configuration will be + used. + + Returns: + A types.BatchPredictionResourceUsageAssessmentResult object + representing the batch prediction resource usage assessment result. + It contains the following keys: + - token_count: The number of tokens in the dataset. + - audio_token_count: The number of audio tokens in the dataset. + + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionResourceUsageAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionResourceUsageAssessmentResult, result, config + ) + + def assess_batch_prediction_validity( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of batch prediction + for a given model. Raises an error if the dataset is invalid, otherwise + returns None. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the batch prediction + validity for. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the batch prediction validity. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction validity. + If not provided, the default configuration will be used. + + Returns: + A types.BatchPredictionValidationAssessmentResult object representing + the batch prediction validity assessment result. + It contains the following keys: + - errors: A list of errors that occurred during the batch prediction + validity assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionValidationAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionValidationAssessmentResult, result, config + ) + + +class AsyncDatasets(_api_module.BaseModule): + + async def _assemble_multimodal_dataset( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assembles a multimodal dataset resource. + """ + + parameter_model = types._AssembleDatasetParameters( + name=name, + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AssembleDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:assemble".format_map(request_url_dict) + else: + path = "{name}:assemble" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _assess_multimodal_dataset( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + tuning_resource_usage_assessment_config: Optional[ + types.TuningResourceUsageAssessmentConfigOrDict + ] = None, + tuning_validation_assessment_config: Optional[ + types.TuningValidationAssessmentConfigOrDict + ] = None, + batch_prediction_resource_usage_assessment_config: Optional[ + types.BatchPredictionResourceUsageAssessmentConfigOrDict + ] = None, + batch_prediction_validation_assessment_config: Optional[ + types.BatchPredictionValidationAssessmentConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Assesses a multimodal dataset resource. + """ + + parameter_model = types._AssessDatasetParameters( + name=name, + gemini_request_read_config=gemini_request_read_config, + tuning_resource_usage_assessment_config=tuning_resource_usage_assessment_config, + tuning_validation_assessment_config=tuning_validation_assessment_config, + batch_prediction_resource_usage_assessment_config=batch_prediction_resource_usage_assessment_config, + batch_prediction_validation_assessment_config=batch_prediction_validation_assessment_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AssessDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:assess".format_map(request_url_dict) + else: + path = "{name}:assess" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_multimodal_dataset( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata_schema_uri: Optional[str] = None, + metadata: Optional[types.SchemaTablesDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Creates a dataset resource to store multimodal datasets. + """ + + parameter_model = types._CreateMultimodalDatasetParameters( + name=name, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete_multimodal_dataset( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.MultimodalDatasetOperation: + """ + Deletes a multimodal dataset resource. + """ + + parameter_model = types._DeleteMultimodalDatasetRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteMultimodalDatasetRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_multimodal_dataset( + self, + *, + name: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """ + Gets a multimodal dataset resource. + """ + + parameter_model = types._GetMultimodalDatasetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_multimodal_dataset_operation( + self, + *, + dataset_id: Optional[str] = None, + operation_id: Optional[str] = None, + config: Optional[types.GetMultimodalDatasetOperationConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """ + Gets the operation from creating a multimodal dataset. + """ + + parameter_model = types._GetMultimodalDatasetOperationParameters( + dataset_id=dataset_id, + operation_id=operation_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetMultimodalDatasetOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/operations/{operation_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/operations/{operation_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list_multimodal_datasets( + self, *, config: Optional[types.ListMultimodalDatasetsConfigOrDict] = None + ) -> types.ListMultimodalDatasetsResponse: + """ + Lists multimodal datasets. + """ + + parameter_model = types._ListMultimodalDatasetsRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListMultimodalDatasetsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListMultimodalDatasetsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update_multimodal_dataset( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[types.SchemaTablesDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """ + Updates a multimodal dataset resource. + """ + + parameter_model = types._UpdateMultimodalDatasetParameters( + name=name, + display_name=display_name, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateMultimodalDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MultimodalDataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _wait_for_operation( + self, + operation: types.MultimodalDatasetOperation, + timeout_seconds: int, + ) -> dict[str, Any]: + """Waits for a multimodal dataset operation to complete. + + Args: + operation: The multimodal dataset operation to wait for. + timeout_seconds: The maximum time in seconds to wait for the operation + to complete. + + Returns: + A dict containing the operation response. + + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + response_operation_name = operation.name + if response_operation_name is None: + raise ValueError("Dataset operation name is empty.") + dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] + operation_id = response_operation_name.split("/")[-1] + + start_time = time.time() + sleep_duration_seconds = 5 + wait_multiplier = 2 + max_wait_time_seconds = 60 + + while (time.time() - start_time) < timeout_seconds: + operation = await self._get_multimodal_dataset_operation( + dataset_id=dataset_id, + operation_id=operation_id, + ) + if operation.done: + break + await asyncio.sleep(sleep_duration_seconds) + sleep_duration_seconds = min( + sleep_duration_seconds * wait_multiplier, max_wait_time_seconds + ) + else: + raise TimeoutError( + "The operation did not complete within the" + f" specified timeout of {timeout_seconds} seconds." + ) + if not operation or operation.response is None: + logger.error(f"Error running the operation {operation.response}.") + raise ValueError(f"Error running the operation {operation.response}.") + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error running the operation {operation.error}") + return operation.response + + async def create_from_bigquery( + self, + *, + bigquery_uri: Optional[str] = None, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a BigQuery table. + + Args: + bigquery_uri: + Optional. The BigQuery URI of the table to create the dataset from. + e.g. "bq://project.dataset.table". If both `bigquery_uri` and + `multimodal_dataset` are provided, and `multimodal_dataset` also + contains a BigQuery URI, the `bigquery_uri` parameter takes precedence. + multimodal_dataset: + Optional. A representation of a multimodal dataset. If `bigquery_uri` + is set, `multimodal_dataset` can still be used to set other metadata + fields. If both `bigquery_uri` and `multimodal_dataset` are provided, + and `multimodal_dataset` also contains a BigQuery URI, the + `bigquery_uri` parameter takes precedence. + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing a multimodal dataset. + """ + if not bigquery_uri and not multimodal_dataset: + raise ValueError( + "At least one of `bigquery_uri` or `multimodal_dataset` must be" + " provided." + ) + + if multimodal_dataset is None: + multimodal_dataset = types.MultimodalDataset() + elif isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + + if bigquery_uri: + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(bigquery_uri) + + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) + + if isinstance(config, dict): + config = types.CreateMultimodalDatasetConfig(**config) + elif not config: + config = types.CreateMultimodalDatasetConfig() + + display_name = ( + multimodal_dataset.display_name + if multimodal_dataset.display_name is not None + else _datasets_utils.generate_multimodal_dataset_display_name() + ) + multimodal_dataset_operation = await self._create_multimodal_dataset( + config=config, + display_name=display_name, + metadata_schema_uri=_datasets_utils.METADATA_SCHEMA_URI, + metadata=multimodal_dataset.metadata, + ) + response = await self._wait_for_operation( + operation=multimodal_dataset_operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.MultimodalDataset, response, config + ) + + async def create_from_pandas( + self, + *, + dataframe: pd.DataFrame, + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a pandas dataframe. + + Args: + dataframe (pandas.DataFrame): + The pandas dataframe to be used for the created dataset. + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigframes = _datasets_utils._try_import_bigframes() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + session_options = bigframes.BigQueryOptions( + credentials=credentials, + project=project, + location=location, + ) + with bigframes.connect(session_options) as session: + return await self.create_from_bigframes( + dataframe=session.read_pandas(dataframe), + multimodal_dataset=multimodal_dataset, + target_table_id=target_table_id, + config=config, + ) + + async def create_from_bigframes( + self, + *, + dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821 + multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None, + target_table_id: Optional[str] = None, + config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Creates a multimodal dataset from a bigframes dataframe. + + Args: + dataframe (bigframes.pandas.DataFrame): + The BigFrames dataframe that will be used for the created + dataset. + multimodal_dataset: + Optional. A representation of a multimodal dataset. + target_table_id (str): + Optional. The BigQuery table id where the dataframe will be + uploaded. The table id can be in the format of "dataset.table" + or "project.dataset.table". Note that the BigQuery + dataset must already exist and be in the same location as the + multimodal dataset. If not provided, a generated table id will + be created in the `vertex_datasets` dataset (e.g. + `project.vertex_datasets_us_central1.multimodal_dataset_4cbf7ffd`). + config: + Optional. A configuration for creating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + dataset (MultimodalDataset): + The created multimodal dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + elif not multimodal_dataset: + multimodal_dataset = types.MultimodalDataset() + + bigquery = _datasets_utils._try_import_bigquery() + project = self._api_client.project + location = self._api_client.location + credentials = self._api_client._credentials + + if target_table_id: + target_table_id = ( + await _datasets_utils._normalize_and_validate_table_id_async( + table_id=target_table_id, + project=project, + location=location, + credentials=credentials, + ) + ) + else: + dataset_id = await _datasets_utils._create_default_bigquery_dataset_if_not_exists_async( + project=project, location=location, credentials=credentials + ) + target_table_id = _datasets_utils._generate_target_table_id(dataset_id) + + client = bigquery.Client(project=project, credentials=credentials) + await _datasets_utils.save_dataframe_to_bigquery_async( + dataframe, + target_table_id, + client, + ) + + multimodal_dataset = multimodal_dataset.model_copy(deep=True) + multimodal_dataset.set_bigquery_uri(f"bq://{target_table_id}") + return await self.create_from_bigquery( + multimodal_dataset=multimodal_dataset, config=config + ) + + async def update_multimodal_dataset( + self, + *, + multimodal_dataset: types.MultimodalDatasetOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Updates a multimodal dataset. + + Args: + multimodal_dataset: + Required. A representation of a multimodal dataset. + config: + Optional. A configuration for updating the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing the updated multimodal + dataset. + """ + if isinstance(multimodal_dataset, dict): + multimodal_dataset = types.MultimodalDataset(**multimodal_dataset) + _datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset) + + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + return await self._update_multimodal_dataset( + config=config, + name=multimodal_dataset.name, + display_name=multimodal_dataset.display_name, + description=multimodal_dataset.description, + metadata=multimodal_dataset.metadata, + ) + + async def get_multimodal_dataset( + self, + *, + name: str, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDataset: + """Gets a multimodal dataset. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + config: + Optional. A configuration for getting the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDataset object representing the retrieved multimodal + dataset. + """ + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + return await self._get_multimodal_dataset(config=config, name=name) + + async def delete_multimodal_dataset( + self, + *, + name: str, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.MultimodalDatasetOperation: + """Deletes a multimodal dataset. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + config: + Optional. A configuration for deleting the multimodal dataset. If not + provided, the default configuration will be used. + + Returns: + A types.MultimodalDatasetOperation object representing the delete + multimodal dataset operation. + """ + if isinstance(config, dict): + config = types.VertexBaseConfig(**config) + elif not config: + config = types.VertexBaseConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + return await self._delete_multimodal_dataset(config=config, name=name) + + async def assemble( + self, + *, + name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssembleDatasetConfigOrDict] = None, + ) -> str: + """Assemble the dataset into a BigQuery table. + + Waits for the assemble operation to complete before returning. + + Args: + name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + gemini_request_read_config: + Optional. The read config to use to assemble the dataset. If + not provided, the read config attached to the dataset will be + used. + config: + Optional. A configuration for assembling the dataset. If not + provided, the default configuration will be used. + + Returns: + The URI of the bigquery table of the assembled dataset. + """ + if isinstance(config, dict): + config = types.AssembleDatasetConfig(**config) + elif not config: + config = types.AssembleDatasetConfig() + + name = _datasets_utils.resolve_dataset_name( + name, self._api_client.project, self._api_client.location + ) + + operation = await self._assemble_multimodal_dataset( + name=name, + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return response["bigqueryDestination"] # type: ignore[no-any-return] + + async def assess_tuning_resources( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningResourceUsageAssessmentResult: + """Assess the tuning resources required for a given model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the tuning resources + for. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the tuning resources. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning resources. If not + provided, the default configuration will be used. + + Returns: + A types.TuningResourceUsageAssessmentResult object representing the + tuning resource usage assessment result. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( + model_name=model_name + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningResourceUsageAssessmentResult, + response["tuningResourceUsageAssessmentResult"], + config, + ) + + async def assess_tuning_validity( + self, + *, + dataset_name: str, + model_name: str, + dataset_usage: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.TuningValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of tuning a given + model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the tuning validity + for. + dataset_usage: + Required. The dataset usage to assess the tuning validity for. + Must be one of the following: SFT_TRAINING, SFT_VALIDATION. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the tuning validity. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the tuning validity. If not + provided, the default configuration will be used. + + Returns: + A dict containing the tuning validity assessment result. The dict + contains the following keys: + - errors: A list of errors that occurred during the tuning validity + assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + tuning_validation_assessment_config=types.TuningValidationAssessmentConfig( + model_name=model_name, + dataset_usage=dataset_usage, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + return _datasets_utils.create_from_response( + types.TuningValidationAssessmentResult, + response["tuningValidationAssessmentResult"], + config, + ) + + async def assess_batch_prediction_resources( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionResourceUsageAssessmentResult: + """Assess the batch prediction resources required for a given model. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the batch prediction + resources. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the batch prediction resources. If not provided, + the read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction + resources. If not provided, the default configuration will be + used. + + Returns: + A types.BatchPredictionResourceUsageAssessmentResult object + representing the batch prediction resource usage assessment result. + It contains the following keys: + - token_count: The number of tokens in the dataset. + - audio_token_count: The number of audio tokens in the dataset. + + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_resource_usage_assessment_config=types.BatchPredictionResourceUsageAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionResourceUsageAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionResourceUsageAssessmentResult, result, config + ) + + async def assess_batch_prediction_validity( + self, + *, + dataset_name: str, + model_name: str, + gemini_request_read_config: Optional[ + types.GeminiRequestReadConfigOrDict + ] = None, + config: Optional[types.AssessDatasetConfigOrDict] = None, + ) -> types.BatchPredictionValidationAssessmentResult: + """Assess if the assembled dataset is valid in terms of batch prediction + for a given model. Raises an error if the dataset is invalid, otherwise + returns None. + + Args: + dataset_name: + Required. A fully-qualified resource name or ID of the dataset. + Example: "projects/.../locations/.../datasets/123" or "123". + model_name: + Required. The name of the model to assess the batch prediction + validity for. + gemini_request_read_config: + Optional. The read config used to assemble the dataset + before assessing the batch prediction validity. If not provided, the + read config attached to the dataset will be used. Required + if no read config is attached to the dataset. + config: + Optional. A configuration for assessing the batch prediction validity. + If not provided, the default configuration will be used. + + Returns: + A types.BatchPredictionValidationAssessmentResult object representing + the batch prediction validity assessment result. + It contains the following keys: + - errors: A list of errors that occurred during the batch prediction + validity assessment. + """ + if isinstance(config, dict): + config = types.AssessDatasetConfig(**config) + elif not config: + config = types.AssessDatasetConfig() + + dataset_name = _datasets_utils.resolve_dataset_name( + dataset_name, self._api_client.project, self._api_client.location + ) + + operation = await self._assess_multimodal_dataset( + name=dataset_name, + batch_prediction_validation_assessment_config=types.BatchPredictionValidationAssessmentConfig( + model_name=model_name, + ), + gemini_request_read_config=gemini_request_read_config, + config=config, + ) + response = await self._wait_for_operation( + operation=operation, + timeout_seconds=config.timeout, + ) + result = response["batchPredictionValidationAssessmentResult"] + return _datasets_utils.create_from_response( + types.BatchPredictionValidationAssessmentResult, result, config + ) diff --git a/agentplatform/_genai/evals.py b/agentplatform/_genai/evals.py new file mode 100644 index 0000000000..35cc9e9986 --- /dev/null +++ b/agentplatform/_genai/evals.py @@ -0,0 +1,4872 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import json +import logging +from typing import Any, Callable, Optional, Union, cast +from urllib.parse import urlencode +import uuid + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +import pandas as pd + +from . import _evals_common +from . import _evals_utils +from . import _transformers as t +from . import types +from .types import evals as evals_types + +try: + from google.adk.agents import LlmAgent +except ImportError: + LlmAgent = None + + +logger = logging.getLogger("agentplatform_genai.evals") + + +def _CreateEvaluationItemParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["evaluation_item_type"]) is not None: + setv( + to_object, + ["evaluationItemType"], + getv(from_object, ["evaluation_item_type"]), + ) + + if getv(from_object, ["gcs_uri"]) is not None: + setv(to_object, ["gcsUri"], getv(from_object, ["gcs_uri"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CreateEvaluationMetricParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["metric"]) is not None: + setv( + to_object, + ["metric"], + t.t_metric_for_registry(getv(from_object, ["metric"])), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CreateEvaluationRunParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["data_source"]) is not None: + setv(to_object, ["dataSource"], getv(from_object, ["data_source"])) + + if getv(from_object, ["evaluation_config"]) is not None: + setv( + to_object, + ["evaluationConfig"], + _EvaluationRunConfig_to_vertex( + getv(from_object, ["evaluation_config"]), to_object + ), + ) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["inference_configs"]) is not None: + setv( + to_object, + ["inferenceConfigs"], + { + k: _EvaluationRunInferenceConfig_to_vertex(v, to_object) + for k, v in getv(from_object, ["inference_configs"]).items() + }, + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CreateEvaluationSetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["evaluation_items"]) is not None: + setv(to_object, ["evaluationItems"], getv(from_object, ["evaluation_items"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CustomCodeExecutionSpec_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["evaluationFunction"]) is not None: + setv( + to_object, + ["evaluation_function"], + getv(from_object, ["evaluationFunction"]), + ) + + if getv(from_object, ["evaluation_function"]) is not None: + setv( + to_object, + ["remote_custom_function"], + getv(from_object, ["evaluation_function"]), + ) + + return to_object + + +def _CustomCodeExecutionSpec_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["evaluation_function"]) is not None: + setv( + to_object, + ["evaluationFunction"], + getv(from_object, ["evaluation_function"]), + ) + + if getv(from_object, ["remote_custom_function"]) is not None: + setv( + to_object, + ["evaluation_function"], + getv(from_object, ["remote_custom_function"]), + ) + + return to_object + + +def _DeleteEvaluationMetricParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_resource_name"]) is not None: + setv( + to_object, + ["_url", "evaluation_metric"], + getv(from_object, ["metric_resource_name"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _EvaluateInstancesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["bleu_input"]) is not None: + setv(to_object, ["bleuInput"], getv(from_object, ["bleu_input"])) + + if getv(from_object, ["exact_match_input"]) is not None: + setv(to_object, ["exactMatchInput"], getv(from_object, ["exact_match_input"])) + + if getv(from_object, ["rouge_input"]) is not None: + setv(to_object, ["rougeInput"], getv(from_object, ["rouge_input"])) + + if getv(from_object, ["pointwise_metric_input"]) is not None: + setv( + to_object, + ["pointwiseMetricInput"], + getv(from_object, ["pointwise_metric_input"]), + ) + + if getv(from_object, ["pairwise_metric_input"]) is not None: + setv( + to_object, + ["pairwiseMetricInput"], + getv(from_object, ["pairwise_metric_input"]), + ) + + if getv(from_object, ["tool_call_valid_input"]) is not None: + setv( + to_object, + ["toolCallValidInput"], + getv(from_object, ["tool_call_valid_input"]), + ) + + if getv(from_object, ["tool_name_match_input"]) is not None: + setv( + to_object, + ["toolNameMatchInput"], + getv(from_object, ["tool_name_match_input"]), + ) + + if getv(from_object, ["tool_parameter_key_match_input"]) is not None: + setv( + to_object, + ["toolParameterKeyMatchInput"], + getv(from_object, ["tool_parameter_key_match_input"]), + ) + + if getv(from_object, ["tool_parameter_kv_match_input"]) is not None: + setv( + to_object, + ["toolParameterKvMatchInput"], + getv(from_object, ["tool_parameter_kv_match_input"]), + ) + + if getv(from_object, ["rubric_based_metric_input"]) is not None: + setv( + to_object, + ["rubricBasedMetricInput"], + _RubricBasedMetricInput_to_vertex( + getv(from_object, ["rubric_based_metric_input"]), to_object + ), + ) + + if getv(from_object, ["autorater_config"]) is not None: + setv(to_object, ["autoraterConfig"], getv(from_object, ["autorater_config"])) + + if getv(from_object, ["metrics"]) is not None: + setv( + to_object, + ["metrics"], + [item for item in t.t_metrics(getv(from_object, ["metrics"]))], + ) + + if getv(from_object, ["instance"]) is not None: + setv( + to_object, + ["instance"], + _EvaluationInstance_to_vertex(getv(from_object, ["instance"]), to_object), + ) + + if getv(from_object, ["metric_sources"]) is not None: + setv( + to_object, + ["metricSources"], + [ + item + for item in t.t_metric_sources(getv(from_object, ["metric_sources"])) + ], + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _EvaluationInstance_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prompt"]) is not None: + setv(to_object, ["prompt"], getv(from_object, ["prompt"])) + + if getv(from_object, ["response"]) is not None: + setv(to_object, ["response"], getv(from_object, ["response"])) + + if getv(from_object, ["reference"]) is not None: + setv(to_object, ["reference"], getv(from_object, ["reference"])) + + if getv(from_object, ["other_data"]) is not None: + setv(to_object, ["otherData"], getv(from_object, ["other_data"])) + + if getv(from_object, ["agent_data"]) is not None: + setv(to_object, ["agent_eval_data"], getv(from_object, ["agent_data"])) + + if getv(from_object, ["rubric_groups"]) is not None: + setv(to_object, ["rubricGroups"], getv(from_object, ["rubric_groups"])) + + return to_object + + +def _EvaluationMetric_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(from_object, ["displayName"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["metric"]) is not None: + setv( + to_object, + ["metric"], + _UnifiedMetric_from_vertex(getv(from_object, ["metric"]), to_object), + ) + + return to_object + + +def _EvaluationRunConfig_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metrics"]) is not None: + setv( + to_object, + ["metrics"], + [ + _EvaluationRunMetric_from_vertex(item, to_object) + for item in getv(from_object, ["metrics"]) + ], + ) + + if getv(from_object, ["outputConfig"]) is not None: + setv(to_object, ["output_config"], getv(from_object, ["outputConfig"])) + + if getv(from_object, ["autoraterConfig"]) is not None: + setv(to_object, ["autorater_config"], getv(from_object, ["autoraterConfig"])) + + if getv(from_object, ["promptTemplate"]) is not None: + setv(to_object, ["prompt_template"], getv(from_object, ["promptTemplate"])) + + if getv(from_object, ["lossAnalysisConfig"]) is not None: + setv( + to_object, + ["loss_analysis_config"], + [item for item in getv(from_object, ["lossAnalysisConfig"])], + ) + + return to_object + + +def _EvaluationRunConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metrics"]) is not None: + setv( + to_object, + ["metrics"], + [ + _EvaluationRunMetric_to_vertex(item, to_object) + for item in getv(from_object, ["metrics"]) + ], + ) + + if getv(from_object, ["output_config"]) is not None: + setv(to_object, ["outputConfig"], getv(from_object, ["output_config"])) + + if getv(from_object, ["autorater_config"]) is not None: + setv(to_object, ["autoraterConfig"], getv(from_object, ["autorater_config"])) + + if getv(from_object, ["prompt_template"]) is not None: + setv(to_object, ["promptTemplate"], getv(from_object, ["prompt_template"])) + + if getv(from_object, ["loss_analysis_config"]) is not None: + setv( + to_object, + ["lossAnalysisConfig"], + [item for item in getv(from_object, ["loss_analysis_config"])], + ) + + return to_object + + +def _EvaluationRunInferenceConfig_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["agentConfig"]) is not None: + setv(to_object, ["agent_config"], getv(from_object, ["agentConfig"])) + + if getv(from_object, ["model"]) is not None: + setv(to_object, ["model"], getv(from_object, ["model"])) + + if getv(from_object, ["promptTemplate"]) is not None: + setv(to_object, ["prompt_template"], getv(from_object, ["promptTemplate"])) + + if getv(from_object, ["agentRunConfig"]) is not None: + setv(to_object, ["agent_run_config"], getv(from_object, ["agentRunConfig"])) + + if getv(from_object, ["agents"]) is not None: + setv(to_object, ["agent_configs"], getv(from_object, ["agents"])) + + return to_object + + +def _EvaluationRunInferenceConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["agent_config"]) is not None: + setv(to_object, ["agentConfig"], getv(from_object, ["agent_config"])) + + if getv(from_object, ["model"]) is not None: + setv(to_object, ["model"], getv(from_object, ["model"])) + + if getv(from_object, ["prompt_template"]) is not None: + setv(to_object, ["promptTemplate"], getv(from_object, ["prompt_template"])) + + if getv(from_object, ["agent_run_config"]) is not None: + setv(to_object, ["agentRunConfig"], getv(from_object, ["agent_run_config"])) + + if getv(from_object, ["agent_configs"]) is not None: + setv(to_object, ["agents"], getv(from_object, ["agent_configs"])) + + return to_object + + +def _EvaluationRunMetric_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric"]) is not None: + setv(to_object, ["metric"], getv(from_object, ["metric"])) + + if getv(from_object, ["metricResourceName"]) is not None: + setv( + to_object, + ["metric_resource_name"], + getv(from_object, ["metricResourceName"]), + ) + + if getv(from_object, ["metricConfig"]) is not None: + setv( + to_object, + ["metric_config"], + _UnifiedMetric_from_vertex(getv(from_object, ["metricConfig"]), to_object), + ) + + return to_object + + +def _EvaluationRunMetric_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric"]) is not None: + setv(to_object, ["metric"], getv(from_object, ["metric"])) + + if getv(from_object, ["metric_resource_name"]) is not None: + setv( + to_object, + ["metricResourceName"], + getv(from_object, ["metric_resource_name"]), + ) + + if getv(from_object, ["metric_config"]) is not None: + setv( + to_object, + ["metricConfig"], + _UnifiedMetric_to_vertex(getv(from_object, ["metric_config"]), to_object), + ) + + return to_object + + +def _EvaluationRun_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(from_object, ["displayName"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["createTime"]) is not None: + setv(to_object, ["create_time"], getv(from_object, ["createTime"])) + + if getv(from_object, ["completionTime"]) is not None: + setv(to_object, ["completion_time"], getv(from_object, ["completionTime"])) + + if getv(from_object, ["state"]) is not None: + setv(to_object, ["state"], getv(from_object, ["state"])) + + if getv(from_object, ["evaluationSetSnapshot"]) is not None: + setv( + to_object, + ["evaluation_set_snapshot"], + getv(from_object, ["evaluationSetSnapshot"]), + ) + + if getv(from_object, ["error"]) is not None: + setv(to_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["dataSource"]) is not None: + setv(to_object, ["data_source"], getv(from_object, ["dataSource"])) + + if getv(from_object, ["evaluationResults"]) is not None: + setv( + to_object, + ["evaluation_run_results"], + getv(from_object, ["evaluationResults"]), + ) + + if getv(from_object, ["evaluationConfig"]) is not None: + setv( + to_object, + ["evaluation_config"], + _EvaluationRunConfig_from_vertex( + getv(from_object, ["evaluationConfig"]), to_object + ), + ) + + if getv(from_object, ["inferenceConfigs"]) is not None: + setv( + to_object, + ["inference_configs"], + { + k: _EvaluationRunInferenceConfig_from_vertex(v, to_object) + for k, v in getv(from_object, ["inferenceConfigs"]).items() + }, + ) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + return to_object + + +def _GenerateInstanceRubricsRequest_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["contents"]) is not None: + setv(to_object, ["contents"], getv(from_object, ["contents"])) + + if getv(from_object, ["predefined_rubric_generation_spec"]) is not None: + setv( + to_object, + ["predefinedRubricGenerationSpec"], + getv(from_object, ["predefined_rubric_generation_spec"]), + ) + + if getv(from_object, ["rubric_generation_spec"]) is not None: + setv( + to_object, + ["rubricGenerationSpec"], + getv(from_object, ["rubric_generation_spec"]), + ) + + if getv(from_object, ["metric_resource_name"]) is not None: + setv( + to_object, + ["metricResourceName"], + getv(from_object, ["metric_resource_name"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GenerateLossClustersParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["location"]) is not None: + setv(to_object, ["location"], getv(from_object, ["location"])) + + if getv(from_object, ["evaluation_set"]) is not None: + setv(to_object, ["evaluationSet"], getv(from_object, ["evaluation_set"])) + + if getv(from_object, ["inline_results"]) is not None: + setv( + to_object, + ["inlineResults", "evaluationResults"], + [ + item + for item in t.t_inline_results(getv(from_object, ["inline_results"])) + ], + ) + + if getv(from_object, ["configs"]) is not None: + setv(to_object, ["configs"], [item for item in getv(from_object, ["configs"])]) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GenerateUserScenariosParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["location"]) is not None: + setv(to_object, ["location"], getv(from_object, ["location"])) + + if getv(from_object, ["agents"]) is not None: + setv(to_object, ["agents"], getv(from_object, ["agents"])) + + if getv(from_object, ["root_agent_id"]) is not None: + setv(to_object, ["rootAgentId"], getv(from_object, ["root_agent_id"])) + + if getv(from_object, ["user_scenario_generation_config"]) is not None: + setv( + to_object, + ["userScenarioGenerationConfig"], + t.t_user_scenario_generation_config( + getv(from_object, ["user_scenario_generation_config"]) + ), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + if getv(from_object, ["allow_cross_region_model"]) is not None: + setv( + to_object, + ["allowCrossRegionModel"], + getv(from_object, ["allow_cross_region_model"]), + ) + + return to_object + + +def _GetEvaluationItemParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetEvaluationMetricParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_resource_name"]) is not None: + setv( + to_object, + ["_url", "evaluation_metric"], + getv(from_object, ["metric_resource_name"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetEvaluationRunParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetEvaluationSetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _ListEvaluationMetricsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["order_by"]) is not None: + setv(parent_object, ["_query", "orderBy"], getv(from_object, ["order_by"])) + + return to_object + + +def _ListEvaluationMetricsParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListEvaluationMetricsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _ListEvaluationMetricsResponse_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["sdkHttpResponse"]) is not None: + setv(to_object, ["sdk_http_response"], getv(from_object, ["sdkHttpResponse"])) + + if getv(from_object, ["nextPageToken"]) is not None: + setv(to_object, ["next_page_token"], getv(from_object, ["nextPageToken"])) + + if getv(from_object, ["evaluationMetrics"]) is not None: + setv( + to_object, + ["evaluation_metrics"], + [ + _EvaluationMetric_from_vertex(item, to_object) + for item in getv(from_object, ["evaluationMetrics"]) + ], + ) + + return to_object + + +def _RubricBasedMetricInput_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_spec"]) is not None: + setv( + to_object, + ["metricSpec"], + _RubricBasedMetricSpec_to_vertex( + getv(from_object, ["metric_spec"]), to_object + ), + ) + + if getv(from_object, ["instance"]) is not None: + setv(to_object, ["instance"], getv(from_object, ["instance"])) + + return to_object + + +def _RubricBasedMetricSpec_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["metric_prompt_template"]) is not None: + setv( + to_object, + ["metricPromptTemplate"], + getv(from_object, ["metric_prompt_template"]), + ) + + if getv(from_object, ["judge_autorater_config"]) is not None: + setv( + to_object, + ["judgeAutoraterConfig"], + getv(from_object, ["judge_autorater_config"]), + ) + + if getv(from_object, ["inline_rubrics"]) is not None: + setv( + to_object, + ["inline_rubrics", "rubrics"], + getv(from_object, ["inline_rubrics"]), + ) + + if getv(from_object, ["rubric_group_key"]) is not None: + setv(to_object, ["rubricGroupKey"], getv(from_object, ["rubric_group_key"])) + + if getv(from_object, ["rubric_generation_spec"]) is not None: + setv( + to_object, + ["rubricGenerationSpec"], + getv(from_object, ["rubric_generation_spec"]), + ) + + return to_object + + +def _UnifiedMetric_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["bleuSpec"]) is not None: + setv(to_object, ["bleu_spec"], getv(from_object, ["bleuSpec"])) + + if getv(from_object, ["rougeSpec"]) is not None: + setv(to_object, ["rouge_spec"], getv(from_object, ["rougeSpec"])) + + if getv(from_object, ["pointwiseMetricSpec"]) is not None: + setv( + to_object, + ["pointwise_metric_spec"], + getv(from_object, ["pointwiseMetricSpec"]), + ) + + if getv(from_object, ["llmBasedMetricSpec"]) is not None: + setv( + to_object, + ["llm_based_metric_spec"], + getv(from_object, ["llmBasedMetricSpec"]), + ) + + if getv(from_object, ["customCodeExecutionSpec"]) is not None: + setv( + to_object, + ["custom_code_execution_spec"], + _CustomCodeExecutionSpec_from_vertex( + getv(from_object, ["customCodeExecutionSpec"]), to_object + ), + ) + + if getv(from_object, ["predefinedMetricSpec"]) is not None: + setv( + to_object, + ["predefined_metric_spec"], + getv(from_object, ["predefinedMetricSpec"]), + ) + + if getv(from_object, ["computationBasedMetricSpec"]) is not None: + setv( + to_object, + ["computation_based_metric_spec"], + getv(from_object, ["computationBasedMetricSpec"]), + ) + + return to_object + + +def _UnifiedMetric_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["bleu_spec"]) is not None: + setv(to_object, ["bleuSpec"], getv(from_object, ["bleu_spec"])) + + if getv(from_object, ["rouge_spec"]) is not None: + setv(to_object, ["rougeSpec"], getv(from_object, ["rouge_spec"])) + + if getv(from_object, ["pointwise_metric_spec"]) is not None: + setv( + to_object, + ["pointwiseMetricSpec"], + getv(from_object, ["pointwise_metric_spec"]), + ) + + if getv(from_object, ["llm_based_metric_spec"]) is not None: + setv( + to_object, + ["llmBasedMetricSpec"], + getv(from_object, ["llm_based_metric_spec"]), + ) + + if getv(from_object, ["custom_code_execution_spec"]) is not None: + setv( + to_object, + ["customCodeExecutionSpec"], + _CustomCodeExecutionSpec_to_vertex( + getv(from_object, ["custom_code_execution_spec"]), to_object + ), + ) + + if getv(from_object, ["predefined_metric_spec"]) is not None: + setv( + to_object, + ["predefinedMetricSpec"], + getv(from_object, ["predefined_metric_spec"]), + ) + + if getv(from_object, ["computation_based_metric_spec"]) is not None: + setv( + to_object, + ["computationBasedMetricSpec"], + getv(from_object, ["computation_based_metric_spec"]), + ) + + return to_object + + +class Evals(_api_module.BaseModule): + + def _create_evaluation_item( + self, + *, + evaluation_item_type: str, + gcs_uri: str, + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """ + Creates an EvaluationItem. + """ + + parameter_model = types._CreateEvaluationItemParameters( + evaluation_item_type=evaluation_item_type, + gcs_uri=gcs_uri, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationItemParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationItems".format_map(request_url_dict) + else: + path = "evaluationItems" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationItem._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_evaluation_metric( + self, + *, + display_name: Optional[str] = None, + description: Optional[str] = None, + metric: Optional[types.MetricOrDict] = None, + config: Optional[types.CreateEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """ + Creates an EvaluationMetric. + """ + + parameter_model = types._CreateEvaluationMetricParameters( + display_name=display_name, + description=description, + metric=metric, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationMetrics".format_map(request_url_dict) + else: + path = "evaluationMetrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationMetric_from_vertex(response_dict) + + return_value = types.EvaluationMetric._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_evaluation_run( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + data_source: types.EvaluationRunDataSourceOrDict, + evaluation_config: types.EvaluationRunConfigOrDict, + labels: Optional[dict[str, str]] = None, + inference_configs: Optional[ + dict[str, types.EvaluationRunInferenceConfigOrDict] + ] = None, + config: Optional[types.CreateEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """ + Creates an EvaluationRun. + """ + + parameter_model = types._CreateEvaluationRunParameters( + name=name, + display_name=display_name, + data_source=data_source, + evaluation_config=evaluation_config, + labels=labels, + inference_configs=inference_configs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationRunParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationRuns".format_map(request_url_dict) + else: + path = "evaluationRuns" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationRun_from_vertex(response_dict) + + return_value = types.EvaluationRun._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_evaluation_set( + self, + *, + evaluation_items: list[str], + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """ + Creates an EvaluationSet. + """ + + parameter_model = types._CreateEvaluationSetParameters( + evaluation_items=evaluation_items, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationSetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationSets".format_map(request_url_dict) + else: + path = "evaluationSets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationSet._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.DeleteEvaluationMetricConfigOrDict] = None, + ) -> types.DeleteEvaluationMetricOperation: + """ + Deletes an EvaluationMetric. + """ + + parameter_model = types._DeleteEvaluationMetricParameters( + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{evaluation_metric}".format_map(request_url_dict) + else: + path = "{evaluation_metric}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteEvaluationMetricOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _evaluate_instances( + self, + *, + bleu_input: Optional[types.BleuInputOrDict] = None, + exact_match_input: Optional[types.ExactMatchInputOrDict] = None, + rouge_input: Optional[types.RougeInputOrDict] = None, + pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None, + pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None, + tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None, + tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None, + tool_parameter_key_match_input: Optional[ + types.ToolParameterKeyMatchInputOrDict + ] = None, + tool_parameter_kv_match_input: Optional[ + types.ToolParameterKVMatchInputOrDict + ] = None, + rubric_based_metric_input: Optional[types.RubricBasedMetricInputOrDict] = None, + autorater_config: Optional[genai_types.AutoraterConfigOrDict] = None, + metrics: Optional[list[types.MetricOrDict]] = None, + instance: Optional[types.EvaluationInstanceOrDict] = None, + metric_sources: Optional[list[types.MetricSourceOrDict]] = None, + config: Optional[types.EvaluateInstancesConfigOrDict] = None, + ) -> types.EvaluateInstancesResponse: + """ + Evaluates instances based on a given metric. + """ + + parameter_model = types._EvaluateInstancesRequestParameters( + bleu_input=bleu_input, + exact_match_input=exact_match_input, + rouge_input=rouge_input, + pointwise_metric_input=pointwise_metric_input, + pairwise_metric_input=pairwise_metric_input, + tool_call_valid_input=tool_call_valid_input, + tool_name_match_input=tool_name_match_input, + tool_parameter_key_match_input=tool_parameter_key_match_input, + tool_parameter_kv_match_input=tool_parameter_kv_match_input, + rubric_based_metric_input=rubric_based_metric_input, + autorater_config=autorater_config, + metrics=metrics, + instance=instance, + metric_sources=metric_sources, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _EvaluateInstancesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateInstances".format_map(request_url_dict) + else: + path = ":evaluateInstances" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluateInstancesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _generate_user_scenarios( + self, + *, + location: Optional[str] = None, + agents: Optional[dict[str, evals_types.AgentConfigOrDict]] = None, + root_agent_id: Optional[str] = None, + user_scenario_generation_config: Optional[ + evals_types.UserScenarioGenerationConfigOrDict + ] = None, + config: Optional[types.GenerateUserScenariosConfigOrDict] = None, + allow_cross_region_model: Optional[bool] = None, + ) -> types.GenerateUserScenariosResponse: + """ + Generates user scenarios for agent evaluation. + """ + + parameter_model = types._GenerateUserScenariosParameters( + location=location, + agents=agents, + root_agent_id=root_agent_id, + user_scenario_generation_config=user_scenario_generation_config, + config=config, + allow_cross_region_model=allow_cross_region_model, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateUserScenariosParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateUserScenarios".format_map(request_url_dict) + else: + path = ":generateUserScenarios" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateUserScenariosResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _generate_loss_clusters( + self, + *, + location: Optional[str] = None, + evaluation_set: Optional[str] = None, + inline_results: Optional[list[types.EvaluationResultOrDict]] = None, + configs: Optional[list[types.LossAnalysisConfigOrDict]] = None, + config: Optional[types.GenerateLossClustersConfigOrDict] = None, + ) -> types.GenerateLossClustersOperation: + """ + Generates loss clusters from evaluation results. + """ + + parameter_model = types._GenerateLossClustersParameters( + location=location, + evaluation_set=evaluation_set, + inline_results=inline_results, + configs=configs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateLossClusters".format_map(request_url_dict) + else: + path = ":generateLossClusters" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateLossClustersOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _generate_rubrics( + self, + *, + contents: list[genai_types.ContentOrDict], + predefined_rubric_generation_spec: Optional[ + genai_types.PredefinedMetricSpecOrDict + ] = None, + rubric_generation_spec: Optional[genai_types.RubricGenerationSpecOrDict] = None, + metric_resource_name: Optional[str] = None, + config: Optional[types.RubricGenerationConfigOrDict] = None, + ) -> types.GenerateInstanceRubricsResponse: + """ + Generates rubrics for a given prompt. + """ + + parameter_model = types._GenerateInstanceRubricsRequest( + contents=contents, + predefined_rubric_generation_spec=predefined_rubric_generation_spec, + rubric_generation_spec=rubric_generation_spec, + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateInstanceRubricsRequest_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateInstanceRubrics".format_map(request_url_dict) + else: + path = ":generateInstanceRubrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateInstanceRubricsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.GetEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """ + Retrieves an EvaluationMetric from the resource name. + """ + + parameter_model = types._GetEvaluationMetricParameters( + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{evaluation_metric}".format_map(request_url_dict) + else: + path = "{evaluation_metric}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationMetric_from_vertex(response_dict) + + return_value = types.EvaluationMetric._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_evaluation_run( + self, *, name: str, config: Optional[types.GetEvaluationRunConfigOrDict] = None + ) -> types.EvaluationRun: + """ + Retrieves an EvaluationRun from the resource name. + """ + + parameter_model = types._GetEvaluationRunParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationRunParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationRuns/{name}".format_map(request_url_dict) + else: + path = "evaluationRuns/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationRun_from_vertex(response_dict) + + return_value = types.EvaluationRun._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_evaluation_set( + self, *, name: str, config: Optional[types.GetEvaluationSetConfigOrDict] = None + ) -> types.EvaluationSet: + """ + Retrieves an EvaluationSet from the resource name. + """ + + parameter_model = types._GetEvaluationSetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationSetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationSets/{name}".format_map(request_url_dict) + else: + path = "evaluationSets/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationSet._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_evaluation_item( + self, *, name: str, config: Optional[types.GetEvaluationItemConfigOrDict] = None + ) -> types.EvaluationItem: + """ + Retrieves an EvaluationItem from the resource name. + """ + + parameter_model = types._GetEvaluationItemParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationItemParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationItems/{name}".format_map(request_url_dict) + else: + path = "evaluationItems/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationItem._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list_evaluation_metrics( + self, *, config: Optional[types.ListEvaluationMetricsConfigOrDict] = None + ) -> types.ListEvaluationMetricsResponse: + """ + Lists EvaluationMetrics. + """ + + parameter_model = types._ListEvaluationMetricsParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListEvaluationMetricsParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationMetrics".format_map(request_url_dict) + else: + path = "evaluationMetrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListEvaluationMetricsResponse_from_vertex(response_dict) + + return_value = types.ListEvaluationMetricsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def evaluate_instances( + self, + *, + metric_config: types._EvaluateInstancesRequestParameters, + ) -> types.EvaluateInstancesResponse: + """Evaluates an instance of a model.""" + + if isinstance(metric_config, types._EvaluateInstancesRequestParameters): + metric_config = metric_config.model_dump() # type: ignore[assignment] + else: + metric_config = dict(metric_config) + + return self._evaluate_instances( + **metric_config, + ) + + def run_inference( + self, + *, + src: Union[str, pd.DataFrame, types.EvaluationDataset], + model: Optional[Union[str, Callable[[Any], Any]]] = None, + agent: Optional[Union[str, types.AgentEngine, LlmAgent]] = None, + location: Optional[str] = None, + config: Optional[types.EvalRunInferenceConfigOrDict] = None, + ) -> types.EvaluationDataset: + """Runs inference on a dataset for evaluation. + + Args: + src: The source of the dataset. Can be a string (path to a local file, + a GCS path, or a BigQuery table), a Pandas DataFrame, or an + EvaluationDataset object. An EvaluationDataset may have either + ``eval_dataset_df`` or ``eval_cases`` populated. When + ``eval_cases`` with ``agent_data`` is provided, the last user + event in the turns is used as the current prompt and prior + events are replayed as session history for local ADK agents. + model: Optional type is experimental and may change in future versions. + The model to use for inference, optional for agent evaluations. + - For Google Gemini models, provide the model name string (e.g., "gemini-2.5-flash"). + - For third-party models via LiteLLM, use the format "provider/model_name" + (e.g., "openai/gpt-4o"). Ensure the necessary API key (e.g., OPENAI_API_KEY) + is set as an environment variable. + - For custom logic, provide a callable function that accepts a prompt and + returns a response. + agent: This field is experimental and may change in future versions + The agent engine used or local agent to run agent, optional for non-agent evaluations. + - agent engine resource name in str type, with format + `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`, + run_inference will fetch the agent engine from the resource name. + - Or `types.AgentEngine` object. + - Or ADK agent in LlMAgent type. + location: The location to use for the inference. If not specified, the + location configured in the client will be used. If specified, + this will override the location set in `vertexai.Client` only + for this API call. + config: The optional configuration for the inference run. Must be a dict or + `types.EvalRunInferenceConfig` type. + - dest: The destination path for storage of the inference results. + - prompt_template: The template string to use for constructing prompts. + - generate_content_config: The config for the Gemini generate content call. + - allow_cross_region_model: Opt-in flag to authorize cross-region routing for LLM models. + + Returns: + The evaluation dataset. + """ + if not config: + config = types.EvalRunInferenceConfig() + if isinstance(config, dict): + config = types.EvalRunInferenceConfig.model_validate(config) + + if isinstance(src, types.EvaluationDataset): + if src.eval_dataset_df is not None: + src = src.eval_dataset_df + elif src.eval_cases: + src = _evals_common._eval_cases_to_dataframe(src.eval_cases) + else: + raise ValueError( + "EvaluationDataset must have eval_dataset_df or eval_cases" + " populated." + ) + + agent_engine_instance = None + agent_instance = None + if agent: + if isinstance(agent, str) or isinstance(agent, types.AgentEngine): + agent_engine_instance = agent + else: + agent_instance = agent + + return _evals_common._execute_inference( # type: ignore[no-any-return] + api_client=self._api_client, + model=model, + agent_engine=agent_engine_instance, + agent=agent_instance, + src=src, + dest=config.dest, + prompt_template=config.prompt_template, + location=location, + config=config.generate_content_config, + user_simulator_config=getattr(config, "user_simulator_config", None), + allow_cross_region_model=getattr(config, "allow_cross_region_model", False), + ) + + def evaluate( + self, + *, + dataset: Union[ + pd.DataFrame, + types.EvaluationDatasetOrDict, + list[types.EvaluationDatasetOrDict], + ], + metrics: Optional[list[types.MetricOrDict]] = None, + location: Optional[str] = None, + config: Optional[types.EvaluateMethodConfigOrDict] = None, + **kwargs: Any, + ) -> types.EvaluationResult: + """Evaluates candidate responses in the provided dataset(s) using the specified metrics. + + Args: + dataset: The dataset(s) to evaluate. Can be a pandas DataFrame, a single + `types.EvaluationDataset` or a list of `types.EvaluationDataset`. + metrics: The list of metrics to use for evaluation. + location: The location to use for the evaluation service. If not specified, + the location configured in the client will be used. If specified, + this will override the location set in `vertexai.Client` only for + this API call. + config: Optional configuration for the evaluation. Can be a dictionary or a + `types.EvaluateMethodConfig` object. + - dataset_schema: Schema to use for the dataset. If not specified, the + dataset schema will be inferred from the dataset automatically. + - dest: Destination path for storing evaluation results. + - evaluation_service_qps: The rate limit (queries per second) for + calls to the evaluation service. Defaults to 10. Increase this + value if your project has a higher EvaluateInstances API quota. + **kwargs: Extra arguments to pass to evaluation, such as `agent_info`. + + Returns: + The evaluation result. + """ + if not config: + config = types.EvaluateMethodConfig() + if isinstance(config, dict): + config = types.EvaluateMethodConfig.model_validate(config) + + if isinstance(dataset, pd.DataFrame): + dataset = types.EvaluationDataset(eval_dataset_df=dataset) + + if isinstance(dataset, list): + dataset = [ + ( + types.EvaluationDataset.model_validate(ds_item) + if isinstance(ds_item, dict) + else ds_item + ) + for ds_item in dataset + ] + else: + if isinstance(dataset, dict): + dataset = types.EvaluationDataset.model_validate(dataset) + if metrics is None: + metrics = [types.Metric(name="general_quality_v1")] + + # TODO: Replace kwargs with agent_info after the experimental phase. + if kwargs: + logger.warning( + "`kwargs` attribute in `evaluate` method is experimental and may change in future versions." + ) + + return _evals_common._execute_evaluation( + api_client=self._api_client, + dataset=dataset, + metrics=metrics, + dataset_schema=config.dataset_schema, + dest=config.dest, + location=location, + evaluation_service_qps=getattr(config, "evaluation_service_qps", None), + **kwargs, + ) + + def batch_evaluate( + self, + *, + dataset: types.EvaluationDatasetOrDict, + metrics: list[types.MetricOrDict], + dest: str, + config: Optional[types.EvaluateDatasetConfigOrDict] = None, + ) -> types.EvaluateDatasetOperation: + """Evaluates a dataset based on a set of given metrics.""" + + resolved_metrics = _evals_common._resolve_metrics(metrics, self._api_client) + output_config = genai_types.OutputConfig( + gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest) + ) + parameter_model = types.EvaluateDatasetRequestParameters( + dataset=dataset, + metrics=resolved_metrics, + output_config=output_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _evals_utils.BatchEvaluateRequestPreparer.EvaluateDatasetRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateDataset".format_map(request_url_dict) + else: + path = ":evaluateDataset" + + request_dict = _evals_utils.BatchEvaluateRequestPreparer.prepare_metric_payload( + request_dict, resolved_metrics + ) + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _evals_utils.BatchEvaluateRequestPreparer.EvaluateDatasetOperation_from_vertex( + response_dict + ) + + return_value = types.EvaluateDatasetOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + def generate_rubrics( + self, + *, + src: Union[str, "pd.DataFrame", types.EvaluationDataset], + rubric_group_name: str, + prompt_template: Optional[str] = None, + generator_model_config: Optional["genai_types.AutoraterConfigOrDict"] = None, + rubric_content_type: Optional["types.RubricContentType"] = None, + rubric_type_ontology: Optional[list[str]] = None, + predefined_spec_name: Optional[Union[str, "types.PrebuiltMetric"]] = None, + metric_spec_parameters: Optional[dict[str, Any]] = None, + metric: Optional[types.MetricOrDict] = None, + config: Optional[types.RubricGenerationConfigOrDict] = None, + ) -> types.EvaluationDataset: + """Generates rubrics for each prompt in the source and adds them as a new column + structured as a dictionary. + + You can generate rubrics by providing either: + 1. A `metric` to use a pre-registered metric resource. + 2. A `predefined_spec_name` to use a Vertex AI backend recipe. + 3. A `prompt_template` along with other configuration parameters + (`generator_model_config`, `rubric_content_type`, `rubric_type_ontology`) + for custom rubric generation. + with `metric` taking precedence over `predefined_spec_name`, + and `predefined_spec_name` taking precedence over `prompt_template` + + These two modes are mutually exclusive. + + Args: + src: The source of the prompts. Can be a string (path to a local + file, a GCS path, or a BigQuery table), a Pandas DataFrame, or + an EvaluationDataset object. The loaded data must contain either + a 'prompt' column (for text) or a 'request' column (for text or + multimodal Gemini Content). + rubric_group_name: Name for the key within the dictionary in the new + column. + prompt_template: Optional. Template for the rubric generation prompt. Used for + custom rubric generation. Mutually exclusive with `predefined_spec_name`. + If using a 'prompt' column, use {prompt} as the placeholder. If using a + 'request' column, this template is passed to the service along + with the content. + generator_model_config: Optional. Configuration for the model used + in custom rubric generation. Only used if `prompt_template` is provided. + e.g., {"autorater_model": "gemini-2.5-flash"}. + rubric_content_type: Optional. The type of rubric content to be + generated. Only used if `prompt_template` is provided. + rubric_type_ontology: Optional. A pre-defined list of allowed types + for generated rubrics. Only used if `prompt_template` is provided. + predefined_spec_name: Optional. The name of a Predefined Metric to use + for rubric generation (e.g., "general_quality_v1") or a types.PrebuiltMetric object. + Mutually exclusive with `prompt_template` and its related parameters. + metric_spec_parameters: Optional. Parameters for the Predefined Metric, + used to customize rubric generation. Only used if `predefined_spec_name` is set. + Example: {"guidelines": ["The response must be in Japanese."]} + metric: Optional. A types.Metric object containing a metric_resource_name, + or a resource name string. If provided, this will take precedence over + predefined_spec_name and prompt_template. + config: Optional. Configuration for the rubric generation process. + + Returns: + An `EvaluationDataset` with an added column named `rubric_groups` in its + `eval_dataset_df`. Each cell in this column contains a dictionary like: + {rubric_group_name: [list[Rubric]]}. + """ + if isinstance(src, types.EvaluationDataset): + if src.eval_dataset_df is not None: + prompts_df = src.eval_dataset_df + elif src.eval_cases: + prompts_df = _evals_common._eval_cases_to_dataframe(src.eval_cases) + else: + raise ValueError( + "EvaluationDataset must have eval_dataset_df or eval_cases" + " populated." + ) + elif isinstance(src, (str, pd.DataFrame)): + try: + prompts_df = _evals_common._load_dataframe(self._api_client, src) + except Exception as e: + raise ValueError( + f"Failed to load prompt dataset from source: {src}. Error: {e}" + ) + else: + raise TypeError( + "Unsupported type for src. Must be str, pd.DataFrame, or types.EvaluationDataset." + ) + + if "prompt" not in prompts_df.columns and "request" not in prompts_df.columns: + raise ValueError("Loaded dataset must have a 'prompt' or 'request' column.") + if not rubric_group_name: + raise ValueError("rubric_group_name cannot be empty.") + + input_column = "request" if "request" in prompts_df.columns else "prompt" + logger.info( + "Generating rubrics for %d prompts from column '%s', group: '%s'...", + len(prompts_df), + input_column, + rubric_group_name, + ) + all_rubric_groups: list[dict[str, list[types.Rubric]]] = [] + + actual_metric_resource_name = None + if metric: + if isinstance(metric, str) and metric.startswith("projects/"): + actual_metric_resource_name = metric + else: + metric_obj = ( + types.Metric.model_validate(metric) + if isinstance(metric, dict) + else metric + ) + actual_metric_resource_name = getattr( + metric_obj, "metric_resource_name", None + ) + if not actual_metric_resource_name: + raise ValueError( + "The provided Metric object must have metric_resource_name set." + ) + + rubric_gen_spec = None + predefined_spec = None + + if actual_metric_resource_name: + # Precedence: Registered metric resource overrides everything else. + predefined_spec = None + rubric_gen_spec = None + elif predefined_spec_name: + if prompt_template: + logger.warning( + "prompt_template is ignored when predefined_spec_name is provided." + ) + if generator_model_config: + logger.warning( + "generator_model_config is ignored when predefined_spec_name is provided." + ) + if rubric_content_type: + logger.warning( + "rubric_content_type is ignored when predefined_spec_name is provided." + ) + if rubric_type_ontology: + logger.warning( + "rubric_type_ontology is ignored when predefined_spec_name is provided." + ) + + if isinstance(predefined_spec_name, str): + actual_predefined_spec_name = predefined_spec_name + elif hasattr( + predefined_spec_name, "resolve" + ): # Check if it's LazyLoadedPrebuiltMetric + try: + resolved_metric = predefined_spec_name.resolve(self._api_client) + actual_predefined_spec_name = resolved_metric.name + except Exception as e: + raise ValueError(f"Failed to resolve PrebuiltMetric: {e}") + else: + raise TypeError( + "predefined_spec_name must be a string or types.PrebuiltMetric" + ) + + if not actual_predefined_spec_name: + raise ValueError( + "Could not determine metric_spec_name from predefined_spec_name" + ) + + predefined_spec = genai_types.PredefinedMetricSpec( + metric_spec_name=actual_predefined_spec_name, + metric_spec_parameters=metric_spec_parameters, + ) + elif prompt_template: + if metric_spec_parameters: + logger.warning( + "metric_spec_parameters is ignored when prompt_template is provided." + ) + spec_dict = { + "prompt_template": prompt_template, + "rubric_content_type": rubric_content_type, + "rubric_type_ontology": rubric_type_ontology, + "generator_model_config": generator_model_config, + } + spec_dict = {k: v for k, v in spec_dict.items() if v is not None} + rubric_gen_spec = genai_types.RubricGenerationSpec.model_validate(spec_dict) + else: + raise ValueError( + "Either metric, predefined_spec_name or prompt_template must be provided." + ) + + for _, row in prompts_df.iterrows(): + input_data = row[input_column] + if isinstance(input_data, str): + contents = [ + genai_types.Content(parts=[genai_types.Part(text=input_data)]) + ] + elif isinstance(input_data, list): + contents = input_data + else: + logger.warning( + f"Skipping row: Unexpected input format in column '{input_column}'." + ) + all_rubric_groups.append({rubric_group_name: []}) + continue + + try: + response = self._generate_rubrics( + contents=contents, + rubric_generation_spec=rubric_gen_spec, + predefined_rubric_generation_spec=predefined_spec, + metric_resource_name=actual_metric_resource_name, + config=config, + ) + rubric_group = {rubric_group_name: response.generated_rubrics} + all_rubric_groups.append(rubric_group) + except Exception as e: + logger.error( + "Rubric generation failed for input: %s... Error: %s", + str(input_data)[:50], + e, + exc_info=True, + ) + all_rubric_groups.append({rubric_group_name: []}) + + prompts_with_rubrics = prompts_df.copy() + prompts_with_rubrics["rubric_groups"] = all_rubric_groups + logger.info( + f"Rubric generation complete. Added column 'rubric_groups' with key '{rubric_group_name}'." + ) + return types.EvaluationDataset(eval_dataset_df=prompts_with_rubrics) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_run module is experimental, " + "and may change in future versions." + ) + def get_evaluation_run( + self, + *, + name: str, + include_evaluation_items: bool = False, + config: Optional[types.GetEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """Retrieves an EvaluationRun from the resource name. + Args: + name: The resource name of the EvaluationRun. Format: + `projects/{project}/locations/{location}/evaluationRuns/{evaluation_run}` + include_evaluation_items: Whether to include the evaluation items in the + response. + config: The optional configuration for the evaluation run. Must be a dict or + `types.GetEvaluationRunConfigOrDict` type. + + Returns: + The evaluation run. + Raises: + ValueError: If the name is empty or invalid. + """ + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + result = self._get_evaluation_run(name=name, config=config) + if include_evaluation_items: + eval_result, eval_item_map = _evals_common._convert_evaluation_run_results( + self._api_client, + result.evaluation_run_results, + result.inference_configs, + ) + result.evaluation_item_results = eval_result + # Bypass pydantic validation (extra='forbid') for this internal field. + object.__setattr__(result, "_eval_item_map", eval_item_map) + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_run module is experimental, " + "and may change in future versions." + ) + def create_evaluation_run( + self, + *, + dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset], + dest: str, + metrics: list[types.EvaluationRunMetricOrDict], + name: Optional[str] = None, + display_name: Optional[str] = None, + agent_info: Optional[evals_types.AgentInfoOrDict] = None, + agent: Optional[str] = None, + user_simulator_config: Optional[evals_types.UserSimulatorConfigOrDict] = None, + inference_configs: Optional[ + dict[str, types.EvaluationRunInferenceConfigOrDict] + ] = None, + labels: Optional[dict[str, str]] = None, + loss_analysis_metrics: Optional[list[Union[str, types.MetricOrDict]]] = None, + loss_analysis_configs: Optional[list[types.LossAnalysisConfigOrDict]] = None, + config: Optional[types.CreateEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """Creates an EvaluationRun. + + Args: + dataset: The dataset to evaluate. Either an EvaluationRunDataSource or an EvaluationDataset. + dest: The GCS URI prefix to write the evaluation results to. + metrics: The list of metrics to evaluate. + name: The name of the evaluation run. + display_name: The display name of the evaluation run. + agent_info: The agent info to evaluate. Mutually exclusive with + `inference_configs`. + agent: The agent engine resource name in str type, with format + `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`. + If provided, runs inference with the deployed agent to get agent responses + for evaluation. This is required if `agent_info` is provided. + user_simulator_config: The user simulator configuration for agent evaluation. + If `agent_info` is provided without `inference_configs`, this config is used + to automatically construct the inference configuration. If not specified, + or if `max_turn` is not set, `max_turn` defaults to 5. + The `model_name` inside this config can be either a full model path or a + short model name, e.g. `gemini-3-preview-flash`. + inference_configs: The candidate to inference config map for the evaluation run. + The key is the candidate name, and the value is the inference config. + If provided, `agent_info` must be None. If omitted and `agent_info` is provided, + this will be automatically constructed using `agent_info` and `user_simulator_config`. + Example: + {"candidate-1": types.EvaluationRunInferenceConfig(model="gemini-2.5-flash")} + labels: The labels to apply to the evaluation run. + loss_analysis_metrics: This field is experimental and may change in future + versions. Optional list of metrics to run loss analysis on. The + candidate is auto-inferred from ``inference_configs`` or + ``agent_info`` when there is exactly one candidate. Each metric can be + a string (e.g., ``"multi_turn_task_success_v1"``), a ``Metric`` + object, or a ``RubricMetric`` enum + (e.g., ``types.RubricMetric.MULTI_TURN_TASK_SUCCESS``). Loss analysis + runs after metric calculation completes. + Mutually exclusive with ``loss_analysis_configs``. + Example:: + + loss_analysis_metrics=[ + types.RubricMetric.MULTI_TURN_TASK_SUCCESS, + types.RubricMetric.MULTI_TURN_TOOL_USE_QUALITY, + ] + loss_analysis_configs: This field is experimental and may change in future + versions. Optional list of ``LossAnalysisConfig`` objects for full + control over loss analysis, including explicit candidate and + advanced options like ``predefined_taxonomy`` and + ``max_top_cluster_count``. Mutually exclusive with + ``loss_analysis_metrics``. + config: The configuration for the evaluation run. + + Returns: + The created evaluation run. + """ + if loss_analysis_metrics and loss_analysis_configs: + raise ValueError( + "At most one of loss_analysis_metrics or loss_analysis_configs" + " can be provided." + ) + if agent_info and inference_configs: + raise ValueError( + "At most one of agent_info or inference_configs can be provided." + ) + parsed_agent_info = ( + evals_types.AgentInfo.model_validate(agent_info) + if isinstance(agent_info, dict) + else (agent_info or evals_types.AgentInfo()) + ) + + if agent_info and not inference_configs: + parsed_user_simulator_config = ( + evals_types.UserSimulatorConfig.model_validate(user_simulator_config) + if isinstance(user_simulator_config, dict) + else (user_simulator_config or evals_types.UserSimulatorConfig()) + ) + if getattr(parsed_user_simulator_config, "max_turn", None) is None: + parsed_user_simulator_config.max_turn = 5 + + candidate_name = parsed_agent_info.name or "candidate-1" + inference_configs = { + candidate_name: types.EvaluationRunInferenceConfig( + agent_configs=parsed_agent_info.agents, + agent_run_config=types.AgentRunConfig( + agent_engine=agent, + user_simulator_config=parsed_user_simulator_config, + ), + ) + } + + if isinstance(dataset, types.EvaluationDataset): + _evals_utils._validate_dataset_agent_data(dataset, inference_configs) + resolved_dataset = _evals_common._resolve_dataset( + self._api_client, dataset, dest, parsed_agent_info + ) + output_config = genai_types.OutputConfig( + gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest) + ) + resolved_metrics = _evals_common._resolve_evaluation_run_metrics( + metrics, self._api_client + ) + resolved_loss_configs = _evals_utils._resolve_eval_run_loss_configs( + loss_analysis_metrics=loss_analysis_metrics, + loss_analysis_configs=loss_analysis_configs, + inference_configs=inference_configs, + ) + evaluation_config = types.EvaluationRunConfig( + output_config=output_config, + metrics=resolved_metrics, + loss_analysis_config=resolved_loss_configs, + ) + resolved_inference_configs = _evals_common._resolve_inference_configs( + self._api_client, resolved_dataset, inference_configs, parsed_agent_info + ) + resolved_labels = _evals_common._add_evaluation_run_labels(labels, agent) + resolved_name = name or f"evaluation_run_{uuid.uuid4()}" + return self._create_evaluation_run( + name=resolved_name, + display_name=display_name or resolved_name, + data_source=resolved_dataset, + evaluation_config=evaluation_config, + inference_configs=resolved_inference_configs, + labels=resolved_labels, + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_set method is experimental, " + "and may change in future versions." + ) + def get_evaluation_set( + self, + *, + name: str, + config: Optional[types.GetEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """Retrieves an EvaluationSet from the resource name. + + Args: + name: The resource name of the EvaluationSet. Format: + `projects/{project}/locations/{location}/evaluationSets/{evaluation_set}` + config: The optional configuration for the evaluation set. Must be a dict or + `types.GetEvaluationSetConfigOrDict` type. + + Returns: + The evaluation set. + """ + + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + return self._get_evaluation_set(name=name, config=config) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_item method is experimental, " + "and may change in future versions." + ) + def get_evaluation_item( + self, + *, + name: str, + config: Optional[types.GetEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """Retrieves an EvaluationItem from the resource name. + + Args: + name: The resource name of the EvaluationItem. Format: + `projects/{project}/locations/{location}/evaluationItems/{evaluation_item}` + config: The optional configuration for the evaluation item. Must be a dict or + `types.GetEvaluationItemConfigOrDict` type. + + Returns: + The evaluation item. + """ + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + result = self._get_evaluation_item(name=name, config=config) + if ( + result.gcs_uri + and result.evaluation_item_type == types.EvaluationItemType.RESULT + ): + result.evaluation_response = ( + _evals_common._convert_gcs_to_evaluation_item_result( + self._api_client, result.gcs_uri + ) + ) + elif ( + result.gcs_uri + and result.evaluation_item_type == types.EvaluationItemType.REQUEST + ): + result.evaluation_request = ( + _evals_common._convert_gcs_to_evaluation_item_request( + self._api_client, result.gcs_uri + ) + ) + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_item module is experimental, " + "and may change in future versions." + ) + def create_evaluation_item( + self, + *, + evaluation_item_type: types.EvaluationItemType, + gcs_uri: str, + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """Creates an EvaluationItem. + + Args: + evaluation_item_type: The type of the evaluation item. + gcs_uri: The GCS URI of the evaluation item. + display_name: The display name of the evaluation item. + config: The optional configuration for the evaluation item. Must be a dict or + `types.CreateEvaluationItemConfigOrDict` type. + + Returns: + The evaluation item. + """ + return self._create_evaluation_item( + evaluation_item_type=evaluation_item_type, + gcs_uri=gcs_uri, + display_name=display_name, + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_set module is experimental, " + "and may change in future versions." + ) + def create_evaluation_set( + self, + *, + evaluation_items: list[str], + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """Creates an EvaluationSet. + + Args: + evaluation_items: The list of evaluation item names. Format: + `projects/{project}/locations/{location}/evaluationItems/{evaluation_item}` + display_name: The display name of the evaluation set. + config: The optional configuration for the evaluation set. Must be a dict or + `types.CreateEvaluationSetConfigOrDict` type. + + Returns: + The evaluation set. + """ + return self._create_evaluation_set( + evaluation_items=evaluation_items, + display_name=display_name, + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.generate_conversation_scenarios module is experimental, " + "and may change in future versions." + ) + def generate_conversation_scenarios( + self, + *, + agent_info: evals_types.AgentInfoOrDict, + config: evals_types.UserScenarioGenerationConfigOrDict, + allow_cross_region_model: Optional[bool] = None, + ) -> types.EvaluationDataset: + """Generates an evaluation dataset with user scenarios, + which helps to generate conversations between a simulated user + and the agent under test. + + Args: + agent_info: The agent info to generate user scenarios for. + config: Configuration for generating user scenarios. + allow_cross_region_model: Opt-in flag to authorize cross-region + routing for model inference. + + Returns: + An EvaluationDataset containing the generated user scenarios. + """ + parsed_agent_info = ( + evals_types.AgentInfo.model_validate(agent_info) + if isinstance(agent_info, dict) + else agent_info + ) + response = self._generate_user_scenarios( + agents=parsed_agent_info.agents, + root_agent_id=parsed_agent_info.root_agent_id, + user_scenario_generation_config=config, + allow_cross_region_model=allow_cross_region_model, + ) + return _evals_utils._postprocess_user_scenarios_response(response) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.generate_loss_clusters module is experimental, " + "and may change in future versions." + ) + def generate_loss_clusters( + self, + *, + eval_result: types.EvaluationResult, + metric: Optional[Union[str, types.MetricOrDict]] = None, + candidate: Optional[str] = None, + config: Optional[types.LossAnalysisConfigOrDict] = None, + ) -> types.GenerateLossClustersResponse: + """Generates loss clusters from evaluation results. + + Analyzes "Pass/Fail" signals from rubric-based autoraters and groups + them into semantic "Loss Patterns" (e.g., "Hallucination of Action"). + + This method calls the GenerateLossClusters LRO and polls until + completion, returning the results directly. + + If ``metric`` or ``candidate`` are not provided, they will be + auto-inferred from ``eval_result`` when unambiguous (i.e., when the + eval result contains exactly one metric or one candidate). For + multi-metric or multi-candidate evaluations, provide them explicitly. + + Available candidate names can be found in + ``eval_result.metadata.candidate_names``. + + Note: This API is only available in the ``global`` region. + + Args: + eval_result: The EvaluationResult object returned from + client.evals.evaluate(). + metric: The metric to analyze. Can be a metric name string + (e.g., "multi_turn_task_success_v1"), a Metric object, or a + RubricMetric enum (e.g., types.RubricMetric.MULTI_TURN_TASK_SUCCESS). + If not provided and config does not specify it, auto-inferred + from eval_result. + candidate: The candidate to analyze. If not provided and config + does not specify it, auto-inferred from eval_result. + config: Optional LossAnalysisConfig with additional options + (predefined_taxonomy, max_top_cluster_count). Can also + specify metric/candidate, but explicit arguments take + precedence. + + Returns: + A GenerateLossClustersResponse containing the analysis results. + Call .show() to visualize, or access .results for individual + LossAnalysisResult objects (each with their own .show()). + """ + metric_name = _evals_utils._resolve_metric_name(metric) + parsed_config = ( + types.LossAnalysisConfig.model_validate(config) + if isinstance(config, dict) + else config + ) + resolved_config = _evals_utils._resolve_loss_analysis_config( + eval_result=eval_result, + config=parsed_config, + metric=metric_name, + candidate=candidate, + ) + operation = self._generate_loss_clusters( + inline_results=[eval_result], + configs=[resolved_config], + ) + completed = _evals_utils._poll_operation( + api_client=self._api_client, + operation=operation, + ) + if completed.error: + raise RuntimeError(f"Loss analysis operation failed: {completed.error}") + if completed.response is None: + raise RuntimeError( + "Loss analysis operation completed but returned no response." + ) + _evals_utils._enrich_loss_response_with_rubric_descriptions( + completed.response, eval_result + ) + return completed.response + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_metric method is experimental, " + "and may change in future versions." + ) + def create_evaluation_metric( + self, + *, + display_name: Optional[str] = None, + description: Optional[str] = None, + metric: Optional[types.MetricOrDict] = None, + config: Optional[types.CreateEvaluationMetricConfigOrDict] = None, + ) -> str: + """Creates an EvaluationMetric.""" + if metric and not isinstance(metric, dict): + # metric is now Metric | LazyLoadedPrebuiltMetric (RubricMetric) + # Mypy correctly narrows the type here, so cast is not needed. + resolved_metrics = _evals_common._resolve_metrics( + [metric], self._api_client + ) + metric = resolved_metrics[0] + + # Add fallback logic for display_name + if display_name is None and metric: + if isinstance(metric, dict): + display_name = metric.get("name") + else: + display_name = getattr(metric, "name", None) + + result = self._create_evaluation_metric( + display_name=display_name, + description=description, + metric=metric, + config=config, + ) + # result.name is Optional[str], but we know it's always returned on creation + return cast(str, result.name) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_metric module is experimental, " + "and may change in future versions." + ) + def get_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.GetEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """Retrieves an EvaluationMetric from the resource name.""" + return self._get_evaluation_metric( + metric_resource_name=metric_resource_name, + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.list_evaluation_metrics module is experimental, " + "and may change in future versions." + ) + def list_evaluation_metrics( + self, + *, + filter: Optional[str] = None, + order_by: Optional[str] = None, + config: Optional[types.ListEvaluationMetricsConfigOrDict] = None, + ) -> types.ListEvaluationMetricsResponse: + """Lists EvaluationMetrics. + + Args: + filter: An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. For more + information about filter syntax, see + `AIP-160 `_. + Example: ``'display_name="my_metric"'``. + order_by: A comma-separated list of fields to order by, sorted in + ascending order by default. Use ``desc`` after a field name for + descending. Example: ``"create_time desc"``. + config: Optional configuration for the list operation, including + pagination (``page_size``, ``page_token``), ``filter``, and + ``order_by``. Top-level ``filter`` and ``order_by`` arguments + take precedence over values set in ``config``. + + Returns: + The list evaluation metrics response. + """ + if config is None: + config = types.ListEvaluationMetricsConfig() + if isinstance(config, dict): + config = types.ListEvaluationMetricsConfig.model_validate(config) + if filter is not None: + config.filter = filter + if order_by is not None: + config.order_by = order_by + return self._list_evaluation_metrics( + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.delete_evaluation_metric method is experimental, " + "and may change in future versions." + ) + def delete_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.DeleteEvaluationMetricConfigOrDict] = None, + ) -> None: + """Deletes an EvaluationMetric. + + Args: + metric_resource_name: The resource name of the EvaluationMetric to delete. + Format: + `projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric}` + config: The optional configuration for the delete operation. + """ + self._delete_evaluation_metric( + metric_resource_name=metric_resource_name, + config=config, + ) + + +class AsyncEvals(_api_module.BaseModule): + + async def _create_evaluation_item( + self, + *, + evaluation_item_type: str, + gcs_uri: str, + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """ + Creates an EvaluationItem. + """ + + parameter_model = types._CreateEvaluationItemParameters( + evaluation_item_type=evaluation_item_type, + gcs_uri=gcs_uri, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationItemParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationItems".format_map(request_url_dict) + else: + path = "evaluationItems" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationItem._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_evaluation_metric( + self, + *, + display_name: Optional[str] = None, + description: Optional[str] = None, + metric: Optional[types.MetricOrDict] = None, + config: Optional[types.CreateEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """ + Creates an EvaluationMetric. + """ + + parameter_model = types._CreateEvaluationMetricParameters( + display_name=display_name, + description=description, + metric=metric, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationMetrics".format_map(request_url_dict) + else: + path = "evaluationMetrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationMetric_from_vertex(response_dict) + + return_value = types.EvaluationMetric._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_evaluation_run( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + data_source: types.EvaluationRunDataSourceOrDict, + evaluation_config: types.EvaluationRunConfigOrDict, + labels: Optional[dict[str, str]] = None, + inference_configs: Optional[ + dict[str, types.EvaluationRunInferenceConfigOrDict] + ] = None, + config: Optional[types.CreateEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """ + Creates an EvaluationRun. + """ + + parameter_model = types._CreateEvaluationRunParameters( + name=name, + display_name=display_name, + data_source=data_source, + evaluation_config=evaluation_config, + labels=labels, + inference_configs=inference_configs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationRunParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationRuns".format_map(request_url_dict) + else: + path = "evaluationRuns" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationRun_from_vertex(response_dict) + + return_value = types.EvaluationRun._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_evaluation_set( + self, + *, + evaluation_items: list[str], + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """ + Creates an EvaluationSet. + """ + + parameter_model = types._CreateEvaluationSetParameters( + evaluation_items=evaluation_items, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateEvaluationSetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationSets".format_map(request_url_dict) + else: + path = "evaluationSets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationSet._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.DeleteEvaluationMetricConfigOrDict] = None, + ) -> types.DeleteEvaluationMetricOperation: + """ + Deletes an EvaluationMetric. + """ + + parameter_model = types._DeleteEvaluationMetricParameters( + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{evaluation_metric}".format_map(request_url_dict) + else: + path = "{evaluation_metric}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteEvaluationMetricOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _evaluate_instances( + self, + *, + bleu_input: Optional[types.BleuInputOrDict] = None, + exact_match_input: Optional[types.ExactMatchInputOrDict] = None, + rouge_input: Optional[types.RougeInputOrDict] = None, + pointwise_metric_input: Optional[types.PointwiseMetricInputOrDict] = None, + pairwise_metric_input: Optional[types.PairwiseMetricInputOrDict] = None, + tool_call_valid_input: Optional[types.ToolCallValidInputOrDict] = None, + tool_name_match_input: Optional[types.ToolNameMatchInputOrDict] = None, + tool_parameter_key_match_input: Optional[ + types.ToolParameterKeyMatchInputOrDict + ] = None, + tool_parameter_kv_match_input: Optional[ + types.ToolParameterKVMatchInputOrDict + ] = None, + rubric_based_metric_input: Optional[types.RubricBasedMetricInputOrDict] = None, + autorater_config: Optional[genai_types.AutoraterConfigOrDict] = None, + metrics: Optional[list[types.MetricOrDict]] = None, + instance: Optional[types.EvaluationInstanceOrDict] = None, + metric_sources: Optional[list[types.MetricSourceOrDict]] = None, + config: Optional[types.EvaluateInstancesConfigOrDict] = None, + ) -> types.EvaluateInstancesResponse: + """ + Evaluates instances based on a given metric. + """ + + parameter_model = types._EvaluateInstancesRequestParameters( + bleu_input=bleu_input, + exact_match_input=exact_match_input, + rouge_input=rouge_input, + pointwise_metric_input=pointwise_metric_input, + pairwise_metric_input=pairwise_metric_input, + tool_call_valid_input=tool_call_valid_input, + tool_name_match_input=tool_name_match_input, + tool_parameter_key_match_input=tool_parameter_key_match_input, + tool_parameter_kv_match_input=tool_parameter_kv_match_input, + rubric_based_metric_input=rubric_based_metric_input, + autorater_config=autorater_config, + metrics=metrics, + instance=instance, + metric_sources=metric_sources, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _EvaluateInstancesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateInstances".format_map(request_url_dict) + else: + path = ":evaluateInstances" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluateInstancesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _generate_user_scenarios( + self, + *, + location: Optional[str] = None, + agents: Optional[dict[str, evals_types.AgentConfigOrDict]] = None, + root_agent_id: Optional[str] = None, + user_scenario_generation_config: Optional[ + evals_types.UserScenarioGenerationConfigOrDict + ] = None, + config: Optional[types.GenerateUserScenariosConfigOrDict] = None, + allow_cross_region_model: Optional[bool] = None, + ) -> types.GenerateUserScenariosResponse: + """ + Generates user scenarios for agent evaluation. + """ + + parameter_model = types._GenerateUserScenariosParameters( + location=location, + agents=agents, + root_agent_id=root_agent_id, + user_scenario_generation_config=user_scenario_generation_config, + config=config, + allow_cross_region_model=allow_cross_region_model, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateUserScenariosParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateUserScenarios".format_map(request_url_dict) + else: + path = ":generateUserScenarios" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateUserScenariosResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _generate_loss_clusters( + self, + *, + location: Optional[str] = None, + evaluation_set: Optional[str] = None, + inline_results: Optional[list[types.EvaluationResultOrDict]] = None, + configs: Optional[list[types.LossAnalysisConfigOrDict]] = None, + config: Optional[types.GenerateLossClustersConfigOrDict] = None, + ) -> types.GenerateLossClustersOperation: + """ + Generates loss clusters from evaluation results. + """ + + parameter_model = types._GenerateLossClustersParameters( + location=location, + evaluation_set=evaluation_set, + inline_results=inline_results, + configs=configs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateLossClustersParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateLossClusters".format_map(request_url_dict) + else: + path = ":generateLossClusters" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateLossClustersOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _generate_rubrics( + self, + *, + contents: list[genai_types.ContentOrDict], + predefined_rubric_generation_spec: Optional[ + genai_types.PredefinedMetricSpecOrDict + ] = None, + rubric_generation_spec: Optional[genai_types.RubricGenerationSpecOrDict] = None, + metric_resource_name: Optional[str] = None, + config: Optional[types.RubricGenerationConfigOrDict] = None, + ) -> types.GenerateInstanceRubricsResponse: + """ + Generates rubrics for a given prompt. + """ + + parameter_model = types._GenerateInstanceRubricsRequest( + contents=contents, + predefined_rubric_generation_spec=predefined_rubric_generation_spec, + rubric_generation_spec=rubric_generation_spec, + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateInstanceRubricsRequest_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":generateInstanceRubrics".format_map(request_url_dict) + else: + path = ":generateInstanceRubrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.GenerateInstanceRubricsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.GetEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """ + Retrieves an EvaluationMetric from the resource name. + """ + + parameter_model = types._GetEvaluationMetricParameters( + metric_resource_name=metric_resource_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationMetricParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{evaluation_metric}".format_map(request_url_dict) + else: + path = "{evaluation_metric}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationMetric_from_vertex(response_dict) + + return_value = types.EvaluationMetric._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_evaluation_run( + self, *, name: str, config: Optional[types.GetEvaluationRunConfigOrDict] = None + ) -> types.EvaluationRun: + """ + Retrieves an EvaluationRun from the resource name. + """ + + parameter_model = types._GetEvaluationRunParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationRunParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationRuns/{name}".format_map(request_url_dict) + else: + path = "evaluationRuns/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _EvaluationRun_from_vertex(response_dict) + + return_value = types.EvaluationRun._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_evaluation_set( + self, *, name: str, config: Optional[types.GetEvaluationSetConfigOrDict] = None + ) -> types.EvaluationSet: + """ + Retrieves an EvaluationSet from the resource name. + """ + + parameter_model = types._GetEvaluationSetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationSetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationSets/{name}".format_map(request_url_dict) + else: + path = "evaluationSets/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationSet._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_evaluation_item( + self, *, name: str, config: Optional[types.GetEvaluationItemConfigOrDict] = None + ) -> types.EvaluationItem: + """ + Retrieves an EvaluationItem from the resource name. + """ + + parameter_model = types._GetEvaluationItemParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetEvaluationItemParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationItems/{name}".format_map(request_url_dict) + else: + path = "evaluationItems/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.EvaluationItem._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list_evaluation_metrics( + self, *, config: Optional[types.ListEvaluationMetricsConfigOrDict] = None + ) -> types.ListEvaluationMetricsResponse: + """ + Lists EvaluationMetrics. + """ + + parameter_model = types._ListEvaluationMetricsParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListEvaluationMetricsParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "evaluationMetrics".format_map(request_url_dict) + else: + path = "evaluationMetrics" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _ListEvaluationMetricsResponse_from_vertex(response_dict) + + return_value = types.ListEvaluationMetricsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def batch_evaluate( + self, + *, + dataset: types.EvaluationDatasetOrDict, + metrics: list[types.MetricOrDict], + dest: str, + config: Optional[types.EvaluateDatasetConfigOrDict] = None, + ) -> types.EvaluateDatasetOperation: + """Evaluates a dataset based on a set of given metrics.""" + resolved_metrics = _evals_common._resolve_metrics(metrics, self._api_client) + output_config = genai_types.OutputConfig( + gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest) + ) + parameter_model = types.EvaluateDatasetRequestParameters( + dataset=dataset, + metrics=resolved_metrics, + output_config=output_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _evals_utils.BatchEvaluateRequestPreparer.EvaluateDatasetRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = ":evaluateDataset".format_map(request_url_dict) + else: + path = ":evaluateDataset" + + request_dict = _evals_utils.BatchEvaluateRequestPreparer.prepare_metric_payload( + request_dict, resolved_metrics + ) + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = "" if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _evals_utils.BatchEvaluateRequestPreparer.EvaluateDatasetOperation_from_vertex( + response_dict + ) + + return_value = types.EvaluateDatasetOperation._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(return_value) + + return return_value + + async def evaluate_instances( + self, + *, + metric_config: types._EvaluateInstancesRequestParameters, + ) -> types.EvaluateInstancesResponse: + """Evaluates an instance of a model.""" + + if isinstance(metric_config, types._EvaluateInstancesRequestParameters): + metric_config = metric_config.model_dump() # type: ignore[assignment] + else: + metric_config = dict(metric_config) + + result = await self._evaluate_instances( + **metric_config, + ) + + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_run module is experimental, " + "and may change in future versions." + ) + async def get_evaluation_run( + self, + *, + name: str, + include_evaluation_items: bool = False, + config: Optional[types.GetEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """Retrieves the EvaluationRun from the resource name. + Args: + name: The resource name of the EvaluationRun. Format: + `projects/{project}/locations/{location}/evaluationRuns/{evaluation_run}` + include_evaluation_items: Whether to include the evaluation items in the + response. + config: The optional configuration for the evaluation run. Must be a dict or + `types.GetEvaluationRunConfigOrDict` type. + + Returns: + The evaluation run. + Raises: + ValueError: If the name is empty or invalid. + """ + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + result = await self._get_evaluation_run(name=name, config=config) + if include_evaluation_items: + eval_result, eval_item_map = ( + await _evals_common._convert_evaluation_run_results_async( + self._api_client, + result.evaluation_run_results, + result.inference_configs, + ) + ) + result.evaluation_item_results = eval_result + # Bypass pydantic validation (extra='forbid') for this internal field. + object.__setattr__(result, "_eval_item_map", eval_item_map) + + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_run module is experimental, " + "and may change in future versions." + ) + async def create_evaluation_run( + self, + *, + dataset: Union[types.EvaluationRunDataSource, types.EvaluationDataset], + dest: str, + metrics: list[types.EvaluationRunMetricOrDict], + name: Optional[str] = None, + display_name: Optional[str] = None, + agent_info: Optional[evals_types.AgentInfo] = None, + agent: Optional[str] = None, + user_simulator_config: Optional[evals_types.UserSimulatorConfigOrDict] = None, + inference_configs: Optional[ + dict[str, types.EvaluationRunInferenceConfigOrDict] + ] = None, + labels: Optional[dict[str, str]] = None, + loss_analysis_metrics: Optional[list[Union[str, types.MetricOrDict]]] = None, + loss_analysis_configs: Optional[list[types.LossAnalysisConfigOrDict]] = None, + config: Optional[types.CreateEvaluationRunConfigOrDict] = None, + ) -> types.EvaluationRun: + """Creates an EvaluationRun. + + Args: + dataset: The dataset to evaluate. Either an EvaluationRunDataSource or an EvaluationDataset. + dest: The GCS URI prefix to write the evaluation results to. + metrics: The list of metrics to evaluate. + name: The name of the evaluation run. + display_name: The display name of the evaluation run. + agent_info: The agent info to evaluate. Mutually exclusive with + `inference_configs`. + agent: The agent engine resource name in str type, with format + `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine_id}`. + If provided, runs inference with the deployed agent to get agent responses + for evaluation. This is required if `agent_info` is provided. + user_simulator_config: The user simulator configuration for agent evaluation. + If `agent_info` is provided without `inference_configs`, this config is used + to automatically construct the inference configuration. If not specified, + or if `max_turn` is not set, `max_turn` defaults to 5. + The `model_name` inside this config can be either a full model path or a + short model name, e.g. `gemini-3-preview-flash`. + inference_configs: The candidate to inference config map for the evaluation run. + The key is the candidate name, and the value is the inference config. + If provided, `agent_info` must be None. If omitted and `agent_info` is provided, + this will be automatically constructed using `agent_info` and `user_simulator_config`. + Example: + {"candidate-1": types.EvaluationRunInferenceConfig(model="gemini-2.5-flash")} + labels: The labels to apply to the evaluation run. + loss_analysis_metrics: This field is experimental and may change in future + versions. Optional list of metrics to run loss analysis on. The + candidate is auto-inferred from ``inference_configs`` or + ``agent_info`` when there is exactly one candidate. Each metric can be + a string (e.g., ``"multi_turn_task_success_v1"``), a ``Metric`` + object, or a ``RubricMetric`` enum + (e.g., ``types.RubricMetric.MULTI_TURN_TASK_SUCCESS``). Loss analysis + runs after metric calculation completes. + Mutually exclusive with ``loss_analysis_configs``. + Example:: + + loss_analysis_metrics=[ + types.RubricMetric.MULTI_TURN_TASK_SUCCESS, + types.RubricMetric.MULTI_TURN_TOOL_USE_QUALITY, + ] + loss_analysis_configs: This field is experimental and may change in future + versions. Optional list of ``LossAnalysisConfig`` objects for full + control over loss analysis, including explicit candidate and + advanced options like ``predefined_taxonomy`` and + ``max_top_cluster_count``. Mutually exclusive with + ``loss_analysis_metrics``. + config: The configuration for the evaluation run. + + Returns: + The created evaluation run. + """ + if loss_analysis_metrics and loss_analysis_configs: + raise ValueError( + "At most one of loss_analysis_metrics or loss_analysis_configs" + " can be provided." + ) + if agent_info and inference_configs: + raise ValueError( + "At most one of agent_info or inference_configs can be provided." + ) + parsed_agent_info = ( + evals_types.AgentInfo.model_validate(agent_info) + if isinstance(agent_info, dict) + else (agent_info or evals_types.AgentInfo()) + ) + + if agent_info and not inference_configs: + parsed_user_simulator_config = ( + evals_types.UserSimulatorConfig.model_validate(user_simulator_config) + if isinstance(user_simulator_config, dict) + else (user_simulator_config or evals_types.UserSimulatorConfig()) + ) + if getattr(parsed_user_simulator_config, "max_turn", None) is None: + parsed_user_simulator_config.max_turn = 5 + + candidate_name = parsed_agent_info.name or "candidate-1" + inference_configs = { + candidate_name: types.EvaluationRunInferenceConfig( + agent_configs=parsed_agent_info.agents, + agent_run_config=types.AgentRunConfig( + agent_engine=agent, + user_simulator_config=parsed_user_simulator_config, + ), + ) + } + + if isinstance(dataset, types.EvaluationDataset): + _evals_utils._validate_dataset_agent_data(dataset, inference_configs) + resolved_dataset = _evals_common._resolve_dataset( + self._api_client, dataset, dest, parsed_agent_info + ) + output_config = genai_types.OutputConfig( + gcs_destination=genai_types.GcsDestination(output_uri_prefix=dest) + ) + resolved_metrics = _evals_common._resolve_evaluation_run_metrics( + metrics, self._api_client + ) + resolved_loss_configs = _evals_utils._resolve_eval_run_loss_configs( + loss_analysis_metrics=loss_analysis_metrics, + loss_analysis_configs=loss_analysis_configs, + inference_configs=inference_configs, + ) + evaluation_config = types.EvaluationRunConfig( + output_config=output_config, + metrics=resolved_metrics, + loss_analysis_config=resolved_loss_configs, + ) + resolved_inference_configs = _evals_common._resolve_inference_configs( + self._api_client, resolved_dataset, inference_configs, parsed_agent_info + ) + resolved_labels = _evals_common._add_evaluation_run_labels(labels, agent) + resolved_name = name or f"evaluation_run_{uuid.uuid4()}" + + result = await self._create_evaluation_run( + name=resolved_name, + display_name=display_name or resolved_name, + data_source=resolved_dataset, + evaluation_config=evaluation_config, + inference_configs=resolved_inference_configs, + labels=resolved_labels, + config=config, + ) + + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_set method is experimental, " + "and may change in future versions." + ) + async def get_evaluation_set( + self, + *, + name: str, + config: Optional[types.GetEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """Retrieves an EvaluationSet from the resource name. + + Args: + name: The resource name of the EvaluationSet. Format: + `projects/{project}/locations/{location}/evaluationSets/{evaluation_set}` + config: The optional configuration for the evaluation set. Must be a dict or + `types.GetEvaluationSetConfigOrDict` type. + + Returns: + The evaluation set. + """ + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + result = await self._get_evaluation_set(name=name, config=config) + + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_item method is experimental, " + "and may change in future versions." + ) + async def get_evaluation_item( + self, + *, + name: str, + config: Optional[types.GetEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """Retrieves an EvaluationItem from the resource name. + + Args: + name: The resource name of the EvaluationItem. Format: + `projects/{project}/locations/{location}/evaluationItems/{evaluation_item}` + config: The optional configuration for the evaluation item. Must be a dict or + `types.GetEvaluationItemConfigOrDict` type. + + Returns: + The evaluation item. + """ + if not name: + raise ValueError("name cannot be empty.") + if name.startswith("projects/"): + name = name.split("/")[-1] + result = await self._get_evaluation_item(name=name, config=config) + if ( + result.gcs_uri + and result.evaluation_item_type == types.EvaluationItemType.RESULT + ): + result.evaluation_response = ( + _evals_common._convert_gcs_to_evaluation_item_result( + self._api_client, result.gcs_uri + ) + ) + elif ( + result.gcs_uri + and result.evaluation_item_type == types.EvaluationItemType.REQUEST + ): + result.evaluation_request = ( + _evals_common._convert_gcs_to_evaluation_item_request( + self._api_client, result.gcs_uri + ) + ) + + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_item module is experimental, " + "and may change in future versions." + ) + async def create_evaluation_item( + self, + *, + evaluation_item_type: types.EvaluationItemType, + gcs_uri: str, + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationItemConfigOrDict] = None, + ) -> types.EvaluationItem: + """Creates an EvaluationItem. + + Args: + evaluation_item_type: The type of the evaluation item. + gcs_uri: The GCS URI of the evaluation item. + display_name: The display name of the evaluation item. + config: The optional configuration for the evaluation item. Must be a dict or + `types.CreateEvaluationItemConfigOrDict` type. + + Returns: + The evaluation item. + """ + result = await self._create_evaluation_item( + evaluation_item_type=evaluation_item_type, + gcs_uri=gcs_uri, + display_name=display_name, + config=config, + ) + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_set module is experimental, " + "and may change in future versions." + ) + async def create_evaluation_set( + self, + *, + evaluation_items: list[str], + display_name: Optional[str] = None, + config: Optional[types.CreateEvaluationSetConfigOrDict] = None, + ) -> types.EvaluationSet: + """Creates an EvaluationSet. + + Args: + evaluation_items: The list of evaluation item names. Format: + `projects/{project}/locations/{location}/evaluationItems/{evaluation_item}` + display_name: The display name of the evaluation set. + config: The optional configuration for the evaluation set. Must be a dict or + `types.CreateEvaluationSetConfigOrDict` type. + + Returns: + The evaluation set. + """ + result = await self._create_evaluation_set( + evaluation_items=evaluation_items, + display_name=display_name, + config=config, + ) + return result + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.generate_conversation_scenarios module is experimental, " + "and may change in future versions." + ) + async def generate_conversation_scenarios( + self, + *, + agent_info: evals_types.AgentInfoOrDict, + config: evals_types.UserScenarioGenerationConfigOrDict, + allow_cross_region_model: Optional[bool] = None, + ) -> types.EvaluationDataset: + """Generates an evaluation dataset with user scenarios, + which helps to generate conversations between a simulated user + and the agent under test. + + Args: + agent_info: The agent info to generate user scenarios for. + config: Configuration for generating user scenarios. + allow_cross_region_model: Opt-in flag to authorize cross-region + routing for model inference. + + Returns: + An EvaluationDataset containing the generated user scenarios. + """ + parsed_agent_info = ( + evals_types.AgentInfo.model_validate(agent_info) + if isinstance(agent_info, dict) + else agent_info + ) + response = await self._generate_user_scenarios( + agents=parsed_agent_info.agents, + root_agent_id=parsed_agent_info.root_agent_id, + user_scenario_generation_config=config, + allow_cross_region_model=allow_cross_region_model, + ) + return _evals_utils._postprocess_user_scenarios_response(response) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.generate_loss_clusters module is experimental, " + "and may change in future versions." + ) + async def generate_loss_clusters( + self, + *, + eval_result: types.EvaluationResult, + metric: Optional[Union[str, types.MetricOrDict]] = None, + candidate: Optional[str] = None, + config: Optional[types.LossAnalysisConfigOrDict] = None, + ) -> types.GenerateLossClustersResponse: + """Generates loss clusters from evaluation results. + + Analyzes "Pass/Fail" signals from rubric-based autoraters and groups + them into semantic "Loss Patterns" (e.g., "Hallucination of Action"). + + This method calls the GenerateLossClusters LRO and polls until + completion, returning the results directly. + + If ``metric`` or ``candidate`` are not provided, they will be + auto-inferred from ``eval_result`` when unambiguous (i.e., when the + eval result contains exactly one metric or one candidate). For + multi-metric or multi-candidate evaluations, provide them explicitly. + + Available candidate names can be found in + ``eval_result.metadata.candidate_names``. + + Note: This API is only available in the ``global`` region. + + Args: + eval_result: The EvaluationResult object returned from + client.evals.evaluate(). + metric: The metric to analyze. Can be a metric name string + (e.g., "multi_turn_task_success_v1"), a Metric object, or a + RubricMetric enum (e.g., types.RubricMetric.MULTI_TURN_TASK_SUCCESS). + If not provided and config does not specify it, auto-inferred + from eval_result. + candidate: The candidate to analyze. If not provided and config + does not specify it, auto-inferred from eval_result. + config: Optional LossAnalysisConfig with additional options + (predefined_taxonomy, max_top_cluster_count). Can also + specify metric/candidate, but explicit arguments take + precedence. + + Returns: + A GenerateLossClustersResponse containing the analysis results. + Call .show() to visualize, or access .results for individual + LossAnalysisResult objects (each with their own .show()). + """ + metric_name = _evals_utils._resolve_metric_name(metric) + parsed_config = ( + types.LossAnalysisConfig.model_validate(config) + if isinstance(config, dict) + else config + ) + resolved_config = _evals_utils._resolve_loss_analysis_config( + eval_result=eval_result, + config=parsed_config, + metric=metric_name, + candidate=candidate, + ) + operation = await self._generate_loss_clusters( + inline_results=[eval_result], + configs=[resolved_config], + ) + completed = await _evals_utils._poll_operation_async( + api_client=self._api_client, + operation=operation, + ) + if completed.error: + raise RuntimeError(f"Loss analysis operation failed: {completed.error}") + if completed.response is None: + raise RuntimeError( + "Loss analysis operation completed but returned no response." + ) + _evals_utils._enrich_loss_response_with_rubric_descriptions( + completed.response, eval_result + ) + return completed.response + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.create_evaluation_metric module is experimental, " + "and may change in future versions." + ) + async def create_evaluation_metric( + self, + *, + display_name: Optional[str] = None, + description: Optional[str] = None, + metric: Optional[types.MetricOrDict] = None, + config: Optional[types.CreateEvaluationMetricConfigOrDict] = None, + ) -> str: + """Creates an EvaluationMetric.""" + if metric and not isinstance(metric, dict): + resolved_metrics = _evals_common._resolve_metrics( + [metric], self._api_client + ) + metric = resolved_metrics[0] + + # Add fallback logic for display_name + if display_name is None and metric: + if isinstance(metric, dict): + display_name = metric.get("name") + else: + display_name = getattr(metric, "name", None) + + result = await self._create_evaluation_metric( + display_name=display_name, + description=description, + metric=metric, + config=config, + ) + return cast(str, result.name) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.get_evaluation_metric module is experimental, " + "and may change in future versions." + ) + async def get_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.GetEvaluationMetricConfigOrDict] = None, + ) -> types.EvaluationMetric: + """Retrieves an EvaluationMetric from the resource name.""" + return await self._get_evaluation_metric( + metric_resource_name=metric_resource_name, + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.list_evaluation_metrics module is experimental, " + "and may change in future versions." + ) + async def list_evaluation_metrics( + self, + *, + filter: Optional[str] = None, + order_by: Optional[str] = None, + config: Optional[types.ListEvaluationMetricsConfigOrDict] = None, + ) -> types.ListEvaluationMetricsResponse: + """Lists EvaluationMetrics. + + Args: + filter: An expression for filtering the results of the request. For + field names both snake_case and camelCase are supported. For more + information about filter syntax, see + `AIP-160 `_. + Example: ``'display_name="my_metric"'``. + order_by: A comma-separated list of fields to order by, sorted in + ascending order by default. Use ``desc`` after a field name for + descending. Example: ``"create_time desc"``. + config: Optional configuration for the list operation, including + pagination (``page_size``, ``page_token``), ``filter``, and + ``order_by``. Top-level ``filter`` and ``order_by`` arguments + take precedence over values set in ``config``. + + Returns: + The list evaluation metrics response. + """ + if config is None: + config = types.ListEvaluationMetricsConfig() + if isinstance(config, dict): + config = types.ListEvaluationMetricsConfig.model_validate(config) + if filter is not None: + config.filter = filter + if order_by is not None: + config.order_by = order_by + return await self._list_evaluation_metrics( + config=config, + ) + + @_common.experimental_warning( + "The Vertex SDK GenAI evals.delete_evaluation_metric method is experimental, " + "and may change in future versions." + ) + async def delete_evaluation_metric( + self, + *, + metric_resource_name: str, + config: Optional[types.DeleteEvaluationMetricConfigOrDict] = None, + ) -> None: + """Deletes an EvaluationMetric. + + Args: + metric_resource_name: The resource name of the EvaluationMetric to delete. + Format: + `projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric}` + config: The optional configuration for the delete operation. + """ + await self._delete_evaluation_metric( + metric_resource_name=metric_resource_name, + config=config, + ) diff --git a/agentplatform/_genai/live.py b/agentplatform/_genai/live.py new file mode 100644 index 0000000000..1a4bbbf006 --- /dev/null +++ b/agentplatform/_genai/live.py @@ -0,0 +1,64 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""[Preview] Live API client.""" + +import importlib +import logging + +from typing import Optional, TYPE_CHECKING +from types import ModuleType + +from google.genai import _api_module +from google.genai import _common +from google.genai._api_client import BaseApiClient + +logger = logging.getLogger("google_genai.live") + +if TYPE_CHECKING: + from agentplatform._genai import ( + live_agent_engines as live_agent_engines_module, + ) + + +class AsyncLive(_api_module.BaseModule): + """[Preview] AsyncLive.""" + + def __init__(self, api_client: BaseApiClient): + super().__init__(api_client) + self._agent_engines: Optional[ModuleType] = None + + @property + @_common.experimental_warning( + "The Vertex SDK GenAI agent engines module is experimental, " + "and may change in future versions." + ) + def agent_engines(self) -> "live_agent_engines_module.AsyncLiveAgentEngines": + if self._agent_engines is None: + try: + # We need to lazy load the live_agent_engines module to handle + # the possibility of ImportError when dependencies are not + # installed. + self._agent_engines = importlib.import_module( + ".live_agent_engines", + __package__, + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines' module requires 'additional packages'. " + "Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._agent_engines.AsyncLiveAgentEngines(self._api_client) # type: ignore[no-any-return] diff --git a/agentplatform/_genai/live_agent_engines.py b/agentplatform/_genai/live_agent_engines.py new file mode 100644 index 0000000000..ed79ed5d0c --- /dev/null +++ b/agentplatform/_genai/live_agent_engines.py @@ -0,0 +1,179 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Live AgentEngine API client.""" + +import contextlib +import json +from typing import Any, AsyncIterator, Dict, Optional +import google.auth + +from google.genai import _api_module +from .types import QueryAgentEngineConfig, QueryAgentEngineConfigOrDict + + +try: + from websockets.asyncio.client import ClientConnection + from websockets.asyncio.client import connect as ws_connect +except ModuleNotFoundError: + # This try/except is for TAP, mypy complains about it which is why we have the type: ignore + from websockets.client import ClientConnection # type: ignore + from websockets.client import connect as ws_connect # type: ignore + + +class AsyncLiveAgentEngineSession: + """AsyncLiveAgentEngineSession.""" + + def __init__(self, websocket: ClientConnection): + self._ws = websocket + + async def send(self, query_input: Dict[str, Any]) -> None: + """Send a query input to the Agent. + + Args: + query_input: A JSON serializable Python Dict to be send to the Agent. + """ + + try: + json_request = json.dumps({"bidi_stream_input": query_input}) + except Exception as exc: + raise ValueError( + "Failed to encode query input to JSON in live_agent_engines: " + f"{str(query_input)}" + ) from exc + await self._ws.send(json_request) + + async def receive(self) -> Any: + """Receive one response from the Agent. + + Returns: + A response from the Agent. + + Raises: + websockets.exceptions.ConnectionClosed: If the connection is closed. + """ + + response = await self._ws.recv() + try: + return json.loads(response) + except json.decoder.JSONDecodeError as exc: + raise ValueError( + "Failed to parse response to JSON in live_agent_engines: " + f"{str(response)}" + ) from exc + + async def close(self) -> None: + """Close the connection.""" + await self._ws.close() + + +class AsyncLiveAgentEngines(_api_module.BaseModule): + """AsyncLiveAgentEngines. + + Example usage: + + .. code-block:: python + + from pathlib import Path + + from google import genai + from google.genai import types + + class MyAgentEngine(client): + def bidi_stream_query(self, input_queue: asyncio.Queue): + while True: + input = await input_queue.get() + yield {"output": f"Agent received {input}!"} + + client = agentplatform.Client(project="my-project", location="us-central1") + agent_engine = client.agent_engines.create(agent) + + async with client.aio.live.agent_engines.connect( + agent_engine=agent_engine.api_resource.name, + setup={"class_method": "bidi_stream_query"}, + ) as session: + await session.send(input={"input": "Hello world"}) + + response = await session.receive() + # {"output": "Agent received Hello world!"} + ... + """ + + @contextlib.asynccontextmanager + async def connect( + self, + *, + agent_engine: str, + config: Optional[QueryAgentEngineConfigOrDict] = None, + ) -> AsyncIterator[AsyncLiveAgentEngineSession]: + """Connect to the agent deployed to Agent Engine in a live (bidirectional streaming) session. + + Args: + agent_engine: The resource name of the Agent Engine to use for the + live session. + config: The optional configuration for starting the live Agent Engine + session. Custom class_method and an optional initial input could be + provided. If no class_method is provided, the default class_method + "bidi_stream_query" will be used by the Agent Engine. + + Yields: + An AsyncLiveAgentEngineSession object. + """ + if isinstance(config, dict): + config = QueryAgentEngineConfig(**config) + + agent_engine_resource_name = agent_engine + if not agent_engine_resource_name.startswith("projects/"): + agent_engine_resource_name = f"projects/{self._api_client.project}/locations/{self._api_client.location}/reasoningEngines/{agent_engine}" + request_dict = {"setup": {"name": agent_engine_resource_name}} + if config is not None and config.class_method: + request_dict["setup"]["class_method"] = config.class_method + if config is not None and config.input: + request_dict["setup"]["input"] = config.input # type: ignore[assignment] + + request = json.dumps(request_dict) + + if not self._api_client._credentials: + # Get bearer token through Application Default Credentials. + creds, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + else: + creds = self._api_client._credentials + # creds.valid is False, and creds.token is None + # Need to refresh credentials to populate those + if not (creds.token and creds.valid): + auth_req = google.auth.transport.requests.Request() + creds.refresh(auth_req) # type: ignore[no-untyped-call] + bearer_token = creds.token + + original_headers = self._api_client._http_options.headers + headers = original_headers.copy() if original_headers is not None else {} + headers["Authorization"] = f"Bearer {bearer_token}" + + base_url = self._api_client._websocket_base_url() + if isinstance(base_url, bytes): + base_url = base_url.decode("utf-8") + uri = ( + f"{base_url}/ws/google.cloud.aiplatform." + f"{self._api_client._http_options.api_version}" + ".ReasoningEngineExecutionService/BidiQueryReasoningEngine" + ) + + async with ws_connect( + uri, additional_headers=headers, **self._api_client._websocket_ssl_ctx + ) as ws: + await ws.send(request) + yield AsyncLiveAgentEngineSession(websocket=ws) diff --git a/agentplatform/_genai/memories.py b/agentplatform/_genai/memories.py new file mode 100644 index 0000000000..b2145cd8e0 --- /dev/null +++ b/agentplatform/_genai/memories.py @@ -0,0 +1,3427 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import importlib +import json +import logging +import typing +from typing import Any, Iterator, List, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import _agent_engines_utils +from . import types + +if typing.TYPE_CHECKING: + from . import memory_revisions as memory_revisions_module + + _ = memory_revisions_module + + +logger = logging.getLogger("agentplatform_genai.memories") + +logger.setLevel(logging.INFO) + + +def _AgentEngineMemoryConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + if getv(from_object, ["expire_time"]) is not None: + setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + + if getv(from_object, ["revision_expire_time"]) is not None: + setv( + parent_object, + ["revisionExpireTime"], + getv(from_object, ["revision_expire_time"]), + ) + + if getv(from_object, ["revision_ttl"]) is not None: + setv(parent_object, ["revisionTtl"], getv(from_object, ["revision_ttl"])) + + if getv(from_object, ["disable_memory_revisions"]) is not None: + setv( + parent_object, + ["disableMemoryRevisions"], + getv(from_object, ["disable_memory_revisions"]), + ) + + if getv(from_object, ["topics"]) is not None: + setv( + parent_object, ["topics"], [item for item in getv(from_object, ["topics"])] + ) + + if getv(from_object, ["metadata"]) is not None: + setv( + parent_object, + ["metadata"], + {k: v for k, v in getv(from_object, ["metadata"]).items()}, + ) + + if getv(from_object, ["memory_id"]) is not None: + setv(parent_object, ["_query", "memoryId"], getv(from_object, ["memory_id"])) + + return to_object + + +def _CreateAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["fact"]) is not None: + setv(to_object, ["fact"], getv(from_object, ["fact"])) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + if getv(from_object, ["config"]) is not None: + _AgentEngineMemoryConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _DeleteAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GenerateAgentEngineMemoriesConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["disable_consolidation"]) is not None: + setv( + parent_object, + ["disableConsolidation"], + getv(from_object, ["disable_consolidation"]), + ) + + if getv(from_object, ["revision_labels"]) is not None: + setv(parent_object, ["revisionLabels"], getv(from_object, ["revision_labels"])) + + if getv(from_object, ["revision_expire_time"]) is not None: + setv( + parent_object, + ["revisionExpireTime"], + getv(from_object, ["revision_expire_time"]), + ) + + if getv(from_object, ["revision_ttl"]) is not None: + setv(parent_object, ["revisionTtl"], getv(from_object, ["revision_ttl"])) + + if getv(from_object, ["disable_memory_revisions"]) is not None: + setv( + parent_object, + ["disableMemoryRevisions"], + getv(from_object, ["disable_memory_revisions"]), + ) + + if getv(from_object, ["metadata"]) is not None: + setv( + parent_object, + ["metadata"], + {k: v for k, v in getv(from_object, ["metadata"]).items()}, + ) + + if getv(from_object, ["metadata_merge_strategy"]) is not None: + setv( + parent_object, + ["metadataMergeStrategy"], + getv(from_object, ["metadata_merge_strategy"]), + ) + + if getv(from_object, ["allowed_topics"]) is not None: + setv( + parent_object, + ["allowedTopics"], + [item for item in getv(from_object, ["allowed_topics"])], + ) + + return to_object + + +def _GenerateAgentEngineMemoriesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["vertex_session_source"]) is not None: + setv( + to_object, + ["vertexSessionSource"], + getv(from_object, ["vertex_session_source"]), + ) + + if getv(from_object, ["direct_contents_source"]) is not None: + setv( + to_object, + ["directContentsSource"], + getv(from_object, ["direct_contents_source"]), + ) + + if getv(from_object, ["direct_memories_source"]) is not None: + setv( + to_object, + ["directMemoriesSource"], + getv(from_object, ["direct_memories_source"]), + ) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + if getv(from_object, ["config"]) is not None: + _GenerateAgentEngineMemoriesConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _GetAgentEngineGenerateMemoriesOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetAgentEngineMemoryOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _IngestEventsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["force_flush"]) is not None: + setv(parent_object, ["forceFlush"], getv(from_object, ["force_flush"])) + + return to_object + + +def _IngestEventsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["stream_id"]) is not None: + setv(to_object, ["streamId"], getv(from_object, ["stream_id"])) + + if getv(from_object, ["direct_contents_source"]) is not None: + setv( + to_object, + ["directContentsSource"], + getv(from_object, ["direct_contents_source"]), + ) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + if getv(from_object, ["generation_trigger_config"]) is not None: + setv( + to_object, + ["generationTriggerConfig"], + getv(from_object, ["generation_trigger_config"]), + ) + + if getv(from_object, ["config"]) is not None: + _IngestEventsConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _ListAgentEngineMemoryConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["order_by"]) is not None: + setv(parent_object, ["_query", "orderBy"], getv(from_object, ["order_by"])) + + return to_object + + +def _ListAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineMemoryConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _PurgeAgentEngineMemoriesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["filter"]) is not None: + setv(to_object, ["filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["filter_groups"]) is not None: + setv( + to_object, + ["filterGroups"], + [item for item in getv(from_object, ["filter_groups"])], + ) + + if getv(from_object, ["force"]) is not None: + setv(to_object, ["force"], getv(from_object, ["force"])) + + return to_object + + +def _RetrieveAgentEngineMemoriesConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["filter"], getv(from_object, ["filter"])) + + if getv(from_object, ["filter_groups"]) is not None: + setv( + parent_object, + ["filterGroups"], + [item for item in getv(from_object, ["filter_groups"])], + ) + + if getv(from_object, ["memory_types"]) is not None: + setv(parent_object, ["memoryTypes"], getv(from_object, ["memory_types"])) + + return to_object + + +def _RetrieveAgentEngineMemoriesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + if getv(from_object, ["similarity_search_params"]) is not None: + setv( + to_object, + ["similaritySearchParams"], + getv(from_object, ["similarity_search_params"]), + ) + + if getv(from_object, ["simple_retrieval_params"]) is not None: + setv( + to_object, + ["simpleRetrievalParams"], + getv(from_object, ["simple_retrieval_params"]), + ) + + if getv(from_object, ["config"]) is not None: + _RetrieveAgentEngineMemoriesConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _RetrieveMemoryProfilesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + return to_object + + +def _RollbackAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["target_revision_id"]) is not None: + setv(to_object, ["targetRevisionId"], getv(from_object, ["target_revision_id"])) + + return to_object + + +def _UpdateAgentEngineMemoryConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + if getv(from_object, ["expire_time"]) is not None: + setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + + if getv(from_object, ["revision_expire_time"]) is not None: + setv( + parent_object, + ["revisionExpireTime"], + getv(from_object, ["revision_expire_time"]), + ) + + if getv(from_object, ["revision_ttl"]) is not None: + setv(parent_object, ["revisionTtl"], getv(from_object, ["revision_ttl"])) + + if getv(from_object, ["disable_memory_revisions"]) is not None: + setv( + parent_object, + ["disableMemoryRevisions"], + getv(from_object, ["disable_memory_revisions"]), + ) + + if getv(from_object, ["topics"]) is not None: + setv( + parent_object, ["topics"], [item for item in getv(from_object, ["topics"])] + ) + + if getv(from_object, ["metadata"]) is not None: + setv( + parent_object, + ["metadata"], + {k: v for k, v in getv(from_object, ["metadata"]).items()}, + ) + + if getv(from_object, ["memory_id"]) is not None: + setv(parent_object, ["_query", "memoryId"], getv(from_object, ["memory_id"])) + + if getv(from_object, ["update_mask"]) is not None: + setv( + parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) + ) + + return to_object + + +def _UpdateAgentEngineMemoryRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["fact"]) is not None: + setv(to_object, ["fact"], getv(from_object, ["fact"])) + + if getv(from_object, ["scope"]) is not None: + setv(to_object, ["scope"], getv(from_object, ["scope"])) + + if getv(from_object, ["config"]) is not None: + _UpdateAgentEngineMemoryConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class Memories(_api_module.BaseModule): + + def _create( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """ + Creates a new memory in the Agent Engine. + """ + + parameter_model = types._CreateAgentEngineMemoryRequestParameters( + name=name, + fact=fact, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories".format_map(request_url_dict) + else: + path = "{name}/memories" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineMemoryConfigOrDict] = None, + ) -> types.DeleteAgentEngineMemoryOperation: + """ + Delete an Agent Engine memory. + + Args: + name (str): + Required. The name of the Agent Engine memory to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory}`. + config (DeleteAgentEngineMemoryConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + + parameter_model = types._DeleteAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _generate( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """ + Generates memories for an Agent Engine. + """ + + parameter_model = types._GenerateAgentEngineMemoriesRequestParameters( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:generate".format_map(request_url_dict) + else: + path = "{name}/memories:generate" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineGenerateMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryConfigOrDict] = None, + ) -> types.Memory: + """ + Gets an agent engine memory. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456/memories/789" + or a shortened name such as "reasoningEngines/456/memories/789". + + """ + + parameter_model = types._GetAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Memory._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _ingest_events( + self, + *, + name: str, + stream_id: Optional[str] = None, + direct_contents_source: Optional[ + types.IngestionDirectContentsSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + generation_trigger_config: Optional[ + types.MemoryGenerationTriggerConfigOrDict + ] = None, + config: Optional[types.IngestEventsConfigOrDict] = None, + ) -> types.MemoryBankIngestEventsOperation: + """ + Ingest events into a Memory Bank. + """ + + parameter_model = types._IngestEventsRequestParameters( + name=name, + stream_id=stream_id, + direct_contents_source=direct_contents_source, + scope=scope, + generation_trigger_config=generation_trigger_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _IngestEventsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:ingestEvents".format_map(request_url_dict) + else: + path = "{name}/memories:ingestEvents" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MemoryBankIngestEventsOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> types.ListReasoningEnginesMemoriesResponse: + """ + Lists Agent Engine memories. + """ + + parameter_model = types._ListAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories".format_map(request_url_dict) + else: + path = "{name}/memories" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListReasoningEnginesMemoriesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_memory_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + parameter_model = types._GetAgentEngineMemoryOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_generate_memories_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + parameter_model = types._GetAgentEngineGenerateMemoriesOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineGenerateMemoriesOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineGenerateMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _retrieve( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> types.RetrieveMemoriesResponse: + """ + Retrieves memories for an Agent Engine. + """ + + parameter_model = types._RetrieveAgentEngineMemoriesRequestParameters( + name=name, + scope=scope, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:retrieve".format_map(request_url_dict) + else: + path = "{name}/memories:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveMemoriesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def retrieve_profiles( + self, + *, + name: str, + scope: dict[str, str], + config: Optional[types.RetrieveMemoryProfilesConfigOrDict] = None, + ) -> types.RetrieveProfilesResponse: + """ + Retrieves memory profiles for an Agent Engine. + + For example, you can use the following code to retrieve all memory profiles + for scope `{'user_id': '123'}`: + + ```python + result = client.agent_engines.memories.retrieve_profiles( + name="projects/123/locations/us-central1/reasoningEngines/456", + scope={"user_id": "123"} + ) + + for profile in result.profiles.values(): + # Each profile is a dictionary corresponding to the relevant schema. + print(profile.profile) + ``` + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456". + scope (dict[str, str]): Required. The scope of the memories to retrieve. + A memory must have exactly the same scope as the scope provided here + to be retrieved (i.e. same keys and values). Order does not matter, + but it is case-sensitive. + + Returns: + RetrieveProfilesResponse: The retrieved memory profiles. + + """ + + parameter_model = types._RetrieveMemoryProfilesRequestParameters( + name=name, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveMemoryProfilesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:retrieveProfiles".format_map(request_url_dict) + else: + path = "{name}/memories:retrieveProfiles" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveProfilesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _rollback( + self, + *, + name: str, + target_revision_id: str, + config: Optional[types.RollbackAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineRollbackMemoryOperation: + """ + Rollback a memory to a previous revision. + """ + + parameter_model = types._RollbackAgentEngineMemoryRequestParameters( + name=name, + target_revision_id=target_revision_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RollbackAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:rollback".format_map(request_url_dict) + else: + path = "{name}:rollback" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineRollbackMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _update( + self, + *, + name: str, + fact: Optional[str] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.UpdateAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """ + Updates an Agent Engine memory. + """ + + parameter_model = types._UpdateAgentEngineMemoryRequestParameters( + name=name, + fact=fact, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _purge( + self, + *, + name: str, + filter: Optional[str] = None, + filter_groups: Optional[list[types.MemoryConjunctionFilterOrDict]] = None, + force: Optional[bool] = None, + config: Optional[types.PurgeAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEnginePurgeMemoriesOperation: + """ + Purges memories from an Agent Engine. + """ + + parameter_model = types._PurgeAgentEngineMemoriesRequestParameters( + name=name, + filter=filter, + filter_groups=filter_groups, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _PurgeAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:purge".format_map(request_url_dict) + else: + path = "{name}/memories:purge" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEnginePurgeMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _revisions = None + + @property + def revisions(self) -> "memory_revisions_module.MemoryRevisions": + if self._revisions is None: + try: + # We need to lazy load the revisions module to handle the + # possibility of ImportError when dependencies are not installed. + self._revisions = importlib.import_module( + ".memory_revisions", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.memories.revisions' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._revisions.MemoryRevisions(self._api_client) # type: ignore[no-any-return] + + def create( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """Creates a new memory in the Agent Engine. + + Args: + name (str): + Required. The name of the memory to create. + fact (str): + Required. The fact to be stored in the memory. + scope (dict[str, str]): + Required. The scope of the memory. For example, {"user_id": "123"}. + config (AgentEngineMemoryConfigOrDict): + Optional. The configuration for the memory. + + Returns: + AgentEngineMemoryOperation: The operation for creating the memory. + """ + if config is None: + config = types.AgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.AgentEngineMemoryConfig.model_validate(config) + operation = self._create( + name=name, + fact=fact, + scope=scope, + config=config, + ) + if config.wait_for_completion: + if not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + # We need to make a call to get the memory because the operation + # response might not contain the relevant fields. + if operation.response: + operation.response = self.get(name=operation.response.name) + elif operation.error: + raise RuntimeError(f"Failed to create memory: {operation.error}") + else: + raise RuntimeError("Error creating memory.") + return operation + + def generate( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """Generates memories for the agent engine. + + Args: + name (str): + Required. The name of the agent engine to generate memories for. + vertex_session_source (GenerateMemoriesRequestVertexSessionSource): + Optional. The vertex session source to use for generating + memories. Only one of vertex_session_source, + direct_contents_source, or direct_memories_source can be + specified. + direct_contents_source(GenerateMemoriesRequestDirectContentsSource): + Optional. The direct contents source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + direct_memories_source (GenerateMemoriesRequestDirectMemoriesSource): + Optional. The direct memories source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + scope (dict[str, str]): + Optional. The scope of the memories to generate. This is optional + if vertex_session_source is used, otherwise it must be specified. + config (GenerateMemoriesConfig): + Optional. The configuration for the memories to generate. + + Returns: + AgentEngineGenerateMemoriesOperation: + The operation for generating the memories. + """ + if config is None: + config = types.GenerateAgentEngineMemoriesConfig() + elif isinstance(config, dict): + config = types.GenerateAgentEngineMemoriesConfig.model_validate(config) + operation = self._generate( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_generate_memories_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to generate memory: {operation.error}") + return operation + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> Iterator[types.Memory]: + """Lists Agent Engine memories. + + Args: + name (str): + Required. The name of the agent engine to list memories for. + config (ListAgentEngineMemoryConfig): + Optional. The configuration for the memories to list. + + Returns: + Iterable[Memory]: An iterable of memories. + """ + + return Pager( + "memories", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + def retrieve( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> Iterator[types.RetrieveMemoriesResponseRetrievedMemory]: + """Retrieves memories for the agent. + + Args: + name (str): + Required. The name of the agent engine to retrieve memories for. + scope (dict[str, str]): + Required. The scope of the memories to retrieve. For example, + {"user_id": "123"}. + similarity_search_params (RetrieveMemoriesRequestSimilaritySearchParams): + Optional. The similarity search parameters to use for retrieving + memories. + simple_retrieval_params (RetrieveMemoriesRequestSimpleRetrievalParams): + Optional. The simple retrieval parameters to use for retrieving + memories. + config (RetrieveAgentEngineMemoriesConfig): + Optional. The configuration for the memories to retrieve. + + Returns: + Iterator[RetrieveMemoriesResponseRetrievedMemory]: An iterable of + retrieved memories. + """ + return Pager( + "retrieved_memories", + lambda config: self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + config, + ) + + def rollback( + self, + *, + name: str, + target_revision_id: str, + config: Optional[types.RollbackAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineRollbackMemoryOperation: + """Rolls back a memory to a previous revision. + + Args: + name (str): + Required. The name of the memory to rollback. + target_revision_id (str): + Required. The revision ID to roll back to + config (RollbackAgentEngineMemoryConfig): + Optional. The configuration for the rollback. + + Returns: + AgentEngineRollbackMemoryOperation: + The operation for rolling back the memory. + """ + if config is None: + config = types.RollbackAgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.RollbackAgentEngineMemoryConfig.model_validate(config) + operation = self._rollback( + name=name, + target_revision_id=target_revision_id, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to rollback memory: {operation.error}") + return operation + + def purge( + self, + *, + name: str, + filter: Optional[str] = None, + filter_groups: Optional[List[types.MemoryConjunctionFilter]] = None, + force: bool = False, + config: Optional[types.PurgeAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEnginePurgeMemoriesOperation: + """Purges memories from an Agent Engine. + + Args: + name (str): + Required. The name of the Agent Engine to purge memories from. + filter (str): + Optional. The standard list filter to determine which memories to purge. + filter_groups (list[MemoryConjunctionFilter]): + Optional. Metadata filters that will be applied to the memories' + `metadata` using OR logic. Filters are defined using disjunctive + normal form (OR of ANDs). + force (bool): + Optional. Whether to force the purge operation. If false, the + operation will be staged but not executed. + config (PurgeAgentEngineMemoriesConfig): + Optional. The configuration for the purge operation. + + Returns: + AgentEnginePurgeMemoriesOperation: + The operation for purging the memories. + """ + if config is None: + config = types.PurgeAgentEngineMemoriesConfig() + elif isinstance(config, dict): + config = types.PurgeAgentEngineMemoriesConfig.model_validate(config) + operation = self._purge( + name=name, + filter=filter, + filter_groups=filter_groups, + force=force, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to purge memories: {operation.error}") + return operation + + def ingest_events( + self, + *, + name: str, + scope: dict[str, str], + stream_id: str = "", + direct_contents_source: Optional[ + types.IngestionDirectContentsSourceOrDict + ] = None, + generation_trigger_config: Optional[ + types.MemoryGenerationTriggerConfigOrDict + ] = None, + config: Optional[types.IngestEventsConfigOrDict] = None, + ) -> types.MemoryBankIngestEventsOperation: + """Ingests events into an Agent Engine. + + Example usage: + ``` + client.agent_engines.memories.ingest_events( + name="projects/test-project/locations/us-central1/reasoningEngines/test-agent-engine", + scope={"user_id": "test-user-id"}, + direct_contents_source={ + "events": [ + { + "content": { + "role": "user", + "parts": [ + {"text": "I am a software engineer."} + ], + } + } + ] + }, + generation_trigger_config={ + "generation_rule": { + "idle_duration": "60s" + } + } + ) + ``` + + Args: + name (str): + Required. The name of the Agent Engine to ingest events into. + scope (dict[str, str]): + Required. The scope of the events to ingest. For example, + {"user_id": "123"}. + stream_id (str): + Optional. The ID of the stream to ingest events into. If not + specified, the events will be ingested into the default stream. + direct_contents_source (IngestionDirectContentsSource): + The direct contents source, containing the events to ingest. + generation_trigger_config (MemoryGenerationTriggerConfig): + Optional. The configuration for the generation trigger config. + config (IngestEventsConfig): + Optional. The configuration for the ingest events operation. + + Returns: + AgentEngineIngestEventsOperation: + The operation for ingesting the events. + """ + if config is None: + config = types.IngestEventsConfig() + elif isinstance(config, dict): + config = types.IngestEventsConfig.model_validate(config) + operation = self._ingest_events( + name=name, + scope=scope, + stream_id=stream_id, + generation_trigger_config=generation_trigger_config, + direct_contents_source=direct_contents_source, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to ingest events: {operation.error}") + return operation + + +class AsyncMemories(_api_module.BaseModule): + + async def _create( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """ + Creates a new memory in the Agent Engine. + """ + + parameter_model = types._CreateAgentEngineMemoryRequestParameters( + name=name, + fact=fact, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories".format_map(request_url_dict) + else: + path = "{name}/memories" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineMemoryConfigOrDict] = None, + ) -> types.DeleteAgentEngineMemoryOperation: + """ + Delete an Agent Engine memory. + + Args: + name (str): + Required. The name of the Agent Engine memory to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory}`. + config (DeleteAgentEngineMemoryConfig): + Optional. Additional configurations for deleting the Agent Engine. + + """ + + parameter_model = types._DeleteAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _generate( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """ + Generates memories for an Agent Engine. + """ + + parameter_model = types._GenerateAgentEngineMemoriesRequestParameters( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GenerateAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:generate".format_map(request_url_dict) + else: + path = "{name}/memories:generate" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineGenerateMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryConfigOrDict] = None, + ) -> types.Memory: + """ + Gets an agent engine memory. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456/memories/789" + or a shortened name such as "reasoningEngines/456/memories/789". + + """ + + parameter_model = types._GetAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Memory._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _ingest_events( + self, + *, + name: str, + stream_id: Optional[str] = None, + direct_contents_source: Optional[ + types.IngestionDirectContentsSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + generation_trigger_config: Optional[ + types.MemoryGenerationTriggerConfigOrDict + ] = None, + config: Optional[types.IngestEventsConfigOrDict] = None, + ) -> types.MemoryBankIngestEventsOperation: + """ + Ingest events into a Memory Bank. + """ + + parameter_model = types._IngestEventsRequestParameters( + name=name, + stream_id=stream_id, + direct_contents_source=direct_contents_source, + scope=scope, + generation_trigger_config=generation_trigger_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _IngestEventsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:ingestEvents".format_map(request_url_dict) + else: + path = "{name}/memories:ingestEvents" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MemoryBankIngestEventsOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> types.ListReasoningEnginesMemoriesResponse: + """ + Lists Agent Engine memories. + """ + + parameter_model = types._ListAgentEngineMemoryRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories".format_map(request_url_dict) + else: + path = "{name}/memories" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListReasoningEnginesMemoriesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_memory_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + parameter_model = types._GetAgentEngineMemoryOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_generate_memories_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + parameter_model = types._GetAgentEngineGenerateMemoriesOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineGenerateMemoriesOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineGenerateMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _retrieve( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> types.RetrieveMemoriesResponse: + """ + Retrieves memories for an Agent Engine. + """ + + parameter_model = types._RetrieveAgentEngineMemoriesRequestParameters( + name=name, + scope=scope, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:retrieve".format_map(request_url_dict) + else: + path = "{name}/memories:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveMemoriesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def retrieve_profiles( + self, + *, + name: str, + scope: dict[str, str], + config: Optional[types.RetrieveMemoryProfilesConfigOrDict] = None, + ) -> types.RetrieveProfilesResponse: + """ + Retrieves memory profiles for an Agent Engine. + + For example, you can use the following code to retrieve all memory profiles + for scope `{'user_id': '123'}`: + + ```python + result = client.agent_engines.memories.retrieve_profiles( + name="projects/123/locations/us-central1/reasoningEngines/456", + scope={"user_id": "123"} + ) + + for profile in result.profiles.values(): + # Each profile is a dictionary corresponding to the relevant schema. + print(profile.profile) + ``` + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456". + scope (dict[str, str]): Required. The scope of the memories to retrieve. + A memory must have exactly the same scope as the scope provided here + to be retrieved (i.e. same keys and values). Order does not matter, + but it is case-sensitive. + + Returns: + RetrieveProfilesResponse: The retrieved memory profiles. + + """ + + parameter_model = types._RetrieveMemoryProfilesRequestParameters( + name=name, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveMemoryProfilesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:retrieveProfiles".format_map(request_url_dict) + else: + path = "{name}/memories:retrieveProfiles" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveProfilesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _rollback( + self, + *, + name: str, + target_revision_id: str, + config: Optional[types.RollbackAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineRollbackMemoryOperation: + """ + Rollback a memory to a previous revision. + """ + + parameter_model = types._RollbackAgentEngineMemoryRequestParameters( + name=name, + target_revision_id=target_revision_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RollbackAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:rollback".format_map(request_url_dict) + else: + path = "{name}:rollback" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineRollbackMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update( + self, + *, + name: str, + fact: Optional[str] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.UpdateAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """ + Updates an Agent Engine memory. + """ + + parameter_model = types._UpdateAgentEngineMemoryRequestParameters( + name=name, + fact=fact, + scope=scope, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineMemoryRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineMemoryOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _purge( + self, + *, + name: str, + filter: Optional[str] = None, + filter_groups: Optional[list[types.MemoryConjunctionFilterOrDict]] = None, + force: Optional[bool] = None, + config: Optional[types.PurgeAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEnginePurgeMemoriesOperation: + """ + Purges memories from an Agent Engine. + """ + + parameter_model = types._PurgeAgentEngineMemoriesRequestParameters( + name=name, + filter=filter, + filter_groups=filter_groups, + force=force, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _PurgeAgentEngineMemoriesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/memories:purge".format_map(request_url_dict) + else: + path = "{name}/memories:purge" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEnginePurgeMemoriesOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _revisions = None + + @property + def revisions(self) -> "memory_revisions_module.AsyncMemoryRevisions": + if self._revisions is None: + try: + # We need to lazy load the revisions module to handle the + # possibility of ImportError when dependencies are not installed. + self._revisions = importlib.import_module( + ".memory_revisions", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.memories.revisions' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._revisions.AsyncMemoryRevisions(self._api_client) # type: ignore[no-any-return] + + async def create( + self, + *, + name: str, + fact: str, + scope: dict[str, str], + config: Optional[types.AgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineMemoryOperation: + """Creates a new memory in the Agent Engine. + + Args: + name (str): + Required. The name of the memory to create. + fact (str): + Required. The fact to be stored in the memory. + scope (dict[str, str]): + Required. The scope of the memory. For example, {"user_id": "123"}. + config (AgentEngineMemoryConfigOrDict): + Optional. The configuration for the memory. + + Returns: + AgentEngineMemoryOperation: The operation for creating the memory. + """ + if config is None: + config = types.AgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.AgentEngineMemoryConfig.model_validate(config) + operation = await self._create( + name=name, + fact=fact, + scope=scope, + config=config, + ) + if config.wait_for_completion: + if not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + # We need to make a call to get the memory because the operation + # response might not contain the relevant fields. + if operation.response: + operation.response = await self.get(name=operation.response.name) + elif operation.error: + raise RuntimeError(f"Failed to create memory: {operation.error}") + else: + raise RuntimeError("Error creating memory.") + return operation + + async def generate( + self, + *, + name: str, + vertex_session_source: Optional[ + types.GenerateMemoriesRequestVertexSessionSourceOrDict + ] = None, + direct_contents_source: Optional[ + types.GenerateMemoriesRequestDirectContentsSourceOrDict + ] = None, + direct_memories_source: Optional[ + types.GenerateMemoriesRequestDirectMemoriesSourceOrDict + ] = None, + scope: Optional[dict[str, str]] = None, + config: Optional[types.GenerateAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEngineGenerateMemoriesOperation: + """Generates memories for the agent engine. + + Args: + name (str): + Required. The name of the agent engine to generate memories for. + vertex_session_source (GenerateMemoriesRequestVertexSessionSource): + Optional. The vertex session source to use for generating + memories. Only one of vertex_session_source, + direct_contents_source, or direct_memories_source can be + specified. + direct_contents_source(GenerateMemoriesRequestDirectContentsSource): + Optional. The direct contents source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + direct_memories_source (GenerateMemoriesRequestDirectMemoriesSource): + Optional. The direct memories source to use for generating + memories. Only one of vertex_session_source, direct_contents_source, + or direct_memories_source can be specified. + scope (dict[str, str]): + Optional. The scope of the memories to generate. This is optional + if vertex_session_source is used, otherwise it must be specified. + config (GenerateMemoriesConfig): + Optional. The configuration for the memories to generate. + + Returns: + AgentEngineGenerateMemoriesOperation: + The operation for generating the memories. + """ + if config is None: + config = types.GenerateAgentEngineMemoriesConfig() + elif isinstance(config, dict): + config = types.GenerateAgentEngineMemoriesConfig.model_validate(config) + operation = await self._generate( + name=name, + vertex_session_source=vertex_session_source, + direct_contents_source=direct_contents_source, + direct_memories_source=direct_memories_source, + scope=scope, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_generate_memories_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to generate memory: {operation.error}") + return operation + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryConfigOrDict] = None, + ) -> AsyncPager[types.Memory]: + """Lists Agent Engine memories. + + Args: + name (str): + Required. The name of the agent engine to list memories for. + config (ListAgentEngineMemoryConfig): + Optional. The configuration for the memories to list. + + Returns: + AsyncPager[Memory]: An async pager of memories. + """ + + return AsyncPager( + "memories", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) + + async def retrieve( + self, + *, + name: str, + scope: dict[str, str], + similarity_search_params: Optional[ + types.RetrieveMemoriesRequestSimilaritySearchParamsOrDict + ] = None, + simple_retrieval_params: Optional[ + types.RetrieveMemoriesRequestSimpleRetrievalParamsOrDict + ] = None, + config: Optional[types.RetrieveAgentEngineMemoriesConfigOrDict] = None, + ) -> AsyncPager[types.RetrieveMemoriesResponseRetrievedMemory]: + """Retrieves memories for the agent. + + Args: + name (str): + Required. The name of the agent engine to retrieve memories for. + scope (dict[str, str]): + Required. The scope of the memories to retrieve. For example, + {"user_id": "123"}. + similarity_search_params (RetrieveMemoriesRequestSimilaritySearchParams): + Optional. The similarity search parameters to use for retrieving + memories. + simple_retrieval_params (RetrieveMemoriesRequestSimpleRetrievalParams): + Optional. The simple retrieval parameters to use for retrieving + memories. + config (RetrieveAgentEngineMemoriesConfig): + Optional. The configuration for the memories to retrieve. + + Returns: + AsyncPager[RetrieveMemoriesResponseRetrievedMemory]: An async pager of + retrieved memories. + """ + return AsyncPager( + "retrieved_memories", + lambda config: self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + await self._retrieve( + name=name, + similarity_search_params=similarity_search_params, + simple_retrieval_params=simple_retrieval_params, + scope=scope, + config=config, + ), + config, + ) + + async def rollback( + self, + *, + name: str, + target_revision_id: str, + config: Optional[types.RollbackAgentEngineMemoryConfigOrDict] = None, + ) -> types.AgentEngineRollbackMemoryOperation: + """Rolls back a memory to a previous revision. + + Args: + name (str): + Required. The name of the memory to rollback. + target_revision_id (str): + Required. The revision ID to roll back to + config (RollbackAgentEngineMemoryConfig): + Optional. The configuration for the rollback. + + Returns: + AgentEngineRollbackMemoryOperation: + The operation for rolling back the memory. + """ + if config is None: + config = types.RollbackAgentEngineMemoryConfig() + elif isinstance(config, dict): + config = types.RollbackAgentEngineMemoryConfig.model_validate(config) + operation = await self._rollback( + name=name, + target_revision_id=target_revision_id, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to rollback memory: {operation.error}") + return operation + + async def purge( + self, + *, + name: str, + filter: Optional[str] = None, + filter_groups: Optional[List[types.MemoryConjunctionFilter]] = None, + force: bool = False, + config: Optional[types.PurgeAgentEngineMemoriesConfigOrDict] = None, + ) -> types.AgentEnginePurgeMemoriesOperation: + """Purges memories from an Agent Engine. + + Args: + name (str): + Required. The name of the Agent Engine to purge memories from. + filter (str): + Optional. The standard list filter to determine which memories to purge. + filter_groups (list[MemoryConjunctionFilter]): + Optional. Metadata filters that will be applied to the memories' + `metadata` using OR logic. Filters are defined using disjunctive + normal form (OR of ANDs). + force (bool): + Optional. Whether to force the purge operation. If false, the + operation will be staged but not executed. + config (PurgeAgentEngineMemoriesConfig): + Optional. The configuration for the purge operation. + + Returns: + AgentEnginePurgeMemoriesOperation: + The operation for purging the memories. + """ + if config is None: + config = types.PurgeAgentEngineMemoriesConfig() + elif isinstance(config, dict): + config = types.PurgeAgentEngineMemoriesConfig.model_validate(config) + operation = await self._purge( + name=name, + filter=filter, + filter_groups=filter_groups, + force=force, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to purge memories: {operation.error}") + return operation + + async def ingest_events( + self, + *, + name: str, + scope: dict[str, str], + stream_id: str = "", + direct_contents_source: Optional[ + types.IngestionDirectContentsSourceOrDict + ] = None, + generation_trigger_config: Optional[ + types.MemoryGenerationTriggerConfigOrDict + ] = None, + config: Optional[types.IngestEventsConfigOrDict] = None, + ) -> types.MemoryBankIngestEventsOperation: + """Ingests events into an Agent Engine. + + Example usage: + ``` + await client.aio.agent_engines.memories.ingest_events( + name="projects/test-project/locations/us-central1/reasoningEngines/test-agent-engine", + scope={"user_id": "test-user-id"}, + direct_contents_source={ + "events": [ + { + "content": { + "role": "user", + "parts": [ + {"text": "I am a software engineer."} + ], + } + } + ] + }, + generation_trigger_config={ + "generation_rule": { + "idle_duration": "60s" + } + } + ) + ``` + + Args: + name (str): + Required. The name of the Agent Engine to ingest events into. + scope (dict[str, str]): + Required. The scope of the events to ingest. For example, + {"user_id": "123"}. + stream_id (str): + Optional. The ID of the stream to ingest events into. If not + specified, the events will be ingested into the default stream. + direct_contents_source (IngestionDirectContentsSource): + The direct contents source, containing the events to ingest. + generation_trigger_config (MemoryGenerationTriggerConfig): + Optional. The configuration for the generation trigger config. + config (IngestEventsConfig): + Optional. The configuration for the ingest events operation. + + Returns: + AgentEngineIngestEventsOperation: + The operation for ingesting the events. + """ + if config is None: + config = types.IngestEventsConfig() + elif isinstance(config, dict): + config = types.IngestEventsConfig.model_validate(config) + operation = await self._ingest_events( + name=name, + scope=scope, + stream_id=stream_id, + generation_trigger_config=generation_trigger_config, + direct_contents_source=direct_contents_source, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_memory_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError(f"Failed to ingest events: {operation.error}") + return operation diff --git a/agentplatform/_genai/memory_revisions.py b/agentplatform/_genai/memory_revisions.py new file mode 100644 index 0000000000..1aef74c6c3 --- /dev/null +++ b/agentplatform/_genai/memory_revisions.py @@ -0,0 +1,475 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import types + +logger = logging.getLogger("agentplatform_genai.memoryrevisions") + +logger.setLevel(logging.INFO) + + +def _GetAgentEngineMemoryRevisionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListAgentEngineMemoryRevisionsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineMemoryRevisionsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineMemoryRevisionsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class MemoryRevisions(_api_module.BaseModule): + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryRevisionConfigOrDict] = None, + ) -> types.MemoryRevision: + """ + Gets an agent engine memory revision. + + Args: + name (str): Required. The name of the Agent Engine memory revision to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory_id}/revisions/{revision_id}`. + config (GetAgentEngineMemoryRevisionConfig): + Optional. Additional configurations for getting the Agent Engine memory revision. + + Returns: + AgentEngineMemoryRevision: The requested Agent Engine memory revision. + + """ + + parameter_model = types._GetAgentEngineMemoryRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MemoryRevision._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryRevisionsConfigOrDict] = None, + ) -> types.ListAgentEngineMemoryRevisionsResponse: + """ + Lists Agent Engine memory revisions. + + Args: + name (str): Required. The name of the Agent Engine memory to list revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory_id}`. + config (ListAgentEngineMemoryRevisionsConfig): + Optional. Additional configurations for listing the Agent Engine memory revisions. + + Returns: + ListAgentEngineMemoryRevisionsResponse: The requested Agent Engine memory revisions. + + """ + + parameter_model = types._ListAgentEngineMemoryRevisionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineMemoryRevisionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/revisions".format_map(request_url_dict) + else: + path = "{name}/revisions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineMemoryRevisionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryRevisionsConfigOrDict] = None, + ) -> Iterator[types.MemoryRevision]: + """Lists Agent Engine memory revisions. + + Args: + name (str): + Required. The name of the Memory to list revisions for. + config (ListAgentEngineMemoryRevisionsConfigOrDict): + Optional. The configuration for the memories to list revisions. + + Returns: + Iterable[MemoryRevision]: An iterable of memory revisions. + """ + + return Pager( + "memory_revisions", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + +class AsyncMemoryRevisions(_api_module.BaseModule): + + async def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineMemoryRevisionConfigOrDict] = None, + ) -> types.MemoryRevision: + """ + Gets an agent engine memory revision. + + Args: + name (str): Required. The name of the Agent Engine memory revision to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory_id}/revisions/{revision_id}`. + config (GetAgentEngineMemoryRevisionConfig): + Optional. Additional configurations for getting the Agent Engine memory revision. + + Returns: + AgentEngineMemoryRevision: The requested Agent Engine memory revision. + + """ + + parameter_model = types._GetAgentEngineMemoryRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineMemoryRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.MemoryRevision._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryRevisionsConfigOrDict] = None, + ) -> types.ListAgentEngineMemoryRevisionsResponse: + """ + Lists Agent Engine memory revisions. + + Args: + name (str): Required. The name of the Agent Engine memory to list revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/memories/{memory_id}`. + config (ListAgentEngineMemoryRevisionsConfig): + Optional. Additional configurations for listing the Agent Engine memory revisions. + + Returns: + ListAgentEngineMemoryRevisionsResponse: The requested Agent Engine memory revisions. + + """ + + parameter_model = types._ListAgentEngineMemoryRevisionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineMemoryRevisionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/revisions".format_map(request_url_dict) + else: + path = "{name}/revisions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineMemoryRevisionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineMemoryRevisionsConfigOrDict] = None, + ) -> AsyncPager[types.MemoryRevision]: + """Lists Agent Engine memory revisions. + + Args: + name (str): + Required. The name of the Memory to list revisions for. + config (ListAgentEngineMemoryRevisionsConfigOrDict): + Optional. The configuration for the memories to list revisions. + + Returns: + AsyncPager[MemoryRevision]: An async pager of memory revisions. + """ + + return AsyncPager( + "memory_revisions", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) diff --git a/agentplatform/_genai/mypy.ini b/agentplatform/_genai/mypy.ini new file mode 100644 index 0000000000..daaaba6a5f --- /dev/null +++ b/agentplatform/_genai/mypy.ini @@ -0,0 +1,10 @@ +[mypy] +# TODO(b/422425982): Fix arg-type errors +disable_error_code = import-not-found, import-untyped, arg-type + +# We only want to run mypy on _genai dir, ignore dependent modules +[mypy-agentplatform.*] +ignore_errors = True + +[mypy-agentplatform._genai.*] +ignore_errors = False diff --git a/agentplatform/_genai/prompt_optimizer.py b/agentplatform/_genai/prompt_optimizer.py new file mode 100644 index 0000000000..d525c2a2d3 --- /dev/null +++ b/agentplatform/_genai/prompt_optimizer.py @@ -0,0 +1,995 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import json +import logging +import time +from typing import Any, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv + +from . import _logging_utils +from . import _prompt_optimizer_utils +from . import prompts +from . import types + +logger = logging.getLogger("agentplatform_genai.promptoptimizer") + + +def _CustomJobParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["custom_job"]) is not None: + setv( + parent_object, + ["customJob"], + _CustomJob_to_vertex(getv(from_object, ["custom_job"]), to_object), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CustomJob_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(parent_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(parent_object, ["displayName"])) + + if getv(parent_object, ["jobSpec"]) is not None: + setv(to_object, ["job_spec"], getv(parent_object, ["jobSpec"])) + + if getv(parent_object, ["encryptionSpec"]) is not None: + setv(to_object, ["encryption_spec"], getv(parent_object, ["encryptionSpec"])) + + if getv(from_object, ["state"]) is not None: + setv(to_object, ["state"], getv(from_object, ["state"])) + + if getv(parent_object, ["error"]) is not None: + setv(to_object, ["error"], getv(parent_object, ["error"])) + + if getv(from_object, ["createTime"]) is not None: + setv(to_object, ["create_time"], getv(from_object, ["createTime"])) + + if getv(from_object, ["endTime"]) is not None: + setv(to_object, ["end_time"], getv(from_object, ["endTime"])) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["satisfiesPzi"]) is not None: + setv(to_object, ["satisfies_pzi"], getv(from_object, ["satisfiesPzi"])) + + if getv(from_object, ["satisfiesPzs"]) is not None: + setv(to_object, ["satisfies_pzs"], getv(from_object, ["satisfiesPzs"])) + + if getv(from_object, ["startTime"]) is not None: + setv(to_object, ["start_time"], getv(from_object, ["startTime"])) + + if getv(from_object, ["updateTime"]) is not None: + setv(to_object, ["update_time"], getv(from_object, ["updateTime"])) + + if getv(from_object, ["webAccessUris"]) is not None: + setv(to_object, ["web_access_uris"], getv(from_object, ["webAccessUris"])) + + return to_object + + +def _CustomJob_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["job_spec"]) is not None: + setv(parent_object, ["jobSpec"], getv(from_object, ["job_spec"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(parent_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["state"]) is not None: + setv(to_object, ["state"], getv(from_object, ["state"])) + + if getv(from_object, ["error"]) is not None: + setv(parent_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["create_time"]) is not None: + setv(to_object, ["createTime"], getv(from_object, ["create_time"])) + + if getv(from_object, ["end_time"]) is not None: + setv(to_object, ["endTime"], getv(from_object, ["end_time"])) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["satisfies_pzi"]) is not None: + setv(to_object, ["satisfiesPzi"], getv(from_object, ["satisfies_pzi"])) + + if getv(from_object, ["satisfies_pzs"]) is not None: + setv(to_object, ["satisfiesPzs"], getv(from_object, ["satisfies_pzs"])) + + if getv(from_object, ["start_time"]) is not None: + setv(to_object, ["startTime"], getv(from_object, ["start_time"])) + + if getv(from_object, ["update_time"]) is not None: + setv(to_object, ["updateTime"], getv(from_object, ["update_time"])) + + if getv(from_object, ["web_access_uris"]) is not None: + setv(to_object, ["webAccessUris"], getv(from_object, ["web_access_uris"])) + + return to_object + + +def _GetCustomJobParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _OptimizeConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["optimization_target"]) is not None: + setv( + parent_object, + ["optimizationTarget"], + getv(from_object, ["optimization_target"]), + ) + + return to_object + + +def _OptimizeRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["content"]) is not None: + setv(to_object, ["content"], getv(from_object, ["content"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _OptimizeConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +class PromptOptimizer(_api_module.BaseModule): + """Prompt Optimizer""" + + def _optimize_prompt( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponseEndpoint: + """ + Optimize a single prompt. + """ + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_custom_job_resource( + self, + *, + custom_job: types.CustomJobOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.CustomJob: + """ + Creates a custom job. + """ + + parameter_model = types._CustomJobParameters( + custom_job=custom_job, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs".format_map(request_url_dict) + else: + path = "customJobs" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_custom_job( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.CustomJob: + """ + Gets a custom job. + """ + + parameter_model = types._GetCustomJobParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetCustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs/{name}".format_map(request_url_dict) + else: + path = "customJobs/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + """Prompt Optimizer PO-Data.""" + + def _wait_for_completion(self, job_name: str) -> types.CustomJob: + + JOB_COMPLETE_STATES = [ + genai_types.JobState.JOB_STATE_SUCCEEDED, + genai_types.JobState.JOB_STATE_FAILED, + genai_types.JobState.JOB_STATE_CANCELLED, + genai_types.JobState.JOB_STATE_PAUSED, + ] + JOB_ERROR_STATES = [ + genai_types.JobState.JOB_STATE_FAILED, + genai_types.JobState.JOB_STATE_CANCELLED, + ] + + log_wait = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + + job = self._get_custom_job(name=job_name) + + while job.state not in JOB_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + logger.info(f"Waiting for job to complete. Current state: {job.state}") + log_wait = min(log_wait * wait_multiplier, max_wait_time) + previous_time = current_time + time.sleep(log_wait) + job = self._get_custom_job(name=job_name) + + logger.info(f"Job state: {job.state}") + + if job.state in JOB_ERROR_STATES: + raise RuntimeError(f"Job failed with state: {job.state}") + else: + logger.info(f"Job completed with state: {job.state}") + return job + + @_logging_utils.show_deprecation_warning_once( + "The prompt_optimizer.optimize method is deprecated. Please use" + " prompts.launch_optimization_job instead." + ) + def optimize( + self, + method: types.PromptOptimizerMethod, + config: types.PromptOptimizerConfigOrDict, + ) -> types.CustomJob: + """Call PO-Data optimizer. + + Args: + method: The method for optimizing multiple prompts. Supported methods: + VAPO, OPTIMIZATION_TARGET_GEMINI_NANO. + config: PromptOptimizerConfig instance containing the + configuration for prompt optimization. + Returns: + The custom job that was created. + """ + prompts_module = prompts.Prompts(api_client_=self._api_client) + + return prompts_module.launch_optimization_job( # type: ignore[no-any-return] + method=method, config=config + ) + + @_logging_utils.show_deprecation_warning_once( + "The prompt_optimizer.optimize_prompt method is deprecated. Please use" + " prompts.optimize instead." + ) + def optimize_prompt( + self, + *, + prompt: str, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Makes an API request to _optimize_prompt and returns the parsed response. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + prompt = "Generate system instructions for analyzing medical articles" + response = client.prompt_optimizer.optimize_prompt(prompt=prompt) + print(response.suggested_prompt) + + Args: + prompt: The prompt to optimize. + config: Optional.The configuration for prompt optimization. To optimize + prompts from Android API provide + types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO + ) + For few-shot optimization, provide: + + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS + or + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + types.OptimizeConfig( + optimization_target=optim_target, + examples_dataframe=dataframe + ) + OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS indicates that the few-shot + examples include specific scoring rubrics and their corresponding + evaluations. + OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE indicates that the few-shot + examples include a ground-truth target response. + Returns: + The parsed response from the API request. + """ + prompts_module = prompts.Prompts(api_client_=self._api_client) + + return prompts_module.optimize( # type: ignore[no-any-return] + prompt=prompt, config=config + ) + + def _custom_optimize_prompt( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Optimize a single prompt. + + Sends a request to the tuningJobs:optimizePrompt streaming endpoint. + Then gathers the response, concatenates into one string and returns + the parsed response. + """ + if isinstance(config, dict): + config.pop("examples_dataframe", None) + elif config and hasattr(config, "examples_dataframe"): + del config.examples_dataframe + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_list = "" if not response.body else json.loads(response.body) + + return_value = [] + + for response_dict in response_list: + response_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(response_value) + content = response_value.content + if content is not None: + parts = content.parts + if parts and parts[0].text is not None: + return_value.append(parts[0].text) + + output = "".join(return_value) + final_response = types.OptimizeResponse(raw_text_response=output) + try: + final_response.parsed_response = _prompt_optimizer_utils._parse(output) + except Exception as e: + logger.warning( + f"Failed to parse response: {e}. Returning only raw_text_response." + ) + return final_response + + +class AsyncPromptOptimizer(_api_module.BaseModule): + """Prompt Optimizer""" + + async def _optimize_prompt( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponseEndpoint: + """ + Optimize a single prompt. + """ + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_custom_job_resource( + self, + *, + custom_job: types.CustomJobOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.CustomJob: + """ + Creates a custom job. + """ + + parameter_model = types._CustomJobParameters( + custom_job=custom_job, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs".format_map(request_url_dict) + else: + path = "customJobs" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_custom_job( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.CustomJob: + """ + Gets a custom job. + """ + + parameter_model = types._GetCustomJobParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetCustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs/{name}".format_map(request_url_dict) + else: + path = "customJobs/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + # Todo: b/428953357 - Add example in the README. + @_logging_utils.show_deprecation_warning_once( + "The prompt_optimizer.optimize method is deprecated. Please use" + " prompts.launch_optimization_job instead." + ) + async def optimize( + self, + method: types.PromptOptimizerMethod, + config: types.PromptOptimizerConfigOrDict, + ) -> types.CustomJob: + """Call async Vertex AI Prompt Optimizer (VAPO). + + + Note: The `wait_for_completion` parameter in the config will be + ignored when using the AsyncClient, as it is not supported. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + vapo_config = vertexai.types.PromptOptimizerConfig( + config_path='gs://you-bucket-name/your-config.json', + service_account=service_account, + ) + job = await client.aio.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.VAPO, config=vapo_config) + + Args: + method: The method for optimizing multiple prompts. Supported methods: + VAPO, OPTIMIZATION_TARGET_GEMINI_NANO. + config: PromptOptimizerConfig instance containing the + configuration for prompt optimization. + Returns: + The custom job that was created. + """ + prompts_module = prompts.AsyncPrompts(api_client_=self._api_client) + + return await prompts_module.launch_optimization_job( # type: ignore[no-any-return] + method=method, config=config + ) + + async def _custom_optimize_prompt( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Optimize a single prompt.""" + if isinstance(config, dict): + config.pop("examples_dataframe", None) + elif config and hasattr(config, "examples_dataframe"): + del config.examples_dataframe + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_list = "" if not response.body else json.loads(response.body) + + return_value = [] + + for response_dict in response_list: + response_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(response_value) + content = response_value.content + if content is not None: + parts = content.parts + if parts and parts[0].text is not None: + return_value.append(parts[0].text) + + output = "".join(return_value) + final_response = types.OptimizeResponse(raw_text_response=output) + try: + final_response.parsed_response = _prompt_optimizer_utils._parse(output) + except Exception as e: + logger.warning( + f"Failed to parse response: {e}. Returning only raw_text_response." + ) + return final_response + + @_logging_utils.show_deprecation_warning_once( + "The prompt_optimizer.optimize_prompt method is deprecated. Please use" + " prompts.optimize instead." + ) + async def optimize_prompt( + self, + *, + prompt: str, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Makes an async request to _optimize_prompt and returns an optimized prompt. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompt_optimizer.optimize_prompt(prompt=prompt) + + Args: + prompt: The prompt to optimize. + config: Optional.The configuration for prompt optimization. To optimize + prompts from Android API provide + types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO + ) + For few-shot optimization, provide: + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS # or types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + types.OptimizeConfig( + optimization_target=optim_target, + examples_dataframe=dataframe + ) + OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS indicates that the few-shot + examples include specific scoring rubrics and their corresponding + evaluations. + OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE indicates that the few-shot + examples include a ground-truth target response. + Returns: + The parsed response from the API request. + """ + prompts_module = prompts.AsyncPrompts(api_client_=self._api_client) + + return await prompts_module.optimize( # type: ignore[no-any-return] + prompt=prompt, config=config + ) diff --git a/agentplatform/_genai/prompts.py b/agentplatform/_genai/prompts.py new file mode 100644 index 0000000000..2bcb36b706 --- /dev/null +++ b/agentplatform/_genai/prompts.py @@ -0,0 +1,4385 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import asyncio +import datetime +import json +import logging +import time +from typing import Any, AsyncIterator, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai import operations +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager +from pydantic import ValidationError + +from . import _prompt_management_utils +from . import _prompt_optimizer_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.prompts") + + +def _CreateDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["metadata_schema_uri"]) is not None: + setv( + to_object, ["metadataSchemaUri"], getv(from_object, ["metadata_schema_uri"]) + ) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(to_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["model_reference"]) is not None: + setv(to_object, ["modelReference"], getv(from_object, ["model_reference"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CreateDatasetVersionParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset_name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["dataset_name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["model_reference"]) is not None: + setv(to_object, ["modelReference"], getv(from_object, ["model_reference"])) + + if getv(from_object, ["parent"]) is not None: + setv(to_object, ["parent"], getv(from_object, ["parent"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CustomJobParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["custom_job"]) is not None: + setv( + parent_object, + ["customJob"], + _CustomJob_to_vertex(getv(from_object, ["custom_job"]), to_object), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _CustomJob_from_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(parent_object, ["displayName"]) is not None: + setv(to_object, ["display_name"], getv(parent_object, ["displayName"])) + + if getv(parent_object, ["jobSpec"]) is not None: + setv(to_object, ["job_spec"], getv(parent_object, ["jobSpec"])) + + if getv(parent_object, ["encryptionSpec"]) is not None: + setv(to_object, ["encryption_spec"], getv(parent_object, ["encryptionSpec"])) + + if getv(from_object, ["state"]) is not None: + setv(to_object, ["state"], getv(from_object, ["state"])) + + if getv(parent_object, ["error"]) is not None: + setv(to_object, ["error"], getv(parent_object, ["error"])) + + if getv(from_object, ["createTime"]) is not None: + setv(to_object, ["create_time"], getv(from_object, ["createTime"])) + + if getv(from_object, ["endTime"]) is not None: + setv(to_object, ["end_time"], getv(from_object, ["endTime"])) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["satisfiesPzi"]) is not None: + setv(to_object, ["satisfies_pzi"], getv(from_object, ["satisfiesPzi"])) + + if getv(from_object, ["satisfiesPzs"]) is not None: + setv(to_object, ["satisfies_pzs"], getv(from_object, ["satisfiesPzs"])) + + if getv(from_object, ["startTime"]) is not None: + setv(to_object, ["start_time"], getv(from_object, ["startTime"])) + + if getv(from_object, ["updateTime"]) is not None: + setv(to_object, ["update_time"], getv(from_object, ["updateTime"])) + + if getv(from_object, ["webAccessUris"]) is not None: + setv(to_object, ["web_access_uris"], getv(from_object, ["webAccessUris"])) + + return to_object + + +def _CustomJob_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["job_spec"]) is not None: + setv(parent_object, ["jobSpec"], getv(from_object, ["job_spec"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(parent_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["state"]) is not None: + setv(to_object, ["state"], getv(from_object, ["state"])) + + if getv(from_object, ["error"]) is not None: + setv(parent_object, ["error"], getv(from_object, ["error"])) + + if getv(from_object, ["create_time"]) is not None: + setv(to_object, ["createTime"], getv(from_object, ["create_time"])) + + if getv(from_object, ["end_time"]) is not None: + setv(to_object, ["endTime"], getv(from_object, ["end_time"])) + + if getv(from_object, ["labels"]) is not None: + setv(to_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["satisfies_pzi"]) is not None: + setv(to_object, ["satisfiesPzi"], getv(from_object, ["satisfies_pzi"])) + + if getv(from_object, ["satisfies_pzs"]) is not None: + setv(to_object, ["satisfiesPzs"], getv(from_object, ["satisfies_pzs"])) + + if getv(from_object, ["start_time"]) is not None: + setv(to_object, ["startTime"], getv(from_object, ["start_time"])) + + if getv(from_object, ["update_time"]) is not None: + setv(to_object, ["updateTime"], getv(from_object, ["update_time"])) + + if getv(from_object, ["web_access_uris"]) is not None: + setv(to_object, ["webAccessUris"], getv(from_object, ["web_access_uris"])) + + return to_object + + +def _DeleteDatasetRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prompt_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["prompt_id"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _DeletePromptVersionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["prompt_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["prompt_id"])) + + if getv(from_object, ["version_id"]) is not None: + setv(to_object, ["_url", "version_id"], getv(from_object, ["version_id"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetCustomJobParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetDatasetOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["operation_id"]) is not None: + setv(to_object, ["_url", "operation_id"], getv(from_object, ["operation_id"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _GetDatasetVersionParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["dataset_version_id"]) is not None: + setv( + to_object, + ["_url", "dataset_version_id"], + getv(from_object, ["dataset_version_id"]), + ) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _ListDatasetVersionsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["read_mask"]) is not None: + setv(to_object, ["_url", "read_mask"], getv(from_object, ["read_mask"])) + + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListPromptsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +def _ListDatasetsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _ListPromptsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +def _ListPromptsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _OptimizeConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["optimization_target"]) is not None: + setv( + parent_object, + ["optimizationTarget"], + getv(from_object, ["optimization_target"]), + ) + + return to_object + + +def _OptimizeRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["content"]) is not None: + setv(to_object, ["content"], getv(from_object, ["content"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _OptimizeConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +def _RestoreVersionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["version_id"]) is not None: + setv(to_object, ["_url", "version_id"], getv(from_object, ["version_id"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _UpdateDatasetParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["name"], getv(from_object, ["name"])) + + if getv(from_object, ["dataset_id"]) is not None: + setv(to_object, ["_url", "dataset_id"], getv(from_object, ["dataset_id"])) + + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["metadata"]) is not None: + setv(to_object, ["metadata"], getv(from_object, ["metadata"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["encryption_spec"]) is not None: + setv(to_object, ["encryptionSpec"], getv(from_object, ["encryption_spec"])) + + if getv(from_object, ["model_reference"]) is not None: + setv(to_object, ["modelReference"], getv(from_object, ["model_reference"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +class Prompts(_api_module.BaseModule): + + def _create_dataset_resource( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata_schema_uri: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + model_reference: Optional[str] = None, + config: Optional[types.CreateDatasetConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Creates a dataset resource to store prompts. + """ + + parameter_model = types._CreateDatasetParameters( + name=name, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + model_reference=model_reference, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_dataset_version_resource( + self, + *, + dataset_name: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + model_reference: Optional[str] = None, + parent: Optional[str] = None, + display_name: Optional[str] = None, + config: Optional[types.CreateDatasetVersionConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Creates a dataset version resource to store prompts. + """ + + parameter_model = types._CreateDatasetVersionParameters( + dataset_name=dataset_name, + metadata=metadata, + model_reference=model_reference, + parent=parent, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateDatasetVersionParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}/datasetVersions".format_map(request_url_dict) + else: + path = "datasets/{name}/datasetVersions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_dataset_resource( + self, + *, + name: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.Dataset: + """ + Gets a dataset resource to store prompts. + """ + + parameter_model = types._GetDatasetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}".format_map(request_url_dict) + else: + path = "datasets/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Dataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_dataset_version_resource( + self, + *, + dataset_id: Optional[str] = None, + dataset_version_id: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.DatasetVersion: + """ + Gets a dataset version resource to store prompts. + """ + + parameter_model = types._GetDatasetVersionParameters( + dataset_id=dataset_id, + dataset_version_id=dataset_version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetVersionParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{dataset_version_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{dataset_version_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetVersion._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_dataset_operation( + self, + *, + dataset_id: Optional[str] = None, + operation_id: Optional[str] = None, + config: Optional[types.GetDatasetOperationConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Gets the operation from creating a dataset. + """ + + parameter_model = types._GetDatasetOperationParameters( + dataset_id=dataset_id, + operation_id=operation_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/operations/{operation_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/operations/{operation_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list_prompts( + self, *, config: Optional[types.ListPromptsConfigOrDict] = None + ) -> types.ListDatasetsResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListDatasetsRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListDatasetsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListDatasetsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list_versions( + self, + *, + read_mask: Optional[str] = None, + dataset_id: Optional[str] = None, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> types.ListDatasetVersionsResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListDatasetVersionsRequestParameters( + read_mask=read_mask, + dataset_id=dataset_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListDatasetVersionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListDatasetVersionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete_dataset( + self, *, prompt_id: str, config: Optional[types.DeletePromptConfigOrDict] = None + ) -> types.DeletePromptOperation: + parameter_model = types._DeleteDatasetRequestParameters( + prompt_id=prompt_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteDatasetRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}".format_map(request_url_dict) + else: + path = "datasets/{dataset_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeletePromptOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete_dataset_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.DeletePromptConfigOrDict] = None, + ) -> types.DeletePromptVersionOperation: + parameter_model = types._DeletePromptVersionRequestParameters( + prompt_id=prompt_id, + version_id=version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeletePromptVersionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{version_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{version_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeletePromptVersionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _restore_version( + self, + *, + dataset_id: str, + version_id: str, + config: Optional[types.RestoreVersionConfigOrDict] = None, + ) -> types.RestoreVersionOperation: + """ + Restores the provided prompt version to the latest version. + """ + + parameter_model = types._RestoreVersionRequestParameters( + dataset_id=dataset_id, + version_id=version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RestoreVersionRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{version_id}:restore".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{version_id}:restore" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RestoreVersionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _update_dataset_resource( + self, + *, + name: Optional[str] = None, + dataset_id: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + model_reference: Optional[str] = None, + config: Optional[types.UpdatePromptConfigOrDict] = None, + ) -> types.Dataset: + """ + Creates a dataset resource to store prompts. + """ + + parameter_model = types._UpdateDatasetParameters( + name=name, + dataset_id=dataset_id, + display_name=display_name, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + model_reference=model_reference, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}".format_map(request_url_dict) + else: + path = "datasets/{dataset_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Dataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create_custom_job_resource( + self, + *, + custom_job: types.CustomJobOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.CustomJob: + """ + Creates a custom job. + """ + + parameter_model = types._CustomJobParameters( + custom_job=custom_job, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs".format_map(request_url_dict) + else: + path = "customJobs" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_custom_job( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.CustomJob: + """ + Gets a custom job. + """ + + parameter_model = types._GetCustomJobParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetCustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs/{name}".format_map(request_url_dict) + else: + path = "customJobs/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _optimize( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponseEndpoint: + """ + Optimize a single prompt. + """ + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + prompt: types.PromptOrDict, + config: Optional[types.CreatePromptConfigOrDict] = None, + ) -> types.Prompt: + """Creates a new prompt in a Vertex Dataset resource. + + This method waits for prompt creation to be complete before returning. + + Note: This method does not create a versioned resource for your prompt. + Call create_version() to create a new prompt resource with a version. + + Args: + prompt: The prompt to create. + config: Optional configuration for creating the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated + Dataset resources. + """ + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.CreatePromptConfig(**config) + elif not config: + config = types.CreatePromptConfig() + + _prompt_management_utils._raise_for_invalid_prompt(prompt) + + if prompt.prompt_data is None: + raise ValueError("Prompt data is required to create a prompt.") + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + # Step 1: Create the dataset resource for the prompt and wait for the operation to complete. + create_prompt_dataset_operation = self._create_dataset_resource( + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else f"prompt_{time.strftime('%Y%m%d-%H%M%S')}" + ), + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + metadata_schema_uri=_prompt_management_utils.PROMPT_SCHEMA_URI, + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + ) + dataset_resource_name = self._wait_for_operation( + operation=create_prompt_dataset_operation, + timeout=config.timeout if config else 90, + ) + dataset_id = dataset_resource_name.split("/")[-1] + + # Step 2: Get the dataset resource + dataset_resource = self._get_dataset_resource( + name=dataset_id, + ) + prompt._dataset = dataset_resource + return prompt + + def create_version( + self, + *, + prompt: types.PromptOrDict, + prompt_id: Optional[str] = None, + config: Optional[types.CreatePromptVersionConfigOrDict] = None, + ) -> types.Prompt: + """Creates a prompt resource and an initial prompt version. + + When creating new prompt and prompt version resources, this waits for + the create operation to complete before returning. + + Note: This method is recommended instead of create() since it creates a + versioned resource for your prompt. + + Args: + prompt: The prompt to create. + prompt_id: This parameter is deprecated, since this method will create a new prompt each time it is called. If provided, it will be ignored. + config: Optional configuration for creating the prompt and prompt version. + + Returns: + A types.Prompt object representing the prompt with its associated + Dataset and Dataset Version resources. + """ + if prompt_id: + raise DeprecationWarning( + "The prompt_id argument is deprecated and will be ignored." + ) + + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.CreatePromptVersionConfig(**config) + elif not config: + config = types.CreatePromptVersionConfig() + + _prompt_management_utils._raise_for_invalid_prompt(prompt) + + if prompt.prompt_data is None: + raise ValueError("Prompt data is required to create a prompt.") + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + # Step 1: Create the dataset resource for the prompt and wait for the operation to complete. + create_prompt_dataset_operation = self._create_dataset_resource( + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else f"prompt_{time.strftime('%Y%m%d-%H%M%S')}" + ), + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + metadata_schema_uri=_prompt_management_utils.PROMPT_SCHEMA_URI, + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + ) + dataset_resource_name = self._wait_for_operation( + operation=create_prompt_dataset_operation, + timeout=config.timeout if config else 90, + ) + dataset_id = dataset_resource_name.split("/")[-1] + + # Step 2: Get the dataset resource + dataset_resource = self._get_dataset_resource( + name=dataset_id, + ) + prompt._dataset = dataset_resource + + if prompt._dataset.name is None: + raise ValueError("Invalid dataset resource.") + + # Step 3: Create the dataset version + create_dataset_version_operation = self._create_dataset_version_resource( + dataset_name=prompt._dataset.name.split("/")[-1], + display_name=( + config.version_display_name + if config and config.version_display_name is not None + else f"prompt_version_{time.strftime('%Y%m%d-%H%M%S')}" + ), + ) + dataset_version_resource_name = self._wait_for_operation( + operation=create_dataset_version_operation, + timeout=config.timeout if config else 90, + ) + + # Step 4: Get the dataset version resource and return it with the prompt + dataset_version_resource = self._get_dataset_version_resource( + dataset_id=dataset_id, + dataset_version_id=dataset_version_resource_name.split("/")[-1], + ) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource + ) + prompt._dataset = dataset_resource + prompt._dataset_version = dataset_version_resource + return prompt + + def _wait_for_operation( + self, + operation: types.DatasetOperation, + timeout: int, + ) -> str: + """Waits for a dataset operation to complete. + + Args: + operation: The dataset operation to wait for. + timeout: The maximum time to wait for the operation to complete. + + Returns: + The name of the Dataset resource from the operation result. + + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + done = False + prompt_dataset_operation: Optional[types.DatasetOperation] = None + + response_operation_name = operation.name + if response_operation_name is None: + raise ValueError("Invalid operation name.") + + dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] + operation_id = response_operation_name.split("/")[-1] + + start_time = time.time() + sleep_duration = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + + while not done: + if (time.time() - start_time) > timeout: + raise TimeoutError( + "Create prompt operation did not complete within the" + f" specified timeout of {timeout} seconds." + ) + current_time = time.time() + if current_time - previous_time >= sleep_duration: + sleep_duration = min(sleep_duration * wait_multiplier, max_wait_time) + previous_time = current_time + time.sleep(sleep_duration) + prompt_dataset_operation = self._get_dataset_operation( + dataset_id=dataset_id, + operation_id=operation_id, + ) + done = ( + (prompt_dataset_operation.done or False) + if hasattr(prompt_dataset_operation, "done") + else False + ) + if ( + not prompt_dataset_operation + or prompt_dataset_operation.response is None + or prompt_dataset_operation.response.get("name") is None + ): + raise ValueError("Error creating prompt version resource.") + if ( + hasattr(prompt_dataset_operation, "error") + and prompt_dataset_operation.error is not None + ): + raise ValueError( + f"Error creating prompt version resource: {prompt_dataset_operation.error}" + ) + return prompt_dataset_operation.response.get("name") # type: ignore[return-value] + + def get( + self, + *, + prompt_id: str, + config: Optional[types.GetPromptConfig] = None, + ) -> types.Prompt: + """Gets a prompt resource from a Vertex Dataset. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + config: Optional configuration for getting the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated Dataset resources. + """ + + prompt_dataset_resource = self._get_dataset_resource(name=prompt_id) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + prompt_dataset_resource, + ) + prompt._dataset = prompt_dataset_resource + + return prompt + + def get_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.GetPromptConfig] = None, + ) -> types.Prompt: + """Gets a prompt resource from a Vertex Dataset. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + version_id: The id of the Vertex Dataset Version resource containing the prompt version. For example, if the prompt version resource name is "projects/123/locations/us-central1/datasets/456/datasetVersions/1", then the version_id is "1". + config: Optional configuration for getting the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated Dataset and Dataset Version resources. + """ + + prompt_dataset_resource = self._get_dataset_resource(name=prompt_id) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + prompt_dataset_resource, + ) + prompt._dataset = prompt_dataset_resource + + prompt_version_resource = self._get_dataset_version_resource( + dataset_id=prompt_id, + dataset_version_id=version_id, + ) + prompt._dataset_version = prompt_version_resource + + return prompt + + def _list_prompts_pager( + self, + *, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> Pager[types.Dataset]: + return Pager( + "datasets", + self._list_prompts, + self._list_prompts(config=config), + config, + ) + + def _list_versions_pager( + self, + *, + prompt_id: str, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> Pager[types.DatasetVersion]: + return Pager( + "dataset_versions", + self._list_versions, + self._list_versions(config=config, dataset_id=prompt_id), + config, + ) + + def list( + self, + *, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> Iterator[types.PromptRef]: + """Lists prompt resources in a project. + + This method retrieves all the prompts from the project provided in the + vertexai.Client constructor and returns a list of prompt references containing the prompt_id and model for the prompt. + + To get the full types.Prompt resource for a PromptRef after calling this method, use the get() method with the prompt_id as the prompt_id argument. + Example usage: + + ``` + # Using an iterator + prompt_refs = client.prompt_management.list_prompts() + for prompt_ref in prompt_refs: + client.prompt_management.get(prompt_id=prompt_ref.prompt_id) + + # Using a list + prompts_list = list(client.prompt_management.list_prompts()) + client.prompt_management.get(prompt_id=prompts_list[0].prompt_id) + ``` + + Args: + config: Optional configuration for listing prompts. + + Returns: + An iterable of types.PromptRef objects. + """ + if isinstance(config, dict): + config = types.ListPromptsConfig(**config) + elif not config: + config = types.ListPromptsConfig() + for dataset in self._list_prompts_pager(config=config): + if not dataset.name: + continue + prompt_ref = types.PromptRef( + model=dataset.model_reference, prompt_id=dataset.name.split("/")[-1] + ) + yield prompt_ref + + def list_versions( + self, + *, + prompt_id: str, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> Iterator[types.PromptVersionRef]: + """Lists prompt version resources for a provided prompt_id. + + This method retrieves all the prompt versions for a provided prompt_id. + + To get the full types.Prompt resource for a PromptVersionRef after calling this method, use the get() method with the returned prompt_id and version_id. + Example usage: + + ``` + # Using an iterator + prompt_version_refs = client.prompt_management.list_versions(prompt_id="123") + for version_ref in prompt_version_refs: + client.prompt_management.get(prompt_id=version_ref.prompt_id, version_id=version_ref.version_id) + + # Using a list + prompt_versions_list = list(client.prompt_management.list_versions(prompt_id="123")) + client.prompt_management.get(prompt_id=prompt_versions_list[0].prompt_id, version_id=prompt_versions_list[0].version_id) + ``` + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + config: Optional configuration for listing prompts. + + Returns: + An iterable of types.PromptVersionRef objects representing the prompt version resources for the provided prompt_id. + + """ + if isinstance(config, dict): + config = types.ListPromptsConfig(**config) + elif not config: + config = types.ListPromptsConfig() + for dataset_version in self._list_versions_pager( + config=config, prompt_id=prompt_id + ): + if ( + not dataset_version + or not dataset_version.model_reference + or not dataset_version.name + ): + continue + prompt_version_ref = types.PromptVersionRef( + model=dataset_version.model_reference, + version_id=dataset_version.name.split("/")[-1], + prompt_id=prompt_id, + ) + yield prompt_version_ref + + def _wait_for_project_operation( + self, + operation: genai_types.ProjectOperation, + timeout: int, + ) -> None: + """Waits for a dataset deletion operation to complete. + + Delete operations are project level operations and are separate from dataset resource operations, for example: projects/123/locations/us-central1/operations/789. + + Args: + operation: The project operation to wait for. + timeout: The maximum time to wait for the operation to complete. + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + done = False + + start_time = time.time() + sleep_duration = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + while not done: + if (time.time() - start_time) > timeout: + raise TimeoutError( + f"Delete operation did not complete within the" + f" specified timeout of {timeout} seconds." + ) + current_time = time.time() + if current_time - previous_time >= sleep_duration: + sleep_duration = min(sleep_duration * wait_multiplier, max_wait_time) + previous_time = current_time + time.sleep(sleep_duration) + operations_module = operations.Operations(api_client_=self._api_client) + + if operation.name is None: + raise ValueError("Invalid operation name.") + operation = operations_module._get( + operation_id=operation.name.split("/")[-1], + ) + done = (operation.done or False) if hasattr(operation, "done") else False + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error in delete operation: {operation.error}") + + def delete( + self, + *, + prompt_id: str, + config: Optional[types.DeletePromptConfig] = None, + ) -> None: + """Deletes a prompt resource. + + Args: + prompt_id: The id of the prompt resource to delete. + + Raises: + TimeoutError: If the delete operation does not complete within the timeout. + ValueError: If the delete operation fails. + """ + + delete_prompt_operation = self._delete_dataset( + prompt_id=prompt_id, + config=config, + ) + self._wait_for_project_operation( + operation=delete_prompt_operation, timeout=config.timeout if config else 90 + ) + logger.info(f"Deleted prompt with id: {prompt_id}") + + def delete_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.DeletePromptConfig] = None, + ) -> None: + """Deletes a prompt version resource. + + Args: + prompt_id: The id of the prompt resource to delete. + version_id: The id of the prompt version resource to delete. + + Raises: + TimeoutError: If the delete operation does not complete within the timeout. + ValueError: If the delete operation fails. + """ + delete_version_operation = self._delete_dataset_version( + prompt_id=prompt_id, + version_id=version_id, + config=config, + ) + + self._wait_for_project_operation( + operation=delete_version_operation, timeout=config.timeout if config else 90 + ) + logger.info( + f"Deleted prompt version {version_id} from prompt with id: {prompt_id}" + ) + + def restore_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.RestoreVersionConfig] = None, + ) -> types.Prompt: + """Restores the provided prompt version to the latest version. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + version_id: The id of the Vertex Dataset Version resource to restore. For example, if the version resource name is "projects/123/locations/us-central1/datasets/456/datasetVersions/789", then the version_id is "789". + config: Optional configuration for restoring the prompt version. + + Returns: + A types.Prompt object representing the prompt with the updated Dataset Version resource. + """ + + restore_prompt_operation = self._restore_version( + dataset_id=prompt_id, + version_id=version_id, + ) + self._wait_for_project_operation( + operation=restore_prompt_operation, + timeout=90, + ) + dataset_version_resource = self._get_dataset_version_resource( + dataset_id=prompt_id, + dataset_version_id=version_id, + ) + updated_prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource, + ) + updated_prompt._dataset_version = dataset_version_resource + return updated_prompt + + def _wait_for_completion(self, job_name: str) -> types.CustomJob: + + JOB_COMPLETE_STATES = [ + genai_types.JobState.JOB_STATE_SUCCEEDED, + genai_types.JobState.JOB_STATE_FAILED, + genai_types.JobState.JOB_STATE_CANCELLED, + genai_types.JobState.JOB_STATE_PAUSED, + ] + JOB_ERROR_STATES = [ + genai_types.JobState.JOB_STATE_FAILED, + genai_types.JobState.JOB_STATE_CANCELLED, + ] + + log_wait = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + + job = self._get_custom_job(name=job_name) + + while job.state not in JOB_COMPLETE_STATES: + current_time = time.time() + if current_time - previous_time >= log_wait: + logger.info(f"Waiting for job to complete. Current state: {job.state}") + log_wait = min(log_wait * wait_multiplier, max_wait_time) + previous_time = current_time + time.sleep(log_wait) + job = self._get_custom_job(name=job_name) + + logger.info(f"Job state: {job.state}") + + if job.state in JOB_ERROR_STATES: + raise RuntimeError(f"Job failed with state: {job.state}") + else: + logger.info(f"Job completed with state: {job.state}") + return job + + @_common.experimental_warning( + "The Vertex SDK GenAI prompts.launch_optimization_job method is " + "experimental, and may change in future versions." + ) + def launch_optimization_job( + self, + method: types.PromptOptimizerMethod, + config: types.PromptOptimizerConfigOrDict, + ) -> types.CustomJob: + """Call PO-Data optimizer. + + Args: + method: The method for optimizing multiple prompts. Supported methods: + VAPO, OPTIMIZATION_TARGET_GEMINI_NANO. + config: PromptOptimizerConfig instance containing the + configuration for prompt optimization. + Returns: + The custom job that was created. + """ + + if isinstance(config, dict): + config = types.PromptOptimizerConfig(**config) + + if not config.config_path: + raise ValueError("Config path is required.") + + _OPTIMIZER_METHOD_TO_CONTAINER_URI = { + types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0", + types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_android_v1_0", + } + container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method) + if not container_uri: + raise ValueError( + 'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" ' + "methods are currently supported." + ) + + if config.optimizer_job_display_name: + display_name = config.optimizer_job_display_name + else: + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + display_name = f"{method.value.lower()}-optimizer-{timestamp}" + + wait_for_completion = config.wait_for_completion + bucket = "/".join(config.config_path.split("/")[:-1]) + + region = self._api_client.location + project = self._api_client.project + container_args = { + "config": config.config_path, + } + args = ["--%s=%s" % (k, v) for k, v in container_args.items()] + worker_pool_specs = [ + types.WorkerPoolSpec( + replica_count=1, + machine_spec=types.MachineSpec(machine_type="n1-standard-4"), + container_spec=types.ContainerSpec( + image_uri=container_uri, + args=args, + ), + ) + ] + + service_account = _prompt_optimizer_utils._get_service_account(config) + + job_spec = types.CustomJobSpec( + worker_pool_specs=worker_pool_specs, + base_output_directory=genai_types.GcsDestination(output_uri_prefix=bucket), + service_account=service_account, + ) + + custom_job = types.CustomJob( + display_name=display_name, + job_spec=job_spec, + ) + + job = self._create_custom_job_resource( + custom_job=custom_job, + ) + + # Get the job resource name + job_resource_name = job.name + if not job_resource_name: + raise ValueError(f"Error creating job: {job}") + job_id = job_resource_name.split("/")[-1] + logger.info("Job created: %s", job.name) + + # Construct the dashboard URL + dashboard_url = f"https://console.cloud.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?project={project}" + logger.info("View the job status at: %s", dashboard_url) + + if wait_for_completion: + job = self._wait_for_completion(job_id) + return job + + @_common.experimental_warning( + "The Vertex SDK GenAI prompts.optimize method is " + "experimental, and may change in future versions." + ) + def optimize( + self, + *, + prompt: str, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Makes an API request to optimize a prompt and returns the parsed response. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + prompt = "Generate system instructions for analyzing medical articles" + response = client.prompts.optimize(prompt=prompt) + + Args: + prompt: Required. The prompt to optimize. + config: Optional. The configuration for prompt optimization. To optimize + prompts from Android API provide + types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO + ) + For few-shot optimization, provide: + + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS + or + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + types.OptimizeConfig( + optimization_target=optim_target, + examples_dataframe=dataframe + ) + OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS indicates that the few-shot + examples include specific scoring rubrics and their corresponding + evaluations. + OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE indicates that the few-shot + examples include a ground-truth target response. + Returns: + The parsed response from the API request. + """ + + if isinstance(config, dict): + config = types.OptimizeConfig(**config) + + optimization_target: Optional[types.OptimizeTarget] = None + if config is not None: + optimization_target = config.optimization_target + + final_prompt = prompt + if ( + optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS + or optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + ): + final_prompt = _prompt_optimizer_utils._get_few_shot_prompt(prompt, config) + + # TODO: b/435653980 - replace the custom method with a generated method. + config_for_api = config.model_copy() if config else None + return self._custom_optimize( + content=genai_types.Content( + parts=[genai_types.Part(text=final_prompt)], role="user" + ), + config=config_for_api, + ) + + def _custom_optimize( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Internal method to call the optimizePrompt endpoint. + + Sends a request to the tuningJobs:optimizePrompt streaming endpoint. + Then gathers the response, concatenates into one string and returns + the parsed response. + """ + if isinstance(config, dict): + config.pop("examples_dataframe", None) + elif config and hasattr(config, "examples_dataframe"): + del config.examples_dataframe + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[genai_types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_list = "" if not response.body else json.loads(response.body) + + return_value = [] + + for response_dict in response_list: + response_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(response_value) + content = response_value.content + if content is not None: + parts = content.parts + if parts and parts[0].text is not None: + return_value.append(parts[0].text) + + output = "".join(return_value) + final_response = types.OptimizeResponse(raw_text_response=output) + try: + final_response.parsed_response = _prompt_optimizer_utils._parse(output) + except (ValueError, TypeError, ValidationError) as e: + logger.warning( + f"Failed to parse response: {e}. Returning only raw_text_response." + ) + return final_response + + def update( + self, + *, + prompt_id: str, + prompt: types.PromptOrDict, + config: Optional[types.UpdatePromptConfigOrDict] = None, + ) -> types.Prompt: + """Updates an existing prompt and creates a new version for the prompt associated with the provided prompt_id. + + Args: + prompt_id: The ID of the prompt to create a new version for. + prompt: The updated prompt. + config: Optional configuration for updating the prompt. + + Returns: + A types.Prompt object representing the updated prompt with its associated + Dataset and Dataset Version resources. + """ + + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.UpdatePromptConfig(**config) + elif not config: + config = types.UpdatePromptConfig() + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + if not prompt.prompt_data: + raise ValueError("Prompt data is required to update a prompt.") + + # Step 1: Update the dataset resource for the prompt and wait for the operation to complete. + updated_dataset_resource = self._update_dataset_resource( + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + dataset_id=prompt_id, + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else None + ), + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + config=config, + ) + + if not updated_dataset_resource.name: + raise ValueError("Failed to update dataset resource.") + + dataset_id = updated_dataset_resource.name.split("/")[-1] + + # Step 2: Create a dataset version for the prompt. + create_dataset_version_operation = self._create_dataset_version_resource( + dataset_name=dataset_id, + display_name=( + config.version_display_name + if config and config.version_display_name is not None + else f"prompt_version_{time.strftime('%Y%m%d-%H%M%S')}" + ), + ) + dataset_version_resource_name = self._wait_for_operation( + operation=create_dataset_version_operation, + timeout=config.timeout if config else 90, + ) + dataset_version_id = dataset_version_resource_name.split("/")[-1] + + # Step 3: Get the dataset version resource and return it with the prompt. + dataset_version_resource = self._get_dataset_version_resource( + dataset_id=dataset_id, + dataset_version_id=dataset_version_id, + ) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource + ) + prompt._dataset = updated_dataset_resource + prompt._dataset_version = dataset_version_resource + return prompt + + +class AsyncPrompts(_api_module.BaseModule): + + async def _create_dataset_resource( + self, + *, + name: Optional[str] = None, + display_name: Optional[str] = None, + metadata_schema_uri: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + model_reference: Optional[str] = None, + config: Optional[types.CreateDatasetConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Creates a dataset resource to store prompts. + """ + + parameter_model = types._CreateDatasetParameters( + name=name, + display_name=display_name, + metadata_schema_uri=metadata_schema_uri, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + model_reference=model_reference, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_dataset_version_resource( + self, + *, + dataset_name: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + model_reference: Optional[str] = None, + parent: Optional[str] = None, + display_name: Optional[str] = None, + config: Optional[types.CreateDatasetVersionConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Creates a dataset version resource to store prompts. + """ + + parameter_model = types._CreateDatasetVersionParameters( + dataset_name=dataset_name, + metadata=metadata, + model_reference=model_reference, + parent=parent, + display_name=display_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateDatasetVersionParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}/datasetVersions".format_map(request_url_dict) + else: + path = "datasets/{name}/datasetVersions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_dataset_resource( + self, + *, + name: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.Dataset: + """ + Gets a dataset resource to store prompts. + """ + + parameter_model = types._GetDatasetParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{name}".format_map(request_url_dict) + else: + path = "datasets/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Dataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_dataset_version_resource( + self, + *, + dataset_id: Optional[str] = None, + dataset_version_id: Optional[str] = None, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.DatasetVersion: + """ + Gets a dataset version resource to store prompts. + """ + + parameter_model = types._GetDatasetVersionParameters( + dataset_id=dataset_id, + dataset_version_id=dataset_version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetVersionParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{dataset_version_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{dataset_version_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetVersion._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_dataset_operation( + self, + *, + dataset_id: Optional[str] = None, + operation_id: Optional[str] = None, + config: Optional[types.GetDatasetOperationConfigOrDict] = None, + ) -> types.DatasetOperation: + """ + Gets the operation from creating a dataset. + """ + + parameter_model = types._GetDatasetOperationParameters( + dataset_id=dataset_id, + operation_id=operation_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetDatasetOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/operations/{operation_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/operations/{operation_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DatasetOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list_prompts( + self, *, config: Optional[types.ListPromptsConfigOrDict] = None + ) -> types.ListDatasetsResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListDatasetsRequestParameters( + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListDatasetsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets".format_map(request_url_dict) + else: + path = "datasets" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListDatasetsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list_versions( + self, + *, + read_mask: Optional[str] = None, + dataset_id: Optional[str] = None, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> types.ListDatasetVersionsResponse: + """ + Lists Agent Engines. + """ + + parameter_model = types._ListDatasetVersionsRequestParameters( + read_mask=read_mask, + dataset_id=dataset_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListDatasetVersionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListDatasetVersionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete_dataset( + self, *, prompt_id: str, config: Optional[types.DeletePromptConfigOrDict] = None + ) -> types.DeletePromptOperation: + parameter_model = types._DeleteDatasetRequestParameters( + prompt_id=prompt_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteDatasetRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}".format_map(request_url_dict) + else: + path = "datasets/{dataset_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeletePromptOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete_dataset_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.DeletePromptConfigOrDict] = None, + ) -> types.DeletePromptVersionOperation: + parameter_model = types._DeletePromptVersionRequestParameters( + prompt_id=prompt_id, + version_id=version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeletePromptVersionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{version_id}".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{version_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeletePromptVersionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _restore_version( + self, + *, + dataset_id: str, + version_id: str, + config: Optional[types.RestoreVersionConfigOrDict] = None, + ) -> types.RestoreVersionOperation: + """ + Restores the provided prompt version to the latest version. + """ + + parameter_model = types._RestoreVersionRequestParameters( + dataset_id=dataset_id, + version_id=version_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RestoreVersionRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}/datasetVersions/{version_id}:restore".format_map( + request_url_dict + ) + else: + path = "datasets/{dataset_id}/datasetVersions/{version_id}:restore" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RestoreVersionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update_dataset_resource( + self, + *, + name: Optional[str] = None, + dataset_id: Optional[str] = None, + display_name: Optional[str] = None, + metadata: Optional[types.SchemaTextPromptDatasetMetadataOrDict] = None, + description: Optional[str] = None, + encryption_spec: Optional[genai_types.EncryptionSpecOrDict] = None, + model_reference: Optional[str] = None, + config: Optional[types.UpdatePromptConfigOrDict] = None, + ) -> types.Dataset: + """ + Creates a dataset resource to store prompts. + """ + + parameter_model = types._UpdateDatasetParameters( + name=name, + dataset_id=dataset_id, + display_name=display_name, + metadata=metadata, + description=description, + encryption_spec=encryption_spec, + model_reference=model_reference, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateDatasetParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "datasets/{dataset_id}".format_map(request_url_dict) + else: + path = "datasets/{dataset_id}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Dataset._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create_custom_job_resource( + self, + *, + custom_job: types.CustomJobOrDict, + config: Optional[types.VertexBaseConfigOrDict] = None, + ) -> types.CustomJob: + """ + Creates a custom job. + """ + + parameter_model = types._CustomJobParameters( + custom_job=custom_job, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs".format_map(request_url_dict) + else: + path = "customJobs" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_custom_job( + self, *, name: str, config: Optional[types.VertexBaseConfigOrDict] = None + ) -> types.CustomJob: + """ + Gets a custom job. + """ + + parameter_model = types._GetCustomJobParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetCustomJobParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "customJobs/{name}".format_map(request_url_dict) + else: + path = "customJobs/{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + if self._api_client.vertexai: + response_dict = _CustomJob_from_vertex(response_dict) + + return_value = types.CustomJob._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _optimize( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponseEndpoint: + """ + Optimize a single prompt. + """ + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def create( + self, + *, + prompt: types.PromptOrDict, + config: Optional[types.CreatePromptConfigOrDict] = None, + ) -> types.Prompt: + """Creates a new prompt in a Vertex Dataset resource. + + This method waits for prompt creation to be complete before returning. + + Note: This method does not create a versioned resource for your prompt. + Call create_version() to create a new prompt resource with a version. + + Args: + prompt: The prompt to create. + config: Optional configuration for creating the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated + Dataset resources. + """ + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.CreatePromptConfig(**config) + elif not config: + config = types.CreatePromptConfig() + + _prompt_management_utils._raise_for_invalid_prompt(prompt) + + if prompt.prompt_data is None: + raise ValueError("Prompt data is required to create a prompt.") + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + # Step 1: Create the dataset resource for the prompt and wait for the operation to complete. + create_prompt_dataset_operation = await self._create_dataset_resource( + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else f"prompt_{time.strftime('%Y%m%d-%H%M%S')}" + ), + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + metadata_schema_uri=_prompt_management_utils.PROMPT_SCHEMA_URI, + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + ) + dataset_resource_name = await self._wait_for_operation( + operation=create_prompt_dataset_operation, + timeout=config.timeout if config else 90, + ) + dataset_id = dataset_resource_name.split("/")[-1] + + # Step 2: Get the dataset resource + dataset_resource = await self._get_dataset_resource( + name=dataset_id, + ) + prompt._dataset = dataset_resource + return prompt + + async def create_version( + self, + *, + prompt: types.PromptOrDict, + prompt_id: Optional[str] = None, + config: Optional[types.CreatePromptVersionConfigOrDict] = None, + ) -> types.Prompt: + """Creates a prompt resource and an initial prompt version. + + When creating new prompt and prompt version resources, this waits for + the create operation to complete before returning. + + Note: This method is recommended instead of create() since it creates a + versioned resource for your prompt. + + Args: + prompt: The prompt to create. + prompt_id: This parameter is deprecated, since this method will create a new prompt each time it is called. If provided, it will be ignored. + config: Optional configuration for creating the prompt and prompt version. + + Returns: + A types.Prompt object representing the prompt with its associated + Dataset and Dataset Version resources. + """ + if prompt_id: + raise DeprecationWarning( + "The prompt_id argument is deprecated and will be ignored." + ) + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.CreatePromptVersionConfig(**config) + elif not config: + config = types.CreatePromptVersionConfig() + + _prompt_management_utils._raise_for_invalid_prompt(prompt) + + if prompt.prompt_data is None: + raise ValueError("Prompt data is required to create a prompt.") + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + # Step 1: Create the dataset resource for the prompt and wait for the operation to complete. + create_prompt_dataset_operation = await self._create_dataset_resource( + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else f"prompt_{time.strftime('%Y%m%d-%H%M%S')}" + ), + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + metadata_schema_uri=_prompt_management_utils.PROMPT_SCHEMA_URI, + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + ) + dataset_resource_name = await self._wait_for_operation( + operation=create_prompt_dataset_operation, + timeout=config.timeout if config else 90, + ) + dataset_id = dataset_resource_name.split("/")[-1] + + # Step 2: Get the dataset resource + dataset_resource = await self._get_dataset_resource( + name=dataset_id, + ) + prompt._dataset = dataset_resource + + if prompt._dataset.name is None: + raise ValueError("Invalid dataset resource.") + + # Step 3: Create the dataset version + create_dataset_version_operation = await self._create_dataset_version_resource( + dataset_name=prompt._dataset.name.split("/")[-1], + display_name=( + config.version_display_name + if config and config.version_display_name is not None + else f"prompt_version_{time.strftime('%Y%m%d-%H%M%S')}" + ), + ) + dataset_version_resource_name = await self._wait_for_operation( + operation=create_dataset_version_operation, + timeout=config.timeout if config else 90, + ) + + # Step 4: Get the dataset version resource and return it with the prompt + dataset_version_resource = await self._get_dataset_version_resource( + dataset_id=dataset_id, + dataset_version_id=dataset_version_resource_name.split("/")[-1], + ) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource + ) + prompt._dataset = dataset_resource + prompt._dataset_version = dataset_version_resource + return prompt + + async def update( + self, + *, + prompt_id: str, + prompt: types.PromptOrDict, + config: Optional[types.UpdatePromptConfigOrDict] = None, + ) -> types.Prompt: + """Updates an existing prompt and creates a new version for the prompt + associated with the provided prompt_id. + + Args: + prompt_id: The ID of the prompt to create a new version for. + prompt: The updated prompt. + config: Optional configuration for updating the prompt. + + Returns: + A types.Prompt object representing the updated prompt with its associated + Dataset and Dataset Version resources. + """ + + if isinstance(prompt, dict): + prompt = types.Prompt(**prompt) + if isinstance(config, dict): + config = types.UpdatePromptConfig(**config) + elif not config: + config = types.UpdatePromptConfig() + + prompt_metadata = _prompt_management_utils._create_dataset_metadata_from_prompt( + prompt, + variables=( + prompt.prompt_data.variables + if prompt.prompt_data and prompt.prompt_data.variables + else None + ), + ) + + if not prompt.prompt_data: + raise ValueError("Prompt data is required to update a prompt.") + + # Step 1: Update the dataset resource for the prompt and wait for the operation to complete. + updated_dataset_resource = await self._update_dataset_resource( + name=f"projects/{self._api_client.project}/locations/{self._api_client.location}", + dataset_id=prompt_id, + display_name=( + config.prompt_display_name + if config and config.prompt_display_name + else None + ), + metadata=prompt_metadata, + model_reference=prompt.prompt_data.model, + encryption_spec=( + config.encryption_spec if config and config.encryption_spec else None + ), + config=config, + ) + + if not updated_dataset_resource.name: + raise ValueError("Failed to update dataset resource.") + + dataset_id = updated_dataset_resource.name.split("/")[-1] + + # Step 2: Create a dataset version for the prompt. + create_dataset_version_operation = await self._create_dataset_version_resource( + dataset_name=dataset_id, + display_name=( + config.version_display_name + if config and config.version_display_name is not None + else f"prompt_version_{time.strftime('%Y%m%d-%H%M%S')}" + ), + ) + dataset_version_resource_name = await self._wait_for_operation( + operation=create_dataset_version_operation, + timeout=config.timeout if config else 90, + ) + dataset_version_id = dataset_version_resource_name.split("/")[-1] + + # Step 3: Get the dataset version resource and return it with the prompt. + dataset_version_resource = await self._get_dataset_version_resource( + dataset_id=dataset_id, + dataset_version_id=dataset_version_id, + ) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource + ) + prompt._dataset = updated_dataset_resource + prompt._dataset_version = dataset_version_resource + return prompt + + async def _wait_for_operation( + self, + operation: types.DatasetOperation, + timeout: int, + ) -> str: + """Waits for a dataset operation to complete. + + Args: + operation: The dataset operation to wait for. + timeout: The maximum time to wait for the operation to complete. + + Returns: + The name of the Dataset resource from the operation result. + + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + done = False + prompt_dataset_operation: Optional[types.DatasetOperation] = None + + response_operation_name = operation.name + if response_operation_name is None: + raise ValueError("Invalid operation name.") + + dataset_id = response_operation_name.split("/datasets/")[1].split("/")[0] + operation_id = response_operation_name.split("/")[-1] + + start_time = time.time() + sleep_duration = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + + while not done: + if (time.time() - start_time) > timeout: + raise TimeoutError( + "Create prompt operation did not complete within the" + f" specified timeout of {timeout} seconds." + ) + current_time = time.time() + if current_time - previous_time >= sleep_duration: + sleep_duration = min(sleep_duration * wait_multiplier, max_wait_time) + previous_time = current_time + await asyncio.sleep(sleep_duration) + prompt_dataset_operation = await self._get_dataset_operation( + dataset_id=dataset_id, + operation_id=operation_id, + ) + done = ( + (prompt_dataset_operation.done or False) + if hasattr(prompt_dataset_operation, "done") + else False + ) + if ( + not prompt_dataset_operation + or prompt_dataset_operation.response is None + or prompt_dataset_operation.response.get("name") is None + ): + raise ValueError("Error creating prompt version resource.") + if ( + hasattr(prompt_dataset_operation, "error") + and prompt_dataset_operation.error is not None + ): + raise ValueError( + f"Error creating prompt version resource: {prompt_dataset_operation.error}" + ) + return prompt_dataset_operation.response.get("name") # type: ignore[return-value] + + async def get( + self, + *, + prompt_id: str, + config: Optional[types.GetPromptConfig] = None, + ) -> types.Prompt: + """Gets a prompt resource from a Vertex Dataset. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + config: Optional configuration for getting the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated Dataset and Dataset Version resources. + """ + + prompt_dataset_resource = await self._get_dataset_resource(name=prompt_id) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + prompt_dataset_resource, + ) + prompt._dataset = prompt_dataset_resource + + return prompt + + async def get_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.GetPromptConfig] = None, + ) -> types.Prompt: + """Gets a prompt resource from a Vertex Dataset. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + version_id: The id of the Vertex Dataset Version resource containing the prompt version. For example, if the prompt version resource name is "projects/123/locations/us-central1/datasets/456/datasetVersions/1", then the version_id is "1". + config: Optional configuration for getting the prompt. + + Returns: + A types.Prompt object representing the prompt with its associated Dataset and Dataset Version resources. + """ + + prompt_dataset_resource = await self._get_dataset_resource(name=prompt_id) + prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + prompt_dataset_resource, + ) + prompt._dataset = prompt_dataset_resource + + prompt_version_resource = await self._get_dataset_version_resource( + dataset_id=prompt_id, + dataset_version_id=version_id, + ) + prompt._dataset_version = prompt_version_resource + + return prompt + + async def _wait_for_project_operation( + self, + operation: genai_types.ProjectOperation, + timeout: int, + ) -> None: + """Waits for a dataset deletion operation to complete. + + Delete operations are project level operations and are separate from dataset resource operations, for example: projects/123/locations/us-central1/operations/789. + + Args: + operation: The project operation to wait for. + timeout: The maximum time to wait for the operation to complete. + Raises: + TimeoutError: If the operation does not complete within the timeout. + ValueError: If the operation fails. + """ + done = False + + start_time = time.time() + sleep_duration = 5 + wait_multiplier = 2 + max_wait_time = 60 + previous_time = time.time() + while not done: + if (time.time() - start_time) > timeout: + raise TimeoutError( + f"Delete operation did not complete within the" + f" specified timeout of {timeout} seconds." + ) + current_time = time.time() + if current_time - previous_time >= sleep_duration: + sleep_duration = min(sleep_duration * wait_multiplier, max_wait_time) + previous_time = current_time + await asyncio.sleep(sleep_duration) + operations_module = operations.AsyncOperations(api_client_=self._api_client) + + if operation.name is None: + raise ValueError("Invalid operation name.") + operation = await operations_module._get( + operation_id=operation.name.split("/")[-1], + ) + done = (operation.done or False) if hasattr(operation, "done") else False + if hasattr(operation, "error") and operation.error is not None: + raise ValueError(f"Error in delete operation: {operation.error}") + + async def delete( + self, + *, + prompt_id: str, + config: Optional[types.DeletePromptConfig] = None, + ) -> None: + """Deletes a prompt resource. + + Args: + prompt_id: The id of the prompt resource to delete. + + Raises: + TimeoutError: If the delete operation does not complete within the timeout. + ValueError: If the delete operation fails. + """ + + delete_prompt_operation = await self._delete_dataset( + prompt_id=prompt_id, + config=config, + ) + await self._wait_for_project_operation( + operation=delete_prompt_operation, timeout=config.timeout if config else 90 + ) + logger.info(f"Deleted prompt with id: {prompt_id}") + + async def delete_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.DeletePromptConfig] = None, + ) -> None: + """Deletes a prompt version resource. + + Args: + prompt_id: The id of the prompt resource to delete. + version_id: The id of the prompt version resource to delete. + + Raises: + TimeoutError: If the delete operation does not complete within the timeout. + ValueError: If the delete operation fails. + """ + delete_version_operation = await self._delete_dataset_version( + prompt_id=prompt_id, + version_id=version_id, + config=config, + ) + + await self._wait_for_project_operation( + operation=delete_version_operation, timeout=config.timeout if config else 90 + ) + logger.info( + f"Deleted prompt version {version_id} from prompt with id: {prompt_id}" + ) + + async def _list_prompts_pager( + self, + *, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> AsyncPager[types.Dataset]: + return AsyncPager( + "datasets", + self._list_prompts, + await self._list_prompts(config=config), + config, + ) + + async def _list_versions_pager( + self, + *, + prompt_id: str, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> AsyncPager[types.DatasetVersion]: + return AsyncPager( + "dataset_versions", + self._list_versions, + await self._list_versions(config=config, dataset_id=prompt_id), + config, + ) + + async def list( + self, + *, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> AsyncIterator[types.PromptRef]: + """Lists prompt resources in a project. + + This method retrieves all the prompts from the project provided in the + vertexai.Client constructor and returns a list of prompt references containing the prompt_id and model for the prompt. + + To get the full types.Prompt resource for a PromptRef after calling this method, use the get() method with the prompt_id as the prompt_id argument. + Example usage: + + ``` + prompt_refs = client.aio.prompt_management.list_prompts() + async for prompt_ref in prompt_refs: + await client.prompt_management.get(prompt_id=prompt_ref.prompt_id) + ``` + + Args: + config: Optional configuration for listing prompts. + + Returns: + An async iterator of types.PromptRef objects. + """ + if isinstance(config, dict): + config = types.ListPromptsConfig(**config) + elif not config: + config = types.ListPromptsConfig() + async for dataset in await self._list_prompts_pager(config=config): + if not dataset or not dataset.model_reference or not dataset.name: + continue + prompt_ref = types.PromptRef( + model=dataset.model_reference, prompt_id=dataset.name.split("/")[-1] + ) + yield prompt_ref + + async def list_versions( + self, + *, + prompt_id: str, + config: Optional[types.ListPromptsConfigOrDict] = None, + ) -> AsyncIterator[types.PromptVersionRef]: + """Lists prompt version resources for a provided prompt_id. + + This method retrieves all the prompt versions for a provided prompt_id. + + To get the full types.Prompt resource for a PromptVersionRef after calling this method, use the get() method with the returned prompt_id and version_id. + Example usage: + + ``` + prompt_version_refs = await client.prompt_management.list_versions(prompt_id="123") + async for version_ref in prompt_version_refs: + await client.aio.prompt_management.get(prompt_id=version_ref.prompt_id, version_id=version_ref.version_id) + ``` + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + config: Optional configuration for listing prompts. + + Returns: + An async iterator of types.PromptVersionRef objects representing the prompt version resources for the provided prompt_id. + + """ + if isinstance(config, dict): + config = types.ListPromptsConfig(**config) + elif not config: + config = types.ListPromptsConfig() + async for dataset_version in await self._list_versions_pager( + config=config, prompt_id=prompt_id + ): + if ( + not dataset_version + or not dataset_version.model_reference + or not dataset_version.name + ): + continue + prompt_version_ref = types.PromptVersionRef( + model=dataset_version.model_reference, + version_id=dataset_version.name.split("/")[-1], + prompt_id=prompt_id, + ) + yield prompt_version_ref + + async def restore_version( + self, + *, + prompt_id: str, + version_id: str, + config: Optional[types.RestoreVersionConfig] = None, + ) -> types.Prompt: + """Restores the provided prompt version to the latest version. + + Args: + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". + version_id: The id of the Vertex Dataset Version resource to restore. For example, if the version resource name is "projects/123/locations/us-central1/datasets/456/datasetVersions/789", then the version_id is "789". + config: Optional configuration for restoring the prompt version. + + Returns: + A types.Prompt object representing the prompt with the updated Dataset Version resource. + """ + + restore_prompt_operation = await self._restore_version( + dataset_id=prompt_id, + version_id=version_id, + ) + await self._wait_for_project_operation( + operation=restore_prompt_operation, + timeout=90, + ) + dataset_version_resource = await self._get_dataset_version_resource( + dataset_id=prompt_id, + dataset_version_id=version_id, + ) + updated_prompt = _prompt_management_utils._create_prompt_from_dataset_metadata( + dataset_version_resource, + ) + updated_prompt._dataset_version = dataset_version_resource + return updated_prompt + + @_common.experimental_warning( + "The Vertex SDK GenAI prompts.launch_optimization_job method is " + "experimental, and may change in future versions." + ) + async def launch_optimization_job( + self, + method: types.PromptOptimizerMethod, + config: types.PromptOptimizerConfigOrDict, + ) -> types.CustomJob: + """Call async Vertex AI Prompt Optimizer (VAPO). + + + Note: The `wait_for_completion` parameter in the config will be + ignored when using the AsyncClient, as it is not supported. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + vapo_config = vertexai.types.PromptOptimizerConfig( + config_path='gs://you-bucket-name/your-config.json', + service_account=service_account, + ) + job = await client.aio.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.VAPO, config=vapo_config) + + Args: + method: The method for optimizing multiple prompts. Supported methods: + VAPO, OPTIMIZATION_TARGET_GEMINI_NANO. + config: PromptOptimizerConfig instance containing the + configuration for prompt optimization. + Returns: + The custom job that was created. + """ + if isinstance(config, dict): + config = types.PromptOptimizerConfig(**config) + + if not config.config_path: + raise ValueError("Config path is required.") + + _OPTIMIZER_METHOD_TO_CONTAINER_URI = { + types.PromptOptimizerMethod.VAPO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_v1_0", + types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO: "us-docker.pkg.dev/vertex-ai/cair/vaipo:preview_android_v1_0", + } + container_uri = _OPTIMIZER_METHOD_TO_CONTAINER_URI.get(method) + if not container_uri: + raise ValueError( + 'Only "VAPO" and "OPTIMIZATION_TARGET_GEMINI_NANO" ' + "methods are currently supported." + ) + + if config.wait_for_completion: + logger.info( + "Ignoring wait_for_completion=True since the AsyncClient does not support it." + ) + + if config.optimizer_job_display_name: + display_name = config.optimizer_job_display_name + else: + timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + display_name = f"{method.value.lower()}-optimizer-{timestamp}" + + if not config.config_path: + raise ValueError("Config path is required.") + bucket = "/".join(config.config_path.split("/")[:-1]) + + region = self._api_client.location + project = self._api_client.project + container_args = { + "config": config.config_path, + } + args = ["--%s=%s" % (k, v) for k, v in container_args.items()] + worker_pool_specs = [ + types.WorkerPoolSpec( + replica_count=1, + machine_spec=types.MachineSpec(machine_type="n1-standard-4"), + container_spec=types.ContainerSpec( + image_uri=container_uri, + args=args, + ), + ) + ] + + service_account = _prompt_optimizer_utils._get_service_account(config) + + job_spec = types.CustomJobSpec( + worker_pool_specs=worker_pool_specs, + base_output_directory=genai_types.GcsDestination(output_uri_prefix=bucket), + service_account=service_account, + ) + + custom_job = types.CustomJob( + display_name=display_name, + job_spec=job_spec, + ) + + job = await self._create_custom_job_resource( + custom_job=custom_job, + ) + + # Get the job id for the dashboard url and display to the user. + job_resource_name = job.name + if not job_resource_name: + raise ValueError(f"Error creating job: {job}") + job_id = job_resource_name.split("/")[-1] + logger.info("Job created: %s", job.name) + + # Construct the dashboard URL to show to the user. + dashboard_url = f"https://console.cloud.google.com/vertex-ai/locations/{region}/training/{job_id}/cpu?project={project}" + logger.info("View the job status at: %s", dashboard_url) + + return job + + async def _custom_optimize( + self, + *, + content: Optional[genai_types.ContentOrDict] = None, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Optimize a single prompt.""" + if isinstance(config, dict): + config.pop("examples_dataframe", None) + elif config and hasattr(config, "examples_dataframe"): + del config.examples_dataframe + + parameter_model = types._OptimizeRequestParameters( + content=content, + config=config, + ) + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError("This method is only supported in the Vertex AI client.") + else: + request_dict = _OptimizeRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "tuningJobs:optimizePrompt".format_map(request_url_dict) + else: + path = "tuningJobs:optimizePrompt" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_list = "" if not response.body else json.loads(response.body) + + return_value = [] + + for response_dict in response_list: + response_value = types.OptimizeResponseEndpoint._from_response( + response=response_dict, kwargs=parameter_model.model_dump() + ) + self._api_client._verify_response(response_value) + content = response_value.content + if content is not None: + parts = content.parts + if parts and parts[0].text is not None: + return_value.append(parts[0].text) + + output = "".join(return_value) + final_response = types.OptimizeResponse(raw_text_response=output) + try: + final_response.parsed_response = _prompt_optimizer_utils._parse(output) + except Exception as e: + logger.warning( + f"Failed to parse response: {e}. Returning only raw_text_response." + ) + return final_response + + @_common.experimental_warning( + "The Vertex SDK GenAI prompts.optimize method is " + "experimental, and may change in future versions." + ) + async def optimize( + self, + *, + prompt: str, + config: Optional[types.OptimizeConfigOrDict] = None, + ) -> types.OptimizeResponse: + """Makes an async request to the optimizePrompt endpoint and returns an optimized prompt. + + Example usage: + client = vertexai.Client(project=PROJECT_NAME, location='us-central1') + prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompts.optimize(prompt=prompt) + + Args: + prompt: Required. The prompt to optimize. + config: Optional.The configuration for prompt optimization. To optimize + prompts from Android API provide + types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO + ) + For few-shot optimization, provide: + optim_target = types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS # or types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + types.OptimizeConfig( + optimization_target=optim_target, + examples_dataframe=dataframe + ) + OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS indicates that the few-shot + examples include specific scoring rubrics and their corresponding + evaluations. + OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE indicates that the few-shot + examples include a ground-truth target response. + Returns: + The parsed response from the API request. + """ + if isinstance(config, dict): + config = types.OptimizeConfig(**config) + + optimization_target: Optional[types.OptimizeTarget] = None + if config is not None: + optimization_target = config.optimization_target + + final_prompt = prompt + if ( + optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS + or optimization_target + == types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE + ): + final_prompt = _prompt_optimizer_utils._get_few_shot_prompt(prompt, config) + + # TODO: b/435653980 - replace the custom method with a generated method. + config_for_api = config.model_copy() if config else None + return await self._custom_optimize( + content=genai_types.Content( + parts=[genai_types.Part(text=final_prompt)], role="user" + ), + config=config_for_api, + ) diff --git a/agentplatform/_genai/runtime_revisions.py b/agentplatform/_genai/runtime_revisions.py new file mode 100644 index 0000000000..c3be86259b --- /dev/null +++ b/agentplatform/_genai/runtime_revisions.py @@ -0,0 +1,1257 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +from typing import Any, AsyncIterator, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import _agent_engines_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.runtimerevisions") + +logger.setLevel(logging.INFO) + + +def _DeleteAgentEngineRuntimeRevisionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetAgentEngineRuntimeRevisionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetDeleteAgentEngineRuntimeRevisionOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _ListAgentEngineRuntimeRevisionsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineRuntimeRevisionsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineRuntimeRevisionsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _QueryAgentEngineRuntimeRevisionConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["class_method"]) is not None: + setv(parent_object, ["classMethod"], getv(from_object, ["class_method"])) + + if getv(from_object, ["input"]) is not None: + setv(parent_object, ["input"], getv(from_object, ["input"])) + + if getv(from_object, ["include_all_fields"]) is not None: + setv(to_object, ["includeAllFields"], getv(from_object, ["include_all_fields"])) + + return to_object + + +def _QueryAgentEngineRuntimeRevisionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _QueryAgentEngineRuntimeRevisionConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class RuntimeRevisions(_api_module.BaseModule): + + def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.ReasoningEngineRuntimeRevision: + """ + Get an agent engine runtime revision instance. + """ + + parameter_model = types._GetAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ReasoningEngineRuntimeRevision._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineRuntimeRevisionsConfigOrDict] = None, + ) -> types.ListReasoningEnginesRuntimeRevisionsResponse: + """ + Lists reasoning engine runtime revisions. + + Args: + name (str): Required. The name of the reasoning engine to list runtime revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineRuntimeRevisionsConfig): + Optional. Additional configurations for listing the reasoning engine runtime revisions. + + Returns: + ListReasoningEnginesRuntimeRevisionsResponse: The requested reasoning engine runtime revisions. + + """ + + parameter_model = types._ListAgentEngineRuntimeRevisionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineRuntimeRevisionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/runtimeRevisions".format_map(request_url_dict) + else: + path = "{name}/runtimeRevisions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = ( + types.ListReasoningEnginesRuntimeRevisionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + """ + Delete an Agent Engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (DeleteAgentEngineRuntimeRevisionConfig): + Optional. Additional configurations for deleting the Agent Engine runtime revision. + + Returns: + DeleteAgentEngineRuntimeRevisionOperation: The operation for deleting the Agent Engine runtime revision. + + """ + + parameter_model = types._DeleteAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineRuntimeRevisionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_delete_runtime_revision_operation( + self, + *, + operation_name: str, + config: Optional[ + types.GetDeleteAgentEngineRuntimeRevisionOperationConfigOrDict + ] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + parameter_model = types._GetDeleteAgentEngineRuntimeRevisionOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = ( + _GetDeleteAgentEngineRuntimeRevisionOperationParameters_to_vertex( + parameter_model + ) + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineRuntimeRevisionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.QueryReasoningEngineResponse: + """ + Query an Agent Engine runtime revision. + """ + + parameter_model = types._QueryAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _QueryAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.AgentEngineRuntimeRevision: + """Gets an agent engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (GetAgentEngineRuntimeRevisionConfigOrDict): + Optional. Additional configurations for getting the Agent Engine runtime revision. + + Returns: + AgentEngineRuntimeRevision: The requested Agent Engine runtime revision instance. + """ + api_resource = self._get(name=name, config=config) + agent_engine_runtime_revision = types.AgentEngineRuntimeRevision( + api_client=self, + api_async_client=AsyncRuntimeRevisions(api_client_=self._api_client), + api_resource=api_resource, + ) + if api_resource.spec: + self._register_api_methods( + agent_engine_runtime_revision=agent_engine_runtime_revision + ) + return agent_engine_runtime_revision + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineRuntimeRevisionsConfigOrDict] = None, + ) -> Iterator[types.AgentEngineRuntimeRevision]: + """Lists all reasoning engine runtime revision instances matching the given query. + + Args: + name (str): Required. The name of the reasoning engine to list runtime revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineRuntimeRevisionsConfig): + Optional. Additional configurations for listing the reasoning engine runtime revisions. + + Returns: + Iterable[AgentEngineRuntimeRevision]: An iterable of runtime revisions. + """ + list_pager: Pager[types.ReasoningEngineRuntimeRevision] = Pager( + "reasoning_engine_runtime_revisions", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + return ( + types.AgentEngineRuntimeRevision( + api_client=self, + api_async_client=AsyncRuntimeRevisions(api_client_=self._api_client), + api_resource=runtime_revision, + ) + for runtime_revision in list_pager + ) + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + """Delete an Agent Engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (DeleteAgentEngineRuntimeRevisionConfig): + Optional. Additional configurations for deleting the Agent Engine runtime revision. + + Returns: + DeleteAgentEngineRuntimeRevisionOperation: The operation for deleting the Agent Engine runtime revision. + """ + if config is None: + config = types.DeleteAgentEngineRuntimeRevisionConfig() + elif isinstance(config, dict): + config = types.DeleteAgentEngineRuntimeRevisionConfig.model_validate(config) + operation = self._delete( + name=name, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_delete_runtime_revision_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError( + f"Failed to delete runtime revision: {operation.error}" + ) + return operation + + def _register_api_methods( + self, + *, + agent_engine_runtime_revision: types.AgentEngineRuntimeRevision, + ) -> types.AgentEngineRuntimeRevision: + """Registers the API methods for the agent engine runtime revision.""" + try: + _agent_engines_utils._register_api_methods_or_raise( + agent_engine=agent_engine_runtime_revision, + wrap_operation_fn={ + "": _agent_engines_utils._wrap_query_operation, # type: ignore[dict-item] + "async": _agent_engines_utils._wrap_async_query_operation, # type: ignore[dict-item] + "stream": _agent_engines_utils._wrap_stream_query_operation, # type: ignore[dict-item] + "async_stream": _agent_engines_utils._wrap_async_stream_query_operation, # type: ignore[dict-item] + "a2a_extension": _agent_engines_utils._wrap_a2a_operation, + }, + ) + except Exception as e: + logger.warning( + _agent_engines_utils._FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE, e + ) + return agent_engine_runtime_revision + + def _stream_query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> Iterator[Any]: + """Streams the response of the agent engine.""" + parameter_model = types._QueryAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + request_dict = _QueryAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + for response in self._api_client.request_streamed( + "post", path, request_dict, http_options + ): + yield response + + async def _async_stream_query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> AsyncIterator[Any]: + """Streams the response of the agent engine.""" + parameter_model = types._QueryAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + request_dict = _QueryAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:streamQuery?alt=sse".format_map(request_url_dict) + else: + path = "{name}:streamQuery?alt=sse" + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + http_options = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + async_iterator = await self._api_client.async_request_streamed( + "post", path, request_dict, http_options + ) + async for response in async_iterator: + yield response + + +class AsyncRuntimeRevisions(_api_module.BaseModule): + + async def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.ReasoningEngineRuntimeRevision: + """ + Get an agent engine runtime revision instance. + """ + + parameter_model = types._GetAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ReasoningEngineRuntimeRevision._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineRuntimeRevisionsConfigOrDict] = None, + ) -> types.ListReasoningEnginesRuntimeRevisionsResponse: + """ + Lists reasoning engine runtime revisions. + + Args: + name (str): Required. The name of the reasoning engine to list runtime revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineRuntimeRevisionsConfig): + Optional. Additional configurations for listing the reasoning engine runtime revisions. + + Returns: + ListReasoningEnginesRuntimeRevisionsResponse: The requested reasoning engine runtime revisions. + + """ + + parameter_model = types._ListAgentEngineRuntimeRevisionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineRuntimeRevisionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/runtimeRevisions".format_map(request_url_dict) + else: + path = "{name}/runtimeRevisions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = ( + types.ListReasoningEnginesRuntimeRevisionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + """ + Delete an Agent Engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (DeleteAgentEngineRuntimeRevisionConfig): + Optional. Additional configurations for deleting the Agent Engine runtime revision. + + Returns: + DeleteAgentEngineRuntimeRevisionOperation: The operation for deleting the Agent Engine runtime revision. + + """ + + parameter_model = types._DeleteAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineRuntimeRevisionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_delete_runtime_revision_operation( + self, + *, + operation_name: str, + config: Optional[ + types.GetDeleteAgentEngineRuntimeRevisionOperationConfigOrDict + ] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + parameter_model = types._GetDeleteAgentEngineRuntimeRevisionOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = ( + _GetDeleteAgentEngineRuntimeRevisionOperationParameters_to_vertex( + parameter_model + ) + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineRuntimeRevisionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _query( + self, + *, + name: str, + config: Optional[types.QueryAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.QueryReasoningEngineResponse: + """ + Query an Agent Engine runtime revision. + """ + + parameter_model = types._QueryAgentEngineRuntimeRevisionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _QueryAgentEngineRuntimeRevisionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:query".format_map(request_url_dict) + else: + path = "{name}:query" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.QueryReasoningEngineResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.AgentEngineRuntimeRevision: + """Gets an agent engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (GetAgentEngineRuntimeRevisionConfigOrDict): + Optional. Additional configurations for getting the Agent Engine runtime revision. + + Returns: + AgentEngineRuntimeRevision: The requested Agent Engine runtime revision instance. + """ + api_resource = await self._get(name=name, config=config) + agent_engine_runtime_revision = types.AgentEngineRuntimeRevision( + api_client=self, + api_async_client=AsyncRuntimeRevisions(api_client_=self._api_client), + api_resource=api_resource, + ) + if api_resource.spec: + self._register_api_methods( + agent_engine_runtime_revision=agent_engine_runtime_revision + ) + return agent_engine_runtime_revision + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineRuntimeRevisionsConfigOrDict] = None, + ) -> AsyncIterator[types.AgentEngineRuntimeRevision]: + """Lists reasoning engine runtime revisions. + + Args: + name (str): Required. The name of the reasoning engine to list runtime revisions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineRuntimeRevisionsConfig): + Optional. Additional configurations for listing the reasoning engine runtime revisions. + + Returns: + AsyncIterator[AgentEngineRuntimeRevision]: An async iterator of runtime revisions. + """ + list_pager: AsyncPager[types.ReasoningEngineRuntimeRevision] = AsyncPager( + "reasoning_engine_runtime_revisions", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) + + async for runtime_revision in list_pager: + yield types.AgentEngineRuntimeRevision( + api_client=self, + api_async_client=AsyncRuntimeRevisions(api_client_=self._api_client), + api_resource=runtime_revision, + ) + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> types.DeleteAgentEngineRuntimeRevisionOperation: + """Delete an Agent Engine runtime revision. + + Args: + name (str): Required. The name of the Agent Engine runtime revision to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/runtimeRevisions/{runtime_revision_id}`. + config (DeleteAgentEngineRuntimeRevisionConfig): + Optional. Additional configurations for deleting the Agent Engine runtime revision. + + Returns: + DeleteAgentEngineRuntimeRevisionOperation: The operation for deleting the Agent Engine runtime revision. + """ + if config is None: + config = types.DeleteAgentEngineRuntimeRevisionConfig() + elif isinstance(config, dict): + config = types.DeleteAgentEngineRuntimeRevisionConfig.model_validate(config) + operation = await self._delete( + name=name, + config=config, + ) + if config.wait_for_completion and not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_delete_runtime_revision_operation, + poll_interval_seconds=0.5, + ) + if operation.error: + raise RuntimeError( + f"Failed to delete runtime revision: {operation.error}" + ) + return operation + + def _register_api_methods( + self, + *, + agent_engine_runtime_revision: types.AgentEngineRuntimeRevision, + ) -> types.AgentEngineRuntimeRevision: + """Registers the API methods for the agent engine runtime revision.""" + try: + _agent_engines_utils._register_api_methods_or_raise( + agent_engine=agent_engine_runtime_revision, + wrap_operation_fn={ + "": _agent_engines_utils._wrap_query_operation, # type: ignore[dict-item] + "async": _agent_engines_utils._wrap_async_query_operation, # type: ignore[dict-item] + "stream": _agent_engines_utils._wrap_stream_query_operation, # type: ignore[dict-item] + "async_stream": _agent_engines_utils._wrap_async_stream_query_operation, # type: ignore[dict-item] + "a2a_extension": _agent_engines_utils._wrap_a2a_operation, + }, + ) + except Exception as e: + logger.warning( + _agent_engines_utils._FAILED_TO_REGISTER_API_METHODS_WARNING_TEMPLATE, e + ) + return agent_engine_runtime_revision diff --git a/agentplatform/_genai/runtimes.py b/agentplatform/_genai/runtimes.py new file mode 100644 index 0000000000..9f41816bf5 --- /dev/null +++ b/agentplatform/_genai/runtimes.py @@ -0,0 +1,78 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Handwritten placeholder code for the runtimes.py file. +# Should be replaced by the generated file. + +import importlib +import logging +import typing + +from google.genai import _api_module + + +if typing.TYPE_CHECKING: + from . import runtime_revisions as runtime_revisions_module + + _ = runtime_revisions_module + + +logger = logging.getLogger("agentplatform_genai.runtimes") + +logger.setLevel(logging.INFO) + + +class Runtimes(_api_module.BaseModule): + + _revisions = None + + @property + def revisions(self) -> "runtime_revisions_module.RuntimeRevisions": + if self._revisions is None: + try: + # We need to lazy load the revisions module to handle the + # possibility of ImportError when dependencies are not installed. + self._revisions = importlib.import_module( + ".runtime_revisions", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.runtimes.revisions' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._revisions.RuntimeRevisions(self._api_client) # type: ignore[no-any-return] + + +class AsyncRuntimes(_api_module.BaseModule): + + _revisions = None + + @property + def revisions(self) -> "runtime_revisions_module.AsyncRuntimeRevisions": + if self._revisions is None: + try: + # We need to lazy load the revisions module to handle the + # possibility of ImportError when dependencies are not installed. + self._revisions = importlib.import_module( + ".runtime_revisions", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.runtimes.revisions' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._revisions.AsyncRuntimeRevisions(self._api_client) # type: ignore[no-any-return] diff --git a/agentplatform/_genai/sandbox_snapshots.py b/agentplatform/_genai/sandbox_snapshots.py new file mode 100644 index 0000000000..fe7265daef --- /dev/null +++ b/agentplatform/_genai/sandbox_snapshots.py @@ -0,0 +1,1019 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import Pager + +from . import _agent_engines_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.sandboxsnapshots") + +logger.setLevel(logging.INFO) + + +def _CreateAgentEngineSandboxSnapshotConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["owner"]) is not None: + setv(parent_object, ["owner"], getv(from_object, ["owner"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + return to_object + + +def _CreateSandboxEnvironmentSnapshotRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["source_sandbox_environment_name"]) is not None: + setv( + to_object, + ["_url", "name"], + getv(from_object, ["source_sandbox_environment_name"]), + ) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _CreateAgentEngineSandboxSnapshotConfig_to_vertex( + getv(from_object, ["config"]), to_object + ), + ) + + return to_object + + +def _DeleteSandboxEnvironmentSnapshotRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetAgentEngineSandboxSnapshotOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetSandboxEnvironmentSnapshotRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListSandboxEnvironmentSnapshotsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListSandboxEnvironmentSnapshotsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListSandboxEnvironmentSnapshotsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class SandboxSnapshots(_api_module.BaseModule): + """Sandbox environment snapshot commands.""" + + def _create( + self, + *, + source_sandbox_environment_name: str, + config: Optional[types.CreateAgentEngineSandboxSnapshotConfigOrDict] = None, + ) -> types.AgentEngineSandboxSnapshotOperation: + """ + Snapshots an existing sandbox environment. + + Args: + source_sandbox_environment_name (str): + Required. The name of the sandbox environment to snapshot. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox_environment_id} + config (CreateAgentEngineSandboxSnapshotConfig): + Optional. The configuration for the sandbox snapshot. + + """ + + parameter_model = types._CreateSandboxEnvironmentSnapshotRequestParameters( + source_sandbox_environment_name=source_sandbox_environment_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:snapshot".format_map(request_url_dict) + else: + path = "{name}:snapshot" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentSnapshotOperation: + """ + Deletes a sandbox environment snapshot. + + """ + + parameter_model = types._DeleteSandboxEnvironmentSnapshotRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteSandboxEnvironmentSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.SandboxEnvironmentSnapshot: + """ + Gets a sandbox environment snapshot. + + """ + + parameter_model = types._GetSandboxEnvironmentSnapshotRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentSnapshot._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentSnapshotsConfigOrDict] = None, + ) -> types.ListSandboxEnvironmentSnapshotsResponse: + """ + Lists sandbox environment snapshots. + + """ + + parameter_model = types._ListSandboxEnvironmentSnapshotsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListSandboxEnvironmentSnapshotsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentSnapshots".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentSnapshots" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListSandboxEnvironmentSnapshotsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def get_sandbox_snapshot_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSandboxSnapshotOperation: + parameter_model = types._GetAgentEngineSandboxSnapshotOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxSnapshotOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + source_sandbox_environment_name: str, + config: Optional[types.CreateAgentEngineSandboxSnapshotConfigOrDict] = None, + poll_interval_seconds: float = 0.1, + ) -> types.AgentEngineSandboxSnapshotOperation: + """Snapshots an existing sandbox environment. + + Args: + source_sandbox_environment_name (str): + Required. The name of the sandbox environment to snapshot. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox_environment_id} + config (CreateAgentEngineSandboxSnapshotConfig): + Optional. The configuration for the sandbox snapshot. + poll_interval_seconds (int): + Optional. Seconds to wait between polling for operation status. Defaults to 0.1. + + Returns: + AgentEngineSandboxSnapshotOperation: The operation for creating the sandbox snapshot. + """ + operation = self._create( + source_sandbox_environment_name=source_sandbox_environment_name, + config=config, + ) + if config is None: + config = types.CreateAgentEngineSandboxSnapshotConfig() + elif isinstance(config, dict): + config = types.CreateAgentEngineSandboxSnapshotConfig.model_validate(config) + if config.wait_for_completion: + if not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self.get_sandbox_snapshot_operation, + poll_interval_seconds=poll_interval_seconds, + ) + # We need to make a call to get the sandbox snapshot because the operation + # response might not contain the relevant fields. + if not operation.response: + raise ValueError("Error retrieving sandbox snapshot.") + operation.response = self.get(name=operation.response.name) + return operation + + def list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentSnapshotsConfigOrDict] = None, + ) -> Iterator[types.SandboxEnvironmentSnapshot]: + """Lists Agent Engine sandbox snapshots. + + Args: + name (str): + Required. The name of the agent engine to list sandbox snapshots for. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentSnapshots/{sandbox_snapshot_id} + config (ListSandboxEnvironmentSnapshotsConfig): + Optional. The configuration for the sandbox snapshots to list. + + Returns: + Iterable[SandboxEnvironmentSnapshot]: An iterable of agent engine sandbox snapshots. + """ + return Pager( + "sandbox_environment_snapshots", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + def get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.SandboxEnvironmentSnapshot: + """Gets a sandbox snapshot in the Agent Engine. + Args: + name (str): + Required. A fully-qualified resource name or ID such as + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentSnapshots/{snapshot_id} + or a shortened name such as "reasoningEngines/{resource_id}/sandboxEnvironmentSnapshots/{snapshot_id}". + config (GetSandboxEnvironmentSnapshotConfigOrDict): + Optional. The configuration for the sandbox snapshot to get. + """ + return self._get(name=name, config=config) + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentSnapshotOperation: + """Deletes a sandbox snapshot in the Agent Engine. + Args: + name (str): + Required. The name of the sandbox snapshot to delete. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentSnapshots/{snapshot_id} + config (DeleteSandboxEnvironmentSnapshotConfigOrDict): + Optional. Configuration for the delete operation. + """ + return self._delete(name=name, config=config) + + +class AsyncSandboxSnapshots(_api_module.BaseModule): + """Sandbox environment snapshot commands.""" + + async def _create( + self, + *, + source_sandbox_environment_name: str, + config: Optional[types.CreateAgentEngineSandboxSnapshotConfigOrDict] = None, + ) -> types.AgentEngineSandboxSnapshotOperation: + """ + Snapshots an existing sandbox environment. + + Args: + source_sandbox_environment_name (str): + Required. The name of the sandbox environment to snapshot. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox_environment_id} + config (CreateAgentEngineSandboxSnapshotConfig): + Optional. The configuration for the sandbox snapshot. + + """ + + parameter_model = types._CreateSandboxEnvironmentSnapshotRequestParameters( + source_sandbox_environment_name=source_sandbox_environment_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:snapshot".format_map(request_url_dict) + else: + path = "{name}:snapshot" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentSnapshotOperation: + """ + Deletes a sandbox environment snapshot. + + """ + + parameter_model = types._DeleteSandboxEnvironmentSnapshotRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteSandboxEnvironmentSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentSnapshotConfigOrDict] = None, + ) -> types.SandboxEnvironmentSnapshot: + """ + Gets a sandbox environment snapshot. + + """ + + parameter_model = types._GetSandboxEnvironmentSnapshotRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentSnapshotRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentSnapshot._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentSnapshotsConfigOrDict] = None, + ) -> types.ListSandboxEnvironmentSnapshotsResponse: + """ + Lists sandbox environment snapshots. + + """ + + parameter_model = types._ListSandboxEnvironmentSnapshotsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListSandboxEnvironmentSnapshotsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentSnapshots".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentSnapshots" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListSandboxEnvironmentSnapshotsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def get_sandbox_snapshot_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSandboxSnapshotOperation: + parameter_model = types._GetAgentEngineSandboxSnapshotOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxSnapshotOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxSnapshotOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value diff --git a/agentplatform/_genai/sandbox_templates.py b/agentplatform/_genai/sandbox_templates.py new file mode 100644 index 0000000000..26a5c184b1 --- /dev/null +++ b/agentplatform/_genai/sandbox_templates.py @@ -0,0 +1,1088 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import Pager + +from . import _agent_engines_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.sandboxtemplates") + +logger.setLevel(logging.INFO) + + +def _CreateSandboxEnvironmentTemplateConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["custom_container_environment"]) is not None: + setv( + parent_object, + ["customContainerEnvironment"], + getv(from_object, ["custom_container_environment"]), + ) + + if getv(from_object, ["default_container_environment"]) is not None: + setv( + parent_object, + ["defaultContainerEnvironment"], + getv(from_object, ["default_container_environment"]), + ) + + if getv(from_object, ["egress_control_config"]) is not None: + setv( + parent_object, + ["egressControlConfig"], + getv(from_object, ["egress_control_config"]), + ) + + return to_object + + +def _CreateSandboxEnvironmentTemplateRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _CreateSandboxEnvironmentTemplateConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + return to_object + + +def _DeleteSandboxEnvironmentTemplateRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetSandboxEnvironmentTemplateOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetSandboxEnvironmentTemplateRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListSandboxEnvironmentTemplatesConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListSandboxEnvironmentTemplatesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListSandboxEnvironmentTemplatesConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class SandboxTemplates(_api_module.BaseModule): + """Sandbox environment templates commands.""" + + def _create( + self, + *, + name: str, + config: Optional[types.CreateSandboxEnvironmentTemplateConfigOrDict] = None, + display_name: str, + ) -> types.SandboxEnvironmentTemplateOperation: + """ + Creates a new sandbox template in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create the template under. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id} + display_name (str): + Required. The display name of the sandbox template. + config (CreateSandboxEnvironmentTemplateConfig): + Optional. The configuration for the sandbox template. + + """ + + parameter_model = types._CreateSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + display_name=display_name, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentTemplates".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentTemplates" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentTemplateOperation: + """ + Delete an Agent Engine sandbox template. + + Args: + name (str): + Required. The name of the sandbox template to delete. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxTemplates/{sandbox_template} + config (DeleteSandboxEnvironmentTemplateConfig): + Optional. Configuration for the delete operation. + + """ + + parameter_model = types._DeleteSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteSandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.SandboxEnvironmentTemplate: + """ + Gets an agent engine sandbox template. + + Args: + name (str): The resource name of the SandboxEnvironmentTemplate. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_environment_template}` + config (GetSandboxEnvironmentTemplateConfig): Configuration for the + request. + + Returns: + shared.SandboxEnvironmentTemplate: The retrieved sandbox environment + template. + + """ + + parameter_model = types._GetSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplate._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentTemplatesConfigOrDict] = None, + ) -> types.ListSandboxEnvironmentTemplatesResponse: + """ + Lists Agent Engine sandbox templates. + + Args: + name (str): Name of the agent engine. Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id} + config (ListSandboxEnvironmentTemplatesConfig): Configuration for listing sandbox templates. + + Returns: + ListSandboxEnvironmentTemplatesResponse: A list of sandbox templates. + + """ + + parameter_model = types._ListSandboxEnvironmentTemplatesRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListSandboxEnvironmentTemplatesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentTemplates".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentTemplates" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListSandboxEnvironmentTemplatesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def get_sandbox_environment_template_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.SandboxEnvironmentTemplateOperation: + parameter_model = types._GetSandboxEnvironmentTemplateOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentTemplateOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + name: str, + display_name: str, + config: Optional[types.CreateSandboxEnvironmentTemplateConfigOrDict] = None, + poll_interval_seconds: float = 0.1, + ) -> types.SandboxEnvironmentTemplateOperation: + """Creates a new sandbox template in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create sandbox template for. + projects/{project}/locations/{location}/reasoningEngines/{resource_id} + display_name (str): + Required. The display name of the sandbox template. + config (CreateSandboxEnvironmentTemplateConfig): + Optional. The configuration for the sandbox template. + polling_interval (int): + Optional. Seconds to wait between polling for operation status. Defaults to 5. + + Returns: + SandboxEnvironmentTemplateOperation: The operation for creating the sandbox template. + """ + operation = self._create( + name=name, + display_name=display_name, + config=config, + ) + if config is None: + config = types.CreateSandboxEnvironmentTemplateConfig() + elif isinstance(config, dict): + config = types.CreateSandboxEnvironmentTemplateConfig.model_validate(config) + if config.wait_for_completion: + if not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self.get_sandbox_environment_template_operation, + poll_interval_seconds=poll_interval_seconds, + ) + # We need to make a call to get the sandbox template because the operation + # response might not contain the relevant fields. + if not operation.response: + raise ValueError("Error retrieving sandbox template.") + operation.response = self.get(name=operation.response.name) + return operation + + def list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentTemplatesConfigOrDict] = None, + ) -> Iterator[types.SandboxEnvironmentTemplate]: + """Lists Agent Engine sandbox templates. + + Args: + name (str): + Required. The name of the agent engine to list sandbox templates for. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_template_id} + config (ListSandboxEnvironmentTemplatesConfig): + Optional. The configuration for the sandbox templates to list. + + Returns: + Iterable[SandboxEnvironmentTemplate]: An iterable of agent engine sandbox templates. + """ + return Pager( + "sandbox_environment_templates", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + def get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.SandboxEnvironmentTemplate: + """Gets a sandbox template in the Agent Engine. + Args: + name (str): + Required. A fully-qualified resource name or ID such as + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_template_id} + or a shortened name such as "reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_template_id}". + config (GetSandboxEnvironmentTemplateConfigOrDict): + Optional. The configuration for the sandbox template to get. + """ + return self._get(name=name, config=config) + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentTemplateOperation: + """Deletes a sandbox template in the Agent Engine. + Args: + name (str): + Required. The name of the sandbox template to delete. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_template_id} + config (DeleteSandboxEnvironmentTemplateConfig): + Optional. Configuration for the delete operation. + """ + return self._delete(name=name, config=config) + + +class AsyncSandboxTemplates(_api_module.BaseModule): + """Sandbox environment templates commands.""" + + async def _create( + self, + *, + name: str, + config: Optional[types.CreateSandboxEnvironmentTemplateConfigOrDict] = None, + display_name: str, + ) -> types.SandboxEnvironmentTemplateOperation: + """ + Creates a new sandbox template in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create the template under. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id} + display_name (str): + Required. The display name of the sandbox template. + config (CreateSandboxEnvironmentTemplateConfig): + Optional. The configuration for the sandbox template. + + """ + + parameter_model = types._CreateSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + display_name=display_name, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentTemplates".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentTemplates" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, + *, + name: str, + config: Optional[types.DeleteSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.DeleteSandboxEnvironmentTemplateOperation: + """ + Delete an Agent Engine sandbox template. + + Args: + name (str): + Required. The name of the sandbox template to delete. + Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxTemplates/{sandbox_template} + config (DeleteSandboxEnvironmentTemplateConfig): + Optional. Configuration for the delete operation. + + """ + + parameter_model = types._DeleteSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteSandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get( + self, + *, + name: str, + config: Optional[types.GetSandboxEnvironmentTemplateConfigOrDict] = None, + ) -> types.SandboxEnvironmentTemplate: + """ + Gets an agent engine sandbox template. + + Args: + name (str): The resource name of the SandboxEnvironmentTemplate. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironmentTemplates/{sandbox_environment_template}` + config (GetSandboxEnvironmentTemplateConfig): Configuration for the + request. + + Returns: + shared.SandboxEnvironmentTemplate: The retrieved sandbox environment + template. + + """ + + parameter_model = types._GetSandboxEnvironmentTemplateRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentTemplateRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplate._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListSandboxEnvironmentTemplatesConfigOrDict] = None, + ) -> types.ListSandboxEnvironmentTemplatesResponse: + """ + Lists Agent Engine sandbox templates. + + Args: + name (str): Name of the agent engine. Format: projects/{project}/locations/{location}/reasoningEngines/{resource_id} + config (ListSandboxEnvironmentTemplatesConfig): Configuration for listing sandbox templates. + + Returns: + ListSandboxEnvironmentTemplatesResponse: A list of sandbox templates. + + """ + + parameter_model = types._ListSandboxEnvironmentTemplatesRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListSandboxEnvironmentTemplatesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironmentTemplates".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironmentTemplates" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListSandboxEnvironmentTemplatesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def get_sandbox_environment_template_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.SandboxEnvironmentTemplateOperation: + parameter_model = types._GetSandboxEnvironmentTemplateOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSandboxEnvironmentTemplateOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironmentTemplateOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value diff --git a/agentplatform/_genai/sandboxes.py b/agentplatform/_genai/sandboxes.py new file mode 100644 index 0000000000..84277458f0 --- /dev/null +++ b/agentplatform/_genai/sandboxes.py @@ -0,0 +1,1524 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import json +import logging +import mimetypes +import secrets +import time +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google import genai +from google.cloud import iam_credentials_v1 # type: ignore[attr-defined] +from google.genai import _api_module +from google.genai import _common +from google.genai import types as genai_types +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import Pager + +from . import _agent_engines_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.sandboxes") + +logger.setLevel(logging.INFO) + + +def _CreateAgentEngineSandboxConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(parent_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + if getv(from_object, ["sandbox_environment_template"]) is not None: + setv( + parent_object, + ["sandboxEnvironmentTemplate"], + getv(from_object, ["sandbox_environment_template"]), + ) + + if getv(from_object, ["sandbox_environment_snapshot"]) is not None: + setv( + parent_object, + ["sandboxEnvironmentSnapshot"], + getv(from_object, ["sandbox_environment_snapshot"]), + ) + + if getv(from_object, ["owner"]) is not None: + setv(parent_object, ["owner"], getv(from_object, ["owner"])) + + return to_object + + +def _CreateAgentEngineSandboxRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["spec"]) is not None: + setv(to_object, ["spec"], getv(from_object, ["spec"])) + + if getv(from_object, ["config"]) is not None: + _CreateAgentEngineSandboxConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _DeleteAgentEngineSandboxRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ExecuteCodeAgentEngineSandboxRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["inputs"]) is not None: + setv(to_object, ["inputs"], [item for item in getv(from_object, ["inputs"])]) + + return to_object + + +def _GetAgentEngineSandboxOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetAgentEngineSandboxRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListAgentEngineSandboxesConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineSandboxesRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineSandboxesConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class Sandboxes(_api_module.BaseModule): + + def _create( + self, + *, + name: str, + spec: Optional[types.SandboxEnvironmentSpecOrDict] = None, + config: Optional[types.CreateAgentEngineSandboxConfigOrDict] = None, + ) -> types.AgentEngineSandboxOperation: + """ + Creates a new sandbox in the Agent Engine. + """ + + parameter_model = types._CreateAgentEngineSandboxRequestParameters( + name=name, + spec=spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironments".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironments" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSandboxConfigOrDict] = None, + ) -> types.DeleteAgentEngineSandboxOperation: + """ + Delete an Agent Engine sandbox. + + Args: + name (str): + Required. The name of the Agent Engine sandbox to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox}`. + + """ + + parameter_model = types._DeleteAgentEngineSandboxRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _execute_code( + self, + *, + name: str, + inputs: Optional[list[types.ChunkOrDict]] = None, + config: Optional[types.ExecuteCodeAgentEngineSandboxConfigOrDict] = None, + ) -> types.ExecuteSandboxEnvironmentResponse: + """ + Execute code in an Agent Engine sandbox. + """ + + parameter_model = types._ExecuteCodeAgentEngineSandboxRequestParameters( + name=name, + inputs=inputs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ExecuteCodeAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/:execute".format_map(request_url_dict) + else: + path = "{name}/:execute" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ExecuteSandboxEnvironmentResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSandboxConfigOrDict] = None, + ) -> types.SandboxEnvironment: + """ + Gets an agent engine sandbox. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + or a shortened name such as "reasoningEngines/456/sandboxEnvironments/789". + + """ + + parameter_model = types._GetAgentEngineSandboxRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironment._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSandboxesConfigOrDict] = None, + ) -> types.ListAgentEngineSandboxesResponse: + """ + Lists Agent Engine sandboxes. + + Args: + name (str): Required. The name of the Agent Engine to list sessions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineSandboxesConfig): + Optional. Additional configurations for listing the Agent Engine sandboxes. + + Returns: + ListReasoningEnginesSandboxesResponse: The requested Agent Engine sandboxes. + + """ + + parameter_model = types._ListAgentEngineSandboxesRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSandboxesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironments".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironments" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineSandboxesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_sandbox_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSandboxOperation: + parameter_model = types._GetAgentEngineSandboxOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _templates = None + _snapshots = None + + @property + def templates(self) -> Any: + if self._templates is None: + try: + self._templates = __import__("importlib").import_module( + ".sandbox_templates", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.sandboxes.templates' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._templates.SandboxTemplates(self._api_client) + + @property + def snapshots(self) -> Any: + if self._snapshots is None: + try: + self._snapshots = __import__("importlib").import_module( + ".sandbox_snapshots", __package__ + ) + except ImportError as e: + raise ImportError( + "The 'agent_engines.sandboxes.snapshots' module requires " + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[sandbox_snapshots]" + ) from e + return self._snapshots.SandboxSnapshots(self._api_client) + + def create( + self, + *, + name: str, + poll_interval_seconds: float = 0.1, + spec: Optional[types.SandboxEnvironmentSpecOrDict] = None, + config: Optional[types.CreateAgentEngineSandboxConfigOrDict] = None, + ) -> types.AgentEngineSandboxOperation: + """Creates a new sandbox in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create sandbox for. + projects/{project}/locations/{location}/reasoningEngines/{resource_id} + poll_interval_seconds (float): + Optional. The interval in seconds to poll for sandbox creation + completion. + spec (SandboxEnvironmentSpec): + Optional. The specification for the sandbox to create. + config (CreateAgentEngineSandboxConfigOrDict): + Optional. The configuration for the sandbox. + + Returns: + AgentEngineSandboxOperation: The operation for creating the sandbox. + """ + if spec: + computer_use = False + if isinstance(spec, dict): + computer_use = spec.get("computer_use_environment") is not None + elif hasattr(spec, "computer_use_environment"): + computer_use = True + + if computer_use: + logging.warning( + "The computer_use_environment feature in the sandboxes module is experimental and may change in future versions." + ) + operation = self._create( + name=name, + spec=spec, + config=config, + ) + if config is None: + config = types.CreateAgentEngineSandboxConfig() + elif isinstance(config, dict): + config = types.CreateAgentEngineSandboxConfig.model_validate(config) + if config.wait_for_completion: + if not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_sandbox_operation, + poll_interval_seconds=poll_interval_seconds, + ) + # We need to make a call to get the sandbox because the operation + # response might not contain the relevant fields. + if not operation.response: + raise ValueError("Error retrieving sandbox.") + operation.response = self.get(name=operation.response.name) + return operation + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSandboxesConfigOrDict] = None, + ) -> Iterator[types.SandboxEnvironment]: + """Lists Agent Engine sandboxes. + + Args: + name (str): + Required. The name of the agent engine to list sandboxes for. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/SandboxEnvironments/{sandbox_id} + config (ListAgentEngineSandboxConfig): + Optional. The configuration for the sandboxes to list. + + Returns: + Iterable[SandboxEnvironment]: An iterable of agent engine sandboxes. + """ + return Pager( + "sandbox_environments", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + def execute_code( + self, + *, + name: str, + input_data: dict[str, Any], + config: Optional[types.ExecuteCodeAgentEngineSandboxConfigOrDict] = None, + ) -> types.ExecuteSandboxEnvironmentResponse: + """Executes code in the Agent Engine sandbox. + + Args: + name (str): + Required. The name of the agent engine sandbox to run code in. + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/SandboxEnvironments/{sandbox_id} + input_data (dict[str, Any]): + Required. The input to the code to execute. + config (ExecuteCodeAgentEngineSandboxConfigOrDict): + Optional. The configuration for the sandboxes to run code in. + + Returns: + ExecuteSandboxEnvironmentResponse: The response from executing the code. + """ + input_chunks = [] + + if input_data.get("code") is not None: + code = input_data.get("code", "") + json_code = json.dumps({"code": code}).encode("utf-8") + input_chunks.append( + types.Chunk( + mime_type="application/json", + data=json_code, + ) + ) + + for file in input_data.get("files", []): + file_name = file.get("name", "") + input_chunks.append( + types.Chunk( + mime_type=file.get("mimeType", ""), + data=file.get("content", b""), + metadata={"attributes": {"file_name": file_name.encode("utf-8")}}, + ) + ) + + response = self._execute_code( + name=name, + inputs=input_chunks, + config=config, + ) + + output_chunks = [] + if response.outputs is not None: + for output in response.outputs: + if output.mime_type is None: + # if mime_type is not available, try to guess the mime_type from the file_name. + if ( + output.metadata is not None + and output.metadata.attributes is not None + ): + file_name = output.metadata.attributes.get( + "file_name", b"" + ).decode("utf-8") + mime_type, _ = mimetypes.guess_type(file_name) + output.mime_type = mime_type + output_chunks.append(output) + + response = types.ExecuteSandboxEnvironmentResponse(outputs=output_chunks) + + return response + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSandboxConfigOrDict] = None, + ) -> types.SandboxEnvironment: + """Gets an agent engine sandbox. + Args: + name (str): + Required. A fully-qualified resource name or ID such as + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/SandboxEnvironments/{sandbox_id} + or a shortened name such as "reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox_id}". + config (GetAgentEngineSandboxConfigOrDict): + Optional. The configuration for the sandbox to get. + + """ + return self._get(name=name, config=config) + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSandboxConfigOrDict] = None, + ) -> types.DeleteAgentEngineSandboxOperation: + """Deletes an agent engine sandbox. + Args: + name (str): + Required. A fully-qualified resource name or ID such as + projects/{project}/locations/{location}/reasoningEngines/{resource_id}/SandboxEnvironments/{sandbox_id} + or a shortened name such as "reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox_id}". + config (DeleteAgentEngineSandboxConfigOrDict): + Optional. The configuration for the sandbox to delete. + """ + return self._delete(name=name, config=config) + + def generate_access_token( + self, + service_account_email: str, + sandbox_hostname: str, + port: str = "8080", + timeout: int = 3600, + ) -> str: + """Signs a JWT with a Google Cloud service account. + + Args: + service_account_email (str): + Required. The email of the service account to use for signing. + sandbox_hostname (str): + Required. The hostname of the sandbox to generate a token for. + port (str): + Optional. The port to use for the token. Defaults to "8080". + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + + Returns: + str: The signed JWT. + """ + client = iam_credentials_v1.IAMCredentialsClient() + name = f"projects/-/serviceAccounts/{service_account_email}" + custom_claims = {"hostname": sandbox_hostname, "port": port} + payload = { + "iat": int(time.time()), + "exp": int(time.time()) + timeout, + "iss": service_account_email, + "sub": service_account_email, + "nonce": secrets.randbelow(1000000000) + 1, + "aud": "https://aiplatform.googleapis.com/", # default audience for sandbox proxy + **custom_claims, + } + request = iam_credentials_v1.SignJwtRequest( + name=name, + payload=json.dumps(payload), + ) + response = client.sign_jwt(request=request) + return response.signed_jwt # type: ignore[no-any-return] + + def send_command( + self, + *, + http_method: str, + access_token: str, + sandbox_environment: types.SandboxEnvironment, + port: str = "8080", + path: Optional[str] = None, + query_params: Optional[dict[str, object]] = None, + headers: Optional[dict[str, str]] = None, + request_dict: Optional[dict[str, object]] = None, + ) -> genai_types.HttpResponse: + """Sends a command to the sandbox. + + Args: + http_method (str): + Required. The HTTP method to use for the command. + access_token (str): + Required. The access token to use for authorization. + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to send the command to. + port (str): + Optional. The port to use for the token. Defaults to "8080". This should be one of the ports specified during template creation. + path (str): + Optional. The path to send the command to. + query_params (dict[str, object]): + Optional. The query parameters to include in the command. + headers (dict[str, str]): + Optional. The headers to include in the command. + request_dict (dict[str, object]): + Optional. The request body to include in the command. + + Returns: + genai_types.HttpResponse: The response from the sandbox. + """ + headers = headers or {} + request_dict = request_dict or {} + connection_info = sandbox_environment.connection_info + if not connection_info: + raise ValueError("Connection info is not available.") + if connection_info.load_balancer_hostname: + endpoint = "https://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + endpoint = "http://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + + routing_token = connection_info.routing_token + if not routing_token: + raise ValueError("Routing token is not available.") + + path = path or "" + if query_params: + path = f"{path}?{urlencode(query_params)}" + headers["Authorization"] = f"Bearer {access_token}" + headers["X-Sandbox-Routing-Token"] = routing_token + headers["X-Sandbox-Port"] = port + endpoint = endpoint + path if path.startswith("/") else endpoint + "/" + path + http_options = genai_types.HttpOptions(headers=headers, base_url=endpoint) + http_client = genai.Client(vertexai=True, http_options=http_options) + # Full path is constructed in this function. The passed in path into request + # function will not be used. + response = http_client._api_client.request(http_method, path, request_dict) + return genai_types.HttpResponse( + headers=response.headers, + body=response.body, + ) + + def generate_browser_ws_headers( + self, + sandbox_environment: types.SandboxEnvironment, + service_account_email: str, + port: str = "8080", + timeout: int = 3600, + ) -> tuple[str, dict[str, str]]: + """Generates the websocket upgrade headers for the browser. + + Args: + sandbox_environment (types.SandboxEnvironment): + Required. The sandbox environment to generate websocket headers for. + service_account_email (str): + Required. The email of the service account to use for signing. + port (str): + Optional. The port to use for the CDP websocket endpoint url fetching. + Defaults to "8080". This should be one of the ports specified during template creation. + timeout (int): + Optional. The timeout in seconds for the token. Defaults to 3600. + Returns: + tuple[str, dict[str, str]]: A tuple containing the websocket URL and + the headers for websocket upgrade. + """ + if not sandbox_environment.connection_info: + raise ValueError("Connection info is not available.") + + connection_info = sandbox_environment.connection_info + if connection_info.load_balancer_hostname: + ws_base_url = "wss://" + connection_info.load_balancer_hostname + elif connection_info.load_balancer_ip: + ws_base_url = "ws://" + connection_info.load_balancer_ip + else: + raise ValueError("Load balancer hostname or ip is not available.") + + http_access_token = self.generate_access_token( + service_account_email, connection_info.load_balancer_hostname, port, timeout + ) + response = self.send_command( + http_method="GET", + access_token=http_access_token, + sandbox_environment=sandbox_environment, + port=port, + path="/cdp_ws_endpoint", + ) + if not response: + raise ValueError("Failed to get the websocket endpoint.") + body_dict = json.loads(response.body) + ws_path = body_dict["endpoint"] + ws_url = ws_base_url + "/" + ws_path + + # port 9222 is the default port for the browser websocket endpoint. + ws_access_token = self.generate_access_token( + service_account_email, + connection_info.load_balancer_hostname, + "9222", + timeout, + ) + + routing_token = connection_info.routing_token + + headers = {} + headers["Sec-WebSocket-Protocol"] = ( + f"v1.stream, {ws_access_token}, {routing_token}, 9222" + ) + return ws_url, headers + + +class AsyncSandboxes(_api_module.BaseModule): + + async def _create( + self, + *, + name: str, + spec: Optional[types.SandboxEnvironmentSpecOrDict] = None, + config: Optional[types.CreateAgentEngineSandboxConfigOrDict] = None, + ) -> types.AgentEngineSandboxOperation: + """ + Creates a new sandbox in the Agent Engine. + """ + + parameter_model = types._CreateAgentEngineSandboxRequestParameters( + name=name, + spec=spec, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironments".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironments" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSandboxConfigOrDict] = None, + ) -> types.DeleteAgentEngineSandboxOperation: + """ + Delete an Agent Engine sandbox. + + Args: + name (str): + Required. The name of the Agent Engine sandbox to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sandboxEnvironments/{sandbox}`. + + """ + + parameter_model = types._DeleteAgentEngineSandboxRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _execute_code( + self, + *, + name: str, + inputs: Optional[list[types.ChunkOrDict]] = None, + config: Optional[types.ExecuteCodeAgentEngineSandboxConfigOrDict] = None, + ) -> types.ExecuteSandboxEnvironmentResponse: + """ + Execute code in an Agent Engine sandbox. + """ + + parameter_model = types._ExecuteCodeAgentEngineSandboxRequestParameters( + name=name, + inputs=inputs, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ExecuteCodeAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/:execute".format_map(request_url_dict) + else: + path = "{name}/:execute" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ExecuteSandboxEnvironmentResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSandboxConfigOrDict] = None, + ) -> types.SandboxEnvironment: + """ + Gets an agent engine sandbox. + + Args: + name (str): Required. A fully-qualified resource name or ID such as + "projects/123/locations/us-central1/reasoningEngines/456/sandboxEnvironments/789" + or a shortened name such as "reasoningEngines/456/sandboxEnvironments/789". + + """ + + parameter_model = types._GetAgentEngineSandboxRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SandboxEnvironment._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSandboxesConfigOrDict] = None, + ) -> types.ListAgentEngineSandboxesResponse: + """ + Lists Agent Engine sandboxes. + + Args: + name (str): Required. The name of the Agent Engine to list sessions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineSandboxesConfig): + Optional. Additional configurations for listing the Agent Engine sandboxes. + + Returns: + ListReasoningEnginesSandboxesResponse: The requested Agent Engine sandboxes. + + """ + + parameter_model = types._ListAgentEngineSandboxesRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSandboxesRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sandboxEnvironments".format_map(request_url_dict) + else: + path = "{name}/sandboxEnvironments" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineSandboxesResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_sandbox_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSandboxOperation: + parameter_model = types._GetAgentEngineSandboxOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSandboxOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSandboxOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value diff --git a/agentplatform/_genai/session_events.py b/agentplatform/_genai/session_events.py new file mode 100644 index 0000000000..05b0156777 --- /dev/null +++ b/agentplatform/_genai/session_events.py @@ -0,0 +1,543 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +import functools +import json +import logging +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import types + +logger = logging.getLogger("agentplatform_genai.sessionevents") + +logger.setLevel(logging.INFO) + + +def _AppendAgentEngineSessionEventConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["content"]) is not None: + setv(parent_object, ["content"], getv(from_object, ["content"])) + + if getv(from_object, ["actions"]) is not None: + setv(parent_object, ["actions"], getv(from_object, ["actions"])) + + if getv(from_object, ["error_code"]) is not None: + setv(parent_object, ["errorCode"], getv(from_object, ["error_code"])) + + if getv(from_object, ["error_message"]) is not None: + setv(parent_object, ["errorMessage"], getv(from_object, ["error_message"])) + + if getv(from_object, ["event_metadata"]) is not None: + setv(parent_object, ["eventMetadata"], getv(from_object, ["event_metadata"])) + + if getv(from_object, ["raw_event"]) is not None: + setv(parent_object, ["rawEvent"], getv(from_object, ["raw_event"])) + + return to_object + + +def _AppendAgentEngineSessionEventRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["author"]) is not None: + setv(to_object, ["author"], getv(from_object, ["author"])) + + if getv(from_object, ["invocation_id"]) is not None: + setv(to_object, ["invocationId"], getv(from_object, ["invocation_id"])) + + if getv(from_object, ["timestamp"]) is not None: + setv(to_object, ["timestamp"], getv(from_object, ["timestamp"])) + + if getv(from_object, ["config"]) is not None: + _AppendAgentEngineSessionEventConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _ListAgentEngineSessionEventsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineSessionEventsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineSessionEventsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class SessionEvents(_api_module.BaseModule): + + def append( + self, + *, + name: str, + author: str, + invocation_id: str, + timestamp: datetime.datetime, + config: Optional[types.AppendAgentEngineSessionEventConfigOrDict] = None, + ) -> types.AppendAgentEngineSessionEventResponse: + """ + Appends Agent Engine session event. + + Args: + name (str): Required. The name of the Agent Engine session to append the event to. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + author (str): Required. The author of the Agent Engine session event. + invocation_id (str): Required. The invocation ID of the Agent Engine session event. + timestamp (datetime.datetime): Required. The timestamp of the Agent Engine session event. + config (AppendAgentEngineSessionEventConfig): + Optional. Additional configurations for appending the Agent Engine session event. + + Returns: + AppendAgentEngineSessionEventResponse: The requested Agent Engine session event. + + """ + + parameter_model = types._AppendAgentEngineSessionEventRequestParameters( + name=name, + author=author, + invocation_id=invocation_id, + timestamp=timestamp, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AppendAgentEngineSessionEventRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:appendEvent".format_map(request_url_dict) + else: + path = "{name}:appendEvent" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AppendAgentEngineSessionEventResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, + ) -> types.ListAgentEngineSessionEventsResponse: + """ + Lists Agent Engine session events. + + Args: + name (str): Required. The name of the Agent Engine session to list events for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (ListAgentEngineSessionEventsConfig): + Optional. Additional configurations for listing the Agent Engine session events. + + Returns: + ListAgentEngineSessionEventsResponse: The requested Agent Engine session events. + + """ + + parameter_model = types._ListAgentEngineSessionEventsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSessionEventsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/events".format_map(request_url_dict) + else: + path = "{name}/events" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineSessionEventsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, + ) -> Iterator[types.SessionEvent]: + """Lists Agent Engine session events. + + Args: + name (str): Required. The name of the agent engine to list session + events for. + config (ListAgentEngineSessionEventsConfig): Optional. The configuration + for the session events to list. Currently, the `filter` field in + `config` only supports filtering by `timestamp`. The timestamp + value must be enclosed in double quotes and include the time zone + information. For example: + `config={'filter': 'timestamp>="2025-08-07T19:44:38.4Z"'}`. + + Returns: + Iterator[SessionEvent]: An iterable of session events. + """ + + return Pager( + "session_events", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + +class AsyncSessionEvents(_api_module.BaseModule): + + async def append( + self, + *, + name: str, + author: str, + invocation_id: str, + timestamp: datetime.datetime, + config: Optional[types.AppendAgentEngineSessionEventConfigOrDict] = None, + ) -> types.AppendAgentEngineSessionEventResponse: + """ + Appends Agent Engine session event. + + Args: + name (str): Required. The name of the Agent Engine session to append the event to. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + author (str): Required. The author of the Agent Engine session event. + invocation_id (str): Required. The invocation ID of the Agent Engine session event. + timestamp (datetime.datetime): Required. The timestamp of the Agent Engine session event. + config (AppendAgentEngineSessionEventConfig): + Optional. Additional configurations for appending the Agent Engine session event. + + Returns: + AppendAgentEngineSessionEventResponse: The requested Agent Engine session event. + + """ + + parameter_model = types._AppendAgentEngineSessionEventRequestParameters( + name=name, + author=author, + invocation_id=invocation_id, + timestamp=timestamp, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _AppendAgentEngineSessionEventRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}:appendEvent".format_map(request_url_dict) + else: + path = "{name}:appendEvent" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AppendAgentEngineSessionEventResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, + ) -> types.ListAgentEngineSessionEventsResponse: + """ + Lists Agent Engine session events. + + Args: + name (str): Required. The name of the Agent Engine session to list events for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (ListAgentEngineSessionEventsConfig): + Optional. Additional configurations for listing the Agent Engine session events. + + Returns: + ListAgentEngineSessionEventsResponse: The requested Agent Engine session events. + + """ + + parameter_model = types._ListAgentEngineSessionEventsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSessionEventsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/events".format_map(request_url_dict) + else: + path = "{name}/events" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListAgentEngineSessionEventsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionEventsConfigOrDict] = None, + ) -> AsyncPager[types.SessionEvent]: + """Lists Agent Engine session events. + + Args: + name (str): Required. The name of the agent engine to list session + events for. + config (ListAgentEngineSessionEventsConfig): Optional. The configuration + for the session events to list. Currently, the `filter` field in + `config` only supports filtering by `timestamp`. The timestamp + value must be enclosed in double quotes and include the time zone + information. For example: + `config={'filter': 'timestamp>="2025-08-07T19:44:38.4Z"'}`. + + Returns: + AsyncPager[SessionEvent]: An async pager of session events. + """ + + return AsyncPager( + "session_events", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) diff --git a/agentplatform/_genai/sessions.py b/agentplatform/_genai/sessions.py new file mode 100644 index 0000000000..59de5bd9b3 --- /dev/null +++ b/agentplatform/_genai/sessions.py @@ -0,0 +1,1395 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import functools +import importlib +import json +import logging +import typing +from typing import Any, Iterator, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv +from google.genai.pagers import AsyncPager, Pager + +from . import _agent_engines_utils +from . import types + +if typing.TYPE_CHECKING: + from . import session_events as session_events_module + + _ = session_events_module + + +logger = logging.getLogger("agentplatform_genai.sessions") + +logger.setLevel(logging.INFO) + + +def _CreateAgentEngineSessionConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["session_state"]) is not None: + setv(parent_object, ["sessionState"], getv(from_object, ["session_state"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + if getv(from_object, ["expire_time"]) is not None: + setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["session_id"]) is not None: + setv(parent_object, ["_query", "sessionId"], getv(from_object, ["session_id"])) + + return to_object + + +def _CreateAgentEngineSessionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["user_id"]) is not None: + setv(to_object, ["userId"], getv(from_object, ["user_id"])) + + if getv(from_object, ["config"]) is not None: + _CreateAgentEngineSessionConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _DeleteAgentEngineSessionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _GetAgentEngineSessionOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetAgentEngineSessionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + return to_object + + +def _ListAgentEngineSessionsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["page_size"]) is not None: + setv(parent_object, ["_query", "pageSize"], getv(from_object, ["page_size"])) + + if getv(from_object, ["page_token"]) is not None: + setv(parent_object, ["_query", "pageToken"], getv(from_object, ["page_token"])) + + if getv(from_object, ["filter"]) is not None: + setv(parent_object, ["_query", "filter"], getv(from_object, ["filter"])) + + return to_object + + +def _ListAgentEngineSessionsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _ListAgentEngineSessionsConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +def _UpdateAgentEngineSessionConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["display_name"]) is not None: + setv(parent_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["session_state"]) is not None: + setv(parent_object, ["sessionState"], getv(from_object, ["session_state"])) + + if getv(from_object, ["ttl"]) is not None: + setv(parent_object, ["ttl"], getv(from_object, ["ttl"])) + + if getv(from_object, ["expire_time"]) is not None: + setv(parent_object, ["expireTime"], getv(from_object, ["expire_time"])) + + if getv(from_object, ["labels"]) is not None: + setv(parent_object, ["labels"], getv(from_object, ["labels"])) + + if getv(from_object, ["session_id"]) is not None: + setv(parent_object, ["_query", "sessionId"], getv(from_object, ["session_id"])) + + if getv(from_object, ["update_mask"]) is not None: + setv( + parent_object, ["_query", "updateMask"], getv(from_object, ["update_mask"]) + ) + + if getv(from_object, ["user_id"]) is not None: + setv(parent_object, ["userId"], getv(from_object, ["user_id"])) + + return to_object + + +def _UpdateAgentEngineSessionRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + _UpdateAgentEngineSessionConfig_to_vertex( + getv(from_object, ["config"]), to_object + ) + + return to_object + + +class Sessions(_api_module.BaseModule): + + def _create( + self, + *, + name: str, + user_id: str, + config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """ + Creates a new session in the Agent Engine. + + Args: + name (str): Required. The name of the Agent Engine to create the session under. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + user_id (str): Required. The user ID of the session. + config (CreateAgentEngineSessionConfig): + Optional. Additional configurations for creating the Agent Engine session. + + Returns: + AgentEngineSessionOperation: The operation for creating the Agent Engine session. + + """ + + parameter_model = types._CreateAgentEngineSessionRequestParameters( + name=name, + user_id=user_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sessions".format_map(request_url_dict) + else: + path = "{name}/sessions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSessionConfigOrDict] = None, + ) -> types.DeleteAgentEngineSessionOperation: + """ + Delete an Agent Engine session. + + Args: + name (str): Required. The name of the Agent Engine session to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (DeleteAgentEngineSessionConfig): + Optional. Additional configurations for deleting the Agent Engine session. + + Returns: + DeleteAgentEngineSessionOperation: The operation for deleting the Agent Engine session. + + """ + + parameter_model = types._DeleteAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("delete", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSessionConfigOrDict] = None, + ) -> types.Session: + """ + Gets an agent engine session. + + Args: + name (str): Required. The name of the Agent Engine session to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (GetAgentEngineSessionConfig): + Optional. Additional configurations for getting the Agent Engine session. + + Returns: + AgentEngineSession: The requested Agent Engine session. + + """ + + parameter_model = types._GetAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Session._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, + ) -> types.ListReasoningEnginesSessionsResponse: + """ + Lists Agent Engine sessions. + + Args: + name (str): Required. The name of the Agent Engine to list sessions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineSessionsConfig): + Optional. Additional configurations for listing the Agent Engine sessions. + + Returns: + ListReasoningEnginesSessionsResponse: The requested Agent Engine sessions. + + """ + + parameter_model = types._ListAgentEngineSessionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSessionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sessions".format_map(request_url_dict) + else: + path = "{name}/sessions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListReasoningEnginesSessionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_session_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + parameter_model = types._GetAgentEngineSessionOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSessionOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _update( + self, + *, + name: str, + config: Optional[types.UpdateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """ + Updates an Agent Engine session. + + Args: + name (str): Required. The name of the Agent Engine session to be updated. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (UpdateAgentEngineSessionConfig): + Optional. Additional configurations for updating the Agent Engine session. + + Returns: + AgentEngineSessionOperation: The operation for updating the Agent Engine session. + + """ + + parameter_model = types._UpdateAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("patch", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _events = None + + @property + def events(self) -> "session_events_module.SessionEvents": + if self._events is None: + try: + # We need to lazy load the sessions.events module to handle the + # possibility of ImportError when dependencies are not installed. + self._events = importlib.import_module(".session_events", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.sessions.events' module requires" + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._events.SessionEvents(self._api_client) # type: ignore[no-any-return] + + def create( + self, + *, + name: str, + user_id: str, + config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """Creates a new session in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create the session for. + user_id (str): + Required. The user ID of the session. + config (CreateAgentEngineSessionConfig): + Optional. The configuration for the session to create. + + Returns: + AgentEngineSessionOperation: The operation for creating the session. + """ + if config is None: + config = types.CreateAgentEngineSessionConfig() + elif isinstance(config, dict): + config = types.CreateAgentEngineSessionConfig.model_validate(config) + operation = self._create( + name=name, + user_id=user_id, + config=config, + ) + if config.wait_for_completion: + if not operation.done: + operation = _agent_engines_utils._await_operation( + operation_name=operation.name, + get_operation_fn=self._get_session_operation, + poll_interval_seconds=0.5, + ) + # We need to make a call to get the session because the operation + # response might not contain the relevant fields. + if operation.response: + operation.response = self.get(name=operation.response.name) + elif operation.error: + raise RuntimeError(f"Failed to create session: {operation.error}") + else: + raise RuntimeError( + "Error retrieving session from the operation response. " + f"Operation name: {operation.name}" + ) + return operation + + def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, + ) -> Iterator[types.Session]: + """Lists Agent Engine sessions. + + Args: + name (str): Required. The name of the agent engine to list sessions + for. + config (ListAgentEngineSessionConfig): Optional. The configuration + for the sessions to list. + + Returns: + Iterable[Session]: An iterable of sessions. + """ + + return Pager( + "sessions", + functools.partial(self._list, name=name), + self._list(name=name, config=config), + config, + ) + + +class AsyncSessions(_api_module.BaseModule): + + async def _create( + self, + *, + name: str, + user_id: str, + config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """ + Creates a new session in the Agent Engine. + + Args: + name (str): Required. The name of the Agent Engine to create the session under. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + user_id (str): Required. The user ID of the session. + config (CreateAgentEngineSessionConfig): + Optional. Additional configurations for creating the Agent Engine session. + + Returns: + AgentEngineSessionOperation: The operation for creating the Agent Engine session. + + """ + + parameter_model = types._CreateAgentEngineSessionRequestParameters( + name=name, + user_id=user_id, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sessions".format_map(request_url_dict) + else: + path = "{name}/sessions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def delete( + self, + *, + name: str, + config: Optional[types.DeleteAgentEngineSessionConfigOrDict] = None, + ) -> types.DeleteAgentEngineSessionOperation: + """ + Delete an Agent Engine session. + + Args: + name (str): Required. The name of the Agent Engine session to be deleted. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (DeleteAgentEngineSessionConfig): + Optional. Additional configurations for deleting the Agent Engine session. + + Returns: + DeleteAgentEngineSessionOperation: The operation for deleting the Agent Engine session. + + """ + + parameter_model = types._DeleteAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _DeleteAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "delete", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.DeleteAgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def get( + self, + *, + name: str, + config: Optional[types.GetAgentEngineSessionConfigOrDict] = None, + ) -> types.Session: + """ + Gets an agent engine session. + + Args: + name (str): Required. The name of the Agent Engine session to get. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (GetAgentEngineSessionConfig): + Optional. Additional configurations for getting the Agent Engine session. + + Returns: + AgentEngineSession: The requested Agent Engine session. + + """ + + parameter_model = types._GetAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Session._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, + ) -> types.ListReasoningEnginesSessionsResponse: + """ + Lists Agent Engine sessions. + + Args: + name (str): Required. The name of the Agent Engine to list sessions for. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}`. + config (ListAgentEngineSessionsConfig): + Optional. Additional configurations for listing the Agent Engine sessions. + + Returns: + ListReasoningEnginesSessionsResponse: The requested Agent Engine sessions. + + """ + + parameter_model = types._ListAgentEngineSessionsRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _ListAgentEngineSessionsRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/sessions".format_map(request_url_dict) + else: + path = "{name}/sessions" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.ListReasoningEnginesSessionsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_session_operation( + self, + *, + operation_name: str, + config: Optional[types.GetAgentEngineOperationConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + parameter_model = types._GetAgentEngineSessionOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetAgentEngineSessionOperationParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _update( + self, + *, + name: str, + config: Optional[types.UpdateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """ + Updates an Agent Engine session. + + Args: + name (str): Required. The name of the Agent Engine session to be updated. Format: + `projects/{project}/locations/{location}/reasoningEngines/{resource_id}/sessions/{session_id}`. + config (UpdateAgentEngineSessionConfig): + Optional. Additional configurations for updating the Agent Engine session. + + Returns: + AgentEngineSessionOperation: The operation for updating the Agent Engine session. + + """ + + parameter_model = types._UpdateAgentEngineSessionRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _UpdateAgentEngineSessionRequestParameters_to_vertex( + parameter_model + ) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "patch", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.AgentEngineSessionOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + _events = None + + @property + def events(self) -> "session_events_module.AsyncSessionEvents": + if self._events is None: + try: + # We need to lazy load the sessions.events module to handle the + # possibility of ImportError when dependencies are not installed. + self._events = importlib.import_module(".session_events", __package__) + except ImportError as e: + raise ImportError( + "The 'agent_engines.sessions.events' module requires" + "additional packages. Please install them using pip install " + "google-cloud-aiplatform[agent_engines]" + ) from e + return self._events.AsyncSessionEvents(self._api_client) # type: ignore[no-any-return] + + async def create( + self, + *, + name: str, + user_id: str, + config: Optional[types.CreateAgentEngineSessionConfigOrDict] = None, + ) -> types.AgentEngineSessionOperation: + """Creates a new session in the Agent Engine. + + Args: + name (str): + Required. The name of the agent engine to create the session for. + user_id (str): + Required. The user ID of the session. + config (CreateAgentEngineSessionConfig): + Optional. The configuration for the session to create. + + Returns: + AgentEngineSessionOperation: The operation for creating the session. + """ + if config is None: + config = types.CreateAgentEngineSessionConfig() + elif isinstance(config, dict): + config = types.CreateAgentEngineSessionConfig.model_validate(config) + operation = await self._create( + name=name, + user_id=user_id, + config=config, + ) + if config.wait_for_completion: + if not operation.done: + operation = await _agent_engines_utils._await_async_operation( + operation_name=operation.name, + get_operation_fn=self._get_session_operation, + poll_interval_seconds=0.5, + ) + # We need to make a call to get the session because the operation + # response might not contain the relevant fields. + if operation.response: + operation.response = await self.get(name=operation.response.name) + elif operation.error: + raise RuntimeError(f"Failed to create session: {operation.error}") + else: + raise RuntimeError( + "Error retrieving session from the operation response. " + f"Operation name: {operation.name}" + ) + return operation + + async def list( + self, + *, + name: str, + config: Optional[types.ListAgentEngineSessionsConfigOrDict] = None, + ) -> AsyncPager[types.Session]: + """Lists Agent Engine sessions. + + Args: + name (str): Required. The name of the agent engine to list sessions + for. + config (ListAgentEngineSessionConfig): Optional. The configuration + for the sessions to list. + + Returns: + AsyncPager[Session]: An async pager of sessions. + """ + + return AsyncPager( + "sessions", + functools.partial(self._list, name=name), + await self._list(name=name, config=config), + config, + ) diff --git a/agentplatform/_genai/skills.py b/agentplatform/_genai/skills.py new file mode 100644 index 0000000000..835fbed2d6 --- /dev/null +++ b/agentplatform/_genai/skills.py @@ -0,0 +1,869 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import asyncio +import base64 +import json +import logging +from typing import Any, Optional, Union +from urllib.parse import urlencode + +from google.genai import _api_module +from google.genai import _common +from google.genai._common import get_value_by_path as getv +from google.genai._common import set_value_by_path as setv + +from . import _skills_utils +from . import types + +logger = logging.getLogger("agentplatform_genai.skills") + + +def _CreateSkillConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["zipped_filesystem"]) is not None: + setv( + parent_object, + ["zippedFilesystem"], + getv(from_object, ["zipped_filesystem"]), + ) + + if getv(from_object, ["skill_id"]) is not None: + setv(parent_object, ["_query", "skillId"], getv(from_object, ["skill_id"])) + + return to_object + + +def _CreateSkillRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["display_name"]) is not None: + setv(to_object, ["displayName"], getv(from_object, ["display_name"])) + + if getv(from_object, ["description"]) is not None: + setv(to_object, ["description"], getv(from_object, ["description"])) + + if getv(from_object, ["config"]) is not None: + _CreateSkillConfig_to_vertex(getv(from_object, ["config"]), to_object) + + return to_object + + +def _GetSkillOperationParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["operation_name"]) is not None: + setv( + to_object, ["_url", "operationName"], getv(from_object, ["operation_name"]) + ) + + return to_object + + +def _GetSkillRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["config"]) is not None: + setv(to_object, ["config"], getv(from_object, ["config"])) + + return to_object + + +def _RetrieveSkillsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["top_k"]) is not None: + setv(parent_object, ["_query", "topK"], getv(from_object, ["top_k"])) + + return to_object + + +def _RetrieveSkillsRequestParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["query"]) is not None: + setv(to_object, ["_query", "query"], getv(from_object, ["query"])) + + if getv(from_object, ["config"]) is not None: + setv( + to_object, + ["config"], + _RetrieveSkillsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) + + return to_object + + +class Skills(_api_module.BaseModule): + """Class for managing Skills in the Skill Registry.""" + + def get( + self, *, name: str, config: Optional[types.GetSkillConfigOrDict] = None + ) -> types.Skill: + """ + Gets a Skill. + """ + + parameter_model = types._GetSkillRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Skill._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def retrieve( + self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None + ) -> types.RetrieveSkillsResponse: + """ + Retrieves skills semantically matched to a query. + """ + + parameter_model = types._RetrieveSkillsRequestParameters( + query=query, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills:retrieve".format_map(request_url_dict) + else: + path = "skills:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _create( + self, + *, + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> types.SkillOperation: + """ + Creates a new Skill. + """ + + parameter_model = types._CreateSkillRequestParameters( + display_name=display_name, + description=description, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills".format_map(request_url_dict) + else: + path = "skills" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def _get_skill_operation( + self, + *, + operation_name: str, + config: Optional[types.GetSkillOperationConfigOrDict] = None, + ) -> types.SkillOperation: + parameter_model = types._GetSkillOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("get", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + def create( + self, + *, + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> Union[types.Skill, types.SkillOperation]: + """Creates a new Skill. + + Args: + display_name (str): + Required. The display name of the Skill. + description (str): + Required. The description of the Skill. + config (CreateSkillConfigOrDict): + Optional. The configuration for creating the Skill. + + Returns: + Skill: The created Skill if wait_for_completion is True. + SkillOperation: The operation for creating the Skill if + wait_for_completion is False. + """ + if config is None: + config = types.CreateSkillConfig() + elif isinstance(config, dict): + config = types.CreateSkillConfig.model_validate(config) + elif not isinstance(config, types.CreateSkillConfig): + raise TypeError( + f"config must be a dict or CreateSkillConfig, but got {type(config)}." + ) + + config = config.model_copy() + + local_path = config.local_path + zipped_filesystem = config.zipped_filesystem + + if local_path and zipped_filesystem: + raise ValueError( + "Only one of `local_path` or `zipped_filesystem` can be provided in config." + ) + if not local_path and not zipped_filesystem: + raise ValueError( + "Either `local_path` or `zipped_filesystem` must be provided in config." + ) + + if local_path: + zipped_filesystem_payload = _skills_utils.get_zipped_filesystem_payload( + local_path + ) + else: + # Narrow type for mypy + if zipped_filesystem is None: + raise ValueError( + "zipped_filesystem is required if local_path is not provided." + ) + if isinstance(zipped_filesystem, bytes): + zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( + "utf-8" + ) + else: + zipped_filesystem_payload = zipped_filesystem + + # Mutate the config object to populate the zipped_filesystem payload + config.zipped_filesystem = zipped_filesystem_payload + + operation = self._create( + display_name=display_name, + description=description, + config=config, + ) + + if config.wait_for_completion: + operation = _skills_utils.await_operation( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to create Skill: {operation.error}") + # Fetch the fully populated Skill resource from the server + return self.get(name=operation.response.name) + + return operation + + +class AsyncSkills(_api_module.BaseModule): + """Class for managing Skills in the Skill Registry.""" + + async def get( + self, *, name: str, config: Optional[types.GetSkillConfigOrDict] = None + ) -> types.Skill: + """ + Gets a Skill. + """ + + parameter_model = types._GetSkillRequestParameters( + name=name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}".format_map(request_url_dict) + else: + path = "{name}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.Skill._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def retrieve( + self, *, query: str, config: Optional[types.RetrieveSkillsConfigOrDict] = None + ) -> types.RetrieveSkillsResponse: + """ + Retrieves skills semantically matched to a query. + """ + + parameter_model = types._RetrieveSkillsRequestParameters( + query=query, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _RetrieveSkillsRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills:retrieve".format_map(request_url_dict) + else: + path = "skills:retrieve" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.RetrieveSkillsResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _create( + self, + *, + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> types.SkillOperation: + """ + Creates a new Skill. + """ + + parameter_model = types._CreateSkillRequestParameters( + display_name=display_name, + description=description, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _CreateSkillRequestParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "skills".format_map(request_url_dict) + else: + path = "skills" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def _get_skill_operation( + self, + *, + operation_name: str, + config: Optional[types.GetSkillOperationConfigOrDict] = None, + ) -> types.SkillOperation: + parameter_model = types._GetSkillOperationParameters( + operation_name=operation_name, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in the Gemini Enterprise Agent Platform (previously known as Vertex AI) client." + ) + else: + request_dict = _GetSkillOperationParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{operationName}".format_map(request_url_dict) + else: + path = "{operationName}" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "get", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.SkillOperation._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + + async def create( + self, + *, + display_name: str, + description: str, + config: Optional[types.CreateSkillConfigOrDict] = None, + ) -> Union[types.Skill, types.SkillOperation]: + """Creates a new Skill asynchronously. + + Args: + display_name (str): + Required. The display name of the Skill. + description (str): + Required. The description of the Skill. + config (CreateSkillConfigOrDict): + Optional. The configuration for creating the Skill. + + Returns: + Skill: The created Skill if wait_for_completion is True. + SkillOperation: The operation for creating the Skill if + wait_for_completion is False. + """ + if config is None: + config = types.CreateSkillConfig() + elif isinstance(config, dict): + config = types.CreateSkillConfig.model_validate(config) + elif not isinstance(config, types.CreateSkillConfig): + raise TypeError( + f"config must be a dict or CreateSkillConfig, but got {type(config)}." + ) + + config = config.model_copy() + + local_path = config.local_path + zipped_filesystem = config.zipped_filesystem + + if local_path and zipped_filesystem: + raise ValueError( + "Only one of `local_path` or `zipped_filesystem` can be provided in config." + ) + if not local_path and not zipped_filesystem: + raise ValueError( + "Either `local_path` or `zipped_filesystem` must be provided in config." + ) + + if local_path: + loop = asyncio.get_running_loop() + zipped_filesystem_payload = await loop.run_in_executor( + None, _skills_utils.get_zipped_filesystem_payload, local_path + ) + else: + # Narrow type for mypy + if zipped_filesystem is None: + raise ValueError( + "zipped_filesystem is required if local_path is not provided." + ) + if isinstance(zipped_filesystem, bytes): + zipped_filesystem_payload = base64.b64encode(zipped_filesystem).decode( + "utf-8" + ) + else: + zipped_filesystem_payload = zipped_filesystem + + # Mutate the config object to populate the zipped_filesystem payload + config.zipped_filesystem = zipped_filesystem_payload + + operation = await self._create( + display_name=display_name, + description=description, + config=config, + ) + + if config.wait_for_completion: + operation = await _skills_utils.await_operation_async( + operation_name=operation.name, + get_operation_fn=self._get_skill_operation, + ) + if operation.error: + raise RuntimeError(f"Failed to create Skill: {operation.error}") + # Fetch the fully populated Skill resource asynchronously + return await self.get(name=operation.response.name) + + return operation diff --git a/agentplatform/_genai/types/__init__.py b/agentplatform/_genai/types/__init__.py new file mode 100644 index 0000000000..2e4acdba09 --- /dev/null +++ b/agentplatform/_genai/types/__init__.py @@ -0,0 +1,2789 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. +# flake8: noqa: F401 + +import importlib +import typing + +from . import agent_engines +from . import evals +from . import prompts +from .common import _AppendAgentEngineSessionEventRequestParameters +from .common import _AppendAgentEngineTaskEventRequestParameters +from .common import _AssembleDatasetParameters +from .common import _AssessDatasetParameters +from .common import _CancelQueryJobAgentEngineRequestParameters +from .common import _CheckQueryJobAgentEngineRequestParameters +from .common import _CreateAgentEngineMemoryRequestParameters +from .common import _CreateAgentEngineRequestParameters +from .common import _CreateAgentEngineSandboxRequestParameters +from .common import _CreateAgentEngineSessionRequestParameters +from .common import _CreateAgentEngineTaskRequestParameters +from .common import _CreateDatasetParameters +from .common import _CreateDatasetVersionParameters +from .common import _CreateEvaluationItemParameters +from .common import _CreateEvaluationMetricParameters +from .common import _CreateEvaluationRunParameters +from .common import _CreateEvaluationSetParameters +from .common import _CreateMultimodalDatasetParameters +from .common import _CreateSandboxEnvironmentSnapshotRequestParameters +from .common import _CreateSandboxEnvironmentTemplateRequestParameters +from .common import _CreateSkillRequestParameters +from .common import _CustomJobParameters +from .common import _CustomJobParameters +from .common import _DeleteAgentEngineMemoryRequestParameters +from .common import _DeleteAgentEngineRequestParameters +from .common import _DeleteAgentEngineRuntimeRevisionRequestParameters +from .common import _DeleteAgentEngineSandboxRequestParameters +from .common import _DeleteAgentEngineSessionRequestParameters +from .common import _DeleteAgentEngineTaskRequestParameters +from .common import _DeleteDatasetRequestParameters +from .common import _DeleteEvaluationMetricParameters +from .common import _DeleteMultimodalDatasetRequestParameters +from .common import _DeletePromptVersionRequestParameters +from .common import _DeleteSandboxEnvironmentSnapshotRequestParameters +from .common import _DeleteSandboxEnvironmentTemplateRequestParameters +from .common import _EvaluateInstancesRequestParameters +from .common import _ExecuteCodeAgentEngineSandboxRequestParameters +from .common import _GenerateAgentEngineMemoriesRequestParameters +from .common import _GenerateInstanceRubricsRequest +from .common import _GenerateLossClustersParameters +from .common import _GenerateUserScenariosParameters +from .common import _GetAgentEngineGenerateMemoriesOperationParameters +from .common import _GetAgentEngineMemoryOperationParameters +from .common import _GetAgentEngineMemoryRequestParameters +from .common import _GetAgentEngineMemoryRevisionRequestParameters +from .common import _GetAgentEngineOperationParameters +from .common import _GetAgentEngineRequestParameters +from .common import _GetAgentEngineRuntimeRevisionRequestParameters +from .common import _GetAgentEngineSandboxOperationParameters +from .common import _GetAgentEngineSandboxRequestParameters +from .common import _GetAgentEngineSandboxSnapshotOperationParameters +from .common import _GetAgentEngineSessionOperationParameters +from .common import _GetAgentEngineSessionRequestParameters +from .common import _GetAgentEngineTaskRequestParameters +from .common import _GetCustomJobParameters +from .common import _GetCustomJobParameters +from .common import _GetDatasetOperationParameters +from .common import _GetDatasetParameters +from .common import _GetDatasetVersionParameters +from .common import _GetDeleteAgentEngineRuntimeRevisionOperationParameters +from .common import _GetEvaluationItemParameters +from .common import _GetEvaluationMetricParameters +from .common import _GetEvaluationRunParameters +from .common import _GetEvaluationSetParameters +from .common import _GetMultimodalDatasetOperationParameters +from .common import _GetMultimodalDatasetParameters +from .common import _GetSandboxEnvironmentSnapshotRequestParameters +from .common import _GetSandboxEnvironmentTemplateOperationParameters +from .common import _GetSandboxEnvironmentTemplateRequestParameters +from .common import _GetSkillOperationParameters +from .common import _GetSkillRequestParameters +from .common import _IngestEventsRequestParameters +from .common import _ListAgentEngineMemoryRequestParameters +from .common import _ListAgentEngineMemoryRevisionsRequestParameters +from .common import _ListAgentEngineRequestParameters +from .common import _ListAgentEngineRuntimeRevisionsRequestParameters +from .common import _ListAgentEngineSandboxesRequestParameters +from .common import _ListAgentEngineSessionEventsRequestParameters +from .common import _ListAgentEngineSessionsRequestParameters +from .common import _ListAgentEngineTaskEventsRequestParameters +from .common import _ListAgentEngineTasksRequestParameters +from .common import _ListDatasetsRequestParameters +from .common import _ListDatasetVersionsRequestParameters +from .common import _ListEvaluationMetricsParameters +from .common import _ListMultimodalDatasetsRequestParameters +from .common import _ListSandboxEnvironmentSnapshotsRequestParameters +from .common import _ListSandboxEnvironmentTemplatesRequestParameters +from .common import _OptimizeRequestParameters +from .common import _OptimizeRequestParameters +from .common import _PurgeAgentEngineMemoriesRequestParameters +from .common import _QueryAgentEngineRequestParameters +from .common import _QueryAgentEngineRuntimeRevisionRequestParameters +from .common import _RestoreVersionRequestParameters +from .common import _RetrieveAgentEngineMemoriesRequestParameters +from .common import _RetrieveMemoryProfilesRequestParameters +from .common import _RetrieveSkillsRequestParameters +from .common import _RollbackAgentEngineMemoryRequestParameters +from .common import _RunQueryJobAgentEngineConfig +from .common import _RunQueryJobAgentEngineConfigDict +from .common import _RunQueryJobAgentEngineConfigOrDict +from .common import _RunQueryJobAgentEngineRequestParameters +from .common import _UpdateAgentEngineMemoryRequestParameters +from .common import _UpdateAgentEngineRequestParameters +from .common import _UpdateAgentEngineSessionRequestParameters +from .common import _UpdateDatasetParameters +from .common import _UpdateMultimodalDatasetParameters +from .common import A2aTask +from .common import A2aTaskDict +from .common import A2aTaskOrDict +from .common import A2aTaskState +from .common import AcceleratorType +from .common import AgentEngine +from .common import AgentEngineConfig +from .common import AgentEngineConfigDict +from .common import AgentEngineConfigOrDict +from .common import AgentEngineDict +from .common import AgentEngineGenerateMemoriesOperation +from .common import AgentEngineGenerateMemoriesOperationDict +from .common import AgentEngineGenerateMemoriesOperationOrDict +from .common import AgentEngineMemoryConfig +from .common import AgentEngineMemoryConfigDict +from .common import AgentEngineMemoryConfigOrDict +from .common import AgentEngineMemoryOperation +from .common import AgentEngineMemoryOperationDict +from .common import AgentEngineMemoryOperationOrDict +from .common import AgentEngineOperation +from .common import AgentEngineOperationDict +from .common import AgentEngineOperationOrDict +from .common import AgentEngineOrDict +from .common import AgentEnginePurgeMemoriesOperation +from .common import AgentEnginePurgeMemoriesOperationDict +from .common import AgentEnginePurgeMemoriesOperationOrDict +from .common import AgentEngineRollbackMemoryOperation +from .common import AgentEngineRollbackMemoryOperationDict +from .common import AgentEngineRollbackMemoryOperationOrDict +from .common import AgentEngineRuntimeRevision +from .common import AgentEngineRuntimeRevisionDict +from .common import AgentEngineRuntimeRevisionOrDict +from .common import AgentEngineSandboxOperation +from .common import AgentEngineSandboxOperationDict +from .common import AgentEngineSandboxOperationOrDict +from .common import AgentEngineSandboxSnapshotOperation +from .common import AgentEngineSandboxSnapshotOperationDict +from .common import AgentEngineSandboxSnapshotOperationOrDict +from .common import AgentEngineSessionOperation +from .common import AgentEngineSessionOperationDict +from .common import AgentEngineSessionOperationOrDict +from .common import AgentRunConfig +from .common import AgentRunConfigDict +from .common import AgentRunConfigOrDict +from .common import AgentServerMode +from .common import AggregatedMetricResult +from .common import AggregatedMetricResultDict +from .common import AggregatedMetricResultOrDict +from .common import AppendAgentEngineSessionEventConfig +from .common import AppendAgentEngineSessionEventConfigDict +from .common import AppendAgentEngineSessionEventConfigOrDict +from .common import AppendAgentEngineSessionEventResponse +from .common import AppendAgentEngineSessionEventResponseDict +from .common import AppendAgentEngineSessionEventResponseOrDict +from .common import AppendAgentEngineTaskEventConfig +from .common import AppendAgentEngineTaskEventConfigDict +from .common import AppendAgentEngineTaskEventConfigOrDict +from .common import AppendAgentEngineTaskEventResponse +from .common import AppendAgentEngineTaskEventResponseDict +from .common import AppendAgentEngineTaskEventResponseOrDict +from .common import AssembleDataset +from .common import AssembleDatasetConfig +from .common import AssembleDatasetConfigDict +from .common import AssembleDatasetConfigOrDict +from .common import AssembleDatasetDict +from .common import AssembleDatasetOrDict +from .common import AssessDatasetConfig +from .common import AssessDatasetConfigDict +from .common import AssessDatasetConfigOrDict +from .common import BatchPredictionResourceUsageAssessmentConfig +from .common import BatchPredictionResourceUsageAssessmentConfigDict +from .common import BatchPredictionResourceUsageAssessmentConfigOrDict +from .common import BatchPredictionResourceUsageAssessmentResult +from .common import BatchPredictionResourceUsageAssessmentResultDict +from .common import BatchPredictionResourceUsageAssessmentResultOrDict +from .common import BatchPredictionValidationAssessmentConfig +from .common import BatchPredictionValidationAssessmentConfigDict +from .common import BatchPredictionValidationAssessmentConfigOrDict +from .common import BatchPredictionValidationAssessmentResult +from .common import BatchPredictionValidationAssessmentResultDict +from .common import BatchPredictionValidationAssessmentResultOrDict +from .common import BigQueryRequestSet +from .common import BigQueryRequestSetDict +from .common import BigQueryRequestSetOrDict +from .common import BleuInput +from .common import BleuInputDict +from .common import BleuInputOrDict +from .common import BleuInstance +from .common import BleuInstanceDict +from .common import BleuInstanceOrDict +from .common import BleuResults +from .common import BleuResultsDict +from .common import BleuResultsOrDict +from .common import CancelQueryJobAgentEngineConfig +from .common import CancelQueryJobAgentEngineConfigDict +from .common import CancelQueryJobAgentEngineConfigOrDict +from .common import CancelQueryJobResult +from .common import CancelQueryJobResultDict +from .common import CancelQueryJobResultOrDict +from .common import CandidateResponse +from .common import CandidateResponseDict +from .common import CandidateResponseOrDict +from .common import CandidateResult +from .common import CandidateResultDict +from .common import CheckQueryJobAgentEngineConfig +from .common import CheckQueryJobAgentEngineConfigDict +from .common import CheckQueryJobAgentEngineConfigOrDict +from .common import CheckQueryJobResponse +from .common import CheckQueryJobResponseDict +from .common import CheckQueryJobResponseOrDict +from .common import CheckQueryJobResult +from .common import CheckQueryJobResultDict +from .common import CheckQueryJobResultOrDict +from .common import Chunk +from .common import ChunkDict +from .common import ChunkOrDict +from .common import CodeExecutionMetric +from .common import CometResult +from .common import CometResultDict +from .common import CometResultOrDict +from .common import ContainerSpec +from .common import ContainerSpecDict +from .common import ContainerSpecOrDict +from .common import ContentMap +from .common import ContentMapContents +from .common import ContentMapContentsDict +from .common import ContentMapContentsOrDict +from .common import ContentMapDict +from .common import ContentMapOrDict +from .common import CreateAgentEngineConfig +from .common import CreateAgentEngineConfigDict +from .common import CreateAgentEngineConfigOrDict +from .common import CreateAgentEngineSandboxConfig +from .common import CreateAgentEngineSandboxConfigDict +from .common import CreateAgentEngineSandboxConfigOrDict +from .common import CreateAgentEngineSandboxSnapshotConfig +from .common import CreateAgentEngineSandboxSnapshotConfigDict +from .common import CreateAgentEngineSandboxSnapshotConfigOrDict +from .common import CreateAgentEngineSessionConfig +from .common import CreateAgentEngineSessionConfigDict +from .common import CreateAgentEngineSessionConfigOrDict +from .common import CreateAgentEngineTaskConfig +from .common import CreateAgentEngineTaskConfigDict +from .common import CreateAgentEngineTaskConfigOrDict +from .common import CreateDatasetConfig +from .common import CreateDatasetConfigDict +from .common import CreateDatasetConfigOrDict +from .common import CreateDatasetVersionConfig +from .common import CreateDatasetVersionConfigDict +from .common import CreateDatasetVersionConfigOrDict +from .common import CreateEvaluationItemConfig +from .common import CreateEvaluationItemConfigDict +from .common import CreateEvaluationItemConfigOrDict +from .common import CreateEvaluationMetricConfig +from .common import CreateEvaluationMetricConfigDict +from .common import CreateEvaluationMetricConfigOrDict +from .common import CreateEvaluationRunConfig +from .common import CreateEvaluationRunConfigDict +from .common import CreateEvaluationRunConfigOrDict +from .common import CreateEvaluationSetConfig +from .common import CreateEvaluationSetConfigDict +from .common import CreateEvaluationSetConfigOrDict +from .common import CreateMultimodalDatasetConfig +from .common import CreateMultimodalDatasetConfigDict +from .common import CreateMultimodalDatasetConfigOrDict +from .common import CreatePromptConfig +from .common import CreatePromptConfigDict +from .common import CreatePromptConfigOrDict +from .common import CreatePromptVersionConfig +from .common import CreatePromptVersionConfigDict +from .common import CreatePromptVersionConfigOrDict +from .common import CreateSandboxEnvironmentTemplateConfig +from .common import CreateSandboxEnvironmentTemplateConfigDict +from .common import CreateSandboxEnvironmentTemplateConfigOrDict +from .common import CreateSkillConfig +from .common import CreateSkillConfigDict +from .common import CreateSkillConfigOrDict +from .common import CustomCodeExecutionSpec +from .common import CustomCodeExecutionSpecDict +from .common import CustomCodeExecutionSpecOrDict +from .common import CustomJob +from .common import CustomJobDict +from .common import CustomJobOrDict +from .common import CustomJobSpec +from .common import CustomJobSpecDict +from .common import CustomJobSpecOrDict +from .common import Dataset +from .common import DatasetDict +from .common import DatasetOperation +from .common import DatasetOperationDict +from .common import DatasetOperationOrDict +from .common import DatasetOrDict +from .common import DatasetVersion +from .common import DatasetVersionDict +from .common import DatasetVersionOrDict +from .common import DefaultContainerCategory +from .common import DeleteAgentEngineConfig +from .common import DeleteAgentEngineConfigDict +from .common import DeleteAgentEngineConfigOrDict +from .common import DeleteAgentEngineMemoryConfig +from .common import DeleteAgentEngineMemoryConfigDict +from .common import DeleteAgentEngineMemoryConfigOrDict +from .common import DeleteAgentEngineMemoryOperation +from .common import DeleteAgentEngineMemoryOperationDict +from .common import DeleteAgentEngineMemoryOperationOrDict +from .common import DeleteAgentEngineOperation +from .common import DeleteAgentEngineOperationDict +from .common import DeleteAgentEngineOperationOrDict +from .common import DeleteAgentEngineRuntimeRevisionConfig +from .common import DeleteAgentEngineRuntimeRevisionConfigDict +from .common import DeleteAgentEngineRuntimeRevisionConfigOrDict +from .common import DeleteAgentEngineRuntimeRevisionOperation +from .common import DeleteAgentEngineRuntimeRevisionOperationDict +from .common import DeleteAgentEngineRuntimeRevisionOperationOrDict +from .common import DeleteAgentEngineSandboxConfig +from .common import DeleteAgentEngineSandboxConfigDict +from .common import DeleteAgentEngineSandboxConfigOrDict +from .common import DeleteAgentEngineSandboxOperation +from .common import DeleteAgentEngineSandboxOperationDict +from .common import DeleteAgentEngineSandboxOperationOrDict +from .common import DeleteAgentEngineSessionConfig +from .common import DeleteAgentEngineSessionConfigDict +from .common import DeleteAgentEngineSessionConfigOrDict +from .common import DeleteAgentEngineSessionOperation +from .common import DeleteAgentEngineSessionOperationDict +from .common import DeleteAgentEngineSessionOperationOrDict +from .common import DeleteAgentEngineTaskConfig +from .common import DeleteAgentEngineTaskConfigDict +from .common import DeleteAgentEngineTaskConfigOrDict +from .common import DeleteEvaluationMetricConfig +from .common import DeleteEvaluationMetricConfigDict +from .common import DeleteEvaluationMetricConfigOrDict +from .common import DeleteEvaluationMetricOperation +from .common import DeleteEvaluationMetricOperationDict +from .common import DeleteEvaluationMetricOperationOrDict +from .common import DeletePromptConfig +from .common import DeletePromptConfigDict +from .common import DeletePromptConfigOrDict +from .common import DeletePromptOperation +from .common import DeletePromptOperationDict +from .common import DeletePromptOperationOrDict +from .common import DeletePromptVersionOperation +from .common import DeletePromptVersionOperationDict +from .common import DeletePromptVersionOperationOrDict +from .common import DeleteSandboxEnvironmentSnapshotConfig +from .common import DeleteSandboxEnvironmentSnapshotConfigDict +from .common import DeleteSandboxEnvironmentSnapshotConfigOrDict +from .common import DeleteSandboxEnvironmentSnapshotOperation +from .common import DeleteSandboxEnvironmentSnapshotOperationDict +from .common import DeleteSandboxEnvironmentSnapshotOperationOrDict +from .common import DeleteSandboxEnvironmentTemplateConfig +from .common import DeleteSandboxEnvironmentTemplateConfigDict +from .common import DeleteSandboxEnvironmentTemplateConfigOrDict +from .common import DeleteSandboxEnvironmentTemplateOperation +from .common import DeleteSandboxEnvironmentTemplateOperationDict +from .common import DeleteSandboxEnvironmentTemplateOperationOrDict +from .common import DiskSpec +from .common import DiskSpecDict +from .common import DiskSpecOrDict +from .common import DnsPeeringConfig +from .common import DnsPeeringConfigDict +from .common import DnsPeeringConfigOrDict +from .common import EnvVar +from .common import EnvVarDict +from .common import EnvVarOrDict +from .common import EvalCase +from .common import EvalCaseDict +from .common import EvalCaseMetricResult +from .common import EvalCaseMetricResultDict +from .common import EvalCaseMetricResultOrDict +from .common import EvalCaseOrDict +from .common import EvalCaseResult +from .common import EvalCaseResultDict +from .common import EvalCaseResultOrDict +from .common import EvalRunInferenceConfig +from .common import EvalRunInferenceConfigDict +from .common import EvalRunInferenceConfigOrDict +from .common import EvaluateDatasetConfig +from .common import EvaluateDatasetConfigDict +from .common import EvaluateDatasetConfigOrDict +from .common import EvaluateDatasetOperation +from .common import EvaluateDatasetOperationDict +from .common import EvaluateDatasetOperationOrDict +from .common import EvaluateDatasetRequestParameters +from .common import EvaluateDatasetRequestParametersDict +from .common import EvaluateDatasetRequestParametersOrDict +from .common import EvaluateInstancesConfig +from .common import EvaluateInstancesConfigDict +from .common import EvaluateInstancesConfigOrDict +from .common import EvaluateInstancesResponse +from .common import EvaluateInstancesResponseDict +from .common import EvaluateInstancesResponseOrDict +from .common import EvaluateMethodConfig +from .common import EvaluateMethodConfigDict +from .common import EvaluateMethodConfigOrDict +from .common import EvaluationDataset +from .common import EvaluationDatasetDict +from .common import EvaluationDatasetOrDict +from .common import EvaluationInstance +from .common import EvaluationInstanceDict +from .common import EvaluationInstanceOrDict +from .common import EvaluationItem +from .common import EvaluationItemDict +from .common import EvaluationItemOrDict +from .common import EvaluationItemRequest +from .common import EvaluationItemRequestDict +from .common import EvaluationItemRequestOrDict +from .common import EvaluationItemResult +from .common import EvaluationItemResultDict +from .common import EvaluationItemResultOrDict +from .common import EvaluationItemType +from .common import EvaluationMetric +from .common import EvaluationMetricDict +from .common import EvaluationMetricOrDict +from .common import EvaluationPrompt +from .common import EvaluationPromptDict +from .common import EvaluationPromptOrDict +from .common import EvaluationResult +from .common import EvaluationResultDict +from .common import EvaluationResultOrDict +from .common import EvaluationRun +from .common import EvaluationRunAgentConfig +from .common import EvaluationRunAgentConfigDict +from .common import EvaluationRunAgentConfigOrDict +from .common import EvaluationRunConfig +from .common import EvaluationRunConfigDict +from .common import EvaluationRunConfigOrDict +from .common import EvaluationRunDataSource +from .common import EvaluationRunDataSourceDict +from .common import EvaluationRunDataSourceOrDict +from .common import EvaluationRunDict +from .common import EvaluationRunInferenceConfig +from .common import EvaluationRunInferenceConfigDict +from .common import EvaluationRunInferenceConfigOrDict +from .common import EvaluationRunMetadata +from .common import EvaluationRunMetadataDict +from .common import EvaluationRunMetadataOrDict +from .common import EvaluationRunMetric +from .common import EvaluationRunMetricDict +from .common import EvaluationRunMetricOrDict +from .common import EvaluationRunOrDict +from .common import EvaluationRunPromptTemplate +from .common import EvaluationRunPromptTemplateDict +from .common import EvaluationRunPromptTemplateOrDict +from .common import EvaluationRunResults +from .common import EvaluationRunResultsDict +from .common import EvaluationRunResultsOrDict +from .common import EvaluationRunState +from .common import EvaluationSet +from .common import EvaluationSetDict +from .common import EvaluationSetOrDict +from .common import Event +from .common import EventActions +from .common import EventActionsDict +from .common import EventActionsOrDict +from .common import EventDict +from .common import EventMetadata +from .common import EventMetadataDict +from .common import EventMetadataOrDict +from .common import ExactMatchInput +from .common import ExactMatchInputDict +from .common import ExactMatchInputOrDict +from .common import ExactMatchInstance +from .common import ExactMatchInstanceDict +from .common import ExactMatchInstanceOrDict +from .common import ExactMatchResults +from .common import ExactMatchResultsDict +from .common import ExactMatchResultsOrDict +from .common import ExactMatchSpec +from .common import ExactMatchSpecDict +from .common import ExactMatchSpecOrDict +from .common import ExecuteCodeAgentEngineSandboxConfig +from .common import ExecuteCodeAgentEngineSandboxConfigDict +from .common import ExecuteCodeAgentEngineSandboxConfigOrDict +from .common import ExecuteSandboxEnvironmentResponse +from .common import ExecuteSandboxEnvironmentResponseDict +from .common import ExecuteSandboxEnvironmentResponseOrDict +from .common import FailedRubric +from .common import FailedRubricDict +from .common import FailedRubricOrDict +from .common import Framework +from .common import GeminiExample +from .common import GeminiExampleDict +from .common import GeminiExampleOrDict +from .common import GeminiRequestReadConfig +from .common import GeminiRequestReadConfigDict +from .common import GeminiRequestReadConfigOrDict +from .common import GeminiTemplateConfig +from .common import GeminiTemplateConfigDict +from .common import GeminiTemplateConfigOrDict +from .common import GenerateAgentEngineMemoriesConfig +from .common import GenerateAgentEngineMemoriesConfigDict +from .common import GenerateAgentEngineMemoriesConfigOrDict +from .common import GenerateInstanceRubricsResponse +from .common import GenerateInstanceRubricsResponseDict +from .common import GenerateInstanceRubricsResponseOrDict +from .common import GenerateLossClustersConfig +from .common import GenerateLossClustersConfigDict +from .common import GenerateLossClustersConfigOrDict +from .common import GenerateLossClustersOperation +from .common import GenerateLossClustersOperationDict +from .common import GenerateLossClustersOperationOrDict +from .common import GenerateLossClustersResponse +from .common import GenerateLossClustersResponseDict +from .common import GenerateLossClustersResponseOrDict +from .common import GenerateMemoriesRequestDirectContentsSource +from .common import GenerateMemoriesRequestDirectContentsSourceDict +from .common import GenerateMemoriesRequestDirectContentsSourceEvent +from .common import GenerateMemoriesRequestDirectContentsSourceEventDict +from .common import GenerateMemoriesRequestDirectContentsSourceEventOrDict +from .common import GenerateMemoriesRequestDirectContentsSourceOrDict +from .common import GenerateMemoriesRequestDirectMemoriesSource +from .common import GenerateMemoriesRequestDirectMemoriesSourceDict +from .common import GenerateMemoriesRequestDirectMemoriesSourceDirectMemory +from .common import GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict +from .common import GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryOrDict +from .common import GenerateMemoriesRequestDirectMemoriesSourceOrDict +from .common import GenerateMemoriesRequestVertexSessionSource +from .common import GenerateMemoriesRequestVertexSessionSourceDict +from .common import GenerateMemoriesRequestVertexSessionSourceOrDict +from .common import GenerateMemoriesResponse +from .common import GenerateMemoriesResponseDict +from .common import GenerateMemoriesResponseGeneratedMemory +from .common import GenerateMemoriesResponseGeneratedMemoryAction +from .common import GenerateMemoriesResponseGeneratedMemoryDict +from .common import GenerateMemoriesResponseGeneratedMemoryOrDict +from .common import GenerateMemoriesResponseOrDict +from .common import GenerateUserScenariosConfig +from .common import GenerateUserScenariosConfigDict +from .common import GenerateUserScenariosConfigOrDict +from .common import GenerateUserScenariosResponse +from .common import GenerateUserScenariosResponseDict +from .common import GenerateUserScenariosResponseOrDict +from .common import GetAgentEngineConfig +from .common import GetAgentEngineConfigDict +from .common import GetAgentEngineConfigOrDict +from .common import GetAgentEngineMemoryConfig +from .common import GetAgentEngineMemoryConfigDict +from .common import GetAgentEngineMemoryConfigOrDict +from .common import GetAgentEngineMemoryRevisionConfig +from .common import GetAgentEngineMemoryRevisionConfigDict +from .common import GetAgentEngineMemoryRevisionConfigOrDict +from .common import GetAgentEngineOperationConfig +from .common import GetAgentEngineOperationConfigDict +from .common import GetAgentEngineOperationConfigOrDict +from .common import GetAgentEngineRuntimeRevisionConfig +from .common import GetAgentEngineRuntimeRevisionConfigDict +from .common import GetAgentEngineRuntimeRevisionConfigOrDict +from .common import GetAgentEngineSandboxConfig +from .common import GetAgentEngineSandboxConfigDict +from .common import GetAgentEngineSandboxConfigOrDict +from .common import GetAgentEngineSessionConfig +from .common import GetAgentEngineSessionConfigDict +from .common import GetAgentEngineSessionConfigOrDict +from .common import GetAgentEngineTaskConfig +from .common import GetAgentEngineTaskConfigDict +from .common import GetAgentEngineTaskConfigOrDict +from .common import GetDatasetOperationConfig +from .common import GetDatasetOperationConfigDict +from .common import GetDatasetOperationConfigOrDict +from .common import GetDeleteAgentEngineRuntimeRevisionOperationConfig +from .common import GetDeleteAgentEngineRuntimeRevisionOperationConfigDict +from .common import GetDeleteAgentEngineRuntimeRevisionOperationConfigOrDict +from .common import GetEvaluationItemConfig +from .common import GetEvaluationItemConfigDict +from .common import GetEvaluationItemConfigOrDict +from .common import GetEvaluationMetricConfig +from .common import GetEvaluationMetricConfigDict +from .common import GetEvaluationMetricConfigOrDict +from .common import GetEvaluationRunConfig +from .common import GetEvaluationRunConfigDict +from .common import GetEvaluationRunConfigOrDict +from .common import GetEvaluationSetConfig +from .common import GetEvaluationSetConfigDict +from .common import GetEvaluationSetConfigOrDict +from .common import GetMultimodalDatasetOperationConfig +from .common import GetMultimodalDatasetOperationConfigDict +from .common import GetMultimodalDatasetOperationConfigOrDict +from .common import GetPromptConfig +from .common import GetPromptConfigDict +from .common import GetPromptConfigOrDict +from .common import GetSandboxEnvironmentSnapshotConfig +from .common import GetSandboxEnvironmentSnapshotConfigDict +from .common import GetSandboxEnvironmentSnapshotConfigOrDict +from .common import GetSandboxEnvironmentTemplateConfig +from .common import GetSandboxEnvironmentTemplateConfigDict +from .common import GetSandboxEnvironmentTemplateConfigOrDict +from .common import GetSkillConfig +from .common import GetSkillConfigDict +from .common import GetSkillConfigOrDict +from .common import GetSkillOperationConfig +from .common import GetSkillOperationConfigDict +from .common import GetSkillOperationConfigOrDict +from .common import IdentityType +from .common import Importance +from .common import IngestEventsConfig +from .common import IngestEventsConfigDict +from .common import IngestEventsConfigOrDict +from .common import IngestionDirectContentsSource +from .common import IngestionDirectContentsSourceDict +from .common import IngestionDirectContentsSourceEvent +from .common import IngestionDirectContentsSourceEventDict +from .common import IngestionDirectContentsSourceEventOrDict +from .common import IngestionDirectContentsSourceOrDict +from .common import IntermediateExtractedMemory +from .common import IntermediateExtractedMemoryDict +from .common import IntermediateExtractedMemoryOrDict +from .common import JobState +from .common import KeepAliveProbe +from .common import KeepAliveProbeDict +from .common import KeepAliveProbeHttpGet +from .common import KeepAliveProbeHttpGetDict +from .common import KeepAliveProbeHttpGetOrDict +from .common import KeepAliveProbeOrDict +from .common import Language +from .common import ListAgentEngineConfig +from .common import ListAgentEngineConfigDict +from .common import ListAgentEngineConfigOrDict +from .common import ListAgentEngineMemoryConfig +from .common import ListAgentEngineMemoryConfigDict +from .common import ListAgentEngineMemoryConfigOrDict +from .common import ListAgentEngineMemoryRevisionsConfig +from .common import ListAgentEngineMemoryRevisionsConfigDict +from .common import ListAgentEngineMemoryRevisionsConfigOrDict +from .common import ListAgentEngineMemoryRevisionsResponse +from .common import ListAgentEngineMemoryRevisionsResponseDict +from .common import ListAgentEngineMemoryRevisionsResponseOrDict +from .common import ListAgentEngineRuntimeRevisionsConfig +from .common import ListAgentEngineRuntimeRevisionsConfigDict +from .common import ListAgentEngineRuntimeRevisionsConfigOrDict +from .common import ListAgentEngineSandboxesConfig +from .common import ListAgentEngineSandboxesConfigDict +from .common import ListAgentEngineSandboxesConfigOrDict +from .common import ListAgentEngineSandboxesResponse +from .common import ListAgentEngineSandboxesResponseDict +from .common import ListAgentEngineSandboxesResponseOrDict +from .common import ListAgentEngineSessionEventsConfig +from .common import ListAgentEngineSessionEventsConfigDict +from .common import ListAgentEngineSessionEventsConfigOrDict +from .common import ListAgentEngineSessionEventsResponse +from .common import ListAgentEngineSessionEventsResponseDict +from .common import ListAgentEngineSessionEventsResponseOrDict +from .common import ListAgentEngineSessionsConfig +from .common import ListAgentEngineSessionsConfigDict +from .common import ListAgentEngineSessionsConfigOrDict +from .common import ListAgentEngineTaskEventsConfig +from .common import ListAgentEngineTaskEventsConfigDict +from .common import ListAgentEngineTaskEventsConfigOrDict +from .common import ListAgentEngineTaskEventsResponse +from .common import ListAgentEngineTaskEventsResponseDict +from .common import ListAgentEngineTaskEventsResponseOrDict +from .common import ListAgentEngineTasksConfig +from .common import ListAgentEngineTasksConfigDict +from .common import ListAgentEngineTasksConfigOrDict +from .common import ListAgentEngineTasksResponse +from .common import ListAgentEngineTasksResponseDict +from .common import ListAgentEngineTasksResponseOrDict +from .common import ListDatasetsResponse +from .common import ListDatasetsResponseDict +from .common import ListDatasetsResponseOrDict +from .common import ListDatasetVersionsResponse +from .common import ListDatasetVersionsResponseDict +from .common import ListDatasetVersionsResponseOrDict +from .common import ListEvaluationMetricsConfig +from .common import ListEvaluationMetricsConfigDict +from .common import ListEvaluationMetricsConfigOrDict +from .common import ListEvaluationMetricsResponse +from .common import ListEvaluationMetricsResponseDict +from .common import ListEvaluationMetricsResponseOrDict +from .common import ListMultimodalDatasetsConfig +from .common import ListMultimodalDatasetsConfigDict +from .common import ListMultimodalDatasetsConfigOrDict +from .common import ListMultimodalDatasetsResponse +from .common import ListMultimodalDatasetsResponseDict +from .common import ListMultimodalDatasetsResponseOrDict +from .common import ListPromptsConfig +from .common import ListPromptsConfigDict +from .common import ListPromptsConfigOrDict +from .common import ListReasoningEnginesMemoriesResponse +from .common import ListReasoningEnginesMemoriesResponseDict +from .common import ListReasoningEnginesMemoriesResponseOrDict +from .common import ListReasoningEnginesResponse +from .common import ListReasoningEnginesResponseDict +from .common import ListReasoningEnginesResponseOrDict +from .common import ListReasoningEnginesRuntimeRevisionsResponse +from .common import ListReasoningEnginesRuntimeRevisionsResponseDict +from .common import ListReasoningEnginesRuntimeRevisionsResponseOrDict +from .common import ListReasoningEnginesSessionsResponse +from .common import ListReasoningEnginesSessionsResponseDict +from .common import ListReasoningEnginesSessionsResponseOrDict +from .common import ListSandboxEnvironmentSnapshotsConfig +from .common import ListSandboxEnvironmentSnapshotsConfigDict +from .common import ListSandboxEnvironmentSnapshotsConfigOrDict +from .common import ListSandboxEnvironmentSnapshotsResponse +from .common import ListSandboxEnvironmentSnapshotsResponseDict +from .common import ListSandboxEnvironmentSnapshotsResponseOrDict +from .common import ListSandboxEnvironmentTemplatesConfig +from .common import ListSandboxEnvironmentTemplatesConfigDict +from .common import ListSandboxEnvironmentTemplatesConfigOrDict +from .common import ListSandboxEnvironmentTemplatesResponse +from .common import ListSandboxEnvironmentTemplatesResponseDict +from .common import ListSandboxEnvironmentTemplatesResponseOrDict +from .common import LLMMetric +from .common import LossAnalysisConfig +from .common import LossAnalysisConfigDict +from .common import LossAnalysisConfigOrDict +from .common import LossAnalysisResult +from .common import LossAnalysisResultDict +from .common import LossAnalysisResultOrDict +from .common import LossCluster +from .common import LossClusterDict +from .common import LossClusterOrDict +from .common import LossExample +from .common import LossExampleDict +from .common import LossExampleOrDict +from .common import LossTaxonomyEntry +from .common import LossTaxonomyEntryDict +from .common import LossTaxonomyEntryOrDict +from .common import LustreMount +from .common import LustreMountDict +from .common import LustreMountOrDict +from .common import MachineConfig +from .common import MachineSpec +from .common import MachineSpecDict +from .common import MachineSpecOrDict +from .common import ManagedTopicEnum +from .common import MapInstance +from .common import MapInstanceDict +from .common import MapInstanceOrDict +from .common import Memory +from .common import MemoryBankCustomizationConfig +from .common import MemoryBankCustomizationConfigConsolidationConfig +from .common import MemoryBankCustomizationConfigConsolidationConfigDict +from .common import MemoryBankCustomizationConfigConsolidationConfigOrDict +from .common import MemoryBankCustomizationConfigDict +from .common import MemoryBankCustomizationConfigGenerateMemoriesExample +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSource, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEvent, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventDict, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventOrDict, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceOrDict, +) +from .common import MemoryBankCustomizationConfigGenerateMemoriesExampleDict +from .common import MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict, +) +from .common import ( + MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryOrDict, +) +from .common import MemoryBankCustomizationConfigGenerateMemoriesExampleOrDict +from .common import MemoryBankCustomizationConfigMemoryTopic +from .common import MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopic +from .common import MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicDict +from .common import MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicOrDict +from .common import MemoryBankCustomizationConfigMemoryTopicDict +from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopic +from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicDict +from .common import MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict +from .common import MemoryBankCustomizationConfigMemoryTopicOrDict +from .common import MemoryBankCustomizationConfigOrDict +from .common import MemoryBankIngestEventsOperation +from .common import MemoryBankIngestEventsOperationDict +from .common import MemoryBankIngestEventsOperationOrDict +from .common import MemoryConjunctionFilter +from .common import MemoryConjunctionFilterDict +from .common import MemoryConjunctionFilterOrDict +from .common import MemoryDict +from .common import MemoryFilter +from .common import MemoryFilterDict +from .common import MemoryFilterOrDict +from .common import MemoryGenerationTriggerConfig +from .common import MemoryGenerationTriggerConfigDict +from .common import MemoryGenerationTriggerConfigGenerationTriggerRule +from .common import MemoryGenerationTriggerConfigGenerationTriggerRuleDict +from .common import MemoryGenerationTriggerConfigGenerationTriggerRuleOrDict +from .common import MemoryGenerationTriggerConfigOrDict +from .common import MemoryMetadataMergeStrategy +from .common import MemoryMetadataValue +from .common import MemoryMetadataValueDict +from .common import MemoryMetadataValueOrDict +from .common import MemoryOrDict +from .common import MemoryProfile +from .common import MemoryProfileDict +from .common import MemoryProfileOrDict +from .common import MemoryRevision +from .common import MemoryRevisionDict +from .common import MemoryRevisionOrDict +from .common import MemoryStructuredContent +from .common import MemoryStructuredContentDict +from .common import MemoryStructuredContentOrDict +from .common import MemoryTopicId +from .common import MemoryTopicIdDict +from .common import MemoryTopicIdOrDict +from .common import MemoryType +from .common import Message +from .common import MessageDict +from .common import Metadata +from .common import MetadataDict +from .common import MetadataOrDict +from .common import Metric +from .common import MetricDict +from .common import MetricOrDict +from .common import MetricPromptBuilder +from .common import MetricResult +from .common import MetricResultDict +from .common import MetricResultOrDict +from .common import MetricSource +from .common import MetricSourceDict +from .common import MetricSourceOrDict +from .common import MetricxResult +from .common import MetricxResultDict +from .common import MetricxResultOrDict +from .common import MultimodalDataset +from .common import MultimodalDatasetDict +from .common import MultimodalDatasetOperation +from .common import MultimodalDatasetOperationDict +from .common import MultimodalDatasetOperationOrDict +from .common import MultimodalDatasetOrDict +from .common import NfsMount +from .common import NfsMountDict +from .common import NfsMountOrDict +from .common import ObservabilityEvalCase +from .common import ObservabilityEvalCaseDict +from .common import ObservabilityEvalCaseOrDict +from .common import Operator +from .common import OptimizationMethod +from .common import OptimizeConfig +from .common import OptimizeConfigDict +from .common import OptimizeConfigOrDict +from .common import OptimizeJobConfig +from .common import OptimizeJobConfigDict +from .common import OptimizeJobConfigOrDict +from .common import OptimizeResponse +from .common import OptimizeResponseDict +from .common import OptimizeResponseEndpoint +from .common import OptimizeResponseEndpointDict +from .common import OptimizeResponseEndpointOrDict +from .common import OptimizeResponseOrDict +from .common import OptimizeTarget +from .common import PairwiseMetricInput +from .common import PairwiseMetricInputDict +from .common import PairwiseMetricInputOrDict +from .common import PairwiseMetricInstance +from .common import PairwiseMetricInstanceDict +from .common import PairwiseMetricInstanceOrDict +from .common import ParsedResponseUnion +from .common import PointwiseMetricInput +from .common import PointwiseMetricInputDict +from .common import PointwiseMetricInputOrDict +from .common import PointwiseMetricInstance +from .common import PointwiseMetricInstanceDict +from .common import PointwiseMetricInstanceOrDict +from .common import PostSnapshotAction +from .common import Prompt +from .common import PromptData +from .common import PromptDataDict +from .common import PromptDataOrDict +from .common import PromptDict +from .common import PromptOptimizerConfig +from .common import PromptOptimizerConfigDict +from .common import PromptOptimizerConfigOrDict +from .common import PromptOptimizerMethod +from .common import PromptOrDict +from .common import PromptRef +from .common import PromptRefDict +from .common import PromptRefOrDict +from .common import PromptTemplate +from .common import PromptTemplateData +from .common import PromptTemplateDataDict +from .common import PromptTemplateDataOrDict +from .common import PromptTemplateDict +from .common import PromptTemplateOrDict +from .common import PromptVersionRef +from .common import PromptVersionRefDict +from .common import PromptVersionRefOrDict +from .common import Protocol +from .common import PscInterfaceConfig +from .common import PscInterfaceConfigDict +from .common import PscInterfaceConfigOrDict +from .common import PurgeAgentEngineMemoriesConfig +from .common import PurgeAgentEngineMemoriesConfigDict +from .common import PurgeAgentEngineMemoriesConfigOrDict +from .common import PurgeMemoriesResponse +from .common import PurgeMemoriesResponseDict +from .common import PurgeMemoriesResponseOrDict +from .common import PythonPackageSpec +from .common import PythonPackageSpecDict +from .common import PythonPackageSpecOrDict +from .common import QueryAgentEngineConfig +from .common import QueryAgentEngineConfigDict +from .common import QueryAgentEngineConfigOrDict +from .common import QueryAgentEngineRuntimeRevisionConfig +from .common import QueryAgentEngineRuntimeRevisionConfigDict +from .common import QueryAgentEngineRuntimeRevisionConfigOrDict +from .common import QueryReasoningEngineResponse +from .common import QueryReasoningEngineResponseDict +from .common import QueryReasoningEngineResponseOrDict +from .common import ReasoningEngine +from .common import ReasoningEngineContextSpec +from .common import ReasoningEngineContextSpecDict +from .common import ReasoningEngineContextSpecMemoryBankConfig +from .common import ReasoningEngineContextSpecMemoryBankConfigDict +from .common import ReasoningEngineContextSpecMemoryBankConfigGenerationConfig +from .common import ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict +from .common import ReasoningEngineContextSpecMemoryBankConfigGenerationConfigOrDict +from .common import ReasoningEngineContextSpecMemoryBankConfigOrDict +from .common import ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfig +from .common import ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigDict +from .common import ( + ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigOrDict, +) +from .common import ReasoningEngineContextSpecMemoryBankConfigTtlConfig +from .common import ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict +from .common import ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfig +from .common import ( + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigDict, +) +from .common import ( + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigOrDict, +) +from .common import ReasoningEngineContextSpecMemoryBankConfigTtlConfigOrDict +from .common import ReasoningEngineContextSpecOrDict +from .common import ReasoningEngineDict +from .common import ReasoningEngineOrDict +from .common import ReasoningEngineRuntimeRevision +from .common import ReasoningEngineRuntimeRevisionDict +from .common import ReasoningEngineRuntimeRevisionOrDict +from .common import ReasoningEngineSpec +from .common import ReasoningEngineSpecContainerSpec +from .common import ReasoningEngineSpecContainerSpecDict +from .common import ReasoningEngineSpecContainerSpecOrDict +from .common import ReasoningEngineSpecDeploymentSpec +from .common import ReasoningEngineSpecDeploymentSpecAgentGatewayConfig +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfig, +) +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigDict, +) +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigOrDict, +) +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfig, +) +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigDict, +) +from .common import ( + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigOrDict, +) +from .common import ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict +from .common import ReasoningEngineSpecDeploymentSpecAgentGatewayConfigOrDict +from .common import ReasoningEngineSpecDeploymentSpecDict +from .common import ReasoningEngineSpecDeploymentSpecOrDict +from .common import ReasoningEngineSpecDict +from .common import ReasoningEngineSpecOrDict +from .common import ReasoningEngineSpecPackageSpec +from .common import ReasoningEngineSpecPackageSpecDict +from .common import ReasoningEngineSpecPackageSpecOrDict +from .common import ReasoningEngineSpecSourceCodeSpec +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSource +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfig +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigDict +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigOrDict +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict +from .common import ReasoningEngineSpecSourceCodeSpecAgentConfigSourceOrDict +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigOrDict +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectSource +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceDict +from .common import ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceOrDict +from .common import ReasoningEngineSpecSourceCodeSpecDict +from .common import ReasoningEngineSpecSourceCodeSpecImageSpec +from .common import ReasoningEngineSpecSourceCodeSpecImageSpecDict +from .common import ReasoningEngineSpecSourceCodeSpecImageSpecOrDict +from .common import ReasoningEngineSpecSourceCodeSpecInlineSource +from .common import ReasoningEngineSpecSourceCodeSpecInlineSourceDict +from .common import ReasoningEngineSpecSourceCodeSpecInlineSourceOrDict +from .common import ReasoningEngineSpecSourceCodeSpecOrDict +from .common import ReasoningEngineSpecSourceCodeSpecPythonSpec +from .common import ReasoningEngineSpecSourceCodeSpecPythonSpecDict +from .common import ReasoningEngineSpecSourceCodeSpecPythonSpecOrDict +from .common import ReasoningEngineTrafficConfig +from .common import ReasoningEngineTrafficConfigDict +from .common import ReasoningEngineTrafficConfigOrDict +from .common import ReasoningEngineTrafficConfigTrafficSplitAlwaysLatest +from .common import ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestDict +from .common import ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestOrDict +from .common import ReasoningEngineTrafficConfigTrafficSplitManual +from .common import ReasoningEngineTrafficConfigTrafficSplitManualDict +from .common import ReasoningEngineTrafficConfigTrafficSplitManualOrDict +from .common import ReasoningEngineTrafficConfigTrafficSplitManualTarget +from .common import ReasoningEngineTrafficConfigTrafficSplitManualTargetDict +from .common import ReasoningEngineTrafficConfigTrafficSplitManualTargetOrDict +from .common import ReservationAffinity +from .common import ReservationAffinityDict +from .common import ReservationAffinityOrDict +from .common import ResponseCandidate +from .common import ResponseCandidateDict +from .common import ResponseCandidateOrDict +from .common import ResponseCandidateResult +from .common import ResponseCandidateResultDict +from .common import ResponseCandidateResultOrDict +from .common import RestoreVersionConfig +from .common import RestoreVersionConfigDict +from .common import RestoreVersionConfigOrDict +from .common import RestoreVersionOperation +from .common import RestoreVersionOperationDict +from .common import RestoreVersionOperationOrDict +from .common import RetrieveAgentEngineMemoriesConfig +from .common import RetrieveAgentEngineMemoriesConfigDict +from .common import RetrieveAgentEngineMemoriesConfigOrDict +from .common import RetrievedSkill +from .common import RetrievedSkillDict +from .common import RetrievedSkillOrDict +from .common import RetrieveMemoriesRequestSimilaritySearchParams +from .common import RetrieveMemoriesRequestSimilaritySearchParamsDict +from .common import RetrieveMemoriesRequestSimilaritySearchParamsOrDict +from .common import RetrieveMemoriesRequestSimpleRetrievalParams +from .common import RetrieveMemoriesRequestSimpleRetrievalParamsDict +from .common import RetrieveMemoriesRequestSimpleRetrievalParamsOrDict +from .common import RetrieveMemoriesResponse +from .common import RetrieveMemoriesResponseDict +from .common import RetrieveMemoriesResponseOrDict +from .common import RetrieveMemoriesResponseRetrievedMemory +from .common import RetrieveMemoriesResponseRetrievedMemoryDict +from .common import RetrieveMemoriesResponseRetrievedMemoryOrDict +from .common import RetrieveMemoryProfilesConfig +from .common import RetrieveMemoryProfilesConfigDict +from .common import RetrieveMemoryProfilesConfigOrDict +from .common import RetrieveProfilesResponse +from .common import RetrieveProfilesResponseDict +from .common import RetrieveProfilesResponseOrDict +from .common import RetrieveSkillsConfig +from .common import RetrieveSkillsConfigDict +from .common import RetrieveSkillsConfigOrDict +from .common import RetrieveSkillsResponse +from .common import RetrieveSkillsResponseDict +from .common import RetrieveSkillsResponseOrDict +from .common import RollbackAgentEngineMemoryConfig +from .common import RollbackAgentEngineMemoryConfigDict +from .common import RollbackAgentEngineMemoryConfigOrDict +from .common import RougeInput +from .common import RougeInputDict +from .common import RougeInputOrDict +from .common import RougeInstance +from .common import RougeInstanceDict +from .common import RougeInstanceOrDict +from .common import RougeResults +from .common import RougeResultsDict +from .common import RougeResultsOrDict +from .common import Rubric +from .common import RubricBasedMetricInput +from .common import RubricBasedMetricInputDict +from .common import RubricBasedMetricInputOrDict +from .common import RubricBasedMetricInstance +from .common import RubricBasedMetricInstanceDict +from .common import RubricBasedMetricInstanceOrDict +from .common import RubricBasedMetricResult +from .common import RubricBasedMetricResultDict +from .common import RubricBasedMetricResultOrDict +from .common import RubricBasedMetricSpec +from .common import RubricBasedMetricSpecDict +from .common import RubricBasedMetricSpecOrDict +from .common import RubricContent +from .common import RubricContentDict +from .common import RubricContentProperty +from .common import RubricContentPropertyDict +from .common import RubricDict +from .common import RubricEnhancedContents +from .common import RubricEnhancedContentsDict +from .common import RubricEnhancedContentsOrDict +from .common import RubricGenerationConfig +from .common import RubricGenerationConfigDict +from .common import RubricGenerationConfigOrDict +from .common import RubricGroup +from .common import RubricGroupDict +from .common import RubricGroupOrDict +from .common import RubricVerdict +from .common import RubricVerdictDict +from .common import RunQueryJobAgentEngineConfig +from .common import RunQueryJobAgentEngineConfigDict +from .common import RunQueryJobAgentEngineConfigOrDict +from .common import RunQueryJobResult +from .common import RunQueryJobResultDict +from .common import RunQueryJobResultOrDict +from .common import SamplingConfig +from .common import SamplingConfigDict +from .common import SamplingConfigOrDict +from .common import SamplingMethod +from .common import SandboxEnvironment +from .common import SandboxEnvironmentConnectionInfo +from .common import SandboxEnvironmentConnectionInfoDict +from .common import SandboxEnvironmentConnectionInfoOrDict +from .common import SandboxEnvironmentDict +from .common import SandboxEnvironmentOrDict +from .common import SandboxEnvironmentSnapshot +from .common import SandboxEnvironmentSnapshotDict +from .common import SandboxEnvironmentSnapshotOrDict +from .common import SandboxEnvironmentSpec +from .common import SandboxEnvironmentSpecCodeExecutionEnvironment +from .common import SandboxEnvironmentSpecCodeExecutionEnvironmentDict +from .common import SandboxEnvironmentSpecCodeExecutionEnvironmentOrDict +from .common import SandboxEnvironmentSpecComputerUseEnvironment +from .common import SandboxEnvironmentSpecComputerUseEnvironmentDict +from .common import SandboxEnvironmentSpecComputerUseEnvironmentOrDict +from .common import SandboxEnvironmentSpecDict +from .common import SandboxEnvironmentSpecOrDict +from .common import SandboxEnvironmentTemplate +from .common import SandboxEnvironmentTemplateCustomContainerEnvironment +from .common import SandboxEnvironmentTemplateCustomContainerEnvironmentDict +from .common import SandboxEnvironmentTemplateCustomContainerEnvironmentOrDict +from .common import SandboxEnvironmentTemplateCustomContainerSpec +from .common import SandboxEnvironmentTemplateCustomContainerSpecDict +from .common import SandboxEnvironmentTemplateCustomContainerSpecOrDict +from .common import SandboxEnvironmentTemplateDefaultContainerEnvironment +from .common import SandboxEnvironmentTemplateDefaultContainerEnvironmentDict +from .common import SandboxEnvironmentTemplateDefaultContainerEnvironmentOrDict +from .common import SandboxEnvironmentTemplateDict +from .common import SandboxEnvironmentTemplateEgressControlConfig +from .common import SandboxEnvironmentTemplateEgressControlConfigDict +from .common import SandboxEnvironmentTemplateEgressControlConfigOrDict +from .common import SandboxEnvironmentTemplateNetworkPort +from .common import SandboxEnvironmentTemplateNetworkPortDict +from .common import SandboxEnvironmentTemplateNetworkPortOrDict +from .common import SandboxEnvironmentTemplateOperation +from .common import SandboxEnvironmentTemplateOperationDict +from .common import SandboxEnvironmentTemplateOperationOrDict +from .common import SandboxEnvironmentTemplateOrDict +from .common import SandboxEnvironmentTemplateResourceRequirements +from .common import SandboxEnvironmentTemplateResourceRequirementsDict +from .common import SandboxEnvironmentTemplateResourceRequirementsOrDict +from .common import SandboxEnvironmentTemplateWarmPoolConfig +from .common import SandboxEnvironmentTemplateWarmPoolConfigDict +from .common import SandboxEnvironmentTemplateWarmPoolConfigOrDict +from .common import SavedQuery +from .common import SavedQueryDict +from .common import SavedQueryOrDict +from .common import Scheduling +from .common import SchedulingDict +from .common import SchedulingOrDict +from .common import SchemaPredictParamsGroundingConfig +from .common import SchemaPredictParamsGroundingConfigDict +from .common import SchemaPredictParamsGroundingConfigOrDict +from .common import SchemaPredictParamsGroundingConfigSourceEntry +from .common import SchemaPredictParamsGroundingConfigSourceEntryDict +from .common import SchemaPredictParamsGroundingConfigSourceEntryOrDict +from .common import SchemaPromptApiSchema +from .common import SchemaPromptApiSchemaDict +from .common import SchemaPromptApiSchemaOrDict +from .common import SchemaPromptInstancePromptExecution +from .common import SchemaPromptInstancePromptExecutionDict +from .common import SchemaPromptInstancePromptExecutionOrDict +from .common import SchemaPromptInstanceVariableValue +from .common import SchemaPromptInstanceVariableValueDict +from .common import SchemaPromptInstanceVariableValueOrDict +from .common import SchemaPromptSpecAppBuilderData +from .common import SchemaPromptSpecAppBuilderDataDict +from .common import SchemaPromptSpecAppBuilderDataLinkedResource +from .common import SchemaPromptSpecAppBuilderDataLinkedResourceDict +from .common import SchemaPromptSpecAppBuilderDataLinkedResourceOrDict +from .common import SchemaPromptSpecAppBuilderDataOrDict +from .common import SchemaPromptSpecMultimodalPrompt +from .common import SchemaPromptSpecMultimodalPromptDict +from .common import SchemaPromptSpecMultimodalPromptOrDict +from .common import SchemaPromptSpecPartList +from .common import SchemaPromptSpecPartListDict +from .common import SchemaPromptSpecPartListOrDict +from .common import SchemaPromptSpecPromptMessage +from .common import SchemaPromptSpecPromptMessageDict +from .common import SchemaPromptSpecPromptMessageOrDict +from .common import SchemaPromptSpecReferenceSentencePair +from .common import SchemaPromptSpecReferenceSentencePairDict +from .common import SchemaPromptSpecReferenceSentencePairList +from .common import SchemaPromptSpecReferenceSentencePairListDict +from .common import SchemaPromptSpecReferenceSentencePairListOrDict +from .common import SchemaPromptSpecReferenceSentencePairOrDict +from .common import SchemaPromptSpecStructuredPrompt +from .common import SchemaPromptSpecStructuredPromptDict +from .common import SchemaPromptSpecStructuredPromptOrDict +from .common import SchemaPromptSpecTranslationExample +from .common import SchemaPromptSpecTranslationExampleDict +from .common import SchemaPromptSpecTranslationExampleOrDict +from .common import SchemaPromptSpecTranslationFileInputSource +from .common import SchemaPromptSpecTranslationFileInputSourceDict +from .common import SchemaPromptSpecTranslationFileInputSourceOrDict +from .common import SchemaPromptSpecTranslationGcsInputSource +from .common import SchemaPromptSpecTranslationGcsInputSourceDict +from .common import SchemaPromptSpecTranslationGcsInputSourceOrDict +from .common import SchemaPromptSpecTranslationOption +from .common import SchemaPromptSpecTranslationOptionDict +from .common import SchemaPromptSpecTranslationOptionOrDict +from .common import SchemaPromptSpecTranslationPrompt +from .common import SchemaPromptSpecTranslationPromptDict +from .common import SchemaPromptSpecTranslationPromptOrDict +from .common import SchemaPromptSpecTranslationSentenceFileInput +from .common import SchemaPromptSpecTranslationSentenceFileInputDict +from .common import SchemaPromptSpecTranslationSentenceFileInputOrDict +from .common import SchemaTablesDatasetMetadata +from .common import SchemaTablesDatasetMetadataBigQuerySource +from .common import SchemaTablesDatasetMetadataBigQuerySourceDict +from .common import SchemaTablesDatasetMetadataBigQuerySourceOrDict +from .common import SchemaTablesDatasetMetadataDict +from .common import SchemaTablesDatasetMetadataInputConfig +from .common import SchemaTablesDatasetMetadataInputConfigDict +from .common import SchemaTablesDatasetMetadataInputConfigOrDict +from .common import SchemaTablesDatasetMetadataOrDict +from .common import SchemaTextPromptDatasetMetadata +from .common import SchemaTextPromptDatasetMetadataDict +from .common import SchemaTextPromptDatasetMetadataOrDict +from .common import SecretEnvVar +from .common import SecretEnvVarDict +from .common import SecretEnvVarOrDict +from .common import SecretRef +from .common import SecretRefDict +from .common import SecretRefOrDict +from .common import Session +from .common import SessionDict +from .common import SessionEvent +from .common import SessionEventDict +from .common import SessionEventOrDict +from .common import SessionOrDict +from .common import Skill +from .common import SkillDict +from .common import SkillOperation +from .common import SkillOperationDict +from .common import SkillOperationOrDict +from .common import SkillOrDict +from .common import SkillState +from .common import State +from .common import Strategy +from .common import StructuredMemoryConfig +from .common import StructuredMemoryConfigDict +from .common import StructuredMemoryConfigOrDict +from .common import StructuredMemorySchemaConfig +from .common import StructuredMemorySchemaConfigDict +from .common import StructuredMemorySchemaConfigOrDict +from .common import SummaryMetric +from .common import SummaryMetricDict +from .common import SummaryMetricOrDict +from .common import TaskArtifact +from .common import TaskArtifactChange +from .common import TaskArtifactChangeDict +from .common import TaskArtifactChangeOrDict +from .common import TaskArtifactDict +from .common import TaskArtifactOrDict +from .common import TaskEvent +from .common import TaskEventData +from .common import TaskEventDataDict +from .common import TaskEventDataOrDict +from .common import TaskEventDict +from .common import TaskEventOrDict +from .common import TaskMessage +from .common import TaskMessageDict +from .common import TaskMessageOrDict +from .common import TaskMetadataChange +from .common import TaskMetadataChangeDict +from .common import TaskMetadataChangeOrDict +from .common import TaskOutput +from .common import TaskOutputChange +from .common import TaskOutputChangeDict +from .common import TaskOutputChangeOrDict +from .common import TaskOutputDict +from .common import TaskOutputOrDict +from .common import TaskStateChange +from .common import TaskStateChangeDict +from .common import TaskStateChangeOrDict +from .common import TaskStatusDetails +from .common import TaskStatusDetailsChange +from .common import TaskStatusDetailsChangeDict +from .common import TaskStatusDetailsChangeOrDict +from .common import TaskStatusDetailsDict +from .common import TaskStatusDetailsOrDict +from .common import ToolCallValidInput +from .common import ToolCallValidInputDict +from .common import ToolCallValidInputOrDict +from .common import ToolCallValidInstance +from .common import ToolCallValidInstanceDict +from .common import ToolCallValidInstanceOrDict +from .common import ToolCallValidMetricValue +from .common import ToolCallValidMetricValueDict +from .common import ToolCallValidMetricValueOrDict +from .common import ToolCallValidResults +from .common import ToolCallValidResultsDict +from .common import ToolCallValidResultsOrDict +from .common import ToolCallValidSpec +from .common import ToolCallValidSpecDict +from .common import ToolCallValidSpecOrDict +from .common import ToolNameMatchInput +from .common import ToolNameMatchInputDict +from .common import ToolNameMatchInputOrDict +from .common import ToolNameMatchInstance +from .common import ToolNameMatchInstanceDict +from .common import ToolNameMatchInstanceOrDict +from .common import ToolNameMatchMetricValue +from .common import ToolNameMatchMetricValueDict +from .common import ToolNameMatchMetricValueOrDict +from .common import ToolNameMatchResults +from .common import ToolNameMatchResultsDict +from .common import ToolNameMatchResultsOrDict +from .common import ToolNameMatchSpec +from .common import ToolNameMatchSpecDict +from .common import ToolNameMatchSpecOrDict +from .common import ToolParameterKeyMatchInput +from .common import ToolParameterKeyMatchInputDict +from .common import ToolParameterKeyMatchInputOrDict +from .common import ToolParameterKeyMatchInstance +from .common import ToolParameterKeyMatchInstanceDict +from .common import ToolParameterKeyMatchInstanceOrDict +from .common import ToolParameterKeyMatchMetricValue +from .common import ToolParameterKeyMatchMetricValueDict +from .common import ToolParameterKeyMatchMetricValueOrDict +from .common import ToolParameterKeyMatchResults +from .common import ToolParameterKeyMatchResultsDict +from .common import ToolParameterKeyMatchResultsOrDict +from .common import ToolParameterKeyMatchSpec +from .common import ToolParameterKeyMatchSpecDict +from .common import ToolParameterKeyMatchSpecOrDict +from .common import ToolParameterKVMatchInput +from .common import ToolParameterKVMatchInputDict +from .common import ToolParameterKVMatchInputOrDict +from .common import ToolParameterKVMatchInstance +from .common import ToolParameterKVMatchInstanceDict +from .common import ToolParameterKVMatchInstanceOrDict +from .common import ToolParameterKVMatchMetricValue +from .common import ToolParameterKVMatchMetricValueDict +from .common import ToolParameterKVMatchMetricValueOrDict +from .common import ToolParameterKVMatchResults +from .common import ToolParameterKVMatchResultsDict +from .common import ToolParameterKVMatchResultsOrDict +from .common import ToolParameterKVMatchSpec +from .common import ToolParameterKVMatchSpecDict +from .common import ToolParameterKVMatchSpecOrDict +from .common import TuningResourceUsageAssessmentConfig +from .common import TuningResourceUsageAssessmentConfigDict +from .common import TuningResourceUsageAssessmentConfigOrDict +from .common import TuningResourceUsageAssessmentResult +from .common import TuningResourceUsageAssessmentResultDict +from .common import TuningResourceUsageAssessmentResultOrDict +from .common import TuningValidationAssessmentConfig +from .common import TuningValidationAssessmentConfigDict +from .common import TuningValidationAssessmentConfigOrDict +from .common import TuningValidationAssessmentResult +from .common import TuningValidationAssessmentResultDict +from .common import TuningValidationAssessmentResultOrDict +from .common import Type +from .common import UnifiedMetric +from .common import UnifiedMetricDict +from .common import UnifiedMetricOrDict +from .common import UpdateAgentEngineConfig +from .common import UpdateAgentEngineConfigDict +from .common import UpdateAgentEngineConfigOrDict +from .common import UpdateAgentEngineMemoryConfig +from .common import UpdateAgentEngineMemoryConfigDict +from .common import UpdateAgentEngineMemoryConfigOrDict +from .common import UpdateAgentEngineSessionConfig +from .common import UpdateAgentEngineSessionConfigDict +from .common import UpdateAgentEngineSessionConfigOrDict +from .common import UpdatePromptConfig +from .common import UpdatePromptConfigDict +from .common import UpdatePromptConfigOrDict +from .common import VertexBaseConfig +from .common import VertexBaseConfigDict +from .common import VertexBaseConfigOrDict +from .common import WinRateStats +from .common import WinRateStatsDict +from .common import WinRateStatsOrDict +from .common import WorkerPoolSpec +from .common import WorkerPoolSpecDict +from .common import WorkerPoolSpecOrDict + +__all__ = [ + "DeleteAgentEngineTaskConfig", + "DeleteAgentEngineTaskConfigDict", + "DeleteAgentEngineTaskConfigOrDict", + "GetAgentEngineTaskConfig", + "GetAgentEngineTaskConfigDict", + "GetAgentEngineTaskConfigOrDict", + "TaskArtifact", + "TaskArtifactDict", + "TaskArtifactOrDict", + "TaskOutput", + "TaskOutputDict", + "TaskOutputOrDict", + "TaskMessage", + "TaskMessageDict", + "TaskMessageOrDict", + "TaskStatusDetails", + "TaskStatusDetailsDict", + "TaskStatusDetailsOrDict", + "A2aTask", + "A2aTaskDict", + "A2aTaskOrDict", + "ListAgentEngineTasksConfig", + "ListAgentEngineTasksConfigDict", + "ListAgentEngineTasksConfigOrDict", + "ListAgentEngineTasksResponse", + "ListAgentEngineTasksResponseDict", + "ListAgentEngineTasksResponseOrDict", + "CreateAgentEngineTaskConfig", + "CreateAgentEngineTaskConfigDict", + "CreateAgentEngineTaskConfigOrDict", + "TaskMetadataChange", + "TaskMetadataChangeDict", + "TaskMetadataChangeOrDict", + "TaskArtifactChange", + "TaskArtifactChangeDict", + "TaskArtifactChangeOrDict", + "TaskOutputChange", + "TaskOutputChangeDict", + "TaskOutputChangeOrDict", + "TaskStateChange", + "TaskStateChangeDict", + "TaskStateChangeOrDict", + "TaskStatusDetailsChange", + "TaskStatusDetailsChangeDict", + "TaskStatusDetailsChangeOrDict", + "TaskEventData", + "TaskEventDataDict", + "TaskEventDataOrDict", + "TaskEvent", + "TaskEventDict", + "TaskEventOrDict", + "AppendAgentEngineTaskEventConfig", + "AppendAgentEngineTaskEventConfigDict", + "AppendAgentEngineTaskEventConfigOrDict", + "AppendAgentEngineTaskEventResponse", + "AppendAgentEngineTaskEventResponseDict", + "AppendAgentEngineTaskEventResponseOrDict", + "ListAgentEngineTaskEventsConfig", + "ListAgentEngineTaskEventsConfigDict", + "ListAgentEngineTaskEventsConfigOrDict", + "ListAgentEngineTaskEventsResponse", + "ListAgentEngineTaskEventsResponseDict", + "ListAgentEngineTaskEventsResponseOrDict", + "CreateEvaluationItemConfig", + "CreateEvaluationItemConfigDict", + "CreateEvaluationItemConfigOrDict", + "PromptTemplateData", + "PromptTemplateDataDict", + "PromptTemplateDataOrDict", + "EvaluationPrompt", + "EvaluationPromptDict", + "EvaluationPromptOrDict", + "CandidateResponse", + "CandidateResponseDict", + "CandidateResponseOrDict", + "EvaluationItemRequest", + "EvaluationItemRequestDict", + "EvaluationItemRequestOrDict", + "EvaluationItemResult", + "EvaluationItemResultDict", + "EvaluationItemResultOrDict", + "EvaluationItem", + "EvaluationItemDict", + "EvaluationItemOrDict", + "Metric", + "MetricDict", + "MetricOrDict", + "CreateEvaluationMetricConfig", + "CreateEvaluationMetricConfigDict", + "CreateEvaluationMetricConfigOrDict", + "CustomCodeExecutionSpec", + "CustomCodeExecutionSpecDict", + "CustomCodeExecutionSpecOrDict", + "UnifiedMetric", + "UnifiedMetricDict", + "UnifiedMetricOrDict", + "EvaluationMetric", + "EvaluationMetricDict", + "EvaluationMetricOrDict", + "SamplingConfig", + "SamplingConfigDict", + "SamplingConfigOrDict", + "BigQueryRequestSet", + "BigQueryRequestSetDict", + "BigQueryRequestSetOrDict", + "EvaluationRunDataSource", + "EvaluationRunDataSourceDict", + "EvaluationRunDataSourceOrDict", + "EvaluationRunMetric", + "EvaluationRunMetricDict", + "EvaluationRunMetricOrDict", + "EvaluationRunPromptTemplate", + "EvaluationRunPromptTemplateDict", + "EvaluationRunPromptTemplateOrDict", + "LossAnalysisConfig", + "LossAnalysisConfigDict", + "LossAnalysisConfigOrDict", + "EvaluationRunConfig", + "EvaluationRunConfigDict", + "EvaluationRunConfigOrDict", + "EvaluationRunAgentConfig", + "EvaluationRunAgentConfigDict", + "EvaluationRunAgentConfigOrDict", + "AgentRunConfig", + "AgentRunConfigDict", + "AgentRunConfigOrDict", + "EvaluationRunInferenceConfig", + "EvaluationRunInferenceConfigDict", + "EvaluationRunInferenceConfigOrDict", + "CreateEvaluationRunConfig", + "CreateEvaluationRunConfigDict", + "CreateEvaluationRunConfigOrDict", + "SummaryMetric", + "SummaryMetricDict", + "SummaryMetricOrDict", + "LossTaxonomyEntry", + "LossTaxonomyEntryDict", + "LossTaxonomyEntryOrDict", + "FailedRubric", + "FailedRubricDict", + "FailedRubricOrDict", + "LossExample", + "LossExampleDict", + "LossExampleOrDict", + "LossCluster", + "LossClusterDict", + "LossClusterOrDict", + "LossAnalysisResult", + "LossAnalysisResultDict", + "LossAnalysisResultOrDict", + "EvaluationRunResults", + "EvaluationRunResultsDict", + "EvaluationRunResultsOrDict", + "EvalCaseMetricResult", + "EvalCaseMetricResultDict", + "EvalCaseMetricResultOrDict", + "ResponseCandidateResult", + "ResponseCandidateResultDict", + "ResponseCandidateResultOrDict", + "EvalCaseResult", + "EvalCaseResultDict", + "EvalCaseResultOrDict", + "AggregatedMetricResult", + "AggregatedMetricResultDict", + "AggregatedMetricResultOrDict", + "WinRateStats", + "WinRateStatsDict", + "WinRateStatsOrDict", + "ResponseCandidate", + "ResponseCandidateDict", + "ResponseCandidateOrDict", + "EvalCase", + "EvalCaseDict", + "EvalCaseOrDict", + "EvaluationDataset", + "EvaluationDatasetDict", + "EvaluationDatasetOrDict", + "EvaluationRunMetadata", + "EvaluationRunMetadataDict", + "EvaluationRunMetadataOrDict", + "EvaluationResult", + "EvaluationResultDict", + "EvaluationResultOrDict", + "EvaluationRun", + "EvaluationRunDict", + "EvaluationRunOrDict", + "CreateEvaluationSetConfig", + "CreateEvaluationSetConfigDict", + "CreateEvaluationSetConfigOrDict", + "EvaluationSet", + "EvaluationSetDict", + "EvaluationSetOrDict", + "DeleteEvaluationMetricConfig", + "DeleteEvaluationMetricConfigDict", + "DeleteEvaluationMetricConfigOrDict", + "DeleteEvaluationMetricOperation", + "DeleteEvaluationMetricOperationDict", + "DeleteEvaluationMetricOperationOrDict", + "BleuInstance", + "BleuInstanceDict", + "BleuInstanceOrDict", + "BleuInput", + "BleuInputDict", + "BleuInputOrDict", + "ExactMatchInstance", + "ExactMatchInstanceDict", + "ExactMatchInstanceOrDict", + "ExactMatchSpec", + "ExactMatchSpecDict", + "ExactMatchSpecOrDict", + "ExactMatchInput", + "ExactMatchInputDict", + "ExactMatchInputOrDict", + "RougeInstance", + "RougeInstanceDict", + "RougeInstanceOrDict", + "RougeInput", + "RougeInputDict", + "RougeInputOrDict", + "ContentMap", + "ContentMapDict", + "ContentMapOrDict", + "PointwiseMetricInstance", + "PointwiseMetricInstanceDict", + "PointwiseMetricInstanceOrDict", + "PointwiseMetricInput", + "PointwiseMetricInputDict", + "PointwiseMetricInputOrDict", + "PairwiseMetricInstance", + "PairwiseMetricInstanceDict", + "PairwiseMetricInstanceOrDict", + "PairwiseMetricInput", + "PairwiseMetricInputDict", + "PairwiseMetricInputOrDict", + "ToolCallValidInstance", + "ToolCallValidInstanceDict", + "ToolCallValidInstanceOrDict", + "ToolCallValidSpec", + "ToolCallValidSpecDict", + "ToolCallValidSpecOrDict", + "ToolCallValidInput", + "ToolCallValidInputDict", + "ToolCallValidInputOrDict", + "ToolNameMatchInstance", + "ToolNameMatchInstanceDict", + "ToolNameMatchInstanceOrDict", + "ToolNameMatchSpec", + "ToolNameMatchSpecDict", + "ToolNameMatchSpecOrDict", + "ToolNameMatchInput", + "ToolNameMatchInputDict", + "ToolNameMatchInputOrDict", + "ToolParameterKeyMatchInstance", + "ToolParameterKeyMatchInstanceDict", + "ToolParameterKeyMatchInstanceOrDict", + "ToolParameterKeyMatchSpec", + "ToolParameterKeyMatchSpecDict", + "ToolParameterKeyMatchSpecOrDict", + "ToolParameterKeyMatchInput", + "ToolParameterKeyMatchInputDict", + "ToolParameterKeyMatchInputOrDict", + "ToolParameterKVMatchInstance", + "ToolParameterKVMatchInstanceDict", + "ToolParameterKVMatchInstanceOrDict", + "ToolParameterKVMatchSpec", + "ToolParameterKVMatchSpecDict", + "ToolParameterKVMatchSpecOrDict", + "ToolParameterKVMatchInput", + "ToolParameterKVMatchInputDict", + "ToolParameterKVMatchInputOrDict", + "MapInstance", + "MapInstanceDict", + "MapInstanceOrDict", + "EvaluationInstance", + "EvaluationInstanceDict", + "EvaluationInstanceOrDict", + "EvaluateInstancesConfig", + "EvaluateInstancesConfigDict", + "EvaluateInstancesConfigOrDict", + "RubricBasedMetricSpec", + "RubricBasedMetricSpecDict", + "RubricBasedMetricSpecOrDict", + "RubricEnhancedContents", + "RubricEnhancedContentsDict", + "RubricEnhancedContentsOrDict", + "RubricBasedMetricInstance", + "RubricBasedMetricInstanceDict", + "RubricBasedMetricInstanceOrDict", + "RubricBasedMetricInput", + "RubricBasedMetricInputDict", + "RubricBasedMetricInputOrDict", + "MetricSource", + "MetricSourceDict", + "MetricSourceOrDict", + "MetricResult", + "MetricResultDict", + "MetricResultOrDict", + "BleuResults", + "BleuResultsDict", + "BleuResultsOrDict", + "ExactMatchResults", + "ExactMatchResultsDict", + "ExactMatchResultsOrDict", + "RougeResults", + "RougeResultsDict", + "RougeResultsOrDict", + "RubricBasedMetricResult", + "RubricBasedMetricResultDict", + "RubricBasedMetricResultOrDict", + "CometResult", + "CometResultDict", + "CometResultOrDict", + "MetricxResult", + "MetricxResultDict", + "MetricxResultOrDict", + "ToolCallValidMetricValue", + "ToolCallValidMetricValueDict", + "ToolCallValidMetricValueOrDict", + "ToolCallValidResults", + "ToolCallValidResultsDict", + "ToolCallValidResultsOrDict", + "ToolNameMatchMetricValue", + "ToolNameMatchMetricValueDict", + "ToolNameMatchMetricValueOrDict", + "ToolNameMatchResults", + "ToolNameMatchResultsDict", + "ToolNameMatchResultsOrDict", + "ToolParameterKeyMatchMetricValue", + "ToolParameterKeyMatchMetricValueDict", + "ToolParameterKeyMatchMetricValueOrDict", + "ToolParameterKeyMatchResults", + "ToolParameterKeyMatchResultsDict", + "ToolParameterKeyMatchResultsOrDict", + "ToolParameterKVMatchMetricValue", + "ToolParameterKVMatchMetricValueDict", + "ToolParameterKVMatchMetricValueOrDict", + "ToolParameterKVMatchResults", + "ToolParameterKVMatchResultsDict", + "ToolParameterKVMatchResultsOrDict", + "EvaluateInstancesResponse", + "EvaluateInstancesResponseDict", + "EvaluateInstancesResponseOrDict", + "GenerateUserScenariosConfig", + "GenerateUserScenariosConfigDict", + "GenerateUserScenariosConfigOrDict", + "GenerateUserScenariosResponse", + "GenerateUserScenariosResponseDict", + "GenerateUserScenariosResponseOrDict", + "GenerateLossClustersConfig", + "GenerateLossClustersConfigDict", + "GenerateLossClustersConfigOrDict", + "GenerateLossClustersResponse", + "GenerateLossClustersResponseDict", + "GenerateLossClustersResponseOrDict", + "GenerateLossClustersOperation", + "GenerateLossClustersOperationDict", + "GenerateLossClustersOperationOrDict", + "RubricGenerationConfig", + "RubricGenerationConfigDict", + "RubricGenerationConfigOrDict", + "GenerateInstanceRubricsResponse", + "GenerateInstanceRubricsResponseDict", + "GenerateInstanceRubricsResponseOrDict", + "GetEvaluationMetricConfig", + "GetEvaluationMetricConfigDict", + "GetEvaluationMetricConfigOrDict", + "GetEvaluationRunConfig", + "GetEvaluationRunConfigDict", + "GetEvaluationRunConfigOrDict", + "GetEvaluationSetConfig", + "GetEvaluationSetConfigDict", + "GetEvaluationSetConfigOrDict", + "GetEvaluationItemConfig", + "GetEvaluationItemConfigDict", + "GetEvaluationItemConfigOrDict", + "ListEvaluationMetricsConfig", + "ListEvaluationMetricsConfigDict", + "ListEvaluationMetricsConfigOrDict", + "ListEvaluationMetricsResponse", + "ListEvaluationMetricsResponseDict", + "ListEvaluationMetricsResponseOrDict", + "OptimizeConfig", + "OptimizeConfigDict", + "OptimizeConfigOrDict", + "OptimizeResponseEndpoint", + "OptimizeResponseEndpointDict", + "OptimizeResponseEndpointOrDict", + "DnsPeeringConfig", + "DnsPeeringConfigDict", + "DnsPeeringConfigOrDict", + "PscInterfaceConfig", + "PscInterfaceConfigDict", + "PscInterfaceConfigOrDict", + "Scheduling", + "SchedulingDict", + "SchedulingOrDict", + "EnvVar", + "EnvVarDict", + "EnvVarOrDict", + "ContainerSpec", + "ContainerSpecDict", + "ContainerSpecOrDict", + "DiskSpec", + "DiskSpecDict", + "DiskSpecOrDict", + "LustreMount", + "LustreMountDict", + "LustreMountOrDict", + "ReservationAffinity", + "ReservationAffinityDict", + "ReservationAffinityOrDict", + "MachineSpec", + "MachineSpecDict", + "MachineSpecOrDict", + "NfsMount", + "NfsMountDict", + "NfsMountOrDict", + "PythonPackageSpec", + "PythonPackageSpecDict", + "PythonPackageSpecOrDict", + "WorkerPoolSpec", + "WorkerPoolSpecDict", + "WorkerPoolSpecOrDict", + "CustomJobSpec", + "CustomJobSpecDict", + "CustomJobSpecOrDict", + "CustomJob", + "CustomJobDict", + "CustomJobOrDict", + "VertexBaseConfig", + "VertexBaseConfigDict", + "VertexBaseConfigOrDict", + "CancelQueryJobAgentEngineConfig", + "CancelQueryJobAgentEngineConfigDict", + "CancelQueryJobAgentEngineConfigOrDict", + "CancelQueryJobResult", + "CancelQueryJobResultDict", + "CancelQueryJobResultOrDict", + "CheckQueryJobAgentEngineConfig", + "CheckQueryJobAgentEngineConfigDict", + "CheckQueryJobAgentEngineConfigOrDict", + "CheckQueryJobResult", + "CheckQueryJobResultDict", + "CheckQueryJobResultOrDict", + "_RunQueryJobAgentEngineConfig", + "_RunQueryJobAgentEngineConfigDict", + "_RunQueryJobAgentEngineConfigOrDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEvent", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventOrDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSource", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceOrDict", + "MemoryTopicId", + "MemoryTopicIdDict", + "MemoryTopicIdOrDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory", + "MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryOrDict", + "MemoryBankCustomizationConfigGenerateMemoriesExample", + "MemoryBankCustomizationConfigGenerateMemoriesExampleDict", + "MemoryBankCustomizationConfigGenerateMemoriesExampleOrDict", + "MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopic", + "MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicDict", + "MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicOrDict", + "MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopic", + "MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicDict", + "MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict", + "MemoryBankCustomizationConfigMemoryTopic", + "MemoryBankCustomizationConfigMemoryTopicDict", + "MemoryBankCustomizationConfigMemoryTopicOrDict", + "MemoryBankCustomizationConfigConsolidationConfig", + "MemoryBankCustomizationConfigConsolidationConfigDict", + "MemoryBankCustomizationConfigConsolidationConfigOrDict", + "MemoryBankCustomizationConfig", + "MemoryBankCustomizationConfigDict", + "MemoryBankCustomizationConfigOrDict", + "MemoryGenerationTriggerConfigGenerationTriggerRule", + "MemoryGenerationTriggerConfigGenerationTriggerRuleDict", + "MemoryGenerationTriggerConfigGenerationTriggerRuleOrDict", + "MemoryGenerationTriggerConfig", + "MemoryGenerationTriggerConfigDict", + "MemoryGenerationTriggerConfigOrDict", + "ReasoningEngineContextSpecMemoryBankConfigGenerationConfig", + "ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict", + "ReasoningEngineContextSpecMemoryBankConfigGenerationConfigOrDict", + "ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfig", + "ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigDict", + "ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigOrDict", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfig", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigDict", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigOrDict", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfig", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict", + "ReasoningEngineContextSpecMemoryBankConfigTtlConfigOrDict", + "StructuredMemorySchemaConfig", + "StructuredMemorySchemaConfigDict", + "StructuredMemorySchemaConfigOrDict", + "StructuredMemoryConfig", + "StructuredMemoryConfigDict", + "StructuredMemoryConfigOrDict", + "ReasoningEngineContextSpecMemoryBankConfig", + "ReasoningEngineContextSpecMemoryBankConfigDict", + "ReasoningEngineContextSpecMemoryBankConfigOrDict", + "ReasoningEngineContextSpec", + "ReasoningEngineContextSpecDict", + "ReasoningEngineContextSpecOrDict", + "SecretRef", + "SecretRefDict", + "SecretRefOrDict", + "SecretEnvVar", + "SecretEnvVarDict", + "SecretEnvVarOrDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfig", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigOrDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfig", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigOrDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfig", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict", + "ReasoningEngineSpecDeploymentSpecAgentGatewayConfigOrDict", + "KeepAliveProbeHttpGet", + "KeepAliveProbeHttpGetDict", + "KeepAliveProbeHttpGetOrDict", + "KeepAliveProbe", + "KeepAliveProbeDict", + "KeepAliveProbeOrDict", + "ReasoningEngineSpecDeploymentSpec", + "ReasoningEngineSpecDeploymentSpecDict", + "ReasoningEngineSpecDeploymentSpecOrDict", + "ReasoningEngineSpecPackageSpec", + "ReasoningEngineSpecPackageSpecDict", + "ReasoningEngineSpecPackageSpecOrDict", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfig", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigDict", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigOrDict", + "ReasoningEngineSpecSourceCodeSpecInlineSource", + "ReasoningEngineSpecSourceCodeSpecInlineSourceDict", + "ReasoningEngineSpecSourceCodeSpecInlineSourceOrDict", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSource", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict", + "ReasoningEngineSpecSourceCodeSpecAgentConfigSourceOrDict", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigOrDict", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectSource", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceDict", + "ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceOrDict", + "ReasoningEngineSpecSourceCodeSpecImageSpec", + "ReasoningEngineSpecSourceCodeSpecImageSpecDict", + "ReasoningEngineSpecSourceCodeSpecImageSpecOrDict", + "ReasoningEngineSpecSourceCodeSpecPythonSpec", + "ReasoningEngineSpecSourceCodeSpecPythonSpecDict", + "ReasoningEngineSpecSourceCodeSpecPythonSpecOrDict", + "ReasoningEngineSpecSourceCodeSpec", + "ReasoningEngineSpecSourceCodeSpecDict", + "ReasoningEngineSpecSourceCodeSpecOrDict", + "ReasoningEngineSpecContainerSpec", + "ReasoningEngineSpecContainerSpecDict", + "ReasoningEngineSpecContainerSpecOrDict", + "ReasoningEngineSpec", + "ReasoningEngineSpecDict", + "ReasoningEngineSpecOrDict", + "ReasoningEngineTrafficConfigTrafficSplitAlwaysLatest", + "ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestDict", + "ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestOrDict", + "ReasoningEngineTrafficConfigTrafficSplitManualTarget", + "ReasoningEngineTrafficConfigTrafficSplitManualTargetDict", + "ReasoningEngineTrafficConfigTrafficSplitManualTargetOrDict", + "ReasoningEngineTrafficConfigTrafficSplitManual", + "ReasoningEngineTrafficConfigTrafficSplitManualDict", + "ReasoningEngineTrafficConfigTrafficSplitManualOrDict", + "ReasoningEngineTrafficConfig", + "ReasoningEngineTrafficConfigDict", + "ReasoningEngineTrafficConfigOrDict", + "ReasoningEngine", + "ReasoningEngineDict", + "ReasoningEngineOrDict", + "AgentEngineOperation", + "AgentEngineOperationDict", + "AgentEngineOperationOrDict", + "CreateAgentEngineConfig", + "CreateAgentEngineConfigDict", + "CreateAgentEngineConfigOrDict", + "DeleteAgentEngineConfig", + "DeleteAgentEngineConfigDict", + "DeleteAgentEngineConfigOrDict", + "DeleteAgentEngineOperation", + "DeleteAgentEngineOperationDict", + "DeleteAgentEngineOperationOrDict", + "GetAgentEngineConfig", + "GetAgentEngineConfigDict", + "GetAgentEngineConfigOrDict", + "ListAgentEngineConfig", + "ListAgentEngineConfigDict", + "ListAgentEngineConfigOrDict", + "ListReasoningEnginesResponse", + "ListReasoningEnginesResponseDict", + "ListReasoningEnginesResponseOrDict", + "GetAgentEngineOperationConfig", + "GetAgentEngineOperationConfigDict", + "GetAgentEngineOperationConfigOrDict", + "QueryAgentEngineConfig", + "QueryAgentEngineConfigDict", + "QueryAgentEngineConfigOrDict", + "QueryReasoningEngineResponse", + "QueryReasoningEngineResponseDict", + "QueryReasoningEngineResponseOrDict", + "UpdateAgentEngineConfig", + "UpdateAgentEngineConfigDict", + "UpdateAgentEngineConfigOrDict", + "MemoryMetadataValue", + "MemoryMetadataValueDict", + "MemoryMetadataValueOrDict", + "AgentEngineMemoryConfig", + "AgentEngineMemoryConfigDict", + "AgentEngineMemoryConfigOrDict", + "MemoryStructuredContent", + "MemoryStructuredContentDict", + "MemoryStructuredContentOrDict", + "Memory", + "MemoryDict", + "MemoryOrDict", + "AgentEngineMemoryOperation", + "AgentEngineMemoryOperationDict", + "AgentEngineMemoryOperationOrDict", + "DeleteAgentEngineMemoryConfig", + "DeleteAgentEngineMemoryConfigDict", + "DeleteAgentEngineMemoryConfigOrDict", + "DeleteAgentEngineMemoryOperation", + "DeleteAgentEngineMemoryOperationDict", + "DeleteAgentEngineMemoryOperationOrDict", + "GenerateMemoriesRequestVertexSessionSource", + "GenerateMemoriesRequestVertexSessionSourceDict", + "GenerateMemoriesRequestVertexSessionSourceOrDict", + "GenerateMemoriesRequestDirectContentsSourceEvent", + "GenerateMemoriesRequestDirectContentsSourceEventDict", + "GenerateMemoriesRequestDirectContentsSourceEventOrDict", + "GenerateMemoriesRequestDirectContentsSource", + "GenerateMemoriesRequestDirectContentsSourceDict", + "GenerateMemoriesRequestDirectContentsSourceOrDict", + "GenerateMemoriesRequestDirectMemoriesSourceDirectMemory", + "GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict", + "GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryOrDict", + "GenerateMemoriesRequestDirectMemoriesSource", + "GenerateMemoriesRequestDirectMemoriesSourceDict", + "GenerateMemoriesRequestDirectMemoriesSourceOrDict", + "GenerateAgentEngineMemoriesConfig", + "GenerateAgentEngineMemoriesConfigDict", + "GenerateAgentEngineMemoriesConfigOrDict", + "GenerateMemoriesResponseGeneratedMemory", + "GenerateMemoriesResponseGeneratedMemoryDict", + "GenerateMemoriesResponseGeneratedMemoryOrDict", + "GenerateMemoriesResponse", + "GenerateMemoriesResponseDict", + "GenerateMemoriesResponseOrDict", + "AgentEngineGenerateMemoriesOperation", + "AgentEngineGenerateMemoriesOperationDict", + "AgentEngineGenerateMemoriesOperationOrDict", + "GetAgentEngineMemoryConfig", + "GetAgentEngineMemoryConfigDict", + "GetAgentEngineMemoryConfigOrDict", + "IngestionDirectContentsSourceEvent", + "IngestionDirectContentsSourceEventDict", + "IngestionDirectContentsSourceEventOrDict", + "IngestionDirectContentsSource", + "IngestionDirectContentsSourceDict", + "IngestionDirectContentsSourceOrDict", + "IngestEventsConfig", + "IngestEventsConfigDict", + "IngestEventsConfigOrDict", + "MemoryBankIngestEventsOperation", + "MemoryBankIngestEventsOperationDict", + "MemoryBankIngestEventsOperationOrDict", + "ListAgentEngineMemoryConfig", + "ListAgentEngineMemoryConfigDict", + "ListAgentEngineMemoryConfigOrDict", + "ListReasoningEnginesMemoriesResponse", + "ListReasoningEnginesMemoriesResponseDict", + "ListReasoningEnginesMemoriesResponseOrDict", + "RetrieveMemoriesRequestSimilaritySearchParams", + "RetrieveMemoriesRequestSimilaritySearchParamsDict", + "RetrieveMemoriesRequestSimilaritySearchParamsOrDict", + "RetrieveMemoriesRequestSimpleRetrievalParams", + "RetrieveMemoriesRequestSimpleRetrievalParamsDict", + "RetrieveMemoriesRequestSimpleRetrievalParamsOrDict", + "MemoryFilter", + "MemoryFilterDict", + "MemoryFilterOrDict", + "MemoryConjunctionFilter", + "MemoryConjunctionFilterDict", + "MemoryConjunctionFilterOrDict", + "RetrieveAgentEngineMemoriesConfig", + "RetrieveAgentEngineMemoriesConfigDict", + "RetrieveAgentEngineMemoriesConfigOrDict", + "RetrieveMemoriesResponseRetrievedMemory", + "RetrieveMemoriesResponseRetrievedMemoryDict", + "RetrieveMemoriesResponseRetrievedMemoryOrDict", + "RetrieveMemoriesResponse", + "RetrieveMemoriesResponseDict", + "RetrieveMemoriesResponseOrDict", + "RetrieveMemoryProfilesConfig", + "RetrieveMemoryProfilesConfigDict", + "RetrieveMemoryProfilesConfigOrDict", + "MemoryProfile", + "MemoryProfileDict", + "MemoryProfileOrDict", + "RetrieveProfilesResponse", + "RetrieveProfilesResponseDict", + "RetrieveProfilesResponseOrDict", + "RollbackAgentEngineMemoryConfig", + "RollbackAgentEngineMemoryConfigDict", + "RollbackAgentEngineMemoryConfigOrDict", + "AgentEngineRollbackMemoryOperation", + "AgentEngineRollbackMemoryOperationDict", + "AgentEngineRollbackMemoryOperationOrDict", + "UpdateAgentEngineMemoryConfig", + "UpdateAgentEngineMemoryConfigDict", + "UpdateAgentEngineMemoryConfigOrDict", + "PurgeAgentEngineMemoriesConfig", + "PurgeAgentEngineMemoriesConfigDict", + "PurgeAgentEngineMemoriesConfigOrDict", + "PurgeMemoriesResponse", + "PurgeMemoriesResponseDict", + "PurgeMemoriesResponseOrDict", + "AgentEnginePurgeMemoriesOperation", + "AgentEnginePurgeMemoriesOperationDict", + "AgentEnginePurgeMemoriesOperationOrDict", + "GetAgentEngineMemoryRevisionConfig", + "GetAgentEngineMemoryRevisionConfigDict", + "GetAgentEngineMemoryRevisionConfigOrDict", + "IntermediateExtractedMemory", + "IntermediateExtractedMemoryDict", + "IntermediateExtractedMemoryOrDict", + "MemoryRevision", + "MemoryRevisionDict", + "MemoryRevisionOrDict", + "ListAgentEngineMemoryRevisionsConfig", + "ListAgentEngineMemoryRevisionsConfigDict", + "ListAgentEngineMemoryRevisionsConfigOrDict", + "ListAgentEngineMemoryRevisionsResponse", + "ListAgentEngineMemoryRevisionsResponseDict", + "ListAgentEngineMemoryRevisionsResponseOrDict", + "GetAgentEngineRuntimeRevisionConfig", + "GetAgentEngineRuntimeRevisionConfigDict", + "GetAgentEngineRuntimeRevisionConfigOrDict", + "ReasoningEngineRuntimeRevision", + "ReasoningEngineRuntimeRevisionDict", + "ReasoningEngineRuntimeRevisionOrDict", + "ListAgentEngineRuntimeRevisionsConfig", + "ListAgentEngineRuntimeRevisionsConfigDict", + "ListAgentEngineRuntimeRevisionsConfigOrDict", + "ListReasoningEnginesRuntimeRevisionsResponse", + "ListReasoningEnginesRuntimeRevisionsResponseDict", + "ListReasoningEnginesRuntimeRevisionsResponseOrDict", + "DeleteAgentEngineRuntimeRevisionConfig", + "DeleteAgentEngineRuntimeRevisionConfigDict", + "DeleteAgentEngineRuntimeRevisionConfigOrDict", + "DeleteAgentEngineRuntimeRevisionOperation", + "DeleteAgentEngineRuntimeRevisionOperationDict", + "DeleteAgentEngineRuntimeRevisionOperationOrDict", + "GetDeleteAgentEngineRuntimeRevisionOperationConfig", + "GetDeleteAgentEngineRuntimeRevisionOperationConfigDict", + "GetDeleteAgentEngineRuntimeRevisionOperationConfigOrDict", + "QueryAgentEngineRuntimeRevisionConfig", + "QueryAgentEngineRuntimeRevisionConfigDict", + "QueryAgentEngineRuntimeRevisionConfigOrDict", + "SandboxEnvironmentSpecCodeExecutionEnvironment", + "SandboxEnvironmentSpecCodeExecutionEnvironmentDict", + "SandboxEnvironmentSpecCodeExecutionEnvironmentOrDict", + "SandboxEnvironmentSpecComputerUseEnvironment", + "SandboxEnvironmentSpecComputerUseEnvironmentDict", + "SandboxEnvironmentSpecComputerUseEnvironmentOrDict", + "SandboxEnvironmentSpec", + "SandboxEnvironmentSpecDict", + "SandboxEnvironmentSpecOrDict", + "CreateAgentEngineSandboxConfig", + "CreateAgentEngineSandboxConfigDict", + "CreateAgentEngineSandboxConfigOrDict", + "SandboxEnvironmentConnectionInfo", + "SandboxEnvironmentConnectionInfoDict", + "SandboxEnvironmentConnectionInfoOrDict", + "SandboxEnvironment", + "SandboxEnvironmentDict", + "SandboxEnvironmentOrDict", + "AgentEngineSandboxOperation", + "AgentEngineSandboxOperationDict", + "AgentEngineSandboxOperationOrDict", + "DeleteAgentEngineSandboxConfig", + "DeleteAgentEngineSandboxConfigDict", + "DeleteAgentEngineSandboxConfigOrDict", + "DeleteAgentEngineSandboxOperation", + "DeleteAgentEngineSandboxOperationDict", + "DeleteAgentEngineSandboxOperationOrDict", + "Metadata", + "MetadataDict", + "MetadataOrDict", + "Chunk", + "ChunkDict", + "ChunkOrDict", + "ExecuteCodeAgentEngineSandboxConfig", + "ExecuteCodeAgentEngineSandboxConfigDict", + "ExecuteCodeAgentEngineSandboxConfigOrDict", + "ExecuteSandboxEnvironmentResponse", + "ExecuteSandboxEnvironmentResponseDict", + "ExecuteSandboxEnvironmentResponseOrDict", + "GetAgentEngineSandboxConfig", + "GetAgentEngineSandboxConfigDict", + "GetAgentEngineSandboxConfigOrDict", + "ListAgentEngineSandboxesConfig", + "ListAgentEngineSandboxesConfigDict", + "ListAgentEngineSandboxesConfigOrDict", + "ListAgentEngineSandboxesResponse", + "ListAgentEngineSandboxesResponseDict", + "ListAgentEngineSandboxesResponseOrDict", + "SandboxEnvironmentTemplateCustomContainerSpec", + "SandboxEnvironmentTemplateCustomContainerSpecDict", + "SandboxEnvironmentTemplateCustomContainerSpecOrDict", + "SandboxEnvironmentTemplateNetworkPort", + "SandboxEnvironmentTemplateNetworkPortDict", + "SandboxEnvironmentTemplateNetworkPortOrDict", + "SandboxEnvironmentTemplateResourceRequirements", + "SandboxEnvironmentTemplateResourceRequirementsDict", + "SandboxEnvironmentTemplateResourceRequirementsOrDict", + "SandboxEnvironmentTemplateCustomContainerEnvironment", + "SandboxEnvironmentTemplateCustomContainerEnvironmentDict", + "SandboxEnvironmentTemplateCustomContainerEnvironmentOrDict", + "SandboxEnvironmentTemplateDefaultContainerEnvironment", + "SandboxEnvironmentTemplateDefaultContainerEnvironmentDict", + "SandboxEnvironmentTemplateDefaultContainerEnvironmentOrDict", + "SandboxEnvironmentTemplateEgressControlConfig", + "SandboxEnvironmentTemplateEgressControlConfigDict", + "SandboxEnvironmentTemplateEgressControlConfigOrDict", + "CreateSandboxEnvironmentTemplateConfig", + "CreateSandboxEnvironmentTemplateConfigDict", + "CreateSandboxEnvironmentTemplateConfigOrDict", + "SandboxEnvironmentTemplateWarmPoolConfig", + "SandboxEnvironmentTemplateWarmPoolConfigDict", + "SandboxEnvironmentTemplateWarmPoolConfigOrDict", + "SandboxEnvironmentTemplate", + "SandboxEnvironmentTemplateDict", + "SandboxEnvironmentTemplateOrDict", + "SandboxEnvironmentTemplateOperation", + "SandboxEnvironmentTemplateOperationDict", + "SandboxEnvironmentTemplateOperationOrDict", + "DeleteSandboxEnvironmentTemplateConfig", + "DeleteSandboxEnvironmentTemplateConfigDict", + "DeleteSandboxEnvironmentTemplateConfigOrDict", + "DeleteSandboxEnvironmentTemplateOperation", + "DeleteSandboxEnvironmentTemplateOperationDict", + "DeleteSandboxEnvironmentTemplateOperationOrDict", + "GetSandboxEnvironmentTemplateConfig", + "GetSandboxEnvironmentTemplateConfigDict", + "GetSandboxEnvironmentTemplateConfigOrDict", + "ListSandboxEnvironmentTemplatesConfig", + "ListSandboxEnvironmentTemplatesConfigDict", + "ListSandboxEnvironmentTemplatesConfigOrDict", + "ListSandboxEnvironmentTemplatesResponse", + "ListSandboxEnvironmentTemplatesResponseDict", + "ListSandboxEnvironmentTemplatesResponseOrDict", + "CreateAgentEngineSandboxSnapshotConfig", + "CreateAgentEngineSandboxSnapshotConfigDict", + "CreateAgentEngineSandboxSnapshotConfigOrDict", + "SandboxEnvironmentSnapshot", + "SandboxEnvironmentSnapshotDict", + "SandboxEnvironmentSnapshotOrDict", + "AgentEngineSandboxSnapshotOperation", + "AgentEngineSandboxSnapshotOperationDict", + "AgentEngineSandboxSnapshotOperationOrDict", + "DeleteSandboxEnvironmentSnapshotConfig", + "DeleteSandboxEnvironmentSnapshotConfigDict", + "DeleteSandboxEnvironmentSnapshotConfigOrDict", + "DeleteSandboxEnvironmentSnapshotOperation", + "DeleteSandboxEnvironmentSnapshotOperationDict", + "DeleteSandboxEnvironmentSnapshotOperationOrDict", + "GetSandboxEnvironmentSnapshotConfig", + "GetSandboxEnvironmentSnapshotConfigDict", + "GetSandboxEnvironmentSnapshotConfigOrDict", + "ListSandboxEnvironmentSnapshotsConfig", + "ListSandboxEnvironmentSnapshotsConfigDict", + "ListSandboxEnvironmentSnapshotsConfigOrDict", + "ListSandboxEnvironmentSnapshotsResponse", + "ListSandboxEnvironmentSnapshotsResponseDict", + "ListSandboxEnvironmentSnapshotsResponseOrDict", + "CreateAgentEngineSessionConfig", + "CreateAgentEngineSessionConfigDict", + "CreateAgentEngineSessionConfigOrDict", + "Session", + "SessionDict", + "SessionOrDict", + "AgentEngineSessionOperation", + "AgentEngineSessionOperationDict", + "AgentEngineSessionOperationOrDict", + "DeleteAgentEngineSessionConfig", + "DeleteAgentEngineSessionConfigDict", + "DeleteAgentEngineSessionConfigOrDict", + "DeleteAgentEngineSessionOperation", + "DeleteAgentEngineSessionOperationDict", + "DeleteAgentEngineSessionOperationOrDict", + "GetAgentEngineSessionConfig", + "GetAgentEngineSessionConfigDict", + "GetAgentEngineSessionConfigOrDict", + "ListAgentEngineSessionsConfig", + "ListAgentEngineSessionsConfigDict", + "ListAgentEngineSessionsConfigOrDict", + "ListReasoningEnginesSessionsResponse", + "ListReasoningEnginesSessionsResponseDict", + "ListReasoningEnginesSessionsResponseOrDict", + "UpdateAgentEngineSessionConfig", + "UpdateAgentEngineSessionConfigDict", + "UpdateAgentEngineSessionConfigOrDict", + "EventActions", + "EventActionsDict", + "EventActionsOrDict", + "EventMetadata", + "EventMetadataDict", + "EventMetadataOrDict", + "AppendAgentEngineSessionEventConfig", + "AppendAgentEngineSessionEventConfigDict", + "AppendAgentEngineSessionEventConfigOrDict", + "AppendAgentEngineSessionEventResponse", + "AppendAgentEngineSessionEventResponseDict", + "AppendAgentEngineSessionEventResponseOrDict", + "ListAgentEngineSessionEventsConfig", + "ListAgentEngineSessionEventsConfigDict", + "ListAgentEngineSessionEventsConfigOrDict", + "SessionEvent", + "SessionEventDict", + "SessionEventOrDict", + "ListAgentEngineSessionEventsResponse", + "ListAgentEngineSessionEventsResponseDict", + "ListAgentEngineSessionEventsResponseOrDict", + "GeminiExample", + "GeminiExampleDict", + "GeminiExampleOrDict", + "GeminiTemplateConfig", + "GeminiTemplateConfigDict", + "GeminiTemplateConfigOrDict", + "GeminiRequestReadConfig", + "GeminiRequestReadConfigDict", + "GeminiRequestReadConfigOrDict", + "AssembleDatasetConfig", + "AssembleDatasetConfigDict", + "AssembleDatasetConfigOrDict", + "MultimodalDatasetOperation", + "MultimodalDatasetOperationDict", + "MultimodalDatasetOperationOrDict", + "TuningResourceUsageAssessmentConfig", + "TuningResourceUsageAssessmentConfigDict", + "TuningResourceUsageAssessmentConfigOrDict", + "TuningValidationAssessmentConfig", + "TuningValidationAssessmentConfigDict", + "TuningValidationAssessmentConfigOrDict", + "BatchPredictionResourceUsageAssessmentConfig", + "BatchPredictionResourceUsageAssessmentConfigDict", + "BatchPredictionResourceUsageAssessmentConfigOrDict", + "BatchPredictionValidationAssessmentConfig", + "BatchPredictionValidationAssessmentConfigDict", + "BatchPredictionValidationAssessmentConfigOrDict", + "AssessDatasetConfig", + "AssessDatasetConfigDict", + "AssessDatasetConfigOrDict", + "SchemaTablesDatasetMetadataBigQuerySource", + "SchemaTablesDatasetMetadataBigQuerySourceDict", + "SchemaTablesDatasetMetadataBigQuerySourceOrDict", + "SchemaTablesDatasetMetadataInputConfig", + "SchemaTablesDatasetMetadataInputConfigDict", + "SchemaTablesDatasetMetadataInputConfigOrDict", + "SchemaTablesDatasetMetadata", + "SchemaTablesDatasetMetadataDict", + "SchemaTablesDatasetMetadataOrDict", + "CreateMultimodalDatasetConfig", + "CreateMultimodalDatasetConfigDict", + "CreateMultimodalDatasetConfigOrDict", + "MultimodalDataset", + "MultimodalDatasetDict", + "MultimodalDatasetOrDict", + "GetMultimodalDatasetOperationConfig", + "GetMultimodalDatasetOperationConfigDict", + "GetMultimodalDatasetOperationConfigOrDict", + "ListMultimodalDatasetsConfig", + "ListMultimodalDatasetsConfigDict", + "ListMultimodalDatasetsConfigOrDict", + "ListMultimodalDatasetsResponse", + "ListMultimodalDatasetsResponseDict", + "ListMultimodalDatasetsResponseOrDict", + "SchemaPredictParamsGroundingConfigSourceEntry", + "SchemaPredictParamsGroundingConfigSourceEntryDict", + "SchemaPredictParamsGroundingConfigSourceEntryOrDict", + "SchemaPredictParamsGroundingConfig", + "SchemaPredictParamsGroundingConfigDict", + "SchemaPredictParamsGroundingConfigOrDict", + "SchemaPromptInstancePromptExecution", + "SchemaPromptInstancePromptExecutionDict", + "SchemaPromptInstancePromptExecutionOrDict", + "SchemaPromptSpecPromptMessage", + "SchemaPromptSpecPromptMessageDict", + "SchemaPromptSpecPromptMessageOrDict", + "SchemaPromptSpecMultimodalPrompt", + "SchemaPromptSpecMultimodalPromptDict", + "SchemaPromptSpecMultimodalPromptOrDict", + "SchemaPromptSpecAppBuilderDataLinkedResource", + "SchemaPromptSpecAppBuilderDataLinkedResourceDict", + "SchemaPromptSpecAppBuilderDataLinkedResourceOrDict", + "SchemaPromptSpecAppBuilderData", + "SchemaPromptSpecAppBuilderDataDict", + "SchemaPromptSpecAppBuilderDataOrDict", + "SchemaPromptSpecPartList", + "SchemaPromptSpecPartListDict", + "SchemaPromptSpecPartListOrDict", + "SchemaPromptSpecStructuredPrompt", + "SchemaPromptSpecStructuredPromptDict", + "SchemaPromptSpecStructuredPromptOrDict", + "SchemaPromptSpecReferenceSentencePair", + "SchemaPromptSpecReferenceSentencePairDict", + "SchemaPromptSpecReferenceSentencePairOrDict", + "SchemaPromptSpecReferenceSentencePairList", + "SchemaPromptSpecReferenceSentencePairListDict", + "SchemaPromptSpecReferenceSentencePairListOrDict", + "SchemaPromptSpecTranslationFileInputSource", + "SchemaPromptSpecTranslationFileInputSourceDict", + "SchemaPromptSpecTranslationFileInputSourceOrDict", + "SchemaPromptSpecTranslationGcsInputSource", + "SchemaPromptSpecTranslationGcsInputSourceDict", + "SchemaPromptSpecTranslationGcsInputSourceOrDict", + "SchemaPromptSpecTranslationSentenceFileInput", + "SchemaPromptSpecTranslationSentenceFileInputDict", + "SchemaPromptSpecTranslationSentenceFileInputOrDict", + "SchemaPromptSpecTranslationExample", + "SchemaPromptSpecTranslationExampleDict", + "SchemaPromptSpecTranslationExampleOrDict", + "SchemaPromptSpecTranslationOption", + "SchemaPromptSpecTranslationOptionDict", + "SchemaPromptSpecTranslationOptionOrDict", + "SchemaPromptSpecTranslationPrompt", + "SchemaPromptSpecTranslationPromptDict", + "SchemaPromptSpecTranslationPromptOrDict", + "SchemaPromptApiSchema", + "SchemaPromptApiSchemaDict", + "SchemaPromptApiSchemaOrDict", + "SchemaTextPromptDatasetMetadata", + "SchemaTextPromptDatasetMetadataDict", + "SchemaTextPromptDatasetMetadataOrDict", + "CreateDatasetConfig", + "CreateDatasetConfigDict", + "CreateDatasetConfigOrDict", + "DatasetOperation", + "DatasetOperationDict", + "DatasetOperationOrDict", + "CreateDatasetVersionConfig", + "CreateDatasetVersionConfigDict", + "CreateDatasetVersionConfigOrDict", + "SavedQuery", + "SavedQueryDict", + "SavedQueryOrDict", + "Dataset", + "DatasetDict", + "DatasetOrDict", + "DatasetVersion", + "DatasetVersionDict", + "DatasetVersionOrDict", + "GetDatasetOperationConfig", + "GetDatasetOperationConfigDict", + "GetDatasetOperationConfigOrDict", + "ListPromptsConfig", + "ListPromptsConfigDict", + "ListPromptsConfigOrDict", + "ListDatasetsResponse", + "ListDatasetsResponseDict", + "ListDatasetsResponseOrDict", + "ListDatasetVersionsResponse", + "ListDatasetVersionsResponseDict", + "ListDatasetVersionsResponseOrDict", + "DeletePromptConfig", + "DeletePromptConfigDict", + "DeletePromptConfigOrDict", + "DeletePromptOperation", + "DeletePromptOperationDict", + "DeletePromptOperationOrDict", + "DeletePromptVersionOperation", + "DeletePromptVersionOperationDict", + "DeletePromptVersionOperationOrDict", + "RestoreVersionConfig", + "RestoreVersionConfigDict", + "RestoreVersionConfigOrDict", + "RestoreVersionOperation", + "RestoreVersionOperationDict", + "RestoreVersionOperationOrDict", + "UpdatePromptConfig", + "UpdatePromptConfigDict", + "UpdatePromptConfigOrDict", + "GetSkillConfig", + "GetSkillConfigDict", + "GetSkillConfigOrDict", + "Skill", + "SkillDict", + "SkillOrDict", + "RetrieveSkillsConfig", + "RetrieveSkillsConfigDict", + "RetrieveSkillsConfigOrDict", + "RetrievedSkill", + "RetrievedSkillDict", + "RetrievedSkillOrDict", + "RetrieveSkillsResponse", + "RetrieveSkillsResponseDict", + "RetrieveSkillsResponseOrDict", + "CreateSkillConfig", + "CreateSkillConfigDict", + "CreateSkillConfigOrDict", + "SkillOperation", + "SkillOperationDict", + "SkillOperationOrDict", + "GetSkillOperationConfig", + "GetSkillOperationConfigDict", + "GetSkillOperationConfigOrDict", + "PromptOptimizerConfig", + "PromptOptimizerConfigDict", + "PromptOptimizerConfigOrDict", + "OptimizeResponse", + "OptimizeResponseDict", + "OptimizeResponseOrDict", + "ContentMapContents", + "ContentMapContentsDict", + "ContentMapContentsOrDict", + "EvaluateMethodConfig", + "EvaluateMethodConfigDict", + "EvaluateMethodConfigOrDict", + "EvaluateDatasetConfig", + "EvaluateDatasetConfigDict", + "EvaluateDatasetConfigOrDict", + "EvaluateDatasetOperation", + "EvaluateDatasetOperationDict", + "EvaluateDatasetOperationOrDict", + "EvaluateDatasetRequestParameters", + "EvaluateDatasetRequestParametersDict", + "EvaluateDatasetRequestParametersOrDict", + "ObservabilityEvalCase", + "ObservabilityEvalCaseDict", + "ObservabilityEvalCaseOrDict", + "RubricGroup", + "RubricGroupDict", + "RubricGroupOrDict", + "PromptTemplate", + "PromptTemplateDict", + "PromptTemplateOrDict", + "EvalRunInferenceConfig", + "EvalRunInferenceConfigDict", + "EvalRunInferenceConfigOrDict", + "AgentEngine", + "AgentEngineDict", + "AgentEngineOrDict", + "AgentEngineConfig", + "AgentEngineConfigDict", + "AgentEngineConfigOrDict", + "RunQueryJobAgentEngineConfig", + "RunQueryJobAgentEngineConfigDict", + "RunQueryJobAgentEngineConfigOrDict", + "RunQueryJobResult", + "RunQueryJobResultDict", + "RunQueryJobResultOrDict", + "CheckQueryJobResponse", + "CheckQueryJobResponseDict", + "CheckQueryJobResponseOrDict", + "AssembleDataset", + "AssembleDatasetDict", + "AssembleDatasetOrDict", + "BatchPredictionResourceUsageAssessmentResult", + "BatchPredictionResourceUsageAssessmentResultDict", + "BatchPredictionResourceUsageAssessmentResultOrDict", + "BatchPredictionValidationAssessmentResult", + "BatchPredictionValidationAssessmentResultDict", + "BatchPredictionValidationAssessmentResultOrDict", + "TuningResourceUsageAssessmentResult", + "TuningResourceUsageAssessmentResultDict", + "TuningResourceUsageAssessmentResultOrDict", + "TuningValidationAssessmentResult", + "TuningValidationAssessmentResultDict", + "TuningValidationAssessmentResultOrDict", + "Prompt", + "PromptDict", + "PromptOrDict", + "SchemaPromptInstanceVariableValue", + "SchemaPromptInstanceVariableValueDict", + "SchemaPromptInstanceVariableValueOrDict", + "CreatePromptConfig", + "CreatePromptConfigDict", + "CreatePromptConfigOrDict", + "CreatePromptVersionConfig", + "CreatePromptVersionConfigDict", + "CreatePromptVersionConfigOrDict", + "GetPromptConfig", + "GetPromptConfigDict", + "GetPromptConfigOrDict", + "PromptRef", + "PromptRefDict", + "PromptRefOrDict", + "PromptVersionRef", + "PromptVersionRefDict", + "PromptVersionRefOrDict", + "OptimizeJobConfig", + "OptimizeJobConfigDict", + "OptimizeJobConfigOrDict", + "AgentEngineRuntimeRevision", + "AgentEngineRuntimeRevisionDict", + "AgentEngineRuntimeRevisionOrDict", + "A2aTaskState", + "State", + "Strategy", + "AcceleratorType", + "Type", + "JobState", + "ManagedTopicEnum", + "IdentityType", + "AgentServerMode", + "MemoryType", + "Operator", + "Language", + "MachineConfig", + "Protocol", + "DefaultContainerCategory", + "PostSnapshotAction", + "Framework", + "EvaluationItemType", + "SamplingMethod", + "EvaluationRunState", + "OptimizeTarget", + "MemoryMetadataMergeStrategy", + "GenerateMemoriesResponseGeneratedMemoryAction", + "SkillState", + "PromptOptimizerMethod", + "OptimizationMethod", + "PromptData", + "PromptDataDict", + "PromptDataOrDict", + "LLMMetric", + "CodeExecutionMetric", + "MetricPromptBuilder", + "RubricContentProperty", + "RubricContentPropertyDict", + "RubricContent", + "RubricContentDict", + "Rubric", + "RubricDict", + "RubricVerdict", + "RubricVerdictDict", + "CandidateResult", + "CandidateResultDict", + "Event", + "EventDict", + "Message", + "MessageDict", + "Importance", + "ParsedResponseUnion", + "_DeleteAgentEngineTaskRequestParameters", + "_GetAgentEngineTaskRequestParameters", + "_ListAgentEngineTasksRequestParameters", + "_CreateAgentEngineTaskRequestParameters", + "_AppendAgentEngineTaskEventRequestParameters", + "_ListAgentEngineTaskEventsRequestParameters", + "_CreateEvaluationItemParameters", + "_CreateEvaluationMetricParameters", + "_CreateEvaluationRunParameters", + "_CreateEvaluationSetParameters", + "_DeleteEvaluationMetricParameters", + "_EvaluateInstancesRequestParameters", + "_GenerateUserScenariosParameters", + "_GenerateLossClustersParameters", + "_GenerateInstanceRubricsRequest", + "_GetEvaluationMetricParameters", + "_GetEvaluationRunParameters", + "_GetEvaluationSetParameters", + "_GetEvaluationItemParameters", + "_ListEvaluationMetricsParameters", + "_OptimizeRequestParameters", + "_CustomJobParameters", + "_GetCustomJobParameters", + "_CancelQueryJobAgentEngineRequestParameters", + "_CheckQueryJobAgentEngineRequestParameters", + "_RunQueryJobAgentEngineRequestParameters", + "_CreateAgentEngineRequestParameters", + "_DeleteAgentEngineRequestParameters", + "_GetAgentEngineRequestParameters", + "_ListAgentEngineRequestParameters", + "_GetAgentEngineOperationParameters", + "_QueryAgentEngineRequestParameters", + "_UpdateAgentEngineRequestParameters", + "_CreateAgentEngineMemoryRequestParameters", + "_DeleteAgentEngineMemoryRequestParameters", + "_GenerateAgentEngineMemoriesRequestParameters", + "_GetAgentEngineMemoryRequestParameters", + "_IngestEventsRequestParameters", + "_ListAgentEngineMemoryRequestParameters", + "_GetAgentEngineMemoryOperationParameters", + "_GetAgentEngineGenerateMemoriesOperationParameters", + "_RetrieveAgentEngineMemoriesRequestParameters", + "_RetrieveMemoryProfilesRequestParameters", + "_RollbackAgentEngineMemoryRequestParameters", + "_UpdateAgentEngineMemoryRequestParameters", + "_PurgeAgentEngineMemoriesRequestParameters", + "_GetAgentEngineMemoryRevisionRequestParameters", + "_ListAgentEngineMemoryRevisionsRequestParameters", + "_GetAgentEngineRuntimeRevisionRequestParameters", + "_ListAgentEngineRuntimeRevisionsRequestParameters", + "_DeleteAgentEngineRuntimeRevisionRequestParameters", + "_GetDeleteAgentEngineRuntimeRevisionOperationParameters", + "_QueryAgentEngineRuntimeRevisionRequestParameters", + "_CreateAgentEngineSandboxRequestParameters", + "_DeleteAgentEngineSandboxRequestParameters", + "_ExecuteCodeAgentEngineSandboxRequestParameters", + "_GetAgentEngineSandboxRequestParameters", + "_ListAgentEngineSandboxesRequestParameters", + "_GetAgentEngineSandboxOperationParameters", + "_CreateSandboxEnvironmentTemplateRequestParameters", + "_DeleteSandboxEnvironmentTemplateRequestParameters", + "_GetSandboxEnvironmentTemplateRequestParameters", + "_ListSandboxEnvironmentTemplatesRequestParameters", + "_GetSandboxEnvironmentTemplateOperationParameters", + "_CreateSandboxEnvironmentSnapshotRequestParameters", + "_DeleteSandboxEnvironmentSnapshotRequestParameters", + "_GetSandboxEnvironmentSnapshotRequestParameters", + "_ListSandboxEnvironmentSnapshotsRequestParameters", + "_GetAgentEngineSandboxSnapshotOperationParameters", + "_CreateAgentEngineSessionRequestParameters", + "_DeleteAgentEngineSessionRequestParameters", + "_GetAgentEngineSessionRequestParameters", + "_ListAgentEngineSessionsRequestParameters", + "_GetAgentEngineSessionOperationParameters", + "_UpdateAgentEngineSessionRequestParameters", + "_AppendAgentEngineSessionEventRequestParameters", + "_ListAgentEngineSessionEventsRequestParameters", + "_AssembleDatasetParameters", + "_AssessDatasetParameters", + "_CreateMultimodalDatasetParameters", + "_DeleteMultimodalDatasetRequestParameters", + "_GetMultimodalDatasetParameters", + "_GetMultimodalDatasetOperationParameters", + "_ListMultimodalDatasetsRequestParameters", + "_UpdateMultimodalDatasetParameters", + "_CreateDatasetParameters", + "_CreateDatasetVersionParameters", + "_GetDatasetParameters", + "_GetDatasetVersionParameters", + "_GetDatasetOperationParameters", + "_ListDatasetsRequestParameters", + "_ListDatasetVersionsRequestParameters", + "_DeleteDatasetRequestParameters", + "_DeletePromptVersionRequestParameters", + "_RestoreVersionRequestParameters", + "_UpdateDatasetParameters", + "_CustomJobParameters", + "_GetCustomJobParameters", + "_OptimizeRequestParameters", + "_GetSkillRequestParameters", + "_RetrieveSkillsRequestParameters", + "_CreateSkillRequestParameters", + "_GetSkillOperationParameters", + "evals", + "agent_engines", + "prompts", + "PrebuiltMetric", + "RubricMetric", +] + + +def __getattr__(name: str) -> typing.Any: + if name == "PrebuiltMetric" or name == "RubricMetric": + module = importlib.import_module(".._evals_metric_loaders", __package__) + prebuilt_metric_obj = getattr(module, name) + globals()[name] = prebuilt_metric_obj + return prebuilt_metric_obj + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") diff --git a/agentplatform/_genai/types/agent_engines.py b/agentplatform/_genai/types/agent_engines.py new file mode 100644 index 0000000000..c47a08a11e --- /dev/null +++ b/agentplatform/_genai/types/agent_engines.py @@ -0,0 +1,16 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py new file mode 100644 index 0000000000..831ad80a02 --- /dev/null +++ b/agentplatform/_genai/types/common.py @@ -0,0 +1,20131 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +import json +import logging +import os +import re +import typing +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, +) +from google.genai import _common +from google.genai import types as genai_types +from pydantic import ( + ConfigDict, + Field, + PrivateAttr, + computed_field, + field_validator, + model_validator, +) +from typing_extensions import TypeAlias, TypedDict +from . import evals as evals_types +from . import prompts as prompts_types + + +def camel_to_snake(camel_case_string: str) -> str: + snake_case_string = re.sub(r"(? Any: + """Converts all camelCase keys to snake_case in a dict or list.""" + if isinstance(message, dict): + return { + camel_to_snake(key): _camel_key_to_snake(value) + for key, value in message.items() + } + elif isinstance(message, list): + return [_camel_key_to_snake(value) for value in message] + else: + return message + + +if typing.TYPE_CHECKING: + import pandas as pd + + PandasDataFrame: TypeAlias = pd.DataFrame +else: + try: + import pandas as pd + + PandasDataFrame = pd.DataFrame + except ImportError: + pd = None + PandasDataFrame = Any +if typing.TYPE_CHECKING: + import yaml +else: + try: + import yaml + except ImportError: + yaml = None + +logger = logging.getLogger("vertexai_genai.types") + +MetricSubclass = TypeVar("MetricSubclass", bound="Metric") + + +class A2aTaskState(_common.CaseInSensitiveEnum): + """Output only. The state of the task. The state of a new task is SUBMITTED by default. The state of a task can only be updated via AppendA2aTaskEvents API.""" + + STATE_UNSPECIFIED = "STATE_UNSPECIFIED" + """Task state unspecified. Default value if not set.""" + SUBMITTED = "SUBMITTED" + """Task is submitted and waiting to be processed.""" + WORKING = "WORKING" + """Task is actively being processed.""" + COMPLETED = "COMPLETED" + """Task is finished.""" + CANCELLED = "CANCELLED" + """Task is cancelled.""" + FAILED = "FAILED" + """Task has failed.""" + REJECTED = "REJECTED" + """Task is rejected by the system.""" + INPUT_REQUIRED = "INPUT_REQUIRED" + """Task requires input from the user.""" + AUTH_REQUIRED = "AUTH_REQUIRED" + """Task requires auth (e.g. OAuth) from the user.""" + PAUSED = "PAUSED" + """Task is paused.""" + + +class State(_common.CaseInSensitiveEnum): + """The new state of the task.""" + + STATE_UNSPECIFIED = "STATE_UNSPECIFIED" + """Task state unspecified. Default value if not set.""" + SUBMITTED = "SUBMITTED" + """Task is submitted and waiting to be processed.""" + WORKING = "WORKING" + """Task is actively being processed.""" + COMPLETED = "COMPLETED" + """Task is finished.""" + CANCELLED = "CANCELLED" + """Task is cancelled.""" + FAILED = "FAILED" + """Task has failed.""" + REJECTED = "REJECTED" + """Task is rejected by the system.""" + INPUT_REQUIRED = "INPUT_REQUIRED" + """Task requires input from the user.""" + AUTH_REQUIRED = "AUTH_REQUIRED" + """Task requires auth (e.g. OAuth) from the user.""" + PAUSED = "PAUSED" + """Task is paused.""" + + +class Strategy(_common.CaseInSensitiveEnum): + """This determines which type of scheduling strategy to use.""" + + STRATEGY_UNSPECIFIED = "STRATEGY_UNSPECIFIED" + """Strategy will default to STANDARD.""" + ON_DEMAND = "ON_DEMAND" + """Deprecated. Regular on-demand provisioning strategy.""" + LOW_COST = "LOW_COST" + """Deprecated. Low cost by making potential use of spot resources.""" + STANDARD = "STANDARD" + """Standard provisioning strategy uses regular on-demand resources.""" + SPOT = "SPOT" + """Spot provisioning strategy uses spot resources.""" + FLEX_START = "FLEX_START" + """Flex Start strategy uses DWS to queue for resources.""" + + +class AcceleratorType(_common.CaseInSensitiveEnum): + """Immutable. The type of accelerator(s) that may be attached to the machine as per accelerator_count.""" + + ACCELERATOR_TYPE_UNSPECIFIED = "ACCELERATOR_TYPE_UNSPECIFIED" + """Unspecified accelerator type, which means no accelerator.""" + NVIDIA_TESLA_K80 = "NVIDIA_TESLA_K80" + """Deprecated: Nvidia Tesla K80 GPU has reached end of support, see https://cloud.google.com/compute/docs/eol/k80-eol.""" + NVIDIA_TESLA_P100 = "NVIDIA_TESLA_P100" + """Nvidia Tesla P100 GPU.""" + NVIDIA_TESLA_V100 = "NVIDIA_TESLA_V100" + """Nvidia Tesla V100 GPU.""" + NVIDIA_TESLA_P4 = "NVIDIA_TESLA_P4" + """Nvidia Tesla P4 GPU.""" + NVIDIA_TESLA_T4 = "NVIDIA_TESLA_T4" + """Nvidia Tesla T4 GPU.""" + NVIDIA_TESLA_A100 = "NVIDIA_TESLA_A100" + """Nvidia Tesla A100 GPU.""" + NVIDIA_A100_80GB = "NVIDIA_A100_80GB" + """Nvidia A100 80GB GPU.""" + NVIDIA_L4 = "NVIDIA_L4" + """Nvidia L4 GPU.""" + NVIDIA_H100_80GB = "NVIDIA_H100_80GB" + """Nvidia H100 80Gb GPU.""" + NVIDIA_H100_MEGA_80GB = "NVIDIA_H100_MEGA_80GB" + """Nvidia H100 Mega 80Gb GPU.""" + NVIDIA_H200_141GB = "NVIDIA_H200_141GB" + """Nvidia H200 141Gb GPU.""" + NVIDIA_B200 = "NVIDIA_B200" + """Nvidia B200 GPU.""" + NVIDIA_GB200 = "NVIDIA_GB200" + """Nvidia GB200 GPU.""" + NVIDIA_RTX_PRO_6000 = "NVIDIA_RTX_PRO_6000" + """Nvidia RTX Pro 6000 GPU.""" + TPU_V2 = "TPU_V2" + """TPU v2.""" + TPU_V3 = "TPU_V3" + """TPU v3.""" + TPU_V4_POD = "TPU_V4_POD" + """TPU v4.""" + TPU_V5_LITEPOD = "TPU_V5_LITEPOD" + """TPU v5.""" + + +class Type(_common.CaseInSensitiveEnum): + """Specifies the reservation affinity type.""" + + TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED" + """Default value. This should not be used.""" + NO_RESERVATION = "NO_RESERVATION" + """Do not consume from any reserved capacity, only use on-demand.""" + ANY_RESERVATION = "ANY_RESERVATION" + """Consume any reservation available, falling back to on-demand.""" + SPECIFIC_RESERVATION = "SPECIFIC_RESERVATION" + """Consume from a specific reservation. When chosen, the reservation must be identified via the `key` and `values` fields.""" + + +class JobState(_common.CaseInSensitiveEnum): + """Output only. The detailed state of the job.""" + + JOB_STATE_UNSPECIFIED = "JOB_STATE_UNSPECIFIED" + """The job state is unspecified.""" + JOB_STATE_QUEUED = "JOB_STATE_QUEUED" + """The job has been just created or resumed and processing has not yet begun.""" + JOB_STATE_PENDING = "JOB_STATE_PENDING" + """The service is preparing to run the job.""" + JOB_STATE_RUNNING = "JOB_STATE_RUNNING" + """The job is in progress.""" + JOB_STATE_SUCCEEDED = "JOB_STATE_SUCCEEDED" + """The job completed successfully.""" + JOB_STATE_FAILED = "JOB_STATE_FAILED" + """The job failed.""" + JOB_STATE_CANCELLING = "JOB_STATE_CANCELLING" + """The job is being cancelled. From this state the job may only go to either `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`.""" + JOB_STATE_CANCELLED = "JOB_STATE_CANCELLED" + """The job has been cancelled.""" + JOB_STATE_PAUSED = "JOB_STATE_PAUSED" + """The job has been stopped, and can be resumed.""" + JOB_STATE_EXPIRED = "JOB_STATE_EXPIRED" + """The job has expired.""" + JOB_STATE_UPDATING = "JOB_STATE_UPDATING" + """The job is being updated. Only jobs in the `RUNNING` state can be updated. After updating, the job goes back to the `RUNNING` state.""" + JOB_STATE_PARTIALLY_SUCCEEDED = "JOB_STATE_PARTIALLY_SUCCEEDED" + """The job is partially succeeded, some results may be missing due to errors.""" + + +class ManagedTopicEnum(_common.CaseInSensitiveEnum): + """Represents the managed memory topic.""" + + MANAGED_TOPIC_ENUM_UNSPECIFIED = "MANAGED_TOPIC_ENUM_UNSPECIFIED" + """Represents an unspecified topic. This value should not be used.""" + USER_PERSONAL_INFO = "USER_PERSONAL_INFO" + """Represents significant personal information about the User like first names, relationships, hobbies, important dates.""" + USER_PREFERENCES = "USER_PREFERENCES" + """Represents stated or implied likes, dislikes, preferred styles, or patterns.""" + KEY_CONVERSATION_DETAILS = "KEY_CONVERSATION_DETAILS" + """Represents important milestones or conclusions within the dialogue.""" + EXPLICIT_INSTRUCTIONS = "EXPLICIT_INSTRUCTIONS" + """Represents information that the user explicitly requested to remember or forget.""" + + +class IdentityType(_common.CaseInSensitiveEnum): + """The identity type to use for the Reasoning Engine. If not specified, the `service_account` field will be used if set, otherwise the default Vertex AI Reasoning Engine Service Agent in the project will be used.""" + + IDENTITY_TYPE_UNSPECIFIED = "IDENTITY_TYPE_UNSPECIFIED" + """Default value. Use a custom service account if the `service_account` field is set, otherwise use the default Vertex AI Reasoning Engine Service Agent in the project. Same behavior as SERVICE_ACCOUNT.""" + SERVICE_ACCOUNT = "SERVICE_ACCOUNT" + """Use a custom service account if the `service_account` field is set, otherwise use the default Vertex AI Reasoning Engine Service Agent in the project.""" + AGENT_IDENTITY = "AGENT_IDENTITY" + """Use Agent Identity. The `service_account` field must not be set.""" + + +class AgentServerMode(_common.CaseInSensitiveEnum): + """The agent server mode.""" + + AGENT_SERVER_MODE_UNSPECIFIED = "AGENT_SERVER_MODE_UNSPECIFIED" + """Unspecified agent server mode. Do not use.""" + STABLE = "STABLE" + """Stable agent server mode. This mode has everything stable and well-tested features agent engine offers.""" + EXPERIMENTAL = "EXPERIMENTAL" + """Experimental agent server mode. This mode contains experimental features.""" + + +class MemoryType(_common.CaseInSensitiveEnum): + """The type of the memory.""" + + MEMORY_TYPE_UNSPECIFIED = "MEMORY_TYPE_UNSPECIFIED" + """Represents an unspecified memory type. This value should not be used.""" + NATURAL_LANGUAGE_COLLECTION = "NATURAL_LANGUAGE_COLLECTION" + """Indicates belonging to a collection of natural language memories.""" + STRUCTURED_PROFILE = "STRUCTURED_PROFILE" + """Indicates belonging to a structured profile.""" + + +class Operator(_common.CaseInSensitiveEnum): + """Represents the operator to apply to the filter. If not set, then EQUAL will be used.""" + + OPERATOR_UNSPECIFIED = "OPERATOR_UNSPECIFIED" + """Represents an unspecified operator. Defaults to EQUAL.""" + EQUAL = "EQUAL" + """Equal to.""" + GREATER_THAN = "GREATER_THAN" + """Greater than.""" + LESS_THAN = "LESS_THAN" + """Less than.""" + + +class Language(_common.CaseInSensitiveEnum): + """The coding language supported in this environment.""" + + LANGUAGE_UNSPECIFIED = "LANGUAGE_UNSPECIFIED" + """The default value. This value is unused.""" + LANGUAGE_PYTHON = "LANGUAGE_PYTHON" + """The coding language is Python.""" + LANGUAGE_JAVASCRIPT = "LANGUAGE_JAVASCRIPT" + """The coding language is JavaScript.""" + + +class MachineConfig(_common.CaseInSensitiveEnum): + """The machine config of the code execution environment.""" + + MACHINE_CONFIG_UNSPECIFIED = "MACHINE_CONFIG_UNSPECIFIED" + """The default value: milligcu 2000, memory 1.5Gib""" + MACHINE_CONFIG_VCPU4_RAM4GIB = "MACHINE_CONFIG_VCPU4_RAM4GIB" + """The default value: milligcu 4000, memory 4 Gib""" + + +class Protocol(_common.CaseInSensitiveEnum): + """Protocol for port. Defaults to TCP if not specified.""" + + PROTOCOL_UNSPECIFIED = "PROTOCOL_UNSPECIFIED" + """Unspecified protocol. Defaults to TCP.""" + TCP = "TCP" + """TCP protocol.""" + UDP = "UDP" + """UDP protocol.""" + + +class DefaultContainerCategory(_common.CaseInSensitiveEnum): + """The category of the default container image.""" + + DEFAULT_CONTAINER_CATEGORY_UNSPECIFIED = "DEFAULT_CONTAINER_CATEGORY_UNSPECIFIED" + """The default value. This value is unused.""" + DEFAULT_CONTAINER_CATEGORY_COMPUTER_USE = "DEFAULT_CONTAINER_CATEGORY_COMPUTER_USE" + """The default container image for Computer Use.""" + + +class PostSnapshotAction(_common.CaseInSensitiveEnum): + """Input only. Action to take on the source SandboxEnvironment after the snapshot is taken. This field is only used in CreateSandboxEnvironmentSnapshotRequest and it is not stored in the resource.""" + + POST_SNAPSHOT_ACTION_UNSPECIFIED = "POST_SNAPSHOT_ACTION_UNSPECIFIED" + """The default value. This value is unused.""" + RUNNING = "RUNNING" + """Sandbox environment will continue to run after snapshot is taken.""" + PAUSE = "PAUSE" + """Sandbox environment will be paused after snapshot is taken.""" + + +class Framework(_common.CaseInSensitiveEnum): + """Framework used to build the application.""" + + FRAMEWORK_UNSPECIFIED = "FRAMEWORK_UNSPECIFIED" + """Unspecified framework.""" + REACT = "REACT" + """React framework.""" + ANGULAR = "ANGULAR" + """Angular framework.""" + + +class EvaluationItemType(_common.CaseInSensitiveEnum): + """The type of the EvaluationItem.""" + + EVALUATION_ITEM_TYPE_UNSPECIFIED = "EVALUATION_ITEM_TYPE_UNSPECIFIED" + """The default value. This value is unused.""" + REQUEST = "REQUEST" + """The EvaluationItem is a request to evaluate.""" + RESULT = "RESULT" + """The EvaluationItem is the result of evaluation.""" + + +class SamplingMethod(_common.CaseInSensitiveEnum): + """Represents the sampling method for a BigQuery request set.""" + + UNSPECIFIED = "UNSPECIFIED" + """Sampling method is unspecified.""" + RANDOM = "RANDOM" + """Sampling method is random.""" + + +class EvaluationRunState(_common.CaseInSensitiveEnum): + """Represents the state of an evaluation run.""" + + UNSPECIFIED = "UNSPECIFIED" + """Evaluation run state is unspecified.""" + PENDING = "PENDING" + """Evaluation run is pending.""" + RUNNING = "RUNNING" + """Evaluation run is in progress.""" + SUCCEEDED = "SUCCEEDED" + """Evaluation run has succeeded.""" + FAILED = "FAILED" + """Evaluation run failed.""" + CANCELLED = "CANCELLED" + """Evaluation run was cancelled.""" + INFERENCE = "INFERENCE" + """Evaluation run is performing inference.""" + GENERATING_RUBRICS = "GENERATING_RUBRICS" + """Evaluation run is performing rubric generation.""" + + +class OptimizeTarget(_common.CaseInSensitiveEnum): + """Specifies the method for calling the optimize_prompt.""" + + OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO" + """The data driven prompt optimizer designer for prompts from Android core API.""" + OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS = "OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS" + """The prompt optimizer based on user provided examples with rubrics.""" + OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE = ( + "OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE" + ) + """The prompt optimizer based on user provided examples with target responses.""" + + +class MemoryMetadataMergeStrategy(_common.CaseInSensitiveEnum): + """The strategy to use when applying metadata to existing memories during consolidation.""" + + METADATA_MERGE_STRATEGY_UNSPECIFIED = "METADATA_MERGE_STRATEGY_UNSPECIFIED" + """The metadata merge strategy is unspecified.""" + OVERWRITE = "OVERWRITE" + """Replace the metadata of the updated memories with the new metadata.""" + MERGE = "MERGE" + """Append new metadata to the existing metadata. If there are duplicate keys, the existing values will be overwritten.""" + REQUIRE_EXACT_MATCH = "REQUIRE_EXACT_MATCH" + """Restrict consolidation to memories that have exactly the same metadata as the request. If a memory doesn't have the same metadata, it is not eligible for consolidation.""" + + +class GenerateMemoriesResponseGeneratedMemoryAction(_common.CaseInSensitiveEnum): + """The action to take.""" + + ACTION_UNSPECIFIED = "ACTION_UNSPECIFIED" + """The action is unspecified.""" + CREATED = "CREATED" + """The memory was created.""" + UPDATED = "UPDATED" + """The memory was updated. The `fact` field may not be updated if the existing fact is still accurate.""" + DELETED = "DELETED" + """The memory was deleted.""" + + +class SkillState(_common.CaseInSensitiveEnum): + """State of the Skill.""" + + STATE_UNSPECIFIED = "STATE_UNSPECIFIED" + """The state of the Skill is unspecified.""" + ACTIVE = "ACTIVE" + """The Skill is active.""" + CREATING = "CREATING" + """The Skill is being created.""" + FAILED = "FAILED" + """The Skill was created, but failed to process.""" + DELETING = "DELETING" + """The Skill is being deleted.""" + + +class PromptOptimizerMethod(_common.CaseInSensitiveEnum): + """The method for data driven prompt optimization.""" + + VAPO = "VAPO" + """The default data driven Vertex AI Prompt Optimizer.""" + OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO" + """The data driven prompt optimizer designer for prompts from Android core API.""" + + +class OptimizationMethod(_common.CaseInSensitiveEnum): + """The method for data driven prompt optimization.""" + + VAPO = "VAPO" + """The default data driven Vertex AI Prompt Optimizer.""" + OPTIMIZATION_TARGET_GEMINI_NANO = "OPTIMIZATION_TARGET_GEMINI_NANO" + """The data driven prompt optimizer designer for prompts from Android core API.""" + + +class DeleteAgentEngineTaskConfig(_common.BaseModel): + """Config for deleting an Agent Engine Task.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineTaskConfigDict(TypedDict, total=False): + """Config for deleting an Agent Engine Task.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineTaskConfigOrDict = Union[ + DeleteAgentEngineTaskConfig, DeleteAgentEngineTaskConfigDict +] + + +class _DeleteAgentEngineTaskRequestParameters(_common.BaseModel): + """Parameters for deleting an agent engine task.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine task.""" + ) + config: Optional[DeleteAgentEngineTaskConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteAgentEngineTaskRequestParametersDict(TypedDict, total=False): + """Parameters for deleting an agent engine task.""" + + name: Optional[str] + """Name of the agent engine task.""" + + config: Optional[DeleteAgentEngineTaskConfigDict] + """""" + + +_DeleteAgentEngineTaskRequestParametersOrDict = Union[ + _DeleteAgentEngineTaskRequestParameters, _DeleteAgentEngineTaskRequestParametersDict +] + + +class GetAgentEngineTaskConfig(_common.BaseModel): + """Config for getting an Agent Engine Task.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineTaskConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Task.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineTaskConfigOrDict = Union[ + GetAgentEngineTaskConfig, GetAgentEngineTaskConfigDict +] + + +class _GetAgentEngineTaskRequestParameters(_common.BaseModel): + """Parameters for getting an agent engine task.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine task.""" + ) + config: Optional[GetAgentEngineTaskConfig] = Field(default=None, description="""""") + + +class _GetAgentEngineTaskRequestParametersDict(TypedDict, total=False): + """Parameters for getting an agent engine task.""" + + name: Optional[str] + """Name of the agent engine task.""" + + config: Optional[GetAgentEngineTaskConfigDict] + """""" + + +_GetAgentEngineTaskRequestParametersOrDict = Union[ + _GetAgentEngineTaskRequestParameters, _GetAgentEngineTaskRequestParametersDict +] + + +class TaskArtifact(_common.BaseModel): + """The artifact of the task event.""" + + artifact_id: Optional[str] = Field( + default=None, + description="""Required. The unique identifier of the artifact within the task. This id is provided by the creator of the artifact.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. A human readable description of the artifact.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Optional. The human-readable name of the artifact provided by the creator.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Additional metadata for the artifact. For A2A, the URIs of the extensions that were used to produce this artifact will be stored here.""", + ) + parts: Optional[list[genai_types.Part]] = Field( + default=None, description="""The parts of the artifact.""" + ) + + +class TaskArtifactDict(TypedDict, total=False): + """The artifact of the task event.""" + + artifact_id: Optional[str] + """Required. The unique identifier of the artifact within the task. This id is provided by the creator of the artifact.""" + + description: Optional[str] + """Optional. A human readable description of the artifact.""" + + display_name: Optional[str] + """Optional. The human-readable name of the artifact provided by the creator.""" + + metadata: Optional[dict[str, Any]] + """Optional. Additional metadata for the artifact. For A2A, the URIs of the extensions that were used to produce this artifact will be stored here.""" + + parts: Optional[list[genai_types.PartDict]] + """The parts of the artifact.""" + + +TaskArtifactOrDict = Union[TaskArtifact, TaskArtifactDict] + + +class TaskOutput(_common.BaseModel): + """The output of the task event.""" + + artifacts: Optional[list[TaskArtifact]] = Field( + default=None, description="""The artifacts of the task event.""" + ) + + +class TaskOutputDict(TypedDict, total=False): + """The output of the task event.""" + + artifacts: Optional[list[TaskArtifactDict]] + """The artifacts of the task event.""" + + +TaskOutputOrDict = Union[TaskOutput, TaskOutputDict] + + +class TaskMessage(_common.BaseModel): + """The message of the task event.""" + + message_id: Optional[str] = Field( + default=None, description="""Required. The unique identifier of the message.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. A2A message may have extension_uris or reference_task_ids. They will be stored under metadata.""", + ) + parts: Optional[list[genai_types.Part]] = Field( + default=None, description="""The parts of the message.""" + ) + role: Optional[str] = Field( + default=None, + description="""Required. The role of the sender of the message. e.g. "user", "agent".""", + ) + + +class TaskMessageDict(TypedDict, total=False): + """The message of the task event.""" + + message_id: Optional[str] + """Required. The unique identifier of the message.""" + + metadata: Optional[dict[str, Any]] + """Optional. A2A message may have extension_uris or reference_task_ids. They will be stored under metadata.""" + + parts: Optional[list[genai_types.PartDict]] + """The parts of the message.""" + + role: Optional[str] + """Required. The role of the sender of the message. e.g. "user", "agent".""" + + +TaskMessageOrDict = Union[TaskMessage, TaskMessageDict] + + +class TaskStatusDetails(_common.BaseModel): + """The status details of the task event.""" + + task_message: Optional[TaskMessage] = Field( + default=None, description="""The status of the task event.""" + ) + + +class TaskStatusDetailsDict(TypedDict, total=False): + """The status details of the task event.""" + + task_message: Optional[TaskMessageDict] + """The status of the task event.""" + + +TaskStatusDetailsOrDict = Union[TaskStatusDetails, TaskStatusDetailsDict] + + +class A2aTask(_common.BaseModel): + """A task.""" + + context_id: Optional[str] = Field( + default=None, + description="""Optional. A generic identifier for grouping related tasks (e.g., session_id, workflow_id).""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, description="""Output only. The creation timestamp of the task.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="""Optional. Arbitrary, user-defined metadata.""" + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the task. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/a2aTasks/{a2a_task}`""", + ) + next_event_sequence_number: Optional[int] = Field( + default=None, + description="""Output only. The next event sequence number to be appended to the task. This value starts at 1 and is guaranteed to be monotonically increasing.""", + ) + output: Optional[TaskOutput] = Field( + default=None, description="""Optional. The final output of the task.""" + ) + state: Optional[A2aTaskState] = Field( + default=None, + description="""Output only. The state of the task. The state of a new task is SUBMITTED by default. The state of a task can only be updated via AppendA2aTaskEvents API.""", + ) + status_details: Optional[TaskStatusDetails] = Field( + default=None, description="""Optional. The status details of the task.""" + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The last update timestamp of the task.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this task is considered expired. This is *always* provided on output, and is calculated based on the `ttl` if set on the request""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL (Time To Live) for the task. If not set, the task will expire in 24 hours by default. Valid range: (0 seconds, 1000 days]""", + ) + + +class A2aTaskDict(TypedDict, total=False): + """A task.""" + + context_id: Optional[str] + """Optional. A generic identifier for grouping related tasks (e.g., session_id, workflow_id).""" + + create_time: Optional[datetime.datetime] + """Output only. The creation timestamp of the task.""" + + metadata: Optional[dict[str, Any]] + """Optional. Arbitrary, user-defined metadata.""" + + name: Optional[str] + """Identifier. The resource name of the task. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/a2aTasks/{a2a_task}`""" + + next_event_sequence_number: Optional[int] + """Output only. The next event sequence number to be appended to the task. This value starts at 1 and is guaranteed to be monotonically increasing.""" + + output: Optional[TaskOutputDict] + """Optional. The final output of the task.""" + + state: Optional[A2aTaskState] + """Output only. The state of the task. The state of a new task is SUBMITTED by default. The state of a task can only be updated via AppendA2aTaskEvents API.""" + + status_details: Optional[TaskStatusDetailsDict] + """Optional. The status details of the task.""" + + update_time: Optional[datetime.datetime] + """Output only. The last update timestamp of the task.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this task is considered expired. This is *always* provided on output, and is calculated based on the `ttl` if set on the request""" + + ttl: Optional[str] + """Optional. Input only. The TTL (Time To Live) for the task. If not set, the task will expire in 24 hours by default. Valid range: (0 seconds, 1000 days]""" + + +A2aTaskOrDict = Union[A2aTask, A2aTaskDict] + + +class ListAgentEngineTasksConfig(_common.BaseModel): + """Config for listing agent engine tasks.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + order_by: Optional[str] = Field( + default=None, + description="""A comma-separated list of fields to order by, sorted in ascending order. + Use "desc" after a field name for descending. + If this field is omitted, the default ordering is `create_time` descending. + More detail in [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time` + + Example: `create_time desc`.""", + ) + + +class ListAgentEngineTasksConfigDict(TypedDict, total=False): + """Config for listing agent engine tasks.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + order_by: Optional[str] + """A comma-separated list of fields to order by, sorted in ascending order. + Use "desc" after a field name for descending. + If this field is omitted, the default ordering is `create_time` descending. + More detail in [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time` + + Example: `create_time desc`.""" + + +ListAgentEngineTasksConfigOrDict = Union[ + ListAgentEngineTasksConfig, ListAgentEngineTasksConfigDict +] + + +class _ListAgentEngineTasksRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[ListAgentEngineTasksConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineTasksRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[ListAgentEngineTasksConfigDict] + """""" + + +_ListAgentEngineTasksRequestParametersOrDict = Union[ + _ListAgentEngineTasksRequestParameters, _ListAgentEngineTasksRequestParametersDict +] + + +class ListAgentEngineTasksResponse(_common.BaseModel): + """Response for listing agent engine tasks.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + a2aTasks: Optional[list[A2aTask]] = Field( + default=None, description="""List of agent engine tasks.""" + ) + + +class ListAgentEngineTasksResponseDict(TypedDict, total=False): + """Response for listing agent engine tasks.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + a2aTasks: Optional[list[A2aTaskDict]] + """List of agent engine tasks.""" + + +ListAgentEngineTasksResponseOrDict = Union[ + ListAgentEngineTasksResponse, ListAgentEngineTasksResponseDict +] + + +class CreateAgentEngineTaskConfig(_common.BaseModel): + """Config for creating a Session.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + context_id: Optional[str] = Field( + default=None, description="""The context id of the task to create.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="""The metadata of the task to create.""" + ) + status_details: Optional[TaskStatusDetails] = Field( + default=None, description="""The status details of the task to create.""" + ) + output: Optional[TaskOutput] = Field( + default=None, description="""The output of the task to create.""" + ) + + +class CreateAgentEngineTaskConfigDict(TypedDict, total=False): + """Config for creating a Session.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + context_id: Optional[str] + """The context id of the task to create.""" + + metadata: Optional[dict[str, Any]] + """The metadata of the task to create.""" + + status_details: Optional[TaskStatusDetailsDict] + """The status details of the task to create.""" + + output: Optional[TaskOutputDict] + """The output of the task to create.""" + + +CreateAgentEngineTaskConfigOrDict = Union[ + CreateAgentEngineTaskConfig, CreateAgentEngineTaskConfigDict +] + + +class _CreateAgentEngineTaskRequestParameters(_common.BaseModel): + """Parameters for creating Agent Engine Tasks.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to create the task under.""", + ) + a2a_task_id: Optional[str] = Field( + default=None, description="""The ID of the task.""" + ) + config: Optional[CreateAgentEngineTaskConfig] = Field( + default=None, description="""""" + ) + + +class _CreateAgentEngineTaskRequestParametersDict(TypedDict, total=False): + """Parameters for creating Agent Engine Tasks.""" + + name: Optional[str] + """Name of the agent engine to create the task under.""" + + a2a_task_id: Optional[str] + """The ID of the task.""" + + config: Optional[CreateAgentEngineTaskConfigDict] + """""" + + +_CreateAgentEngineTaskRequestParametersOrDict = Union[ + _CreateAgentEngineTaskRequestParameters, _CreateAgentEngineTaskRequestParametersDict +] + + +class TaskMetadataChange(_common.BaseModel): + """An event representing a change to the task's top-level metadata. example: metadata_change: { new_metadata: { "name": "My task", } update_mask: { paths: "name" } }""" + + new_metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Required. The complete state of the metadata object *after* the change.""", + ) + update_mask: Optional[str] = Field( + default=None, + description="""Optional. A field mask indicating which paths in the Struct were changed. If not set, all fields will be updated. go/aip-internal/cloud-standard/2412""", + ) + + +class TaskMetadataChangeDict(TypedDict, total=False): + """An event representing a change to the task's top-level metadata. example: metadata_change: { new_metadata: { "name": "My task", } update_mask: { paths: "name" } }""" + + new_metadata: Optional[dict[str, Any]] + """Required. The complete state of the metadata object *after* the change.""" + + update_mask: Optional[str] + """Optional. A field mask indicating which paths in the Struct were changed. If not set, all fields will be updated. go/aip-internal/cloud-standard/2412""" + + +TaskMetadataChangeOrDict = Union[TaskMetadataChange, TaskMetadataChangeDict] + + +class TaskArtifactChange(_common.BaseModel): + """Describes changes to the artifact list.""" + + added_artifacts: Optional[list[TaskArtifact]] = Field( + default=None, + description="""Optional. A list of brand-new artifacts created in this event.""", + ) + deleted_artifact_ids: Optional[list[str]] = Field( + default=None, + description="""Optional. A list of artifact IDs that were removed in this event.""", + ) + updated_artifacts: Optional[list[TaskArtifact]] = Field( + default=None, + description="""Optional. A list of existing artifacts that were modified in this event.""", + ) + + +class TaskArtifactChangeDict(TypedDict, total=False): + """Describes changes to the artifact list.""" + + added_artifacts: Optional[list[TaskArtifactDict]] + """Optional. A list of brand-new artifacts created in this event.""" + + deleted_artifact_ids: Optional[list[str]] + """Optional. A list of artifact IDs that were removed in this event.""" + + updated_artifacts: Optional[list[TaskArtifactDict]] + """Optional. A list of existing artifacts that were modified in this event.""" + + +TaskArtifactChangeOrDict = Union[TaskArtifactChange, TaskArtifactChangeDict] + + +class TaskOutputChange(_common.BaseModel): + """An event representing a change to the task's outputs.""" + + task_artifact_change: Optional[TaskArtifactChange] = Field( + default=None, + description="""Required. A granular change to the list of artifacts.""", + ) + + +class TaskOutputChangeDict(TypedDict, total=False): + """An event representing a change to the task's outputs.""" + + task_artifact_change: Optional[TaskArtifactChangeDict] + """Required. A granular change to the list of artifacts.""" + + +TaskOutputChangeOrDict = Union[TaskOutputChange, TaskOutputChangeDict] + + +class TaskStateChange(_common.BaseModel): + """A message representing a change in a task's state.""" + + new_state: Optional[State] = Field( + default=None, description="""Required. The new state of the task.""" + ) + + +class TaskStateChangeDict(TypedDict, total=False): + """A message representing a change in a task's state.""" + + new_state: Optional[State] + """Required. The new state of the task.""" + + +TaskStateChangeOrDict = Union[TaskStateChange, TaskStateChangeDict] + + +class TaskStatusDetailsChange(_common.BaseModel): + """Represents a change to the task's status details.""" + + new_task_status: Optional[TaskStatusDetails] = Field( + default=None, + description="""Required. The complete state of the task's status *after* the change.""", + ) + + +class TaskStatusDetailsChangeDict(TypedDict, total=False): + """Represents a change to the task's status details.""" + + new_task_status: Optional[TaskStatusDetailsDict] + """Required. The complete state of the task's status *after* the change.""" + + +TaskStatusDetailsChangeOrDict = Union[ + TaskStatusDetailsChange, TaskStatusDetailsChangeDict +] + + +class TaskEventData(_common.BaseModel): + """Data for a TaskEvent.""" + + metadata_change: Optional[TaskMetadataChange] = Field( + default=None, description="""Optional. A change to the task's metadata.""" + ) + output_change: Optional[TaskOutputChange] = Field( + default=None, description="""Optional. A change to the task's final outputs.""" + ) + state_change: Optional[TaskStateChange] = Field( + default=None, description="""Optional. A change in the task's state.""" + ) + status_details_change: Optional[TaskStatusDetailsChange] = Field( + default=None, + description="""Optional. A change to the framework-specific status details.""", + ) + + +class TaskEventDataDict(TypedDict, total=False): + """Data for a TaskEvent.""" + + metadata_change: Optional[TaskMetadataChangeDict] + """Optional. A change to the task's metadata.""" + + output_change: Optional[TaskOutputChangeDict] + """Optional. A change to the task's final outputs.""" + + state_change: Optional[TaskStateChangeDict] + """Optional. A change in the task's state.""" + + status_details_change: Optional[TaskStatusDetailsChangeDict] + """Optional. A change to the framework-specific status details.""" + + +TaskEventDataOrDict = Union[TaskEventData, TaskEventDataDict] + + +class TaskEvent(_common.BaseModel): + """A task event.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, description="""Output only. The create time of the event.""" + ) + event_data: Optional[TaskEventData] = Field( + default=None, description="""Required. The delta associated with the event.""" + ) + event_sequence_number: Optional[int] = Field( + default=None, + description="""Required. The sequence number of the event. This is used to uniquely identify the event within the task and order events chronologically. This is a id generated by the SDK.""", + ) + + +class TaskEventDict(TypedDict, total=False): + """A task event.""" + + create_time: Optional[datetime.datetime] + """Output only. The create time of the event.""" + + event_data: Optional[TaskEventDataDict] + """Required. The delta associated with the event.""" + + event_sequence_number: Optional[int] + """Required. The sequence number of the event. This is used to uniquely identify the event within the task and order events chronologically. This is a id generated by the SDK.""" + + +TaskEventOrDict = Union[TaskEvent, TaskEventDict] + + +class AppendAgentEngineTaskEventConfig(_common.BaseModel): + """Config for appending Agent Engine task events.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class AppendAgentEngineTaskEventConfigDict(TypedDict, total=False): + """Config for appending Agent Engine task events.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +AppendAgentEngineTaskEventConfigOrDict = Union[ + AppendAgentEngineTaskEventConfig, AppendAgentEngineTaskEventConfigDict +] + + +class _AppendAgentEngineTaskEventRequestParameters(_common.BaseModel): + """Parameters for appending Agent Engine task events.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the Agent Engine task to append the events to.""", + ) + task_events: Optional[list[TaskEvent]] = Field( + default=None, description="""The events to append to the task.""" + ) + config: Optional[AppendAgentEngineTaskEventConfig] = Field( + default=None, description="""""" + ) + + +class _AppendAgentEngineTaskEventRequestParametersDict(TypedDict, total=False): + """Parameters for appending Agent Engine task events.""" + + name: Optional[str] + """Name of the Agent Engine task to append the events to.""" + + task_events: Optional[list[TaskEventDict]] + """The events to append to the task.""" + + config: Optional[AppendAgentEngineTaskEventConfigDict] + """""" + + +_AppendAgentEngineTaskEventRequestParametersOrDict = Union[ + _AppendAgentEngineTaskEventRequestParameters, + _AppendAgentEngineTaskEventRequestParametersDict, +] + + +class AppendAgentEngineTaskEventResponse(_common.BaseModel): + """Response for appending Agent Engine task events.""" + + pass + + +class AppendAgentEngineTaskEventResponseDict(TypedDict, total=False): + """Response for appending Agent Engine task events.""" + + pass + + +AppendAgentEngineTaskEventResponseOrDict = Union[ + AppendAgentEngineTaskEventResponse, AppendAgentEngineTaskEventResponseDict +] + + +class ListAgentEngineTaskEventsConfig(_common.BaseModel): + """Config for listing agent engine tasks.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + order_by: Optional[str] = Field( + default=None, + description="""A comma-separated list of fields to order by, sorted in ascending order. + Use "desc" after a field name for descending. + If this field is omitted, the default ordering is `create_time` descending. + More detail in [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time` + + Example: `create_time desc`.""", + ) + + +class ListAgentEngineTaskEventsConfigDict(TypedDict, total=False): + """Config for listing agent engine tasks.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + order_by: Optional[str] + """A comma-separated list of fields to order by, sorted in ascending order. + Use "desc" after a field name for descending. + If this field is omitted, the default ordering is `create_time` descending. + More detail in [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time` + + Example: `create_time desc`.""" + + +ListAgentEngineTaskEventsConfigOrDict = Union[ + ListAgentEngineTaskEventsConfig, ListAgentEngineTaskEventsConfigDict +] + + +class _ListAgentEngineTaskEventsRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the Agent Engine task.""" + ) + config: Optional[ListAgentEngineTaskEventsConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineTaskEventsRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + name: Optional[str] + """Name of the Agent Engine task.""" + + config: Optional[ListAgentEngineTaskEventsConfigDict] + """""" + + +_ListAgentEngineTaskEventsRequestParametersOrDict = Union[ + _ListAgentEngineTaskEventsRequestParameters, + _ListAgentEngineTaskEventsRequestParametersDict, +] + + +class ListAgentEngineTaskEventsResponse(_common.BaseModel): + """Response for listing Agent Engine tasks events.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + taskEvents: Optional[list[TaskEvent]] = Field( + default=None, description="""List of Agent Engine task events.""" + ) + + +class ListAgentEngineTaskEventsResponseDict(TypedDict, total=False): + """Response for listing Agent Engine tasks events.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + taskEvents: Optional[list[TaskEventDict]] + """List of Agent Engine task events.""" + + +ListAgentEngineTaskEventsResponseOrDict = Union[ + ListAgentEngineTaskEventsResponse, ListAgentEngineTaskEventsResponseDict +] + + +class CreateEvaluationItemConfig(_common.BaseModel): + """Config to create an evaluation item.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateEvaluationItemConfigDict(TypedDict, total=False): + """Config to create an evaluation item.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateEvaluationItemConfigOrDict = Union[ + CreateEvaluationItemConfig, CreateEvaluationItemConfigDict +] + + +class _CreateEvaluationItemParameters(_common.BaseModel): + """Represents a job that creates an evaluation item.""" + + evaluation_item_type: Optional[str] = Field(default=None, description="""""") + gcs_uri: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + config: Optional[CreateEvaluationItemConfig] = Field( + default=None, description="""""" + ) + + +class _CreateEvaluationItemParametersDict(TypedDict, total=False): + """Represents a job that creates an evaluation item.""" + + evaluation_item_type: Optional[str] + """""" + + gcs_uri: Optional[str] + """""" + + display_name: Optional[str] + """""" + + config: Optional[CreateEvaluationItemConfigDict] + """""" + + +_CreateEvaluationItemParametersOrDict = Union[ + _CreateEvaluationItemParameters, _CreateEvaluationItemParametersDict +] + + +class PromptTemplateData(_common.BaseModel): + """Holds data for a prompt template. + + Message to hold a prompt template and the values to populate the template. + """ + + values: Optional[dict[str, genai_types.Content]] = Field( + default=None, description="""The values for fields in the prompt template.""" + ) + + +class PromptTemplateDataDict(TypedDict, total=False): + """Holds data for a prompt template. + + Message to hold a prompt template and the values to populate the template. + """ + + values: Optional[dict[str, genai_types.ContentDict]] + """The values for fields in the prompt template.""" + + +PromptTemplateDataOrDict = Union[PromptTemplateData, PromptTemplateDataDict] + + +class EvaluationPrompt(_common.BaseModel): + """Represents the prompt to be evaluated.""" + + text: Optional[str] = Field(default=None, description="""Text prompt.""") + value: Optional[dict[str, Any]] = Field( + default=None, + description="""Fields and values that can be used to populate the prompt template.""", + ) + prompt_template_data: Optional[PromptTemplateData] = Field( + default=None, description="""Prompt template data.""" + ) + user_scenario: Optional[evals_types.UserScenario] = Field( + default=None, + description="""User scenario to help simulate multi-turn agent run results.""", + ) + + +class EvaluationPromptDict(TypedDict, total=False): + """Represents the prompt to be evaluated.""" + + text: Optional[str] + """Text prompt.""" + + value: Optional[dict[str, Any]] + """Fields and values that can be used to populate the prompt template.""" + + prompt_template_data: Optional[PromptTemplateDataDict] + """Prompt template data.""" + + user_scenario: Optional[evals_types.UserScenario] + """User scenario to help simulate multi-turn agent run results.""" + + +EvaluationPromptOrDict = Union[EvaluationPrompt, EvaluationPromptDict] + + +class CandidateResponse(_common.BaseModel): + """Responses from model or agent.""" + + candidate: Optional[str] = Field( + default=None, + description="""The name of the candidate that produced the response.""", + ) + text: Optional[str] = Field(default=None, description="""The text response.""") + value: Optional[dict[str, Any]] = Field( + default=None, + description="""Fields and values that can be used to populate the response template.""", + ) + events: Optional[list[genai_types.Content]] = Field( + default=None, + description="""Intermediate events (such as tool calls and responses) that led to the final response.""", + ) + agent_data: Optional[evals_types.AgentData] = Field( + default=None, + description="""Represents the complete execution trace of an anget conversation, + which can involve single or multiple agents. This field is used to + provide the full output of an agent's run, including all turns and + events, for direct evaluation.""", + ) + + +class CandidateResponseDict(TypedDict, total=False): + """Responses from model or agent.""" + + candidate: Optional[str] + """The name of the candidate that produced the response.""" + + text: Optional[str] + """The text response.""" + + value: Optional[dict[str, Any]] + """Fields and values that can be used to populate the response template.""" + + events: Optional[list[genai_types.ContentDict]] + """Intermediate events (such as tool calls and responses) that led to the final response.""" + + agent_data: Optional[evals_types.AgentData] + """Represents the complete execution trace of an anget conversation, + which can involve single or multiple agents. This field is used to + provide the full output of an agent's run, including all turns and + events, for direct evaluation.""" + + +CandidateResponseOrDict = Union[CandidateResponse, CandidateResponseDict] + + +class EvaluationItemRequest(_common.BaseModel): + """Single evaluation request.""" + + prompt: Optional[EvaluationPrompt] = Field( + default=None, description="""The request/prompt to evaluate.""" + ) + golden_response: Optional[CandidateResponse] = Field( + default=None, description="""The ideal response or ground truth.""" + ) + rubrics: Optional[dict[str, "RubricGroup"]] = Field( + default=None, + description="""Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""", + ) + candidate_responses: Optional[list[CandidateResponse]] = Field( + default=None, + description="""Responses from model under test and other baseline models for comparison.""", + ) + + +class EvaluationItemRequestDict(TypedDict, total=False): + """Single evaluation request.""" + + prompt: Optional[EvaluationPromptDict] + """The request/prompt to evaluate.""" + + golden_response: Optional[CandidateResponseDict] + """The ideal response or ground truth.""" + + rubrics: Optional[dict[str, "RubricGroupDict"]] + """Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""" + + candidate_responses: Optional[list[CandidateResponseDict]] + """Responses from model under test and other baseline models for comparison.""" + + +EvaluationItemRequestOrDict = Union[EvaluationItemRequest, EvaluationItemRequestDict] + + +class EvaluationItemResult(_common.BaseModel): + """Represents the result of an evaluation item.""" + + evaluation_request: Optional[str] = Field( + default=None, description="""The request item that was evaluated.""" + ) + evaluation_run: Optional[str] = Field( + default=None, + description="""The evaluation run that was used to generate the result.""", + ) + request: Optional[EvaluationItemRequest] = Field( + default=None, description="""The request that was evaluated.""" + ) + metric: Optional[str] = Field( + default=None, description="""The metric that was evaluated.""" + ) + candidate_results: Optional[list[evals_types.CandidateResult]] = Field( + default=None, description="""TThe results for the metric.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="""Metadata about the evaluation result.""" + ) + + +class EvaluationItemResultDict(TypedDict, total=False): + """Represents the result of an evaluation item.""" + + evaluation_request: Optional[str] + """The request item that was evaluated.""" + + evaluation_run: Optional[str] + """The evaluation run that was used to generate the result.""" + + request: Optional[EvaluationItemRequestDict] + """The request that was evaluated.""" + + metric: Optional[str] + """The metric that was evaluated.""" + + candidate_results: Optional[list[evals_types.CandidateResult]] + """TThe results for the metric.""" + + metadata: Optional[dict[str, Any]] + """Metadata about the evaluation result.""" + + +EvaluationItemResultOrDict = Union[EvaluationItemResult, EvaluationItemResultDict] + + +class EvaluationItem(_common.BaseModel): + """EvaluationItem is a single evaluation request or result. + + The content of an EvaluationItem is immutable - it cannot be updated once + created. EvaluationItems can be deleted when no longer needed. + """ + + name: Optional[str] = Field( + default=None, description="""The resource name of the EvaluationItem.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the EvaluationItem.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="""Metadata for the EvaluationItem.""" + ) + labels: Optional[dict[str, str]] = Field( + default=None, description="""Labels for the EvaluationItem.""" + ) + evaluation_item_type: Optional[EvaluationItemType] = Field( + default=None, description="""The type of the EvaluationItem.""" + ) + evaluation_request: Optional[EvaluationItemRequest] = Field( + default=None, description="""The request to evaluate.""" + ) + evaluation_response: Optional[EvaluationItemResult] = Field( + default=None, description="""The response from evaluation.""" + ) + gcs_uri: Optional[str] = Field( + default=None, + description="""The Cloud Storage object where the request or response is stored.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, description="""Timestamp when this item was created.""" + ) + error: Optional[genai_types.GoogleRpcStatus] = Field( + default=None, description="""Error for the evaluation item.""" + ) + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["EvaluationItem"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "EvaluationItem": + """Converts a dictionary response into a EvaluationItem object.""" + + response = _camel_key_to_snake(response) + result = super()._from_response(response=response, kwargs=kwargs) + return result + + +class EvaluationItemDict(TypedDict, total=False): + """EvaluationItem is a single evaluation request or result. + + The content of an EvaluationItem is immutable - it cannot be updated once + created. EvaluationItems can be deleted when no longer needed. + """ + + name: Optional[str] + """The resource name of the EvaluationItem.""" + + display_name: Optional[str] + """The display name of the EvaluationItem.""" + + metadata: Optional[dict[str, Any]] + """Metadata for the EvaluationItem.""" + + labels: Optional[dict[str, str]] + """Labels for the EvaluationItem.""" + + evaluation_item_type: Optional[EvaluationItemType] + """The type of the EvaluationItem.""" + + evaluation_request: Optional[EvaluationItemRequestDict] + """The request to evaluate.""" + + evaluation_response: Optional[EvaluationItemResultDict] + """The response from evaluation.""" + + gcs_uri: Optional[str] + """The Cloud Storage object where the request or response is stored.""" + + create_time: Optional[datetime.datetime] + """Timestamp when this item was created.""" + + error: Optional[genai_types.GoogleRpcStatusDict] + """Error for the evaluation item.""" + + +EvaluationItemOrDict = Union[EvaluationItem, EvaluationItemDict] + + +class Metric(_common.BaseModel): + """The metric used for evaluation.""" + + name: Optional[str] = Field(default=None, description="""The name of the metric.""") + custom_function: Optional[Union[str, Callable[..., Any]]] = Field( + default=None, + description="""The custom function that defines the end-to-end logic for metric computation.""", + ) + prompt_template: Optional[str] = Field( + default=None, description="""The prompt template for the metric.""" + ) + judge_model_system_instruction: Optional[str] = Field( + default=None, description="""The system instruction for the judge model.""" + ) + return_raw_output: Optional[bool] = Field( + default=None, + description="""Whether to return the raw output from the judge model.""", + ) + parse_and_reduce_fn: Optional[Callable[..., Any]] = Field( + default=None, + description="""The parse and reduce function for the judge model.""", + ) + aggregate_summary_fn: Optional[Callable[..., Any]] = Field( + default=None, + description="""The aggregate summary function for the judge model.""", + ) + remote_custom_function: Optional[str] = Field( + default=None, + description="""The evaluation function for the custom code execution metric. This custom code is run remotely in the evaluation service.""", + ) + judge_model: Optional[str] = Field( + default=None, description="""The judge model for the metric.""" + ) + judge_model_generation_config: Optional[genai_types.GenerationConfig] = Field( + default=None, + description="""The generation config for the judge LLM (temperature, top_k, top_p, etc).""", + ) + judge_model_sampling_count: Optional[int] = Field( + default=None, description="""The sampling count for the judge model.""" + ) + rubric_group_name: Optional[str] = Field( + default=None, + description="""The rubric group name for the rubric-based metric.""", + ) + metric_spec_parameters: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional steering instruction parameters for the automated predefined metric.""", + ) + metric_resource_name: Optional[str] = Field( + default=None, + description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""", + ) + + # Allow extra fields to support metric-specific config fields. + model_config = ConfigDict(extra="allow") + + _is_predefined: bool = PrivateAttr(default=False) + """A boolean indicating whether the metric is predefined.""" + + _config_source: Optional[str] = PrivateAttr(default=None) + """An optional string indicating the source of the metric configuration.""" + + _version: Optional[str] = PrivateAttr(default=None) + """An optional string indicating the version of the metric.""" + + @model_validator(mode="after") + @classmethod + def validate_name(cls, model: "Metric") -> "Metric": + if not model.name: + raise ValueError("Metric name cannot be empty.") + model.name = model.name.lower() + return model + + def to_yaml_file(self, file_path: str, version: Optional[str] = None) -> None: + """Dumps the metric object to a YAML file. + + Args: + file_path: The path to the YAML file. + version: Optional version string to include in the YAML output. + + Raises: + ImportError: If the pyyaml library is not installed. + """ + if yaml is None: + raise ImportError( + "YAML serialization requires the pyyaml library. Please install" + " it using 'pip install google-cloud-aiplatform[evaluation]'." + ) + + fields_to_exclude = { + field_name + for field_name, field_info in self.model_fields.items() + if self.__getattribute__(field_name) is not None + and isinstance(self.__getattribute__(field_name), Callable) + } + + data_to_dump = self.model_dump( + exclude_unset=True, + exclude_none=True, + mode="json", + exclude=fields_to_exclude if fields_to_exclude else None, + ) + + if version: + data_to_dump["version"] = version + + with open(file_path, "w", encoding="utf-8") as f: + yaml.dump(data_to_dump, f, sort_keys=False, allow_unicode=True) + + +class CodeExecutionMetric(Metric): + """A metric that executes custom Python code for evaluation.""" + + # You can use standard Pydantic Field syntax here because this is raw Python code + custom_function: Optional[str] = Field( + default=None, + description="""The Python function code to be executed on the server side.""", + ) + + # You can also add hand-written validators or methods here + @field_validator("custom_function") + @classmethod + def validate_code(cls, value: Optional[str]) -> Optional[str]: + if value and "def evaluate" not in value: + raise ValueError( + "custom_function must contain a 'def evaluate(instance):' signature." + ) + return value + + +class LLMMetric(Metric): + """A metric that uses LLM-as-a-judge for evaluation.""" + + rubric_group_name: Optional[str] = Field( + default=None, + description="""Optional. The name of the column in the EvaluationDataset containing the list of rubrics to use for this metric.""", + ) + + @field_validator("prompt_template", mode="before") + @classmethod + def validate_prompt_template(cls, value: Union[str, "MetricPromptBuilder"]) -> str: + """Validates prompt template to be a non-empty string.""" + if value is None: + raise ValueError("Prompt template cannot be empty.") + if isinstance(value, MetricPromptBuilder): + value = str(value) + if not value.strip(): + raise ValueError("Prompt template cannot be an empty string.") + return value + + @field_validator("judge_model_sampling_count") + @classmethod + def validate_judge_model_sampling_count(cls, value: Optional[int]) -> Optional[int]: + """Validates judge_model_sampling_count to be between 1 and 32.""" + if value is not None and (value < 1 or value > 32): + raise ValueError("judge_model_sampling_count must be between 1 and 32.") + return value + + @classmethod + def load(cls, config_path: str, client: Optional[Any] = None) -> "LLMMetric": + """Loads a metric configuration from a YAML or JSON file. + + This method allows for the creation of an LLMMetric instance from a + local file path or a Google Cloud Storage (GCS) URI. It will automatically + detect the file type (.yaml, .yml, or .json) and parse it accordingly. + + Args: + config_path: The local path or GCS URI (e.g., 'gs://bucket/metric.yaml') + to the metric configuration file. + client: Optional. The Vertex AI client instance to use for authentication. + If not provided, Application Default Credentials (ADC) will be used. + + Returns: + An instance of LLMMetric configured with the loaded data. + + Raises: + ValueError: If the file path is invalid or the file content cannot be parsed. + ImportError: If a required library like 'PyYAML' or 'google-cloud-storage' is not installed. + IOError: If the file cannot be read from the specified path. + """ + file_extension = os.path.splitext(config_path)[1].lower() + if file_extension not in [".yaml", ".yml", ".json"]: + raise ValueError( + "Unsupported file extension for metric config. Must be .yaml, .yml, or .json" + ) + + content_str: str + if config_path.startswith("gs://"): + try: + from google.cloud import storage # type: ignore[attr-defined] + + storage_client = storage.Client( + credentials=client._api_client._credentials if client else None + ) + path_without_prefix = config_path[len("gs://") :] + bucket_name, blob_path = path_without_prefix.split("/", 1) + + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_path) + content_str = blob.download_as_bytes().decode("utf-8") + except ImportError as e: + raise ImportError( + "Reading from GCS requires the 'google-cloud-storage' library. Please install it with 'pip install google-cloud-aiplatform[evaluation]'." + ) from e + except Exception as e: + raise IOError(f"Failed to read from GCS path {config_path}: {e}") from e + else: + try: + with open(config_path, "r", encoding="utf-8") as f: + content_str = f.read() + except FileNotFoundError: + raise FileNotFoundError( + f"Local configuration file not found at: {config_path}" + ) + except Exception as e: + raise IOError(f"Failed to read local file {config_path}: {e}") from e + + data: Dict[str, Any] + + if file_extension in [".yaml", ".yml"]: + if yaml is None: + raise ImportError( + "YAML parsing requires the pyyaml library. Please install it with 'pip install google-cloud-aiplatform[evaluation]'." + ) + data = yaml.safe_load(content_str) + elif file_extension == ".json": + data = json.loads(content_str) + + if not isinstance(data, dict): + raise ValueError("Metric config content did not parse into a dictionary.") + + return cls.model_validate(data) + + +class MetricDict(TypedDict, total=False): + """The metric used for evaluation.""" + + name: Optional[str] + """The name of the metric.""" + + custom_function: Optional[Union[str, Callable[..., Any]]] + """The custom function that defines the end-to-end logic for metric computation.""" + + prompt_template: Optional[str] + """The prompt template for the metric.""" + + judge_model_system_instruction: Optional[str] + """The system instruction for the judge model.""" + + return_raw_output: Optional[bool] + """Whether to return the raw output from the judge model.""" + + parse_and_reduce_fn: Optional[Callable[..., Any]] + """The parse and reduce function for the judge model.""" + + aggregate_summary_fn: Optional[Callable[..., Any]] + """The aggregate summary function for the judge model.""" + + remote_custom_function: Optional[str] + """The evaluation function for the custom code execution metric. This custom code is run remotely in the evaluation service.""" + + judge_model: Optional[str] + """The judge model for the metric.""" + + judge_model_generation_config: Optional[genai_types.GenerationConfigDict] + """The generation config for the judge LLM (temperature, top_k, top_p, etc).""" + + judge_model_sampling_count: Optional[int] + """The sampling count for the judge model.""" + + rubric_group_name: Optional[str] + """The rubric group name for the rubric-based metric.""" + + metric_spec_parameters: Optional[dict[str, Any]] + """Optional steering instruction parameters for the automated predefined metric.""" + + metric_resource_name: Optional[str] + """The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""" + + +MetricOrDict = Union[Metric, MetricDict] + + +class CreateEvaluationMetricConfig(_common.BaseModel): + """Config for creating an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateEvaluationMetricConfigDict(TypedDict, total=False): + """Config for creating an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateEvaluationMetricConfigOrDict = Union[ + CreateEvaluationMetricConfig, CreateEvaluationMetricConfigDict +] + + +class _CreateEvaluationMetricParameters(_common.BaseModel): + """Parameters for creating an evaluation metric.""" + + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the evaluation metric. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """, + ) + description: Optional[str] = Field( + default=None, description="""The description of the evaluation metric.""" + ) + metric: Optional[Metric] = Field( + default=None, + description="""The metric configuration of the evaluation metric.""", + ) + config: Optional[CreateEvaluationMetricConfig] = Field( + default=None, description="""""" + ) + + +class _CreateEvaluationMetricParametersDict(TypedDict, total=False): + """Parameters for creating an evaluation metric.""" + + display_name: Optional[str] + """The user-defined name of the evaluation metric. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """ + + description: Optional[str] + """The description of the evaluation metric.""" + + metric: Optional[MetricDict] + """The metric configuration of the evaluation metric.""" + + config: Optional[CreateEvaluationMetricConfigDict] + """""" + + +_CreateEvaluationMetricParametersOrDict = Union[ + _CreateEvaluationMetricParameters, _CreateEvaluationMetricParametersDict +] + + +class CustomCodeExecutionSpec(_common.BaseModel): + """Specifies a metric using remote Python function execution. + + This metric is computed by running user-defined Python functions remotely. + """ + + evaluation_function: Optional[str] = Field( + default=None, + description="""Required. Python function. Expected user to define the following function, e.g.: def evaluate(instance: dict[str, Any]) -> float: Please include this function signature in the code snippet. Instance is the evaluation instance, any fields populated in the instance are available to the function as instance[field_name]. Example: Example input: ``` instance= EvaluationInstance( response=EvaluationInstance.InstanceData(text="The answer is 4."), reference=EvaluationInstance.InstanceData(text="4") ) ``` Example converted input: ``` { 'response': {'text': 'The answer is 4.'}, 'reference': {'text': '4'} } ``` Example python function: ``` def evaluate(instance: dict[str, Any]) -> float: if instance'response' == instance'reference': return 1.0 return 0.0 ``` CustomCodeExecutionSpec is also supported in Batch Evaluation (EvalDataset RPC) and Tuning Evaluation. Each line in the input jsonl file will be converted to dict[str, Any] and passed to the evaluation function.""", + ) + remote_custom_function: Optional[str] = Field( + default=None, + description="""A string representing a user-defined function for evaluation. + Expected user to define the following function, e.g.: + def evaluate(instance: dict[str, Any]) -> float: + Please include this function signature in the code snippet. + Instance is the evaluation instance, any fields populated in the instance + are available to the function as instance[field_name].""", + ) + + +class CustomCodeExecutionSpecDict(TypedDict, total=False): + """Specifies a metric using remote Python function execution. + + This metric is computed by running user-defined Python functions remotely. + """ + + evaluation_function: Optional[str] + """Required. Python function. Expected user to define the following function, e.g.: def evaluate(instance: dict[str, Any]) -> float: Please include this function signature in the code snippet. Instance is the evaluation instance, any fields populated in the instance are available to the function as instance[field_name]. Example: Example input: ``` instance= EvaluationInstance( response=EvaluationInstance.InstanceData(text="The answer is 4."), reference=EvaluationInstance.InstanceData(text="4") ) ``` Example converted input: ``` { 'response': {'text': 'The answer is 4.'}, 'reference': {'text': '4'} } ``` Example python function: ``` def evaluate(instance: dict[str, Any]) -> float: if instance'response' == instance'reference': return 1.0 return 0.0 ``` CustomCodeExecutionSpec is also supported in Batch Evaluation (EvalDataset RPC) and Tuning Evaluation. Each line in the input jsonl file will be converted to dict[str, Any] and passed to the evaluation function.""" + + remote_custom_function: Optional[str] + """A string representing a user-defined function for evaluation. + Expected user to define the following function, e.g.: + def evaluate(instance: dict[str, Any]) -> float: + Please include this function signature in the code snippet. + Instance is the evaluation instance, any fields populated in the instance + are available to the function as instance[field_name].""" + + +CustomCodeExecutionSpecOrDict = Union[ + CustomCodeExecutionSpec, CustomCodeExecutionSpecDict +] + + +class UnifiedMetric(_common.BaseModel): + """The unified metric used for evaluation.""" + + bleu_spec: Optional[genai_types.BleuSpec] = Field( + default=None, description="""The Bleu metric spec.""" + ) + rouge_spec: Optional[genai_types.RougeSpec] = Field( + default=None, description="""The rouge metric spec.""" + ) + pointwise_metric_spec: Optional[genai_types.PointwiseMetricSpec] = Field( + default=None, description="""The pointwise metric spec.""" + ) + llm_based_metric_spec: Optional[genai_types.LLMBasedMetricSpec] = Field( + default=None, description="""The spec for an LLM based metric.""" + ) + custom_code_execution_spec: Optional[CustomCodeExecutionSpec] = Field( + default=None, description="""The spec for a custom code execution metric.""" + ) + predefined_metric_spec: Optional[genai_types.PredefinedMetricSpec] = Field( + default=None, description="""The spec for a pre-defined metric.""" + ) + computation_based_metric_spec: Optional[genai_types.ComputationBasedMetricSpec] = ( + Field(default=None, description="""The spec for a computation based metric.""") + ) + + +class UnifiedMetricDict(TypedDict, total=False): + """The unified metric used for evaluation.""" + + bleu_spec: Optional[genai_types.BleuSpecDict] + """The Bleu metric spec.""" + + rouge_spec: Optional[genai_types.RougeSpecDict] + """The rouge metric spec.""" + + pointwise_metric_spec: Optional[genai_types.PointwiseMetricSpecDict] + """The pointwise metric spec.""" + + llm_based_metric_spec: Optional[genai_types.LLMBasedMetricSpecDict] + """The spec for an LLM based metric.""" + + custom_code_execution_spec: Optional[CustomCodeExecutionSpecDict] + """The spec for a custom code execution metric.""" + + predefined_metric_spec: Optional[genai_types.PredefinedMetricSpecDict] + """The spec for a pre-defined metric.""" + + computation_based_metric_spec: Optional[genai_types.ComputationBasedMetricSpecDict] + """The spec for a computation based metric.""" + + +UnifiedMetricOrDict = Union[UnifiedMetric, UnifiedMetricDict] + + +class EvaluationMetric(_common.BaseModel): + """Represents an evaluation metric.""" + + name: Optional[str] = Field( + default=None, description="""The resource name of the evaluation metric.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-friendly display name for the EvaluationMetric.""", + ) + description: Optional[str] = Field( + default=None, description="""The description of the EvaluationMetric.""" + ) + metric: Optional[UnifiedMetric] = Field( + default=None, + description="""The metric configuration of the evaluation metric.""", + ) + + +class EvaluationMetricDict(TypedDict, total=False): + """Represents an evaluation metric.""" + + name: Optional[str] + """The resource name of the evaluation metric.""" + + display_name: Optional[str] + """The user-friendly display name for the EvaluationMetric.""" + + description: Optional[str] + """The description of the EvaluationMetric.""" + + metric: Optional[UnifiedMetricDict] + """The metric configuration of the evaluation metric.""" + + +EvaluationMetricOrDict = Union[EvaluationMetric, EvaluationMetricDict] + + +class SamplingConfig(_common.BaseModel): + """Sampling config for a BigQuery request set.""" + + sampling_count: Optional[int] = Field(default=None, description="""""") + sampling_method: Optional[SamplingMethod] = Field(default=None, description="""""") + sampling_duration: Optional[str] = Field(default=None, description="""""") + + +class SamplingConfigDict(TypedDict, total=False): + """Sampling config for a BigQuery request set.""" + + sampling_count: Optional[int] + """""" + + sampling_method: Optional[SamplingMethod] + """""" + + sampling_duration: Optional[str] + """""" + + +SamplingConfigOrDict = Union[SamplingConfig, SamplingConfigDict] + + +class BigQueryRequestSet(_common.BaseModel): + """Represents a BigQuery request set.""" + + uri: Optional[str] = Field(default=None, description="""""") + prompt_column: Optional[str] = Field( + default=None, + description="""The column name of the prompt in the BigQuery table. Used for EvaluationRun only.""", + ) + rubrics_column: Optional[str] = Field( + default=None, + description="""The column name of the rubrics in the BigQuery table. Used for EvaluationRun only.""", + ) + candidate_response_columns: Optional[dict[str, str]] = Field( + default=None, + description="""The column name of the response candidates in the BigQuery table. Used for EvaluationRun only.""", + ) + sampling_config: Optional[SamplingConfig] = Field( + default=None, + description="""The sampling config for the BigQuery request set. Used for EvaluationRun only.""", + ) + + +class BigQueryRequestSetDict(TypedDict, total=False): + """Represents a BigQuery request set.""" + + uri: Optional[str] + """""" + + prompt_column: Optional[str] + """The column name of the prompt in the BigQuery table. Used for EvaluationRun only.""" + + rubrics_column: Optional[str] + """The column name of the rubrics in the BigQuery table. Used for EvaluationRun only.""" + + candidate_response_columns: Optional[dict[str, str]] + """The column name of the response candidates in the BigQuery table. Used for EvaluationRun only.""" + + sampling_config: Optional[SamplingConfigDict] + """The sampling config for the BigQuery request set. Used for EvaluationRun only.""" + + +BigQueryRequestSetOrDict = Union[BigQueryRequestSet, BigQueryRequestSetDict] + + +class EvaluationRunDataSource(_common.BaseModel): + """Represents an evaluation run data source.""" + + evaluation_set: Optional[str] = Field(default=None, description="""""") + bigquery_request_set: Optional[BigQueryRequestSet] = Field( + default=None, description="""""" + ) + + +class EvaluationRunDataSourceDict(TypedDict, total=False): + """Represents an evaluation run data source.""" + + evaluation_set: Optional[str] + """""" + + bigquery_request_set: Optional[BigQueryRequestSetDict] + """""" + + +EvaluationRunDataSourceOrDict = Union[ + EvaluationRunDataSource, EvaluationRunDataSourceDict +] + + +class EvaluationRunMetric(_common.BaseModel): + """The metric used for evaluation run.""" + + metric: Optional[str] = Field( + default=None, description="""The name of the metric.""" + ) + metric_resource_name: Optional[str] = Field( + default=None, + description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""", + ) + metric_config: Optional[UnifiedMetric] = Field( + default=None, description="""The unified metric used for evaluation run.""" + ) + + +class EvaluationRunMetricDict(TypedDict, total=False): + """The metric used for evaluation run.""" + + metric: Optional[str] + """The name of the metric.""" + + metric_resource_name: Optional[str] + """The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""" + + metric_config: Optional[UnifiedMetricDict] + """The unified metric used for evaluation run.""" + + +EvaluationRunMetricOrDict = Union[EvaluationRunMetric, EvaluationRunMetricDict] + + +class EvaluationRunPromptTemplate(_common.BaseModel): + """Prompt template used for inference. + + Only one of `prompt_template` or `gcs_uri` should be set. If both are + provided, an error will be raised. + """ + + prompt_template: Optional[str] = Field( + default=None, + description="""Inline prompt template. Template variables should be in the format + "{var_name}". Only one of `prompt_template` or `gcs_uri` should be set.""", + ) + gcs_uri: Optional[str] = Field( + default=None, + description="""Prompt template stored in Cloud Storage. Format: + "gs://my-bucket/file-name.txt". Only one of `prompt_template` or `gcs_uri` + should be set.""", + ) + + +class EvaluationRunPromptTemplateDict(TypedDict, total=False): + """Prompt template used for inference. + + Only one of `prompt_template` or `gcs_uri` should be set. If both are + provided, an error will be raised. + """ + + prompt_template: Optional[str] + """Inline prompt template. Template variables should be in the format + "{var_name}". Only one of `prompt_template` or `gcs_uri` should be set.""" + + gcs_uri: Optional[str] + """Prompt template stored in Cloud Storage. Format: + "gs://my-bucket/file-name.txt". Only one of `prompt_template` or `gcs_uri` + should be set.""" + + +EvaluationRunPromptTemplateOrDict = Union[ + EvaluationRunPromptTemplate, EvaluationRunPromptTemplateDict +] + + +class LossAnalysisConfig(_common.BaseModel): + """Configuration for the loss analysis job.""" + + metric: Optional[str] = Field( + default=None, + description="""Required. The metric to analyze (e.g., "multi_turn_tool_use_quality_v1").""", + ) + candidate: Optional[str] = Field( + default=None, + description="""Required. The candidate model/agent to analyze (e.g., "gemini-3.1-pro-preview"). This targets the specific CandidateResult within the EvaluationResult.""", + ) + predefined_taxonomy: Optional[str] = Field( + default=None, + description="""Optional. The identifier for the pre-defined taxonomy to use (e.g., "agent_taxonomy_v1", "tool_use_v2"). If not specified, the service may select a default based on the metric.""", + ) + max_top_cluster_count: Optional[int] = Field( + default=None, + description="""Optional. Limits the analysis to the top N clusters. If not specified or set to 0, all clusters are returned.""", + ) + + +class LossAnalysisConfigDict(TypedDict, total=False): + """Configuration for the loss analysis job.""" + + metric: Optional[str] + """Required. The metric to analyze (e.g., "multi_turn_tool_use_quality_v1").""" + + candidate: Optional[str] + """Required. The candidate model/agent to analyze (e.g., "gemini-3.1-pro-preview"). This targets the specific CandidateResult within the EvaluationResult.""" + + predefined_taxonomy: Optional[str] + """Optional. The identifier for the pre-defined taxonomy to use (e.g., "agent_taxonomy_v1", "tool_use_v2"). If not specified, the service may select a default based on the metric.""" + + max_top_cluster_count: Optional[int] + """Optional. Limits the analysis to the top N clusters. If not specified or set to 0, all clusters are returned.""" + + +LossAnalysisConfigOrDict = Union[LossAnalysisConfig, LossAnalysisConfigDict] + + +class EvaluationRunConfig(_common.BaseModel): + """The evaluation configuration used for the evaluation run.""" + + metrics: Optional[list[EvaluationRunMetric]] = Field( + default=None, + description="""The metrics to be calculated in the evaluation run.""", + ) + output_config: Optional[genai_types.OutputConfig] = Field( + default=None, description="""The output config for the evaluation run.""" + ) + autorater_config: Optional[genai_types.AutoraterConfig] = Field( + default=None, description="""The autorater config for the evaluation run.""" + ) + prompt_template: Optional[EvaluationRunPromptTemplate] = Field( + default=None, description="""The prompt template used for inference.""" + ) + loss_analysis_config: Optional[list[LossAnalysisConfig]] = Field( + default=None, + description="""Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns.""", + ) + + +class EvaluationRunConfigDict(TypedDict, total=False): + """The evaluation configuration used for the evaluation run.""" + + metrics: Optional[list[EvaluationRunMetricDict]] + """The metrics to be calculated in the evaluation run.""" + + output_config: Optional[genai_types.OutputConfigDict] + """The output config for the evaluation run.""" + + autorater_config: Optional[genai_types.AutoraterConfigDict] + """The autorater config for the evaluation run.""" + + prompt_template: Optional[EvaluationRunPromptTemplateDict] + """The prompt template used for inference.""" + + loss_analysis_config: Optional[list[LossAnalysisConfigDict]] + """Specifications for loss analysis. Each config specifies a metric and candidate to analyze for loss patterns.""" + + +EvaluationRunConfigOrDict = Union[EvaluationRunConfig, EvaluationRunConfigDict] + + +class EvaluationRunAgentConfig(_common.BaseModel): + """This field is experimental and may change in future versions. + + Agent config for an evaluation run. + """ + + developer_instruction: Optional[genai_types.Content] = Field( + default=None, description="""The developer instruction for the agent.""" + ) + tools: Optional[list[genai_types.Tool]] = Field( + default=None, description="""The tools available to the agent.""" + ) + + +class EvaluationRunAgentConfigDict(TypedDict, total=False): + """This field is experimental and may change in future versions. + + Agent config for an evaluation run. + """ + + developer_instruction: Optional[genai_types.ContentDict] + """The developer instruction for the agent.""" + + tools: Optional[list[genai_types.ToolDict]] + """The tools available to the agent.""" + + +EvaluationRunAgentConfigOrDict = Union[ + EvaluationRunAgentConfig, EvaluationRunAgentConfigDict +] + + +class AgentRunConfig(_common.BaseModel): + """Configuration for an Agent Run.""" + + session_input: Optional[evals_types.SessionInput] = Field( + default=None, description="""The session input to get agent running results.""" + ) + agent_engine: Optional[str] = Field( + default=None, description="""The resource name of the Agent Engine.""" + ) + user_simulator_config: Optional[evals_types.UserSimulatorConfig] = Field( + default=None, + description="""Used for multi-turn agent run. + Contains configuration for a user simulator that + uses an LLM to generate messages on behalf of the user.""", + ) + + +class AgentRunConfigDict(TypedDict, total=False): + """Configuration for an Agent Run.""" + + session_input: Optional[evals_types.SessionInput] + """The session input to get agent running results.""" + + agent_engine: Optional[str] + """The resource name of the Agent Engine.""" + + user_simulator_config: Optional[evals_types.UserSimulatorConfig] + """Used for multi-turn agent run. + Contains configuration for a user simulator that + uses an LLM to generate messages on behalf of the user.""" + + +AgentRunConfigOrDict = Union[AgentRunConfig, AgentRunConfigDict] + + +class EvaluationRunInferenceConfig(_common.BaseModel): + """This field is experimental and may change in future versions. + + Configuration that describes an agent. + """ + + agent_config: Optional[EvaluationRunAgentConfig] = Field( + default=None, description="""The agent config.""" + ) + model: Optional[str] = Field( + default=None, + description="""The fully qualified name of the publisher model or endpoint to use for inference.""", + ) + prompt_template: Optional[EvaluationRunPromptTemplate] = Field( + default=None, description="""The prompt template used for inference.""" + ) + agent_run_config: Optional[AgentRunConfig] = Field( + default=None, + description="""Configuration for Agent Run in evaluation management service.""", + ) + agent_configs: Optional[dict[str, evals_types.AgentConfig]] = Field( + default=None, + description="""A map of agent IDs to their respective agent config.""", + ) + + +class EvaluationRunInferenceConfigDict(TypedDict, total=False): + """This field is experimental and may change in future versions. + + Configuration that describes an agent. + """ + + agent_config: Optional[EvaluationRunAgentConfigDict] + """The agent config.""" + + model: Optional[str] + """The fully qualified name of the publisher model or endpoint to use for inference.""" + + prompt_template: Optional[EvaluationRunPromptTemplateDict] + """The prompt template used for inference.""" + + agent_run_config: Optional[AgentRunConfigDict] + """Configuration for Agent Run in evaluation management service.""" + + agent_configs: Optional[dict[str, evals_types.AgentConfig]] + """A map of agent IDs to their respective agent config.""" + + +EvaluationRunInferenceConfigOrDict = Union[ + EvaluationRunInferenceConfig, EvaluationRunInferenceConfigDict +] + + +class CreateEvaluationRunConfig(_common.BaseModel): + """Config to create an evaluation run.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateEvaluationRunConfigDict(TypedDict, total=False): + """Config to create an evaluation run.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateEvaluationRunConfigOrDict = Union[ + CreateEvaluationRunConfig, CreateEvaluationRunConfigDict +] + + +class _CreateEvaluationRunParameters(_common.BaseModel): + """Represents a job that creates an evaluation run.""" + + name: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + data_source: Optional[EvaluationRunDataSource] = Field( + default=None, description="""""" + ) + evaluation_config: Optional[EvaluationRunConfig] = Field( + default=None, description="""""" + ) + labels: Optional[dict[str, str]] = Field(default=None, description="""""") + inference_configs: Optional[dict[str, EvaluationRunInferenceConfig]] = Field( + default=None, description="""""" + ) + config: Optional[CreateEvaluationRunConfig] = Field( + default=None, description="""""" + ) + + +class _CreateEvaluationRunParametersDict(TypedDict, total=False): + """Represents a job that creates an evaluation run.""" + + name: Optional[str] + """""" + + display_name: Optional[str] + """""" + + data_source: Optional[EvaluationRunDataSourceDict] + """""" + + evaluation_config: Optional[EvaluationRunConfigDict] + """""" + + labels: Optional[dict[str, str]] + """""" + + inference_configs: Optional[dict[str, EvaluationRunInferenceConfigDict]] + """""" + + config: Optional[CreateEvaluationRunConfigDict] + """""" + + +_CreateEvaluationRunParametersOrDict = Union[ + _CreateEvaluationRunParameters, _CreateEvaluationRunParametersDict +] + + +class SummaryMetric(_common.BaseModel): + """Represents a summary metric for an evaluation run.""" + + metrics: Optional[dict[str, Any]] = Field( + default=None, description="""Map of metric name to metric value.""" + ) + total_items: Optional[int] = Field( + default=None, description="""The total number of items that were evaluated.""" + ) + failed_items: Optional[int] = Field( + default=None, description="""The number of items that failed to be evaluated.""" + ) + + +class SummaryMetricDict(TypedDict, total=False): + """Represents a summary metric for an evaluation run.""" + + metrics: Optional[dict[str, Any]] + """Map of metric name to metric value.""" + + total_items: Optional[int] + """The total number of items that were evaluated.""" + + failed_items: Optional[int] + """The number of items that failed to be evaluated.""" + + +SummaryMetricOrDict = Union[SummaryMetric, SummaryMetricDict] + + +class LossTaxonomyEntry(_common.BaseModel): + """A specific entry in the loss pattern taxonomy.""" + + l1_category: Optional[str] = Field( + default=None, + description="""The primary category of the loss (e.g., "Hallucination", "Tool Calling").""", + ) + l2_category: Optional[str] = Field( + default=None, + description="""The secondary category of the loss (e.g., "Hallucination of Action", "Incorrect Tool Selection").""", + ) + description: Optional[str] = Field( + default=None, + description="""A detailed description of this loss pattern. Example: "The agent verbally confirms an action without executing the tool." """, + ) + + +class LossTaxonomyEntryDict(TypedDict, total=False): + """A specific entry in the loss pattern taxonomy.""" + + l1_category: Optional[str] + """The primary category of the loss (e.g., "Hallucination", "Tool Calling").""" + + l2_category: Optional[str] + """The secondary category of the loss (e.g., "Hallucination of Action", "Incorrect Tool Selection").""" + + description: Optional[str] + """A detailed description of this loss pattern. Example: "The agent verbally confirms an action without executing the tool." """ + + +LossTaxonomyEntryOrDict = Union[LossTaxonomyEntry, LossTaxonomyEntryDict] + + +class FailedRubric(_common.BaseModel): + """A specific failed rubric and the associated analysis.""" + + rubric_id: Optional[str] = Field( + default=None, + description="""The unique ID of the rubric (if available from the metric source).""", + ) + classification_rationale: Optional[str] = Field( + default=None, + description="""The rationale provided by the Loss Analysis Classifier for why this failure maps to this specific Loss Cluster.""", + ) + + +class FailedRubricDict(TypedDict, total=False): + """A specific failed rubric and the associated analysis.""" + + rubric_id: Optional[str] + """The unique ID of the rubric (if available from the metric source).""" + + classification_rationale: Optional[str] + """The rationale provided by the Loss Analysis Classifier for why this failure maps to this specific Loss Cluster.""" + + +FailedRubricOrDict = Union[FailedRubric, FailedRubricDict] + + +class LossExample(_common.BaseModel): + """A specific example of a loss pattern.""" + + evaluation_item: Optional[str] = Field( + default=None, + description="""Reference to the persisted EvalItem resource name. Format: projects/.../locations/.../evaluationItems/{item_id}.""", + ) + evaluation_result: Optional[dict[str, Any]] = Field( + default=None, + description="""The full evaluation result object provided inline. Used when the analysis is performed on ephemeral data.""", + ) + failed_rubrics: Optional[list[FailedRubric]] = Field( + default=None, + description="""The specific rubric(s) that failed and caused this example to be classified here. An example might fail multiple rubrics, but only specific ones trigger this loss pattern.""", + ) + + +class LossExampleDict(TypedDict, total=False): + """A specific example of a loss pattern.""" + + evaluation_item: Optional[str] + """Reference to the persisted EvalItem resource name. Format: projects/.../locations/.../evaluationItems/{item_id}.""" + + evaluation_result: Optional[dict[str, Any]] + """The full evaluation result object provided inline. Used when the analysis is performed on ephemeral data.""" + + failed_rubrics: Optional[list[FailedRubricDict]] + """The specific rubric(s) that failed and caused this example to be classified here. An example might fail multiple rubrics, but only specific ones trigger this loss pattern.""" + + +LossExampleOrDict = Union[LossExample, LossExampleDict] + + +class LossCluster(_common.BaseModel): + """A semantic grouping of failures (e.g., "Hallucination of Action").""" + + cluster_id: Optional[str] = Field( + default=None, + description="""Unique identifier for the loss cluster within the scope of the analysis result.""", + ) + taxonomy_entry: Optional[LossTaxonomyEntry] = Field( + default=None, + description="""The structured definition of the loss taxonomy for this cluster.""", + ) + item_count: Optional[int] = Field( + default=None, + description="""The total number of EvaluationItems falling into this cluster.""", + ) + examples: Optional[list[LossExample]] = Field( + default=None, + description="""A list of examples that belong to this cluster. This links the cluster back to the specific EvaluationItems and Rubrics.""", + ) + + +class LossClusterDict(TypedDict, total=False): + """A semantic grouping of failures (e.g., "Hallucination of Action").""" + + cluster_id: Optional[str] + """Unique identifier for the loss cluster within the scope of the analysis result.""" + + taxonomy_entry: Optional[LossTaxonomyEntryDict] + """The structured definition of the loss taxonomy for this cluster.""" + + item_count: Optional[int] + """The total number of EvaluationItems falling into this cluster.""" + + examples: Optional[list[LossExampleDict]] + """A list of examples that belong to this cluster. This links the cluster back to the specific EvaluationItems and Rubrics.""" + + +LossClusterOrDict = Union[LossCluster, LossClusterDict] + + +class LossAnalysisResult(_common.BaseModel): + """The top-level result for loss analysis.""" + + config: Optional[LossAnalysisConfig] = Field( + default=None, + description="""The configuration used to generate this analysis.""", + ) + analysis_time: Optional[str] = Field( + default=None, description="""The timestamp when this analysis was performed.""" + ) + clusters: Optional[list[LossCluster]] = Field( + default=None, description="""The list of identified loss clusters.""" + ) + + def show(self) -> None: + """Shows the loss analysis result with rich HTML visualization.""" + from .. import _evals_visualization + + _evals_visualization.display_loss_analysis_result(self) + + +class LossAnalysisResultDict(TypedDict, total=False): + """The top-level result for loss analysis.""" + + config: Optional[LossAnalysisConfigDict] + """The configuration used to generate this analysis.""" + + analysis_time: Optional[str] + """The timestamp when this analysis was performed.""" + + clusters: Optional[list[LossClusterDict]] + """The list of identified loss clusters.""" + + +LossAnalysisResultOrDict = Union[LossAnalysisResult, LossAnalysisResultDict] + + +class EvaluationRunResults(_common.BaseModel): + """Represents the results of an evaluation run.""" + + evaluation_set: Optional[str] = Field( + default=None, + description="""The evaluation set where item level results are stored.""", + ) + summary_metrics: Optional[SummaryMetric] = Field( + default=None, description="""The summary metrics for the evaluation run.""" + ) + loss_analysis_results: Optional[list[LossAnalysisResult]] = Field( + default=None, + description="""The loss analysis results for the evaluation run.""", + ) + + +class EvaluationRunResultsDict(TypedDict, total=False): + """Represents the results of an evaluation run.""" + + evaluation_set: Optional[str] + """The evaluation set where item level results are stored.""" + + summary_metrics: Optional[SummaryMetricDict] + """The summary metrics for the evaluation run.""" + + loss_analysis_results: Optional[list[LossAnalysisResultDict]] + """The loss analysis results for the evaluation run.""" + + +EvaluationRunResultsOrDict = Union[EvaluationRunResults, EvaluationRunResultsDict] + + +class EvalCaseMetricResult(_common.BaseModel): + """Evaluation result for a single evaluation case for a single metric.""" + + metric_name: Optional[str] = Field( + default=None, description="""Name of the metric.""" + ) + score: Optional[float] = Field(default=None, description="""Score of the metric.""") + explanation: Optional[str] = Field( + default=None, description="""Explanation of the metric.""" + ) + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] = Field( + default=None, + description="""The details of all the rubrics and their verdicts for rubric-based metrics.""", + ) + raw_output: Optional[list[str]] = Field( + default=None, description="""Raw output of the metric.""" + ) + error_message: Optional[str] = Field( + default=None, description="""Error message for the metric.""" + ) + + +class EvalCaseMetricResultDict(TypedDict, total=False): + """Evaluation result for a single evaluation case for a single metric.""" + + metric_name: Optional[str] + """Name of the metric.""" + + score: Optional[float] + """Score of the metric.""" + + explanation: Optional[str] + """Explanation of the metric.""" + + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] + """The details of all the rubrics and their verdicts for rubric-based metrics.""" + + raw_output: Optional[list[str]] + """Raw output of the metric.""" + + error_message: Optional[str] + """Error message for the metric.""" + + +EvalCaseMetricResultOrDict = Union[EvalCaseMetricResult, EvalCaseMetricResultDict] + + +class ResponseCandidateResult(_common.BaseModel): + """Aggregated metric results for a single response candidate.""" + + response_index: Optional[int] = Field( + default=None, + description="""Index of the response candidate this result pertains to.""", + ) + metric_results: Optional[dict[str, EvalCaseMetricResult]] = Field( + default=None, + description="""A dictionary of metric results for this response candidate, keyed by metric name.""", + ) + + +class ResponseCandidateResultDict(TypedDict, total=False): + """Aggregated metric results for a single response candidate.""" + + response_index: Optional[int] + """Index of the response candidate this result pertains to.""" + + metric_results: Optional[dict[str, EvalCaseMetricResultDict]] + """A dictionary of metric results for this response candidate, keyed by metric name.""" + + +ResponseCandidateResultOrDict = Union[ + ResponseCandidateResult, ResponseCandidateResultDict +] + + +class EvalCaseResult(_common.BaseModel): + """Eval result for a single evaluation case.""" + + eval_case_index: Optional[int] = Field( + default=None, description="""Index of the evaluation case.""" + ) + response_candidate_results: Optional[list[ResponseCandidateResult]] = Field( + default=None, + description="""A list of results, one for each response candidate of the EvalCase.""", + ) + + +class EvalCaseResultDict(TypedDict, total=False): + """Eval result for a single evaluation case.""" + + eval_case_index: Optional[int] + """Index of the evaluation case.""" + + response_candidate_results: Optional[list[ResponseCandidateResultDict]] + """A list of results, one for each response candidate of the EvalCase.""" + + +EvalCaseResultOrDict = Union[EvalCaseResult, EvalCaseResultDict] + + +class AggregatedMetricResult(_common.BaseModel): + """Evaluation result for a single metric for an evaluation dataset.""" + + metric_name: Optional[str] = Field( + default=None, description="""Name of the metric.""" + ) + num_cases_total: Optional[int] = Field( + default=None, description="""Total number of cases in the dataset.""" + ) + num_cases_valid: Optional[int] = Field( + default=None, description="""Number of valid cases in the dataset.""" + ) + num_cases_error: Optional[int] = Field( + default=None, description="""Number of cases with errors in the dataset.""" + ) + mean_score: Optional[float] = Field( + default=None, description="""Mean score of the metric.""" + ) + stdev_score: Optional[float] = Field( + default=None, description="""Standard deviation of the metric.""" + ) + pass_rate: Optional[float] = Field( + default=None, + description="""Pass rate of the adaptive rubric metric. Calculated as the number of cases where all criteria passed divided by the total number of valid cases. A case is passing if it has a score of 1.0.""", + ) + + # Allow extra fields to support custom aggregation stats. + model_config = ConfigDict(extra="allow") + + +class AggregatedMetricResultDict(TypedDict, total=False): + """Evaluation result for a single metric for an evaluation dataset.""" + + metric_name: Optional[str] + """Name of the metric.""" + + num_cases_total: Optional[int] + """Total number of cases in the dataset.""" + + num_cases_valid: Optional[int] + """Number of valid cases in the dataset.""" + + num_cases_error: Optional[int] + """Number of cases with errors in the dataset.""" + + mean_score: Optional[float] + """Mean score of the metric.""" + + stdev_score: Optional[float] + """Standard deviation of the metric.""" + + pass_rate: Optional[float] + """Pass rate of the adaptive rubric metric. Calculated as the number of cases where all criteria passed divided by the total number of valid cases. A case is passing if it has a score of 1.0.""" + + +AggregatedMetricResultOrDict = Union[AggregatedMetricResult, AggregatedMetricResultDict] + + +class WinRateStats(_common.BaseModel): + """Statistics for win rates for a single metric.""" + + win_rates: Optional[list[float]] = Field( + default=None, + description="""Win rates for the metric, one for each candidate.""", + ) + tie_rate: Optional[float] = Field( + default=None, description="""Tie rate for the metric.""" + ) + + +class WinRateStatsDict(TypedDict, total=False): + """Statistics for win rates for a single metric.""" + + win_rates: Optional[list[float]] + """Win rates for the metric, one for each candidate.""" + + tie_rate: Optional[float] + """Tie rate for the metric.""" + + +WinRateStatsOrDict = Union[WinRateStats, WinRateStatsDict] + + +class ResponseCandidate(_common.BaseModel): + """A model-generated content to the prompt.""" + + response: Optional[genai_types.Content] = Field( + default=None, + description="""The final model-generated response to the `prompt`.""", + ) + + +class ResponseCandidateDict(TypedDict, total=False): + """A model-generated content to the prompt.""" + + response: Optional[genai_types.ContentDict] + """The final model-generated response to the `prompt`.""" + + +ResponseCandidateOrDict = Union[ResponseCandidate, ResponseCandidateDict] + + +class EvalCase(_common.BaseModel): + """A comprehensive representation of a GenAI interaction for evaluation.""" + + prompt: Optional[genai_types.Content] = Field( + default=None, description="""The most recent user message (current input).""" + ) + responses: Optional[list[ResponseCandidate]] = Field( + default=None, + description="""Model-generated replies to the last user message in a conversation. Multiple responses are allowed to support use cases such as comparing different model outputs.""", + ) + reference: Optional[ResponseCandidate] = Field( + default=None, + description="""User-provided, golden reference model reply to prompt in context of chat history; Reference for last response in a conversation.""", + ) + system_instruction: Optional[genai_types.Content] = Field( + default=None, description="""System instruction for the model.""" + ) + conversation_history: Optional[list[evals_types.Message]] = Field( + default=None, + description="""List of all prior messages in the conversation (chat history).""", + ) + rubric_groups: Optional[dict[str, "RubricGroup"]] = Field( + default=None, + description="""Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""", + ) + eval_case_id: Optional[str] = Field( + default=None, description="""Unique identifier for the evaluation case.""" + ) + intermediate_events: Optional[list[evals_types.Event]] = Field( + default=None, + description="""This field is experimental and may change in future versions. Intermediate events of a single turn in an agent run or intermediate events of the last turn for multi-turn an agent run.""", + ) + agent_info: Optional[evals_types.AgentInfo] = Field( + default=None, + description="""This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation.""", + ) + agent_data: Optional[evals_types.AgentData] = Field( + default=None, + description="""This field is experimental and may change in future versions. The agent data of the agent under evaluation.""", + ) + user_scenario: Optional[evals_types.UserScenario] = Field( + default=None, + description="""This field is experimental and may change in future versions. The user scenario for the evaluation case.""", + ) + # Allow extra fields to support custom metric prompts and stay backward compatible. + model_config = ConfigDict(frozen=True, extra="allow") + + +class EvalCaseDict(TypedDict, total=False): + """A comprehensive representation of a GenAI interaction for evaluation.""" + + prompt: Optional[genai_types.ContentDict] + """The most recent user message (current input).""" + + responses: Optional[list[ResponseCandidateDict]] + """Model-generated replies to the last user message in a conversation. Multiple responses are allowed to support use cases such as comparing different model outputs.""" + + reference: Optional[ResponseCandidateDict] + """User-provided, golden reference model reply to prompt in context of chat history; Reference for last response in a conversation.""" + + system_instruction: Optional[genai_types.ContentDict] + """System instruction for the model.""" + + conversation_history: Optional[list[evals_types.Message]] + """List of all prior messages in the conversation (chat history).""" + + rubric_groups: Optional[dict[str, "RubricGroupDict"]] + """Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""" + + eval_case_id: Optional[str] + """Unique identifier for the evaluation case.""" + + intermediate_events: Optional[list[evals_types.Event]] + """This field is experimental and may change in future versions. Intermediate events of a single turn in an agent run or intermediate events of the last turn for multi-turn an agent run.""" + + agent_info: Optional[evals_types.AgentInfo] + """This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation.""" + + agent_data: Optional[evals_types.AgentData] + """This field is experimental and may change in future versions. The agent data of the agent under evaluation.""" + + user_scenario: Optional[evals_types.UserScenario] + """This field is experimental and may change in future versions. The user scenario for the evaluation case.""" + + +EvalCaseOrDict = Union[EvalCase, EvalCaseDict] + + +class EvaluationDataset(_common.BaseModel): + """The dataset used for evaluation.""" + + bigquery_source: Optional[genai_types.BigQuerySource] = Field( + default=None, description="""The BigQuery source for the evaluation dataset.""" + ) + gcs_source: Optional[genai_types.GcsSource] = Field( + default=None, description="""The GCS source for the evaluation dataset.""" + ) + eval_cases: Optional[list[EvalCase]] = Field( + default=None, description="""The evaluation cases to be evaluated.""" + ) + eval_dataset_df: Optional[PandasDataFrame] = Field( + default=None, + description="""The evaluation dataset in the form of a Pandas DataFrame.""", + ) + candidate_name: Optional[str] = Field( + default=None, + description="""The name of the candidate model or agent for this evaluation dataset.""", + ) + + @model_validator(mode="before") + @classmethod + def _check_pandas_installed(cls, data: Any) -> Any: + if isinstance(data, dict) and data.get("eval_dataset_df") is not None: + if pd is None: + logger.warning( + "Pandas is not installed, some evals features are not available." + " Please install it with `pip install" + " google-cloud-aiplatform[evaluation]`." + ) + return data + + @classmethod + def load_from_observability_eval_cases( + cls, cases: list["ObservabilityEvalCase"] + ) -> "EvaluationDataset": + """Fetches GenAI Observability data from GCS and parses into a DataFrame.""" + try: + import pandas as pd + from .. import _gcs_utils + + formats = [] + requests = [] + responses = [] + system_instructions = [] + + for case in cases: + gcs_utils = _gcs_utils.GcsUtils( + case.api_client._api_client if case.api_client else None + ) + + # Associate "observability" data format for given sources + formats.append("observability") + + # Input source + request_data = gcs_utils.read_file_contents(case.input_src) + requests.append(request_data) + + # Output source + response_data = gcs_utils.read_file_contents(case.output_src) + responses.append(response_data) + + # System instruction source + system_instruction_data = "" + if case.system_instruction_src is not None: + system_instruction_data = gcs_utils.read_file_contents( + case.system_instruction_src + ) + system_instructions.append(system_instruction_data) + + eval_dataset_df = pd.DataFrame( + { + "format": formats, + "request": requests, + "response": responses, + "system_instruction": system_instructions, + } + ) + + except ImportError as e: + raise ImportError("Pandas DataFrame library is required.") from e + + return EvaluationDataset(eval_dataset_df=eval_dataset_df) + + def show(self) -> None: + """Shows the evaluation dataset.""" + from .. import _evals_visualization + + _evals_visualization.display_evaluation_dataset(self) + + +class EvaluationDatasetDict(TypedDict, total=False): + """The dataset used for evaluation.""" + + bigquery_source: Optional[genai_types.BigQuerySourceDict] + """The BigQuery source for the evaluation dataset.""" + + gcs_source: Optional[genai_types.GcsSourceDict] + """The GCS source for the evaluation dataset.""" + + eval_cases: Optional[list[EvalCaseDict]] + """The evaluation cases to be evaluated.""" + + eval_dataset_df: Optional[PandasDataFrame] + """The evaluation dataset in the form of a Pandas DataFrame.""" + + candidate_name: Optional[str] + """The name of the candidate model or agent for this evaluation dataset.""" + + +EvaluationDatasetOrDict = Union[EvaluationDataset, EvaluationDatasetDict] + + +class EvaluationRunMetadata(_common.BaseModel): + """Metadata for an evaluation run.""" + + candidate_names: Optional[list[str]] = Field( + default=None, + description="""Name of the candidate(s) being evaluated in the evaluation run.""", + ) + dataset_name: Optional[str] = Field( + default=None, + description="""Name of the evaluation dataset used for the evaluation run.""", + ) + dataset_id: Optional[str] = Field( + default=None, + description="""Unique identifier for the evaluation dataset used for the evaluation run.""", + ) + creation_timestamp: Optional[datetime.datetime] = Field( + default=None, description="""Creation timestamp of the evaluation run.""" + ) + + +class EvaluationRunMetadataDict(TypedDict, total=False): + """Metadata for an evaluation run.""" + + candidate_names: Optional[list[str]] + """Name of the candidate(s) being evaluated in the evaluation run.""" + + dataset_name: Optional[str] + """Name of the evaluation dataset used for the evaluation run.""" + + dataset_id: Optional[str] + """Unique identifier for the evaluation dataset used for the evaluation run.""" + + creation_timestamp: Optional[datetime.datetime] + """Creation timestamp of the evaluation run.""" + + +EvaluationRunMetadataOrDict = Union[EvaluationRunMetadata, EvaluationRunMetadataDict] + + +class EvaluationResult(_common.BaseModel): + """Result of an evaluation run for an evaluation dataset.""" + + eval_case_results: Optional[list[EvalCaseResult]] = Field( + default=None, + description="""A list of evaluation results for each evaluation case.""", + ) + summary_metrics: Optional[list[AggregatedMetricResult]] = Field( + default=None, + description="""A list of summary-level evaluation results for each metric.""", + ) + win_rates: Optional[dict[str, WinRateStats]] = Field( + default=None, + description="""A dictionary of win rates for each metric, only populated for multi-response evaluation runs.""", + ) + evaluation_dataset: Optional[list[EvaluationDataset]] = Field( + default=None, + description="""The input evaluation dataset(s) for the evaluation run.""", + ) + metadata: Optional[EvaluationRunMetadata] = Field( + default=None, description="""Metadata for the evaluation run.""" + ) + agent_info: Optional[evals_types.AgentInfo] = Field( + default=None, + description="""This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation.""", + ) + + def show(self, candidate_names: Optional[List[str]] = None) -> None: + """Shows the evaluation result. + + Args: + candidate_names: list of names for the evaluated candidates, used in + comparison reports. + """ + from .. import _evals_visualization + + _evals_visualization.display_evaluation_result(self, candidate_names) + + +class EvaluationResultDict(TypedDict, total=False): + """Result of an evaluation run for an evaluation dataset.""" + + eval_case_results: Optional[list[EvalCaseResultDict]] + """A list of evaluation results for each evaluation case.""" + + summary_metrics: Optional[list[AggregatedMetricResultDict]] + """A list of summary-level evaluation results for each metric.""" + + win_rates: Optional[dict[str, WinRateStatsDict]] + """A dictionary of win rates for each metric, only populated for multi-response evaluation runs.""" + + evaluation_dataset: Optional[list[EvaluationDatasetDict]] + """The input evaluation dataset(s) for the evaluation run.""" + + metadata: Optional[EvaluationRunMetadataDict] + """Metadata for the evaluation run.""" + + agent_info: Optional[evals_types.AgentInfo] + """This field is experimental and may change in future versions. The agent info of the agent under evaluation. This can be extended for multi-agent evaluation.""" + + +EvaluationResultOrDict = Union[EvaluationResult, EvaluationResultDict] + + +class EvaluationRun(_common.BaseModel): + """Represents an evaluation run.""" + + name: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + metadata: Optional[dict[str, Any]] = Field(default=None, description="""""") + create_time: Optional[datetime.datetime] = Field(default=None, description="""""") + completion_time: Optional[datetime.datetime] = Field( + default=None, description="""""" + ) + state: Optional[EvaluationRunState] = Field(default=None, description="""""") + evaluation_set_snapshot: Optional[str] = Field(default=None, description="""""") + error: Optional[genai_types.GoogleRpcStatus] = Field( + default=None, description="""""" + ) + data_source: Optional[EvaluationRunDataSource] = Field( + default=None, description="""""" + ) + evaluation_run_results: Optional[EvaluationRunResults] = Field( + default=None, description="""The evaluation run formatted results.""" + ) + evaluation_item_results: Optional[EvaluationResult] = Field( + default=None, + description="""The parsed EvaluationItem results for the evaluation run. This is only populated when include_evaluation_items is set to True.""", + ) + evaluation_config: Optional[EvaluationRunConfig] = Field( + default=None, description="""The evaluation config for the evaluation run.""" + ) + inference_configs: Optional[dict[str, EvaluationRunInferenceConfig]] = Field( + default=None, + description="""This field is experimental and may change in future versions. The inference configs for the evaluation run.""", + ) + labels: Optional[dict[str, str]] = Field(default=None, description="""""") + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["EvaluationRun"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "EvaluationRun": + """Converts a dictionary response into a EvaluationRun object.""" + + snaked_response = _camel_key_to_snake(response) + + evaluation_run_results = response.get("evaluation_run_results") + + if ( + isinstance(evaluation_run_results, dict) + and "summaryMetrics" in evaluation_run_results + ): + snaked_response["evaluation_run_results"]["summary_metrics"] = ( + evaluation_run_results["summaryMetrics"] + ) + result = super()._from_response(response=snaked_response, kwargs=kwargs) + return result + + def show(self) -> None: + """Shows the evaluation result.""" + from .. import _evals_visualization + + if self.state == "SUCCEEDED": + if self.evaluation_item_results is not None: + _evals_visualization.display_evaluation_result( + self.evaluation_item_results, None + ) + else: + logger.warning( + "Evaluation Run succeeded but no evaluation item results found. To display results, please set include_evaluation_items to True when calling get_evaluation_run()." + ) + # Show loss analysis results if present on the evaluation run. + # Pass the eval item map so the visualization can enrich + # loss examples with scenario/rubric data. + if ( + self.evaluation_run_results + and self.evaluation_run_results.loss_analysis_results + ): + eval_item_map = getattr(self, "_eval_item_map", None) + _evals_visualization.display_loss_analysis_results( + self.evaluation_run_results.loss_analysis_results, + eval_item_map=eval_item_map, + ) + else: + _evals_visualization.display_evaluation_run_status(self) + + +class EvaluationRunDict(TypedDict, total=False): + """Represents an evaluation run.""" + + name: Optional[str] + """""" + + display_name: Optional[str] + """""" + + metadata: Optional[dict[str, Any]] + """""" + + create_time: Optional[datetime.datetime] + """""" + + completion_time: Optional[datetime.datetime] + """""" + + state: Optional[EvaluationRunState] + """""" + + evaluation_set_snapshot: Optional[str] + """""" + + error: Optional[genai_types.GoogleRpcStatusDict] + """""" + + data_source: Optional[EvaluationRunDataSourceDict] + """""" + + evaluation_run_results: Optional[EvaluationRunResultsDict] + """The evaluation run formatted results.""" + + evaluation_item_results: Optional[EvaluationResultDict] + """The parsed EvaluationItem results for the evaluation run. This is only populated when include_evaluation_items is set to True.""" + + evaluation_config: Optional[EvaluationRunConfigDict] + """The evaluation config for the evaluation run.""" + + inference_configs: Optional[dict[str, EvaluationRunInferenceConfigDict]] + """This field is experimental and may change in future versions. The inference configs for the evaluation run.""" + + labels: Optional[dict[str, str]] + """""" + + +EvaluationRunOrDict = Union[EvaluationRun, EvaluationRunDict] + + +class CreateEvaluationSetConfig(_common.BaseModel): + """Config to create an evaluation set.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateEvaluationSetConfigDict(TypedDict, total=False): + """Config to create an evaluation set.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateEvaluationSetConfigOrDict = Union[ + CreateEvaluationSetConfig, CreateEvaluationSetConfigDict +] + + +class _CreateEvaluationSetParameters(_common.BaseModel): + """Represents a job that creates an evaluation set.""" + + evaluation_items: Optional[list[str]] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + config: Optional[CreateEvaluationSetConfig] = Field( + default=None, description="""""" + ) + + +class _CreateEvaluationSetParametersDict(TypedDict, total=False): + """Represents a job that creates an evaluation set.""" + + evaluation_items: Optional[list[str]] + """""" + + display_name: Optional[str] + """""" + + config: Optional[CreateEvaluationSetConfigDict] + """""" + + +_CreateEvaluationSetParametersOrDict = Union[ + _CreateEvaluationSetParameters, _CreateEvaluationSetParametersDict +] + + +class EvaluationSet(_common.BaseModel): + """Represents an evaluation set.""" + + name: Optional[str] = Field( + default=None, description="""The resource name of the evaluation set.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the evaluation set.""" + ) + evaluation_items: Optional[list[str]] = Field( + default=None, + description="""The EvaluationItems that are part of this dataset.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, description="""The create time of the evaluation set.""" + ) + update_time: Optional[datetime.datetime] = Field( + default=None, description="""The update time of the evaluation set.""" + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, description="""The metadata of the evaluation set.""" + ) + + +class EvaluationSetDict(TypedDict, total=False): + """Represents an evaluation set.""" + + name: Optional[str] + """The resource name of the evaluation set.""" + + display_name: Optional[str] + """The display name of the evaluation set.""" + + evaluation_items: Optional[list[str]] + """The EvaluationItems that are part of this dataset.""" + + create_time: Optional[datetime.datetime] + """The create time of the evaluation set.""" + + update_time: Optional[datetime.datetime] + """The update time of the evaluation set.""" + + metadata: Optional[dict[str, Any]] + """The metadata of the evaluation set.""" + + +EvaluationSetOrDict = Union[EvaluationSet, EvaluationSetDict] + + +class DeleteEvaluationMetricConfig(_common.BaseModel): + """Config for deleting an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteEvaluationMetricConfigDict(TypedDict, total=False): + """Config for deleting an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteEvaluationMetricConfigOrDict = Union[ + DeleteEvaluationMetricConfig, DeleteEvaluationMetricConfigDict +] + + +class _DeleteEvaluationMetricParameters(_common.BaseModel): + """Parameters for deleting an evaluation metric.""" + + metric_resource_name: Optional[str] = Field(default=None, description="""""") + config: Optional[DeleteEvaluationMetricConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteEvaluationMetricParametersDict(TypedDict, total=False): + """Parameters for deleting an evaluation metric.""" + + metric_resource_name: Optional[str] + """""" + + config: Optional[DeleteEvaluationMetricConfigDict] + """""" + + +_DeleteEvaluationMetricParametersOrDict = Union[ + _DeleteEvaluationMetricParameters, _DeleteEvaluationMetricParametersDict +] + + +class DeleteEvaluationMetricOperation(_common.BaseModel): + """Operation for deleting an evaluation metric.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteEvaluationMetricOperationDict(TypedDict, total=False): + """Operation for deleting an evaluation metric.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteEvaluationMetricOperationOrDict = Union[ + DeleteEvaluationMetricOperation, DeleteEvaluationMetricOperationDict +] + + +class BleuInstance(_common.BaseModel): + """Bleu instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class BleuInstanceDict(TypedDict, total=False): + """Bleu instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +BleuInstanceOrDict = Union[BleuInstance, BleuInstanceDict] + + +class BleuInput(_common.BaseModel): + + instances: Optional[list[BleuInstance]] = Field( + default=None, description="""Required. Repeated bleu instances.""" + ) + metric_spec: Optional[genai_types.BleuSpec] = Field( + default=None, description="""Required. Spec for bleu score metric.""" + ) + + +class BleuInputDict(TypedDict, total=False): + + instances: Optional[list[BleuInstanceDict]] + """Required. Repeated bleu instances.""" + + metric_spec: Optional[genai_types.BleuSpecDict] + """Required. Spec for bleu score metric.""" + + +BleuInputOrDict = Union[BleuInput, BleuInputDict] + + +class ExactMatchInstance(_common.BaseModel): + """Exact match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ExactMatchInstanceDict(TypedDict, total=False): + """Exact match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ExactMatchInstanceOrDict = Union[ExactMatchInstance, ExactMatchInstanceDict] + + +class ExactMatchSpec(_common.BaseModel): + """Spec for exact match metric.""" + + pass + + +class ExactMatchSpecDict(TypedDict, total=False): + """Spec for exact match metric.""" + + pass + + +ExactMatchSpecOrDict = Union[ExactMatchSpec, ExactMatchSpecDict] + + +class ExactMatchInput(_common.BaseModel): + + instances: Optional[list[ExactMatchInstance]] = Field( + default=None, description="""Required. Repeated exact match instances.""" + ) + metric_spec: Optional[ExactMatchSpec] = Field( + default=None, description="""Required. Spec for exact match metric.""" + ) + + +class ExactMatchInputDict(TypedDict, total=False): + + instances: Optional[list[ExactMatchInstanceDict]] + """Required. Repeated exact match instances.""" + + metric_spec: Optional[ExactMatchSpecDict] + """Required. Spec for exact match metric.""" + + +ExactMatchInputOrDict = Union[ExactMatchInput, ExactMatchInputDict] + + +class RougeInstance(_common.BaseModel): + """Rouge instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class RougeInstanceDict(TypedDict, total=False): + """Rouge instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +RougeInstanceOrDict = Union[RougeInstance, RougeInstanceDict] + + +class RougeInput(_common.BaseModel): + """Rouge input.""" + + instances: Optional[list[RougeInstance]] = Field( + default=None, description="""Required. Repeated rouge instances.""" + ) + metric_spec: Optional[genai_types.RougeSpec] = Field( + default=None, description="""Required. Spec for rouge score metric.""" + ) + + +class RougeInputDict(TypedDict, total=False): + """Rouge input.""" + + instances: Optional[list[RougeInstanceDict]] + """Required. Repeated rouge instances.""" + + metric_spec: Optional[genai_types.RougeSpecDict] + """Required. Spec for rouge score metric.""" + + +RougeInputOrDict = Union[RougeInput, RougeInputDict] + + +class ContentMap(_common.BaseModel): + """Map of placeholder in metric prompt template to contents of model input.""" + + values: Optional[dict[str, "ContentMapContents"]] = Field( + default=None, description="""Map of placeholder to contents.""" + ) + + +class ContentMapDict(TypedDict, total=False): + """Map of placeholder in metric prompt template to contents of model input.""" + + values: Optional[dict[str, "ContentMapContents"]] + """Map of placeholder to contents.""" + + +ContentMapOrDict = Union[ContentMap, ContentMapDict] + + +class PointwiseMetricInstance(_common.BaseModel): + """Pointwise metric instance.""" + + json_instance: Optional[str] = Field( + default=None, + description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template.""", + ) + content_map_instance: Optional[ContentMap] = Field( + default=None, + description="""Key-value contents for the mutlimodality input, including text, image, video, audio, and pdf, etc. The key is placeholder in metric prompt template, and the value is the multimodal content.""", + ) + + +class PointwiseMetricInstanceDict(TypedDict, total=False): + """Pointwise metric instance.""" + + json_instance: Optional[str] + """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PointwiseMetricSpec.instance_prompt_template.""" + + content_map_instance: Optional[ContentMapDict] + """Key-value contents for the mutlimodality input, including text, image, video, audio, and pdf, etc. The key is placeholder in metric prompt template, and the value is the multimodal content.""" + + +PointwiseMetricInstanceOrDict = Union[ + PointwiseMetricInstance, PointwiseMetricInstanceDict +] + + +class PointwiseMetricInput(_common.BaseModel): + """Pointwise metric input.""" + + instance: Optional[PointwiseMetricInstance] = Field( + default=None, description="""Required. Pointwise metric instance.""" + ) + metric_spec: Optional[genai_types.PointwiseMetricSpec] = Field( + default=None, description="""Required. Spec for pointwise metric.""" + ) + + +class PointwiseMetricInputDict(TypedDict, total=False): + """Pointwise metric input.""" + + instance: Optional[PointwiseMetricInstanceDict] + """Required. Pointwise metric instance.""" + + metric_spec: Optional[genai_types.PointwiseMetricSpecDict] + """Required. Spec for pointwise metric.""" + + +PointwiseMetricInputOrDict = Union[PointwiseMetricInput, PointwiseMetricInputDict] + + +class PairwiseMetricInstance(_common.BaseModel): + """Pairwise metric instance.""" + + json_instance: Optional[str] = Field( + default=None, + description="""Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template.""", + ) + + +class PairwiseMetricInstanceDict(TypedDict, total=False): + """Pairwise metric instance.""" + + json_instance: Optional[str] + """Instance specified as a json string. String key-value pairs are expected in the json_instance to render PairwiseMetricSpec.instance_prompt_template.""" + + +PairwiseMetricInstanceOrDict = Union[PairwiseMetricInstance, PairwiseMetricInstanceDict] + + +class PairwiseMetricInput(_common.BaseModel): + """Pairwise metric instance.""" + + instance: Optional[PairwiseMetricInstance] = Field( + default=None, description="""Required. Pairwise metric instance.""" + ) + metric_spec: Optional[genai_types.PairwiseMetricSpec] = Field( + default=None, description="""Required. Spec for pairwise metric.""" + ) + + +class PairwiseMetricInputDict(TypedDict, total=False): + """Pairwise metric instance.""" + + instance: Optional[PairwiseMetricInstanceDict] + """Required. Pairwise metric instance.""" + + metric_spec: Optional[genai_types.PairwiseMetricSpecDict] + """Required. Spec for pairwise metric.""" + + +PairwiseMetricInputOrDict = Union[PairwiseMetricInput, PairwiseMetricInputDict] + + +class ToolCallValidInstance(_common.BaseModel): + """Tool call valid instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolCallValidInstanceDict(TypedDict, total=False): + """Tool call valid instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolCallValidInstanceOrDict = Union[ToolCallValidInstance, ToolCallValidInstanceDict] + + +class ToolCallValidSpec(_common.BaseModel): + """Spec for tool call valid metric.""" + + pass + + +class ToolCallValidSpecDict(TypedDict, total=False): + """Spec for tool call valid metric.""" + + pass + + +ToolCallValidSpecOrDict = Union[ToolCallValidSpec, ToolCallValidSpecDict] + + +class ToolCallValidInput(_common.BaseModel): + """Tool call valid input.""" + + instances: Optional[list[ToolCallValidInstance]] = Field( + default=None, description="""Required. Repeated tool call valid instances.""" + ) + metric_spec: Optional[ToolCallValidSpec] = Field( + default=None, description="""Required. Spec for tool call valid metric.""" + ) + + +class ToolCallValidInputDict(TypedDict, total=False): + """Tool call valid input.""" + + instances: Optional[list[ToolCallValidInstanceDict]] + """Required. Repeated tool call valid instances.""" + + metric_spec: Optional[ToolCallValidSpecDict] + """Required. Spec for tool call valid metric.""" + + +ToolCallValidInputOrDict = Union[ToolCallValidInput, ToolCallValidInputDict] + + +class ToolNameMatchInstance(_common.BaseModel): + """Tool name match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolNameMatchInstanceDict(TypedDict, total=False): + """Tool name match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolNameMatchInstanceOrDict = Union[ToolNameMatchInstance, ToolNameMatchInstanceDict] + + +class ToolNameMatchSpec(_common.BaseModel): + """Spec for tool name match metric.""" + + pass + + +class ToolNameMatchSpecDict(TypedDict, total=False): + """Spec for tool name match metric.""" + + pass + + +ToolNameMatchSpecOrDict = Union[ToolNameMatchSpec, ToolNameMatchSpecDict] + + +class ToolNameMatchInput(_common.BaseModel): + """Tool name match input.""" + + instances: Optional[list[ToolNameMatchInstance]] = Field( + default=None, description="""Required. Repeated tool name match instances.""" + ) + metric_spec: Optional[ToolNameMatchSpec] = Field( + default=None, description="""Required. Spec for tool name match metric.""" + ) + + +class ToolNameMatchInputDict(TypedDict, total=False): + """Tool name match input.""" + + instances: Optional[list[ToolNameMatchInstanceDict]] + """Required. Repeated tool name match instances.""" + + metric_spec: Optional[ToolNameMatchSpecDict] + """Required. Spec for tool name match metric.""" + + +ToolNameMatchInputOrDict = Union[ToolNameMatchInput, ToolNameMatchInputDict] + + +class ToolParameterKeyMatchInstance(_common.BaseModel): + """Tool parameter key match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolParameterKeyMatchInstanceDict(TypedDict, total=False): + """Tool parameter key match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolParameterKeyMatchInstanceOrDict = Union[ + ToolParameterKeyMatchInstance, ToolParameterKeyMatchInstanceDict +] + + +class ToolParameterKeyMatchSpec(_common.BaseModel): + """Spec for tool parameter key match metric.""" + + pass + + +class ToolParameterKeyMatchSpecDict(TypedDict, total=False): + """Spec for tool parameter key match metric.""" + + pass + + +ToolParameterKeyMatchSpecOrDict = Union[ + ToolParameterKeyMatchSpec, ToolParameterKeyMatchSpecDict +] + + +class ToolParameterKeyMatchInput(_common.BaseModel): + """Tool parameter key match input.""" + + instances: Optional[list[ToolParameterKeyMatchInstance]] = Field( + default=None, + description="""Required. Repeated tool parameter key match instances.""", + ) + metric_spec: Optional[ToolParameterKeyMatchSpec] = Field( + default=None, + description="""Required. Spec for tool parameter key match metric.""", + ) + + +class ToolParameterKeyMatchInputDict(TypedDict, total=False): + """Tool parameter key match input.""" + + instances: Optional[list[ToolParameterKeyMatchInstanceDict]] + """Required. Repeated tool parameter key match instances.""" + + metric_spec: Optional[ToolParameterKeyMatchSpecDict] + """Required. Spec for tool parameter key match metric.""" + + +ToolParameterKeyMatchInputOrDict = Union[ + ToolParameterKeyMatchInput, ToolParameterKeyMatchInputDict +] + + +class ToolParameterKVMatchInstance(_common.BaseModel): + """Tool parameter kv match instance.""" + + prediction: Optional[str] = Field( + default=None, description="""Required. Output of the evaluated model.""" + ) + reference: Optional[str] = Field( + default=None, + description="""Required. Ground truth used to compare against the prediction.""", + ) + + +class ToolParameterKVMatchInstanceDict(TypedDict, total=False): + """Tool parameter kv match instance.""" + + prediction: Optional[str] + """Required. Output of the evaluated model.""" + + reference: Optional[str] + """Required. Ground truth used to compare against the prediction.""" + + +ToolParameterKVMatchInstanceOrDict = Union[ + ToolParameterKVMatchInstance, ToolParameterKVMatchInstanceDict +] + + +class ToolParameterKVMatchSpec(_common.BaseModel): + """Spec for tool parameter kv match metric.""" + + use_strict_string_match: Optional[bool] = Field( + default=None, + description="""Optional. Whether to use STRICT string match on parameter values.""", + ) + + +class ToolParameterKVMatchSpecDict(TypedDict, total=False): + """Spec for tool parameter kv match metric.""" + + use_strict_string_match: Optional[bool] + """Optional. Whether to use STRICT string match on parameter values.""" + + +ToolParameterKVMatchSpecOrDict = Union[ + ToolParameterKVMatchSpec, ToolParameterKVMatchSpecDict +] + + +class ToolParameterKVMatchInput(_common.BaseModel): + """Tool parameter kv match input.""" + + instances: Optional[list[ToolParameterKVMatchInstance]] = Field( + default=None, + description="""Required. Repeated tool parameter key value match instances.""", + ) + metric_spec: Optional[ToolParameterKVMatchSpec] = Field( + default=None, + description="""Required. Spec for tool parameter key value match metric.""", + ) + + +class ToolParameterKVMatchInputDict(TypedDict, total=False): + """Tool parameter kv match input.""" + + instances: Optional[list[ToolParameterKVMatchInstanceDict]] + """Required. Repeated tool parameter key value match instances.""" + + metric_spec: Optional[ToolParameterKVMatchSpecDict] + """Required. Spec for tool parameter key value match metric.""" + + +ToolParameterKVMatchInputOrDict = Union[ + ToolParameterKVMatchInput, ToolParameterKVMatchInputDict +] + + +class MapInstance(_common.BaseModel): + """Instance data specified as a map.""" + + map_instance: Optional[dict[str, evals_types.InstanceData]] = Field( + default=None, description="""Map of instance data.""" + ) + + +class MapInstanceDict(TypedDict, total=False): + """Instance data specified as a map.""" + + map_instance: Optional[dict[str, evals_types.InstanceData]] + """Map of instance data.""" + + +MapInstanceOrDict = Union[MapInstance, MapInstanceDict] + + +class EvaluationInstance(_common.BaseModel): + """A single instance to be evaluated.""" + + prompt: Optional[evals_types.InstanceData] = Field( + default=None, + description="""Data used to populate placeholder `prompt` in a metric prompt template.""", + ) + response: Optional[evals_types.InstanceData] = Field( + default=None, + description="""Data used to populate placeholder `response` in a metric prompt template.""", + ) + reference: Optional[evals_types.InstanceData] = Field( + default=None, + description="""Data used to populate placeholder `reference` in a metric prompt template.""", + ) + other_data: Optional[MapInstance] = Field( + default=None, + description="""Other data used to populate placeholders based on their key.""", + ) + agent_data: Optional[evals_types.AgentData] = Field( + default=None, description="""Data used for agent evaluation.""" + ) + rubric_groups: Optional[dict[str, "RubricGroup"]] = Field( + default=None, + description="""Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""", + ) + + +class EvaluationInstanceDict(TypedDict, total=False): + """A single instance to be evaluated.""" + + prompt: Optional[evals_types.InstanceData] + """Data used to populate placeholder `prompt` in a metric prompt template.""" + + response: Optional[evals_types.InstanceData] + """Data used to populate placeholder `response` in a metric prompt template.""" + + reference: Optional[evals_types.InstanceData] + """Data used to populate placeholder `reference` in a metric prompt template.""" + + other_data: Optional[MapInstanceDict] + """Other data used to populate placeholders based on their key.""" + + agent_data: Optional[evals_types.AgentData] + """Data used for agent evaluation.""" + + rubric_groups: Optional[dict[str, "RubricGroupDict"]] + """Named groups of rubrics associated with this prompt. The key is a user-defined name for the rubric group.""" + + +EvaluationInstanceOrDict = Union[EvaluationInstance, EvaluationInstanceDict] + + +class EvaluateInstancesConfig(_common.BaseModel): + """Config for evaluate instances.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class EvaluateInstancesConfigDict(TypedDict, total=False): + """Config for evaluate instances.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +EvaluateInstancesConfigOrDict = Union[ + EvaluateInstancesConfig, EvaluateInstancesConfigDict +] + + +class RubricBasedMetricSpec(_common.BaseModel): + """Specification for a metric that is based on rubrics.""" + + metric_prompt_template: Optional[str] = Field( + default=None, + description="""Template for the prompt used by the judge model to evaluate against + rubrics.""", + ) + judge_autorater_config: Optional[genai_types.AutoraterConfig] = Field( + default=None, + description="""Optional configuration for the judge LLM (Autorater).""", + ) + inline_rubrics: Optional[list[evals_types.Rubric]] = Field( + default=None, description="""Use rubrics provided directly in the spec.""" + ) + rubric_group_key: Optional[str] = Field( + default=None, + description="""Use a pre-defined group of rubrics associated with the input content. + This refers to a key in the `rubric_groups` map of + `RubricEnhancedContents`.""", + ) + rubric_generation_spec: Optional[genai_types.RubricGenerationSpec] = Field( + default=None, + description="""Dynamically generate rubrics for evaluation using this specification.""", + ) + + +class RubricBasedMetricSpecDict(TypedDict, total=False): + """Specification for a metric that is based on rubrics.""" + + metric_prompt_template: Optional[str] + """Template for the prompt used by the judge model to evaluate against + rubrics.""" + + judge_autorater_config: Optional[genai_types.AutoraterConfigDict] + """Optional configuration for the judge LLM (Autorater).""" + + inline_rubrics: Optional[list[evals_types.Rubric]] + """Use rubrics provided directly in the spec.""" + + rubric_group_key: Optional[str] + """Use a pre-defined group of rubrics associated with the input content. + This refers to a key in the `rubric_groups` map of + `RubricEnhancedContents`.""" + + rubric_generation_spec: Optional[genai_types.RubricGenerationSpecDict] + """Dynamically generate rubrics for evaluation using this specification.""" + + +RubricBasedMetricSpecOrDict = Union[RubricBasedMetricSpec, RubricBasedMetricSpecDict] + + +class RubricEnhancedContents(_common.BaseModel): + """Rubric-enhanced contents for evaluation.""" + + prompt: Optional[list[genai_types.Content]] = Field( + default=None, + description="""User prompt, using the standard Content type from the Gen AI SDK.""", + ) + rubric_groups: Optional[dict[str, "RubricGroup"]] = Field( + default=None, + description="""Named groups of rubrics associated with this prompt. + The key is a user-defined name for the rubric group.""", + ) + response: Optional[list[genai_types.Content]] = Field( + default=None, + description="""Response, using the standard Content type from the Gen AI SDK.""", + ) + other_content: Optional[ContentMap] = Field( + default=None, + description="""Other contents needed for the metric. + For example, if `reference` is needed for the metric, it can be provided + here.""", + ) + + +class RubricEnhancedContentsDict(TypedDict, total=False): + """Rubric-enhanced contents for evaluation.""" + + prompt: Optional[list[genai_types.ContentDict]] + """User prompt, using the standard Content type from the Gen AI SDK.""" + + rubric_groups: Optional[dict[str, "RubricGroup"]] + """Named groups of rubrics associated with this prompt. + The key is a user-defined name for the rubric group.""" + + response: Optional[list[genai_types.ContentDict]] + """Response, using the standard Content type from the Gen AI SDK.""" + + other_content: Optional[ContentMapDict] + """Other contents needed for the metric. + For example, if `reference` is needed for the metric, it can be provided + here.""" + + +RubricEnhancedContentsOrDict = Union[RubricEnhancedContents, RubricEnhancedContentsDict] + + +class RubricBasedMetricInstance(_common.BaseModel): + """Defines an instance for Rubric-based metrics. + + This class allows various input formats. + """ + + json_instance: Optional[str] = Field( + default=None, + description="""Specify evaluation fields and their string values in JSON format.""", + ) + content_map_instance: Optional[ContentMap] = Field( + default=None, + description="""Specify evaluation fields and their content values using a ContentMap.""", + ) + rubric_enhanced_contents: Optional[RubricEnhancedContents] = Field( + default=None, + description="""Provide input as Gemini Content along with one or more + associated rubric groups.""", + ) + + +class RubricBasedMetricInstanceDict(TypedDict, total=False): + """Defines an instance for Rubric-based metrics. + + This class allows various input formats. + """ + + json_instance: Optional[str] + """Specify evaluation fields and their string values in JSON format.""" + + content_map_instance: Optional[ContentMapDict] + """Specify evaluation fields and their content values using a ContentMap.""" + + rubric_enhanced_contents: Optional[RubricEnhancedContentsDict] + """Provide input as Gemini Content along with one or more + associated rubric groups.""" + + +RubricBasedMetricInstanceOrDict = Union[ + RubricBasedMetricInstance, RubricBasedMetricInstanceDict +] + + +class RubricBasedMetricInput(_common.BaseModel): + """Input for a rubric-based metrics.""" + + metric_spec: Optional[RubricBasedMetricSpec] = Field( + default=None, description="""Specification for the rubric-based metric.""" + ) + instance: Optional[RubricBasedMetricInstance] = Field( + default=None, description="""The instance to be evaluated.""" + ) + + +class RubricBasedMetricInputDict(TypedDict, total=False): + """Input for a rubric-based metrics.""" + + metric_spec: Optional[RubricBasedMetricSpecDict] + """Specification for the rubric-based metric.""" + + instance: Optional[RubricBasedMetricInstanceDict] + """The instance to be evaluated.""" + + +RubricBasedMetricInputOrDict = Union[RubricBasedMetricInput, RubricBasedMetricInputDict] + + +class MetricSource(_common.BaseModel): + """The metric source used for evaluation.""" + + metric: Optional[Metric] = Field( + default=None, description="""Inline metric config.""" + ) + metric_resource_name: Optional[str] = Field( + default=None, + description="""Resource name for registered metric. Example: + projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""", + ) + + +class MetricSourceDict(TypedDict, total=False): + """The metric source used for evaluation.""" + + metric: Optional[MetricDict] + """Inline metric config.""" + + metric_resource_name: Optional[str] + """Resource name for registered metric. Example: + projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""" + + +MetricSourceOrDict = Union[MetricSource, MetricSourceDict] + + +class _EvaluateInstancesRequestParameters(_common.BaseModel): + """Parameters for evaluating instances.""" + + bleu_input: Optional[BleuInput] = Field(default=None, description="""""") + exact_match_input: Optional[ExactMatchInput] = Field( + default=None, description="""""" + ) + rouge_input: Optional[RougeInput] = Field(default=None, description="""""") + pointwise_metric_input: Optional[PointwiseMetricInput] = Field( + default=None, description="""""" + ) + pairwise_metric_input: Optional[PairwiseMetricInput] = Field( + default=None, description="""""" + ) + tool_call_valid_input: Optional[ToolCallValidInput] = Field( + default=None, description="""""" + ) + tool_name_match_input: Optional[ToolNameMatchInput] = Field( + default=None, description="""""" + ) + tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInput] = Field( + default=None, description="""""" + ) + tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInput] = Field( + default=None, description="""""" + ) + rubric_based_metric_input: Optional[RubricBasedMetricInput] = Field( + default=None, description="""""" + ) + autorater_config: Optional[genai_types.AutoraterConfig] = Field( + default=None, description="""""" + ) + metrics: Optional[list[Metric]] = Field( + default=None, + description="""The metrics used for evaluation. + Currently, we only support evaluating a single metric. If multiple metrics + are provided, only the first one will be evaluated.""", + ) + instance: Optional[EvaluationInstance] = Field( + default=None, description="""The instance to be evaluated.""" + ) + metric_sources: Optional[list[MetricSource]] = Field( + default=None, description="""The metrics used for evaluation.""" + ) + config: Optional[EvaluateInstancesConfig] = Field(default=None, description="""""") + + +class _EvaluateInstancesRequestParametersDict(TypedDict, total=False): + """Parameters for evaluating instances.""" + + bleu_input: Optional[BleuInputDict] + """""" + + exact_match_input: Optional[ExactMatchInputDict] + """""" + + rouge_input: Optional[RougeInputDict] + """""" + + pointwise_metric_input: Optional[PointwiseMetricInputDict] + """""" + + pairwise_metric_input: Optional[PairwiseMetricInputDict] + """""" + + tool_call_valid_input: Optional[ToolCallValidInputDict] + """""" + + tool_name_match_input: Optional[ToolNameMatchInputDict] + """""" + + tool_parameter_key_match_input: Optional[ToolParameterKeyMatchInputDict] + """""" + + tool_parameter_kv_match_input: Optional[ToolParameterKVMatchInputDict] + """""" + + rubric_based_metric_input: Optional[RubricBasedMetricInputDict] + """""" + + autorater_config: Optional[genai_types.AutoraterConfigDict] + """""" + + metrics: Optional[list[MetricDict]] + """The metrics used for evaluation. + Currently, we only support evaluating a single metric. If multiple metrics + are provided, only the first one will be evaluated.""" + + instance: Optional[EvaluationInstanceDict] + """The instance to be evaluated.""" + + metric_sources: Optional[list[MetricSourceDict]] + """The metrics used for evaluation.""" + + config: Optional[EvaluateInstancesConfigDict] + """""" + + +_EvaluateInstancesRequestParametersOrDict = Union[ + _EvaluateInstancesRequestParameters, _EvaluateInstancesRequestParametersDict +] + + +class MetricResult(_common.BaseModel): + """Result for a single metric on a single instance.""" + + score: Optional[float] = Field( + default=None, + description="""The score for the metric. Please refer to each metric's documentation for the meaning of the score.""", + ) + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] = Field( + default=None, + description="""For rubric-based metrics, the verdicts for each rubric.""", + ) + explanation: Optional[str] = Field( + default=None, description="""The explanation for the metric result.""" + ) + error: Optional[genai_types.GoogleRpcStatus] = Field( + default=None, description="""The error status for the metric result.""" + ) + + +class MetricResultDict(TypedDict, total=False): + """Result for a single metric on a single instance.""" + + score: Optional[float] + """The score for the metric. Please refer to each metric's documentation for the meaning of the score.""" + + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] + """For rubric-based metrics, the verdicts for each rubric.""" + + explanation: Optional[str] + """The explanation for the metric result.""" + + error: Optional[genai_types.GoogleRpcStatusDict] + """The error status for the metric result.""" + + +MetricResultOrDict = Union[MetricResult, MetricResultDict] + + +class BleuResults(_common.BaseModel): + """Result of evaluating a bleu metric.""" + + bleu_metric_values: Optional[list[genai_types.BleuMetricValue]] = Field( + default=None, description="""Output only. Bleu metric values.""" + ) + + +class BleuResultsDict(TypedDict, total=False): + """Result of evaluating a bleu metric.""" + + bleu_metric_values: Optional[list[genai_types.BleuMetricValueDict]] + """Output only. Bleu metric values.""" + + +BleuResultsOrDict = Union[BleuResults, BleuResultsDict] + + +class ExactMatchResults(_common.BaseModel): + """Result of evaluating an exact match metric.""" + + exact_match_metric_values: Optional[list[genai_types.ExactMatchMetricValue]] = ( + Field(default=None, description="""Output only. Exact match metric values.""") + ) + + +class ExactMatchResultsDict(TypedDict, total=False): + """Result of evaluating an exact match metric.""" + + exact_match_metric_values: Optional[list[genai_types.ExactMatchMetricValueDict]] + """Output only. Exact match metric values.""" + + +ExactMatchResultsOrDict = Union[ExactMatchResults, ExactMatchResultsDict] + + +class RougeResults(_common.BaseModel): + """Result of evaluating a rouge metric.""" + + rouge_metric_values: Optional[list[genai_types.RougeMetricValue]] = Field( + default=None, description="""Output only. Rouge metric values.""" + ) + + +class RougeResultsDict(TypedDict, total=False): + """Result of evaluating a rouge metric.""" + + rouge_metric_values: Optional[list[genai_types.RougeMetricValueDict]] + """Output only. Rouge metric values.""" + + +RougeResultsOrDict = Union[RougeResults, RougeResultsDict] + + +class RubricBasedMetricResult(_common.BaseModel): + """Result for a rubric-based metric.""" + + score: Optional[float] = Field( + default=None, description="""Passing rate of all the rubrics.""" + ) + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] = Field( + default=None, + description="""The details of all the rubrics and their verdicts.""", + ) + + +class RubricBasedMetricResultDict(TypedDict, total=False): + """Result for a rubric-based metric.""" + + score: Optional[float] + """Passing rate of all the rubrics.""" + + rubric_verdicts: Optional[list[evals_types.RubricVerdict]] + """The details of all the rubrics and their verdicts.""" + + +RubricBasedMetricResultOrDict = Union[ + RubricBasedMetricResult, RubricBasedMetricResultDict +] + + +class CometResult(_common.BaseModel): + """Spec for Comet result - calculates the comet score for the given instance using the version specified in the spec.""" + + score: Optional[float] = Field( + default=None, + description="""Output only. Comet score. Range depends on version.""", + ) + + +class CometResultDict(TypedDict, total=False): + """Spec for Comet result - calculates the comet score for the given instance using the version specified in the spec.""" + + score: Optional[float] + """Output only. Comet score. Range depends on version.""" + + +CometResultOrDict = Union[CometResult, CometResultDict] + + +class MetricxResult(_common.BaseModel): + """Spec for MetricX result - calculates the MetricX score for the given instance using the version specified in the spec.""" + + score: Optional[float] = Field( + default=None, + description="""Output only. MetricX score. Range depends on version.""", + ) + + +class MetricxResultDict(TypedDict, total=False): + """Spec for MetricX result - calculates the MetricX score for the given instance using the version specified in the spec.""" + + score: Optional[float] + """Output only. MetricX score. Range depends on version.""" + + +MetricxResultOrDict = Union[MetricxResult, MetricxResultDict] + + +class ToolCallValidMetricValue(_common.BaseModel): + """Tool call valid metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Tool call valid score.""" + ) + + +class ToolCallValidMetricValueDict(TypedDict, total=False): + """Tool call valid metric value for an instance.""" + + score: Optional[float] + """Output only. Tool call valid score.""" + + +ToolCallValidMetricValueOrDict = Union[ + ToolCallValidMetricValue, ToolCallValidMetricValueDict +] + + +class ToolCallValidResults(_common.BaseModel): + """Results for tool call valid metric.""" + + tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValue]] = Field( + default=None, description="""Output only. Tool call valid metric values.""" + ) + + +class ToolCallValidResultsDict(TypedDict, total=False): + """Results for tool call valid metric.""" + + tool_call_valid_metric_values: Optional[list[ToolCallValidMetricValueDict]] + """Output only. Tool call valid metric values.""" + + +ToolCallValidResultsOrDict = Union[ToolCallValidResults, ToolCallValidResultsDict] + + +class ToolNameMatchMetricValue(_common.BaseModel): + """Tool name match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Tool name match score.""" + ) + + +class ToolNameMatchMetricValueDict(TypedDict, total=False): + """Tool name match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool name match score.""" + + +ToolNameMatchMetricValueOrDict = Union[ + ToolNameMatchMetricValue, ToolNameMatchMetricValueDict +] + + +class ToolNameMatchResults(_common.BaseModel): + """Results for tool name match metric.""" + + tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValue]] = Field( + default=None, description="""Output only. Tool name match metric values.""" + ) + + +class ToolNameMatchResultsDict(TypedDict, total=False): + """Results for tool name match metric.""" + + tool_name_match_metric_values: Optional[list[ToolNameMatchMetricValueDict]] + """Output only. Tool name match metric values.""" + + +ToolNameMatchResultsOrDict = Union[ToolNameMatchResults, ToolNameMatchResultsDict] + + +class ToolParameterKeyMatchMetricValue(_common.BaseModel): + """Tool parameter key match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, description="""Output only. Tool parameter key match score.""" + ) + + +class ToolParameterKeyMatchMetricValueDict(TypedDict, total=False): + """Tool parameter key match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool parameter key match score.""" + + +ToolParameterKeyMatchMetricValueOrDict = Union[ + ToolParameterKeyMatchMetricValue, ToolParameterKeyMatchMetricValueDict +] + + +class ToolParameterKeyMatchResults(_common.BaseModel): + """Results for tool parameter key match metric.""" + + tool_parameter_key_match_metric_values: Optional[ + list[ToolParameterKeyMatchMetricValue] + ] = Field( + default=None, + description="""Output only. Tool parameter key match metric values.""", + ) + + +class ToolParameterKeyMatchResultsDict(TypedDict, total=False): + """Results for tool parameter key match metric.""" + + tool_parameter_key_match_metric_values: Optional[ + list[ToolParameterKeyMatchMetricValueDict] + ] + """Output only. Tool parameter key match metric values.""" + + +ToolParameterKeyMatchResultsOrDict = Union[ + ToolParameterKeyMatchResults, ToolParameterKeyMatchResultsDict +] + + +class ToolParameterKVMatchMetricValue(_common.BaseModel): + """Tool parameter key value match metric value for an instance.""" + + score: Optional[float] = Field( + default=None, + description="""Output only. Tool parameter key value match score.""", + ) + + +class ToolParameterKVMatchMetricValueDict(TypedDict, total=False): + """Tool parameter key value match metric value for an instance.""" + + score: Optional[float] + """Output only. Tool parameter key value match score.""" + + +ToolParameterKVMatchMetricValueOrDict = Union[ + ToolParameterKVMatchMetricValue, ToolParameterKVMatchMetricValueDict +] + + +class ToolParameterKVMatchResults(_common.BaseModel): + """Results for tool parameter key value match metric.""" + + tool_parameter_kv_match_metric_values: Optional[ + list[ToolParameterKVMatchMetricValue] + ] = Field( + default=None, + description="""Output only. Tool parameter key value match metric values.""", + ) + + +class ToolParameterKVMatchResultsDict(TypedDict, total=False): + """Results for tool parameter key value match metric.""" + + tool_parameter_kv_match_metric_values: Optional[ + list[ToolParameterKVMatchMetricValueDict] + ] + """Output only. Tool parameter key value match metric values.""" + + +ToolParameterKVMatchResultsOrDict = Union[ + ToolParameterKVMatchResults, ToolParameterKVMatchResultsDict +] + + +class EvaluateInstancesResponse(_common.BaseModel): + """Result of evaluating an LLM metric.""" + + rubric_based_metric_result: Optional[RubricBasedMetricResult] = Field( + default=None, description="""Result for rubric based metric.""" + ) + metric_results: Optional[list[MetricResult]] = Field( + default=None, + description="""A list of metric results for each evaluation case. The order of the metric results is guaranteed to be the same as the order of the instances in the request.""", + ) + bleu_results: Optional[BleuResults] = Field( + default=None, description="""Results for bleu metric.""" + ) + comet_result: Optional[CometResult] = Field( + default=None, description="""Translation metrics. Result for Comet metric.""" + ) + exact_match_results: Optional[ExactMatchResults] = Field( + default=None, + description="""Auto metric evaluation results. Results for exact match metric.""", + ) + metricx_result: Optional[MetricxResult] = Field( + default=None, description="""Result for Metricx metric.""" + ) + pairwise_metric_result: Optional[genai_types.PairwiseMetricResult] = Field( + default=None, description="""Result for pairwise metric.""" + ) + pointwise_metric_result: Optional[genai_types.PointwiseMetricResult] = Field( + default=None, description="""Generic metrics. Result for pointwise metric.""" + ) + rouge_results: Optional[RougeResults] = Field( + default=None, description="""Results for rouge metric.""" + ) + tool_call_valid_results: Optional[ToolCallValidResults] = Field( + default=None, + description="""Tool call metrics. Results for tool call valid metric.""", + ) + tool_name_match_results: Optional[ToolNameMatchResults] = Field( + default=None, description="""Results for tool name match metric.""" + ) + tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResults] = Field( + default=None, description="""Results for tool parameter key match metric.""" + ) + tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResults] = Field( + default=None, + description="""Results for tool parameter key value match metric.""", + ) + + +class EvaluateInstancesResponseDict(TypedDict, total=False): + """Result of evaluating an LLM metric.""" + + rubric_based_metric_result: Optional[RubricBasedMetricResultDict] + """Result for rubric based metric.""" + + metric_results: Optional[list[MetricResultDict]] + """A list of metric results for each evaluation case. The order of the metric results is guaranteed to be the same as the order of the instances in the request.""" + + bleu_results: Optional[BleuResultsDict] + """Results for bleu metric.""" + + comet_result: Optional[CometResultDict] + """Translation metrics. Result for Comet metric.""" + + exact_match_results: Optional[ExactMatchResultsDict] + """Auto metric evaluation results. Results for exact match metric.""" + + metricx_result: Optional[MetricxResultDict] + """Result for Metricx metric.""" + + pairwise_metric_result: Optional[genai_types.PairwiseMetricResultDict] + """Result for pairwise metric.""" + + pointwise_metric_result: Optional[genai_types.PointwiseMetricResultDict] + """Generic metrics. Result for pointwise metric.""" + + rouge_results: Optional[RougeResultsDict] + """Results for rouge metric.""" + + tool_call_valid_results: Optional[ToolCallValidResultsDict] + """Tool call metrics. Results for tool call valid metric.""" + + tool_name_match_results: Optional[ToolNameMatchResultsDict] + """Results for tool name match metric.""" + + tool_parameter_key_match_results: Optional[ToolParameterKeyMatchResultsDict] + """Results for tool parameter key match metric.""" + + tool_parameter_kv_match_results: Optional[ToolParameterKVMatchResultsDict] + """Results for tool parameter key value match metric.""" + + +EvaluateInstancesResponseOrDict = Union[ + EvaluateInstancesResponse, EvaluateInstancesResponseDict +] + + +class GenerateUserScenariosConfig(_common.BaseModel): + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GenerateUserScenariosConfigDict(TypedDict, total=False): + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GenerateUserScenariosConfigOrDict = Union[ + GenerateUserScenariosConfig, GenerateUserScenariosConfigDict +] + + +class _GenerateUserScenariosParameters(_common.BaseModel): + """Parameters for GenerateUserScenarios.""" + + location: Optional[str] = Field(default=None, description="""""") + agents: Optional[dict[str, evals_types.AgentConfig]] = Field( + default=None, description="""""" + ) + root_agent_id: Optional[str] = Field(default=None, description="""""") + user_scenario_generation_config: Optional[ + evals_types.UserScenarioGenerationConfig + ] = Field(default=None, description="""""") + config: Optional[GenerateUserScenariosConfig] = Field( + default=None, description="""""" + ) + allow_cross_region_model: Optional[bool] = Field( + default=None, + description="""Opt-in flag to authorize cross-region routing for LLM models.""", + ) + + +class _GenerateUserScenariosParametersDict(TypedDict, total=False): + """Parameters for GenerateUserScenarios.""" + + location: Optional[str] + """""" + + agents: Optional[dict[str, evals_types.AgentConfig]] + """""" + + root_agent_id: Optional[str] + """""" + + user_scenario_generation_config: Optional[evals_types.UserScenarioGenerationConfig] + """""" + + config: Optional[GenerateUserScenariosConfigDict] + """""" + + allow_cross_region_model: Optional[bool] + """Opt-in flag to authorize cross-region routing for LLM models.""" + + +_GenerateUserScenariosParametersOrDict = Union[ + _GenerateUserScenariosParameters, _GenerateUserScenariosParametersDict +] + + +class GenerateUserScenariosResponse(_common.BaseModel): + """Response message for DataFoundryService.GenerateUserScenarios.""" + + user_scenarios: Optional[list[evals_types.UserScenario]] = Field( + default=None, description="""""" + ) + + +class GenerateUserScenariosResponseDict(TypedDict, total=False): + """Response message for DataFoundryService.GenerateUserScenarios.""" + + user_scenarios: Optional[list[evals_types.UserScenario]] + """""" + + +GenerateUserScenariosResponseOrDict = Union[ + GenerateUserScenariosResponse, GenerateUserScenariosResponseDict +] + + +class GenerateLossClustersConfig(_common.BaseModel): + """Config for generating loss clusters.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GenerateLossClustersConfigDict(TypedDict, total=False): + """Config for generating loss clusters.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GenerateLossClustersConfigOrDict = Union[ + GenerateLossClustersConfig, GenerateLossClustersConfigDict +] + + +class _GenerateLossClustersParameters(_common.BaseModel): + """Parameters for GenerateLossClusters.""" + + location: Optional[str] = Field( + default=None, + description="""The resource name of the Location. Format: `projects/{project}/locations/{location}`.""", + ) + evaluation_set: Optional[str] = Field( + default=None, + description="""Reference to a persisted EvaluationSet. The service will read items from this set.""", + ) + inline_results: Optional[list[EvaluationResult]] = Field( + default=None, + description="""Inline evaluation results. Useful for ephemeral analysis in notebooks/SDKs where data isn't persisted.""", + ) + configs: Optional[list[LossAnalysisConfig]] = Field( + default=None, + description="""Configuration for the analysis algorithm. Analysis for multiple metrics and multiple candidates could be specified.""", + ) + config: Optional[GenerateLossClustersConfig] = Field( + default=None, description="""Config for generating loss clusters.""" + ) + + +class _GenerateLossClustersParametersDict(TypedDict, total=False): + """Parameters for GenerateLossClusters.""" + + location: Optional[str] + """The resource name of the Location. Format: `projects/{project}/locations/{location}`.""" + + evaluation_set: Optional[str] + """Reference to a persisted EvaluationSet. The service will read items from this set.""" + + inline_results: Optional[list[EvaluationResultDict]] + """Inline evaluation results. Useful for ephemeral analysis in notebooks/SDKs where data isn't persisted.""" + + configs: Optional[list[LossAnalysisConfigDict]] + """Configuration for the analysis algorithm. Analysis for multiple metrics and multiple candidates could be specified.""" + + config: Optional[GenerateLossClustersConfigDict] + """Config for generating loss clusters.""" + + +_GenerateLossClustersParametersOrDict = Union[ + _GenerateLossClustersParameters, _GenerateLossClustersParametersDict +] + + +class GenerateLossClustersResponse(_common.BaseModel): + """Response message for EvaluationAnalyticsService.GenerateLossClusters.""" + + analysis_time: Optional[str] = Field( + default=None, description="""The timestamp when this analysis was completed.""" + ) + results: Optional[list[LossAnalysisResult]] = Field( + default=None, + description="""The analysis results, one per config provided in the request.""", + ) + + def show(self) -> None: + """Shows the loss pattern analysis report with rich HTML visualization.""" + from .. import _evals_visualization + + _evals_visualization.display_loss_clusters_response(self) + + +class GenerateLossClustersResponseDict(TypedDict, total=False): + """Response message for EvaluationAnalyticsService.GenerateLossClusters.""" + + analysis_time: Optional[str] + """The timestamp when this analysis was completed.""" + + results: Optional[list[LossAnalysisResultDict]] + """The analysis results, one per config provided in the request.""" + + +GenerateLossClustersResponseOrDict = Union[ + GenerateLossClustersResponse, GenerateLossClustersResponseDict +] + + +class GenerateLossClustersOperation(_common.BaseModel): + """Long-running operation for generating loss clusters.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[GenerateLossClustersResponse] = Field( + default=None, + description="""Response message for EvaluationAnalyticsService.GenerateLossClusters.""", + ) + + +class GenerateLossClustersOperationDict(TypedDict, total=False): + """Long-running operation for generating loss clusters.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[GenerateLossClustersResponseDict] + """Response message for EvaluationAnalyticsService.GenerateLossClusters.""" + + +GenerateLossClustersOperationOrDict = Union[ + GenerateLossClustersOperation, GenerateLossClustersOperationDict +] + + +class RubricGenerationConfig(_common.BaseModel): + """Config for generating rubrics.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class RubricGenerationConfigDict(TypedDict, total=False): + """Config for generating rubrics.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +RubricGenerationConfigOrDict = Union[RubricGenerationConfig, RubricGenerationConfigDict] + + +class _GenerateInstanceRubricsRequest(_common.BaseModel): + """Parameters for generating rubrics.""" + + contents: Optional[list[genai_types.Content]] = Field( + default=None, + description="""The prompt to generate rubrics from. For single-turn queries, this is a single instance. For multi-turn queries, this is a repeated field that contains conversation history + latest request.""", + ) + predefined_rubric_generation_spec: Optional[genai_types.PredefinedMetricSpec] = ( + Field( + default=None, + description="""Specification for using the rubric generation configs of a pre-defined + metric, e.g. "generic_quality_v1" and "instruction_following_v1". + Some of the configs may be only used in rubric generation and not + supporting evaluation, e.g. "fully_customized_generic_quality_v1". + If this field is set, the `rubric_generation_spec` field will be ignored. + """, + ) + ) + rubric_generation_spec: Optional[genai_types.RubricGenerationSpec] = Field( + default=None, + description="""Specification for how the rubrics should be generated.""", + ) + metric_resource_name: Optional[str] = Field( + default=None, + description="""Registered metric resource name. If this field is set, the configuration provided in this field is used for rubric generation. The `predefined_rubric_generation_spec` and `rubric_generation_spec` fields will be ignored.""", + ) + config: Optional[RubricGenerationConfig] = Field(default=None, description="""""") + + +class _GenerateInstanceRubricsRequestDict(TypedDict, total=False): + """Parameters for generating rubrics.""" + + contents: Optional[list[genai_types.ContentDict]] + """The prompt to generate rubrics from. For single-turn queries, this is a single instance. For multi-turn queries, this is a repeated field that contains conversation history + latest request.""" + + predefined_rubric_generation_spec: Optional[genai_types.PredefinedMetricSpecDict] + """Specification for using the rubric generation configs of a pre-defined + metric, e.g. "generic_quality_v1" and "instruction_following_v1". + Some of the configs may be only used in rubric generation and not + supporting evaluation, e.g. "fully_customized_generic_quality_v1". + If this field is set, the `rubric_generation_spec` field will be ignored. + """ + + rubric_generation_spec: Optional[genai_types.RubricGenerationSpecDict] + """Specification for how the rubrics should be generated.""" + + metric_resource_name: Optional[str] + """Registered metric resource name. If this field is set, the configuration provided in this field is used for rubric generation. The `predefined_rubric_generation_spec` and `rubric_generation_spec` fields will be ignored.""" + + config: Optional[RubricGenerationConfigDict] + """""" + + +_GenerateInstanceRubricsRequestOrDict = Union[ + _GenerateInstanceRubricsRequest, _GenerateInstanceRubricsRequestDict +] + + +class GenerateInstanceRubricsResponse(_common.BaseModel): + """Response for generating rubrics.""" + + generated_rubrics: Optional[list[evals_types.Rubric]] = Field( + default=None, description="""A list of generated rubrics.""" + ) + + +class GenerateInstanceRubricsResponseDict(TypedDict, total=False): + """Response for generating rubrics.""" + + generated_rubrics: Optional[list[evals_types.Rubric]] + """A list of generated rubrics.""" + + +GenerateInstanceRubricsResponseOrDict = Union[ + GenerateInstanceRubricsResponse, GenerateInstanceRubricsResponseDict +] + + +class GetEvaluationMetricConfig(_common.BaseModel): + """Config for getting an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetEvaluationMetricConfigDict(TypedDict, total=False): + """Config for getting an evaluation metric.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetEvaluationMetricConfigOrDict = Union[ + GetEvaluationMetricConfig, GetEvaluationMetricConfigDict +] + + +class _GetEvaluationMetricParameters(_common.BaseModel): + """Parameters for getting an evaluation metric.""" + + metric_resource_name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetEvaluationMetricConfig] = Field( + default=None, description="""""" + ) + + +class _GetEvaluationMetricParametersDict(TypedDict, total=False): + """Parameters for getting an evaluation metric.""" + + metric_resource_name: Optional[str] + """""" + + config: Optional[GetEvaluationMetricConfigDict] + """""" + + +_GetEvaluationMetricParametersOrDict = Union[ + _GetEvaluationMetricParameters, _GetEvaluationMetricParametersDict +] + + +class GetEvaluationRunConfig(_common.BaseModel): + """Config for get evaluation run.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetEvaluationRunConfigDict(TypedDict, total=False): + """Config for get evaluation run.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetEvaluationRunConfigOrDict = Union[GetEvaluationRunConfig, GetEvaluationRunConfigDict] + + +class _GetEvaluationRunParameters(_common.BaseModel): + """Represents a job that runs evaluation.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetEvaluationRunConfig] = Field(default=None, description="""""") + + +class _GetEvaluationRunParametersDict(TypedDict, total=False): + """Represents a job that runs evaluation.""" + + name: Optional[str] + """""" + + config: Optional[GetEvaluationRunConfigDict] + """""" + + +_GetEvaluationRunParametersOrDict = Union[ + _GetEvaluationRunParameters, _GetEvaluationRunParametersDict +] + + +class GetEvaluationSetConfig(_common.BaseModel): + """Config for get evaluation set.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetEvaluationSetConfigDict(TypedDict, total=False): + """Config for get evaluation set.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetEvaluationSetConfigOrDict = Union[GetEvaluationSetConfig, GetEvaluationSetConfigDict] + + +class _GetEvaluationSetParameters(_common.BaseModel): + """Represents a job that gets an evaluation set.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetEvaluationSetConfig] = Field(default=None, description="""""") + + +class _GetEvaluationSetParametersDict(TypedDict, total=False): + """Represents a job that gets an evaluation set.""" + + name: Optional[str] + """""" + + config: Optional[GetEvaluationSetConfigDict] + """""" + + +_GetEvaluationSetParametersOrDict = Union[ + _GetEvaluationSetParameters, _GetEvaluationSetParametersDict +] + + +class GetEvaluationItemConfig(_common.BaseModel): + """Config for get evaluation item.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetEvaluationItemConfigDict(TypedDict, total=False): + """Config for get evaluation item.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetEvaluationItemConfigOrDict = Union[ + GetEvaluationItemConfig, GetEvaluationItemConfigDict +] + + +class _GetEvaluationItemParameters(_common.BaseModel): + """Represents a job that gets an evaluation item.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[GetEvaluationItemConfig] = Field(default=None, description="""""") + + +class _GetEvaluationItemParametersDict(TypedDict, total=False): + """Represents a job that gets an evaluation item.""" + + name: Optional[str] + """""" + + config: Optional[GetEvaluationItemConfigDict] + """""" + + +_GetEvaluationItemParametersOrDict = Union[ + _GetEvaluationItemParameters, _GetEvaluationItemParametersDict +] + + +class ListEvaluationMetricsConfig(_common.BaseModel): + """Config for listing evaluation metrics.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + For more information about filter syntax, see + `AIP-160 `_.""", + ) + order_by: Optional[str] = Field( + default=None, + description="""A comma-separated list of fields to order by, sorted in ascending + order by default. Use ``desc`` after a field name for descending. + Example: ``"create_time desc"``.""", + ) + + +class ListEvaluationMetricsConfigDict(TypedDict, total=False): + """Config for listing evaluation metrics.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported. + For more information about filter syntax, see + `AIP-160 `_.""" + + order_by: Optional[str] + """A comma-separated list of fields to order by, sorted in ascending + order by default. Use ``desc`` after a field name for descending. + Example: ``"create_time desc"``.""" + + +ListEvaluationMetricsConfigOrDict = Union[ + ListEvaluationMetricsConfig, ListEvaluationMetricsConfigDict +] + + +class _ListEvaluationMetricsParameters(_common.BaseModel): + """Parameters for listing evaluation metrics.""" + + config: Optional[ListEvaluationMetricsConfig] = Field( + default=None, description="""""" + ) + + +class _ListEvaluationMetricsParametersDict(TypedDict, total=False): + """Parameters for listing evaluation metrics.""" + + config: Optional[ListEvaluationMetricsConfigDict] + """""" + + +_ListEvaluationMetricsParametersOrDict = Union[ + _ListEvaluationMetricsParameters, _ListEvaluationMetricsParametersDict +] + + +class ListEvaluationMetricsResponse(_common.BaseModel): + """Response for listing evaluation metrics.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + evaluation_metrics: Optional[list[EvaluationMetric]] = Field( + default=None, + description="""List of evaluation metrics. + """, + ) + + +class ListEvaluationMetricsResponseDict(TypedDict, total=False): + """Response for listing evaluation metrics.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + evaluation_metrics: Optional[list[EvaluationMetricDict]] + """List of evaluation metrics. + """ + + +ListEvaluationMetricsResponseOrDict = Union[ + ListEvaluationMetricsResponse, ListEvaluationMetricsResponseDict +] + + +class OptimizeConfig(_common.BaseModel): + """Config for Prompt Optimizer.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + optimization_target: Optional[OptimizeTarget] = Field( + default=None, + description="""The optimization target for the prompt optimizer. It must be one of the OptimizeTarget enum values: OPTIMIZATION_TARGET_GEMINI_NANO for the prompts from Android core API, OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS for the few-shot prompt optimizer with rubrics, OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE for the few-shot prompt optimizer with target responses.""", + ) + examples_dataframe: Optional[PandasDataFrame] = Field( + default=None, + description="""The examples dataframe for the few-shot prompt optimizer. It must contain "prompt" and "model_response" columns. Depending on which optimization target is used, it also needs to contain "rubrics" and "rubrics_evaluations" or "target_response" columns.""", + ) + + +class OptimizeConfigDict(TypedDict, total=False): + """Config for Prompt Optimizer.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + optimization_target: Optional[OptimizeTarget] + """The optimization target for the prompt optimizer. It must be one of the OptimizeTarget enum values: OPTIMIZATION_TARGET_GEMINI_NANO for the prompts from Android core API, OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS for the few-shot prompt optimizer with rubrics, OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE for the few-shot prompt optimizer with target responses.""" + + examples_dataframe: Optional[PandasDataFrame] + """The examples dataframe for the few-shot prompt optimizer. It must contain "prompt" and "model_response" columns. Depending on which optimization target is used, it also needs to contain "rubrics" and "rubrics_evaluations" or "target_response" columns.""" + + +OptimizeConfigOrDict = Union[OptimizeConfig, OptimizeConfigDict] + + +class _OptimizeRequestParameters(_common.BaseModel): + """Request for the optimize_prompt method.""" + + content: Optional[genai_types.Content] = Field(default=None, description="""""") + config: Optional[OptimizeConfig] = Field(default=None, description="""""") + + +class _OptimizeRequestParametersDict(TypedDict, total=False): + """Request for the optimize_prompt method.""" + + content: Optional[genai_types.ContentDict] + """""" + + config: Optional[OptimizeConfigDict] + """""" + + +_OptimizeRequestParametersOrDict = Union[ + _OptimizeRequestParameters, _OptimizeRequestParametersDict +] + + +class OptimizeResponseEndpoint(_common.BaseModel): + """Response for the optimize_prompt method.""" + + content: Optional[genai_types.Content] = Field(default=None, description="""""") + + +class OptimizeResponseEndpointDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + content: Optional[genai_types.ContentDict] + """""" + + +OptimizeResponseEndpointOrDict = Union[ + OptimizeResponseEndpoint, OptimizeResponseEndpointDict +] + + +class DnsPeeringConfig(_common.BaseModel): + """DNS peering configuration. These configurations are used to create DNS peering zones in the Vertex tenant project VPC, enabling resolution of records within the specified domain hosted in the target network's Cloud DNS.""" + + domain: Optional[str] = Field( + default=None, + description="""Required. The DNS name suffix of the zone being peered to, e.g., "my-internal-domain.corp.". Must end with a dot.""", + ) + target_network: Optional[str] = Field( + default=None, + description="""Required. The VPC network name in the target_project where the DNS zone specified by 'domain' is visible.""", + ) + target_project: Optional[str] = Field( + default=None, + description="""Required. The project ID hosting the Cloud DNS managed zone that contains the 'domain'. The Vertex AI Service Agent requires the dns.peer role on this project.""", + ) + + +class DnsPeeringConfigDict(TypedDict, total=False): + """DNS peering configuration. These configurations are used to create DNS peering zones in the Vertex tenant project VPC, enabling resolution of records within the specified domain hosted in the target network's Cloud DNS.""" + + domain: Optional[str] + """Required. The DNS name suffix of the zone being peered to, e.g., "my-internal-domain.corp.". Must end with a dot.""" + + target_network: Optional[str] + """Required. The VPC network name in the target_project where the DNS zone specified by 'domain' is visible.""" + + target_project: Optional[str] + """Required. The project ID hosting the Cloud DNS managed zone that contains the 'domain'. The Vertex AI Service Agent requires the dns.peer role on this project.""" + + +DnsPeeringConfigOrDict = Union[DnsPeeringConfig, DnsPeeringConfigDict] + + +class PscInterfaceConfig(_common.BaseModel): + """Configuration for PSC-I.""" + + dns_peering_configs: Optional[list[DnsPeeringConfig]] = Field( + default=None, + description="""Optional. DNS peering configurations. When specified, Vertex AI will attempt to configure DNS peering zones in the tenant project VPC to resolve the specified domains using the target network's Cloud DNS. The user must grant the dns.peer role to the Vertex AI Service Agent on the target project.""", + ) + network_attachment: Optional[str] = Field( + default=None, + description="""Optional. The name of the Compute Engine [network attachment](https://cloud.google.com/vpc/docs/about-network-attachments) to attach to the resource within the region and user project. To specify this field, you must have already [created a network attachment] (https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments). This field is only used for resources using PSC-I.""", + ) + + +class PscInterfaceConfigDict(TypedDict, total=False): + """Configuration for PSC-I.""" + + dns_peering_configs: Optional[list[DnsPeeringConfigDict]] + """Optional. DNS peering configurations. When specified, Vertex AI will attempt to configure DNS peering zones in the tenant project VPC to resolve the specified domains using the target network's Cloud DNS. The user must grant the dns.peer role to the Vertex AI Service Agent on the target project.""" + + network_attachment: Optional[str] + """Optional. The name of the Compute Engine [network attachment](https://cloud.google.com/vpc/docs/about-network-attachments) to attach to the resource within the region and user project. To specify this field, you must have already [created a network attachment] (https://cloud.google.com/vpc/docs/create-manage-network-attachments#create-network-attachments). This field is only used for resources using PSC-I.""" + + +PscInterfaceConfigOrDict = Union[PscInterfaceConfig, PscInterfaceConfigDict] + + +class Scheduling(_common.BaseModel): + """All parameters related to queuing and scheduling of custom jobs.""" + + disable_retries: Optional[bool] = Field( + default=None, + description="""Optional. Indicates if the job should retry for internal errors after the job starts running. If true, overrides `Scheduling.restart_job_on_worker_restart` to false.""", + ) + max_wait_duration: Optional[str] = Field( + default=None, + description="""Optional. This is the maximum duration that a job will wait for the requested resources to be provisioned if the scheduling strategy is set to [Strategy.DWS_FLEX_START]. If set to 0, the job will wait indefinitely. The default is 24 hours.""", + ) + restart_job_on_worker_restart: Optional[bool] = Field( + default=None, + description="""Optional. Restarts the entire CustomJob if a worker gets restarted. This feature can be used by distributed training jobs that are not resilient to workers leaving and joining a job.""", + ) + strategy: Optional[Strategy] = Field( + default=None, + description="""Optional. This determines which type of scheduling strategy to use.""", + ) + timeout: Optional[str] = Field( + default=None, + description="""Optional. The maximum job running time. The default is 7 days.""", + ) + + +class SchedulingDict(TypedDict, total=False): + """All parameters related to queuing and scheduling of custom jobs.""" + + disable_retries: Optional[bool] + """Optional. Indicates if the job should retry for internal errors after the job starts running. If true, overrides `Scheduling.restart_job_on_worker_restart` to false.""" + + max_wait_duration: Optional[str] + """Optional. This is the maximum duration that a job will wait for the requested resources to be provisioned if the scheduling strategy is set to [Strategy.DWS_FLEX_START]. If set to 0, the job will wait indefinitely. The default is 24 hours.""" + + restart_job_on_worker_restart: Optional[bool] + """Optional. Restarts the entire CustomJob if a worker gets restarted. This feature can be used by distributed training jobs that are not resilient to workers leaving and joining a job.""" + + strategy: Optional[Strategy] + """Optional. This determines which type of scheduling strategy to use.""" + + timeout: Optional[str] + """Optional. The maximum job running time. The default is 7 days.""" + + +SchedulingOrDict = Union[Scheduling, SchedulingDict] + + +class EnvVar(_common.BaseModel): + """Represents an environment variable present in a Container or Python Module.""" + + name: Optional[str] = Field( + default=None, + description="""Required. Name of the environment variable. Must be a valid C identifier.""", + ) + value: Optional[str] = Field( + default=None, + description="""Required. Variables that reference a $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not.""", + ) + + +class EnvVarDict(TypedDict, total=False): + """Represents an environment variable present in a Container or Python Module.""" + + name: Optional[str] + """Required. Name of the environment variable. Must be a valid C identifier.""" + + value: Optional[str] + """Required. Variables that reference a $(VAR_NAME) are expanded using the previous defined environment variables in the container and any service environment variables. If a variable cannot be resolved, the reference in the input string will be unchanged. The $(VAR_NAME) syntax can be escaped with a double $$, ie: $$(VAR_NAME). Escaped references will never be expanded, regardless of whether the variable exists or not.""" + + +EnvVarOrDict = Union[EnvVar, EnvVarDict] + + +class ContainerSpec(_common.BaseModel): + """The spec of a Container.""" + + args: Optional[list[str]] = Field( + default=None, + description="""The arguments to be passed when starting the container.""", + ) + command: Optional[list[str]] = Field( + default=None, + description="""The command to be invoked when the container is started. It overrides the entrypoint instruction in Dockerfile when provided.""", + ) + env: Optional[list[EnvVar]] = Field( + default=None, + description="""Environment variables to be passed to the container. Maximum limit is 100.""", + ) + image_uri: Optional[str] = Field( + default=None, + description="""Required. The URI of a container image in the Container Registry that is to be run on each worker replica.""", + ) + + +class ContainerSpecDict(TypedDict, total=False): + """The spec of a Container.""" + + args: Optional[list[str]] + """The arguments to be passed when starting the container.""" + + command: Optional[list[str]] + """The command to be invoked when the container is started. It overrides the entrypoint instruction in Dockerfile when provided.""" + + env: Optional[list[EnvVarDict]] + """Environment variables to be passed to the container. Maximum limit is 100.""" + + image_uri: Optional[str] + """Required. The URI of a container image in the Container Registry that is to be run on each worker replica.""" + + +ContainerSpecOrDict = Union[ContainerSpec, ContainerSpecDict] + + +class DiskSpec(_common.BaseModel): + """Represents the spec of disk options.""" + + boot_disk_size_gb: Optional[int] = Field( + default=None, description="""Size in GB of the boot disk (default is 100GB).""" + ) + boot_disk_type: Optional[str] = Field( + default=None, + description="""Type of the boot disk. For non-A3U machines, the default value is "pd-ssd", for A3U machines, the default value is "hyperdisk-balanced". Valid values: "pd-ssd" (Persistent Disk Solid State Drive), "pd-standard" (Persistent Disk Hard Disk Drive) or "hyperdisk-balanced".""", + ) + + +class DiskSpecDict(TypedDict, total=False): + """Represents the spec of disk options.""" + + boot_disk_size_gb: Optional[int] + """Size in GB of the boot disk (default is 100GB).""" + + boot_disk_type: Optional[str] + """Type of the boot disk. For non-A3U machines, the default value is "pd-ssd", for A3U machines, the default value is "hyperdisk-balanced". Valid values: "pd-ssd" (Persistent Disk Solid State Drive), "pd-standard" (Persistent Disk Hard Disk Drive) or "hyperdisk-balanced".""" + + +DiskSpecOrDict = Union[DiskSpec, DiskSpecDict] + + +class LustreMount(_common.BaseModel): + """Represents a mount configuration for Lustre file system.""" + + filesystem: Optional[str] = Field( + default=None, description="""Required. The name of the Lustre filesystem.""" + ) + instance_ip: Optional[str] = Field( + default=None, description="""Required. IP address of the Lustre instance.""" + ) + mount_point: Optional[str] = Field( + default=None, + description="""Required. Destination mount path. The Lustre file system will be mounted for the user under /mnt/lustre/""", + ) + volume_handle: Optional[str] = Field( + default=None, + description="""Required. The unique identifier of the Lustre volume.""", + ) + + +class LustreMountDict(TypedDict, total=False): + """Represents a mount configuration for Lustre file system.""" + + filesystem: Optional[str] + """Required. The name of the Lustre filesystem.""" + + instance_ip: Optional[str] + """Required. IP address of the Lustre instance.""" + + mount_point: Optional[str] + """Required. Destination mount path. The Lustre file system will be mounted for the user under /mnt/lustre/""" + + volume_handle: Optional[str] + """Required. The unique identifier of the Lustre volume.""" + + +LustreMountOrDict = Union[LustreMount, LustreMountDict] + + +class ReservationAffinity(_common.BaseModel): + """A ReservationAffinity can be used to configure a Vertex AI resource (e.g., a DeployedModel) to draw its Compute Engine resources from a Shared Reservation, or exclusively from on-demand capacity.""" + + key: Optional[str] = Field( + default=None, + description="""Optional. Corresponds to the label key of a reservation resource. To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key and specify the name of your reservation as its value.""", + ) + reservation_affinity_type: Optional[Type] = Field( + default=None, + description="""Required. Specifies the reservation affinity type.""", + ) + values: Optional[list[str]] = Field( + default=None, + description="""Optional. Corresponds to the label values of a reservation resource. This must be the full resource name of the reservation or reservation block.""", + ) + + +class ReservationAffinityDict(TypedDict, total=False): + """A ReservationAffinity can be used to configure a Vertex AI resource (e.g., a DeployedModel) to draw its Compute Engine resources from a Shared Reservation, or exclusively from on-demand capacity.""" + + key: Optional[str] + """Optional. Corresponds to the label key of a reservation resource. To target a SPECIFIC_RESERVATION by name, use `compute.googleapis.com/reservation-name` as the key and specify the name of your reservation as its value.""" + + reservation_affinity_type: Optional[Type] + """Required. Specifies the reservation affinity type.""" + + values: Optional[list[str]] + """Optional. Corresponds to the label values of a reservation resource. This must be the full resource name of the reservation or reservation block.""" + + +ReservationAffinityOrDict = Union[ReservationAffinity, ReservationAffinityDict] + + +class MachineSpec(_common.BaseModel): + """Specification of a single machine.""" + + accelerator_count: Optional[int] = Field( + default=None, + description="""The number of accelerators to attach to the machine. For accelerator optimized machine types (https://cloud.google.com/compute/docs/accelerator-optimized-machines), One may set the accelerator_count from 1 to N for machine with N GPUs. If accelerator_count is less than or equal to N / 2, Vertex will co-schedule the replicas of the model into the same VM to save cost. For example, if the machine type is a3-highgpu-8g, which has 8 H100 GPUs, one can set accelerator_count to 1 to 8. If accelerator_count is 1, 2, 3, or 4, Vertex will co-schedule 8, 4, 2, or 2 replicas of the model into the same VM to save cost. When co-scheduling, CPU, memory and storage on the VM will be distributed to replicas on the VM. For example, one can expect a co-scheduled replica requesting 2 GPUs out of a 8-GPU VM will receive 25% of the CPU, memory and storage of the VM. Note that the feature is not compatible with multihost_gpu_node_count. When multihost_gpu_node_count is set, the co-scheduling will not be enabled.""", + ) + accelerator_type: Optional[AcceleratorType] = Field( + default=None, + description="""Immutable. The type of accelerator(s) that may be attached to the machine as per accelerator_count.""", + ) + gpu_partition_size: Optional[str] = Field( + default=None, + description="""Optional. Immutable. The Nvidia GPU partition size. When specified, the requested accelerators will be partitioned into smaller GPU partitions. For example, if the request is for 8 units of NVIDIA A100 GPUs, and gpu_partition_size="1g.10gb", the service will create 8 * 7 = 56 partitioned MIG instances. The partition size must be a value supported by the requested accelerator. Refer to [Nvidia GPU Partitioning](https://cloud.google.com/kubernetes-engine/docs/how-to/gpus-multi#multi-instance_gpu_partitions) for the available partition sizes. If set, the accelerator_count should be set to 1.""", + ) + machine_type: Optional[str] = Field( + default=None, + description="""Immutable. The type of the machine. See the [list of machine types supported for prediction](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#machine-types) See the [list of machine types supported for custom training](https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types). For DeployedModel this field is optional, and the default value is `n1-standard-2`. For BatchPredictionJob or as part of WorkerPoolSpec this field is required.""", + ) + min_gpu_driver_version: Optional[str] = Field( + default=None, + description="""Optional. Immutable. The minimum GPU driver version that this machine requires. For example, "535.104.06". If not specified, the default GPU driver version will be used by the underlying infrastructure.""", + ) + multihost_gpu_node_count: Optional[int] = Field( + default=None, + description="""Optional. Immutable. The number of nodes per replica for multihost GPU deployments.""", + ) + reservation_affinity: Optional[ReservationAffinity] = Field( + default=None, + description="""Optional. Immutable. Configuration controlling how this resource pool consumes reservation.""", + ) + tpu_topology: Optional[str] = Field( + default=None, + description="""Immutable. The topology of the TPUs. Corresponds to the TPU topologies available from GKE. (Example: tpu_topology: "2x2x1").""", + ) + + +class MachineSpecDict(TypedDict, total=False): + """Specification of a single machine.""" + + accelerator_count: Optional[int] + """The number of accelerators to attach to the machine. For accelerator optimized machine types (https://cloud.google.com/compute/docs/accelerator-optimized-machines), One may set the accelerator_count from 1 to N for machine with N GPUs. If accelerator_count is less than or equal to N / 2, Vertex will co-schedule the replicas of the model into the same VM to save cost. For example, if the machine type is a3-highgpu-8g, which has 8 H100 GPUs, one can set accelerator_count to 1 to 8. If accelerator_count is 1, 2, 3, or 4, Vertex will co-schedule 8, 4, 2, or 2 replicas of the model into the same VM to save cost. When co-scheduling, CPU, memory and storage on the VM will be distributed to replicas on the VM. For example, one can expect a co-scheduled replica requesting 2 GPUs out of a 8-GPU VM will receive 25% of the CPU, memory and storage of the VM. Note that the feature is not compatible with multihost_gpu_node_count. When multihost_gpu_node_count is set, the co-scheduling will not be enabled.""" + + accelerator_type: Optional[AcceleratorType] + """Immutable. The type of accelerator(s) that may be attached to the machine as per accelerator_count.""" + + gpu_partition_size: Optional[str] + """Optional. Immutable. The Nvidia GPU partition size. When specified, the requested accelerators will be partitioned into smaller GPU partitions. For example, if the request is for 8 units of NVIDIA A100 GPUs, and gpu_partition_size="1g.10gb", the service will create 8 * 7 = 56 partitioned MIG instances. The partition size must be a value supported by the requested accelerator. Refer to [Nvidia GPU Partitioning](https://cloud.google.com/kubernetes-engine/docs/how-to/gpus-multi#multi-instance_gpu_partitions) for the available partition sizes. If set, the accelerator_count should be set to 1.""" + + machine_type: Optional[str] + """Immutable. The type of the machine. See the [list of machine types supported for prediction](https://cloud.google.com/vertex-ai/docs/predictions/configure-compute#machine-types) See the [list of machine types supported for custom training](https://cloud.google.com/vertex-ai/docs/training/configure-compute#machine-types). For DeployedModel this field is optional, and the default value is `n1-standard-2`. For BatchPredictionJob or as part of WorkerPoolSpec this field is required.""" + + min_gpu_driver_version: Optional[str] + """Optional. Immutable. The minimum GPU driver version that this machine requires. For example, "535.104.06". If not specified, the default GPU driver version will be used by the underlying infrastructure.""" + + multihost_gpu_node_count: Optional[int] + """Optional. Immutable. The number of nodes per replica for multihost GPU deployments.""" + + reservation_affinity: Optional[ReservationAffinityDict] + """Optional. Immutable. Configuration controlling how this resource pool consumes reservation.""" + + tpu_topology: Optional[str] + """Immutable. The topology of the TPUs. Corresponds to the TPU topologies available from GKE. (Example: tpu_topology: "2x2x1").""" + + +MachineSpecOrDict = Union[MachineSpec, MachineSpecDict] + + +class NfsMount(_common.BaseModel): + """Represents a mount configuration for Network File System (NFS) to mount.""" + + mount_point: Optional[str] = Field( + default=None, + description="""Required. Destination mount path. The NFS will be mounted for the user under /mnt/nfs/""", + ) + path: Optional[str] = Field( + default=None, + description="""Required. Source path exported from NFS server. Has to start with '/', and combined with the ip address, it indicates the source mount path in the form of `server:path`""", + ) + server: Optional[str] = Field( + default=None, description="""Required. IP address of the NFS server.""" + ) + + +class NfsMountDict(TypedDict, total=False): + """Represents a mount configuration for Network File System (NFS) to mount.""" + + mount_point: Optional[str] + """Required. Destination mount path. The NFS will be mounted for the user under /mnt/nfs/""" + + path: Optional[str] + """Required. Source path exported from NFS server. Has to start with '/', and combined with the ip address, it indicates the source mount path in the form of `server:path`""" + + server: Optional[str] + """Required. IP address of the NFS server.""" + + +NfsMountOrDict = Union[NfsMount, NfsMountDict] + + +class PythonPackageSpec(_common.BaseModel): + """The spec of a Python packaged code.""" + + args: Optional[list[str]] = Field( + default=None, + description="""Command line arguments to be passed to the Python task.""", + ) + env: Optional[list[EnvVar]] = Field( + default=None, + description="""Environment variables to be passed to the python module. Maximum limit is 100.""", + ) + executor_image_uri: Optional[str] = Field( + default=None, + description="""Required. The URI of a container image in Artifact Registry that will run the provided Python package. Vertex AI provides a wide range of executor images with pre-installed packages to meet users' various use cases. See the list of [pre-built containers for training](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers). You must use an image from this list.""", + ) + package_uris: Optional[list[str]] = Field( + default=None, + description="""Required. The Google Cloud Storage location of the Python package files which are the training program and its dependent packages. The maximum number of package URIs is 100.""", + ) + python_module: Optional[str] = Field( + default=None, + description="""Required. The Python module name to run after installing the packages.""", + ) + + +class PythonPackageSpecDict(TypedDict, total=False): + """The spec of a Python packaged code.""" + + args: Optional[list[str]] + """Command line arguments to be passed to the Python task.""" + + env: Optional[list[EnvVarDict]] + """Environment variables to be passed to the python module. Maximum limit is 100.""" + + executor_image_uri: Optional[str] + """Required. The URI of a container image in Artifact Registry that will run the provided Python package. Vertex AI provides a wide range of executor images with pre-installed packages to meet users' various use cases. See the list of [pre-built containers for training](https://cloud.google.com/vertex-ai/docs/training/pre-built-containers). You must use an image from this list.""" + + package_uris: Optional[list[str]] + """Required. The Google Cloud Storage location of the Python package files which are the training program and its dependent packages. The maximum number of package URIs is 100.""" + + python_module: Optional[str] + """Required. The Python module name to run after installing the packages.""" + + +PythonPackageSpecOrDict = Union[PythonPackageSpec, PythonPackageSpecDict] + + +class WorkerPoolSpec(_common.BaseModel): + """Represents the spec of a worker pool in a job.""" + + container_spec: Optional[ContainerSpec] = Field( + default=None, description="""The custom container task.""" + ) + disk_spec: Optional[DiskSpec] = Field(default=None, description="""Disk spec.""") + lustre_mounts: Optional[list[LustreMount]] = Field( + default=None, description="""Optional. List of Lustre mounts.""" + ) + machine_spec: Optional[MachineSpec] = Field( + default=None, + description="""Optional. Immutable. The specification of a single machine.""", + ) + nfs_mounts: Optional[list[NfsMount]] = Field( + default=None, description="""Optional. List of NFS mount spec.""" + ) + python_package_spec: Optional[PythonPackageSpec] = Field( + default=None, description="""The Python packaged task.""" + ) + replica_count: Optional[int] = Field( + default=None, + description="""Optional. The number of worker replicas to use for this worker pool.""", + ) + + +class WorkerPoolSpecDict(TypedDict, total=False): + """Represents the spec of a worker pool in a job.""" + + container_spec: Optional[ContainerSpecDict] + """The custom container task.""" + + disk_spec: Optional[DiskSpecDict] + """Disk spec.""" + + lustre_mounts: Optional[list[LustreMountDict]] + """Optional. List of Lustre mounts.""" + + machine_spec: Optional[MachineSpecDict] + """Optional. Immutable. The specification of a single machine.""" + + nfs_mounts: Optional[list[NfsMountDict]] + """Optional. List of NFS mount spec.""" + + python_package_spec: Optional[PythonPackageSpecDict] + """The Python packaged task.""" + + replica_count: Optional[int] + """Optional. The number of worker replicas to use for this worker pool.""" + + +WorkerPoolSpecOrDict = Union[WorkerPoolSpec, WorkerPoolSpecDict] + + +class CustomJobSpec(_common.BaseModel): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + base_output_directory: Optional[genai_types.GcsDestination] = Field( + default=None, + description="""The Cloud Storage location to store the output of this CustomJob or HyperparameterTuningJob. For HyperparameterTuningJob, the baseOutputDirectory of each child CustomJob backing a Trial is set to a subdirectory of name id under its parent HyperparameterTuningJob's baseOutputDirectory. The following Vertex AI environment variables will be passed to containers or python modules when this field is set: For CustomJob: * AIP_MODEL_DIR = `/model/` * AIP_CHECKPOINT_DIR = `/checkpoints/` * AIP_TENSORBOARD_LOG_DIR = `/logs/` For CustomJob backing a Trial of HyperparameterTuningJob: * AIP_MODEL_DIR = `//model/` * AIP_CHECKPOINT_DIR = `//checkpoints/` * AIP_TENSORBOARD_LOG_DIR = `//logs/`""", + ) + enable_dashboard_access: Optional[bool] = Field( + default=None, + description="""Optional. Whether you want Vertex AI to enable access to the customized dashboard in training chief container. If set to `true`, you can access the dashboard at the URIs given by CustomJob.web_access_uris or Trial.web_access_uris (within HyperparameterTuningJob.trials).""", + ) + enable_web_access: Optional[bool] = Field( + default=None, + description="""Optional. Whether you want Vertex AI to enable [interactive shell access](https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell) to training containers. If set to `true`, you can access interactive shells at the URIs given by CustomJob.web_access_uris or Trial.web_access_uris (within HyperparameterTuningJob.trials).""", + ) + experiment: Optional[str] = Field( + default=None, + description="""Optional. The Experiment associated with this job. Format: `projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}`""", + ) + experiment_run: Optional[str] = Field( + default=None, + description="""Optional. The Experiment Run associated with this job. Format: `projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}-{experiment-run-name}`""", + ) + models: Optional[list[str]] = Field( + default=None, + description="""Optional. The name of the Model resources for which to generate a mapping to artifact URIs. Applicable only to some of the Google-provided custom jobs. Format: `projects/{project}/locations/{location}/models/{model}` In order to retrieve a specific version of the model, also provide the version ID or version alias. Example: `projects/{project}/locations/{location}/models/{model}@2` or `projects/{project}/locations/{location}/models/{model}@golden` If no version ID or alias is specified, the "default" version will be returned. The "default" version alias is created for the first version of the model, and can be moved to other versions later on. There will be exactly one default version.""", + ) + network: Optional[str] = Field( + default=None, + description="""Optional. The full name of the Compute Engine [network](/compute/docs/networks-and-firewalls#networks) to which the Job should be peered. For example, `projects/12345/global/networks/myVPC`. [Format](/compute/docs/reference/rest/v1/networks/insert) is of the form `projects/{project}/global/networks/{network}`. Where {project} is a project number, as in `12345`, and {network} is a network name. To specify this field, you must have already [configured VPC Network Peering for Vertex AI](https://cloud.google.com/vertex-ai/docs/general/vpc-peering). If this field is left unspecified, the job is not peered with any network.""", + ) + persistent_resource_id: Optional[str] = Field( + default=None, + description="""Optional. The ID of the PersistentResource in the same Project and Location which to run If this is specified, the job will be run on existing machines held by the PersistentResource instead of on-demand short-live machines. The network and CMEK configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected.""", + ) + protected_artifact_location_id: Optional[str] = Field( + default=None, + description="""The ID of the location to store protected artifacts. e.g. us-central1. Populate only when the location is different than CustomJob location. List of supported locations: https://cloud.google.com/vertex-ai/docs/general/locations""", + ) + psc_interface_config: Optional[PscInterfaceConfig] = Field( + default=None, description="""Optional. Configuration for PSC-I for CustomJob.""" + ) + reserved_ip_ranges: Optional[list[str]] = Field( + default=None, + description="""Optional. A list of names for the reserved ip ranges under the VPC network that can be used for this job. If set, we will deploy the job within the provided ip ranges. Otherwise, the job will be deployed to any ip ranges under the provided VPC network. Example: ['vertex-ai-ip-range'].""", + ) + scheduling: Optional[Scheduling] = Field( + default=None, description="""Scheduling options for a CustomJob.""" + ) + service_account: Optional[str] = Field( + default=None, + description="""Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If unspecified, the [Vertex AI Custom Code Service Agent](https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents) for the CustomJob's project is used.""", + ) + tensorboard: Optional[str] = Field( + default=None, + description="""Optional. The name of a Vertex AI Tensorboard resource to which this CustomJob will upload Tensorboard logs. Format: `projects/{project}/locations/{location}/tensorboards/{tensorboard}`""", + ) + worker_pool_specs: Optional[list[WorkerPoolSpec]] = Field( + default=None, + description="""Required. The spec of the worker pools including machine type and Docker image. All worker pools except the first one are optional and can be skipped by providing an empty value.""", + ) + + +class CustomJobSpecDict(TypedDict, total=False): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + base_output_directory: Optional[genai_types.GcsDestinationDict] + """The Cloud Storage location to store the output of this CustomJob or HyperparameterTuningJob. For HyperparameterTuningJob, the baseOutputDirectory of each child CustomJob backing a Trial is set to a subdirectory of name id under its parent HyperparameterTuningJob's baseOutputDirectory. The following Vertex AI environment variables will be passed to containers or python modules when this field is set: For CustomJob: * AIP_MODEL_DIR = `/model/` * AIP_CHECKPOINT_DIR = `/checkpoints/` * AIP_TENSORBOARD_LOG_DIR = `/logs/` For CustomJob backing a Trial of HyperparameterTuningJob: * AIP_MODEL_DIR = `//model/` * AIP_CHECKPOINT_DIR = `//checkpoints/` * AIP_TENSORBOARD_LOG_DIR = `//logs/`""" + + enable_dashboard_access: Optional[bool] + """Optional. Whether you want Vertex AI to enable access to the customized dashboard in training chief container. If set to `true`, you can access the dashboard at the URIs given by CustomJob.web_access_uris or Trial.web_access_uris (within HyperparameterTuningJob.trials).""" + + enable_web_access: Optional[bool] + """Optional. Whether you want Vertex AI to enable [interactive shell access](https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell) to training containers. If set to `true`, you can access interactive shells at the URIs given by CustomJob.web_access_uris or Trial.web_access_uris (within HyperparameterTuningJob.trials).""" + + experiment: Optional[str] + """Optional. The Experiment associated with this job. Format: `projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}`""" + + experiment_run: Optional[str] + """Optional. The Experiment Run associated with this job. Format: `projects/{project}/locations/{location}/metadataStores/{metadataStores}/contexts/{experiment-name}-{experiment-run-name}`""" + + models: Optional[list[str]] + """Optional. The name of the Model resources for which to generate a mapping to artifact URIs. Applicable only to some of the Google-provided custom jobs. Format: `projects/{project}/locations/{location}/models/{model}` In order to retrieve a specific version of the model, also provide the version ID or version alias. Example: `projects/{project}/locations/{location}/models/{model}@2` or `projects/{project}/locations/{location}/models/{model}@golden` If no version ID or alias is specified, the "default" version will be returned. The "default" version alias is created for the first version of the model, and can be moved to other versions later on. There will be exactly one default version.""" + + network: Optional[str] + """Optional. The full name of the Compute Engine [network](/compute/docs/networks-and-firewalls#networks) to which the Job should be peered. For example, `projects/12345/global/networks/myVPC`. [Format](/compute/docs/reference/rest/v1/networks/insert) is of the form `projects/{project}/global/networks/{network}`. Where {project} is a project number, as in `12345`, and {network} is a network name. To specify this field, you must have already [configured VPC Network Peering for Vertex AI](https://cloud.google.com/vertex-ai/docs/general/vpc-peering). If this field is left unspecified, the job is not peered with any network.""" + + persistent_resource_id: Optional[str] + """Optional. The ID of the PersistentResource in the same Project and Location which to run If this is specified, the job will be run on existing machines held by the PersistentResource instead of on-demand short-live machines. The network and CMEK configs on the job should be consistent with those on the PersistentResource, otherwise, the job will be rejected.""" + + protected_artifact_location_id: Optional[str] + """The ID of the location to store protected artifacts. e.g. us-central1. Populate only when the location is different than CustomJob location. List of supported locations: https://cloud.google.com/vertex-ai/docs/general/locations""" + + psc_interface_config: Optional[PscInterfaceConfigDict] + """Optional. Configuration for PSC-I for CustomJob.""" + + reserved_ip_ranges: Optional[list[str]] + """Optional. A list of names for the reserved ip ranges under the VPC network that can be used for this job. If set, we will deploy the job within the provided ip ranges. Otherwise, the job will be deployed to any ip ranges under the provided VPC network. Example: ['vertex-ai-ip-range'].""" + + scheduling: Optional[SchedulingDict] + """Scheduling options for a CustomJob.""" + + service_account: Optional[str] + """Specifies the service account for workload run-as account. Users submitting jobs must have act-as permission on this run-as account. If unspecified, the [Vertex AI Custom Code Service Agent](https://cloud.google.com/vertex-ai/docs/general/access-control#service-agents) for the CustomJob's project is used.""" + + tensorboard: Optional[str] + """Optional. The name of a Vertex AI Tensorboard resource to which this CustomJob will upload Tensorboard logs. Format: `projects/{project}/locations/{location}/tensorboards/{tensorboard}`""" + + worker_pool_specs: Optional[list[WorkerPoolSpecDict]] + """Required. The spec of the worker pools including machine type and Docker image. All worker pools except the first one are optional and can be skipped by providing an empty value.""" + + +CustomJobSpecOrDict = Union[CustomJobSpec, CustomJobSpecDict] + + +class CustomJob(_common.BaseModel): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + display_name: Optional[str] = Field( + default=None, + description="""Required. The display name of the CustomJob. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + job_spec: Optional[CustomJobSpec] = Field( + default=None, description="""Required. Job spec.""" + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key options for a CustomJob. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key.""", + ) + state: Optional[genai_types.JobState] = Field( + default=None, description="""Output only. The detailed state of the job.""" + ) + error: Optional[genai_types.GoogleRpcStatus] = Field( + default=None, + description="""Output only. Only populated when job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the CustomJob was created.""", + ) + end_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the CustomJob entered any of the following states: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""The labels with user-defined metadata to organize CustomJobs. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + name: Optional[str] = Field( + default=None, description="""Output only. Resource name of a CustomJob.""" + ) + satisfies_pzi: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + satisfies_pzs: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + start_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the CustomJob for the first time entered the `JOB_STATE_RUNNING` state.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Time when the CustomJob was most recently updated.""", + ) + web_access_uris: Optional[dict[str, str]] = Field( + default=None, + description="""Output only. URIs for accessing [interactive shells](https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell) (one URI for each training node). Only available if job_spec.enable_web_access is `true`. The keys are names of each node in the training job; for example, `workerpool0-0` for the primary node, `workerpool1-0` for the first node in the second worker pool, and `workerpool1-1` for the second node in the second worker pool. The values are the URIs for each node's interactive shell.""", + ) + + +class CustomJobDict(TypedDict, total=False): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + display_name: Optional[str] + """Required. The display name of the CustomJob. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + job_spec: Optional[CustomJobSpecDict] + """Required. Job spec.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key options for a CustomJob. If this is set, then all resources created by the CustomJob will be encrypted with the provided encryption key.""" + + state: Optional[genai_types.JobState] + """Output only. The detailed state of the job.""" + + error: Optional[genai_types.GoogleRpcStatusDict] + """Output only. Only populated when job's state is `JOB_STATE_FAILED` or `JOB_STATE_CANCELLED`.""" + + create_time: Optional[datetime.datetime] + """Output only. Time when the CustomJob was created.""" + + end_time: Optional[datetime.datetime] + """Output only. Time when the CustomJob entered any of the following states: `JOB_STATE_SUCCEEDED`, `JOB_STATE_FAILED`, `JOB_STATE_CANCELLED`.""" + + labels: Optional[dict[str, str]] + """The labels with user-defined metadata to organize CustomJobs. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + name: Optional[str] + """Output only. Resource name of a CustomJob.""" + + satisfies_pzi: Optional[bool] + """Output only. Reserved for future use.""" + + satisfies_pzs: Optional[bool] + """Output only. Reserved for future use.""" + + start_time: Optional[datetime.datetime] + """Output only. Time when the CustomJob for the first time entered the `JOB_STATE_RUNNING` state.""" + + update_time: Optional[datetime.datetime] + """Output only. Time when the CustomJob was most recently updated.""" + + web_access_uris: Optional[dict[str, str]] + """Output only. URIs for accessing [interactive shells](https://cloud.google.com/vertex-ai/docs/training/monitor-debug-interactive-shell) (one URI for each training node). Only available if job_spec.enable_web_access is `true`. The keys are names of each node in the training job; for example, `workerpool0-0` for the primary node, `workerpool1-0` for the first node in the second worker pool, and `workerpool1-1` for the second node in the second worker pool. The values are the URIs for each node's interactive shell.""" + + +CustomJobOrDict = Union[CustomJob, CustomJobDict] + + +class VertexBaseConfig(_common.BaseModel): + """Base config for Vertex AI.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class VertexBaseConfigDict(TypedDict, total=False): + """Base config for Vertex AI.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +VertexBaseConfigOrDict = Union[VertexBaseConfig, VertexBaseConfigDict] + + +class _CustomJobParameters(_common.BaseModel): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + custom_job: Optional[CustomJob] = Field(default=None, description="""""") + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _CustomJobParametersDict(TypedDict, total=False): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + custom_job: Optional[CustomJobDict] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_CustomJobParametersOrDict = Union[_CustomJobParameters, _CustomJobParametersDict] + + +class _GetCustomJobParameters(_common.BaseModel): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _GetCustomJobParametersDict(TypedDict, total=False): + """Represents a job that runs custom workloads such as a Docker container or a Python package.""" + + name: Optional[str] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_GetCustomJobParametersOrDict = Union[ + _GetCustomJobParameters, _GetCustomJobParametersDict +] + + +class CancelQueryJobAgentEngineConfig(_common.BaseModel): + """Config for canceling async querying agent engines.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + operation_name: Optional[str] = Field( + default=None, + description="""Name of the longrunning operation returned from run_query_job.""", + ) + + +class CancelQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for canceling async querying agent engines.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + operation_name: Optional[str] + """Name of the longrunning operation returned from run_query_job.""" + + +CancelQueryJobAgentEngineConfigOrDict = Union[ + CancelQueryJobAgentEngineConfig, CancelQueryJobAgentEngineConfigDict +] + + +class _CancelQueryJobAgentEngineRequestParameters(_common.BaseModel): + """Parameters for canceling async querying agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the reasoning engine resource.""" + ) + config: Optional[CancelQueryJobAgentEngineConfig] = Field( + default=None, description="""""" + ) + + +class _CancelQueryJobAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for canceling async querying agent engines.""" + + name: Optional[str] + """Name of the reasoning engine resource.""" + + config: Optional[CancelQueryJobAgentEngineConfigDict] + """""" + + +_CancelQueryJobAgentEngineRequestParametersOrDict = Union[ + _CancelQueryJobAgentEngineRequestParameters, + _CancelQueryJobAgentEngineRequestParametersDict, +] + + +class CancelQueryJobResult(_common.BaseModel): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CancelQueryJobResultDict(TypedDict, total=False): + """Result of canceling a query job.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CancelQueryJobResultOrDict = Union[CancelQueryJobResult, CancelQueryJobResultDict] + + +class CheckQueryJobAgentEngineConfig(_common.BaseModel): + """Config for async querying agent engines.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + retrieve_result: Optional[bool] = Field( + default=None, + description="""Whether to retrieve the results of the query job.""", + ) + + +class CheckQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for async querying agent engines.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + retrieve_result: Optional[bool] + """Whether to retrieve the results of the query job.""" + + +CheckQueryJobAgentEngineConfigOrDict = Union[ + CheckQueryJobAgentEngineConfig, CheckQueryJobAgentEngineConfigDict +] + + +class _CheckQueryJobAgentEngineRequestParameters(_common.BaseModel): + """Parameters for async querying agent engines.""" + + name: Optional[str] = Field(default=None, description="""Name of the query job.""") + config: Optional[CheckQueryJobAgentEngineConfig] = Field( + default=None, description="""""" + ) + + +class _CheckQueryJobAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for async querying agent engines.""" + + name: Optional[str] + """Name of the query job.""" + + config: Optional[CheckQueryJobAgentEngineConfigDict] + """""" + + +_CheckQueryJobAgentEngineRequestParametersOrDict = Union[ + _CheckQueryJobAgentEngineRequestParameters, + _CheckQueryJobAgentEngineRequestParametersDict, +] + + +class CheckQueryJobResult(_common.BaseModel): + """Result of checking a query job.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + operation_name: Optional[str] = Field( + default=None, description="""Name of the agent engine operation.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the output file.""" + ) + status: Optional[str] = Field( + default=None, description="""Status of the operation.""" + ) + result: Optional[str] = Field( + default=None, description="""JSON result of the operation.""" + ) + + +class CheckQueryJobResultDict(TypedDict, total=False): + """Result of checking a query job.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + operation_name: Optional[str] + """Name of the agent engine operation.""" + + output_gcs_uri: Optional[str] + """The GCS URI of the output file.""" + + status: Optional[str] + """Status of the operation.""" + + result: Optional[str] + """JSON result of the operation.""" + + +CheckQueryJobResultOrDict = Union[CheckQueryJobResult, CheckQueryJobResultDict] + + +class _RunQueryJobAgentEngineConfig(_common.BaseModel): + """Config for running a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + input_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the input file.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the output file.""" + ) + + +class _RunQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for running a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + input_gcs_uri: Optional[str] + """The GCS URI of the input file.""" + + output_gcs_uri: Optional[str] + """The GCS URI of the output file.""" + + +_RunQueryJobAgentEngineConfigOrDict = Union[ + _RunQueryJobAgentEngineConfig, _RunQueryJobAgentEngineConfigDict +] + + +class _RunQueryJobAgentEngineRequestParameters(_common.BaseModel): + """Parameters for running a query job on an agent engine.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[_RunQueryJobAgentEngineConfig] = Field( + default=None, description="""""" + ) + + +class _RunQueryJobAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for running a query job on an agent engine.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[_RunQueryJobAgentEngineConfigDict] + """""" + + +_RunQueryJobAgentEngineRequestParametersOrDict = Union[ + _RunQueryJobAgentEngineRequestParameters, + _RunQueryJobAgentEngineRequestParametersDict, +] + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEvent( + _common.BaseModel +): + """The conversation source event for generating memories.""" + + content: Optional[genai_types.Content] = Field( + default=None, description="""Required. Represents the content of the event.""" + ) + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventDict( + TypedDict, total=False +): + """The conversation source event for generating memories.""" + + content: Optional[genai_types.ContentDict] + """Required. Represents the content of the event.""" + + +MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventOrDict = ( + Union[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEvent, + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventDict, + ] +) + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSource( + _common.BaseModel +): + """A conversation source for the example. This is similar to `DirectContentsSource`.""" + + events: Optional[ + list[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEvent + ] + ] = Field( + default=None, + description="""Optional. Represents the input conversation events for the example.""", + ) + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict( + TypedDict, total=False +): + """A conversation source for the example. This is similar to `DirectContentsSource`.""" + + events: Optional[ + list[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceEventDict + ] + ] + """Optional. Represents the input conversation events for the example.""" + + +MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceOrDict = Union[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSource, + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict, +] + + +class MemoryTopicId(_common.BaseModel): + """The topic ID for a memory.""" + + custom_memory_topic_label: Optional[str] = Field( + default=None, + description="""Optional. Represents the custom memory topic label.""", + ) + managed_memory_topic: Optional[ManagedTopicEnum] = Field( + default=None, description="""Optional. Represents the managed memory topic.""" + ) + + +class MemoryTopicIdDict(TypedDict, total=False): + """The topic ID for a memory.""" + + custom_memory_topic_label: Optional[str] + """Optional. Represents the custom memory topic label.""" + + managed_memory_topic: Optional[ManagedTopicEnum] + """Optional. Represents the managed memory topic.""" + + +MemoryTopicIdOrDict = Union[MemoryTopicId, MemoryTopicIdDict] + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory( + _common.BaseModel +): + """A memory generated by the operation.""" + + fact: Optional[str] = Field( + default=None, + description="""Required. Represents the fact to generate a memory from.""", + ) + topics: Optional[list[MemoryTopicId]] = Field( + default=None, + description="""Optional. Represents the list of topics that the memory should be associated with. For example, use `custom_memory_topic_label = "jargon"` if the extracted memory is an example of memory extraction for the custom topic `jargon`.""", + ) + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict( + TypedDict, total=False +): + """A memory generated by the operation.""" + + fact: Optional[str] + """Required. Represents the fact to generate a memory from.""" + + topics: Optional[list[MemoryTopicIdDict]] + """Optional. Represents the list of topics that the memory should be associated with. For example, use `custom_memory_topic_label = "jargon"` if the extracted memory is an example of memory extraction for the custom topic `jargon`.""" + + +MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryOrDict = Union[ + MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory, + MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict, +] + + +class MemoryBankCustomizationConfigGenerateMemoriesExample(_common.BaseModel): + """An example of how to generate memories for a particular scope.""" + + conversation_source: Optional[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSource + ] = Field(default=None, description="""A conversation source for the example.""") + generated_memories: Optional[ + list[MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemory] + ] = Field( + default=None, + description="""Optional. Represents the memories that are expected to be generated from the input conversation. An empty list indicates that no memories are expected to be generated for the input conversation.""", + ) + + +class MemoryBankCustomizationConfigGenerateMemoriesExampleDict(TypedDict, total=False): + """An example of how to generate memories for a particular scope.""" + + conversation_source: Optional[ + MemoryBankCustomizationConfigGenerateMemoriesExampleConversationSourceDict + ] + """A conversation source for the example.""" + + generated_memories: Optional[ + list[MemoryBankCustomizationConfigGenerateMemoriesExampleGeneratedMemoryDict] + ] + """Optional. Represents the memories that are expected to be generated from the input conversation. An empty list indicates that no memories are expected to be generated for the input conversation.""" + + +MemoryBankCustomizationConfigGenerateMemoriesExampleOrDict = Union[ + MemoryBankCustomizationConfigGenerateMemoriesExample, + MemoryBankCustomizationConfigGenerateMemoriesExampleDict, +] + + +class MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopic(_common.BaseModel): + """A custom memory topic defined by the developer.""" + + label: Optional[str] = Field( + default=None, description="""Required. Represents the label of the topic.""" + ) + description: Optional[str] = Field( + default=None, + description="""Required. Represents the description of the memory topic. This should explain what information should be extracted for this topic.""", + ) + + +class MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicDict( + TypedDict, total=False +): + """A custom memory topic defined by the developer.""" + + label: Optional[str] + """Required. Represents the label of the topic.""" + + description: Optional[str] + """Required. Represents the description of the memory topic. This should explain what information should be extracted for this topic.""" + + +MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicOrDict = Union[ + MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopic, + MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicDict, +] + + +class MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopic(_common.BaseModel): + """A managed memory topic defined by the system.""" + + managed_topic_enum: Optional[ManagedTopicEnum] = Field( + default=None, description="""Required. Represents the managed topic.""" + ) + + +class MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicDict( + TypedDict, total=False +): + """A managed memory topic defined by the system.""" + + managed_topic_enum: Optional[ManagedTopicEnum] + """Required. Represents the managed topic.""" + + +MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicOrDict = Union[ + MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopic, + MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicDict, +] + + +class MemoryBankCustomizationConfigMemoryTopic(_common.BaseModel): + """A topic of information that should be extracted from conversations and stored as memories.""" + + custom_memory_topic: Optional[ + MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopic + ] = Field( + default=None, description="""A custom memory topic defined by the developer.""" + ) + managed_memory_topic: Optional[ + MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopic + ] = Field( + default=None, description="""A managed memory topic defined by Memory Bank.""" + ) + + +class MemoryBankCustomizationConfigMemoryTopicDict(TypedDict, total=False): + """A topic of information that should be extracted from conversations and stored as memories.""" + + custom_memory_topic: Optional[ + MemoryBankCustomizationConfigMemoryTopicCustomMemoryTopicDict + ] + """A custom memory topic defined by the developer.""" + + managed_memory_topic: Optional[ + MemoryBankCustomizationConfigMemoryTopicManagedMemoryTopicDict + ] + """A managed memory topic defined by Memory Bank.""" + + +MemoryBankCustomizationConfigMemoryTopicOrDict = Union[ + MemoryBankCustomizationConfigMemoryTopic, + MemoryBankCustomizationConfigMemoryTopicDict, +] + + +class MemoryBankCustomizationConfigConsolidationConfig(_common.BaseModel): + """Represents configuration for customizing how memories are consolidated.""" + + revisions_per_candidate_count: Optional[int] = Field( + default=None, + description="""Optional. Represents the maximum number of revisions to consider for each candidate memory. If not set, then the default value (1) will be used, which means that only the latest revision will be considered.""", + ) + + +class MemoryBankCustomizationConfigConsolidationConfigDict(TypedDict, total=False): + """Represents configuration for customizing how memories are consolidated.""" + + revisions_per_candidate_count: Optional[int] + """Optional. Represents the maximum number of revisions to consider for each candidate memory. If not set, then the default value (1) will be used, which means that only the latest revision will be considered.""" + + +MemoryBankCustomizationConfigConsolidationConfigOrDict = Union[ + MemoryBankCustomizationConfigConsolidationConfig, + MemoryBankCustomizationConfigConsolidationConfigDict, +] + + +class MemoryBankCustomizationConfig(_common.BaseModel): + """Represents configuration for organizing natural language memories for a particular scope.""" + + enable_third_person_memories: Optional[bool] = Field( + default=None, + description="""Optional. Indicates whether the memories will be generated in the third person (i.e. "The user generates memories with Memory Bank."). By default, the memories will be generated in the first person (i.e. "I generate memories with Memory Bank.")""", + ) + generate_memories_examples: Optional[ + list[MemoryBankCustomizationConfigGenerateMemoriesExample] + ] = Field( + default=None, + description="""Optional. Provides examples of how to generate memories for a particular scope.""", + ) + memory_topics: Optional[list[MemoryBankCustomizationConfigMemoryTopic]] = Field( + default=None, + description="""Optional. Represents topics of information that should be extracted from conversations and stored as memories. If not set, then Memory Bank's default topics will be used.""", + ) + scope_keys: Optional[list[str]] = Field( + default=None, + description="""Optional. Represents the scope keys (i.e. 'user_id') for which to use this config. A request's scope must include all of the provided keys for the config to be used (order does not matter). If empty, then the config will be used for all requests that do not have a more specific config. Only one default config is allowed per Memory Bank.""", + ) + consolidation_config: Optional[MemoryBankCustomizationConfigConsolidationConfig] = ( + Field( + default=None, + description="""Optional. Represents configuration for customizing how memories are consolidated together.""", + ) + ) + disable_natural_language_memories: Optional[bool] = Field( + default=None, + description="""Optional. Indicates whether natural language memory generation should be disabled for all requests. By default, natural language memory generation is enabled. Set this to `true` when you only want to generate structured memories.""", + ) + + +class MemoryBankCustomizationConfigDict(TypedDict, total=False): + """Represents configuration for organizing natural language memories for a particular scope.""" + + enable_third_person_memories: Optional[bool] + """Optional. Indicates whether the memories will be generated in the third person (i.e. "The user generates memories with Memory Bank."). By default, the memories will be generated in the first person (i.e. "I generate memories with Memory Bank.")""" + + generate_memories_examples: Optional[ + list[MemoryBankCustomizationConfigGenerateMemoriesExampleDict] + ] + """Optional. Provides examples of how to generate memories for a particular scope.""" + + memory_topics: Optional[list[MemoryBankCustomizationConfigMemoryTopicDict]] + """Optional. Represents topics of information that should be extracted from conversations and stored as memories. If not set, then Memory Bank's default topics will be used.""" + + scope_keys: Optional[list[str]] + """Optional. Represents the scope keys (i.e. 'user_id') for which to use this config. A request's scope must include all of the provided keys for the config to be used (order does not matter). If empty, then the config will be used for all requests that do not have a more specific config. Only one default config is allowed per Memory Bank.""" + + consolidation_config: Optional[MemoryBankCustomizationConfigConsolidationConfigDict] + """Optional. Represents configuration for customizing how memories are consolidated together.""" + + disable_natural_language_memories: Optional[bool] + """Optional. Indicates whether natural language memory generation should be disabled for all requests. By default, natural language memory generation is enabled. Set this to `true` when you only want to generate structured memories.""" + + +MemoryBankCustomizationConfigOrDict = Union[ + MemoryBankCustomizationConfig, MemoryBankCustomizationConfigDict +] + + +class MemoryGenerationTriggerConfigGenerationTriggerRule(_common.BaseModel): + """Represents the active rule that determines when to flush the buffer.""" + + event_count: Optional[int] = Field( + default=None, + description="""Specifies to trigger generation when the event count reaches this limit.""", + ) + fixed_interval: Optional[str] = Field( + default=None, + description="""Specifies to trigger generation at a fixed interval. The duration must have a minute-level granularity.""", + ) + idle_duration: Optional[str] = Field( + default=None, + description="""Specifies to trigger generation if the stream is inactive for the specified duration after the most recent event. The duration must have a minute-level granularity.""", + ) + + +class MemoryGenerationTriggerConfigGenerationTriggerRuleDict(TypedDict, total=False): + """Represents the active rule that determines when to flush the buffer.""" + + event_count: Optional[int] + """Specifies to trigger generation when the event count reaches this limit.""" + + fixed_interval: Optional[str] + """Specifies to trigger generation at a fixed interval. The duration must have a minute-level granularity.""" + + idle_duration: Optional[str] + """Specifies to trigger generation if the stream is inactive for the specified duration after the most recent event. The duration must have a minute-level granularity.""" + + +MemoryGenerationTriggerConfigGenerationTriggerRuleOrDict = Union[ + MemoryGenerationTriggerConfigGenerationTriggerRule, + MemoryGenerationTriggerConfigGenerationTriggerRuleDict, +] + + +class MemoryGenerationTriggerConfig(_common.BaseModel): + """The configuration for triggering memory generation for ingested events.""" + + generation_rule: Optional[MemoryGenerationTriggerConfigGenerationTriggerRule] = ( + Field( + default=None, + description="""Optional. Represents the active rule that determines when to flush the buffer. If not set, then the stream will be force flushed immediately.""", + ) + ) + + +class MemoryGenerationTriggerConfigDict(TypedDict, total=False): + """The configuration for triggering memory generation for ingested events.""" + + generation_rule: Optional[MemoryGenerationTriggerConfigGenerationTriggerRuleDict] + """Optional. Represents the active rule that determines when to flush the buffer. If not set, then the stream will be force flushed immediately.""" + + +MemoryGenerationTriggerConfigOrDict = Union[ + MemoryGenerationTriggerConfig, MemoryGenerationTriggerConfigDict +] + + +class ReasoningEngineContextSpecMemoryBankConfigGenerationConfig(_common.BaseModel): + """Configuration for how to generate memories.""" + + model: Optional[str] = Field( + default=None, + description="""Optional. The model used to generate memories. Format: `projects/{project}/locations/{location}/publishers/google/models/{model}`.""", + ) + generation_trigger_config: Optional[MemoryGenerationTriggerConfig] = Field( + default=None, + description="""Optional. Specifies the default trigger configuration for generating memories using `IngestEvents`.""", + ) + + +class ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict( + TypedDict, total=False +): + """Configuration for how to generate memories.""" + + model: Optional[str] + """Optional. The model used to generate memories. Format: `projects/{project}/locations/{location}/publishers/google/models/{model}`.""" + + generation_trigger_config: Optional[MemoryGenerationTriggerConfigDict] + """Optional. Specifies the default trigger configuration for generating memories using `IngestEvents`.""" + + +ReasoningEngineContextSpecMemoryBankConfigGenerationConfigOrDict = Union[ + ReasoningEngineContextSpecMemoryBankConfigGenerationConfig, + ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict, +] + + +class ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfig( + _common.BaseModel +): + """Configuration for how to perform similarity search on memories.""" + + embedding_model: Optional[str] = Field( + default=None, + description="""Required. The model used to generate embeddings to lookup similar memories. Format: `projects/{project}/locations/{location}/publishers/google/models/{model}`.""", + ) + + +class ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigDict( + TypedDict, total=False +): + """Configuration for how to perform similarity search on memories.""" + + embedding_model: Optional[str] + """Required. The model used to generate embeddings to lookup similar memories. Format: `projects/{project}/locations/{location}/publishers/google/models/{model}`.""" + + +ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigOrDict = Union[ + ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfig, + ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigDict, +] + + +class ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfig( + _common.BaseModel +): + """Configuration for TTL of the memories in the Memory Bank based on the action that created or updated the memory.""" + + create_ttl: Optional[str] = Field( + default=None, + description="""Optional. The TTL duration for memories uploaded via CreateMemory.""", + ) + generate_created_ttl: Optional[str] = Field( + default=None, + description="""Optional. The TTL duration for memories newly generated via GenerateMemories (GenerateMemoriesResponse.GeneratedMemory.Action.CREATED).""", + ) + generate_updated_ttl: Optional[str] = Field( + default=None, + description="""Optional. The TTL duration for memories updated via GenerateMemories (GenerateMemoriesResponse.GeneratedMemory.Action.UPDATED). In the case of an UPDATE action, the `expire_time` of the existing memory will be updated to the new value (now + TTL).""", + ) + + +class ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigDict( + TypedDict, total=False +): + """Configuration for TTL of the memories in the Memory Bank based on the action that created or updated the memory.""" + + create_ttl: Optional[str] + """Optional. The TTL duration for memories uploaded via CreateMemory.""" + + generate_created_ttl: Optional[str] + """Optional. The TTL duration for memories newly generated via GenerateMemories (GenerateMemoriesResponse.GeneratedMemory.Action.CREATED).""" + + generate_updated_ttl: Optional[str] + """Optional. The TTL duration for memories updated via GenerateMemories (GenerateMemoriesResponse.GeneratedMemory.Action.UPDATED). In the case of an UPDATE action, the `expire_time` of the existing memory will be updated to the new value (now + TTL).""" + + +ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigOrDict = Union[ + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfig, + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigDict, +] + + +class ReasoningEngineContextSpecMemoryBankConfigTtlConfig(_common.BaseModel): + """Configuration for automatically setting the TTL ("time-to-live") of the memories in the Memory Bank.""" + + default_ttl: Optional[str] = Field( + default=None, + description="""Optional. The default TTL duration of the memories in the Memory Bank. This applies to all operations that create or update a memory.""", + ) + granular_ttl_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfig + ] = Field( + default=None, + description="""Optional. The granular TTL configuration of the memories in the Memory Bank.""", + ) + memory_revision_default_ttl: Optional[str] = Field( + default=None, + description="""Optional. The default TTL duration of the memory revisions in the Memory Bank. This applies to all operations that create a memory revision. If not set, a default TTL of 365 days will be used.""", + ) + + +class ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict(TypedDict, total=False): + """Configuration for automatically setting the TTL ("time-to-live") of the memories in the Memory Bank.""" + + default_ttl: Optional[str] + """Optional. The default TTL duration of the memories in the Memory Bank. This applies to all operations that create or update a memory.""" + + granular_ttl_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigTtlConfigGranularTtlConfigDict + ] + """Optional. The granular TTL configuration of the memories in the Memory Bank.""" + + memory_revision_default_ttl: Optional[str] + """Optional. The default TTL duration of the memory revisions in the Memory Bank. This applies to all operations that create a memory revision. If not set, a default TTL of 365 days will be used.""" + + +ReasoningEngineContextSpecMemoryBankConfigTtlConfigOrDict = Union[ + ReasoningEngineContextSpecMemoryBankConfigTtlConfig, + ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict, +] + + +class StructuredMemorySchemaConfig(_common.BaseModel): + """Represents the OpenAPI schema of the structured memories.""" + + memory_schema: Optional[genai_types.Schema] = Field( + default=None, + description="""Required. Represents the OpenAPI schema of the structured memories.""", + ) + id: Optional[str] = Field( + default=None, + description="""Required. Represents the ID of the schema. Must be 1-63 characters, start with a lowercase letter, and consist of lowercase letters, numbers, and hyphens.""", + ) + memory_type: Optional[MemoryType] = Field( + default=None, + description="""Optional. Represents the type of the structured memories associated with the schema. If not set, then `STRUCTURED_PROFILE` will be used.""", + ) + + +class StructuredMemorySchemaConfigDict(TypedDict, total=False): + """Represents the OpenAPI schema of the structured memories.""" + + memory_schema: Optional[genai_types.SchemaDict] + """Required. Represents the OpenAPI schema of the structured memories.""" + + id: Optional[str] + """Required. Represents the ID of the schema. Must be 1-63 characters, start with a lowercase letter, and consist of lowercase letters, numbers, and hyphens.""" + + memory_type: Optional[MemoryType] + """Optional. Represents the type of the structured memories associated with the schema. If not set, then `STRUCTURED_PROFILE` will be used.""" + + +StructuredMemorySchemaConfigOrDict = Union[ + StructuredMemorySchemaConfig, StructuredMemorySchemaConfigDict +] + + +class StructuredMemoryConfig(_common.BaseModel): + """Configuration for organizing structured memories within a scope.""" + + schema_configs: Optional[list[StructuredMemorySchemaConfig]] = Field( + default=None, + description="""Optional. Represents configuration of the structured memories' schemas.""", + ) + scope_keys: Optional[list[str]] = Field( + default=None, + description="""Optional. Represents the scope keys (i.e. 'user_id') for which to use this config. A request's scope must include all of the provided keys for the config to be used (order does not matter). If empty, then the config will be used for all requests that do not have a more specific config. Only one default config is allowed per Memory Bank.""", + ) + + +class StructuredMemoryConfigDict(TypedDict, total=False): + """Configuration for organizing structured memories within a scope.""" + + schema_configs: Optional[list[StructuredMemorySchemaConfigDict]] + """Optional. Represents configuration of the structured memories' schemas.""" + + scope_keys: Optional[list[str]] + """Optional. Represents the scope keys (i.e. 'user_id') for which to use this config. A request's scope must include all of the provided keys for the config to be used (order does not matter). If empty, then the config will be used for all requests that do not have a more specific config. Only one default config is allowed per Memory Bank.""" + + +StructuredMemoryConfigOrDict = Union[StructuredMemoryConfig, StructuredMemoryConfigDict] + + +class ReasoningEngineContextSpecMemoryBankConfig(_common.BaseModel): + """Specification for a Memory Bank.""" + + customization_configs: Optional[list[MemoryBankCustomizationConfig]] = Field( + default=None, + description="""Optional. Configuration for how to customize Memory Bank behavior for a particular scope.""", + ) + disable_memory_revisions: Optional[bool] = Field( + default=None, + description="""If true, no memory revisions will be created for any requests to the Memory Bank.""", + ) + generation_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigGenerationConfig + ] = Field( + default=None, + description="""Optional. Configuration for how to generate memories for the Memory Bank.""", + ) + similarity_search_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfig + ] = Field( + default=None, + description="""Optional. Configuration for how to perform similarity search on memories. If not set, the Memory Bank will use the default embedding model `text-embedding-005`.""", + ) + ttl_config: Optional[ReasoningEngineContextSpecMemoryBankConfigTtlConfig] = Field( + default=None, + description="""Optional. Configuration for automatic TTL ("time-to-live") of the memories in the Memory Bank. If not set, TTL will not be applied automatically. The TTL can be explicitly set by modifying the `expire_time` of each Memory resource.""", + ) + structured_memory_configs: Optional[list[StructuredMemoryConfig]] = Field( + default=None, + description="""Optional. Configuration for organizing structured memories for a particular scope.""", + ) + + +class ReasoningEngineContextSpecMemoryBankConfigDict(TypedDict, total=False): + """Specification for a Memory Bank.""" + + customization_configs: Optional[list[MemoryBankCustomizationConfigDict]] + """Optional. Configuration for how to customize Memory Bank behavior for a particular scope.""" + + disable_memory_revisions: Optional[bool] + """If true, no memory revisions will be created for any requests to the Memory Bank.""" + + generation_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigGenerationConfigDict + ] + """Optional. Configuration for how to generate memories for the Memory Bank.""" + + similarity_search_config: Optional[ + ReasoningEngineContextSpecMemoryBankConfigSimilaritySearchConfigDict + ] + """Optional. Configuration for how to perform similarity search on memories. If not set, the Memory Bank will use the default embedding model `text-embedding-005`.""" + + ttl_config: Optional[ReasoningEngineContextSpecMemoryBankConfigTtlConfigDict] + """Optional. Configuration for automatic TTL ("time-to-live") of the memories in the Memory Bank. If not set, TTL will not be applied automatically. The TTL can be explicitly set by modifying the `expire_time` of each Memory resource.""" + + structured_memory_configs: Optional[list[StructuredMemoryConfigDict]] + """Optional. Configuration for organizing structured memories for a particular scope.""" + + +ReasoningEngineContextSpecMemoryBankConfigOrDict = Union[ + ReasoningEngineContextSpecMemoryBankConfig, + ReasoningEngineContextSpecMemoryBankConfigDict, +] + + +class ReasoningEngineContextSpec(_common.BaseModel): + """Configuration for how Agent Engine sub-resources should manage context.""" + + memory_bank_config: Optional[ReasoningEngineContextSpecMemoryBankConfig] = Field( + default=None, + description="""Optional. Specification for a Memory Bank, which manages memories for the Agent Engine.""", + ) + + +class ReasoningEngineContextSpecDict(TypedDict, total=False): + """Configuration for how Agent Engine sub-resources should manage context.""" + + memory_bank_config: Optional[ReasoningEngineContextSpecMemoryBankConfigDict] + """Optional. Specification for a Memory Bank, which manages memories for the Agent Engine.""" + + +ReasoningEngineContextSpecOrDict = Union[ + ReasoningEngineContextSpec, ReasoningEngineContextSpecDict +] + + +class SecretRef(_common.BaseModel): + """Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" + + secret: Optional[str] = Field( + default=None, + description="""Required. The name of the secret in Cloud Secret Manager. Format: {secret_name}.""", + ) + version: Optional[str] = Field( + default=None, + description="""The Cloud Secret Manager secret version. Can be 'latest' for the latest version, an integer for a specific version, or a version alias.""", + ) + + +class SecretRefDict(TypedDict, total=False): + """Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" + + secret: Optional[str] + """Required. The name of the secret in Cloud Secret Manager. Format: {secret_name}.""" + + version: Optional[str] + """The Cloud Secret Manager secret version. Can be 'latest' for the latest version, an integer for a specific version, or a version alias.""" + + +SecretRefOrDict = Union[SecretRef, SecretRefDict] + + +class SecretEnvVar(_common.BaseModel): + """Represents an environment variable where the value is a secret in Cloud Secret Manager.""" + + name: Optional[str] = Field( + default=None, + description="""Required. Name of the secret environment variable.""", + ) + secret_ref: Optional[SecretRef] = Field( + default=None, + description="""Required. Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""", + ) + + +class SecretEnvVarDict(TypedDict, total=False): + """Represents an environment variable where the value is a secret in Cloud Secret Manager.""" + + name: Optional[str] + """Required. Name of the secret environment variable.""" + + secret_ref: Optional[SecretRefDict] + """Required. Reference to a secret stored in the Cloud Secret Manager that will provide the value for this environment variable.""" + + +SecretEnvVarOrDict = Union[SecretEnvVar, SecretEnvVarDict] + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfig( + _common.BaseModel +): + """Configuration for traffic originating from a Reasoning Engine.""" + + agent_gateway: Optional[str] = Field( + default=None, + description="""Required. The resource name of the Agent Gateway for outbound traffic. It must be set to a Google-managed gateway whose `governed_access_path` is `AGENT_TO_ANYWHERE`. Format: `projects/{project}/locations/{location}/agentGateways/{agent_gateway}`""", + ) + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigDict( + TypedDict, total=False +): + """Configuration for traffic originating from a Reasoning Engine.""" + + agent_gateway: Optional[str] + """Required. The resource name of the Agent Gateway for outbound traffic. It must be set to a Google-managed gateway whose `governed_access_path` is `AGENT_TO_ANYWHERE`. Format: `projects/{project}/locations/{location}/agentGateways/{agent_gateway}`""" + + +ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigOrDict = Union[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfig, + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigDict, +] + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfig( + _common.BaseModel +): + """Configuration for traffic targeting a Reasoning Engine.""" + + agent_gateway: Optional[str] = Field( + default=None, + description="""Required. The resource name of the Agent Gateway to use for inbound traffic. It must be set to a Google-managed gateway whose `governed_access_path` is `CLIENT_TO_AGENT`. Format: `projects/{project}/locations/{location}/agentGateways/{agent_gateway}`""", + ) + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigDict( + TypedDict, total=False +): + """Configuration for traffic targeting a Reasoning Engine.""" + + agent_gateway: Optional[str] + """Required. The resource name of the Agent Gateway to use for inbound traffic. It must be set to a Google-managed gateway whose `governed_access_path` is `CLIENT_TO_AGENT`. Format: `projects/{project}/locations/{location}/agentGateways/{agent_gateway}`""" + + +ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigOrDict = Union[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfig, + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigDict, +] + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfig(_common.BaseModel): + """Agent Gateway configuration for a Reasoning Engine deployment.""" + + agent_to_anywhere_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfig + ] = Field( + default=None, + description="""Optional. Configuration for traffic originating from the Reasoning Engine. When unset, outgoing traffic is not routed through an Agent Gateway.""", + ) + client_to_agent_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfig + ] = Field( + default=None, + description="""Optional. Configuration for traffic targeting the Reasoning Engine. When unset, incoming traffic is not routed through an Agent Gateway.""", + ) + + +class ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict(TypedDict, total=False): + """Agent Gateway configuration for a Reasoning Engine deployment.""" + + agent_to_anywhere_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigAgentToAnywhereConfigDict + ] + """Optional. Configuration for traffic originating from the Reasoning Engine. When unset, outgoing traffic is not routed through an Agent Gateway.""" + + client_to_agent_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigClientToAgentConfigDict + ] + """Optional. Configuration for traffic targeting the Reasoning Engine. When unset, incoming traffic is not routed through an Agent Gateway.""" + + +ReasoningEngineSpecDeploymentSpecAgentGatewayConfigOrDict = Union[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfig, + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict, +] + + +class KeepAliveProbeHttpGet(_common.BaseModel): + """Specifies the HTTP GET configuration for the probe.""" + + path: Optional[str] = Field( + default=None, + description="""Required. Specifies the path of the HTTP GET request (e.g., "/is_busy").""", + ) + port: Optional[int] = Field( + default=None, + description="""Optional. Specifies the port number on the container to which the request is sent.""", + ) + + +class KeepAliveProbeHttpGetDict(TypedDict, total=False): + """Specifies the HTTP GET configuration for the probe.""" + + path: Optional[str] + """Required. Specifies the path of the HTTP GET request (e.g., "/is_busy").""" + + port: Optional[int] + """Optional. Specifies the port number on the container to which the request is sent.""" + + +KeepAliveProbeHttpGetOrDict = Union[KeepAliveProbeHttpGet, KeepAliveProbeHttpGetDict] + + +class KeepAliveProbe(_common.BaseModel): + """Represents the configuration for keep-alive probe. Contains configuration on a specified endpoint that a deployment host should use to keep the container alive based on the probe settings.""" + + http_get: Optional[KeepAliveProbeHttpGet] = Field( + default=None, + description="""Optional. Specifies the HTTP GET configuration for the probe.""", + ) + max_seconds: Optional[int] = Field( + default=None, + description="""Optional. Specifies the maximum duration (in seconds) to keep the instance alive via this probe. Can be a maximum of 3600 seconds (1 hour).""", + ) + + +class KeepAliveProbeDict(TypedDict, total=False): + """Represents the configuration for keep-alive probe. Contains configuration on a specified endpoint that a deployment host should use to keep the container alive based on the probe settings.""" + + http_get: Optional[KeepAliveProbeHttpGetDict] + """Optional. Specifies the HTTP GET configuration for the probe.""" + + max_seconds: Optional[int] + """Optional. Specifies the maximum duration (in seconds) to keep the instance alive via this probe. Can be a maximum of 3600 seconds (1 hour).""" + + +KeepAliveProbeOrDict = Union[KeepAliveProbe, KeepAliveProbeDict] + + +class ReasoningEngineSpecDeploymentSpec(_common.BaseModel): + """The specification of a Reasoning Engine deployment.""" + + agent_server_mode: Optional[AgentServerMode] = Field( + default=None, description="""The agent server mode.""" + ) + container_concurrency: Optional[int] = Field( + default=None, + description="""Optional. Concurrency for each container and agent server. Recommended value: 2 * cpu + 1. Defaults to 9.""", + ) + env: Optional[list[EnvVar]] = Field( + default=None, + description="""Optional. Environment variables to be set with the Reasoning Engine deployment. The environment variables can be updated through the UpdateReasoningEngine API.""", + ) + max_instances: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of application instances that can be launched to handle increased traffic. Defaults to 100. Range: [1, 1000]. If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100].""", + ) + min_instances: Optional[int] = Field( + default=None, + description="""Optional. The minimum number of application instances that will be kept running at all times. Defaults to 1. Range: [0, 75].""", + ) + psc_interface_config: Optional[PscInterfaceConfig] = Field( + default=None, description="""Optional. Configuration for PSC-I.""" + ) + resource_limits: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. Resource limits for each container. Only 'cpu' and 'memory' keys are supported. Defaults to {"cpu": "4", "memory": "4Gi"}. * The only supported values for CPU are '1', '2', '4', '6' and '8'. For more information, go to https://cloud.google.com/run/docs/configuring/cpu. * The only supported values for memory are '1Gi', '2Gi', ... '32 Gi'. * For required cpu on different memory values, go to https://cloud.google.com/run/docs/configuring/memory-limits""", + ) + secret_env: Optional[list[SecretEnvVar]] = Field( + default=None, + description="""Optional. Environment variables where the value is a secret in Cloud Secret Manager. To use this feature, add 'Secret Manager Secret Accessor' role (roles/secretmanager.secretAccessor) to AI Platform Reasoning Engine Service Agent.""", + ) + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfig + ] = Field( + default=None, + description="""Optional. Agent Gateway configuration for the Reasoning Engine deployment.""", + ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. Contains configuration on a specified endpoint that a deployment host should use to keep the container alive based on the probe settings.""", + ) + + +class ReasoningEngineSpecDeploymentSpecDict(TypedDict, total=False): + """The specification of a Reasoning Engine deployment.""" + + agent_server_mode: Optional[AgentServerMode] + """The agent server mode.""" + + container_concurrency: Optional[int] + """Optional. Concurrency for each container and agent server. Recommended value: 2 * cpu + 1. Defaults to 9.""" + + env: Optional[list[EnvVarDict]] + """Optional. Environment variables to be set with the Reasoning Engine deployment. The environment variables can be updated through the UpdateReasoningEngine API.""" + + max_instances: Optional[int] + """Optional. The maximum number of application instances that can be launched to handle increased traffic. Defaults to 100. Range: [1, 1000]. If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100].""" + + min_instances: Optional[int] + """Optional. The minimum number of application instances that will be kept running at all times. Defaults to 1. Range: [0, 75].""" + + psc_interface_config: Optional[PscInterfaceConfigDict] + """Optional. Configuration for PSC-I.""" + + resource_limits: Optional[dict[str, str]] + """Optional. Resource limits for each container. Only 'cpu' and 'memory' keys are supported. Defaults to {"cpu": "4", "memory": "4Gi"}. * The only supported values for CPU are '1', '2', '4', '6' and '8'. For more information, go to https://cloud.google.com/run/docs/configuring/cpu. * The only supported values for memory are '1Gi', '2Gi', ... '32 Gi'. * For required cpu on different memory values, go to https://cloud.google.com/run/docs/configuring/memory-limits""" + + secret_env: Optional[list[SecretEnvVarDict]] + """Optional. Environment variables where the value is a secret in Cloud Secret Manager. To use this feature, add 'Secret Manager Secret Accessor' role (roles/secretmanager.secretAccessor) to AI Platform Reasoning Engine Service Agent.""" + + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] + """Optional. Agent Gateway configuration for the Reasoning Engine deployment.""" + + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. Contains configuration on a specified endpoint that a deployment host should use to keep the container alive based on the probe settings.""" + + +ReasoningEngineSpecDeploymentSpecOrDict = Union[ + ReasoningEngineSpecDeploymentSpec, ReasoningEngineSpecDeploymentSpecDict +] + + +class ReasoningEngineSpecPackageSpec(_common.BaseModel): + """User-provided package specification, containing pickled object and package requirements.""" + + dependency_files_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the dependency files in tar.gz format.""", + ) + pickle_object_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the pickled python object.""", + ) + python_version: Optional[str] = Field( + default=None, + description="""Optional. The Python version. Supported values are 3.9, 3.10, 3.11, 3.12, 3.13, 3.14. If not specified, the default value is 3.10.""", + ) + requirements_gcs_uri: Optional[str] = Field( + default=None, + description="""Optional. The Cloud Storage URI of the `requirements.txt` file""", + ) + + +class ReasoningEngineSpecPackageSpecDict(TypedDict, total=False): + """User-provided package specification, containing pickled object and package requirements.""" + + dependency_files_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the dependency files in tar.gz format.""" + + pickle_object_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the pickled python object.""" + + python_version: Optional[str] + """Optional. The Python version. Supported values are 3.9, 3.10, 3.11, 3.12, 3.13, 3.14. If not specified, the default value is 3.10.""" + + requirements_gcs_uri: Optional[str] + """Optional. The Cloud Storage URI of the `requirements.txt` file""" + + +ReasoningEngineSpecPackageSpecOrDict = Union[ + ReasoningEngineSpecPackageSpec, ReasoningEngineSpecPackageSpecDict +] + + +class ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfig(_common.BaseModel): + """Configuration for the Agent Development Kit (ADK).""" + + json_config: Optional[dict[str, Any]] = Field( + default=None, + description="""Required. The value of the ADK config in JSON format.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigDict( + TypedDict, total=False +): + """Configuration for the Agent Development Kit (ADK).""" + + json_config: Optional[dict[str, Any]] + """Required. The value of the ADK config in JSON format.""" + + +ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfig, + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigDict, +] + + +class ReasoningEngineSpecSourceCodeSpecInlineSource(_common.BaseModel): + """Specifies source code provided as a byte stream.""" + + source_archive: Optional[bytes] = Field( + default=None, + description="""Required. Input only. The application source code archive. It must be a compressed tarball (.tar.gz) file.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecInlineSourceDict(TypedDict, total=False): + """Specifies source code provided as a byte stream.""" + + source_archive: Optional[bytes] + """Required. Input only. The application source code archive. It must be a compressed tarball (.tar.gz) file.""" + + +ReasoningEngineSpecSourceCodeSpecInlineSourceOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecInlineSource, + ReasoningEngineSpecSourceCodeSpecInlineSourceDict, +] + + +class ReasoningEngineSpecSourceCodeSpecAgentConfigSource(_common.BaseModel): + """Specification for the deploying from agent config.""" + + adk_config: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfig + ] = Field(default=None, description="""Required. The ADK configuration.""") + inline_source: Optional[ReasoningEngineSpecSourceCodeSpecInlineSource] = Field( + default=None, + description="""Optional. Any additional files needed to interpret the config. If a `requirements.txt` file is present in the `inline_source`, the corresponding packages will be installed. If no `requirements.txt` file is present in `inline_source`, then the latest version of `google-adk` will be installed for interpreting the ADK config.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict(TypedDict, total=False): + """Specification for the deploying from agent config.""" + + adk_config: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceAdkConfigDict + ] + """Required. The ADK configuration.""" + + inline_source: Optional[ReasoningEngineSpecSourceCodeSpecInlineSourceDict] + """Optional. Any additional files needed to interpret the config. If a `requirements.txt` file is present in the `inline_source`, the corresponding packages will be installed. If no `requirements.txt` file is present in `inline_source`, then the latest version of `google-adk` will be installed for interpreting the ADK config.""" + + +ReasoningEngineSpecSourceCodeSpecAgentConfigSourceOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSource, + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict, +] + + +class ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig(_common.BaseModel): + """Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. + + This includes the repository, revision, and directory to use. + """ + + git_repository_link: Optional[str] = Field( + default=None, + description="""Required. The Developer Connect Git repository link, formatted as `projects/{project_id}/locations/{location_id}/connections/{connection_id}/gitRepositoryLink/{repository_link_id}`.""", + ) + dir: Optional[str] = Field( + default=None, + description="""Required. Directory, relative to the source root, in which to run the build.""", + ) + revision: Optional[str] = Field( + default=None, + description="""Required. The revision to fetch from the Git repository such as a branch, a tag, a commit SHA, or any Git ref.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict( + TypedDict, total=False +): + """Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. + + This includes the repository, revision, and directory to use. + """ + + git_repository_link: Optional[str] + """Required. The Developer Connect Git repository link, formatted as `projects/{project_id}/locations/{location_id}/connections/{connection_id}/gitRepositoryLink/{repository_link_id}`.""" + + dir: Optional[str] + """Required. Directory, relative to the source root, in which to run the build.""" + + revision: Optional[str] + """Required. The revision to fetch from the Git repository such as a branch, a tag, a commit SHA, or any Git ref.""" + + +ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig, + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict, +] + + +class ReasoningEngineSpecSourceCodeSpecDeveloperConnectSource(_common.BaseModel): + """Specifies source code to be fetched from a Git repository managed through the Developer Connect service.""" + + config: Optional[ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig] = Field( + default=None, + description="""Required. The Developer Connect configuration that defines the specific repository, revision, and directory to use as the source code root.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceDict( + TypedDict, total=False +): + """Specifies source code to be fetched from a Git repository managed through the Developer Connect service.""" + + config: Optional[ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict] + """Required. The Developer Connect configuration that defines the specific repository, revision, and directory to use as the source code root.""" + + +ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectSource, + ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceDict, +] + + +class ReasoningEngineSpecSourceCodeSpecImageSpec(_common.BaseModel): + """The image spec for building an image (within a single build step), based on the config file (i.e. Dockerfile) in the source directory.""" + + build_args: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. Build arguments to be used. They will be passed through --build-arg flags.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecImageSpecDict(TypedDict, total=False): + """The image spec for building an image (within a single build step), based on the config file (i.e. Dockerfile) in the source directory.""" + + build_args: Optional[dict[str, str]] + """Optional. Build arguments to be used. They will be passed through --build-arg flags.""" + + +ReasoningEngineSpecSourceCodeSpecImageSpecOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecImageSpec, + ReasoningEngineSpecSourceCodeSpecImageSpecDict, +] + + +class ReasoningEngineSpecSourceCodeSpecPythonSpec(_common.BaseModel): + """Specification for running a Python application from source.""" + + entrypoint_module: Optional[str] = Field( + default=None, + description="""Optional. The Python module to load as the entrypoint, specified as a fully qualified module name. For example: path.to.agent. If not specified, defaults to "agent". The project root will be added to Python sys.path, allowing imports to be specified relative to the root. This field should not be set if the source is `agent_config_source`.""", + ) + entrypoint_object: Optional[str] = Field( + default=None, + description="""Optional. The name of the callable object within the `entrypoint_module` to use as the application If not specified, defaults to "root_agent". This field should not be set if the source is `agent_config_source`.""", + ) + requirements_file: Optional[str] = Field( + default=None, + description="""Optional. The path to the requirements file, relative to the source root. If not specified, defaults to "requirements.txt".""", + ) + version: Optional[str] = Field( + default=None, + description="""Optional. The version of Python to use. Support version includes 3.9, 3.10, 3.11, 3.12, 3.13, 3.14. If not specified, default value is 3.10.""", + ) + + +class ReasoningEngineSpecSourceCodeSpecPythonSpecDict(TypedDict, total=False): + """Specification for running a Python application from source.""" + + entrypoint_module: Optional[str] + """Optional. The Python module to load as the entrypoint, specified as a fully qualified module name. For example: path.to.agent. If not specified, defaults to "agent". The project root will be added to Python sys.path, allowing imports to be specified relative to the root. This field should not be set if the source is `agent_config_source`.""" + + entrypoint_object: Optional[str] + """Optional. The name of the callable object within the `entrypoint_module` to use as the application If not specified, defaults to "root_agent". This field should not be set if the source is `agent_config_source`.""" + + requirements_file: Optional[str] + """Optional. The path to the requirements file, relative to the source root. If not specified, defaults to "requirements.txt".""" + + version: Optional[str] + """Optional. The version of Python to use. Support version includes 3.9, 3.10, 3.11, 3.12, 3.13, 3.14. If not specified, default value is 3.10.""" + + +ReasoningEngineSpecSourceCodeSpecPythonSpecOrDict = Union[ + ReasoningEngineSpecSourceCodeSpecPythonSpec, + ReasoningEngineSpecSourceCodeSpecPythonSpecDict, +] + + +class ReasoningEngineSpecSourceCodeSpec(_common.BaseModel): + """Specification for deploying from source code.""" + + agent_config_source: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSource + ] = Field( + default=None, description="""Source code is generated from the agent config.""" + ) + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectSource + ] = Field( + default=None, + description="""Source code is in a Git repository managed by Developer Connect.""", + ) + image_spec: Optional[ReasoningEngineSpecSourceCodeSpecImageSpec] = Field( + default=None, + description="""Optional. Configuration for building an image with custom config file.""", + ) + inline_source: Optional[ReasoningEngineSpecSourceCodeSpecInlineSource] = Field( + default=None, description="""Source code is provided directly in the request.""" + ) + python_spec: Optional[ReasoningEngineSpecSourceCodeSpecPythonSpec] = Field( + default=None, description="""Configuration for a Python application.""" + ) + + +class ReasoningEngineSpecSourceCodeSpecDict(TypedDict, total=False): + """Specification for deploying from source code.""" + + agent_config_source: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict + ] + """Source code is generated from the agent config.""" + + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectSourceDict + ] + """Source code is in a Git repository managed by Developer Connect.""" + + image_spec: Optional[ReasoningEngineSpecSourceCodeSpecImageSpecDict] + """Optional. Configuration for building an image with custom config file.""" + + inline_source: Optional[ReasoningEngineSpecSourceCodeSpecInlineSourceDict] + """Source code is provided directly in the request.""" + + python_spec: Optional[ReasoningEngineSpecSourceCodeSpecPythonSpecDict] + """Configuration for a Python application.""" + + +ReasoningEngineSpecSourceCodeSpecOrDict = Union[ + ReasoningEngineSpecSourceCodeSpec, ReasoningEngineSpecSourceCodeSpecDict +] + + +class ReasoningEngineSpecContainerSpec(_common.BaseModel): + """Specification for deploying from a container image.""" + + image_uri: Optional[str] = Field( + default=None, + description="""Required. The Artifact Registry Docker image URI (e.g., us-central1-docker.pkg.dev/my-project/my-repo/my-image:tag) of the container image that is to be run on each worker replica.""", + ) + + +class ReasoningEngineSpecContainerSpecDict(TypedDict, total=False): + """Specification for deploying from a container image.""" + + image_uri: Optional[str] + """Required. The Artifact Registry Docker image URI (e.g., us-central1-docker.pkg.dev/my-project/my-repo/my-image:tag) of the container image that is to be run on each worker replica.""" + + +ReasoningEngineSpecContainerSpecOrDict = Union[ + ReasoningEngineSpecContainerSpec, ReasoningEngineSpecContainerSpecDict +] + + +class ReasoningEngineSpec(_common.BaseModel): + """The specification of an agent engine.""" + + agent_card: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. The A2A Agent Card for the agent (if available). It follows the specification at https://a2a-protocol.org/latest/specification/#5-agent-discovery-the-agent-card.""", + ) + agent_framework: Optional[str] = Field( + default=None, + description="""Optional. The OSS agent framework used to develop the agent. Currently supported values: "google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom".""", + ) + class_methods: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""Optional. Declarations for object class methods in OpenAPI specification format.""", + ) + deployment_spec: Optional[ReasoningEngineSpecDeploymentSpec] = Field( + default=None, + description="""Optional. The specification of a Reasoning Engine deployment.""", + ) + effective_identity: Optional[str] = Field( + default=None, + description="""Output only. The identity to use for the Reasoning Engine. It can contain one of the following values: * service-{project}@gcp-sa-aiplatform-re.googleapis.com (for SERVICE_AGENT identity type) * {name}@{project}.gserviceaccount.com (for SERVICE_ACCOUNT identity type) * agents.global.{org}.system.id.goog/resources/aiplatform/projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine} (for AGENT_IDENTITY identity type)""", + ) + identity_type: Optional[IdentityType] = Field( + default=None, + description="""Optional. The identity type to use for the Reasoning Engine. If not specified, the `service_account` field will be used if set, otherwise the default Vertex AI Reasoning Engine Service Agent in the project will be used.""", + ) + package_spec: Optional[ReasoningEngineSpecPackageSpec] = Field( + default=None, + description="""Optional. User provided package spec of the ReasoningEngine. Ignored when users directly specify a deployment image through `deployment_spec.first_party_image_override`, but keeping the field_behavior to avoid introducing breaking changes. The `deployment_source` field should not be set if `package_spec` is specified.""", + ) + service_account: Optional[str] = Field( + default=None, + description="""Optional. The service account that the Reasoning Engine artifact runs as. It should have "roles/storage.objectViewer" for reading the user project's Cloud Storage and "roles/aiplatform.user" for using Vertex extensions. If not specified, the Vertex AI Reasoning Engine Service Agent in the project will be used.""", + ) + source_code_spec: Optional[ReasoningEngineSpecSourceCodeSpec] = Field( + default=None, + description="""Deploy from source code files with a defined entrypoint.""", + ) + container_spec: Optional[ReasoningEngineSpecContainerSpec] = Field( + default=None, + description="""Deploy from a container image with a defined entrypoint and commands.""", + ) + + +class ReasoningEngineSpecDict(TypedDict, total=False): + """The specification of an agent engine.""" + + agent_card: Optional[dict[str, Any]] + """Optional. The A2A Agent Card for the agent (if available). It follows the specification at https://a2a-protocol.org/latest/specification/#5-agent-discovery-the-agent-card.""" + + agent_framework: Optional[str] + """Optional. The OSS agent framework used to develop the agent. Currently supported values: "google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom".""" + + class_methods: Optional[list[dict[str, Any]]] + """Optional. Declarations for object class methods in OpenAPI specification format.""" + + deployment_spec: Optional[ReasoningEngineSpecDeploymentSpecDict] + """Optional. The specification of a Reasoning Engine deployment.""" + + effective_identity: Optional[str] + """Output only. The identity to use for the Reasoning Engine. It can contain one of the following values: * service-{project}@gcp-sa-aiplatform-re.googleapis.com (for SERVICE_AGENT identity type) * {name}@{project}.gserviceaccount.com (for SERVICE_ACCOUNT identity type) * agents.global.{org}.system.id.goog/resources/aiplatform/projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine} (for AGENT_IDENTITY identity type)""" + + identity_type: Optional[IdentityType] + """Optional. The identity type to use for the Reasoning Engine. If not specified, the `service_account` field will be used if set, otherwise the default Vertex AI Reasoning Engine Service Agent in the project will be used.""" + + package_spec: Optional[ReasoningEngineSpecPackageSpecDict] + """Optional. User provided package spec of the ReasoningEngine. Ignored when users directly specify a deployment image through `deployment_spec.first_party_image_override`, but keeping the field_behavior to avoid introducing breaking changes. The `deployment_source` field should not be set if `package_spec` is specified.""" + + service_account: Optional[str] + """Optional. The service account that the Reasoning Engine artifact runs as. It should have "roles/storage.objectViewer" for reading the user project's Cloud Storage and "roles/aiplatform.user" for using Vertex extensions. If not specified, the Vertex AI Reasoning Engine Service Agent in the project will be used.""" + + source_code_spec: Optional[ReasoningEngineSpecSourceCodeSpecDict] + """Deploy from source code files with a defined entrypoint.""" + + container_spec: Optional[ReasoningEngineSpecContainerSpecDict] + """Deploy from a container image with a defined entrypoint and commands.""" + + +ReasoningEngineSpecOrDict = Union[ReasoningEngineSpec, ReasoningEngineSpecDict] + + +class ReasoningEngineTrafficConfigTrafficSplitAlwaysLatest(_common.BaseModel): + """Traffic distribution configuration, where all traffic is sent to the latest Runtime Revision.""" + + pass + + +class ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestDict(TypedDict, total=False): + """Traffic distribution configuration, where all traffic is sent to the latest Runtime Revision.""" + + pass + + +ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestOrDict = Union[ + ReasoningEngineTrafficConfigTrafficSplitAlwaysLatest, + ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestDict, +] + + +class ReasoningEngineTrafficConfigTrafficSplitManualTarget(_common.BaseModel): + """A single target for the traffic split, specifying a Runtime Revision and the percentage of traffic to send to it.""" + + percent: Optional[int] = Field( + default=None, + description="""Required. Specifies percent of the traffic to this Runtime Revision.""", + ) + runtime_revision_name: Optional[str] = Field( + default=None, + description="""Required. The Runtime Revision name to which to send this portion of traffic, if traffic allocation is by Runtime Revision.""", + ) + + +class ReasoningEngineTrafficConfigTrafficSplitManualTargetDict(TypedDict, total=False): + """A single target for the traffic split, specifying a Runtime Revision and the percentage of traffic to send to it.""" + + percent: Optional[int] + """Required. Specifies percent of the traffic to this Runtime Revision.""" + + runtime_revision_name: Optional[str] + """Required. The Runtime Revision name to which to send this portion of traffic, if traffic allocation is by Runtime Revision.""" + + +ReasoningEngineTrafficConfigTrafficSplitManualTargetOrDict = Union[ + ReasoningEngineTrafficConfigTrafficSplitManualTarget, + ReasoningEngineTrafficConfigTrafficSplitManualTargetDict, +] + + +class ReasoningEngineTrafficConfigTrafficSplitManual(_common.BaseModel): + """Manual traffic distribution configuration, where the user specifies the Runtime Revision IDs and the percentage of traffic to send to each.""" + + targets: Optional[list[ReasoningEngineTrafficConfigTrafficSplitManualTarget]] = ( + Field( + default=None, + description="""A list of traffic targets for the Runtimes Revisions. The sum of percentages must equal to 100.""", + ) + ) + + +class ReasoningEngineTrafficConfigTrafficSplitManualDict(TypedDict, total=False): + """Manual traffic distribution configuration, where the user specifies the Runtime Revision IDs and the percentage of traffic to send to each.""" + + targets: Optional[list[ReasoningEngineTrafficConfigTrafficSplitManualTargetDict]] + """A list of traffic targets for the Runtimes Revisions. The sum of percentages must equal to 100.""" + + +ReasoningEngineTrafficConfigTrafficSplitManualOrDict = Union[ + ReasoningEngineTrafficConfigTrafficSplitManual, + ReasoningEngineTrafficConfigTrafficSplitManualDict, +] + + +class ReasoningEngineTrafficConfig(_common.BaseModel): + """Traffic distribution configuration.""" + + traffic_split_always_latest: Optional[ + ReasoningEngineTrafficConfigTrafficSplitAlwaysLatest + ] = Field( + default=None, + description="""Optional. Traffic distribution configuration, where all traffic is sent to the latest Runtime Revision.""", + ) + traffic_split_manual: Optional[ReasoningEngineTrafficConfigTrafficSplitManual] = ( + Field( + default=None, + description="""Optional. Manual traffic distribution configuration, where the user specifies the Runtime Revision IDs and the percentage of traffic to send to each.""", + ) + ) + + +class ReasoningEngineTrafficConfigDict(TypedDict, total=False): + """Traffic distribution configuration.""" + + traffic_split_always_latest: Optional[ + ReasoningEngineTrafficConfigTrafficSplitAlwaysLatestDict + ] + """Optional. Traffic distribution configuration, where all traffic is sent to the latest Runtime Revision.""" + + traffic_split_manual: Optional[ReasoningEngineTrafficConfigTrafficSplitManualDict] + """Optional. Manual traffic distribution configuration, where the user specifies the Runtime Revision IDs and the percentage of traffic to send to each.""" + + +ReasoningEngineTrafficConfigOrDict = Union[ + ReasoningEngineTrafficConfig, ReasoningEngineTrafficConfigDict +] + + +class ReasoningEngine(_common.BaseModel): + """An agent engine.""" + + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key spec for a ReasoningEngine. If set, this ReasoningEngine and all sub-resources of this ReasoningEngine will be secured by this key.""", + ) + context_spec: Optional[ReasoningEngineContextSpec] = Field( + default=None, + description="""Optional. Configuration for how Agent Engine sub-resources should manage context.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this ReasoningEngine was created.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. The description of the ReasoningEngine.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The display name of the ReasoningEngine.""", + ) + etag: Optional[str] = Field( + default=None, + description="""Optional. Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, description="""Labels for the ReasoningEngine.""" + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the ReasoningEngine. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`""", + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, description="""Optional. Configurations of the ReasoningEngine""" + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this ReasoningEngine was most recently updated.""", + ) + traffic_config: Optional[ReasoningEngineTrafficConfig] = Field( + default=None, + description="""Optional. Traffic distribution configuration for the Reasoning Engine.""", + ) + + +class ReasoningEngineDict(TypedDict, total=False): + """An agent engine.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key spec for a ReasoningEngine. If set, this ReasoningEngine and all sub-resources of this ReasoningEngine will be secured by this key.""" + + context_spec: Optional[ReasoningEngineContextSpecDict] + """Optional. Configuration for how Agent Engine sub-resources should manage context.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this ReasoningEngine was created.""" + + description: Optional[str] + """Optional. The description of the ReasoningEngine.""" + + display_name: Optional[str] + """Required. The display name of the ReasoningEngine.""" + + etag: Optional[str] + """Optional. Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""" + + labels: Optional[dict[str, str]] + """Labels for the ReasoningEngine.""" + + name: Optional[str] + """Identifier. The resource name of the ReasoningEngine. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}`""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the ReasoningEngine""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when this ReasoningEngine was most recently updated.""" + + traffic_config: Optional[ReasoningEngineTrafficConfigDict] + """Optional. Traffic distribution configuration for the Reasoning Engine.""" + + +ReasoningEngineOrDict = Union[ReasoningEngine, ReasoningEngineDict] + + +class AgentEngineOperation(_common.BaseModel): + """Operation that has an agent engine as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[ReasoningEngine] = Field( + default=None, description="""The created Agent Engine.""" + ) + + +class AgentEngineOperationDict(TypedDict, total=False): + """Operation that has an agent engine as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[ReasoningEngineDict] + """The created Agent Engine.""" + + +AgentEngineOperationOrDict = Union[AgentEngineOperation, AgentEngineOperationDict] + + +class CreateAgentEngineConfig(_common.BaseModel): + """Config for create agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """, + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, description="""Optional. Configurations of the Agent Engine.""" + ) + context_spec: Optional[ReasoningEngineContextSpec] = Field( + default=None, + description="""Optional. The context spec to be used for the Agent Engine.""", + ) + psc_interface_config: Optional[PscInterfaceConfig] = Field( + default=None, + description="""Optional. The PSC interface config for PSC-I to be used for the + Agent Engine.""", + ) + min_instances: Optional[int] = Field( + default=None, + description="""The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """, + ) + max_instances: Optional[int] = Field( + default=None, + description="""The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """, + ) + resource_limits: Optional[dict[str, str]] = Field( + default=None, + description="""The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """, + ) + container_concurrency: Optional[int] = Field( + default=None, + description="""The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """, + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""The encryption spec to be used for the Agent Engine.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, description="""The labels to be used for the Agent Engine.""" + ) + class_methods: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """, + ) + source_packages: Optional[list[str]] = Field( + default=None, + description="""The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """, + ) + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig + ] = Field( + default=None, + description="""Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""", + ) + entrypoint_module: Optional[str] = Field( + default=None, + description="""The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""", + ) + entrypoint_object: Optional[str] = Field( + default=None, + description="""The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""", + ) + requirements_file: Optional[str] = Field( + default=None, + description="""The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """, + ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] = Field( + default=None, + description="""The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """, + ) + build_options: Optional[dict[str, list[str]]] = Field( + default=None, + description="""The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """, + ) + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfig + ] = Field( + default=None, + description="""Agent Gateway configuration for a Reasoning Engine deployment.""", + ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) + + +class CreateAgentEngineConfigDict(TypedDict, total=False): + """Config for create agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """ + + description: Optional[str] + """The description of the Agent Engine.""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the Agent Engine.""" + + context_spec: Optional[ReasoningEngineContextSpecDict] + """Optional. The context spec to be used for the Agent Engine.""" + + psc_interface_config: Optional[PscInterfaceConfigDict] + """Optional. The PSC interface config for PSC-I to be used for the + Agent Engine.""" + + min_instances: Optional[int] + """The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """ + + max_instances: Optional[int] + """The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """ + + resource_limits: Optional[dict[str, str]] + """The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """ + + container_concurrency: Optional[int] + """The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """ + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """The encryption spec to be used for the Agent Engine.""" + + labels: Optional[dict[str, str]] + """The labels to be used for the Agent Engine.""" + + class_methods: Optional[list[dict[str, Any]]] + """The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """ + + source_packages: Optional[list[str]] + """The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """ + + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict + ] + """Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""" + + entrypoint_module: Optional[str] + """The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""" + + entrypoint_object: Optional[str] + """The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""" + + requirements_file: Optional[str] + """The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """ + + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] + """The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """ + + build_options: Optional[dict[str, list[str]]] + """The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """ + + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] + """Agent Gateway configuration for a Reasoning Engine deployment.""" + + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + + +CreateAgentEngineConfigOrDict = Union[ + CreateAgentEngineConfig, CreateAgentEngineConfigDict +] + + +class _CreateAgentEngineRequestParameters(_common.BaseModel): + """Parameters for creating agent engines.""" + + config: Optional[CreateAgentEngineConfig] = Field(default=None, description="""""") + + +class _CreateAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for creating agent engines.""" + + config: Optional[CreateAgentEngineConfigDict] + """""" + + +_CreateAgentEngineRequestParametersOrDict = Union[ + _CreateAgentEngineRequestParameters, _CreateAgentEngineRequestParametersDict +] + + +class DeleteAgentEngineConfig(_common.BaseModel): + """Config for deleting agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineConfigDict(TypedDict, total=False): + """Config for deleting agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineConfigOrDict = Union[ + DeleteAgentEngineConfig, DeleteAgentEngineConfigDict +] + + +class _DeleteAgentEngineRequestParameters(_common.BaseModel): + """Parameters for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + force: Optional[bool] = Field( + default=False, + description="""If set to true, any child resources will also be deleted.""", + ) + config: Optional[DeleteAgentEngineConfig] = Field(default=None, description="""""") + + +class _DeleteAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + force: Optional[bool] + """If set to true, any child resources will also be deleted.""" + + config: Optional[DeleteAgentEngineConfigDict] + """""" + + +_DeleteAgentEngineRequestParametersOrDict = Union[ + _DeleteAgentEngineRequestParameters, _DeleteAgentEngineRequestParametersDict +] + + +class DeleteAgentEngineOperation(_common.BaseModel): + """Operation for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineOperationDict(TypedDict, total=False): + """Operation for deleting agent engines.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineOperationOrDict = Union[ + DeleteAgentEngineOperation, DeleteAgentEngineOperationDict +] + + +class GetAgentEngineConfig(_common.BaseModel): + """Config for create agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineConfigDict(TypedDict, total=False): + """Config for create agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineConfigOrDict = Union[GetAgentEngineConfig, GetAgentEngineConfigDict] + + +class _GetAgentEngineRequestParameters(_common.BaseModel): + """Parameters for getting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[GetAgentEngineConfig] = Field(default=None, description="""""") + + +class _GetAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for getting agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[GetAgentEngineConfigDict] + """""" + + +_GetAgentEngineRequestParametersOrDict = Union[ + _GetAgentEngineRequestParameters, _GetAgentEngineRequestParametersDict +] + + +class ListAgentEngineConfig(_common.BaseModel): + """Config for listing agent engines.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineConfigDict(TypedDict, total=False): + """Config for listing agent engines.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineConfigOrDict = Union[ListAgentEngineConfig, ListAgentEngineConfigDict] + + +class _ListAgentEngineRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + config: Optional[ListAgentEngineConfig] = Field(default=None, description="""""") + + +class _ListAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + config: Optional[ListAgentEngineConfigDict] + """""" + + +_ListAgentEngineRequestParametersOrDict = Union[ + _ListAgentEngineRequestParameters, _ListAgentEngineRequestParametersDict +] + + +class ListReasoningEnginesResponse(_common.BaseModel): + """Response for listing agent engines.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + reasoning_engines: Optional[list[ReasoningEngine]] = Field( + default=None, + description="""List of agent engines. + """, + ) + + +class ListReasoningEnginesResponseDict(TypedDict, total=False): + """Response for listing agent engines.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + reasoning_engines: Optional[list[ReasoningEngineDict]] + """List of agent engines. + """ + + +ListReasoningEnginesResponseOrDict = Union[ + ListReasoningEnginesResponse, ListReasoningEnginesResponseDict +] + + +class GetAgentEngineOperationConfig(_common.BaseModel): + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineOperationConfigDict(TypedDict, total=False): + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineOperationConfigOrDict = Union[ + GetAgentEngineOperationConfig, GetAgentEngineOperationConfigDict +] + + +class _GetAgentEngineOperationParameters(_common.BaseModel): + """Parameters for getting an operation with an agent engine as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with an agent engine as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineOperationParametersOrDict = Union[ + _GetAgentEngineOperationParameters, _GetAgentEngineOperationParametersDict +] + + +class QueryAgentEngineConfig(_common.BaseModel): + """Config for querying agent engines.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + class_method: Optional[str] = Field( + default=None, description="""The class method to call.""" + ) + input: Optional[dict[str, Any]] = Field( + default=None, description="""The input to the class method.""" + ) + include_all_fields: Optional[bool] = Field(default=False, description="""""") + + +class QueryAgentEngineConfigDict(TypedDict, total=False): + """Config for querying agent engines.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + class_method: Optional[str] + """The class method to call.""" + + input: Optional[dict[str, Any]] + """The input to the class method.""" + + include_all_fields: Optional[bool] + """""" + + +QueryAgentEngineConfigOrDict = Union[QueryAgentEngineConfig, QueryAgentEngineConfigDict] + + +class _QueryAgentEngineRequestParameters(_common.BaseModel): + """Parameters for querying agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[QueryAgentEngineConfig] = Field(default=None, description="""""") + + +class _QueryAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for querying agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[QueryAgentEngineConfigDict] + """""" + + +_QueryAgentEngineRequestParametersOrDict = Union[ + _QueryAgentEngineRequestParameters, _QueryAgentEngineRequestParametersDict +] + + +class QueryReasoningEngineResponse(_common.BaseModel): + """The response for querying an agent engine.""" + + output: Optional[Any] = Field( + default=None, + description="""Response provided by users in JSON object format.""", + ) + + +class QueryReasoningEngineResponseDict(TypedDict, total=False): + """The response for querying an agent engine.""" + + output: Optional[Any] + """Response provided by users in JSON object format.""" + + +QueryReasoningEngineResponseOrDict = Union[ + QueryReasoningEngineResponse, QueryReasoningEngineResponseDict +] + + +class UpdateAgentEngineConfig(_common.BaseModel): + """Config for updating agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """, + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, description="""Optional. Configurations of the Agent Engine.""" + ) + context_spec: Optional[ReasoningEngineContextSpec] = Field( + default=None, + description="""Optional. The context spec to be used for the Agent Engine.""", + ) + psc_interface_config: Optional[PscInterfaceConfig] = Field( + default=None, + description="""Optional. The PSC interface config for PSC-I to be used for the + Agent Engine.""", + ) + min_instances: Optional[int] = Field( + default=None, + description="""The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """, + ) + max_instances: Optional[int] = Field( + default=None, + description="""The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """, + ) + resource_limits: Optional[dict[str, str]] = Field( + default=None, + description="""The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """, + ) + container_concurrency: Optional[int] = Field( + default=None, + description="""The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """, + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""The encryption spec to be used for the Agent Engine.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, description="""The labels to be used for the Agent Engine.""" + ) + class_methods: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """, + ) + source_packages: Optional[list[str]] = Field( + default=None, + description="""The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """, + ) + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig + ] = Field( + default=None, + description="""Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""", + ) + entrypoint_module: Optional[str] = Field( + default=None, + description="""The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""", + ) + entrypoint_object: Optional[str] = Field( + default=None, + description="""The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""", + ) + requirements_file: Optional[str] = Field( + default=None, + description="""The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """, + ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] = Field( + default=None, + description="""The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """, + ) + build_options: Optional[dict[str, list[str]]] = Field( + default=None, + description="""The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """, + ) + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfig + ] = Field( + default=None, + description="""Agent Gateway configuration for a Reasoning Engine deployment.""", + ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) + update_mask: Optional[str] = Field( + default=None, + description="""The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""", + ) + traffic_config: Optional[ReasoningEngineTrafficConfig] = Field( + default=None, + description="""Traffic distribution configuration for the Reasoning Engine.""", + ) + + +class UpdateAgentEngineConfigDict(TypedDict, total=False): + """Config for updating agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The display name can be up to 128 characters long and can comprise any + UTF-8 characters. + """ + + description: Optional[str] + """The description of the Agent Engine.""" + + spec: Optional[ReasoningEngineSpecDict] + """Optional. Configurations of the Agent Engine.""" + + context_spec: Optional[ReasoningEngineContextSpecDict] + """Optional. The context spec to be used for the Agent Engine.""" + + psc_interface_config: Optional[PscInterfaceConfigDict] + """Optional. The PSC interface config for PSC-I to be used for the + Agent Engine.""" + + min_instances: Optional[int] + """The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """ + + max_instances: Optional[int] + """The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """ + + resource_limits: Optional[dict[str, str]] + """The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """ + + container_concurrency: Optional[int] + """The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """ + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """The encryption spec to be used for the Agent Engine.""" + + labels: Optional[dict[str, str]] + """The labels to be used for the Agent Engine.""" + + class_methods: Optional[list[dict[str, Any]]] + """The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """ + + source_packages: Optional[list[str]] + """The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """ + + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict + ] + """Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""" + + entrypoint_module: Optional[str] + """The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""" + + entrypoint_object: Optional[str] + """The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""" + + requirements_file: Optional[str] + """The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """ + + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] + """The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """ + + build_options: Optional[dict[str, list[str]]] + """The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """ + + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] + """Agent Gateway configuration for a Reasoning Engine deployment.""" + + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + + update_mask: Optional[str] + """The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" + + traffic_config: Optional[ReasoningEngineTrafficConfigDict] + """Traffic distribution configuration for the Reasoning Engine.""" + + +UpdateAgentEngineConfigOrDict = Union[ + UpdateAgentEngineConfig, UpdateAgentEngineConfigDict +] + + +class _UpdateAgentEngineRequestParameters(_common.BaseModel): + """Parameters for updating agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[UpdateAgentEngineConfig] = Field(default=None, description="""""") + + +class _UpdateAgentEngineRequestParametersDict(TypedDict, total=False): + """Parameters for updating agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[UpdateAgentEngineConfigDict] + """""" + + +_UpdateAgentEngineRequestParametersOrDict = Union[ + _UpdateAgentEngineRequestParameters, _UpdateAgentEngineRequestParametersDict +] + + +class MemoryMetadataValue(_common.BaseModel): + """The metadata values for memories.""" + + bool_value: Optional[bool] = Field( + default=None, description="""Represents a boolean value.""" + ) + double_value: Optional[float] = Field( + default=None, description="""Represents a double value.""" + ) + string_value: Optional[str] = Field( + default=None, description="""Represents a string value.""" + ) + timestamp_value: Optional[datetime.datetime] = Field( + default=None, + description="""Represents a timestamp value. When filtering on timestamp values, only the seconds field will be compared.""", + ) + + +class MemoryMetadataValueDict(TypedDict, total=False): + """The metadata values for memories.""" + + bool_value: Optional[bool] + """Represents a boolean value.""" + + double_value: Optional[float] + """Represents a double value.""" + + string_value: Optional[str] + """Represents a string value.""" + + timestamp_value: Optional[datetime.datetime] + """Represents a timestamp value. When filtering on timestamp values, only the seconds field will be compared.""" + + +MemoryMetadataValueOrDict = Union[MemoryMetadataValue, MemoryMetadataValueDict] + + +class AgentEngineMemoryConfig(_common.BaseModel): + """Config for creating a Memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the memory.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the memory.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", + ) + revision_expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""", + ) + revision_ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""", + ) + disable_memory_revisions: Optional[bool] = Field( + default=None, + description="""Optional. Input only. If true, no revision will be created for this request.""", + ) + topics: Optional[list[MemoryTopicId]] = Field( + default=None, description="""Optional. The topics of the memory.""" + ) + metadata: Optional[dict[str, MemoryMetadataValue]] = Field( + default=None, + description="""Optional. User-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""", + ) + memory_id: Optional[str] = Field( + default=None, + description="""Optional. The user defined ID to use for memory, which will become the final component of the memory resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""", + ) + + +class AgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for creating a Memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the memory.""" + + description: Optional[str] + """The description of the memory.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + + revision_expire_time: Optional[datetime.datetime] + """Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""" + + revision_ttl: Optional[str] + """Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""" + + disable_memory_revisions: Optional[bool] + """Optional. Input only. If true, no revision will be created for this request.""" + + topics: Optional[list[MemoryTopicIdDict]] + """Optional. The topics of the memory.""" + + metadata: Optional[dict[str, MemoryMetadataValueDict]] + """Optional. User-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""" + + memory_id: Optional[str] + """Optional. The user defined ID to use for memory, which will become the final component of the memory resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""" + + +AgentEngineMemoryConfigOrDict = Union[ + AgentEngineMemoryConfig, AgentEngineMemoryConfigDict +] + + +class _CreateAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for creating Agent Engine Memories.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to create the memory under.""", + ) + fact: Optional[str] = Field( + default=None, + description="""The fact of the memory. + + This is the semantic knowledge extracted from the source content).""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The scope of the memory. + + Memories are isolated within their scope. The scope is defined when + creating or generating memories. Up to 5 key-value pairs are accepted, + andscope values cannot contain the wildcard character '*'.""", + ) + config: Optional[AgentEngineMemoryConfig] = Field(default=None, description="""""") + + +class _CreateAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for creating Agent Engine Memories.""" + + name: Optional[str] + """Name of the agent engine to create the memory under.""" + + fact: Optional[str] + """The fact of the memory. + + This is the semantic knowledge extracted from the source content).""" + + scope: Optional[dict[str, str]] + """The scope of the memory. + + Memories are isolated within their scope. The scope is defined when + creating or generating memories. Up to 5 key-value pairs are accepted, + andscope values cannot contain the wildcard character '*'.""" + + config: Optional[AgentEngineMemoryConfigDict] + """""" + + +_CreateAgentEngineMemoryRequestParametersOrDict = Union[ + _CreateAgentEngineMemoryRequestParameters, + _CreateAgentEngineMemoryRequestParametersDict, +] + + +class MemoryStructuredContent(_common.BaseModel): + """Represents the structured value of the memory.""" + + data: Optional[dict[str, Any]] = Field( + default=None, + description="""Required. Represents the structured value of the memory.""", + ) + schema_id: Optional[str] = Field( + default=None, + description="""Required. Represents the schema ID for which this structured memory belongs to.""", + ) + + +class MemoryStructuredContentDict(TypedDict, total=False): + """Represents the structured value of the memory.""" + + data: Optional[dict[str, Any]] + """Required. Represents the structured value of the memory.""" + + schema_id: Optional[str] + """Required. Represents the schema ID for which this structured memory belongs to.""" + + +MemoryStructuredContentOrDict = Union[ + MemoryStructuredContent, MemoryStructuredContentDict +] + + +class Memory(_common.BaseModel): + """A memory.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Represents the timestamp when this Memory was created.""", + ) + description: Optional[str] = Field( + default=None, + description="""Optional. Represents the description of the Memory.""", + ) + disable_memory_revisions: Optional[bool] = Field( + default=None, + description="""Optional. Input only. Indicates whether no revision will be created for this request.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Optional. Represents the display name of the Memory.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Represents the timestamp of when this resource is considered expired. This is *always* provided on output when `expiration` is set on input, regardless of whether `expire_time` or `ttl` was provided.""", + ) + fact: Optional[str] = Field( + default=None, + description="""Optional. Represents semantic knowledge extracted from the source content.""", + ) + metadata: Optional[dict[str, MemoryMetadataValue]] = Field( + default=None, + description="""Optional. Represents user-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. Represents the resource name of the Memory. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}`""", + ) + revision_expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Input only. Represents the timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""", + ) + revision_labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. Input only. Represents the labels to apply to the Memory Revision created as a result of this request.""", + ) + revision_ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. Represents the TTL for the revision. The expiration time is computed: now + TTL.""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""Required. Immutable. Represents the scope of the Memory. Memories are isolated within their scope. The scope is defined when creating or generating memories. Scope values cannot contain the wildcard character '*'.""", + ) + topics: Optional[list[MemoryTopicId]] = Field( + default=None, description="""Optional. Represents the Topics of the Memory.""" + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. Represents the TTL for this resource. The expiration time is computed: now + TTL.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Represents the timestamp when this Memory was most recently updated.""", + ) + memory_type: Optional[MemoryType] = Field( + default=None, + description="""Optional. Represents the type of the memory. If not set, the `NATURAL_LANGUAGE_COLLECTION` type is used. If `STRUCTURED_COLLECTION` or `STRUCTURED_PROFILE` is used, then `structured_data` must be provided.""", + ) + structured_content: Optional[MemoryStructuredContent] = Field( + default=None, + description="""Optional. Represents the structured content of the memory.""", + ) + + +class MemoryDict(TypedDict, total=False): + """A memory.""" + + create_time: Optional[datetime.datetime] + """Output only. Represents the timestamp when this Memory was created.""" + + description: Optional[str] + """Optional. Represents the description of the Memory.""" + + disable_memory_revisions: Optional[bool] + """Optional. Input only. Indicates whether no revision will be created for this request.""" + + display_name: Optional[str] + """Optional. Represents the display name of the Memory.""" + + expire_time: Optional[datetime.datetime] + """Optional. Represents the timestamp of when this resource is considered expired. This is *always* provided on output when `expiration` is set on input, regardless of whether `expire_time` or `ttl` was provided.""" + + fact: Optional[str] + """Optional. Represents semantic knowledge extracted from the source content.""" + + metadata: Optional[dict[str, MemoryMetadataValueDict]] + """Optional. Represents user-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""" + + name: Optional[str] + """Identifier. Represents the resource name of the Memory. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}`""" + + revision_expire_time: Optional[datetime.datetime] + """Optional. Input only. Represents the timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""" + + revision_labels: Optional[dict[str, str]] + """Optional. Input only. Represents the labels to apply to the Memory Revision created as a result of this request.""" + + revision_ttl: Optional[str] + """Optional. Input only. Represents the TTL for the revision. The expiration time is computed: now + TTL.""" + + scope: Optional[dict[str, str]] + """Required. Immutable. Represents the scope of the Memory. Memories are isolated within their scope. The scope is defined when creating or generating memories. Scope values cannot contain the wildcard character '*'.""" + + topics: Optional[list[MemoryTopicIdDict]] + """Optional. Represents the Topics of the Memory.""" + + ttl: Optional[str] + """Optional. Input only. Represents the TTL for this resource. The expiration time is computed: now + TTL.""" + + update_time: Optional[datetime.datetime] + """Output only. Represents the timestamp when this Memory was most recently updated.""" + + memory_type: Optional[MemoryType] + """Optional. Represents the type of the memory. If not set, the `NATURAL_LANGUAGE_COLLECTION` type is used. If `STRUCTURED_COLLECTION` or `STRUCTURED_PROFILE` is used, then `structured_data` must be provided.""" + + structured_content: Optional[MemoryStructuredContentDict] + """Optional. Represents the structured content of the memory.""" + + +MemoryOrDict = Union[Memory, MemoryDict] + + +class AgentEngineMemoryOperation(_common.BaseModel): + """Operation that has an agent engine memory as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[Memory] = Field( + default=None, description="""The Agent Engine Memory.""" + ) + + +class AgentEngineMemoryOperationDict(TypedDict, total=False): + """Operation that has an agent engine memory as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[MemoryDict] + """The Agent Engine Memory.""" + + +AgentEngineMemoryOperationOrDict = Union[ + AgentEngineMemoryOperation, AgentEngineMemoryOperationDict +] + + +class DeleteAgentEngineMemoryConfig(_common.BaseModel): + """Config for deleting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for deleting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineMemoryConfigOrDict = Union[ + DeleteAgentEngineMemoryConfig, DeleteAgentEngineMemoryConfigDict +] + + +class _DeleteAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine memory to delete.""" + ) + config: Optional[DeleteAgentEngineMemoryConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engines.""" + + name: Optional[str] + """Name of the agent engine memory to delete.""" + + config: Optional[DeleteAgentEngineMemoryConfigDict] + """""" + + +_DeleteAgentEngineMemoryRequestParametersOrDict = Union[ + _DeleteAgentEngineMemoryRequestParameters, + _DeleteAgentEngineMemoryRequestParametersDict, +] + + +class DeleteAgentEngineMemoryOperation(_common.BaseModel): + """Operation for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineMemoryOperationDict(TypedDict, total=False): + """Operation for deleting agent engines.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineMemoryOperationOrDict = Union[ + DeleteAgentEngineMemoryOperation, DeleteAgentEngineMemoryOperationDict +] + + +class GenerateMemoriesRequestVertexSessionSource(_common.BaseModel): + """The vertex session source for generating memories.""" + + end_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. End time (exclusive) of the time range. If not set, the end time is unbounded.""", + ) + session: Optional[str] = Field( + default=None, + description="""Required. The resource name of the Session to generate memories for. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}`""", + ) + start_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Time range to define which session events should be used to generate memories. Start time (inclusive) of the time range. If not set, the start time is unbounded.""", + ) + + +class GenerateMemoriesRequestVertexSessionSourceDict(TypedDict, total=False): + """The vertex session source for generating memories.""" + + end_time: Optional[datetime.datetime] + """Optional. End time (exclusive) of the time range. If not set, the end time is unbounded.""" + + session: Optional[str] + """Required. The resource name of the Session to generate memories for. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}`""" + + start_time: Optional[datetime.datetime] + """Optional. Time range to define which session events should be used to generate memories. Start time (inclusive) of the time range. If not set, the start time is unbounded.""" + + +GenerateMemoriesRequestVertexSessionSourceOrDict = Union[ + GenerateMemoriesRequestVertexSessionSource, + GenerateMemoriesRequestVertexSessionSourceDict, +] + + +class GenerateMemoriesRequestDirectContentsSourceEvent(_common.BaseModel): + + content: Optional[genai_types.Content] = Field( + default=None, + description="""Required. A single piece of content from which to generate memories.""", + ) + + +class GenerateMemoriesRequestDirectContentsSourceEventDict(TypedDict, total=False): + + content: Optional[genai_types.ContentDict] + """Required. A single piece of content from which to generate memories.""" + + +GenerateMemoriesRequestDirectContentsSourceEventOrDict = Union[ + GenerateMemoriesRequestDirectContentsSourceEvent, + GenerateMemoriesRequestDirectContentsSourceEventDict, +] + + +class GenerateMemoriesRequestDirectContentsSource(_common.BaseModel): + """The direct contents source for generating memories.""" + + events: Optional[list[GenerateMemoriesRequestDirectContentsSourceEvent]] = Field( + default=None, + description="""Required. The source content (i.e. chat history) to generate memories from.""", + ) + + +class GenerateMemoriesRequestDirectContentsSourceDict(TypedDict, total=False): + """The direct contents source for generating memories.""" + + events: Optional[list[GenerateMemoriesRequestDirectContentsSourceEventDict]] + """Required. The source content (i.e. chat history) to generate memories from.""" + + +GenerateMemoriesRequestDirectContentsSourceOrDict = Union[ + GenerateMemoriesRequestDirectContentsSource, + GenerateMemoriesRequestDirectContentsSourceDict, +] + + +class GenerateMemoriesRequestDirectMemoriesSourceDirectMemory(_common.BaseModel): + """A direct memory to upload to Memory Bank.""" + + fact: Optional[str] = Field( + default=None, + description="""Required. The fact to consolidate with existing memories.""", + ) + topics: Optional[list[MemoryTopicId]] = Field( + default=None, + description="""Optional. The topics that the consolidated memories should be associated with.""", + ) + + +class GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict( + TypedDict, total=False +): + """A direct memory to upload to Memory Bank.""" + + fact: Optional[str] + """Required. The fact to consolidate with existing memories.""" + + topics: Optional[list[MemoryTopicIdDict]] + """Optional. The topics that the consolidated memories should be associated with.""" + + +GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryOrDict = Union[ + GenerateMemoriesRequestDirectMemoriesSourceDirectMemory, + GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict, +] + + +class GenerateMemoriesRequestDirectMemoriesSource(_common.BaseModel): + """The direct memories source for generating memories.""" + + direct_memories: Optional[ + list[GenerateMemoriesRequestDirectMemoriesSourceDirectMemory] + ] = Field( + default=None, + description="""Required. The direct memories to upload to Memory Bank. At most 5 direct memories are allowed per request.""", + ) + + +class GenerateMemoriesRequestDirectMemoriesSourceDict(TypedDict, total=False): + """The direct memories source for generating memories.""" + + direct_memories: Optional[ + list[GenerateMemoriesRequestDirectMemoriesSourceDirectMemoryDict] + ] + """Required. The direct memories to upload to Memory Bank. At most 5 direct memories are allowed per request.""" + + +GenerateMemoriesRequestDirectMemoriesSourceOrDict = Union[ + GenerateMemoriesRequestDirectMemoriesSource, + GenerateMemoriesRequestDirectMemoriesSourceDict, +] + + +class GenerateAgentEngineMemoriesConfig(_common.BaseModel): + """Config for generating memories.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + disable_consolidation: Optional[bool] = Field( + default=None, + description="""Whether to disable consolidation of memories. + + If true, generated memories will not be consolidated with existing + memories; all generated memories will be added as new memories regardless + of whether they are duplicates of or contradictory to existing memories. + By default, memory consolidation is enabled.""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + revision_labels: Optional[dict[str, str]] = Field( + default=None, + description="""Labels to apply to the memory revision. For example, you can use this to label a revision with its data source.""", + ) + revision_expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""", + ) + revision_ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""", + ) + disable_memory_revisions: Optional[bool] = Field( + default=None, + description="""Optional. Input only. If true, no revisions will be created for this request.""", + ) + metadata: Optional[dict[str, MemoryMetadataValue]] = Field( + default=None, + description="""Optional. User-provided metadata for the generated memories. This is not generated by Memory Bank.""", + ) + metadata_merge_strategy: Optional[MemoryMetadataMergeStrategy] = Field( + default=None, + description="""Optional. The strategy to use when applying metadata to existing memories.""", + ) + allowed_topics: Optional[list[MemoryTopicId]] = Field( + default=None, + description="""Optional. Restricts memory generation to a subset of memory topics.""", + ) + + +class GenerateAgentEngineMemoriesConfigDict(TypedDict, total=False): + """Config for generating memories.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + disable_consolidation: Optional[bool] + """Whether to disable consolidation of memories. + + If true, generated memories will not be consolidated with existing + memories; all generated memories will be added as new memories regardless + of whether they are duplicates of or contradictory to existing memories. + By default, memory consolidation is enabled.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + revision_labels: Optional[dict[str, str]] + """Labels to apply to the memory revision. For example, you can use this to label a revision with its data source.""" + + revision_expire_time: Optional[datetime.datetime] + """Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""" + + revision_ttl: Optional[str] + """Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""" + + disable_memory_revisions: Optional[bool] + """Optional. Input only. If true, no revisions will be created for this request.""" + + metadata: Optional[dict[str, MemoryMetadataValueDict]] + """Optional. User-provided metadata for the generated memories. This is not generated by Memory Bank.""" + + metadata_merge_strategy: Optional[MemoryMetadataMergeStrategy] + """Optional. The strategy to use when applying metadata to existing memories.""" + + allowed_topics: Optional[list[MemoryTopicIdDict]] + """Optional. Restricts memory generation to a subset of memory topics.""" + + +GenerateAgentEngineMemoriesConfigOrDict = Union[ + GenerateAgentEngineMemoriesConfig, GenerateAgentEngineMemoriesConfigDict +] + + +class _GenerateAgentEngineMemoriesRequestParameters(_common.BaseModel): + """Parameters for generating agent engine memories.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to generate memories for.""", + ) + vertex_session_source: Optional[GenerateMemoriesRequestVertexSessionSource] = Field( + default=None, + description="""The vertex session source of the memories that should be generated.""", + ) + direct_contents_source: Optional[GenerateMemoriesRequestDirectContentsSource] = ( + Field( + default=None, + description="""The direct contents source of the memories that should be generated.""", + ) + ) + direct_memories_source: Optional[GenerateMemoriesRequestDirectMemoriesSource] = ( + Field( + default=None, + description="""The direct memories source of the memories that should be generated.""", + ) + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The scope of the memories that should be generated. + + Memories will be consolidated across memories with the same scope. Must be + provided unless the scope is defined in the source content. If `scope` is + provided, it will override the scope defined in the source content. Scope + values cannot contain the wildcard character '*'.""", + ) + config: Optional[GenerateAgentEngineMemoriesConfig] = Field( + default=None, description="""""" + ) + + +class _GenerateAgentEngineMemoriesRequestParametersDict(TypedDict, total=False): + """Parameters for generating agent engine memories.""" + + name: Optional[str] + """Name of the agent engine to generate memories for.""" + + vertex_session_source: Optional[GenerateMemoriesRequestVertexSessionSourceDict] + """The vertex session source of the memories that should be generated.""" + + direct_contents_source: Optional[GenerateMemoriesRequestDirectContentsSourceDict] + """The direct contents source of the memories that should be generated.""" + + direct_memories_source: Optional[GenerateMemoriesRequestDirectMemoriesSourceDict] + """The direct memories source of the memories that should be generated.""" + + scope: Optional[dict[str, str]] + """The scope of the memories that should be generated. + + Memories will be consolidated across memories with the same scope. Must be + provided unless the scope is defined in the source content. If `scope` is + provided, it will override the scope defined in the source content. Scope + values cannot contain the wildcard character '*'.""" + + config: Optional[GenerateAgentEngineMemoriesConfigDict] + """""" + + +_GenerateAgentEngineMemoriesRequestParametersOrDict = Union[ + _GenerateAgentEngineMemoriesRequestParameters, + _GenerateAgentEngineMemoriesRequestParametersDict, +] + + +class GenerateMemoriesResponseGeneratedMemory(_common.BaseModel): + """A memory that was generated.""" + + memory: Optional[Memory] = Field( + default=None, description="""The generated memory.""" + ) + action: Optional[GenerateMemoriesResponseGeneratedMemoryAction] = Field( + default=None, description="""The action to take.""" + ) + previous_revision: Optional[str] = Field( + default=None, + description="""The previous revision of the Memory before the action was performed. This + field is only set if the action is `UPDATED` or `DELETED`. You can use + this to rollback the Memory to the previous revision, undoing the action. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}/revisions/{revision}`""", + ) + + +class GenerateMemoriesResponseGeneratedMemoryDict(TypedDict, total=False): + """A memory that was generated.""" + + memory: Optional[MemoryDict] + """The generated memory.""" + + action: Optional[GenerateMemoriesResponseGeneratedMemoryAction] + """The action to take.""" + + previous_revision: Optional[str] + """The previous revision of the Memory before the action was performed. This + field is only set if the action is `UPDATED` or `DELETED`. You can use + this to rollback the Memory to the previous revision, undoing the action. + Format: + `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}/revisions/{revision}`""" + + +GenerateMemoriesResponseGeneratedMemoryOrDict = Union[ + GenerateMemoriesResponseGeneratedMemory, GenerateMemoriesResponseGeneratedMemoryDict +] + + +class GenerateMemoriesResponse(_common.BaseModel): + """The response for generating memories.""" + + generated_memories: Optional[list[GenerateMemoriesResponseGeneratedMemory]] = Field( + default=None, description="""The generated memories.""" + ) + + +class GenerateMemoriesResponseDict(TypedDict, total=False): + """The response for generating memories.""" + + generated_memories: Optional[list[GenerateMemoriesResponseGeneratedMemoryDict]] + """The generated memories.""" + + +GenerateMemoriesResponseOrDict = Union[ + GenerateMemoriesResponse, GenerateMemoriesResponseDict +] + + +class AgentEngineGenerateMemoriesOperation(_common.BaseModel): + """Operation that generates memories for an agent engine.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[GenerateMemoriesResponse] = Field( + default=None, description="""The response for generating memories.""" + ) + + +class AgentEngineGenerateMemoriesOperationDict(TypedDict, total=False): + """Operation that generates memories for an agent engine.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[GenerateMemoriesResponseDict] + """The response for generating memories.""" + + +AgentEngineGenerateMemoriesOperationOrDict = Union[ + AgentEngineGenerateMemoriesOperation, AgentEngineGenerateMemoriesOperationDict +] + + +class GetAgentEngineMemoryConfig(_common.BaseModel): + """Config for getting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineMemoryConfigOrDict = Union[ + GetAgentEngineMemoryConfig, GetAgentEngineMemoryConfigDict +] + + +class _GetAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for getting an agent engine.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[GetAgentEngineMemoryConfig] = Field( + default=None, description="""""" + ) + + +class _GetAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for getting an agent engine.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[GetAgentEngineMemoryConfigDict] + """""" + + +_GetAgentEngineMemoryRequestParametersOrDict = Union[ + _GetAgentEngineMemoryRequestParameters, _GetAgentEngineMemoryRequestParametersDict +] + + +class IngestionDirectContentsSourceEvent(_common.BaseModel): + """The direct contents source event for ingesting events.""" + + content: Optional[genai_types.Content] = Field( + default=None, description="""Required. The content of the event.""" + ) + event_id: Optional[str] = Field( + default=None, + description="""Optional. A unique identifier for the event. If an event with the same event_id is ingested multiple times, it will be de-duplicated.""", + ) + event_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. The time at which the event occurred. If provided, this timestamp will be used for ordering events within a stream. If not provided, the server-side ingestion time will be used.""", + ) + + +class IngestionDirectContentsSourceEventDict(TypedDict, total=False): + """The direct contents source event for ingesting events.""" + + content: Optional[genai_types.ContentDict] + """Required. The content of the event.""" + + event_id: Optional[str] + """Optional. A unique identifier for the event. If an event with the same event_id is ingested multiple times, it will be de-duplicated.""" + + event_time: Optional[datetime.datetime] + """Optional. The time at which the event occurred. If provided, this timestamp will be used for ordering events within a stream. If not provided, the server-side ingestion time will be used.""" + + +IngestionDirectContentsSourceEventOrDict = Union[ + IngestionDirectContentsSourceEvent, IngestionDirectContentsSourceEventDict +] + + +class IngestionDirectContentsSource(_common.BaseModel): + """The direct contents source for ingesting events.""" + + events: Optional[list[IngestionDirectContentsSourceEvent]] = Field( + default=None, description="""Required. The events to ingest.""" + ) + + +class IngestionDirectContentsSourceDict(TypedDict, total=False): + """The direct contents source for ingesting events.""" + + events: Optional[list[IngestionDirectContentsSourceEventDict]] + """Required. The events to ingest.""" + + +IngestionDirectContentsSourceOrDict = Union[ + IngestionDirectContentsSource, IngestionDirectContentsSourceDict +] + + +class IngestEventsConfig(_common.BaseModel): + """Config for ingesting events.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=False, + description="""Waits for the underlying memory generation operation to complete + before returning. Defaults to false.""", + ) + force_flush: Optional[bool] = Field( + default=None, + description="""Optional. Forces a flush of all pending events in the stream and triggers memory generation immediately bypassing any conditions configured in the `generation_trigger_config`.""", + ) + + +class IngestEventsConfigDict(TypedDict, total=False): + """Config for ingesting events.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Waits for the underlying memory generation operation to complete + before returning. Defaults to false.""" + + force_flush: Optional[bool] + """Optional. Forces a flush of all pending events in the stream and triggers memory generation immediately bypassing any conditions configured in the `generation_trigger_config`.""" + + +IngestEventsConfigOrDict = Union[IngestEventsConfig, IngestEventsConfigDict] + + +class _IngestEventsRequestParameters(_common.BaseModel): + """Parameters for purging agent engine memories.""" + + name: Optional[str] = Field( + default=None, description="""Name of the Agent Engine to ingest events into.""" + ) + stream_id: Optional[str] = Field( + default=None, description="""The ID of the stream to ingest events into.""" + ) + direct_contents_source: Optional[IngestionDirectContentsSource] = Field( + default=None, + description="""The direct memories source of the events that should be ingested.""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The scope of the memories that should be generated from the stream. + + Memories will be consolidated across memories with the same scope. Scope + values cannot contain the wildcard character '*'.""", + ) + generation_trigger_config: Optional[MemoryGenerationTriggerConfig] = Field( + default=None, + description="""The configuration for the memory generation trigger.""", + ) + config: Optional[IngestEventsConfig] = Field(default=None, description="""""") + + +class _IngestEventsRequestParametersDict(TypedDict, total=False): + """Parameters for purging agent engine memories.""" + + name: Optional[str] + """Name of the Agent Engine to ingest events into.""" + + stream_id: Optional[str] + """The ID of the stream to ingest events into.""" + + direct_contents_source: Optional[IngestionDirectContentsSourceDict] + """The direct memories source of the events that should be ingested.""" + + scope: Optional[dict[str, str]] + """The scope of the memories that should be generated from the stream. + + Memories will be consolidated across memories with the same scope. Scope + values cannot contain the wildcard character '*'.""" + + generation_trigger_config: Optional[MemoryGenerationTriggerConfigDict] + """The configuration for the memory generation trigger.""" + + config: Optional[IngestEventsConfigDict] + """""" + + +_IngestEventsRequestParametersOrDict = Union[ + _IngestEventsRequestParameters, _IngestEventsRequestParametersDict +] + + +class MemoryBankIngestEventsOperation(_common.BaseModel): + """Operation that ingests events into a memory bank.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class MemoryBankIngestEventsOperationDict(TypedDict, total=False): + """Operation that ingests events into a memory bank.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +MemoryBankIngestEventsOperationOrDict = Union[ + MemoryBankIngestEventsOperation, MemoryBankIngestEventsOperationDict +] + + +class ListAgentEngineMemoryConfig(_common.BaseModel): + """Config for listing agent engine memories.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + order_by: Optional[str] = Field( + default=None, + description="""The standard list order by string. If not specified, the default + order is `create_time desc`. If specified, the default sorting order of + provided fields is ascending. More detail in + [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time`""", + ) + + +class ListAgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for listing agent engine memories.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + order_by: Optional[str] + """The standard list order by string. If not specified, the default + order is `create_time desc`. If specified, the default sorting order of + provided fields is ascending. More detail in + [AIP-132](https://google.aip.dev/132). + + Supported fields: + * `create_time` + * `update_time`""" + + +ListAgentEngineMemoryConfigOrDict = Union[ + ListAgentEngineMemoryConfig, ListAgentEngineMemoryConfigDict +] + + +class _ListAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[ListAgentEngineMemoryConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[ListAgentEngineMemoryConfigDict] + """""" + + +_ListAgentEngineMemoryRequestParametersOrDict = Union[ + _ListAgentEngineMemoryRequestParameters, _ListAgentEngineMemoryRequestParametersDict +] + + +class ListReasoningEnginesMemoriesResponse(_common.BaseModel): + """Response for listing agent engine memories.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + memories: Optional[list[Memory]] = Field( + default=None, description="""List of agent engine memories.""" + ) + + +class ListReasoningEnginesMemoriesResponseDict(TypedDict, total=False): + """Response for listing agent engine memories.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + memories: Optional[list[MemoryDict]] + """List of agent engine memories.""" + + +ListReasoningEnginesMemoriesResponseOrDict = Union[ + ListReasoningEnginesMemoriesResponse, ListReasoningEnginesMemoriesResponseDict +] + + +class _GetAgentEngineMemoryOperationParameters(_common.BaseModel): + """Parameters for getting an operation with a memory as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineMemoryOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with a memory as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineMemoryOperationParametersOrDict = Union[ + _GetAgentEngineMemoryOperationParameters, + _GetAgentEngineMemoryOperationParametersDict, +] + + +class _GetAgentEngineGenerateMemoriesOperationParameters(_common.BaseModel): + """Parameters for getting an operation with generated memories as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineGenerateMemoriesOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with generated memories as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineGenerateMemoriesOperationParametersOrDict = Union[ + _GetAgentEngineGenerateMemoriesOperationParameters, + _GetAgentEngineGenerateMemoriesOperationParametersDict, +] + + +class RetrieveMemoriesRequestSimilaritySearchParams(_common.BaseModel): + """The parameters for semantic similarity search based retrieval.""" + + search_query: Optional[str] = Field( + default=None, + description="""Required. Query to use for similarity search retrieval. If provided, then the parent ReasoningEngine must have ReasoningEngineContextSpec.MemoryBankConfig.SimilaritySearchConfig set.""", + ) + top_k: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of memories to return. The service may return fewer than this value. If unspecified, at most 3 memories will be returned. The maximum value is 100; values above 100 will be coerced to 100.""", + ) + + +class RetrieveMemoriesRequestSimilaritySearchParamsDict(TypedDict, total=False): + """The parameters for semantic similarity search based retrieval.""" + + search_query: Optional[str] + """Required. Query to use for similarity search retrieval. If provided, then the parent ReasoningEngine must have ReasoningEngineContextSpec.MemoryBankConfig.SimilaritySearchConfig set.""" + + top_k: Optional[int] + """Optional. The maximum number of memories to return. The service may return fewer than this value. If unspecified, at most 3 memories will be returned. The maximum value is 100; values above 100 will be coerced to 100.""" + + +RetrieveMemoriesRequestSimilaritySearchParamsOrDict = Union[ + RetrieveMemoriesRequestSimilaritySearchParams, + RetrieveMemoriesRequestSimilaritySearchParamsDict, +] + + +class RetrieveMemoriesRequestSimpleRetrievalParams(_common.BaseModel): + """The parameters for simple (non-similarity search) retrieval.""" + + page_size: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of memories to return. The service may return fewer than this value. If unspecified, at most 3 memories will be returned. The maximum value is 100; values above 100 will be coerced to 100.""", + ) + page_token: Optional[str] = Field( + default=None, + description="""Optional. A page token, received from a previous `RetrieveMemories` call. Provide this to retrieve the subsequent page.""", + ) + + +class RetrieveMemoriesRequestSimpleRetrievalParamsDict(TypedDict, total=False): + """The parameters for simple (non-similarity search) retrieval.""" + + page_size: Optional[int] + """Optional. The maximum number of memories to return. The service may return fewer than this value. If unspecified, at most 3 memories will be returned. The maximum value is 100; values above 100 will be coerced to 100.""" + + page_token: Optional[str] + """Optional. A page token, received from a previous `RetrieveMemories` call. Provide this to retrieve the subsequent page.""" + + +RetrieveMemoriesRequestSimpleRetrievalParamsOrDict = Union[ + RetrieveMemoriesRequestSimpleRetrievalParams, + RetrieveMemoriesRequestSimpleRetrievalParamsDict, +] + + +class MemoryFilter(_common.BaseModel): + """Filter to apply when retrieving memories.""" + + key: Optional[str] = Field( + default=None, + description="""Represents the key of the filter. For example, "author" would apply to `metadata` entries with the key "author".""", + ) + negate: Optional[bool] = Field( + default=None, description="""Indicates whether the filter will be negated.""" + ) + op: Optional[Operator] = Field( + default=None, + description="""Represents the operator to apply to the filter. If not set, then EQUAL will be used.""", + ) + value: Optional[MemoryMetadataValue] = Field( + default=None, description="""Represents the value to compare to.""" + ) + + +class MemoryFilterDict(TypedDict, total=False): + """Filter to apply when retrieving memories.""" + + key: Optional[str] + """Represents the key of the filter. For example, "author" would apply to `metadata` entries with the key "author".""" + + negate: Optional[bool] + """Indicates whether the filter will be negated.""" + + op: Optional[Operator] + """Represents the operator to apply to the filter. If not set, then EQUAL will be used.""" + + value: Optional[MemoryMetadataValueDict] + """Represents the value to compare to.""" + + +MemoryFilterOrDict = Union[MemoryFilter, MemoryFilterDict] + + +class MemoryConjunctionFilter(_common.BaseModel): + """The conjunction filter for memories.""" + + filters: Optional[list[MemoryFilter]] = Field( + default=None, + description="""Represents filters that will be combined using AND logic.""", + ) + + +class MemoryConjunctionFilterDict(TypedDict, total=False): + """The conjunction filter for memories.""" + + filters: Optional[list[MemoryFilterDict]] + """Represents filters that will be combined using AND logic.""" + + +MemoryConjunctionFilterOrDict = Union[ + MemoryConjunctionFilter, MemoryConjunctionFilterDict +] + + +class RetrieveAgentEngineMemoriesConfig(_common.BaseModel): + """Config for retrieving memories.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + filter: Optional[str] = Field( + default=None, + description="""The standard list filter that will be applied to the retrieved + memories. More detail in [AIP-160](https://google.aip.dev/160). + + Supported fields: + * `fact` + * `create_time` + * `update_time` + """, + ) + filter_groups: Optional[list[MemoryConjunctionFilter]] = Field( + default=None, + description="""Metadata filters that will be applied to the retrieved memories' + `metadata` using OR logic. Filters are defined using disjunctive normal + form (OR of ANDs). + + For example: + `filter_groups: [{filters: [{key: "author", value: {string_value: "agent + `123"}, op: EQUAL}]}, {filters: [{key: "label", value: {string_value: + "travel"}, op: EQUAL}, {key: "author", value: {string_value: "agent 321"}, + op: EQUAL}]}]` + + would be equivalent to the logical expression: + `(metadata.author = "agent 123" OR (metadata.label = "travel" AND + metadata.author = "agent 321"))`. + """, + ) + memory_types: Optional[list[MemoryType]] = Field( + default=None, + description="""Specifies the types of memories to retrieve. If this field is empty + or not provided, the request will default to retrieving only memories of + type `NATURAL_LANGUAGE_COLLECTION`. If populated, the request will + retrieve memories matching any of the specified `MemoryType` values.""", + ) + + +class RetrieveAgentEngineMemoriesConfigDict(TypedDict, total=False): + """Config for retrieving memories.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + filter: Optional[str] + """The standard list filter that will be applied to the retrieved + memories. More detail in [AIP-160](https://google.aip.dev/160). + + Supported fields: + * `fact` + * `create_time` + * `update_time` + """ + + filter_groups: Optional[list[MemoryConjunctionFilterDict]] + """Metadata filters that will be applied to the retrieved memories' + `metadata` using OR logic. Filters are defined using disjunctive normal + form (OR of ANDs). + + For example: + `filter_groups: [{filters: [{key: "author", value: {string_value: "agent + `123"}, op: EQUAL}]}, {filters: [{key: "label", value: {string_value: + "travel"}, op: EQUAL}, {key: "author", value: {string_value: "agent 321"}, + op: EQUAL}]}]` + + would be equivalent to the logical expression: + `(metadata.author = "agent 123" OR (metadata.label = "travel" AND + metadata.author = "agent 321"))`. + """ + + memory_types: Optional[list[MemoryType]] + """Specifies the types of memories to retrieve. If this field is empty + or not provided, the request will default to retrieving only memories of + type `NATURAL_LANGUAGE_COLLECTION`. If populated, the request will + retrieve memories matching any of the specified `MemoryType` values.""" + + +RetrieveAgentEngineMemoriesConfigOrDict = Union[ + RetrieveAgentEngineMemoriesConfig, RetrieveAgentEngineMemoriesConfigDict +] + + +class _RetrieveAgentEngineMemoriesRequestParameters(_common.BaseModel): + """Parameters for retrieving agent engine memories.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to retrieve memories from.""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The scope of the memories to retrieve. + + A memory must have exactly the same scope as the scope provided here to be + retrieved (i.e. same keys and values). Order does not matter, but it is + case-sensitive.""", + ) + similarity_search_params: Optional[ + RetrieveMemoriesRequestSimilaritySearchParams + ] = Field( + default=None, + description="""Parameters for semantic similarity search based retrieval.""", + ) + simple_retrieval_params: Optional[RetrieveMemoriesRequestSimpleRetrievalParams] = ( + Field( + default=None, + description="""Parameters for simple (non-similarity search) retrieval.""", + ) + ) + config: Optional[RetrieveAgentEngineMemoriesConfig] = Field( + default=None, description="""""" + ) + + +class _RetrieveAgentEngineMemoriesRequestParametersDict(TypedDict, total=False): + """Parameters for retrieving agent engine memories.""" + + name: Optional[str] + """Name of the agent engine to retrieve memories from.""" + + scope: Optional[dict[str, str]] + """The scope of the memories to retrieve. + + A memory must have exactly the same scope as the scope provided here to be + retrieved (i.e. same keys and values). Order does not matter, but it is + case-sensitive.""" + + similarity_search_params: Optional[ + RetrieveMemoriesRequestSimilaritySearchParamsDict + ] + """Parameters for semantic similarity search based retrieval.""" + + simple_retrieval_params: Optional[RetrieveMemoriesRequestSimpleRetrievalParamsDict] + """Parameters for simple (non-similarity search) retrieval.""" + + config: Optional[RetrieveAgentEngineMemoriesConfigDict] + """""" + + +_RetrieveAgentEngineMemoriesRequestParametersOrDict = Union[ + _RetrieveAgentEngineMemoriesRequestParameters, + _RetrieveAgentEngineMemoriesRequestParametersDict, +] + + +class RetrieveMemoriesResponseRetrievedMemory(_common.BaseModel): + """A retrieved memory.""" + + distance: Optional[float] = Field( + default=None, + description="""The distance between the query and the retrieved Memory. Smaller values indicate more similar memories. This is only set if similarity search was used for retrieval.""", + ) + memory: Optional[Memory] = Field( + default=None, description="""The retrieved Memory.""" + ) + + +class RetrieveMemoriesResponseRetrievedMemoryDict(TypedDict, total=False): + """A retrieved memory.""" + + distance: Optional[float] + """The distance between the query and the retrieved Memory. Smaller values indicate more similar memories. This is only set if similarity search was used for retrieval.""" + + memory: Optional[MemoryDict] + """The retrieved Memory.""" + + +RetrieveMemoriesResponseRetrievedMemoryOrDict = Union[ + RetrieveMemoriesResponseRetrievedMemory, RetrieveMemoriesResponseRetrievedMemoryDict +] + + +class RetrieveMemoriesResponse(_common.BaseModel): + """The response for retrieving memories.""" + + next_page_token: Optional[str] = Field( + default=None, + description="""A token that can be sent as `page_token` to retrieve the next page. If this field is omitted, there are no subsequent pages. This token is not set if similarity search was used for retrieval.""", + ) + retrieved_memories: Optional[list[RetrieveMemoriesResponseRetrievedMemory]] = Field( + default=None, description="""The retrieved memories.""" + ) + + +class RetrieveMemoriesResponseDict(TypedDict, total=False): + """The response for retrieving memories.""" + + next_page_token: Optional[str] + """A token that can be sent as `page_token` to retrieve the next page. If this field is omitted, there are no subsequent pages. This token is not set if similarity search was used for retrieval.""" + + retrieved_memories: Optional[list[RetrieveMemoriesResponseRetrievedMemoryDict]] + """The retrieved memories.""" + + +RetrieveMemoriesResponseOrDict = Union[ + RetrieveMemoriesResponse, RetrieveMemoriesResponseDict +] + + +class RetrieveMemoryProfilesConfig(_common.BaseModel): + """Config for retrieving memory profiles.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class RetrieveMemoryProfilesConfigDict(TypedDict, total=False): + """Config for retrieving memory profiles.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +RetrieveMemoryProfilesConfigOrDict = Union[ + RetrieveMemoryProfilesConfig, RetrieveMemoryProfilesConfigDict +] + + +class _RetrieveMemoryProfilesRequestParameters(_common.BaseModel): + """Parameters for retrieving agent engine memory profiles.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to retrieve memory profiles from.""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The scope of the memories to retrieve. + + A memory must have exactly the same scope as the scope provided here to be + retrieved (i.e. same keys and values). Order does not matter, but it is + case-sensitive.""", + ) + config: Optional[RetrieveMemoryProfilesConfig] = Field( + default=None, description="""""" + ) + + +class _RetrieveMemoryProfilesRequestParametersDict(TypedDict, total=False): + """Parameters for retrieving agent engine memory profiles.""" + + name: Optional[str] + """Name of the agent engine to retrieve memory profiles from.""" + + scope: Optional[dict[str, str]] + """The scope of the memories to retrieve. + + A memory must have exactly the same scope as the scope provided here to be + retrieved (i.e. same keys and values). Order does not matter, but it is + case-sensitive.""" + + config: Optional[RetrieveMemoryProfilesConfigDict] + """""" + + +_RetrieveMemoryProfilesRequestParametersOrDict = Union[ + _RetrieveMemoryProfilesRequestParameters, + _RetrieveMemoryProfilesRequestParametersDict, +] + + +class MemoryProfile(_common.BaseModel): + """A memory profile.""" + + schema_id: Optional[str] = Field( + default=None, + description="""Represents the ID of the schema. This ID corresponds to the `schema_id` defined inside the SchemaConfig, under StructuredMemoryCustomizationConfig.""", + ) + profile: Optional[dict[str, Any]] = Field( + default=None, description="""Represents the profile data.""" + ) + + +class MemoryProfileDict(TypedDict, total=False): + """A memory profile.""" + + schema_id: Optional[str] + """Represents the ID of the schema. This ID corresponds to the `schema_id` defined inside the SchemaConfig, under StructuredMemoryCustomizationConfig.""" + + profile: Optional[dict[str, Any]] + """Represents the profile data.""" + + +MemoryProfileOrDict = Union[MemoryProfile, MemoryProfileDict] + + +class RetrieveProfilesResponse(_common.BaseModel): + """The response for retrieving memory profiles.""" + + profiles: Optional[dict[str, MemoryProfile]] = Field( + default=None, + description="""The retrieved structured profiles, which match the schemas under the + requested scope. The key is the ID of the schema that the profile is + linked with, which corresponds to the `schema_id` defined inside the + `SchemaConfig`, under `StructuredMemoryCustomizationConfig`.""", + ) + + +class RetrieveProfilesResponseDict(TypedDict, total=False): + """The response for retrieving memory profiles.""" + + profiles: Optional[dict[str, MemoryProfileDict]] + """The retrieved structured profiles, which match the schemas under the + requested scope. The key is the ID of the schema that the profile is + linked with, which corresponds to the `schema_id` defined inside the + `SchemaConfig`, under `StructuredMemoryCustomizationConfig`.""" + + +RetrieveProfilesResponseOrDict = Union[ + RetrieveProfilesResponse, RetrieveProfilesResponseDict +] + + +class RollbackAgentEngineMemoryConfig(_common.BaseModel): + """Config for rolling back a memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + + +class RollbackAgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for rolling back a memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + +RollbackAgentEngineMemoryConfigOrDict = Union[ + RollbackAgentEngineMemoryConfig, RollbackAgentEngineMemoryConfigDict +] + + +class _RollbackAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for generating agent engine memories.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine memory to rollback.""" + ) + target_revision_id: Optional[str] = Field( + default=None, description="""The ID of the revision to rollback to.""" + ) + config: Optional[RollbackAgentEngineMemoryConfig] = Field( + default=None, description="""""" + ) + + +class _RollbackAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for generating agent engine memories.""" + + name: Optional[str] + """Name of the agent engine memory to rollback.""" + + target_revision_id: Optional[str] + """The ID of the revision to rollback to.""" + + config: Optional[RollbackAgentEngineMemoryConfigDict] + """""" + + +_RollbackAgentEngineMemoryRequestParametersOrDict = Union[ + _RollbackAgentEngineMemoryRequestParameters, + _RollbackAgentEngineMemoryRequestParametersDict, +] + + +class AgentEngineRollbackMemoryOperation(_common.BaseModel): + """Operation that rolls back a memory.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class AgentEngineRollbackMemoryOperationDict(TypedDict, total=False): + """Operation that rolls back a memory.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +AgentEngineRollbackMemoryOperationOrDict = Union[ + AgentEngineRollbackMemoryOperation, AgentEngineRollbackMemoryOperationDict +] + + +class UpdateAgentEngineMemoryConfig(_common.BaseModel): + """Config for updating agent engine memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the memory.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the memory.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", + ) + revision_expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""", + ) + revision_ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""", + ) + disable_memory_revisions: Optional[bool] = Field( + default=None, + description="""Optional. Input only. If true, no revision will be created for this request.""", + ) + topics: Optional[list[MemoryTopicId]] = Field( + default=None, description="""Optional. The topics of the memory.""" + ) + metadata: Optional[dict[str, MemoryMetadataValue]] = Field( + default=None, + description="""Optional. User-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""", + ) + memory_id: Optional[str] = Field( + default=None, + description="""Optional. The user defined ID to use for memory, which will become the final component of the memory resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""", + ) + update_mask: Optional[str] = Field( + default=None, + description="""The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""", + ) + + +class UpdateAgentEngineMemoryConfigDict(TypedDict, total=False): + """Config for updating agent engine memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the memory.""" + + description: Optional[str] + """The description of the memory.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + + revision_expire_time: Optional[datetime.datetime] + """Optional. Input only. Timestamp of when the revision is considered expired. If not set, the memory revision will be kept until manually deleted.""" + + revision_ttl: Optional[str] + """Optional. Input only. The TTL for the revision. The expiration time is computed: now + TTL.""" + + disable_memory_revisions: Optional[bool] + """Optional. Input only. If true, no revision will be created for this request.""" + + topics: Optional[list[MemoryTopicIdDict]] + """Optional. The topics of the memory.""" + + metadata: Optional[dict[str, MemoryMetadataValueDict]] + """Optional. User-provided metadata for the Memory. This information was provided when creating, updating, or generating the Memory. It was not generated by Memory Bank.""" + + memory_id: Optional[str] + """Optional. The user defined ID to use for memory, which will become the final component of the memory resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""" + + update_mask: Optional[str] + """The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" + + +UpdateAgentEngineMemoryConfigOrDict = Union[ + UpdateAgentEngineMemoryConfig, UpdateAgentEngineMemoryConfigDict +] + + +class _UpdateAgentEngineMemoryRequestParameters(_common.BaseModel): + """Parameters for updating agent engine memories.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine memory to update.""" + ) + fact: Optional[str] = Field( + default=None, + description="""The updated fact of the memory. + + This is the semantic knowledge extracted from the source content.""", + ) + scope: Optional[dict[str, str]] = Field( + default=None, + description="""The updated scope of the memory. + + Memories are isolated within their scope. The scope is defined when + creating or generating memories. Up to 5 key-value pairs are accepted, + and scope values cannot contain the wildcard character '*'.""", + ) + config: Optional[UpdateAgentEngineMemoryConfig] = Field( + default=None, description="""""" + ) + + +class _UpdateAgentEngineMemoryRequestParametersDict(TypedDict, total=False): + """Parameters for updating agent engine memories.""" + + name: Optional[str] + """Name of the agent engine memory to update.""" + + fact: Optional[str] + """The updated fact of the memory. + + This is the semantic knowledge extracted from the source content.""" + + scope: Optional[dict[str, str]] + """The updated scope of the memory. + + Memories are isolated within their scope. The scope is defined when + creating or generating memories. Up to 5 key-value pairs are accepted, + and scope values cannot contain the wildcard character '*'.""" + + config: Optional[UpdateAgentEngineMemoryConfigDict] + """""" + + +_UpdateAgentEngineMemoryRequestParametersOrDict = Union[ + _UpdateAgentEngineMemoryRequestParameters, + _UpdateAgentEngineMemoryRequestParametersDict, +] + + +class PurgeAgentEngineMemoriesConfig(_common.BaseModel): + """Config for purging memories.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + + +class PurgeAgentEngineMemoriesConfigDict(TypedDict, total=False): + """Config for purging memories.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + +PurgeAgentEngineMemoriesConfigOrDict = Union[ + PurgeAgentEngineMemoriesConfig, PurgeAgentEngineMemoriesConfigDict +] + + +class _PurgeAgentEngineMemoriesRequestParameters(_common.BaseModel): + """Parameters for purging agent engine memories.""" + + name: Optional[str] = Field( + default=None, description="""Name of the Agent Engine to purge memories from.""" + ) + filter: Optional[str] = Field( + default=None, + description="""The standard list filter to determine which memories to purge. + More detail in [AIP-160](https://google.aip.dev/160).""", + ) + filter_groups: Optional[list[MemoryConjunctionFilter]] = Field( + default=None, + description="""Metadata filters that will be applied to the memories' + `metadata` using OR logic. Filters are defined using disjunctive normal + form (OR of ANDs). + + For example: + `filter_groups: [{filters: [{key: "author", value: {string_value: "agent + `123"}, op: EQUAL}]}, {filters: [{key: "label", value: {string_value: + "travel"}, op: EQUAL}, {key: "author", value: {string_value: "agent 321"}, + op: EQUAL}]}]` + + would be equivalent to the logical expression: + `(metadata.author = "agent 123" OR (metadata.label = "travel" AND + metadata.author = "agent 321"))`. + """, + ) + force: Optional[bool] = Field( + default=None, + description="""If true, the memories will actually be purged. If false, the purge request will be validated but not executed.""", + ) + config: Optional[PurgeAgentEngineMemoriesConfig] = Field( + default=None, description="""""" + ) + + +class _PurgeAgentEngineMemoriesRequestParametersDict(TypedDict, total=False): + """Parameters for purging agent engine memories.""" + + name: Optional[str] + """Name of the Agent Engine to purge memories from.""" + + filter: Optional[str] + """The standard list filter to determine which memories to purge. + More detail in [AIP-160](https://google.aip.dev/160).""" + + filter_groups: Optional[list[MemoryConjunctionFilterDict]] + """Metadata filters that will be applied to the memories' + `metadata` using OR logic. Filters are defined using disjunctive normal + form (OR of ANDs). + + For example: + `filter_groups: [{filters: [{key: "author", value: {string_value: "agent + `123"}, op: EQUAL}]}, {filters: [{key: "label", value: {string_value: + "travel"}, op: EQUAL}, {key: "author", value: {string_value: "agent 321"}, + op: EQUAL}]}]` + + would be equivalent to the logical expression: + `(metadata.author = "agent 123" OR (metadata.label = "travel" AND + metadata.author = "agent 321"))`. + """ + + force: Optional[bool] + """If true, the memories will actually be purged. If false, the purge request will be validated but not executed.""" + + config: Optional[PurgeAgentEngineMemoriesConfigDict] + """""" + + +_PurgeAgentEngineMemoriesRequestParametersOrDict = Union[ + _PurgeAgentEngineMemoriesRequestParameters, + _PurgeAgentEngineMemoriesRequestParametersDict, +] + + +class PurgeMemoriesResponse(_common.BaseModel): + """The response for purging memories.""" + + purge_count: Optional[int] = Field( + default=None, description="""The number of memories that were purged.""" + ) + + +class PurgeMemoriesResponseDict(TypedDict, total=False): + """The response for purging memories.""" + + purge_count: Optional[int] + """The number of memories that were purged.""" + + +PurgeMemoriesResponseOrDict = Union[PurgeMemoriesResponse, PurgeMemoriesResponseDict] + + +class AgentEnginePurgeMemoriesOperation(_common.BaseModel): + """Operation that purges memories from an agent engine.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[PurgeMemoriesResponse] = Field( + default=None, description="""The response for purging memories.""" + ) + + +class AgentEnginePurgeMemoriesOperationDict(TypedDict, total=False): + """Operation that purges memories from an agent engine.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[PurgeMemoriesResponseDict] + """The response for purging memories.""" + + +AgentEnginePurgeMemoriesOperationOrDict = Union[ + AgentEnginePurgeMemoriesOperation, AgentEnginePurgeMemoriesOperationDict +] + + +class GetAgentEngineMemoryRevisionConfig(_common.BaseModel): + """Config for getting an Agent Engine Memory Revision.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineMemoryRevisionConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Memory Revision.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineMemoryRevisionConfigOrDict = Union[ + GetAgentEngineMemoryRevisionConfig, GetAgentEngineMemoryRevisionConfigDict +] + + +class _GetAgentEngineMemoryRevisionRequestParameters(_common.BaseModel): + """Parameters for getting an Agent Engine memory revision.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[GetAgentEngineMemoryRevisionConfig] = Field( + default=None, description="""""" + ) + + +class _GetAgentEngineMemoryRevisionRequestParametersDict(TypedDict, total=False): + """Parameters for getting an Agent Engine memory revision.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[GetAgentEngineMemoryRevisionConfigDict] + """""" + + +_GetAgentEngineMemoryRevisionRequestParametersOrDict = Union[ + _GetAgentEngineMemoryRevisionRequestParameters, + _GetAgentEngineMemoryRevisionRequestParametersDict, +] + + +class IntermediateExtractedMemory(_common.BaseModel): + """An extracted memory that is the intermediate result before consolidation.""" + + fact: Optional[str] = Field( + default=None, + description="""Output only. Represents the fact of the extracted memory.""", + ) + context: Optional[str] = Field( + default=None, + description="""Output only. Represents the explanation of why the information was extracted from the source content.""", + ) + structured_data: Optional[dict[str, Any]] = Field( + default=None, + description="""Output only. Represents the structured value of the extracted memory.""", + ) + + +class IntermediateExtractedMemoryDict(TypedDict, total=False): + """An extracted memory that is the intermediate result before consolidation.""" + + fact: Optional[str] + """Output only. Represents the fact of the extracted memory.""" + + context: Optional[str] + """Output only. Represents the explanation of why the information was extracted from the source content.""" + + structured_data: Optional[dict[str, Any]] + """Output only. Represents the structured value of the extracted memory.""" + + +IntermediateExtractedMemoryOrDict = Union[ + IntermediateExtractedMemory, IntermediateExtractedMemoryDict +] + + +class MemoryRevision(_common.BaseModel): + """A memory revision.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Represents the timestamp when this Memory Revision was created.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Represents the timestamp of when this resource is considered expired.""", + ) + extracted_memories: Optional[list[IntermediateExtractedMemory]] = Field( + default=None, + description="""Output only. Represents the extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""", + ) + fact: Optional[str] = Field( + default=None, + description="""Output only. Represents the fact of the Memory Revision. This corresponds to the `fact` field of the parent Memory at the time of revision creation.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Output only. Represents the labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. Represents the resource name of the Memory Revision. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}/revisions/{memory_revision}`""", + ) + structured_data: Optional[dict[str, Any]] = Field( + default=None, + description="""Output only. Represents the structured value of the memory at the time of revision creation.""", + ) + + +class MemoryRevisionDict(TypedDict, total=False): + """A memory revision.""" + + create_time: Optional[datetime.datetime] + """Output only. Represents the timestamp when this Memory Revision was created.""" + + expire_time: Optional[datetime.datetime] + """Output only. Represents the timestamp of when this resource is considered expired.""" + + extracted_memories: Optional[list[IntermediateExtractedMemoryDict]] + """Output only. Represents the extracted memories from the source content before consolidation when the memory was updated via GenerateMemories. This information was used to modify an existing Memory via Consolidation.""" + + fact: Optional[str] + """Output only. Represents the fact of the Memory Revision. This corresponds to the `fact` field of the parent Memory at the time of revision creation.""" + + labels: Optional[dict[str, str]] + """Output only. Represents the labels of the Memory Revision. These labels are applied to the MemoryRevision when it is created based on `GenerateMemoriesRequest.revision_labels`.""" + + name: Optional[str] + """Identifier. Represents the resource name of the Memory Revision. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/memories/{memory}/revisions/{memory_revision}`""" + + structured_data: Optional[dict[str, Any]] + """Output only. Represents the structured value of the memory at the time of revision creation.""" + + +MemoryRevisionOrDict = Union[MemoryRevision, MemoryRevisionDict] + + +class ListAgentEngineMemoryRevisionsConfig(_common.BaseModel): + """Config for listing Agent Engine memory revisions.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineMemoryRevisionsConfigDict(TypedDict, total=False): + """Config for listing Agent Engine memory revisions.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineMemoryRevisionsConfigOrDict = Union[ + ListAgentEngineMemoryRevisionsConfig, ListAgentEngineMemoryRevisionsConfigDict +] + + +class _ListAgentEngineMemoryRevisionsRequestParameters(_common.BaseModel): + """Parameters for listing Agent Engine memory revisions.""" + + name: Optional[str] = Field( + default=None, description="""Name of the Agent Engine memory""" + ) + config: Optional[ListAgentEngineMemoryRevisionsConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineMemoryRevisionsRequestParametersDict(TypedDict, total=False): + """Parameters for listing Agent Engine memory revisions.""" + + name: Optional[str] + """Name of the Agent Engine memory""" + + config: Optional[ListAgentEngineMemoryRevisionsConfigDict] + """""" + + +_ListAgentEngineMemoryRevisionsRequestParametersOrDict = Union[ + _ListAgentEngineMemoryRevisionsRequestParameters, + _ListAgentEngineMemoryRevisionsRequestParametersDict, +] + + +class ListAgentEngineMemoryRevisionsResponse(_common.BaseModel): + """Response for listing agent engine memory revisions.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + memory_revisions: Optional[list[MemoryRevision]] = Field( + default=None, description="""List of memory revisions.""" + ) + + +class ListAgentEngineMemoryRevisionsResponseDict(TypedDict, total=False): + """Response for listing agent engine memory revisions.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + memory_revisions: Optional[list[MemoryRevisionDict]] + """List of memory revisions.""" + + +ListAgentEngineMemoryRevisionsResponseOrDict = Union[ + ListAgentEngineMemoryRevisionsResponse, ListAgentEngineMemoryRevisionsResponseDict +] + + +class GetAgentEngineRuntimeRevisionConfig(_common.BaseModel): + """Config for getting an Agent Engine Runtime Revision.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineRuntimeRevisionConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Runtime Revision.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineRuntimeRevisionConfigOrDict = Union[ + GetAgentEngineRuntimeRevisionConfig, GetAgentEngineRuntimeRevisionConfigDict +] + + +class _GetAgentEngineRuntimeRevisionRequestParameters(_common.BaseModel): + """Parameters for getting an agent engine runtime revision.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine runtime revision.""" + ) + config: Optional[GetAgentEngineRuntimeRevisionConfig] = Field( + default=None, description="""""" + ) + + +class _GetAgentEngineRuntimeRevisionRequestParametersDict(TypedDict, total=False): + """Parameters for getting an agent engine runtime revision.""" + + name: Optional[str] + """Name of the agent engine runtime revision.""" + + config: Optional[GetAgentEngineRuntimeRevisionConfigDict] + """""" + + +_GetAgentEngineRuntimeRevisionRequestParametersOrDict = Union[ + _GetAgentEngineRuntimeRevisionRequestParameters, + _GetAgentEngineRuntimeRevisionRequestParametersDict, +] + + +class ReasoningEngineRuntimeRevision(_common.BaseModel): + """A runtime revision.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this ReasoningEngineRuntimeRevision was created.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the ReasoningEngineRuntimeRevision. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/runtimeRevisions/{runtime_revision}`""", + ) + spec: Optional[ReasoningEngineSpec] = Field( + default=None, + description="""Immutable. Configurations of the ReasoningEngineRuntimeRevision. Contains only revision specific fields.""", + ) + state: Optional[State] = Field( + default=None, description="""Output only. The state of the revision.""" + ) + + +class ReasoningEngineRuntimeRevisionDict(TypedDict, total=False): + """A runtime revision.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this ReasoningEngineRuntimeRevision was created.""" + + name: Optional[str] + """Identifier. The resource name of the ReasoningEngineRuntimeRevision. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/runtimeRevisions/{runtime_revision}`""" + + spec: Optional[ReasoningEngineSpecDict] + """Immutable. Configurations of the ReasoningEngineRuntimeRevision. Contains only revision specific fields.""" + + state: Optional[State] + """Output only. The state of the revision.""" + + +ReasoningEngineRuntimeRevisionOrDict = Union[ + ReasoningEngineRuntimeRevision, ReasoningEngineRuntimeRevisionDict +] + + +class ListAgentEngineRuntimeRevisionsConfig(_common.BaseModel): + """Config for listing reasoning engine runtime revisions.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineRuntimeRevisionsConfigDict(TypedDict, total=False): + """Config for listing reasoning engine runtime revisions.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineRuntimeRevisionsConfigOrDict = Union[ + ListAgentEngineRuntimeRevisionsConfig, ListAgentEngineRuntimeRevisionsConfigDict +] + + +class _ListAgentEngineRuntimeRevisionsRequestParameters(_common.BaseModel): + """Parameters for listing reasoning engine runtime revisions.""" + + name: Optional[str] = Field( + default=None, description="""Name of the reasoning engine.""" + ) + config: Optional[ListAgentEngineRuntimeRevisionsConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineRuntimeRevisionsRequestParametersDict(TypedDict, total=False): + """Parameters for listing reasoning engine runtime revisions.""" + + name: Optional[str] + """Name of the reasoning engine.""" + + config: Optional[ListAgentEngineRuntimeRevisionsConfigDict] + """""" + + +_ListAgentEngineRuntimeRevisionsRequestParametersOrDict = Union[ + _ListAgentEngineRuntimeRevisionsRequestParameters, + _ListAgentEngineRuntimeRevisionsRequestParametersDict, +] + + +class ListReasoningEnginesRuntimeRevisionsResponse(_common.BaseModel): + """Response for listing agent engine runtime revisions.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + reasoning_engine_runtime_revisions: Optional[ + list[ReasoningEngineRuntimeRevision] + ] = Field( + default=None, description="""List of reasoning engine runtime revisions.""" + ) + + +class ListReasoningEnginesRuntimeRevisionsResponseDict(TypedDict, total=False): + """Response for listing agent engine runtime revisions.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + reasoning_engine_runtime_revisions: Optional[ + list[ReasoningEngineRuntimeRevisionDict] + ] + """List of reasoning engine runtime revisions.""" + + +ListReasoningEnginesRuntimeRevisionsResponseOrDict = Union[ + ListReasoningEnginesRuntimeRevisionsResponse, + ListReasoningEnginesRuntimeRevisionsResponseDict, +] + + +class DeleteAgentEngineRuntimeRevisionConfig(_common.BaseModel): + """Config for deleting an Agent Engine Runtime Revision.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + + +class DeleteAgentEngineRuntimeRevisionConfigDict(TypedDict, total=False): + """Config for deleting an Agent Engine Runtime Revision.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + +DeleteAgentEngineRuntimeRevisionConfigOrDict = Union[ + DeleteAgentEngineRuntimeRevisionConfig, DeleteAgentEngineRuntimeRevisionConfigDict +] + + +class _DeleteAgentEngineRuntimeRevisionRequestParameters(_common.BaseModel): + """Parameters for deleting agent engine runtime revisions.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine runtime revision to delete.""", + ) + config: Optional[DeleteAgentEngineRuntimeRevisionConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteAgentEngineRuntimeRevisionRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engine runtime revisions.""" + + name: Optional[str] + """Name of the agent engine runtime revision to delete.""" + + config: Optional[DeleteAgentEngineRuntimeRevisionConfigDict] + """""" + + +_DeleteAgentEngineRuntimeRevisionRequestParametersOrDict = Union[ + _DeleteAgentEngineRuntimeRevisionRequestParameters, + _DeleteAgentEngineRuntimeRevisionRequestParametersDict, +] + + +class DeleteAgentEngineRuntimeRevisionOperation(_common.BaseModel): + """Operation for deleting agent engine runtime revisions.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineRuntimeRevisionOperationDict(TypedDict, total=False): + """Operation for deleting agent engine runtime revisions.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineRuntimeRevisionOperationOrDict = Union[ + DeleteAgentEngineRuntimeRevisionOperation, + DeleteAgentEngineRuntimeRevisionOperationDict, +] + + +class GetDeleteAgentEngineRuntimeRevisionOperationConfig(_common.BaseModel): + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetDeleteAgentEngineRuntimeRevisionOperationConfigDict(TypedDict, total=False): + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetDeleteAgentEngineRuntimeRevisionOperationConfigOrDict = Union[ + GetDeleteAgentEngineRuntimeRevisionOperationConfig, + GetDeleteAgentEngineRuntimeRevisionOperationConfigDict, +] + + +class _GetDeleteAgentEngineRuntimeRevisionOperationParameters(_common.BaseModel): + """Parameters for getting an operation that deletes a agent engine runtime revision.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetDeleteAgentEngineRuntimeRevisionOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetDeleteAgentEngineRuntimeRevisionOperationParametersDict( + TypedDict, total=False +): + """Parameters for getting an operation that deletes a agent engine runtime revision.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetDeleteAgentEngineRuntimeRevisionOperationConfigDict] + """Used to override the default configuration.""" + + +_GetDeleteAgentEngineRuntimeRevisionOperationParametersOrDict = Union[ + _GetDeleteAgentEngineRuntimeRevisionOperationParameters, + _GetDeleteAgentEngineRuntimeRevisionOperationParametersDict, +] + + +class QueryAgentEngineRuntimeRevisionConfig(_common.BaseModel): + """Config for querying agent engine runtime revisions.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + class_method: Optional[str] = Field( + default=None, description="""The class method to call.""" + ) + input: Optional[dict[str, Any]] = Field( + default=None, description="""The input to the class method.""" + ) + include_all_fields: Optional[bool] = Field(default=False, description="""""") + + +class QueryAgentEngineRuntimeRevisionConfigDict(TypedDict, total=False): + """Config for querying agent engine runtime revisions.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + class_method: Optional[str] + """The class method to call.""" + + input: Optional[dict[str, Any]] + """The input to the class method.""" + + include_all_fields: Optional[bool] + """""" + + +QueryAgentEngineRuntimeRevisionConfigOrDict = Union[ + QueryAgentEngineRuntimeRevisionConfig, QueryAgentEngineRuntimeRevisionConfigDict +] + + +class _QueryAgentEngineRuntimeRevisionRequestParameters(_common.BaseModel): + """Parameters for querying agent engine runtime revisions.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine runtime revision.""" + ) + config: Optional[QueryAgentEngineRuntimeRevisionConfig] = Field( + default=None, description="""""" + ) + + +class _QueryAgentEngineRuntimeRevisionRequestParametersDict(TypedDict, total=False): + """Parameters for querying agent engine runtime revisions.""" + + name: Optional[str] + """Name of the agent engine runtime revision.""" + + config: Optional[QueryAgentEngineRuntimeRevisionConfigDict] + """""" + + +_QueryAgentEngineRuntimeRevisionRequestParametersOrDict = Union[ + _QueryAgentEngineRuntimeRevisionRequestParameters, + _QueryAgentEngineRuntimeRevisionRequestParametersDict, +] + + +class SandboxEnvironmentSpecCodeExecutionEnvironment(_common.BaseModel): + """The code execution environment with customized settings.""" + + code_language: Optional[Language] = Field( + default=None, + description="""The coding language supported in this environment.""", + ) + machine_config: Optional[MachineConfig] = Field( + default=None, + description="""The machine config of the code execution environment.""", + ) + + +class SandboxEnvironmentSpecCodeExecutionEnvironmentDict(TypedDict, total=False): + """The code execution environment with customized settings.""" + + code_language: Optional[Language] + """The coding language supported in this environment.""" + + machine_config: Optional[MachineConfig] + """The machine config of the code execution environment.""" + + +SandboxEnvironmentSpecCodeExecutionEnvironmentOrDict = Union[ + SandboxEnvironmentSpecCodeExecutionEnvironment, + SandboxEnvironmentSpecCodeExecutionEnvironmentDict, +] + + +class SandboxEnvironmentSpecComputerUseEnvironment(_common.BaseModel): + """The computer use environment with customized settings.""" + + pass + + +class SandboxEnvironmentSpecComputerUseEnvironmentDict(TypedDict, total=False): + """The computer use environment with customized settings.""" + + pass + + +SandboxEnvironmentSpecComputerUseEnvironmentOrDict = Union[ + SandboxEnvironmentSpecComputerUseEnvironment, + SandboxEnvironmentSpecComputerUseEnvironmentDict, +] + + +class SandboxEnvironmentSpec(_common.BaseModel): + """The specification of a sandbox environment.""" + + code_execution_environment: Optional[ + SandboxEnvironmentSpecCodeExecutionEnvironment + ] = Field(default=None, description="""Optional. The code execution environment.""") + computer_use_environment: Optional[SandboxEnvironmentSpecComputerUseEnvironment] = ( + Field(default=None, description="""Optional. The computer use environment.""") + ) + + +class SandboxEnvironmentSpecDict(TypedDict, total=False): + """The specification of a sandbox environment.""" + + code_execution_environment: Optional[ + SandboxEnvironmentSpecCodeExecutionEnvironmentDict + ] + """Optional. The code execution environment.""" + + computer_use_environment: Optional[SandboxEnvironmentSpecComputerUseEnvironmentDict] + """Optional. The computer use environment.""" + + +SandboxEnvironmentSpecOrDict = Union[SandboxEnvironmentSpec, SandboxEnvironmentSpecDict] + + +class CreateAgentEngineSandboxConfig(_common.BaseModel): + """Config for creating a Sandbox.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the sandbox.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the sandbox.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""The TTL for this resource. The expiration time is computed: now + TTL.""", + ) + sandbox_environment_template: Optional[str] = Field( + default=None, + description="""The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""", + ) + sandbox_environment_snapshot: Optional[str] = Field( + default=None, + description="""The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""", + ) + owner: Optional[str] = Field( + default=None, + description="""Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""", + ) + + +class CreateAgentEngineSandboxConfigDict(TypedDict, total=False): + """Config for creating a Sandbox.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the sandbox.""" + + description: Optional[str] + """The description of the sandbox.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + ttl: Optional[str] + """The TTL for this resource. The expiration time is computed: now + TTL.""" + + sandbox_environment_template: Optional[str] + """The name of the sandbox environment template to create the sandbox from. The sandbox environment template should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}""" + + sandbox_environment_snapshot: Optional[str] + """The name of the sandbox environment snapshot to restore the sandbox from. The sandbox environment snapshot should be in the format: + projects/{project}/locations/{location}/agentEngines/{agent_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}""" + + owner: Optional[str] + """Owner information for this sandbox environment. A sandbox can only be restored from a snapshot belonging to the same owner.""" + + +CreateAgentEngineSandboxConfigOrDict = Union[ + CreateAgentEngineSandboxConfig, CreateAgentEngineSandboxConfigDict +] + + +class _CreateAgentEngineSandboxRequestParameters(_common.BaseModel): + """Parameters for creating Agent Engine Sandboxes.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to create the sandbox under.""", + ) + spec: Optional[SandboxEnvironmentSpec] = Field( + default=None, description="""The specification of the sandbox.""" + ) + config: Optional[CreateAgentEngineSandboxConfig] = Field( + default=None, description="""""" + ) + + +class _CreateAgentEngineSandboxRequestParametersDict(TypedDict, total=False): + """Parameters for creating Agent Engine Sandboxes.""" + + name: Optional[str] + """Name of the agent engine to create the sandbox under.""" + + spec: Optional[SandboxEnvironmentSpecDict] + """The specification of the sandbox.""" + + config: Optional[CreateAgentEngineSandboxConfigDict] + """""" + + +_CreateAgentEngineSandboxRequestParametersOrDict = Union[ + _CreateAgentEngineSandboxRequestParameters, + _CreateAgentEngineSandboxRequestParametersDict, +] + + +class SandboxEnvironmentConnectionInfo(_common.BaseModel): + """The connection information of the SandboxEnvironment.""" + + load_balancer_hostname: Optional[str] = Field( + default=None, description="""Output only. The hostname of the load balancer.""" + ) + load_balancer_ip: Optional[str] = Field( + default=None, + description="""Output only. The IP address of the load balancer.""", + ) + sandbox_internal_ip: Optional[str] = Field( + default=None, + description="""Output only. The internal IP address of the SandboxEnvironment.""", + ) + sandbox_hostname: Optional[str] = Field( + default=None, + description="""Output only. The hostname of the SandboxEnvironment.""", + ) + routing_token: Optional[str] = Field( + default=None, + description="""Output only. The routing token for the SandboxEnvironment.""", + ) + + +class SandboxEnvironmentConnectionInfoDict(TypedDict, total=False): + """The connection information of the SandboxEnvironment.""" + + load_balancer_hostname: Optional[str] + """Output only. The hostname of the load balancer.""" + + load_balancer_ip: Optional[str] + """Output only. The IP address of the load balancer.""" + + sandbox_internal_ip: Optional[str] + """Output only. The internal IP address of the SandboxEnvironment.""" + + sandbox_hostname: Optional[str] + """Output only. The hostname of the SandboxEnvironment.""" + + routing_token: Optional[str] + """Output only. The routing token for the SandboxEnvironment.""" + + +SandboxEnvironmentConnectionInfoOrDict = Union[ + SandboxEnvironmentConnectionInfo, SandboxEnvironmentConnectionInfoDict +] + + +class SandboxEnvironment(_common.BaseModel): + """A sandbox environment.""" + + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Expiration time of the sandbox environment. + """, + ) + connection_info: Optional[SandboxEnvironmentConnectionInfo] = Field( + default=None, + description="""Output only. The connection information of the SandboxEnvironment.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironment was created.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The display name of the SandboxEnvironment.""", + ) + name: Optional[str] = Field( + default=None, description="""Identifier. The name of the SandboxEnvironment.""" + ) + spec: Optional[SandboxEnvironmentSpec] = Field( + default=None, + description="""Optional. The configuration of the SandboxEnvironment.""", + ) + state: Optional[State] = Field( + default=None, + description="""Output only. The runtime state of the SandboxEnvironment.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for the sandbox environment. The expiration time is computed: now + TTL.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironment was most recently updated.""", + ) + latest_sandbox_environment_snapshot: Optional[str] = Field( + default=None, + description="""Output only. The resource name of the latest snapshot taken for this SandboxEnvironment.""", + ) + owner: Optional[str] = Field( + default=None, + description="""Optional. Owner information for this sandbox environment. A Sandbox can only be restored from a snapshot that belongs to the same owner. If not set, sandbox will be created as the default owner.""", + ) + sandbox_environment_snapshot: Optional[str] = Field( + default=None, + description="""Optional. The resource name of the SandboxEnvironmentSnapshot to use for creating this SandboxEnvironment. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}`""", + ) + sandbox_environment_template: Optional[str] = Field( + default=None, + description="""Optional. The name of the SandboxEnvironmentTemplate specified in the parent Agent Engine resource that this SandboxEnvironment is created from. Only one of `sandbox_environment_template` and `spec` should be set.""", + ) + + +class SandboxEnvironmentDict(TypedDict, total=False): + """A sandbox environment.""" + + expire_time: Optional[datetime.datetime] + """Expiration time of the sandbox environment. + """ + + connection_info: Optional[SandboxEnvironmentConnectionInfoDict] + """Output only. The connection information of the SandboxEnvironment.""" + + create_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironment was created.""" + + display_name: Optional[str] + """Required. The display name of the SandboxEnvironment.""" + + name: Optional[str] + """Identifier. The name of the SandboxEnvironment.""" + + spec: Optional[SandboxEnvironmentSpecDict] + """Optional. The configuration of the SandboxEnvironment.""" + + state: Optional[State] + """Output only. The runtime state of the SandboxEnvironment.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for the sandbox environment. The expiration time is computed: now + TTL.""" + + update_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironment was most recently updated.""" + + latest_sandbox_environment_snapshot: Optional[str] + """Output only. The resource name of the latest snapshot taken for this SandboxEnvironment.""" + + owner: Optional[str] + """Optional. Owner information for this sandbox environment. A Sandbox can only be restored from a snapshot that belongs to the same owner. If not set, sandbox will be created as the default owner.""" + + sandbox_environment_snapshot: Optional[str] + """Optional. The resource name of the SandboxEnvironmentSnapshot to use for creating this SandboxEnvironment. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}`""" + + sandbox_environment_template: Optional[str] + """Optional. The name of the SandboxEnvironmentTemplate specified in the parent Agent Engine resource that this SandboxEnvironment is created from. Only one of `sandbox_environment_template` and `spec` should be set.""" + + +SandboxEnvironmentOrDict = Union[SandboxEnvironment, SandboxEnvironmentDict] + + +class AgentEngineSandboxOperation(_common.BaseModel): + """Operation that has an agent engine sandbox as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[SandboxEnvironment] = Field( + default=None, description="""The Agent Engine Sandbox.""" + ) + + +class AgentEngineSandboxOperationDict(TypedDict, total=False): + """Operation that has an agent engine sandbox as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[SandboxEnvironmentDict] + """The Agent Engine Sandbox.""" + + +AgentEngineSandboxOperationOrDict = Union[ + AgentEngineSandboxOperation, AgentEngineSandboxOperationDict +] + + +class DeleteAgentEngineSandboxConfig(_common.BaseModel): + """Config for deleting an Agent Engine Sandbox.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineSandboxConfigDict(TypedDict, total=False): + """Config for deleting an Agent Engine Sandbox.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineSandboxConfigOrDict = Union[ + DeleteAgentEngineSandboxConfig, DeleteAgentEngineSandboxConfigDict +] + + +class _DeleteAgentEngineSandboxRequestParameters(_common.BaseModel): + """Parameters for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine sandbox to delete.""" + ) + config: Optional[DeleteAgentEngineSandboxConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteAgentEngineSandboxRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engines.""" + + name: Optional[str] + """Name of the agent engine sandbox to delete.""" + + config: Optional[DeleteAgentEngineSandboxConfigDict] + """""" + + +_DeleteAgentEngineSandboxRequestParametersOrDict = Union[ + _DeleteAgentEngineSandboxRequestParameters, + _DeleteAgentEngineSandboxRequestParametersDict, +] + + +class DeleteAgentEngineSandboxOperation(_common.BaseModel): + """Operation for deleting agent engines.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineSandboxOperationDict(TypedDict, total=False): + """Operation for deleting agent engines.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineSandboxOperationOrDict = Union[ + DeleteAgentEngineSandboxOperation, DeleteAgentEngineSandboxOperationDict +] + + +class Metadata(_common.BaseModel): + """Metadata for a chunk.""" + + attributes: Optional[dict[str, bytes]] = Field( + default=None, + description="""Optional. Attributes attached to the data. The keys have semantic conventions and the consumers of the attributes should know how to deserialize the value bytes based on the keys.""", + ) + + +class MetadataDict(TypedDict, total=False): + """Metadata for a chunk.""" + + attributes: Optional[dict[str, bytes]] + """Optional. Attributes attached to the data. The keys have semantic conventions and the consumers of the attributes should know how to deserialize the value bytes based on the keys.""" + + +MetadataOrDict = Union[Metadata, MetadataDict] + + +class Chunk(_common.BaseModel): + """A chunk of data.""" + + data: Optional[bytes] = Field( + default=None, description="""Required. The data in the chunk.""" + ) + metadata: Optional[Metadata] = Field( + default=None, + description="""Optional. Metadata that is associated with the data in the payload.""", + ) + mime_type: Optional[str] = Field( + default=None, + description="""Required. Mime type of the chunk data. See https://www.iana.org/assignments/media-types/media-types.xhtml for the full list.""", + ) + + +class ChunkDict(TypedDict, total=False): + """A chunk of data.""" + + data: Optional[bytes] + """Required. The data in the chunk.""" + + metadata: Optional[MetadataDict] + """Optional. Metadata that is associated with the data in the payload.""" + + mime_type: Optional[str] + """Required. Mime type of the chunk data. See https://www.iana.org/assignments/media-types/media-types.xhtml for the full list.""" + + +ChunkOrDict = Union[Chunk, ChunkDict] + + +class ExecuteCodeAgentEngineSandboxConfig(_common.BaseModel): + """Config for executing code in an Agent Engine sandbox.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class ExecuteCodeAgentEngineSandboxConfigDict(TypedDict, total=False): + """Config for executing code in an Agent Engine sandbox.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +ExecuteCodeAgentEngineSandboxConfigOrDict = Union[ + ExecuteCodeAgentEngineSandboxConfig, ExecuteCodeAgentEngineSandboxConfigDict +] + + +class _ExecuteCodeAgentEngineSandboxRequestParameters(_common.BaseModel): + """Parameters for executing code in an agent engine sandbox.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine sandbox to execute code in.""", + ) + inputs: Optional[list[Chunk]] = Field( + default=None, description="""Inputs to the code execution.""" + ) + config: Optional[ExecuteCodeAgentEngineSandboxConfig] = Field( + default=None, description="""""" + ) + + +class _ExecuteCodeAgentEngineSandboxRequestParametersDict(TypedDict, total=False): + """Parameters for executing code in an agent engine sandbox.""" + + name: Optional[str] + """Name of the agent engine sandbox to execute code in.""" + + inputs: Optional[list[ChunkDict]] + """Inputs to the code execution.""" + + config: Optional[ExecuteCodeAgentEngineSandboxConfigDict] + """""" + + +_ExecuteCodeAgentEngineSandboxRequestParametersOrDict = Union[ + _ExecuteCodeAgentEngineSandboxRequestParameters, + _ExecuteCodeAgentEngineSandboxRequestParametersDict, +] + + +class ExecuteSandboxEnvironmentResponse(_common.BaseModel): + """The response for executing a sandbox environment.""" + + outputs: Optional[list[Chunk]] = Field( + default=None, description="""The outputs from the sandbox environment.""" + ) + + +class ExecuteSandboxEnvironmentResponseDict(TypedDict, total=False): + """The response for executing a sandbox environment.""" + + outputs: Optional[list[ChunkDict]] + """The outputs from the sandbox environment.""" + + +ExecuteSandboxEnvironmentResponseOrDict = Union[ + ExecuteSandboxEnvironmentResponse, ExecuteSandboxEnvironmentResponseDict +] + + +class GetAgentEngineSandboxConfig(_common.BaseModel): + """Config for getting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineSandboxConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Memory.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineSandboxConfigOrDict = Union[ + GetAgentEngineSandboxConfig, GetAgentEngineSandboxConfigDict +] + + +class _GetAgentEngineSandboxRequestParameters(_common.BaseModel): + """Parameters for getting an agent engine sandbox.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine sandbox.""" + ) + config: Optional[GetAgentEngineSandboxConfig] = Field( + default=None, description="""""" + ) + + +class _GetAgentEngineSandboxRequestParametersDict(TypedDict, total=False): + """Parameters for getting an agent engine sandbox.""" + + name: Optional[str] + """Name of the agent engine sandbox.""" + + config: Optional[GetAgentEngineSandboxConfigDict] + """""" + + +_GetAgentEngineSandboxRequestParametersOrDict = Union[ + _GetAgentEngineSandboxRequestParameters, _GetAgentEngineSandboxRequestParametersDict +] + + +class ListAgentEngineSandboxesConfig(_common.BaseModel): + """Config for listing agent engine sandboxes.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineSandboxesConfigDict(TypedDict, total=False): + """Config for listing agent engine sandboxes.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineSandboxesConfigOrDict = Union[ + ListAgentEngineSandboxesConfig, ListAgentEngineSandboxesConfigDict +] + + +class _ListAgentEngineSandboxesRequestParameters(_common.BaseModel): + """Parameters for listing agent engine sandboxes.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[ListAgentEngineSandboxesConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineSandboxesRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engine sandboxes.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[ListAgentEngineSandboxesConfigDict] + """""" + + +_ListAgentEngineSandboxesRequestParametersOrDict = Union[ + _ListAgentEngineSandboxesRequestParameters, + _ListAgentEngineSandboxesRequestParametersDict, +] + + +class ListAgentEngineSandboxesResponse(_common.BaseModel): + """Response for listing agent engine sandboxes.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + sandbox_environments: Optional[list[SandboxEnvironment]] = Field( + default=None, description="""List of agent engine sandboxes.""" + ) + + +class ListAgentEngineSandboxesResponseDict(TypedDict, total=False): + """Response for listing agent engine sandboxes.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + sandbox_environments: Optional[list[SandboxEnvironmentDict]] + """List of agent engine sandboxes.""" + + +ListAgentEngineSandboxesResponseOrDict = Union[ + ListAgentEngineSandboxesResponse, ListAgentEngineSandboxesResponseDict +] + + +class _GetAgentEngineSandboxOperationParameters(_common.BaseModel): + """Parameters for getting an operation with a sandbox as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineSandboxOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with a sandbox as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineSandboxOperationParametersOrDict = Union[ + _GetAgentEngineSandboxOperationParameters, + _GetAgentEngineSandboxOperationParametersDict, +] + + +class SandboxEnvironmentTemplateCustomContainerSpec(_common.BaseModel): + """Specification for deploying from a custom container image.""" + + image_uri: Optional[str] = Field( + default=None, + description="""Required. The Artifact Registry Docker image URI (e.g., us-central1-docker.pkg.dev/my-project/my-repo/my-image:tag) of the container image that is to be run on each worker replica.""", + ) + + +class SandboxEnvironmentTemplateCustomContainerSpecDict(TypedDict, total=False): + """Specification for deploying from a custom container image.""" + + image_uri: Optional[str] + """Required. The Artifact Registry Docker image URI (e.g., us-central1-docker.pkg.dev/my-project/my-repo/my-image:tag) of the container image that is to be run on each worker replica.""" + + +SandboxEnvironmentTemplateCustomContainerSpecOrDict = Union[ + SandboxEnvironmentTemplateCustomContainerSpec, + SandboxEnvironmentTemplateCustomContainerSpecDict, +] + + +class SandboxEnvironmentTemplateNetworkPort(_common.BaseModel): + """Represents a network port in a container.""" + + port: Optional[int] = Field( + default=None, + description="""Optional. Port number to expose. This must be a valid port number, between 1 and 65535.""", + ) + protocol: Optional[Protocol] = Field( + default=None, + description="""Optional. Protocol for port. Defaults to TCP if not specified.""", + ) + + +class SandboxEnvironmentTemplateNetworkPortDict(TypedDict, total=False): + """Represents a network port in a container.""" + + port: Optional[int] + """Optional. Port number to expose. This must be a valid port number, between 1 and 65535.""" + + protocol: Optional[Protocol] + """Optional. Protocol for port. Defaults to TCP if not specified.""" + + +SandboxEnvironmentTemplateNetworkPortOrDict = Union[ + SandboxEnvironmentTemplateNetworkPort, SandboxEnvironmentTemplateNetworkPortDict +] + + +class SandboxEnvironmentTemplateResourceRequirements(_common.BaseModel): + """Message to define resource requests and limits (mirroring Kubernetes) for each sandbox instance created from this template.""" + + limits: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The maximum amounts of compute resources allowed. Keys are resource names (e.g., "cpu", "memory"). Values are quantities (e.g., "500m", "1Gi").""", + ) + requests: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The requested amounts of compute resources. Keys are resource names (e.g., "cpu", "memory"). Values are quantities (e.g., "250m", "512Mi").""", + ) + + +class SandboxEnvironmentTemplateResourceRequirementsDict(TypedDict, total=False): + """Message to define resource requests and limits (mirroring Kubernetes) for each sandbox instance created from this template.""" + + limits: Optional[dict[str, str]] + """Optional. The maximum amounts of compute resources allowed. Keys are resource names (e.g., "cpu", "memory"). Values are quantities (e.g., "500m", "1Gi").""" + + requests: Optional[dict[str, str]] + """Optional. The requested amounts of compute resources. Keys are resource names (e.g., "cpu", "memory"). Values are quantities (e.g., "250m", "512Mi").""" + + +SandboxEnvironmentTemplateResourceRequirementsOrDict = Union[ + SandboxEnvironmentTemplateResourceRequirements, + SandboxEnvironmentTemplateResourceRequirementsDict, +] + + +class SandboxEnvironmentTemplateCustomContainerEnvironment(_common.BaseModel): + """The customized sandbox runtime environment for BYOC.""" + + custom_container_spec: Optional[SandboxEnvironmentTemplateCustomContainerSpec] = ( + Field( + default=None, + description="""The specification of the custom container environment.""", + ) + ) + ports: Optional[list[SandboxEnvironmentTemplateNetworkPort]] = Field( + default=None, description="""Ports to expose from the container.""" + ) + resources: Optional[SandboxEnvironmentTemplateResourceRequirements] = Field( + default=None, description="""Resource requests and limits for the container.""" + ) + + +class SandboxEnvironmentTemplateCustomContainerEnvironmentDict(TypedDict, total=False): + """The customized sandbox runtime environment for BYOC.""" + + custom_container_spec: Optional[SandboxEnvironmentTemplateCustomContainerSpecDict] + """The specification of the custom container environment.""" + + ports: Optional[list[SandboxEnvironmentTemplateNetworkPortDict]] + """Ports to expose from the container.""" + + resources: Optional[SandboxEnvironmentTemplateResourceRequirementsDict] + """Resource requests and limits for the container.""" + + +SandboxEnvironmentTemplateCustomContainerEnvironmentOrDict = Union[ + SandboxEnvironmentTemplateCustomContainerEnvironment, + SandboxEnvironmentTemplateCustomContainerEnvironmentDict, +] + + +class SandboxEnvironmentTemplateDefaultContainerEnvironment(_common.BaseModel): + """The default sandbox runtime environment for default container workloads.""" + + default_container_category: Optional[DefaultContainerCategory] = Field( + default=None, + description="""Required. The category of the default container image.""", + ) + + +class SandboxEnvironmentTemplateDefaultContainerEnvironmentDict(TypedDict, total=False): + """The default sandbox runtime environment for default container workloads.""" + + default_container_category: Optional[DefaultContainerCategory] + """Required. The category of the default container image.""" + + +SandboxEnvironmentTemplateDefaultContainerEnvironmentOrDict = Union[ + SandboxEnvironmentTemplateDefaultContainerEnvironment, + SandboxEnvironmentTemplateDefaultContainerEnvironmentDict, +] + + +class SandboxEnvironmentTemplateEgressControlConfig(_common.BaseModel): + """Configuration for egress control of sandbox instances.""" + + internet_access: Optional[bool] = Field( + default=None, description="""Optional. Whether to allow internet access.""" + ) + + +class SandboxEnvironmentTemplateEgressControlConfigDict(TypedDict, total=False): + """Configuration for egress control of sandbox instances.""" + + internet_access: Optional[bool] + """Optional. Whether to allow internet access.""" + + +SandboxEnvironmentTemplateEgressControlConfigOrDict = Union[ + SandboxEnvironmentTemplateEgressControlConfig, + SandboxEnvironmentTemplateEgressControlConfigDict, +] + + +class CreateSandboxEnvironmentTemplateConfig(_common.BaseModel): + """Config for creating a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + custom_container_environment: Optional[ + SandboxEnvironmentTemplateCustomContainerEnvironment + ] = Field( + default=None, + description="""The custom container environment for the sandbox template.""", + ) + default_container_environment: Optional[ + SandboxEnvironmentTemplateDefaultContainerEnvironment + ] = Field( + default=None, + description="""The default container environment for the sandbox template.""", + ) + egress_control_config: Optional[SandboxEnvironmentTemplateEgressControlConfig] = ( + Field( + default=None, + description="""The egress control config for the sandbox template.""", + ) + ) + + +class CreateSandboxEnvironmentTemplateConfigDict(TypedDict, total=False): + """Config for creating a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + custom_container_environment: Optional[ + SandboxEnvironmentTemplateCustomContainerEnvironmentDict + ] + """The custom container environment for the sandbox template.""" + + default_container_environment: Optional[ + SandboxEnvironmentTemplateDefaultContainerEnvironmentDict + ] + """The default container environment for the sandbox template.""" + + egress_control_config: Optional[SandboxEnvironmentTemplateEgressControlConfigDict] + """The egress control config for the sandbox template.""" + + +CreateSandboxEnvironmentTemplateConfigOrDict = Union[ + CreateSandboxEnvironmentTemplateConfig, CreateSandboxEnvironmentTemplateConfigDict +] + + +class _CreateSandboxEnvironmentTemplateRequestParameters(_common.BaseModel): + """Parameters for creating Sandbox Environment Templates.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to create the template under.""", + ) + config: Optional[CreateSandboxEnvironmentTemplateConfig] = Field( + default=None, description="""""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the sandbox template.""" + ) + + +class _CreateSandboxEnvironmentTemplateRequestParametersDict(TypedDict, total=False): + """Parameters for creating Sandbox Environment Templates.""" + + name: Optional[str] + """Name of the agent engine to create the template under.""" + + config: Optional[CreateSandboxEnvironmentTemplateConfigDict] + """""" + + display_name: Optional[str] + """The display name of the sandbox template.""" + + +_CreateSandboxEnvironmentTemplateRequestParametersOrDict = Union[ + _CreateSandboxEnvironmentTemplateRequestParameters, + _CreateSandboxEnvironmentTemplateRequestParametersDict, +] + + +class SandboxEnvironmentTemplateWarmPoolConfig(_common.BaseModel): + """Configuration for a warm pool of sandbox instances.""" + + target_instance_count: Optional[int] = Field( + default=None, + description="""Optional. The target number of pre-warmed instances to maintain.""", + ) + + +class SandboxEnvironmentTemplateWarmPoolConfigDict(TypedDict, total=False): + """Configuration for a warm pool of sandbox instances.""" + + target_instance_count: Optional[int] + """Optional. The target number of pre-warmed instances to maintain.""" + + +SandboxEnvironmentTemplateWarmPoolConfigOrDict = Union[ + SandboxEnvironmentTemplateWarmPoolConfig, + SandboxEnvironmentTemplateWarmPoolConfigDict, +] + + +class SandboxEnvironmentTemplate(_common.BaseModel): + """A sandbox environment template.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironmentTemplate was created.""", + ) + custom_container_environment: Optional[ + SandboxEnvironmentTemplateCustomContainerEnvironment + ] = Field( + default=None, + description="""The sandbox environment for custom container workloads.""", + ) + default_container_environment: Optional[ + SandboxEnvironmentTemplateDefaultContainerEnvironment + ] = Field( + default=None, + description="""The sandbox environment for default container workloads.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The display name of the SandboxEnvironmentTemplate.""", + ) + egress_control_config: Optional[SandboxEnvironmentTemplateEgressControlConfig] = ( + Field( + default=None, + description="""Optional. The configuration for egress control of this template.""", + ) + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the SandboxEnvironmentTemplate. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}`""", + ) + state: Optional[ + Literal[ + "UNSPECIFIED", + "PROVISIONING", + "ACTIVE", + "DEPROVISIONING", + "DELETED", + "FAILED", + ] + ] = Field( + default=None, + description="""Output only. The state of the sandbox environment template.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironmentTemplate was most recently updated.""", + ) + warm_pool_config: Optional[SandboxEnvironmentTemplateWarmPoolConfig] = Field( + default=None, + description="""Optional. The configuration for the warm pool of this template.""", + ) + + +class SandboxEnvironmentTemplateDict(TypedDict, total=False): + """A sandbox environment template.""" + + create_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironmentTemplate was created.""" + + custom_container_environment: Optional[ + SandboxEnvironmentTemplateCustomContainerEnvironmentDict + ] + """The sandbox environment for custom container workloads.""" + + default_container_environment: Optional[ + SandboxEnvironmentTemplateDefaultContainerEnvironmentDict + ] + """The sandbox environment for default container workloads.""" + + display_name: Optional[str] + """Required. The display name of the SandboxEnvironmentTemplate.""" + + egress_control_config: Optional[SandboxEnvironmentTemplateEgressControlConfigDict] + """Optional. The configuration for egress control of this template.""" + + name: Optional[str] + """Identifier. The resource name of the SandboxEnvironmentTemplate. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentTemplates/{sandbox_environment_template}`""" + + state: Optional[ + Literal[ + "UNSPECIFIED", + "PROVISIONING", + "ACTIVE", + "DEPROVISIONING", + "DELETED", + "FAILED", + ] + ] + """Output only. The state of the sandbox environment template.""" + + update_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironmentTemplate was most recently updated.""" + + warm_pool_config: Optional[SandboxEnvironmentTemplateWarmPoolConfigDict] + """Optional. The configuration for the warm pool of this template.""" + + +SandboxEnvironmentTemplateOrDict = Union[ + SandboxEnvironmentTemplate, SandboxEnvironmentTemplateDict +] + + +class SandboxEnvironmentTemplateOperation(_common.BaseModel): + """Operation that has an agent engine sandbox as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[SandboxEnvironmentTemplate] = Field( + default=None, description="""The Agent Engine Sandbox Template.""" + ) + + +class SandboxEnvironmentTemplateOperationDict(TypedDict, total=False): + """Operation that has an agent engine sandbox as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[SandboxEnvironmentTemplateDict] + """The Agent Engine Sandbox Template.""" + + +SandboxEnvironmentTemplateOperationOrDict = Union[ + SandboxEnvironmentTemplateOperation, SandboxEnvironmentTemplateOperationDict +] + + +class DeleteSandboxEnvironmentTemplateConfig(_common.BaseModel): + """Config for deleting a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteSandboxEnvironmentTemplateConfigDict(TypedDict, total=False): + """Config for deleting a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteSandboxEnvironmentTemplateConfigOrDict = Union[ + DeleteSandboxEnvironmentTemplateConfig, DeleteSandboxEnvironmentTemplateConfigDict +] + + +class _DeleteSandboxEnvironmentTemplateRequestParameters(_common.BaseModel): + """Parameters for deleting sandbox templates.""" + + name: Optional[str] = Field( + default=None, description="""Name of the sandbox template to delete.""" + ) + config: Optional[DeleteSandboxEnvironmentTemplateConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteSandboxEnvironmentTemplateRequestParametersDict(TypedDict, total=False): + """Parameters for deleting sandbox templates.""" + + name: Optional[str] + """Name of the sandbox template to delete.""" + + config: Optional[DeleteSandboxEnvironmentTemplateConfigDict] + """""" + + +_DeleteSandboxEnvironmentTemplateRequestParametersOrDict = Union[ + _DeleteSandboxEnvironmentTemplateRequestParameters, + _DeleteSandboxEnvironmentTemplateRequestParametersDict, +] + + +class DeleteSandboxEnvironmentTemplateOperation(_common.BaseModel): + """Operation for deleting sandbox templates.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteSandboxEnvironmentTemplateOperationDict(TypedDict, total=False): + """Operation for deleting sandbox templates.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteSandboxEnvironmentTemplateOperationOrDict = Union[ + DeleteSandboxEnvironmentTemplateOperation, + DeleteSandboxEnvironmentTemplateOperationDict, +] + + +class GetSandboxEnvironmentTemplateConfig(_common.BaseModel): + """Config for getting a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetSandboxEnvironmentTemplateConfigDict(TypedDict, total=False): + """Config for getting a Sandbox Template.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetSandboxEnvironmentTemplateConfigOrDict = Union[ + GetSandboxEnvironmentTemplateConfig, GetSandboxEnvironmentTemplateConfigDict +] + + +class _GetSandboxEnvironmentTemplateRequestParameters(_common.BaseModel): + """Parameters for getting a sandbox template.""" + + name: Optional[str] = Field( + default=None, description="""Name of the sandbox template.""" + ) + config: Optional[GetSandboxEnvironmentTemplateConfig] = Field( + default=None, description="""""" + ) + + +class _GetSandboxEnvironmentTemplateRequestParametersDict(TypedDict, total=False): + """Parameters for getting a sandbox template.""" + + name: Optional[str] + """Name of the sandbox template.""" + + config: Optional[GetSandboxEnvironmentTemplateConfigDict] + """""" + + +_GetSandboxEnvironmentTemplateRequestParametersOrDict = Union[ + _GetSandboxEnvironmentTemplateRequestParameters, + _GetSandboxEnvironmentTemplateRequestParametersDict, +] + + +class ListSandboxEnvironmentTemplatesConfig(_common.BaseModel): + """Config for listing sandbox templates.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request.""", + ) + + +class ListSandboxEnvironmentTemplatesConfigDict(TypedDict, total=False): + """Config for listing sandbox templates.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request.""" + + +ListSandboxEnvironmentTemplatesConfigOrDict = Union[ + ListSandboxEnvironmentTemplatesConfig, ListSandboxEnvironmentTemplatesConfigDict +] + + +class _ListSandboxEnvironmentTemplatesRequestParameters(_common.BaseModel): + """Parameters for listing sandbox templates.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[ListSandboxEnvironmentTemplatesConfig] = Field( + default=None, description="""""" + ) + + +class _ListSandboxEnvironmentTemplatesRequestParametersDict(TypedDict, total=False): + """Parameters for listing sandbox templates.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[ListSandboxEnvironmentTemplatesConfigDict] + """""" + + +_ListSandboxEnvironmentTemplatesRequestParametersOrDict = Union[ + _ListSandboxEnvironmentTemplatesRequestParameters, + _ListSandboxEnvironmentTemplatesRequestParametersDict, +] + + +class ListSandboxEnvironmentTemplatesResponse(_common.BaseModel): + """Response for listing sandbox templates.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + sandbox_environment_templates: Optional[list[SandboxEnvironmentTemplate]] = Field( + default=None, description="""List of sandbox templates.""" + ) + + +class ListSandboxEnvironmentTemplatesResponseDict(TypedDict, total=False): + """Response for listing sandbox templates.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + sandbox_environment_templates: Optional[list[SandboxEnvironmentTemplateDict]] + """List of sandbox templates.""" + + +ListSandboxEnvironmentTemplatesResponseOrDict = Union[ + ListSandboxEnvironmentTemplatesResponse, ListSandboxEnvironmentTemplatesResponseDict +] + + +class _GetSandboxEnvironmentTemplateOperationParameters(_common.BaseModel): + """Parameters for getting an operation with a sandbox template as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetSandboxEnvironmentTemplateOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with a sandbox template as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetSandboxEnvironmentTemplateOperationParametersOrDict = Union[ + _GetSandboxEnvironmentTemplateOperationParameters, + _GetSandboxEnvironmentTemplateOperationParametersDict, +] + + +class CreateAgentEngineSandboxSnapshotConfig(_common.BaseModel): + """Config for creating a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the sandbox snapshot.""" + ) + owner: Optional[str] = Field( + default=None, description="""The owner of the sandbox snapshot.""" + ) + ttl: Optional[str] = Field( + default=None, + description="""The TTL for this resource. The expiration time is computed: now + TTL.""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + + +class CreateAgentEngineSandboxSnapshotConfigDict(TypedDict, total=False): + """Config for creating a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the sandbox snapshot.""" + + owner: Optional[str] + """The owner of the sandbox snapshot.""" + + ttl: Optional[str] + """The TTL for this resource. The expiration time is computed: now + TTL.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + +CreateAgentEngineSandboxSnapshotConfigOrDict = Union[ + CreateAgentEngineSandboxSnapshotConfig, CreateAgentEngineSandboxSnapshotConfigDict +] + + +class _CreateSandboxEnvironmentSnapshotRequestParameters(_common.BaseModel): + """Parameters for creating a sandbox environment snapshot.""" + + source_sandbox_environment_name: Optional[str] = Field( + default=None, description="""Name of the sandbox environment to snapshot.""" + ) + config: Optional[CreateAgentEngineSandboxSnapshotConfig] = Field( + default=None, description="""""" + ) + + +class _CreateSandboxEnvironmentSnapshotRequestParametersDict(TypedDict, total=False): + """Parameters for creating a sandbox environment snapshot.""" + + source_sandbox_environment_name: Optional[str] + """Name of the sandbox environment to snapshot.""" + + config: Optional[CreateAgentEngineSandboxSnapshotConfigDict] + """""" + + +_CreateSandboxEnvironmentSnapshotRequestParametersOrDict = Union[ + _CreateSandboxEnvironmentSnapshotRequestParameters, + _CreateSandboxEnvironmentSnapshotRequestParametersDict, +] + + +class SandboxEnvironmentSnapshot(_common.BaseModel): + """A sandbox environment snapshot.""" + + display_name: Optional[str] = Field( + default=None, + description="""The display name of the sandbox environment snapshot.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Expiration time of the sandbox environment snapshot. + """, + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironmentSnapshot was created.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the SandboxEnvironmentSnapshot. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}`""", + ) + owner: Optional[str] = Field( + default=None, + description="""Optional. Owner information for this sandbox snapshot. Different owners will have isolations on snapshot storage and identity. If not set, snapshot will be created as the default owner.""", + ) + parent_snapshot: Optional[str] = Field( + default=None, + description="""Output only. The resource name of the parent SandboxEnvironmentSnapshot. Empty if this is a root Snapshot (the first snapshot from a newly created sandbox). Can be used to reconstruct the whole ancestry tree of snapshots.""", + ) + post_snapshot_action: Optional[PostSnapshotAction] = Field( + default=None, + description="""Optional. Input only. Action to take on the source SandboxEnvironment after the snapshot is taken. This field is only used in CreateSandboxEnvironmentSnapshotRequest and it is not stored in the resource.""", + ) + size_bytes: Optional[int] = Field( + default=None, + description="""Optional. Output only. Size of the snapshot data in bytes.""", + ) + source_sandbox_environment: Optional[str] = Field( + default=None, + description="""Required. The resource name of the source SandboxEnvironment this snapshot was taken from.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for the sandbox environment snapshot. The expiration time is computed: now + TTL.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. The timestamp when this SandboxEnvironment was most recently updated.""", + ) + + +class SandboxEnvironmentSnapshotDict(TypedDict, total=False): + """A sandbox environment snapshot.""" + + display_name: Optional[str] + """The display name of the sandbox environment snapshot.""" + + expire_time: Optional[datetime.datetime] + """Expiration time of the sandbox environment snapshot. + """ + + create_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironmentSnapshot was created.""" + + name: Optional[str] + """Identifier. The resource name of the SandboxEnvironmentSnapshot. Format: `projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sandboxEnvironmentSnapshots/{sandbox_environment_snapshot}`""" + + owner: Optional[str] + """Optional. Owner information for this sandbox snapshot. Different owners will have isolations on snapshot storage and identity. If not set, snapshot will be created as the default owner.""" + + parent_snapshot: Optional[str] + """Output only. The resource name of the parent SandboxEnvironmentSnapshot. Empty if this is a root Snapshot (the first snapshot from a newly created sandbox). Can be used to reconstruct the whole ancestry tree of snapshots.""" + + post_snapshot_action: Optional[PostSnapshotAction] + """Optional. Input only. Action to take on the source SandboxEnvironment after the snapshot is taken. This field is only used in CreateSandboxEnvironmentSnapshotRequest and it is not stored in the resource.""" + + size_bytes: Optional[int] + """Optional. Output only. Size of the snapshot data in bytes.""" + + source_sandbox_environment: Optional[str] + """Required. The resource name of the source SandboxEnvironment this snapshot was taken from.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for the sandbox environment snapshot. The expiration time is computed: now + TTL.""" + + update_time: Optional[datetime.datetime] + """Output only. The timestamp when this SandboxEnvironment was most recently updated.""" + + +SandboxEnvironmentSnapshotOrDict = Union[ + SandboxEnvironmentSnapshot, SandboxEnvironmentSnapshotDict +] + + +class AgentEngineSandboxSnapshotOperation(_common.BaseModel): + """Operation that has an agent engine sandbox snapshot as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[SandboxEnvironmentSnapshot] = Field( + default=None, description="""The Agent Engine Sandbox Snapshot.""" + ) + + +class AgentEngineSandboxSnapshotOperationDict(TypedDict, total=False): + """Operation that has an agent engine sandbox snapshot as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[SandboxEnvironmentSnapshotDict] + """The Agent Engine Sandbox Snapshot.""" + + +AgentEngineSandboxSnapshotOperationOrDict = Union[ + AgentEngineSandboxSnapshotOperation, AgentEngineSandboxSnapshotOperationDict +] + + +class DeleteSandboxEnvironmentSnapshotConfig(_common.BaseModel): + """Config for deleting a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteSandboxEnvironmentSnapshotConfigDict(TypedDict, total=False): + """Config for deleting a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteSandboxEnvironmentSnapshotConfigOrDict = Union[ + DeleteSandboxEnvironmentSnapshotConfig, DeleteSandboxEnvironmentSnapshotConfigDict +] + + +class _DeleteSandboxEnvironmentSnapshotRequestParameters(_common.BaseModel): + """Parameters for deleting sandbox environment snapshots.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the sandbox environment snapshot to delete.""", + ) + config: Optional[DeleteSandboxEnvironmentSnapshotConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteSandboxEnvironmentSnapshotRequestParametersDict(TypedDict, total=False): + """Parameters for deleting sandbox environment snapshots.""" + + name: Optional[str] + """Name of the sandbox environment snapshot to delete.""" + + config: Optional[DeleteSandboxEnvironmentSnapshotConfigDict] + """""" + + +_DeleteSandboxEnvironmentSnapshotRequestParametersOrDict = Union[ + _DeleteSandboxEnvironmentSnapshotRequestParameters, + _DeleteSandboxEnvironmentSnapshotRequestParametersDict, +] + + +class DeleteSandboxEnvironmentSnapshotOperation(_common.BaseModel): + """Operation for deleting sandbox environment snapshots.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteSandboxEnvironmentSnapshotOperationDict(TypedDict, total=False): + """Operation for deleting sandbox environment snapshots.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteSandboxEnvironmentSnapshotOperationOrDict = Union[ + DeleteSandboxEnvironmentSnapshotOperation, + DeleteSandboxEnvironmentSnapshotOperationDict, +] + + +class GetSandboxEnvironmentSnapshotConfig(_common.BaseModel): + """Config for getting a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetSandboxEnvironmentSnapshotConfigDict(TypedDict, total=False): + """Config for getting a Sandbox Environment Snapshot.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetSandboxEnvironmentSnapshotConfigOrDict = Union[ + GetSandboxEnvironmentSnapshotConfig, GetSandboxEnvironmentSnapshotConfigDict +] + + +class _GetSandboxEnvironmentSnapshotRequestParameters(_common.BaseModel): + """Parameters for getting a sandbox environment snapshot.""" + + name: Optional[str] = Field( + default=None, description="""Name of the sandbox environment snapshot.""" + ) + config: Optional[GetSandboxEnvironmentSnapshotConfig] = Field( + default=None, description="""""" + ) + + +class _GetSandboxEnvironmentSnapshotRequestParametersDict(TypedDict, total=False): + """Parameters for getting a sandbox environment snapshot.""" + + name: Optional[str] + """Name of the sandbox environment snapshot.""" + + config: Optional[GetSandboxEnvironmentSnapshotConfigDict] + """""" + + +_GetSandboxEnvironmentSnapshotRequestParametersOrDict = Union[ + _GetSandboxEnvironmentSnapshotRequestParameters, + _GetSandboxEnvironmentSnapshotRequestParametersDict, +] + + +class ListSandboxEnvironmentSnapshotsConfig(_common.BaseModel): + """Config for listing sandbox environment snapshots.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request.""", + ) + + +class ListSandboxEnvironmentSnapshotsConfigDict(TypedDict, total=False): + """Config for listing sandbox environment snapshots.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request.""" + + +ListSandboxEnvironmentSnapshotsConfigOrDict = Union[ + ListSandboxEnvironmentSnapshotsConfig, ListSandboxEnvironmentSnapshotsConfigDict +] + + +class _ListSandboxEnvironmentSnapshotsRequestParameters(_common.BaseModel): + """Parameters for listing sandbox environment snapshots.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the reasoning engine to list snapshots from.""", + ) + config: Optional[ListSandboxEnvironmentSnapshotsConfig] = Field( + default=None, description="""""" + ) + + +class _ListSandboxEnvironmentSnapshotsRequestParametersDict(TypedDict, total=False): + """Parameters for listing sandbox environment snapshots.""" + + name: Optional[str] + """Name of the reasoning engine to list snapshots from.""" + + config: Optional[ListSandboxEnvironmentSnapshotsConfigDict] + """""" + + +_ListSandboxEnvironmentSnapshotsRequestParametersOrDict = Union[ + _ListSandboxEnvironmentSnapshotsRequestParameters, + _ListSandboxEnvironmentSnapshotsRequestParametersDict, +] + + +class ListSandboxEnvironmentSnapshotsResponse(_common.BaseModel): + """Response for listing sandbox environment snapshots.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + sandbox_environment_snapshots: Optional[list[SandboxEnvironmentSnapshot]] = Field( + default=None, description="""List of sandbox environment snapshots.""" + ) + + +class ListSandboxEnvironmentSnapshotsResponseDict(TypedDict, total=False): + """Response for listing sandbox environment snapshots.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + sandbox_environment_snapshots: Optional[list[SandboxEnvironmentSnapshotDict]] + """List of sandbox environment snapshots.""" + + +ListSandboxEnvironmentSnapshotsResponseOrDict = Union[ + ListSandboxEnvironmentSnapshotsResponse, ListSandboxEnvironmentSnapshotsResponseDict +] + + +class _GetAgentEngineSandboxSnapshotOperationParameters(_common.BaseModel): + """Parameters for getting an operation with a sandbox snapshot as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineSandboxSnapshotOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with a sandbox snapshot as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineSandboxSnapshotOperationParametersOrDict = Union[ + _GetAgentEngineSandboxSnapshotOperationParameters, + _GetAgentEngineSandboxSnapshotOperationParametersDict, +] + + +class CreateAgentEngineSessionConfig(_common.BaseModel): + """Config for creating a Session.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the session.""" + ) + session_state: Optional[dict[str, Any]] = Field( + default=None, + description="""Session state which stores key conversation points.""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + session_id: Optional[str] = Field( + default=None, + description="""Optional. The user defined ID to use for session, which will become the final component of the session resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""", + ) + + +class CreateAgentEngineSessionConfigDict(TypedDict, total=False): + """Config for creating a Session.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the session.""" + + session_state: Optional[dict[str, Any]] + """Session state which stores key conversation points.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + + labels: Optional[dict[str, str]] + """Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + session_id: Optional[str] + """Optional. The user defined ID to use for session, which will become the final component of the session resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""" + + +CreateAgentEngineSessionConfigOrDict = Union[ + CreateAgentEngineSessionConfig, CreateAgentEngineSessionConfigDict +] + + +class _CreateAgentEngineSessionRequestParameters(_common.BaseModel): + """Parameters for creating Agent Engine Sessions.""" + + name: Optional[str] = Field( + default=None, + description="""Name of the agent engine to create the session under.""", + ) + user_id: Optional[str] = Field( + default=None, description="""The user ID of the session.""" + ) + config: Optional[CreateAgentEngineSessionConfig] = Field( + default=None, description="""""" + ) + + +class _CreateAgentEngineSessionRequestParametersDict(TypedDict, total=False): + """Parameters for creating Agent Engine Sessions.""" + + name: Optional[str] + """Name of the agent engine to create the session under.""" + + user_id: Optional[str] + """The user ID of the session.""" + + config: Optional[CreateAgentEngineSessionConfigDict] + """""" + + +_CreateAgentEngineSessionRequestParametersOrDict = Union[ + _CreateAgentEngineSessionRequestParameters, + _CreateAgentEngineSessionRequestParametersDict, +] + + +class Session(_common.BaseModel): + """A session.""" + + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when the session was created.""", + ) + display_name: Optional[str] = Field( + default=None, description="""Optional. The display name of the session.""" + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input. The minimum value is 24 hours from the time of creation.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""", + ) + session_state: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Session specific memory which stores key conversation points.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for this session. The minimum value is 24 hours.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when the session was updated.""", + ) + user_id: Optional[str] = Field( + default=None, + description="""Required. Immutable. String id provided by the user""", + ) + + +class SessionDict(TypedDict, total=False): + """A session.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when the session was created.""" + + display_name: Optional[str] + """Optional. The display name of the session.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this session is considered expired. This is *always* provided on output, regardless of what was sent on input. The minimum value is 24 hours from the time of creation.""" + + labels: Optional[dict[str, str]] + """The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + name: Optional[str] + """Identifier. The resource name of the session. Format: 'projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}'.""" + + session_state: Optional[dict[str, Any]] + """Optional. Session specific memory which stores key conversation points.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for this session. The minimum value is 24 hours.""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when the session was updated.""" + + user_id: Optional[str] + """Required. Immutable. String id provided by the user""" + + +SessionOrDict = Union[Session, SessionDict] + + +class AgentEngineSessionOperation(_common.BaseModel): + """Operation that has an agent engine session as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[Session] = Field( + default=None, description="""The Agent Engine Session.""" + ) + + +class AgentEngineSessionOperationDict(TypedDict, total=False): + """Operation that has an agent engine session as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[SessionDict] + """The Agent Engine Session.""" + + +AgentEngineSessionOperationOrDict = Union[ + AgentEngineSessionOperation, AgentEngineSessionOperationDict +] + + +class DeleteAgentEngineSessionConfig(_common.BaseModel): + """Config for deleting an Agent Engine Session.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class DeleteAgentEngineSessionConfigDict(TypedDict, total=False): + """Config for deleting an Agent Engine Session.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +DeleteAgentEngineSessionConfigOrDict = Union[ + DeleteAgentEngineSessionConfig, DeleteAgentEngineSessionConfigDict +] + + +class _DeleteAgentEngineSessionRequestParameters(_common.BaseModel): + """Parameters for deleting agent engine sessions.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine session to delete.""" + ) + config: Optional[DeleteAgentEngineSessionConfig] = Field( + default=None, description="""""" + ) + + +class _DeleteAgentEngineSessionRequestParametersDict(TypedDict, total=False): + """Parameters for deleting agent engine sessions.""" + + name: Optional[str] + """Name of the agent engine session to delete.""" + + config: Optional[DeleteAgentEngineSessionConfigDict] + """""" + + +_DeleteAgentEngineSessionRequestParametersOrDict = Union[ + _DeleteAgentEngineSessionRequestParameters, + _DeleteAgentEngineSessionRequestParametersDict, +] + + +class DeleteAgentEngineSessionOperation(_common.BaseModel): + """Operation for deleting agent engine sessions.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeleteAgentEngineSessionOperationDict(TypedDict, total=False): + """Operation for deleting agent engine sessions.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeleteAgentEngineSessionOperationOrDict = Union[ + DeleteAgentEngineSessionOperation, DeleteAgentEngineSessionOperationDict +] + + +class GetAgentEngineSessionConfig(_common.BaseModel): + """Config for getting an Agent Engine Session.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetAgentEngineSessionConfigDict(TypedDict, total=False): + """Config for getting an Agent Engine Session.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetAgentEngineSessionConfigOrDict = Union[ + GetAgentEngineSessionConfig, GetAgentEngineSessionConfigDict +] + + +class _GetAgentEngineSessionRequestParameters(_common.BaseModel): + """Parameters for getting an agent engine session.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine session.""" + ) + config: Optional[GetAgentEngineSessionConfig] = Field( + default=None, description="""""" + ) + + +class _GetAgentEngineSessionRequestParametersDict(TypedDict, total=False): + """Parameters for getting an agent engine session.""" + + name: Optional[str] + """Name of the agent engine session.""" + + config: Optional[GetAgentEngineSessionConfigDict] + """""" + + +_GetAgentEngineSessionRequestParametersOrDict = Union[ + _GetAgentEngineSessionRequestParameters, _GetAgentEngineSessionRequestParametersDict +] + + +class ListAgentEngineSessionsConfig(_common.BaseModel): + """Config for listing agent engine sessions.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineSessionsConfigDict(TypedDict, total=False): + """Config for listing agent engine sessions.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineSessionsConfigOrDict = Union[ + ListAgentEngineSessionsConfig, ListAgentEngineSessionsConfigDict +] + + +class _ListAgentEngineSessionsRequestParameters(_common.BaseModel): + """Parameters for listing agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine.""" + ) + config: Optional[ListAgentEngineSessionsConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineSessionsRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engines.""" + + name: Optional[str] + """Name of the agent engine.""" + + config: Optional[ListAgentEngineSessionsConfigDict] + """""" + + +_ListAgentEngineSessionsRequestParametersOrDict = Union[ + _ListAgentEngineSessionsRequestParameters, + _ListAgentEngineSessionsRequestParametersDict, +] + + +class ListReasoningEnginesSessionsResponse(_common.BaseModel): + """Response for listing agent engine sessions.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + sessions: Optional[list[Session]] = Field( + default=None, description="""List of agent engine sessions.""" + ) + + +class ListReasoningEnginesSessionsResponseDict(TypedDict, total=False): + """Response for listing agent engine sessions.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + sessions: Optional[list[SessionDict]] + """List of agent engine sessions.""" + + +ListReasoningEnginesSessionsResponseOrDict = Union[ + ListReasoningEnginesSessionsResponse, ListReasoningEnginesSessionsResponseDict +] + + +class _GetAgentEngineSessionOperationParameters(_common.BaseModel): + """Parameters for getting an operation with a session as a response.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetAgentEngineOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetAgentEngineSessionOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation with a session as a response.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetAgentEngineOperationConfigDict] + """Used to override the default configuration.""" + + +_GetAgentEngineSessionOperationParametersOrDict = Union[ + _GetAgentEngineSessionOperationParameters, + _GetAgentEngineSessionOperationParametersDict, +] + + +class UpdateAgentEngineSessionConfig(_common.BaseModel): + """Config for updating agent engine session.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the session.""" + ) + session_state: Optional[dict[str, Any]] = Field( + default=None, + description="""Session state which stores key conversation points.""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Waits for the operation to complete before returning.""", + ) + ttl: Optional[str] = Field( + default=None, + description="""Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""", + ) + expire_time: Optional[datetime.datetime] = Field( + default=None, + description="""Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""", + ) + session_id: Optional[str] = Field( + default=None, + description="""Optional. The user defined ID to use for session, which will become the final component of the session resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""", + ) + update_mask: Optional[str] = Field( + default=None, + description="""The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""", + ) + user_id: Optional[str] = Field( + default=None, description="""User ID of the agent engine session to update.""" + ) + + +class UpdateAgentEngineSessionConfigDict(TypedDict, total=False): + """Config for updating agent engine session.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + display_name: Optional[str] + """The display name of the session.""" + + session_state: Optional[dict[str, Any]] + """Session state which stores key conversation points.""" + + wait_for_completion: Optional[bool] + """Waits for the operation to complete before returning.""" + + ttl: Optional[str] + """Optional. Input only. The TTL for this resource. + + The expiration time is computed: now + TTL.""" + + expire_time: Optional[datetime.datetime] + """Optional. Timestamp of when this resource is considered expired. This is *always* provided on output, regardless of what `expiration` was sent on input.""" + + labels: Optional[dict[str, str]] + """Optional. The labels with user-defined metadata to organize your Sessions. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. See https://goo.gl/xmQnxf for more information and examples of labels.""" + + session_id: Optional[str] + """Optional. The user defined ID to use for session, which will become the final component of the session resource name. If not provided, Vertex AI will generate a value for this ID. This value may be up to 63 characters, and valid characters are `[a-z0-9-]`. The first character must be a letter, and the last character must be a letter or number.""" + + update_mask: Optional[str] + """The update mask to apply. For the `FieldMask` definition, see + https://protobuf.dev/reference/protobuf/google.protobuf/#field-mask.""" + + user_id: Optional[str] + """User ID of the agent engine session to update.""" + + +UpdateAgentEngineSessionConfigOrDict = Union[ + UpdateAgentEngineSessionConfig, UpdateAgentEngineSessionConfigDict +] + + +class _UpdateAgentEngineSessionRequestParameters(_common.BaseModel): + """Parameters for updating agent engine sessions.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine session to update.""" + ) + config: Optional[UpdateAgentEngineSessionConfig] = Field( + default=None, description="""""" + ) + + +class _UpdateAgentEngineSessionRequestParametersDict(TypedDict, total=False): + """Parameters for updating agent engine sessions.""" + + name: Optional[str] + """Name of the agent engine session to update.""" + + config: Optional[UpdateAgentEngineSessionConfigDict] + """""" + + +_UpdateAgentEngineSessionRequestParametersOrDict = Union[ + _UpdateAgentEngineSessionRequestParameters, + _UpdateAgentEngineSessionRequestParametersDict, +] + + +class EventActions(_common.BaseModel): + """Actions are parts of events that are executed by the agent.""" + + artifact_delta: Optional[dict[str, int]] = Field( + default=None, + description="""Optional. Indicates that the event is updating an artifact. key is the filename, value is the version.""", + ) + escalate: Optional[bool] = Field( + default=None, + description="""Optional. The agent is escalating to a higher level agent.""", + ) + requested_auth_configs: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Will only be set by a tool response indicating tool request euc. Struct key is the function call id since one function call response (from model) could correspond to multiple function calls. Struct value is the required auth config, which can be another struct.""", + ) + skip_summarization: Optional[bool] = Field( + default=None, + description="""Optional. If true, it won't call model to summarize function response. Only used for function_response event.""", + ) + state_delta: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Indicates that the event is updating the state with the given delta.""", + ) + transfer_agent: Optional[str] = Field( + default=None, + description="""Optional. If set, the event transfers to the specified agent.""", + ) + + +class EventActionsDict(TypedDict, total=False): + """Actions are parts of events that are executed by the agent.""" + + artifact_delta: Optional[dict[str, int]] + """Optional. Indicates that the event is updating an artifact. key is the filename, value is the version.""" + + escalate: Optional[bool] + """Optional. The agent is escalating to a higher level agent.""" + + requested_auth_configs: Optional[dict[str, Any]] + """Optional. Will only be set by a tool response indicating tool request euc. Struct key is the function call id since one function call response (from model) could correspond to multiple function calls. Struct value is the required auth config, which can be another struct.""" + + skip_summarization: Optional[bool] + """Optional. If true, it won't call model to summarize function response. Only used for function_response event.""" + + state_delta: Optional[dict[str, Any]] + """Optional. Indicates that the event is updating the state with the given delta.""" + + transfer_agent: Optional[str] + """Optional. If set, the event transfers to the specified agent.""" + + +EventActionsOrDict = Union[EventActions, EventActionsDict] + + +class EventMetadata(_common.BaseModel): + """Metadata relating to a LLM response event.""" + + grounding_metadata: Optional[genai_types.GroundingMetadata] = Field( + default=None, + description="""Optional. Metadata returned to client when grounding is enabled.""", + ) + branch: Optional[str] = Field( + default=None, + description="""Optional. The branch of the event. The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of agent_2, and agent_2 is the parent of agent_3. Branch is used when multiple child agents shouldn't see their siblings' conversation history.""", + ) + custom_metadata: Optional[dict[str, Any]] = Field( + default=None, description="""The custom metadata of the LlmResponse.""" + ) + interrupted: Optional[bool] = Field( + default=None, + description="""Optional. Flag indicating that LLM was interrupted when generating the content. Usually it's due to user interruption during a bidi streaming.""", + ) + long_running_tool_ids: Optional[list[str]] = Field( + default=None, + description="""Optional. Set of ids of the long running function calls. Agent client will know from this field about which function call is long running. Only valid for function call event.""", + ) + partial: Optional[bool] = Field( + default=None, + description="""Optional. Indicates whether the text content is part of a unfinished text stream. Only used for streaming mode and when the content is plain text.""", + ) + turn_complete: Optional[bool] = Field( + default=None, + description="""Optional. Indicates whether the response from the model is complete. Only used for streaming mode.""", + ) + input_transcription: Optional[genai_types.Transcription] = Field( + default=None, description="""Optional. Audio transcription of user input.""" + ) + output_transcription: Optional[genai_types.Transcription] = Field( + default=None, description="""Optional. Audio transcription of model output.""" + ) + + +class EventMetadataDict(TypedDict, total=False): + """Metadata relating to a LLM response event.""" + + grounding_metadata: Optional[genai_types.GroundingMetadataDict] + """Optional. Metadata returned to client when grounding is enabled.""" + + branch: Optional[str] + """Optional. The branch of the event. The format is like agent_1.agent_2.agent_3, where agent_1 is the parent of agent_2, and agent_2 is the parent of agent_3. Branch is used when multiple child agents shouldn't see their siblings' conversation history.""" + + custom_metadata: Optional[dict[str, Any]] + """The custom metadata of the LlmResponse.""" + + interrupted: Optional[bool] + """Optional. Flag indicating that LLM was interrupted when generating the content. Usually it's due to user interruption during a bidi streaming.""" + + long_running_tool_ids: Optional[list[str]] + """Optional. Set of ids of the long running function calls. Agent client will know from this field about which function call is long running. Only valid for function call event.""" + + partial: Optional[bool] + """Optional. Indicates whether the text content is part of a unfinished text stream. Only used for streaming mode and when the content is plain text.""" + + turn_complete: Optional[bool] + """Optional. Indicates whether the response from the model is complete. Only used for streaming mode.""" + + input_transcription: Optional[genai_types.TranscriptionDict] + """Optional. Audio transcription of user input.""" + + output_transcription: Optional[genai_types.TranscriptionDict] + """Optional. Audio transcription of model output.""" + + +EventMetadataOrDict = Union[EventMetadata, EventMetadataDict] + + +class AppendAgentEngineSessionEventConfig(_common.BaseModel): + """Config for appending agent engine session event.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + content: Optional[genai_types.Content] = Field( + default=None, description="""The content of the session event.""" + ) + actions: Optional[EventActions] = Field( + default=None, + description="""Actions are parts of events that are related to the session event.""", + ) + error_code: Optional[str] = Field( + default=None, description="""The error code of the session event.""" + ) + error_message: Optional[str] = Field( + default=None, description="""The error message of the session event.""" + ) + event_metadata: Optional[EventMetadata] = Field( + default=None, description="""Metadata relating to the session event.""" + ) + raw_event: Optional[dict[str, Any]] = Field( + default=None, + description="""Weakly typed raw event data in proto struct format.""", + ) + + +class AppendAgentEngineSessionEventConfigDict(TypedDict, total=False): + """Config for appending agent engine session event.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + content: Optional[genai_types.ContentDict] + """The content of the session event.""" + + actions: Optional[EventActionsDict] + """Actions are parts of events that are related to the session event.""" + + error_code: Optional[str] + """The error code of the session event.""" + + error_message: Optional[str] + """The error message of the session event.""" + + event_metadata: Optional[EventMetadataDict] + """Metadata relating to the session event.""" + + raw_event: Optional[dict[str, Any]] + """Weakly typed raw event data in proto struct format.""" + + +AppendAgentEngineSessionEventConfigOrDict = Union[ + AppendAgentEngineSessionEventConfig, AppendAgentEngineSessionEventConfigDict +] + + +class _AppendAgentEngineSessionEventRequestParameters(_common.BaseModel): + """Parameters for appending agent engines.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine session.""" + ) + author: Optional[str] = Field( + default=None, description="""Author of the agent engine session event.""" + ) + invocation_id: Optional[str] = Field( + default=None, description="""Invocation ID of the agent engine.""" + ) + timestamp: Optional[datetime.datetime] = Field( + default=None, description="""Timestamp indicating when the event was created.""" + ) + config: Optional[AppendAgentEngineSessionEventConfig] = Field( + default=None, description="""""" + ) + + +class _AppendAgentEngineSessionEventRequestParametersDict(TypedDict, total=False): + """Parameters for appending agent engines.""" + + name: Optional[str] + """Name of the agent engine session.""" + + author: Optional[str] + """Author of the agent engine session event.""" + + invocation_id: Optional[str] + """Invocation ID of the agent engine.""" + + timestamp: Optional[datetime.datetime] + """Timestamp indicating when the event was created.""" + + config: Optional[AppendAgentEngineSessionEventConfigDict] + """""" + + +_AppendAgentEngineSessionEventRequestParametersOrDict = Union[ + _AppendAgentEngineSessionEventRequestParameters, + _AppendAgentEngineSessionEventRequestParametersDict, +] + + +class AppendAgentEngineSessionEventResponse(_common.BaseModel): + """Response for appending agent engine session event.""" + + pass + + +class AppendAgentEngineSessionEventResponseDict(TypedDict, total=False): + """Response for appending agent engine session event.""" + + pass + + +AppendAgentEngineSessionEventResponseOrDict = Union[ + AppendAgentEngineSessionEventResponse, AppendAgentEngineSessionEventResponseDict +] + + +class ListAgentEngineSessionEventsConfig(_common.BaseModel): + """Config for listing agent engine session events.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListAgentEngineSessionEventsConfigDict(TypedDict, total=False): + """Config for listing agent engine session events.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListAgentEngineSessionEventsConfigOrDict = Union[ + ListAgentEngineSessionEventsConfig, ListAgentEngineSessionEventsConfigDict +] + + +class _ListAgentEngineSessionEventsRequestParameters(_common.BaseModel): + """Parameters for listing agent engine session events.""" + + name: Optional[str] = Field( + default=None, description="""Name of the agent engine session.""" + ) + config: Optional[ListAgentEngineSessionEventsConfig] = Field( + default=None, description="""""" + ) + + +class _ListAgentEngineSessionEventsRequestParametersDict(TypedDict, total=False): + """Parameters for listing agent engine session events.""" + + name: Optional[str] + """Name of the agent engine session.""" + + config: Optional[ListAgentEngineSessionEventsConfigDict] + """""" + + +_ListAgentEngineSessionEventsRequestParametersOrDict = Union[ + _ListAgentEngineSessionEventsRequestParameters, + _ListAgentEngineSessionEventsRequestParametersDict, +] + + +class SessionEvent(_common.BaseModel): + """A session event.""" + + content: Optional[genai_types.Content] = Field( + default=None, + description="""Optional. Content of the event provided by the author.""", + ) + actions: Optional[EventActions] = Field( + default=None, description="""Optional. Actions executed by the agent.""" + ) + author: Optional[str] = Field( + default=None, + description="""Required. The name of the agent that sent the event, or user.""", + ) + error_code: Optional[str] = Field( + default=None, + description="""Optional. Error code if the response is an error. Code varies by model.""", + ) + error_message: Optional[str] = Field( + default=None, + description="""Optional. Error message if the response is an error.""", + ) + event_metadata: Optional[EventMetadata] = Field( + default=None, description="""Optional. Metadata relating to this event.""" + ) + invocation_id: Optional[str] = Field( + default=None, + description="""Required. The invocation id of the event, multiple events can have the same invocation id.""", + ) + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the event. Format:`projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}`.""", + ) + timestamp: Optional[datetime.datetime] = Field( + default=None, + description="""Required. Timestamp when the event was created on client side.""", + ) + raw_event: Optional[dict[str, Any]] = Field( + default=None, + description="""Optional. Weakly typed raw event data in proto struct format.""", + ) + + +class SessionEventDict(TypedDict, total=False): + """A session event.""" + + content: Optional[genai_types.ContentDict] + """Optional. Content of the event provided by the author.""" + + actions: Optional[EventActionsDict] + """Optional. Actions executed by the agent.""" + + author: Optional[str] + """Required. The name of the agent that sent the event, or user.""" + + error_code: Optional[str] + """Optional. Error code if the response is an error. Code varies by model.""" + + error_message: Optional[str] + """Optional. Error message if the response is an error.""" + + event_metadata: Optional[EventMetadataDict] + """Optional. Metadata relating to this event.""" + + invocation_id: Optional[str] + """Required. The invocation id of the event, multiple events can have the same invocation id.""" + + name: Optional[str] + """Identifier. The resource name of the event. Format:`projects/{project}/locations/{location}/reasoningEngines/{reasoning_engine}/sessions/{session}/events/{event}`.""" + + timestamp: Optional[datetime.datetime] + """Required. Timestamp when the event was created on client side.""" + + raw_event: Optional[dict[str, Any]] + """Optional. Weakly typed raw event data in proto struct format.""" + + +SessionEventOrDict = Union[SessionEvent, SessionEventDict] + + +class ListAgentEngineSessionEventsResponse(_common.BaseModel): + """Response for listing agent engine session events.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + session_events: Optional[list[SessionEvent]] = Field( + default=None, description="""List of session events.""" + ) + + +class ListAgentEngineSessionEventsResponseDict(TypedDict, total=False): + """Response for listing agent engine session events.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + session_events: Optional[list[SessionEventDict]] + """List of session events.""" + + +ListAgentEngineSessionEventsResponseOrDict = Union[ + ListAgentEngineSessionEventsResponse, ListAgentEngineSessionEventsResponseDict +] + + +class GeminiExample(_common.BaseModel): + """Represents a Gemini example.""" + + model: Optional[str] = Field( + default=None, description="""The model used to generate the Gemini example.""" + ) + contents: Optional[list[genai_types.Content]] = Field( + default=None, description="""Contents of the Gemini example.""" + ) + system_instruction: Optional[genai_types.Content] = Field( + default=None, description="""System instruction for the Gemini example.""" + ) + cached_content: Optional[str] = Field( + default=None, description="""Cached content for the Gemini example.""" + ) + tools: Optional[list[genai_types.Tool]] = Field( + default=None, description="""Tools for the Gemini example.""" + ) + tool_config: Optional[genai_types.ToolConfig] = Field( + default=None, description="""Tools for the Gemini example.""" + ) + safety_settings: Optional[list[genai_types.SafetySetting]] = Field( + default=None, description="""Safety settings for the Gemini example.""" + ) + generation_config: Optional[genai_types.GenerationConfig] = Field( + default=None, description="""Generation config for the Gemini example.""" + ) + model_armor_config: Optional[genai_types.ModelArmorConfig] = Field( + default=None, + description="""Optional. Settings for prompt and response sanitization using the Model Armor service. If supplied, safety_settings must not be supplied.""", + ) + + +class GeminiExampleDict(TypedDict, total=False): + """Represents a Gemini example.""" + + model: Optional[str] + """The model used to generate the Gemini example.""" + + contents: Optional[list[genai_types.ContentDict]] + """Contents of the Gemini example.""" + + system_instruction: Optional[genai_types.ContentDict] + """System instruction for the Gemini example.""" + + cached_content: Optional[str] + """Cached content for the Gemini example.""" + + tools: Optional[list[genai_types.ToolDict]] + """Tools for the Gemini example.""" + + tool_config: Optional[genai_types.ToolConfigDict] + """Tools for the Gemini example.""" + + safety_settings: Optional[list[genai_types.SafetySettingDict]] + """Safety settings for the Gemini example.""" + + generation_config: Optional[genai_types.GenerationConfigDict] + """Generation config for the Gemini example.""" + + model_armor_config: Optional[genai_types.ModelArmorConfigDict] + """Optional. Settings for prompt and response sanitization using the Model Armor service. If supplied, safety_settings must not be supplied.""" + + +GeminiExampleOrDict = Union[GeminiExample, GeminiExampleDict] + + +class GeminiTemplateConfig(_common.BaseModel): + """Represents a Gemini template config.""" + + gemini_example: Optional[GeminiExample] = Field( + default=None, + description="""Required. The template that will be used for assembling the request to use for downstream applications.""", + ) + field_mapping: Optional[dict[str, str]] = Field( + default=None, + description="""Required. Map of template parameters to the columns in the dataset table.""", + ) + + +class GeminiTemplateConfigDict(TypedDict, total=False): + """Represents a Gemini template config.""" + + gemini_example: Optional[GeminiExampleDict] + """Required. The template that will be used for assembling the request to use for downstream applications.""" + + field_mapping: Optional[dict[str, str]] + """Required. Map of template parameters to the columns in the dataset table.""" + + +GeminiTemplateConfigOrDict = Union[GeminiTemplateConfig, GeminiTemplateConfigDict] + + +class GeminiRequestReadConfig(_common.BaseModel): + """Represents the config for reading Gemini requests.""" + + template_config: Optional[GeminiTemplateConfig] = Field( + default=None, description="""Gemini request template with placeholders.""" + ) + assembled_request_column_name: Optional[str] = Field( + default=None, + description="""Column name in the underlying BigQuery table that contains already fully assembled Gemini requests.""", + ) + + @classmethod + def single_turn_template( + cls, + *, + prompt: str, + response: Optional[str] = None, + system_instruction: Optional[str] = None, + model: Optional[str] = None, + cached_content: Optional[str] = None, + tools: Optional[list[Union[genai_types.Tool, dict[str, Any]]]] = None, + tool_config: Optional[Union[genai_types.ToolConfig, dict[str, Any]]] = None, + safety_settings: Optional[ + list[Union[genai_types.SafetySetting, dict[str, Any]]] + ] = None, + generation_config: Optional[ + Union[genai_types.GenerationConfig, dict[str, Any]] + ] = None, + field_mapping: Optional[dict[str, str]] = None, + ) -> "GeminiRequestReadConfig": + """Constructs a GeminiRequestReadConfig object for single-turn cases. + + Example: + read_config = GeminiRequestReadConfig.single_turn_template( + prompt="Which flower is this {flower_image}?", + response="This is a {label}.", + system_instruction="You are a botanical classifier." + ) + + Args: + prompt: Required. User input. + response: Optional. Model response to user input. + system_instruction: Optional. System instructions for the model. + model: Optional. The model to use for the GeminiExample. + cached_content: Optional. The cached content to use for the GeminiExample. + tools: Optional. The tools to use for the GeminiExample. + tool_config: Optional. The tool config to use for the GeminiExample. + safety_settings: Optional. The safety settings to use for the GeminiExample. + generation_config: Optional. The generation config to use for the GeminiExample. + field_mapping: Optional. Mapping of placeholders to dataset columns. + + Returns: + A GeminiRequestReadConfig object. + """ + contents = [] + contents.append( + genai_types.Content( + role="user", + parts=[ + genai_types.Part.from_text(text=prompt), + ], + ) + ) + if response: + contents.append( + genai_types.Content( + role="model", + parts=[ + genai_types.Part.from_text(text=response), + ], + ) + ) + + system_instruction_content = None + if system_instruction: + system_instruction_content = genai_types.Content( + parts=[ + genai_types.Part.from_text(text=system_instruction), + ], + ) + + return cls( + template_config=GeminiTemplateConfig( + gemini_example=GeminiExample( + model=model, + contents=contents, + system_instruction=system_instruction_content, + cached_content=cached_content, + tools=tools, + tool_config=tool_config, + safety_settings=safety_settings, + generation_config=generation_config, + ), + field_mapping=field_mapping, + ), + ) + + +class GeminiRequestReadConfigDict(TypedDict, total=False): + """Represents the config for reading Gemini requests.""" + + template_config: Optional[GeminiTemplateConfigDict] + """Gemini request template with placeholders.""" + + assembled_request_column_name: Optional[str] + """Column name in the underlying BigQuery table that contains already fully assembled Gemini requests.""" + + +GeminiRequestReadConfigOrDict = Union[ + GeminiRequestReadConfig, GeminiRequestReadConfigDict +] + + +class AssembleDatasetConfig(_common.BaseModel): + """Config for assembling a multimodal dataset resource.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the assemble dataset request in seconds. If not + set, the default timeout is 90 seconds.""", + ) + + +class AssembleDatasetConfigDict(TypedDict, total=False): + """Config for assembling a multimodal dataset resource.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + timeout: Optional[int] + """The timeout for the assemble dataset request in seconds. If not + set, the default timeout is 90 seconds.""" + + +AssembleDatasetConfigOrDict = Union[AssembleDatasetConfig, AssembleDatasetConfigDict] + + +class _AssembleDatasetParameters(_common.BaseModel): + """Parameters for assembling a multimodal dataset resource.""" + + name: Optional[str] = Field(default=None, description="""""") + gemini_request_read_config: Optional[GeminiRequestReadConfig] = Field( + default=None, description="""""" + ) + config: Optional[AssembleDatasetConfig] = Field(default=None, description="""""") + + +class _AssembleDatasetParametersDict(TypedDict, total=False): + """Parameters for assembling a multimodal dataset resource.""" + + name: Optional[str] + """""" + + gemini_request_read_config: Optional[GeminiRequestReadConfigDict] + """""" + + config: Optional[AssembleDatasetConfigDict] + """""" + + +_AssembleDatasetParametersOrDict = Union[ + _AssembleDatasetParameters, _AssembleDatasetParametersDict +] + + +class MultimodalDatasetOperation(_common.BaseModel): + """Represents the create dataset operation.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[dict[str, Any]] = Field( + default=None, description="""The result of the dataset operation.""" + ) + + +class MultimodalDatasetOperationDict(TypedDict, total=False): + """Represents the create dataset operation.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[dict[str, Any]] + """The result of the dataset operation.""" + + +MultimodalDatasetOperationOrDict = Union[ + MultimodalDatasetOperation, MultimodalDatasetOperationDict +] + + +class TuningResourceUsageAssessmentConfig(_common.BaseModel): + """Config for tuning resource usage assessment.""" + + model_name: Optional[str] = Field(default=None, description="""""") + + +class TuningResourceUsageAssessmentConfigDict(TypedDict, total=False): + """Config for tuning resource usage assessment.""" + + model_name: Optional[str] + """""" + + +TuningResourceUsageAssessmentConfigOrDict = Union[ + TuningResourceUsageAssessmentConfig, TuningResourceUsageAssessmentConfigDict +] + + +class TuningValidationAssessmentConfig(_common.BaseModel): + """Config for tuning validation assessment.""" + + model_name: Optional[str] = Field(default=None, description="""""") + dataset_usage: Optional[str] = Field(default=None, description="""""") + + +class TuningValidationAssessmentConfigDict(TypedDict, total=False): + """Config for tuning validation assessment.""" + + model_name: Optional[str] + """""" + + dataset_usage: Optional[str] + """""" + + +TuningValidationAssessmentConfigOrDict = Union[ + TuningValidationAssessmentConfig, TuningValidationAssessmentConfigDict +] + + +class BatchPredictionResourceUsageAssessmentConfig(_common.BaseModel): + """Config for batch prediction resource usage assessment.""" + + model_name: Optional[str] = Field(default=None, description="""""") + + +class BatchPredictionResourceUsageAssessmentConfigDict(TypedDict, total=False): + """Config for batch prediction resource usage assessment.""" + + model_name: Optional[str] + """""" + + +BatchPredictionResourceUsageAssessmentConfigOrDict = Union[ + BatchPredictionResourceUsageAssessmentConfig, + BatchPredictionResourceUsageAssessmentConfigDict, +] + + +class BatchPredictionValidationAssessmentConfig(_common.BaseModel): + """Config for batch prediction validation assessment.""" + + model_name: Optional[str] = Field(default=None, description="""""") + + +class BatchPredictionValidationAssessmentConfigDict(TypedDict, total=False): + """Config for batch prediction validation assessment.""" + + model_name: Optional[str] + """""" + + +BatchPredictionValidationAssessmentConfigOrDict = Union[ + BatchPredictionValidationAssessmentConfig, + BatchPredictionValidationAssessmentConfigDict, +] + + +class AssessDatasetConfig(_common.BaseModel): + """Config for assessing a multimodal dataset resource.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the assess dataset request in seconds. If not set, + the default timeout is 90 seconds.""", + ) + + +class AssessDatasetConfigDict(TypedDict, total=False): + """Config for assessing a multimodal dataset resource.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + timeout: Optional[int] + """The timeout for the assess dataset request in seconds. If not set, + the default timeout is 90 seconds.""" + + +AssessDatasetConfigOrDict = Union[AssessDatasetConfig, AssessDatasetConfigDict] + + +class _AssessDatasetParameters(_common.BaseModel): + """Parameters for assessing a multimodal dataset resource.""" + + name: Optional[str] = Field(default=None, description="""""") + gemini_request_read_config: Optional[GeminiRequestReadConfig] = Field( + default=None, description="""""" + ) + tuning_resource_usage_assessment_config: Optional[ + TuningResourceUsageAssessmentConfig + ] = Field(default=None, description="""""") + tuning_validation_assessment_config: Optional[TuningValidationAssessmentConfig] = ( + Field(default=None, description="""""") + ) + batch_prediction_resource_usage_assessment_config: Optional[ + BatchPredictionResourceUsageAssessmentConfig + ] = Field(default=None, description="""""") + batch_prediction_validation_assessment_config: Optional[ + BatchPredictionValidationAssessmentConfig + ] = Field(default=None, description="""""") + config: Optional[AssessDatasetConfig] = Field(default=None, description="""""") + + +class _AssessDatasetParametersDict(TypedDict, total=False): + """Parameters for assessing a multimodal dataset resource.""" + + name: Optional[str] + """""" + + gemini_request_read_config: Optional[GeminiRequestReadConfigDict] + """""" + + tuning_resource_usage_assessment_config: Optional[ + TuningResourceUsageAssessmentConfigDict + ] + """""" + + tuning_validation_assessment_config: Optional[TuningValidationAssessmentConfigDict] + """""" + + batch_prediction_resource_usage_assessment_config: Optional[ + BatchPredictionResourceUsageAssessmentConfigDict + ] + """""" + + batch_prediction_validation_assessment_config: Optional[ + BatchPredictionValidationAssessmentConfigDict + ] + """""" + + config: Optional[AssessDatasetConfigDict] + """""" + + +_AssessDatasetParametersOrDict = Union[ + _AssessDatasetParameters, _AssessDatasetParametersDict +] + + +class SchemaTablesDatasetMetadataBigQuerySource(_common.BaseModel): + """Represents the BigQuery source for multimodal dataset metadata.""" + + uri: Optional[str] = Field( + default=None, + description="""The URI of the BigQuery table. This accepts the table name with or without the bq:// prefix.""", + ) + + +class SchemaTablesDatasetMetadataBigQuerySourceDict(TypedDict, total=False): + """Represents the BigQuery source for multimodal dataset metadata.""" + + uri: Optional[str] + """The URI of the BigQuery table. This accepts the table name with or without the bq:// prefix.""" + + +SchemaTablesDatasetMetadataBigQuerySourceOrDict = Union[ + SchemaTablesDatasetMetadataBigQuerySource, + SchemaTablesDatasetMetadataBigQuerySourceDict, +] + + +class SchemaTablesDatasetMetadataInputConfig(_common.BaseModel): + """Represents the input config for multimodal dataset metadata.""" + + bigquery_source: Optional[SchemaTablesDatasetMetadataBigQuerySource] = Field( + default=None, + description="""The BigQuery source for multimodal dataset metadata.""", + ) + + +class SchemaTablesDatasetMetadataInputConfigDict(TypedDict, total=False): + """Represents the input config for multimodal dataset metadata.""" + + bigquery_source: Optional[SchemaTablesDatasetMetadataBigQuerySourceDict] + """The BigQuery source for multimodal dataset metadata.""" + + +SchemaTablesDatasetMetadataInputConfigOrDict = Union[ + SchemaTablesDatasetMetadataInputConfig, SchemaTablesDatasetMetadataInputConfigDict +] + + +class SchemaTablesDatasetMetadata(_common.BaseModel): + """Represents the metadata schema for multimodal dataset metadata.""" + + input_config: Optional[SchemaTablesDatasetMetadataInputConfig] = Field( + default=None, + description="""The input config for multimodal dataset metadata.""", + ) + gemini_request_read_config: Optional[GeminiRequestReadConfig] = Field( + default=None, + description="""The Gemini request read config for the multimodal dataset.""", + ) + + +class SchemaTablesDatasetMetadataDict(TypedDict, total=False): + """Represents the metadata schema for multimodal dataset metadata.""" + + input_config: Optional[SchemaTablesDatasetMetadataInputConfigDict] + """The input config for multimodal dataset metadata.""" + + gemini_request_read_config: Optional[GeminiRequestReadConfigDict] + """The Gemini request read config for the multimodal dataset.""" + + +SchemaTablesDatasetMetadataOrDict = Union[ + SchemaTablesDatasetMetadata, SchemaTablesDatasetMetadataDict +] + + +class CreateMultimodalDatasetConfig(_common.BaseModel): + """Config for creating a dataset resource to store multimodal dataset.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the create dataset request in seconds. If not set, + the default timeout is 90 seconds.""", + ) + + +class CreateMultimodalDatasetConfigDict(TypedDict, total=False): + """Config for creating a dataset resource to store multimodal dataset.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + timeout: Optional[int] + """The timeout for the create dataset request in seconds. If not set, + the default timeout is 90 seconds.""" + + +CreateMultimodalDatasetConfigOrDict = Union[ + CreateMultimodalDatasetConfig, CreateMultimodalDatasetConfigDict +] + + +class _CreateMultimodalDatasetParameters(_common.BaseModel): + """Parameters for creating a dataset resource to store multimodal dataset.""" + + name: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + metadata_schema_uri: Optional[str] = Field(default=None, description="""""") + metadata: Optional[SchemaTablesDatasetMetadata] = Field( + default=None, description="""""" + ) + description: Optional[str] = Field(default=None, description="""""") + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, description="""""" + ) + config: Optional[CreateMultimodalDatasetConfig] = Field( + default=None, description="""""" + ) + + +class _CreateMultimodalDatasetParametersDict(TypedDict, total=False): + """Parameters for creating a dataset resource to store multimodal dataset.""" + + name: Optional[str] + """""" + + display_name: Optional[str] + """""" + + metadata_schema_uri: Optional[str] + """""" + + metadata: Optional[SchemaTablesDatasetMetadataDict] + """""" + + description: Optional[str] + """""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """""" + + config: Optional[CreateMultimodalDatasetConfigDict] + """""" + + +_CreateMultimodalDatasetParametersOrDict = Union[ + _CreateMultimodalDatasetParameters, _CreateMultimodalDatasetParametersDict +] + + +class _DeleteMultimodalDatasetRequestParameters(_common.BaseModel): + """Parameters for deleting a multimodal dataset.""" + + name: Optional[str] = Field( + default=None, description="""ID of the dataset to be deleted.""" + ) + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _DeleteMultimodalDatasetRequestParametersDict(TypedDict, total=False): + """Parameters for deleting a multimodal dataset.""" + + name: Optional[str] + """ID of the dataset to be deleted.""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_DeleteMultimodalDatasetRequestParametersOrDict = Union[ + _DeleteMultimodalDatasetRequestParameters, + _DeleteMultimodalDatasetRequestParametersDict, +] + + +class _GetMultimodalDatasetParameters(_common.BaseModel): + """Parameters for getting a multimodal dataset resource.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _GetMultimodalDatasetParametersDict(TypedDict, total=False): + """Parameters for getting a multimodal dataset resource.""" + + name: Optional[str] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_GetMultimodalDatasetParametersOrDict = Union[ + _GetMultimodalDatasetParameters, _GetMultimodalDatasetParametersDict +] + + +class MultimodalDataset(_common.BaseModel): + """Represents a multimodal dataset.""" + + name: Optional[str] = Field( + default=None, description="""The ID of the multimodal dataset.""" + ) + display_name: Optional[str] = Field( + default=None, description="""The display name of the multimodal dataset.""" + ) + metadata: Optional[SchemaTablesDatasetMetadata] = Field( + default=None, description="""The metadata of the multimodal dataset.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the multimodal dataset.""" + ) + + @property + def read_config(self) -> Optional[GeminiRequestReadConfig]: + """Gets the read config from the dataset metadata. Returns None if it's not set.""" + if self.metadata is None or self.metadata.gemini_request_read_config is None: + return None + return self.metadata.gemini_request_read_config + + def set_read_config( + self, + *, + read_config: GeminiRequestReadConfigOrDict, + ) -> None: + """Sets the read config in the dataset metadata.""" + if isinstance(read_config, dict): + read_config = GeminiRequestReadConfig(**read_config) + + if self.metadata is None: + self.metadata = SchemaTablesDatasetMetadata() + self.metadata.gemini_request_read_config = read_config + + @property + def bigquery_uri( + self, + ) -> Optional[str]: + """Gets the bigquery uri from the dataset metadata. Returns None if it's not set.""" + if ( + self.metadata is None + or self.metadata.input_config is None + or self.metadata.input_config.bigquery_source is None + ): + return None + return str(self.metadata.input_config.bigquery_source.uri) + + def set_bigquery_uri( + self, + bigquery_uri: str, + ) -> None: + """Sets the bigquery uri in the dataset metadata. Prepends 'bq://' if it's not already present.""" + if not bigquery_uri.startswith("bq://"): + bigquery_uri = f"bq://{bigquery_uri}" + metadata = ( + SchemaTablesDatasetMetadata() if self.metadata is None else self.metadata + ) + input_config = ( + SchemaTablesDatasetMetadataInputConfig() + if metadata.input_config is None + else metadata.input_config + ) + bigquery_source = ( + SchemaTablesDatasetMetadataBigQuerySource() + if input_config.bigquery_source is None + else input_config.bigquery_source + ) + bigquery_source.uri = bigquery_uri + input_config.bigquery_source = bigquery_source + metadata.input_config = input_config + self.metadata = metadata + + def to_bigframes( + self, + ) -> "bigframes.pandas.DataFrame": # type: ignore # noqa: F821 + """Converts the multimodal dataset to a BigFrames dataframe. + + This is the preferred method to inspect the multimodal dataset in a + notebook. + + Returns: + A BigFrames dataframe. + """ + from .. import _datasets_utils + + bigframes = _datasets_utils._try_import_bigframes() + + if self.bigquery_uri is None: + raise ValueError("Multimodal dataset bigquery source uri is not set.") + return bigframes.pandas.read_gbq_table(self.bigquery_uri.removeprefix("bq://")) + + +class MultimodalDatasetDict(TypedDict, total=False): + """Represents a multimodal dataset.""" + + name: Optional[str] + """The ID of the multimodal dataset.""" + + display_name: Optional[str] + """The display name of the multimodal dataset.""" + + metadata: Optional[SchemaTablesDatasetMetadataDict] + """The metadata of the multimodal dataset.""" + + description: Optional[str] + """The description of the multimodal dataset.""" + + +MultimodalDatasetOrDict = Union[MultimodalDataset, MultimodalDatasetDict] + + +class GetMultimodalDatasetOperationConfig(_common.BaseModel): + """Config for getting a multimodal dataset operation.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetMultimodalDatasetOperationConfigDict(TypedDict, total=False): + """Config for getting a multimodal dataset operation.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetMultimodalDatasetOperationConfigOrDict = Union[ + GetMultimodalDatasetOperationConfig, GetMultimodalDatasetOperationConfigDict +] + + +class _GetMultimodalDatasetOperationParameters(_common.BaseModel): + """Parameters for getting a dataset operation.""" + + dataset_id: Optional[str] = Field(default=None, description="""""") + operation_id: Optional[str] = Field(default=None, description="""""") + config: Optional[GetMultimodalDatasetOperationConfig] = Field( + default=None, description="""""" + ) + + +class _GetMultimodalDatasetOperationParametersDict(TypedDict, total=False): + """Parameters for getting a dataset operation.""" + + dataset_id: Optional[str] + """""" + + operation_id: Optional[str] + """""" + + config: Optional[GetMultimodalDatasetOperationConfigDict] + """""" + + +_GetMultimodalDatasetOperationParametersOrDict = Union[ + _GetMultimodalDatasetOperationParameters, + _GetMultimodalDatasetOperationParametersDict, +] + + +class ListMultimodalDatasetsConfig(_common.BaseModel): + """Config for listing multimodal datasets.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListMultimodalDatasetsConfigDict(TypedDict, total=False): + """Config for listing multimodal datasets.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListMultimodalDatasetsConfigOrDict = Union[ + ListMultimodalDatasetsConfig, ListMultimodalDatasetsConfigDict +] + + +class _ListMultimodalDatasetsRequestParameters(_common.BaseModel): + """Parameters for listing multimodal datasets.""" + + config: Optional[ListMultimodalDatasetsConfig] = Field( + default=None, description="""""" + ) + + +class _ListMultimodalDatasetsRequestParametersDict(TypedDict, total=False): + """Parameters for listing multimodal datasets.""" + + config: Optional[ListMultimodalDatasetsConfigDict] + """""" + + +_ListMultimodalDatasetsRequestParametersOrDict = Union[ + _ListMultimodalDatasetsRequestParameters, + _ListMultimodalDatasetsRequestParametersDict, +] + + +class ListMultimodalDatasetsResponse(_common.BaseModel): + """Response for listing multimodal datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the list datasets request in seconds. If not set, + the default timeout is 90 seconds.""", + ) + datasets: Optional[list[MultimodalDataset]] = Field( + default=None, + description="""List of datasets for the project. + """, + ) + + +class ListMultimodalDatasetsResponseDict(TypedDict, total=False): + """Response for listing multimodal datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + timeout: Optional[int] + """The timeout for the list datasets request in seconds. If not set, + the default timeout is 90 seconds.""" + + datasets: Optional[list[MultimodalDatasetDict]] + """List of datasets for the project. + """ + + +ListMultimodalDatasetsResponseOrDict = Union[ + ListMultimodalDatasetsResponse, ListMultimodalDatasetsResponseDict +] + + +class _UpdateMultimodalDatasetParameters(_common.BaseModel): + """Parameters for updating a multimodal dataset resource.""" + + name: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + metadata: Optional[SchemaTablesDatasetMetadata] = Field( + default=None, description="""""" + ) + description: Optional[str] = Field(default=None, description="""""") + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, description="""""" + ) + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _UpdateMultimodalDatasetParametersDict(TypedDict, total=False): + """Parameters for updating a multimodal dataset resource.""" + + name: Optional[str] + """""" + + display_name: Optional[str] + """""" + + metadata: Optional[SchemaTablesDatasetMetadataDict] + """""" + + description: Optional[str] + """""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_UpdateMultimodalDatasetParametersOrDict = Union[ + _UpdateMultimodalDatasetParameters, _UpdateMultimodalDatasetParametersDict +] + + +class SchemaPredictParamsGroundingConfigSourceEntry(_common.BaseModel): + """Single source entry for the grounding checking.""" + + enterprise_datastore: Optional[str] = Field( + default=None, + description="""The uri of the Vertex AI Search data source. Deprecated. Use vertex_ai_search_datastore instead.""", + ) + inline_context: Optional[str] = Field( + default=None, + description="""The grounding text passed inline with the Predict API. It can support up to 1 million bytes.""", + ) + type: Optional[ + Literal["UNSPECIFIED", "WEB", "ENTERPRISE", "VERTEX_AI_SEARCH", "INLINE"] + ] = Field( + default=None, description="""The type of the grounding checking source.""" + ) + vertex_ai_search_datastore: Optional[str] = Field( + default=None, description="""The uri of the Vertex AI Search data source.""" + ) + + +class SchemaPredictParamsGroundingConfigSourceEntryDict(TypedDict, total=False): + """Single source entry for the grounding checking.""" + + enterprise_datastore: Optional[str] + """The uri of the Vertex AI Search data source. Deprecated. Use vertex_ai_search_datastore instead.""" + + inline_context: Optional[str] + """The grounding text passed inline with the Predict API. It can support up to 1 million bytes.""" + + type: Optional[ + Literal["UNSPECIFIED", "WEB", "ENTERPRISE", "VERTEX_AI_SEARCH", "INLINE"] + ] + """The type of the grounding checking source.""" + + vertex_ai_search_datastore: Optional[str] + """The uri of the Vertex AI Search data source.""" + + +SchemaPredictParamsGroundingConfigSourceEntryOrDict = Union[ + SchemaPredictParamsGroundingConfigSourceEntry, + SchemaPredictParamsGroundingConfigSourceEntryDict, +] + + +class SchemaPredictParamsGroundingConfig(_common.BaseModel): + """The configuration for grounding checking.""" + + disable_attribution: Optional[bool] = Field( + default=None, + description="""If set, skip finding claim attributions (i.e not generate grounding citation).""", + ) + sources: Optional[list[SchemaPredictParamsGroundingConfigSourceEntry]] = Field( + default=None, description="""The sources for the grounding checking.""" + ) + + +class SchemaPredictParamsGroundingConfigDict(TypedDict, total=False): + """The configuration for grounding checking.""" + + disable_attribution: Optional[bool] + """If set, skip finding claim attributions (i.e not generate grounding citation).""" + + sources: Optional[list[SchemaPredictParamsGroundingConfigSourceEntryDict]] + """The sources for the grounding checking.""" + + +SchemaPredictParamsGroundingConfigOrDict = Union[ + SchemaPredictParamsGroundingConfig, SchemaPredictParamsGroundingConfigDict +] + + +class SchemaPromptInstancePromptExecution(_common.BaseModel): + """A prompt instance's parameters set that contains a set of variable values.""" + + arguments: Optional[dict[str, "SchemaPromptInstanceVariableValue"]] = Field( + default=None, description="""Maps variable names to their value.""" + ) + + +class SchemaPromptInstancePromptExecutionDict(TypedDict, total=False): + """A prompt instance's parameters set that contains a set of variable values.""" + + arguments: Optional[dict[str, "SchemaPromptInstanceVariableValueDict"]] + """Maps variable names to their value.""" + + +SchemaPromptInstancePromptExecutionOrDict = Union[ + SchemaPromptInstancePromptExecution, SchemaPromptInstancePromptExecutionDict +] + + +class SchemaPromptSpecPromptMessage(_common.BaseModel): + """Represents a prompt message.""" + + generation_config: Optional[genai_types.GenerationConfig] = Field( + default=None, description="""Generation config.""" + ) + tool_config: Optional[genai_types.FunctionCallingConfig] = Field( + default=None, + description="""Tool config. This config is shared for all tools provided in the request.""", + ) + tools: Optional[list[genai_types.Tool]] = Field( + default=None, + description="""A list of `Tools` the model may use to generate the next response. A `Tool` is a piece of code that enables the system to interact with external systems to perform an action, or set of actions, outside of knowledge and scope of the model.""", + ) + safety_settings: Optional[list[genai_types.SafetySetting]] = Field( + default=None, + description="""Per request settings for blocking unsafe content. Enforced on GenerateContentResponse.candidates.""", + ) + contents: Optional[list[genai_types.Content]] = Field( + default=None, + description="""The content of the current conversation with the model. For single-turn queries, this is a single instance. For multi-turn queries, this is a repeated field that contains conversation history + latest request.""", + ) + system_instruction: Optional[genai_types.Content] = Field( + default=None, + description="""The user provided system instructions for the model. Note: only text should be used in parts and content in each part will be in a separate paragraph.""", + ) + variables: Optional[list[dict[str, genai_types.Part]]] = Field( + default=None, description="""""" + ) + model: Optional[str] = Field(default=None, description="""The model name.""") + + +class SchemaPromptSpecPromptMessageDict(TypedDict, total=False): + """Represents a prompt message.""" + + generation_config: Optional[genai_types.GenerationConfigDict] + """Generation config.""" + + tool_config: Optional[genai_types.FunctionCallingConfigDict] + """Tool config. This config is shared for all tools provided in the request.""" + + tools: Optional[list[genai_types.ToolDict]] + """A list of `Tools` the model may use to generate the next response. A `Tool` is a piece of code that enables the system to interact with external systems to perform an action, or set of actions, outside of knowledge and scope of the model.""" + + safety_settings: Optional[list[genai_types.SafetySettingDict]] + """Per request settings for blocking unsafe content. Enforced on GenerateContentResponse.candidates.""" + + contents: Optional[list[genai_types.ContentDict]] + """The content of the current conversation with the model. For single-turn queries, this is a single instance. For multi-turn queries, this is a repeated field that contains conversation history + latest request.""" + + system_instruction: Optional[genai_types.ContentDict] + """The user provided system instructions for the model. Note: only text should be used in parts and content in each part will be in a separate paragraph.""" + + variables: Optional[list[dict[str, genai_types.PartDict]]] + """""" + + model: Optional[str] + """The model name.""" + + +SchemaPromptSpecPromptMessageOrDict = Union[ + SchemaPromptSpecPromptMessage, SchemaPromptSpecPromptMessageDict +] + + +class SchemaPromptSpecMultimodalPrompt(_common.BaseModel): + """Prompt variation that embeds preambles to prompt string.""" + + prompt_message: Optional[SchemaPromptSpecPromptMessage] = Field( + default=None, description="""The prompt message.""" + ) + + +class SchemaPromptSpecMultimodalPromptDict(TypedDict, total=False): + """Prompt variation that embeds preambles to prompt string.""" + + prompt_message: Optional[SchemaPromptSpecPromptMessageDict] + """The prompt message.""" + + +SchemaPromptSpecMultimodalPromptOrDict = Union[ + SchemaPromptSpecMultimodalPrompt, SchemaPromptSpecMultimodalPromptDict +] + + +class SchemaPromptSpecAppBuilderDataLinkedResource(_common.BaseModel): + """A linked resource attached to the application by the user.""" + + display_name: Optional[str] = Field( + default=None, + description="""A user-friendly name for the data source shown in the UI.""", + ) + name: Optional[str] = Field( + default=None, + description="""The unique resource name of the data source. The format is determined by the 'type' field. For type "SAVED_PROMPT": projects/{project}/locations/{location}/datasets/{dataset} For type "AI_AGENT": projects/{project}/locations/{location}/agents/{agent}""", + ) + type: Optional[str] = Field( + default=None, + description="""The type of the linked resource. e.g., "SAVED_PROMPT", "AI_AGENT" This string corresponds to the name of the LinkedResourceType enum member. See: google3/cloud/console/web/ai/platform/llm/prompts/build/services/specs_repository_service/linked_resources/linked_resource.ts""", + ) + + +class SchemaPromptSpecAppBuilderDataLinkedResourceDict(TypedDict, total=False): + """A linked resource attached to the application by the user.""" + + display_name: Optional[str] + """A user-friendly name for the data source shown in the UI.""" + + name: Optional[str] + """The unique resource name of the data source. The format is determined by the 'type' field. For type "SAVED_PROMPT": projects/{project}/locations/{location}/datasets/{dataset} For type "AI_AGENT": projects/{project}/locations/{location}/agents/{agent}""" + + type: Optional[str] + """The type of the linked resource. e.g., "SAVED_PROMPT", "AI_AGENT" This string corresponds to the name of the LinkedResourceType enum member. See: google3/cloud/console/web/ai/platform/llm/prompts/build/services/specs_repository_service/linked_resources/linked_resource.ts""" + + +SchemaPromptSpecAppBuilderDataLinkedResourceOrDict = Union[ + SchemaPromptSpecAppBuilderDataLinkedResource, + SchemaPromptSpecAppBuilderDataLinkedResourceDict, +] + + +class SchemaPromptSpecAppBuilderData(_common.BaseModel): + """Defines data for an application builder.""" + + code_repository_state: Optional[str] = Field( + default=None, + description="""Serialized state of the code repository. This string will typically contain a JSON representation of the UI's CodeRepositoryService state (files, folders, content, and any metadata). The UI is responsible for serialization and deserialization.""", + ) + framework: Optional[Framework] = Field( + default=None, + description="""Optional. Framework used to build the application.""", + ) + linked_resources: Optional[list[SchemaPromptSpecAppBuilderDataLinkedResource]] = ( + Field( + default=None, + description="""Linked resources attached to the application by the user.""", + ) + ) + + +class SchemaPromptSpecAppBuilderDataDict(TypedDict, total=False): + """Defines data for an application builder.""" + + code_repository_state: Optional[str] + """Serialized state of the code repository. This string will typically contain a JSON representation of the UI's CodeRepositoryService state (files, folders, content, and any metadata). The UI is responsible for serialization and deserialization.""" + + framework: Optional[Framework] + """Optional. Framework used to build the application.""" + + linked_resources: Optional[list[SchemaPromptSpecAppBuilderDataLinkedResourceDict]] + """Linked resources attached to the application by the user.""" + + +SchemaPromptSpecAppBuilderDataOrDict = Union[ + SchemaPromptSpecAppBuilderData, SchemaPromptSpecAppBuilderDataDict +] + + +class SchemaPromptSpecPartList(_common.BaseModel): + """Represents a prompt spec part list.""" + + parts: Optional[list[genai_types.Part]] = Field( + default=None, description="""A list of elements that can be part of a prompt.""" + ) + + +class SchemaPromptSpecPartListDict(TypedDict, total=False): + """Represents a prompt spec part list.""" + + parts: Optional[list[genai_types.PartDict]] + """A list of elements that can be part of a prompt.""" + + +SchemaPromptSpecPartListOrDict = Union[ + SchemaPromptSpecPartList, SchemaPromptSpecPartListDict +] + + +class SchemaPromptSpecStructuredPrompt(_common.BaseModel): + """Represents a structured prompt.""" + + context: Optional[genai_types.Content] = Field( + default=None, description="""Preamble: The context of the prompt.""" + ) + app_builder_data: Optional[SchemaPromptSpecAppBuilderData] = Field( + default=None, description="""Data for app builder use case.""" + ) + examples: Optional[list[SchemaPromptSpecPartList]] = Field( + default=None, + description="""Preamble: A set of examples for expected model response.""", + ) + infill_prefix: Optional[str] = Field( + default=None, + description="""Preamble: For infill prompt, the prefix before expected model response.""", + ) + infill_suffix: Optional[str] = Field( + default=None, + description="""Preamble: For infill prompt, the suffix after expected model response.""", + ) + input_prefixes: Optional[list[str]] = Field( + default=None, + description="""Preamble: The input prefixes before each example input.""", + ) + output_prefixes: Optional[list[str]] = Field( + default=None, + description="""Preamble: The output prefixes before each example output.""", + ) + prediction_inputs: Optional[list[SchemaPromptSpecPartList]] = Field( + default=None, + description="""Preamble: The input test data for prediction. Each PartList in this field represents one text-only input set for a single model request.""", + ) + prompt_message: Optional[SchemaPromptSpecPromptMessage] = Field( + default=None, description="""The prompt message.""" + ) + + +class SchemaPromptSpecStructuredPromptDict(TypedDict, total=False): + """Represents a structured prompt.""" + + context: Optional[genai_types.ContentDict] + """Preamble: The context of the prompt.""" + + app_builder_data: Optional[SchemaPromptSpecAppBuilderDataDict] + """Data for app builder use case.""" + + examples: Optional[list[SchemaPromptSpecPartListDict]] + """Preamble: A set of examples for expected model response.""" + + infill_prefix: Optional[str] + """Preamble: For infill prompt, the prefix before expected model response.""" + + infill_suffix: Optional[str] + """Preamble: For infill prompt, the suffix after expected model response.""" + + input_prefixes: Optional[list[str]] + """Preamble: The input prefixes before each example input.""" + + output_prefixes: Optional[list[str]] + """Preamble: The output prefixes before each example output.""" + + prediction_inputs: Optional[list[SchemaPromptSpecPartListDict]] + """Preamble: The input test data for prediction. Each PartList in this field represents one text-only input set for a single model request.""" + + prompt_message: Optional[SchemaPromptSpecPromptMessageDict] + """The prompt message.""" + + +SchemaPromptSpecStructuredPromptOrDict = Union[ + SchemaPromptSpecStructuredPrompt, SchemaPromptSpecStructuredPromptDict +] + + +class SchemaPromptSpecReferenceSentencePair(_common.BaseModel): + """A pair of sentences used as reference in source and target languages.""" + + source_sentence: Optional[str] = Field( + default=None, description="""Source sentence in the sentence pair.""" + ) + target_sentence: Optional[str] = Field( + default=None, description="""Target sentence in the sentence pair.""" + ) + + +class SchemaPromptSpecReferenceSentencePairDict(TypedDict, total=False): + """A pair of sentences used as reference in source and target languages.""" + + source_sentence: Optional[str] + """Source sentence in the sentence pair.""" + + target_sentence: Optional[str] + """Target sentence in the sentence pair.""" + + +SchemaPromptSpecReferenceSentencePairOrDict = Union[ + SchemaPromptSpecReferenceSentencePair, SchemaPromptSpecReferenceSentencePairDict +] + + +class SchemaPromptSpecReferenceSentencePairList(_common.BaseModel): + """A list of reference sentence pairs.""" + + reference_sentence_pairs: Optional[list[SchemaPromptSpecReferenceSentencePair]] = ( + Field(default=None, description="""Reference sentence pairs.""") + ) + + +class SchemaPromptSpecReferenceSentencePairListDict(TypedDict, total=False): + """A list of reference sentence pairs.""" + + reference_sentence_pairs: Optional[list[SchemaPromptSpecReferenceSentencePairDict]] + """Reference sentence pairs.""" + + +SchemaPromptSpecReferenceSentencePairListOrDict = Union[ + SchemaPromptSpecReferenceSentencePairList, + SchemaPromptSpecReferenceSentencePairListDict, +] + + +class SchemaPromptSpecTranslationFileInputSource(_common.BaseModel): + + content: Optional[str] = Field(default=None, description="""The file's contents.""") + display_name: Optional[str] = Field( + default=None, description="""The file's display name.""" + ) + mime_type: Optional[str] = Field( + default=None, description="""The file's mime type.""" + ) + + +class SchemaPromptSpecTranslationFileInputSourceDict(TypedDict, total=False): + + content: Optional[str] + """The file's contents.""" + + display_name: Optional[str] + """The file's display name.""" + + mime_type: Optional[str] + """The file's mime type.""" + + +SchemaPromptSpecTranslationFileInputSourceOrDict = Union[ + SchemaPromptSpecTranslationFileInputSource, + SchemaPromptSpecTranslationFileInputSourceDict, +] + + +class SchemaPromptSpecTranslationGcsInputSource(_common.BaseModel): + + input_uri: Optional[str] = Field( + default=None, + description="""Source data URI. For example, `gs://my_bucket/my_object`.""", + ) + + +class SchemaPromptSpecTranslationGcsInputSourceDict(TypedDict, total=False): + + input_uri: Optional[str] + """Source data URI. For example, `gs://my_bucket/my_object`.""" + + +SchemaPromptSpecTranslationGcsInputSourceOrDict = Union[ + SchemaPromptSpecTranslationGcsInputSource, + SchemaPromptSpecTranslationGcsInputSourceDict, +] + + +class SchemaPromptSpecTranslationSentenceFileInput(_common.BaseModel): + + file_input_source: Optional[SchemaPromptSpecTranslationFileInputSource] = Field( + default=None, description="""Inlined file source.""" + ) + gcs_input_source: Optional[SchemaPromptSpecTranslationGcsInputSource] = Field( + default=None, description="""Cloud Storage file source.""" + ) + + +class SchemaPromptSpecTranslationSentenceFileInputDict(TypedDict, total=False): + + file_input_source: Optional[SchemaPromptSpecTranslationFileInputSourceDict] + """Inlined file source.""" + + gcs_input_source: Optional[SchemaPromptSpecTranslationGcsInputSourceDict] + """Cloud Storage file source.""" + + +SchemaPromptSpecTranslationSentenceFileInputOrDict = Union[ + SchemaPromptSpecTranslationSentenceFileInput, + SchemaPromptSpecTranslationSentenceFileInputDict, +] + + +class SchemaPromptSpecTranslationExample(_common.BaseModel): + """The translation example that contains reference sentences from various sources.""" + + reference_sentence_pair_lists: Optional[ + list[SchemaPromptSpecReferenceSentencePairList] + ] = Field(default=None, description="""The reference sentences from inline text.""") + reference_sentences_file_inputs: Optional[ + list[SchemaPromptSpecTranslationSentenceFileInput] + ] = Field(default=None, description="""The reference sentences from file.""") + + +class SchemaPromptSpecTranslationExampleDict(TypedDict, total=False): + """The translation example that contains reference sentences from various sources.""" + + reference_sentence_pair_lists: Optional[ + list[SchemaPromptSpecReferenceSentencePairListDict] + ] + """The reference sentences from inline text.""" + + reference_sentences_file_inputs: Optional[ + list[SchemaPromptSpecTranslationSentenceFileInputDict] + ] + """The reference sentences from file.""" + + +SchemaPromptSpecTranslationExampleOrDict = Union[ + SchemaPromptSpecTranslationExample, SchemaPromptSpecTranslationExampleDict +] + + +class SchemaPromptSpecTranslationOption(_common.BaseModel): + """Optional settings for translation prompt.""" + + number_of_shots: Optional[int] = Field( + default=None, description="""How many shots to use.""" + ) + + +class SchemaPromptSpecTranslationOptionDict(TypedDict, total=False): + """Optional settings for translation prompt.""" + + number_of_shots: Optional[int] + """How many shots to use.""" + + +SchemaPromptSpecTranslationOptionOrDict = Union[ + SchemaPromptSpecTranslationOption, SchemaPromptSpecTranslationOptionDict +] + + +class SchemaPromptSpecTranslationPrompt(_common.BaseModel): + """Prompt variation for Translation use case.""" + + example: Optional[SchemaPromptSpecTranslationExample] = Field( + default=None, description="""The translation example.""" + ) + option: Optional[SchemaPromptSpecTranslationOption] = Field( + default=None, description="""The translation option.""" + ) + prompt_message: Optional[SchemaPromptSpecPromptMessage] = Field( + default=None, description="""The prompt message.""" + ) + source_language_code: Optional[str] = Field( + default=None, description="""The source language code.""" + ) + target_language_code: Optional[str] = Field( + default=None, description="""The target language code.""" + ) + + +class SchemaPromptSpecTranslationPromptDict(TypedDict, total=False): + """Prompt variation for Translation use case.""" + + example: Optional[SchemaPromptSpecTranslationExampleDict] + """The translation example.""" + + option: Optional[SchemaPromptSpecTranslationOptionDict] + """The translation option.""" + + prompt_message: Optional[SchemaPromptSpecPromptMessageDict] + """The prompt message.""" + + source_language_code: Optional[str] + """The source language code.""" + + target_language_code: Optional[str] + """The target language code.""" + + +SchemaPromptSpecTranslationPromptOrDict = Union[ + SchemaPromptSpecTranslationPrompt, SchemaPromptSpecTranslationPromptDict +] + + +class SchemaPromptApiSchema(_common.BaseModel): + """The A2 schema of a prompt.""" + + api_schema_version: Optional[str] = Field( + default=None, + description="""The Schema version that represents changes to the API behavior.""", + ) + executions: Optional[list[SchemaPromptInstancePromptExecution]] = Field( + default=None, + description="""A list of execution instances for constructing a ready-to-use prompt.""", + ) + multimodal_prompt: Optional[SchemaPromptSpecMultimodalPrompt] = Field( + default=None, + description="""Multimodal prompt which embeds preambles to prompt string.""", + ) + structured_prompt: Optional[SchemaPromptSpecStructuredPrompt] = Field( + default=None, + description="""The prompt variation that stores preambles in separate fields.""", + ) + translation_prompt: Optional[SchemaPromptSpecTranslationPrompt] = Field( + default=None, description="""The prompt variation for Translation use case.""" + ) + + +class SchemaPromptApiSchemaDict(TypedDict, total=False): + """The A2 schema of a prompt.""" + + api_schema_version: Optional[str] + """The Schema version that represents changes to the API behavior.""" + + executions: Optional[list[SchemaPromptInstancePromptExecutionDict]] + """A list of execution instances for constructing a ready-to-use prompt.""" + + multimodal_prompt: Optional[SchemaPromptSpecMultimodalPromptDict] + """Multimodal prompt which embeds preambles to prompt string.""" + + structured_prompt: Optional[SchemaPromptSpecStructuredPromptDict] + """The prompt variation that stores preambles in separate fields.""" + + translation_prompt: Optional[SchemaPromptSpecTranslationPromptDict] + """The prompt variation for Translation use case.""" + + +SchemaPromptApiSchemaOrDict = Union[SchemaPromptApiSchema, SchemaPromptApiSchemaDict] + + +class SchemaTextPromptDatasetMetadata(_common.BaseModel): + """Represents the text prompt dataset metadata.""" + + candidate_count: Optional[int] = Field( + default=None, description="""Number of candidates.""" + ) + gcs_uri: Optional[str] = Field( + default=None, + description="""The Google Cloud Storage URI that stores the prompt data.""", + ) + grounding_config: Optional[SchemaPredictParamsGroundingConfig] = Field( + default=None, description="""Grounding checking configuration.""" + ) + has_prompt_variable: Optional[bool] = Field( + default=None, description="""Whether the prompt dataset has prompt variable.""" + ) + logprobs: Optional[bool] = Field( + default=None, + description="""Whether or not the user has enabled logit probabilities in the model parameters.""", + ) + max_output_tokens: Optional[int] = Field( + default=None, + description="""Value of the maximum number of tokens generated set when the dataset was saved.""", + ) + note: Optional[str] = Field( + default=None, + description="""User-created prompt note. Note size limit is 2KB.""", + ) + prompt_api_schema: Optional[SchemaPromptApiSchema] = Field( + default=None, + description="""The API schema of the prompt to support both UI and SDK usages.""", + ) + prompt_type: Optional[str] = Field( + default=None, description="""Type of the prompt dataset.""" + ) + seed_enabled: Optional[bool] = Field( + default=None, + description="""Seeding enables model to return a deterministic response on a best effort basis. Determinism isn't guaranteed. This field determines whether or not seeding is enabled.""", + ) + seed_value: Optional[int] = Field( + default=None, description="""The actual value of the seed.""" + ) + stop_sequences: Optional[list[str]] = Field( + default=None, description="""Customized stop sequences.""" + ) + system_instruction: Optional[str] = Field( + default=None, + description="""The content of the prompt dataset system instruction.""", + ) + system_instruction_gcs_uri: Optional[str] = Field( + default=None, + description="""The Google Cloud Storage URI that stores the system instruction, starting with gs://.""", + ) + temperature: Optional[float] = Field( + default=None, + description="""Temperature value used for sampling set when the dataset was saved. This value is used to tune the degree of randomness.""", + ) + text: Optional[str] = Field( + default=None, description="""The content of the prompt dataset.""" + ) + top_k: Optional[int] = Field( + default=None, + description="""Top K value set when the dataset was saved. This value determines how many candidates with highest probability from the vocab would be selected for each decoding step.""", + ) + top_p: Optional[float] = Field( + default=None, + description="""Top P value set when the dataset was saved. Given topK tokens for decoding, top candidates will be selected until the sum of their probabilities is topP.""", + ) + + +class SchemaTextPromptDatasetMetadataDict(TypedDict, total=False): + """Represents the text prompt dataset metadata.""" + + candidate_count: Optional[int] + """Number of candidates.""" + + gcs_uri: Optional[str] + """The Google Cloud Storage URI that stores the prompt data.""" + + grounding_config: Optional[SchemaPredictParamsGroundingConfigDict] + """Grounding checking configuration.""" + + has_prompt_variable: Optional[bool] + """Whether the prompt dataset has prompt variable.""" + + logprobs: Optional[bool] + """Whether or not the user has enabled logit probabilities in the model parameters.""" + + max_output_tokens: Optional[int] + """Value of the maximum number of tokens generated set when the dataset was saved.""" + + note: Optional[str] + """User-created prompt note. Note size limit is 2KB.""" + + prompt_api_schema: Optional[SchemaPromptApiSchemaDict] + """The API schema of the prompt to support both UI and SDK usages.""" + + prompt_type: Optional[str] + """Type of the prompt dataset.""" + + seed_enabled: Optional[bool] + """Seeding enables model to return a deterministic response on a best effort basis. Determinism isn't guaranteed. This field determines whether or not seeding is enabled.""" + + seed_value: Optional[int] + """The actual value of the seed.""" + + stop_sequences: Optional[list[str]] + """Customized stop sequences.""" + + system_instruction: Optional[str] + """The content of the prompt dataset system instruction.""" + + system_instruction_gcs_uri: Optional[str] + """The Google Cloud Storage URI that stores the system instruction, starting with gs://.""" + + temperature: Optional[float] + """Temperature value used for sampling set when the dataset was saved. This value is used to tune the degree of randomness.""" + + text: Optional[str] + """The content of the prompt dataset.""" + + top_k: Optional[int] + """Top K value set when the dataset was saved. This value determines how many candidates with highest probability from the vocab would be selected for each decoding step.""" + + top_p: Optional[float] + """Top P value set when the dataset was saved. Given topK tokens for decoding, top candidates will be selected until the sum of their probabilities is topP.""" + + +SchemaTextPromptDatasetMetadataOrDict = Union[ + SchemaTextPromptDatasetMetadata, SchemaTextPromptDatasetMetadataDict +] + + +class CreateDatasetConfig(_common.BaseModel): + """Config for creating a dataset resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateDatasetConfigDict(TypedDict, total=False): + """Config for creating a dataset resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateDatasetConfigOrDict = Union[CreateDatasetConfig, CreateDatasetConfigDict] + + +class _CreateDatasetParameters(_common.BaseModel): + """Parameters for creating a dataset resource to store prompts.""" + + name: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + metadata_schema_uri: Optional[str] = Field(default=None, description="""""") + metadata: Optional[SchemaTextPromptDatasetMetadata] = Field( + default=None, description="""""" + ) + description: Optional[str] = Field(default=None, description="""""") + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, description="""""" + ) + model_reference: Optional[str] = Field(default=None, description="""""") + config: Optional[CreateDatasetConfig] = Field(default=None, description="""""") + + +class _CreateDatasetParametersDict(TypedDict, total=False): + """Parameters for creating a dataset resource to store prompts.""" + + name: Optional[str] + """""" + + display_name: Optional[str] + """""" + + metadata_schema_uri: Optional[str] + """""" + + metadata: Optional[SchemaTextPromptDatasetMetadataDict] + """""" + + description: Optional[str] + """""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """""" + + model_reference: Optional[str] + """""" + + config: Optional[CreateDatasetConfigDict] + """""" + + +_CreateDatasetParametersOrDict = Union[ + _CreateDatasetParameters, _CreateDatasetParametersDict +] + + +class DatasetOperation(_common.BaseModel): + """Represents the create dataset operation.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[dict[str, Any]] = Field( + default=None, description="""The result of the dataset operation.""" + ) + + +class DatasetOperationDict(TypedDict, total=False): + """Represents the create dataset operation.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[dict[str, Any]] + """The result of the dataset operation.""" + + +DatasetOperationOrDict = Union[DatasetOperation, DatasetOperationDict] + + +class CreateDatasetVersionConfig(_common.BaseModel): + """Config for creating a dataset version resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class CreateDatasetVersionConfigDict(TypedDict, total=False): + """Config for creating a dataset version resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +CreateDatasetVersionConfigOrDict = Union[ + CreateDatasetVersionConfig, CreateDatasetVersionConfigDict +] + + +class _CreateDatasetVersionParameters(_common.BaseModel): + """Represents the create dataset version parameters.""" + + dataset_name: Optional[str] = Field(default=None, description="""""") + metadata: Optional[SchemaTextPromptDatasetMetadata] = Field( + default=None, description="""""" + ) + model_reference: Optional[str] = Field(default=None, description="""""") + parent: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + config: Optional[CreateDatasetVersionConfig] = Field( + default=None, description="""""" + ) + + +class _CreateDatasetVersionParametersDict(TypedDict, total=False): + """Represents the create dataset version parameters.""" + + dataset_name: Optional[str] + """""" + + metadata: Optional[SchemaTextPromptDatasetMetadataDict] + """""" + + model_reference: Optional[str] + """""" + + parent: Optional[str] + """""" + + display_name: Optional[str] + """""" + + config: Optional[CreateDatasetVersionConfigDict] + """""" + + +_CreateDatasetVersionParametersOrDict = Union[ + _CreateDatasetVersionParameters, _CreateDatasetVersionParametersDict +] + + +class _GetDatasetParameters(_common.BaseModel): + """Parameters for getting a dataset resource to store prompts.""" + + name: Optional[str] = Field(default=None, description="""""") + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _GetDatasetParametersDict(TypedDict, total=False): + """Parameters for getting a dataset resource to store prompts.""" + + name: Optional[str] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_GetDatasetParametersOrDict = Union[_GetDatasetParameters, _GetDatasetParametersDict] + + +class SavedQuery(_common.BaseModel): + """A SavedQuery is a view of the dataset. It references a subset of annotations by problem type and filters.""" + + annotation_filter: Optional[str] = Field( + default=None, + description="""Output only. Filters on the Annotations in the dataset.""", + ) + annotation_spec_count: Optional[int] = Field( + default=None, + description="""Output only. Number of AnnotationSpecs in the context of the SavedQuery.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this SavedQuery was created.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The user-defined name of the SavedQuery. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + etag: Optional[str] = Field( + default=None, + description="""Used to perform a consistent read-modify-write update. If not set, a blind "overwrite" update happens.""", + ) + metadata: Optional[Any] = Field( + default=None, + description="""Some additional information about the SavedQuery.""", + ) + name: Optional[str] = Field( + default=None, description="""Output only. Resource name of the SavedQuery.""" + ) + problem_type: Optional[str] = Field( + default=None, + description="""Required. Problem type of the SavedQuery. Allowed values: * IMAGE_CLASSIFICATION_SINGLE_LABEL * IMAGE_CLASSIFICATION_MULTI_LABEL * IMAGE_BOUNDING_POLY * IMAGE_BOUNDING_BOX * TEXT_CLASSIFICATION_SINGLE_LABEL * TEXT_CLASSIFICATION_MULTI_LABEL * TEXT_EXTRACTION * TEXT_SENTIMENT * VIDEO_CLASSIFICATION * VIDEO_OBJECT_TRACKING""", + ) + support_automl_training: Optional[bool] = Field( + default=None, + description="""Output only. If the Annotations belonging to the SavedQuery can be used for AutoML training.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when SavedQuery was last updated.""", + ) + + +class SavedQueryDict(TypedDict, total=False): + """A SavedQuery is a view of the dataset. It references a subset of annotations by problem type and filters.""" + + annotation_filter: Optional[str] + """Output only. Filters on the Annotations in the dataset.""" + + annotation_spec_count: Optional[int] + """Output only. Number of AnnotationSpecs in the context of the SavedQuery.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this SavedQuery was created.""" + + display_name: Optional[str] + """Required. The user-defined name of the SavedQuery. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + etag: Optional[str] + """Used to perform a consistent read-modify-write update. If not set, a blind "overwrite" update happens.""" + + metadata: Optional[Any] + """Some additional information about the SavedQuery.""" + + name: Optional[str] + """Output only. Resource name of the SavedQuery.""" + + problem_type: Optional[str] + """Required. Problem type of the SavedQuery. Allowed values: * IMAGE_CLASSIFICATION_SINGLE_LABEL * IMAGE_CLASSIFICATION_MULTI_LABEL * IMAGE_BOUNDING_POLY * IMAGE_BOUNDING_BOX * TEXT_CLASSIFICATION_SINGLE_LABEL * TEXT_CLASSIFICATION_MULTI_LABEL * TEXT_EXTRACTION * TEXT_SENTIMENT * VIDEO_CLASSIFICATION * VIDEO_OBJECT_TRACKING""" + + support_automl_training: Optional[bool] + """Output only. If the Annotations belonging to the SavedQuery can be used for AutoML training.""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when SavedQuery was last updated.""" + + +SavedQueryOrDict = Union[SavedQuery, SavedQueryDict] + + +class Dataset(_common.BaseModel): + """Represents a dataset resource to store prompts.""" + + metadata: Optional[SchemaTextPromptDatasetMetadata] = Field( + default=None, + description="""Required. Additional information about the Dataset.""", + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key spec for a Dataset. If set, this Dataset and all sub-resources of this Dataset will be secured by this key.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this Dataset was created.""", + ) + data_item_count: Optional[int] = Field( + default=None, + description="""Output only. The number of DataItems in this Dataset. Only apply for non-structured Dataset.""", + ) + description: Optional[str] = Field( + default=None, description="""The description of the Dataset.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. The user-defined name of the Dataset. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + etag: Optional[str] = Field( + default=None, + description="""Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, + description="""The labels with user-defined metadata to organize your Datasets. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. No more than 64 user labels can be associated with one Dataset (System labels are excluded). See https://goo.gl/xmQnxf for more information and examples of labels. System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. Following system labels exist for each Dataset: * "aiplatform.googleapis.com/dataset_metadata_schema": output only, its value is the metadata_schema's title.""", + ) + metadata_artifact: Optional[str] = Field( + default=None, + description="""Output only. The resource name of the Artifact that was created in MetadataStore when creating the Dataset. The Artifact resource name pattern is `projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}`.""", + ) + metadata_schema_uri: Optional[str] = Field( + default=None, + description="""Required. Points to a YAML file stored on Google Cloud Storage describing additional information about the Dataset. The schema is defined as an OpenAPI 3.0.2 Schema Object. The schema files that can be used here are found in gs://google-cloud-aiplatform/schema/dataset/metadata/.""", + ) + model_reference: Optional[str] = Field( + default=None, + description="""Optional. Reference to the public base model last used by the dataset. Only set for prompt datasets.""", + ) + name: Optional[str] = Field( + default=None, + description="""Output only. Identifier. The resource name of the Dataset. Format: `projects/{project}/locations/{location}/datasets/{dataset}`""", + ) + satisfies_pzi: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + satisfies_pzs: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + saved_queries: Optional[list[SavedQuery]] = Field( + default=None, + description="""All SavedQueries belong to the Dataset will be returned in List/Get Dataset response. The annotation_specs field will not be populated except for UI cases which will only use annotation_spec_count. In CreateDataset request, a SavedQuery is created together if this field is set, up to one SavedQuery can be set in CreateDatasetRequest. The SavedQuery should not contain any AnnotationSpec.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this Dataset was last updated.""", + ) + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["Dataset"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "Dataset": + """Converts a dictionary response into a Dataset object.""" + + response = _camel_key_to_snake(response) + result = super()._from_response(response=response, kwargs=kwargs) + return result + + +class DatasetDict(TypedDict, total=False): + """Represents a dataset resource to store prompts.""" + + metadata: Optional[SchemaTextPromptDatasetMetadataDict] + """Required. Additional information about the Dataset.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key spec for a Dataset. If set, this Dataset and all sub-resources of this Dataset will be secured by this key.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this Dataset was created.""" + + data_item_count: Optional[int] + """Output only. The number of DataItems in this Dataset. Only apply for non-structured Dataset.""" + + description: Optional[str] + """The description of the Dataset.""" + + display_name: Optional[str] + """Required. The user-defined name of the Dataset. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + etag: Optional[str] + """Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""" + + labels: Optional[dict[str, str]] + """The labels with user-defined metadata to organize your Datasets. Label keys and values can be no longer than 64 characters (Unicode codepoints), can only contain lowercase letters, numeric characters, underscores and dashes. International characters are allowed. No more than 64 user labels can be associated with one Dataset (System labels are excluded). See https://goo.gl/xmQnxf for more information and examples of labels. System reserved label keys are prefixed with "aiplatform.googleapis.com/" and are immutable. Following system labels exist for each Dataset: * "aiplatform.googleapis.com/dataset_metadata_schema": output only, its value is the metadata_schema's title.""" + + metadata_artifact: Optional[str] + """Output only. The resource name of the Artifact that was created in MetadataStore when creating the Dataset. The Artifact resource name pattern is `projects/{project}/locations/{location}/metadataStores/{metadata_store}/artifacts/{artifact}`.""" + + metadata_schema_uri: Optional[str] + """Required. Points to a YAML file stored on Google Cloud Storage describing additional information about the Dataset. The schema is defined as an OpenAPI 3.0.2 Schema Object. The schema files that can be used here are found in gs://google-cloud-aiplatform/schema/dataset/metadata/.""" + + model_reference: Optional[str] + """Optional. Reference to the public base model last used by the dataset. Only set for prompt datasets.""" + + name: Optional[str] + """Output only. Identifier. The resource name of the Dataset. Format: `projects/{project}/locations/{location}/datasets/{dataset}`""" + + satisfies_pzi: Optional[bool] + """Output only. Reserved for future use.""" + + satisfies_pzs: Optional[bool] + """Output only. Reserved for future use.""" + + saved_queries: Optional[list[SavedQueryDict]] + """All SavedQueries belong to the Dataset will be returned in List/Get Dataset response. The annotation_specs field will not be populated except for UI cases which will only use annotation_spec_count. In CreateDataset request, a SavedQuery is created together if this field is set, up to one SavedQuery can be set in CreateDatasetRequest. The SavedQuery should not contain any AnnotationSpec.""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when this Dataset was last updated.""" + + +DatasetOrDict = Union[Dataset, DatasetDict] + + +class _GetDatasetVersionParameters(_common.BaseModel): + """Parameters for getting a dataset resource to store prompts.""" + + dataset_id: Optional[str] = Field(default=None, description="""""") + dataset_version_id: Optional[str] = Field(default=None, description="""""") + config: Optional[VertexBaseConfig] = Field(default=None, description="""""") + + +class _GetDatasetVersionParametersDict(TypedDict, total=False): + """Parameters for getting a dataset resource to store prompts.""" + + dataset_id: Optional[str] + """""" + + dataset_version_id: Optional[str] + """""" + + config: Optional[VertexBaseConfigDict] + """""" + + +_GetDatasetVersionParametersOrDict = Union[ + _GetDatasetVersionParameters, _GetDatasetVersionParametersDict +] + + +class DatasetVersion(_common.BaseModel): + """Represents a dataset version resource to store prompts.""" + + metadata: Optional[SchemaTextPromptDatasetMetadata] = Field( + default=None, + description="""Required. Output only. Additional information about the DatasetVersion.""", + ) + big_query_dataset_name: Optional[str] = Field( + default=None, + description="""Output only. Name of the associated BigQuery dataset.""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this DatasetVersion was created.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the DatasetVersion. The name can be up to 128 characters long and can consist of any UTF-8 characters.""", + ) + etag: Optional[str] = Field( + default=None, + description="""Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""", + ) + model_reference: Optional[str] = Field( + default=None, + description="""Output only. Reference to the public base model last used by the dataset version. Only set for prompt dataset versions.""", + ) + name: Optional[str] = Field( + default=None, + description="""Output only. Identifier. The resource name of the DatasetVersion. Format: `projects/{project}/locations/{location}/datasets/{dataset}/datasetVersions/{dataset_version}`""", + ) + satisfies_pzi: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + satisfies_pzs: Optional[bool] = Field( + default=None, description="""Output only. Reserved for future use.""" + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this DatasetVersion was last updated.""", + ) + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["DatasetVersion"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "DatasetVersion": + """Converts a dictionary response into a DatasetVersion object.""" + + response = _camel_key_to_snake(response) + result = super()._from_response(response=response, kwargs=kwargs) + return result + + +class DatasetVersionDict(TypedDict, total=False): + """Represents a dataset version resource to store prompts.""" + + metadata: Optional[SchemaTextPromptDatasetMetadataDict] + """Required. Output only. Additional information about the DatasetVersion.""" + + big_query_dataset_name: Optional[str] + """Output only. Name of the associated BigQuery dataset.""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this DatasetVersion was created.""" + + display_name: Optional[str] + """The user-defined name of the DatasetVersion. The name can be up to 128 characters long and can consist of any UTF-8 characters.""" + + etag: Optional[str] + """Used to perform consistent read-modify-write updates. If not set, a blind "overwrite" update happens.""" + + model_reference: Optional[str] + """Output only. Reference to the public base model last used by the dataset version. Only set for prompt dataset versions.""" + + name: Optional[str] + """Output only. Identifier. The resource name of the DatasetVersion. Format: `projects/{project}/locations/{location}/datasets/{dataset}/datasetVersions/{dataset_version}`""" + + satisfies_pzi: Optional[bool] + """Output only. Reserved for future use.""" + + satisfies_pzs: Optional[bool] + """Output only. Reserved for future use.""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when this DatasetVersion was last updated.""" + + +DatasetVersionOrDict = Union[DatasetVersion, DatasetVersionDict] + + +class GetDatasetOperationConfig(_common.BaseModel): + """Config for getting a dataset version operation.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetDatasetOperationConfigDict(TypedDict, total=False): + """Config for getting a dataset version operation.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetDatasetOperationConfigOrDict = Union[ + GetDatasetOperationConfig, GetDatasetOperationConfigDict +] + + +class _GetDatasetOperationParameters(_common.BaseModel): + """Parameters for getting a dataset operation.""" + + dataset_id: Optional[str] = Field(default=None, description="""""") + operation_id: Optional[str] = Field(default=None, description="""""") + config: Optional[GetDatasetOperationConfig] = Field( + default=None, description="""""" + ) + + +class _GetDatasetOperationParametersDict(TypedDict, total=False): + """Parameters for getting a dataset operation.""" + + dataset_id: Optional[str] + """""" + + operation_id: Optional[str] + """""" + + config: Optional[GetDatasetOperationConfigDict] + """""" + + +_GetDatasetOperationParametersOrDict = Union[ + _GetDatasetOperationParameters, _GetDatasetOperationParametersDict +] + + +class ListPromptsConfig(_common.BaseModel): + """Config for listing prompt datasets and dataset versions.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + page_size: Optional[int] = Field(default=None, description="""""") + page_token: Optional[str] = Field(default=None, description="""""") + filter: Optional[str] = Field( + default=None, + description="""An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""", + ) + + +class ListPromptsConfigDict(TypedDict, total=False): + """Config for listing prompt datasets and dataset versions.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + page_size: Optional[int] + """""" + + page_token: Optional[str] + """""" + + filter: Optional[str] + """An expression for filtering the results of the request. + For field names both snake_case and camelCase are supported.""" + + +ListPromptsConfigOrDict = Union[ListPromptsConfig, ListPromptsConfigDict] + + +class _ListDatasetsRequestParameters(_common.BaseModel): + """Parameters for listing prompt datasets.""" + + config: Optional[ListPromptsConfig] = Field(default=None, description="""""") + + +class _ListDatasetsRequestParametersDict(TypedDict, total=False): + """Parameters for listing prompt datasets.""" + + config: Optional[ListPromptsConfigDict] + """""" + + +_ListDatasetsRequestParametersOrDict = Union[ + _ListDatasetsRequestParameters, _ListDatasetsRequestParametersDict +] + + +class ListDatasetsResponse(_common.BaseModel): + """Response for listing prompt datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + datasets: Optional[list[Dataset]] = Field( + default=None, + description="""List of datasets for the project. + """, + ) + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["ListDatasetsResponse"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "ListDatasetsResponse": + """Converts a dictionary response into a ListDatasetsResponse object.""" + + response = _camel_key_to_snake(response) + result = super()._from_response(response=response, kwargs=kwargs) + return result + + +class ListDatasetsResponseDict(TypedDict, total=False): + """Response for listing prompt datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + datasets: Optional[list[DatasetDict]] + """List of datasets for the project. + """ + + +ListDatasetsResponseOrDict = Union[ListDatasetsResponse, ListDatasetsResponseDict] + + +class _ListDatasetVersionsRequestParameters(_common.BaseModel): + """Parameters for listing dataset versions.""" + + read_mask: Optional[str] = Field(default=None, description="""""") + dataset_id: Optional[str] = Field(default=None, description="""""") + config: Optional[ListPromptsConfig] = Field(default=None, description="""""") + + +class _ListDatasetVersionsRequestParametersDict(TypedDict, total=False): + """Parameters for listing dataset versions.""" + + read_mask: Optional[str] + """""" + + dataset_id: Optional[str] + """""" + + config: Optional[ListPromptsConfigDict] + """""" + + +_ListDatasetVersionsRequestParametersOrDict = Union[ + _ListDatasetVersionsRequestParameters, _ListDatasetVersionsRequestParametersDict +] + + +class ListDatasetVersionsResponse(_common.BaseModel): + """Response for listing prompt datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponse] = Field( + default=None, description="""Used to retain the full HTTP response.""" + ) + next_page_token: Optional[str] = Field(default=None, description="""""") + dataset_versions: Optional[list[DatasetVersion]] = Field( + default=None, + description="""List of datasets for the project. + """, + ) + + # TODO(b/448806531): Remove all the overridden _from_response methods once the + # ticket is resolved and published. + @classmethod + def _from_response( + cls: typing.Type["ListDatasetVersionsResponse"], + *, + response: dict[str, object], + kwargs: dict[str, object], + ) -> "ListDatasetVersionsResponse": + """Converts a dictionary response into a ListDatasetVersionsResponse object.""" + + response = _camel_key_to_snake(response) + result = super()._from_response(response=response, kwargs=kwargs) + return result + + +class ListDatasetVersionsResponseDict(TypedDict, total=False): + """Response for listing prompt datasets.""" + + sdk_http_response: Optional[genai_types.HttpResponseDict] + """Used to retain the full HTTP response.""" + + next_page_token: Optional[str] + """""" + + dataset_versions: Optional[list[DatasetVersionDict]] + """List of datasets for the project. + """ + + +ListDatasetVersionsResponseOrDict = Union[ + ListDatasetVersionsResponse, ListDatasetVersionsResponseDict +] + + +class DeletePromptConfig(_common.BaseModel): + """Config for deleting a prompt.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + timeout: Optional[int] = Field( + default=90, + description="""Timeout for the delete prompt operation in seconds. Defaults to 90.""", + ) + + +class DeletePromptConfigDict(TypedDict, total=False): + """Config for deleting a prompt.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + timeout: Optional[int] + """Timeout for the delete prompt operation in seconds. Defaults to 90.""" + + +DeletePromptConfigOrDict = Union[DeletePromptConfig, DeletePromptConfigDict] + + +class _DeleteDatasetRequestParameters(_common.BaseModel): + """Parameters for deleting a prompt dataset.""" + + prompt_id: Optional[str] = Field( + default=None, description="""ID of the prompt dataset to be deleted.""" + ) + config: Optional[DeletePromptConfig] = Field(default=None, description="""""") + + +class _DeleteDatasetRequestParametersDict(TypedDict, total=False): + """Parameters for deleting a prompt dataset.""" + + prompt_id: Optional[str] + """ID of the prompt dataset to be deleted.""" + + config: Optional[DeletePromptConfigDict] + """""" + + +_DeleteDatasetRequestParametersOrDict = Union[ + _DeleteDatasetRequestParameters, _DeleteDatasetRequestParametersDict +] + + +class DeletePromptOperation(_common.BaseModel): + """Operation for deleting prompts.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeletePromptOperationDict(TypedDict, total=False): + """Operation for deleting prompts.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeletePromptOperationOrDict = Union[DeletePromptOperation, DeletePromptOperationDict] + + +class _DeletePromptVersionRequestParameters(_common.BaseModel): + """Parameters for deleting a prompt version.""" + + prompt_id: Optional[str] = Field( + default=None, description="""ID of the prompt to be deleted.""" + ) + version_id: Optional[str] = Field( + default=None, + description="""ID of the prompt version to be deleted within the provided prompt_id.""", + ) + config: Optional[DeletePromptConfig] = Field(default=None, description="""""") + + +class _DeletePromptVersionRequestParametersDict(TypedDict, total=False): + """Parameters for deleting a prompt version.""" + + prompt_id: Optional[str] + """ID of the prompt to be deleted.""" + + version_id: Optional[str] + """ID of the prompt version to be deleted within the provided prompt_id.""" + + config: Optional[DeletePromptConfigDict] + """""" + + +_DeletePromptVersionRequestParametersOrDict = Union[ + _DeletePromptVersionRequestParameters, _DeletePromptVersionRequestParametersDict +] + + +class DeletePromptVersionOperation(_common.BaseModel): + """Operation for deleting prompt versions.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class DeletePromptVersionOperationDict(TypedDict, total=False): + """Operation for deleting prompt versions.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +DeletePromptVersionOperationOrDict = Union[ + DeletePromptVersionOperation, DeletePromptVersionOperationDict +] + + +class RestoreVersionConfig(_common.BaseModel): + """Config for restoring a prompt version.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class RestoreVersionConfigDict(TypedDict, total=False): + """Config for restoring a prompt version.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +RestoreVersionConfigOrDict = Union[RestoreVersionConfig, RestoreVersionConfigDict] + + +class _RestoreVersionRequestParameters(_common.BaseModel): + """Parameters for restoring a prompt version.""" + + dataset_id: Optional[str] = Field( + default=None, description="""ID of the prompt dataset to be restored.""" + ) + version_id: Optional[str] = Field( + default=None, description="""ID of the prompt dataset version to be restored.""" + ) + config: Optional[RestoreVersionConfig] = Field(default=None, description="""""") + + +class _RestoreVersionRequestParametersDict(TypedDict, total=False): + """Parameters for restoring a prompt version.""" + + dataset_id: Optional[str] + """ID of the prompt dataset to be restored.""" + + version_id: Optional[str] + """ID of the prompt dataset version to be restored.""" + + config: Optional[RestoreVersionConfigDict] + """""" + + +_RestoreVersionRequestParametersOrDict = Union[ + _RestoreVersionRequestParameters, _RestoreVersionRequestParametersDict +] + + +class RestoreVersionOperation(_common.BaseModel): + """Represents the restore version operation.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + + +class RestoreVersionOperationDict(TypedDict, total=False): + """Represents the restore version operation.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + +RestoreVersionOperationOrDict = Union[ + RestoreVersionOperation, RestoreVersionOperationDict +] + + +class UpdatePromptConfig(_common.BaseModel): + """Config for creating a dataset resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + prompt_display_name: Optional[str] = Field( + default=None, description="""The updated display name for the prompt.""" + ) + version_display_name: Optional[str] = Field( + default=None, + description="""The updated display name for the prompt version. If not set, a default name with a timestamp will be used.""", + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the update_dataset_resource request in seconds. If not set, the default timeout is 90 seconds.""", + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""", + ) + + +class UpdatePromptConfigDict(TypedDict, total=False): + """Config for creating a dataset resource to store prompts.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + prompt_display_name: Optional[str] + """The updated display name for the prompt.""" + + version_display_name: Optional[str] + """The updated display name for the prompt version. If not set, a default name with a timestamp will be used.""" + + timeout: Optional[int] + """The timeout for the update_dataset_resource request in seconds. If not set, the default timeout is 90 seconds.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""" + + +UpdatePromptConfigOrDict = Union[UpdatePromptConfig, UpdatePromptConfigDict] + + +class _UpdateDatasetParameters(_common.BaseModel): + """Parameters for creating a dataset resource to store prompts.""" + + name: Optional[str] = Field(default=None, description="""""") + dataset_id: Optional[str] = Field(default=None, description="""""") + display_name: Optional[str] = Field(default=None, description="""""") + metadata: Optional[SchemaTextPromptDatasetMetadata] = Field( + default=None, description="""""" + ) + description: Optional[str] = Field(default=None, description="""""") + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, description="""""" + ) + model_reference: Optional[str] = Field(default=None, description="""""") + config: Optional[UpdatePromptConfig] = Field(default=None, description="""""") + + +class _UpdateDatasetParametersDict(TypedDict, total=False): + """Parameters for creating a dataset resource to store prompts.""" + + name: Optional[str] + """""" + + dataset_id: Optional[str] + """""" + + display_name: Optional[str] + """""" + + metadata: Optional[SchemaTextPromptDatasetMetadataDict] + """""" + + description: Optional[str] + """""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """""" + + model_reference: Optional[str] + """""" + + config: Optional[UpdatePromptConfigDict] + """""" + + +_UpdateDatasetParametersOrDict = Union[ + _UpdateDatasetParameters, _UpdateDatasetParametersDict +] + + +class GetSkillConfig(_common.BaseModel): + """Config for getting a skill.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetSkillConfigDict(TypedDict, total=False): + """Config for getting a skill.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetSkillConfigOrDict = Union[GetSkillConfig, GetSkillConfigDict] + + +class _GetSkillRequestParameters(_common.BaseModel): + """Parameters for GetSkillRequest.""" + + name: Optional[str] = Field( + default=None, + description="""The resource name of the Skill to retrieve. Format: projects/{project}/locations/{location}/skills/{skill}""", + ) + config: Optional[GetSkillConfig] = Field(default=None, description="""""") + + +class _GetSkillRequestParametersDict(TypedDict, total=False): + """Parameters for GetSkillRequest.""" + + name: Optional[str] + """The resource name of the Skill to retrieve. Format: projects/{project}/locations/{location}/skills/{skill}""" + + config: Optional[GetSkillConfigDict] + """""" + + +_GetSkillRequestParametersOrDict = Union[ + _GetSkillRequestParameters, _GetSkillRequestParametersDict +] + + +class Skill(_common.BaseModel): + """Represents a Skill resource. + + Patches the type from the discovery document. + """ + + name: Optional[str] = Field( + default=None, + description="""Identifier. The resource name of the Skill. Format: `projects/{project}/locations/{location}/skills/{skill}`""", + ) + create_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this Skill was created.""", + ) + update_time: Optional[datetime.datetime] = Field( + default=None, + description="""Output only. Timestamp when this Skill was most recently updated.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""Required. Provides the display name of the Skill. This should align with `name` in the `SKILL.md` file.""", + ) + description: Optional[str] = Field( + default=None, + description="""Required. Describes the Skill. Should describe both what the skill does and when to use it. Should include specific keywords that help agents identify relevant tasks. This should align with `description` in the `SKILL.md` file.""", + ) + license: Optional[str] = Field( + default=None, + description="""Optional. Specifies the license of the Skill. This should be an SPDX license identifier (e.g., "MIT", "Apache-2.0"). See https://spdx.org/licenses/. This should align with `license` in the `SKILL.md` file.""", + ) + compatibility: Optional[str] = Field( + default=None, + description="""Optional. Specifies the compatibility of the Skill. Indicates environment requirements (intended product, system packages, network access, etc.). This should align with `compatibility` in the `SKILL.md` file.""", + ) + zipped_filesystem: Optional[str] = Field( + default=None, + description="""Required. Provides the zipped filesystem of the Skill. This should contain the `SKILL.md` file at the root of the zip and optional directories for scripts, references, and assets. Directory should align with the directory structure specified at https://agentskills.io/specification#directory-structure.""", + ) + state: Optional[SkillState] = Field( + default=None, description="""Output only. The state of the Skill.""" + ) + + +class SkillDict(TypedDict, total=False): + """Represents a Skill resource. + + Patches the type from the discovery document. + """ + + name: Optional[str] + """Identifier. The resource name of the Skill. Format: `projects/{project}/locations/{location}/skills/{skill}`""" + + create_time: Optional[datetime.datetime] + """Output only. Timestamp when this Skill was created.""" + + update_time: Optional[datetime.datetime] + """Output only. Timestamp when this Skill was most recently updated.""" + + display_name: Optional[str] + """Required. Provides the display name of the Skill. This should align with `name` in the `SKILL.md` file.""" + + description: Optional[str] + """Required. Describes the Skill. Should describe both what the skill does and when to use it. Should include specific keywords that help agents identify relevant tasks. This should align with `description` in the `SKILL.md` file.""" + + license: Optional[str] + """Optional. Specifies the license of the Skill. This should be an SPDX license identifier (e.g., "MIT", "Apache-2.0"). See https://spdx.org/licenses/. This should align with `license` in the `SKILL.md` file.""" + + compatibility: Optional[str] + """Optional. Specifies the compatibility of the Skill. Indicates environment requirements (intended product, system packages, network access, etc.). This should align with `compatibility` in the `SKILL.md` file.""" + + zipped_filesystem: Optional[str] + """Required. Provides the zipped filesystem of the Skill. This should contain the `SKILL.md` file at the root of the zip and optional directories for scripts, references, and assets. Directory should align with the directory structure specified at https://agentskills.io/specification#directory-structure.""" + + state: Optional[SkillState] + """Output only. The state of the Skill.""" + + +SkillOrDict = Union[Skill, SkillDict] + + +class RetrieveSkillsConfig(_common.BaseModel): + """Config for retrieving skills.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + top_k: Optional[int] = Field( + default=None, + description="""Optional. The maximum number of skills to return. The service may + return fewer than this value. If unspecified, at most 10 skills will be + returned. The maximum value is 100. + """, + ) + + +class RetrieveSkillsConfigDict(TypedDict, total=False): + """Config for retrieving skills.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + top_k: Optional[int] + """Optional. The maximum number of skills to return. The service may + return fewer than this value. If unspecified, at most 10 skills will be + returned. The maximum value is 100. + """ + + +RetrieveSkillsConfigOrDict = Union[RetrieveSkillsConfig, RetrieveSkillsConfigDict] + + +class _RetrieveSkillsRequestParameters(_common.BaseModel): + """Parameters for retrieving skills.""" + + query: Optional[str] = Field( + default=None, description="""Required. The query to find matching skills.""" + ) + config: Optional[RetrieveSkillsConfig] = Field(default=None, description="""""") + + +class _RetrieveSkillsRequestParametersDict(TypedDict, total=False): + """Parameters for retrieving skills.""" + + query: Optional[str] + """Required. The query to find matching skills.""" + + config: Optional[RetrieveSkillsConfigDict] + """""" + + +_RetrieveSkillsRequestParametersOrDict = Union[ + _RetrieveSkillsRequestParameters, _RetrieveSkillsRequestParametersDict +] + + +class RetrievedSkill(_common.BaseModel): + """A retrieved skill from semantic search.""" + + skill_name: Optional[str] = Field( + default=None, description="""The resource name of the skill.""" + ) + description: Optional[str] = Field( + default=None, description="""The description of the skill.""" + ) + + +class RetrievedSkillDict(TypedDict, total=False): + """A retrieved skill from semantic search.""" + + skill_name: Optional[str] + """The resource name of the skill.""" + + description: Optional[str] + """The description of the skill.""" + + +RetrievedSkillOrDict = Union[RetrievedSkill, RetrievedSkillDict] + + +class RetrieveSkillsResponse(_common.BaseModel): + """Response for retrieving skills.""" + + retrieved_skills: Optional[list[RetrievedSkill]] = Field( + default=None, description="""List of retrieved skills ranked by similarity.""" + ) + + +class RetrieveSkillsResponseDict(TypedDict, total=False): + """Response for retrieving skills.""" + + retrieved_skills: Optional[list[RetrievedSkillDict]] + """List of retrieved skills ranked by similarity.""" + + +RetrieveSkillsResponseOrDict = Union[RetrieveSkillsResponse, RetrieveSkillsResponseDict] + + +class CreateSkillConfig(_common.BaseModel): + """Config for creating a skill.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Whether to wait for the long running operation to complete.""", + ) + local_path: Optional[str] = Field( + default=None, + description="""Optional. The local path to the directory containing the Skill to + be zipped and uploaded. + """, + ) + zipped_filesystem: Optional[Any] = Field( + default=None, description="""Optional. The zipped filesystem of the Skill.""" + ) + skill_id: Optional[str] = Field( + default=None, + description="""Optional. The ID to use for the Skill, which will become the final + component of the Skill's resource name. + """, + ) + + +class CreateSkillConfigDict(TypedDict, total=False): + """Config for creating a skill.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + wait_for_completion: Optional[bool] + """Whether to wait for the long running operation to complete.""" + + local_path: Optional[str] + """Optional. The local path to the directory containing the Skill to + be zipped and uploaded. + """ + + zipped_filesystem: Optional[Any] + """Optional. The zipped filesystem of the Skill.""" + + skill_id: Optional[str] + """Optional. The ID to use for the Skill, which will become the final + component of the Skill's resource name. + """ + + +CreateSkillConfigOrDict = Union[CreateSkillConfig, CreateSkillConfigDict] + + +class _CreateSkillRequestParameters(_common.BaseModel): + """Parameters for creating a skill.""" + + display_name: Optional[str] = Field( + default=None, description="""Required. The display name of the Skill.""" + ) + description: Optional[str] = Field( + default=None, description="""Required. The description of the Skill.""" + ) + config: Optional[CreateSkillConfig] = Field(default=None, description="""""") + + +class _CreateSkillRequestParametersDict(TypedDict, total=False): + """Parameters for creating a skill.""" + + display_name: Optional[str] + """Required. The display name of the Skill.""" + + description: Optional[str] + """Required. The description of the Skill.""" + + config: Optional[CreateSkillConfigDict] + """""" + + +_CreateSkillRequestParametersOrDict = Union[ + _CreateSkillRequestParameters, _CreateSkillRequestParametersDict +] + + +class SkillOperation(_common.BaseModel): + """Operation that has a skill as a response.""" + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[Skill] = Field( + default=None, description="""The created Skill.""" + ) + + +class SkillOperationDict(TypedDict, total=False): + """Operation that has a skill as a response.""" + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[SkillDict] + """The created Skill.""" + + +SkillOperationOrDict = Union[SkillOperation, SkillOperationDict] + + +class GetSkillOperationConfig(_common.BaseModel): + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetSkillOperationConfigDict(TypedDict, total=False): + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetSkillOperationConfigOrDict = Union[ + GetSkillOperationConfig, GetSkillOperationConfigDict +] + + +class _GetSkillOperationParameters(_common.BaseModel): + """Parameters for getting an operation.""" + + operation_name: Optional[str] = Field( + default=None, description="""The server-assigned name for the operation.""" + ) + config: Optional[GetSkillOperationConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _GetSkillOperationParametersDict(TypedDict, total=False): + """Parameters for getting an operation.""" + + operation_name: Optional[str] + """The server-assigned name for the operation.""" + + config: Optional[GetSkillOperationConfigDict] + """Used to override the default configuration.""" + + +_GetSkillOperationParametersOrDict = Union[ + _GetSkillOperationParameters, _GetSkillOperationParametersDict +] + + +class PromptOptimizerConfig(_common.BaseModel): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] = Field( + default=None, + description="""The gcs path to the config file, e.g. gs://bucket/config.json.""", + ) + service_account: Optional[str] = Field( + default=None, + description="""The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number.""", + ) + service_account_project_number: Optional[Union[int, str]] = Field( + default=None, + description="""The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account".""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Whether to wait for the job tocomplete. Ignored for async jobs.""", + ) + optimizer_job_display_name: Optional[str] = Field( + default=None, + description="""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used.""", + ) + + +class PromptOptimizerConfigDict(TypedDict, total=False): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] + """The gcs path to the config file, e.g. gs://bucket/config.json.""" + + service_account: Optional[str] + """The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number.""" + + service_account_project_number: Optional[Union[int, str]] + """The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account".""" + + wait_for_completion: Optional[bool] + """Whether to wait for the job tocomplete. Ignored for async jobs.""" + + optimizer_job_display_name: Optional[str] + """The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used.""" + + +PromptOptimizerConfigOrDict = Union[PromptOptimizerConfig, PromptOptimizerConfigDict] + + +class OptimizeResponse(_common.BaseModel): + """Response for the optimize_prompt method.""" + + raw_text_response: Optional[str] = Field(default=None, description="""""") + parsed_response: Optional["ParsedResponseUnion"] = Field( + default=None, description="""""" + ) + + +class OptimizeResponseDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + raw_text_response: Optional[str] + """""" + + parsed_response: Optional["ParsedResponseUnionDict"] + """""" + + +OptimizeResponseOrDict = Union[OptimizeResponse, OptimizeResponseDict] + + +class ContentMapContents(_common.BaseModel): + """Map of placeholder in metric prompt template to contents of model input.""" + + contents: Optional[list[genai_types.Content]] = Field( + default=None, description="""Contents of the model input.""" + ) + + +class ContentMapContentsDict(TypedDict, total=False): + """Map of placeholder in metric prompt template to contents of model input.""" + + contents: Optional[list[genai_types.ContentDict]] + """Contents of the model input.""" + + +ContentMapContentsOrDict = Union[ContentMapContents, ContentMapContentsDict] + + +class EvaluateMethodConfig(_common.BaseModel): + """Optional parameters for the evaluate method.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] = Field( + default=None, + description="""The schema to use for the dataset. + If not specified, the dataset schema will be inferred from the first + example in the dataset.""", + ) + dest: Optional[str] = Field( + default=None, description="""The destination path for the evaluation results.""" + ) + evaluation_service_qps: Optional[float] = Field( + default=None, + description="""The rate limit (queries per second) for calls to the + evaluation service. Defaults to 10. Increase this value if your + project has a higher EvaluateInstances API quota.""", + ) + + +class EvaluateMethodConfigDict(TypedDict, total=False): + """Optional parameters for the evaluate method.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + dataset_schema: Optional[Literal["GEMINI", "FLATTEN", "OPENAI"]] + """The schema to use for the dataset. + If not specified, the dataset schema will be inferred from the first + example in the dataset.""" + + dest: Optional[str] + """The destination path for the evaluation results.""" + + evaluation_service_qps: Optional[float] + """The rate limit (queries per second) for calls to the + evaluation service. Defaults to 10. Increase this value if your + project has a higher EvaluateInstances API quota.""" + + +EvaluateMethodConfigOrDict = Union[EvaluateMethodConfig, EvaluateMethodConfigDict] + + +class EvaluateDatasetConfig(_common.BaseModel): + """Config for evaluate instances.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class EvaluateDatasetConfigDict(TypedDict, total=False): + """Config for evaluate instances.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +EvaluateDatasetConfigOrDict = Union[EvaluateDatasetConfig, EvaluateDatasetConfigDict] + + +class EvaluateDatasetOperation(_common.BaseModel): + + name: Optional[str] = Field( + default=None, + description="""The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""", + ) + metadata: Optional[dict[str, Any]] = Field( + default=None, + description="""Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""", + ) + done: Optional[bool] = Field( + default=None, + description="""If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""", + ) + error: Optional[dict[str, Any]] = Field( + default=None, + description="""The error result of the operation in case of failure or cancellation.""", + ) + response: Optional[EvaluationDataset] = Field(default=None, description="""""") + + +class EvaluateDatasetOperationDict(TypedDict, total=False): + + name: Optional[str] + """The server-assigned name, which is only unique within the same service that originally returns it. If you use the default HTTP mapping, the `name` should be a resource name ending with `operations/{unique_id}`.""" + + metadata: Optional[dict[str, Any]] + """Service-specific metadata associated with the operation. It typically contains progress information and common metadata such as create time. Some services might not provide such metadata. Any method that returns a long-running operation should document the metadata type, if any.""" + + done: Optional[bool] + """If the value is `false`, it means the operation is still in progress. If `true`, the operation is completed, and either `error` or `response` is available.""" + + error: Optional[dict[str, Any]] + """The error result of the operation in case of failure or cancellation.""" + + response: Optional[EvaluationDatasetDict] + """""" + + +EvaluateDatasetOperationOrDict = Union[ + EvaluateDatasetOperation, EvaluateDatasetOperationDict +] + + +class EvaluateDatasetRequestParameters(_common.BaseModel): + """Parameters for batch dataset evaluation.""" + + dataset: Optional[EvaluationDataset] = Field(default=None, description="""""") + metrics: Optional[list[Metric]] = Field(default=None, description="""""") + output_config: Optional[genai_types.OutputConfig] = Field( + default=None, description="""""" + ) + autorater_config: Optional[genai_types.AutoraterConfig] = Field( + default=None, description="""""" + ) + config: Optional[EvaluateDatasetConfig] = Field(default=None, description="""""") + + +class EvaluateDatasetRequestParametersDict(TypedDict, total=False): + """Parameters for batch dataset evaluation.""" + + dataset: Optional[EvaluationDatasetDict] + """""" + + metrics: Optional[list[MetricDict]] + """""" + + output_config: Optional[genai_types.OutputConfigDict] + """""" + + autorater_config: Optional[genai_types.AutoraterConfigDict] + """""" + + config: Optional[EvaluateDatasetConfigDict] + """""" + + +EvaluateDatasetRequestParametersOrDict = Union[ + EvaluateDatasetRequestParameters, EvaluateDatasetRequestParametersDict +] + + +class ObservabilityEvalCase(_common.BaseModel): + """A single evaluation case instance for data stored in GCP Observability.""" + + input_src: Optional[str] = Field( + default=None, + description="""String containing the GCS reference to the GenAI input content.""", + ) + output_src: Optional[str] = Field( + default=None, + description="""String containing the GCS reference to the GenAI response content.""", + ) + system_instruction_src: Optional[str] = Field( + default=None, + description="""An optional string containing the GCS reference to the GenAI system instruction.""", + ) + api_client: Optional[Any] = Field( + default=None, description="""The underlying API client.""" + ) + + +class ObservabilityEvalCaseDict(TypedDict, total=False): + """A single evaluation case instance for data stored in GCP Observability.""" + + input_src: Optional[str] + """String containing the GCS reference to the GenAI input content.""" + + output_src: Optional[str] + """String containing the GCS reference to the GenAI response content.""" + + system_instruction_src: Optional[str] + """An optional string containing the GCS reference to the GenAI system instruction.""" + + api_client: Optional[Any] + """The underlying API client.""" + + +ObservabilityEvalCaseOrDict = Union[ObservabilityEvalCase, ObservabilityEvalCaseDict] + + +class RubricGroup(_common.BaseModel): + """A group of rubrics. + + Used for grouping rubrics based on a metric or a version. + """ + + group_id: Optional[str] = Field( + default=None, description="""Unique identifier for the group.""" + ) + display_name: Optional[str] = Field( + default=None, + description="""Human-readable name for the group. This should be unique + within a given context if used for display or selection. + Example: "Instruction Following V1", "Content Quality - Summarization + Task".""", + ) + rubrics: Optional[list[evals_types.Rubric]] = Field( + default=None, description="""Rubrics that are part of this group.""" + ) + + +class RubricGroupDict(TypedDict, total=False): + """A group of rubrics. + + Used for grouping rubrics based on a metric or a version. + """ + + group_id: Optional[str] + """Unique identifier for the group.""" + + display_name: Optional[str] + """Human-readable name for the group. This should be unique + within a given context if used for display or selection. + Example: "Instruction Following V1", "Content Quality - Summarization + Task".""" + + rubrics: Optional[list[evals_types.Rubric]] + """Rubrics that are part of this group.""" + + +RubricGroupOrDict = Union[RubricGroup, RubricGroupDict] + + +class PromptTemplate(_common.BaseModel): + """A prompt template for creating prompts with variables.""" + + text: Optional[str] = Field( + default=None, description="""The prompt template text.""" + ) + _VARIABLE_NAME_REGEX: ClassVar[str] = r"\{([_a-zA-Z][_a-zA-Z0-9]*)\}" + + @field_validator("text") + @classmethod + def text_must_not_be_empty(cls, value: str) -> str: + if not value.strip(): + raise ValueError( + "Prompt template text cannot be empty or consist only of" " whitespace." + ) + return value + + @computed_field # type: ignore[prop-decorator] + @property + def variables(self) -> set[str]: + return set(re.findall(self._VARIABLE_NAME_REGEX, self.text)) + + def _split_template_by_variables(self) -> list[Tuple[str, str]]: + parts = [] + last_end = 0 + for match in re.finditer(self._VARIABLE_NAME_REGEX, self.text): + start, end = match.span() + var_name = match.group(1) + if start > last_end and self.text: + parts.append(("text", self.text[last_end:start])) + parts.append(("var", var_name)) + last_end = end + if last_end < len(self.text) and self.text: + parts.append(("text", self.text[last_end:])) + return parts + + def _merge_adjacent_text_parts( + self, parts: list[genai_types.Part] + ) -> list[genai_types.Part]: + if not parts: + return [] + + merged = [] + current_text_buffer = [] + + for part in parts: + is_purely_text = part.text is not None and all( + getattr(part, field) is None + for field in part.model_fields + if field != "text" + ) + + if is_purely_text: + current_text_buffer.append(part.text) + else: + if current_text_buffer: + merged.append(genai_types.Part(text="".join(current_text_buffer))) + current_text_buffer = [] + merged.append(part) + + if current_text_buffer: + merged.append(genai_types.Part(text="".join(current_text_buffer))) + + return merged + + def _is_multimodal_json_string( + self, + value: Any, + ) -> bool: + """Checks if the input value is a multimodal JSON string.""" + if not isinstance(value, str): + return False + try: + data = json.loads(value) + # Check for the specific structure: {"contents": [{"parts": [...]}]} + # or {"parts": [...]} if assemble returns a single Content JSON + if isinstance(data, dict): + if "contents" in data and isinstance(data["contents"], list): + if not data["contents"]: + return False + first_content = data["contents"][0] + if isinstance(first_content, dict) and "parts" in first_content: + try: + genai_types.Content.model_validate(first_content) + return True + except ValueError: + return False + # Adding a check if 'data' itself is a Content-like object with parts + elif "parts" in data and isinstance(data["parts"], list): + try: + genai_types.Content.model_validate(data) + return True + except ValueError: + return False + return False + except json.JSONDecodeError: + return False + + def _parse_multimodal_json_string_into_parts( + self, + value: str, + ) -> list[genai_types.Part]: + """Parses a multimodal JSON string and returns its list of Parts.""" + try: + content = genai_types.Content.model_validate_json(value) + return content.parts if content.parts is not None else [genai_types.Part()] + except Exception: + return [genai_types.Part(text=value)] + + def assemble(self, **kwargs: Any) -> str: + """Assembles the prompt template with the given keyword arguments. + + Supports both text and multimodal content. The `assemble` method + substitutes variables from the prompt template text with provided + values. + + Key Behaviors of `assemble()`: + 1. Variable Substitution: Replaces all defined variables with their + corresponding keyword argument values. Raises ValueError if a + template + variable is missing a value or if an extraneous kwarg is provided. + 2. Multimodal Handling: + - Detects if any variable's value is a JSON string representing + multimodal content (specifically, `{"contents": [{"parts": [...]}]}` + or `{"role": "user", "parts": [...]}`). + - If multimodal content is detected for a variable, its `Part` + objects + are extracted and inserted into the assembled sequence. + - Text segments from the template and simple text variable values + become `Part(text=...)`. + 3. Output Format: + - If ALL substituted variables were simple text AND the assembled + result (after merging adjacent text parts) consists of a single, + purely textual `Part`, `assemble()` returns a raw Python string. + - Otherwise (if any variable was multimodal, or if the assembly + results in multiple parts or non-textual parts), `assemble()` + returns + a JSON string representing a single `google.genai.types.Content` + object with `role="user"` and the assembled parts. + 4. Text Part Merging: Consecutively assembled text parts are + automatically merged into a single text `Part` to create a more + concise list of parts. + + This dual output format (raw string or JSON string of `Content`) allows + the downstream inference functions to seamlessly handle both simple text + prompts and more complex multimodal prompts generated from the same + templating mechanism. + """ + current_variables = self.variables + for var_name_in_kwarg in kwargs: + if var_name_in_kwarg not in current_variables: + raise ValueError( + f"Invalid variable name '{var_name_in_kwarg}' provided to" + " assemble. Valid variables in template are:" + f" {current_variables}" + ) + # Check if all template variables are provided in kwargs + for tpl_var in current_variables: + if tpl_var not in kwargs: + raise ValueError(f"Missing value for template variable '{tpl_var}'.") + + template_segments = self._split_template_by_variables() + + raw_assembled_parts: list[genai_types.Part] = [] + contains_multimodal_variable_type = False + + for segment_type, segment_value in template_segments: + if segment_type == "text": + if segment_value: + raw_assembled_parts.append(genai_types.Part(text=segment_value)) + elif segment_type == "var": + var_value = kwargs.get(segment_value) + + str_var_value = str(var_value) + + if self._is_multimodal_json_string(str_var_value): + multimodal_parts = self._parse_multimodal_json_string_into_parts( + str_var_value + ) + if multimodal_parts: + contains_multimodal_variable_type = True + raw_assembled_parts.extend(multimodal_parts) + else: + raw_assembled_parts.append(genai_types.Part(text=str_var_value)) + else: + raw_assembled_parts.append(genai_types.Part(text=str_var_value)) + + final_assembled_parts = self._merge_adjacent_text_parts(raw_assembled_parts) + + # Condition for returning raw text string: + # 1. No multimodal variable was *originally* a multimodal JSON string. + # 2. After merging, there's exactly one part. + # 3. That single part is purely textual. + if ( + not contains_multimodal_variable_type + and len(final_assembled_parts) == 1 + and final_assembled_parts[0].text is not None + and all( + getattr(final_assembled_parts[0], field) is None + for field in final_assembled_parts[0].model_fields + if field not in ["text", "role"] + ) + ): + return final_assembled_parts[0].text + + # Otherwise, construct a Content object (as JSON string). + final_content_obj = genai_types.Content(parts=final_assembled_parts) + return final_content_obj.model_dump_json(exclude_none=True) + + def __str__(self) -> str: + return self.text if self.text else "" + + def __repr__(self) -> str: + return f"PromptTemplate(text='{self.text}')" + + +class MetricPromptBuilder(PromptTemplate): + """Builder class for structured LLM-based metric prompt template.""" + + criteria: Optional[dict[str, str]] = Field( + None, + description="""A dictionary of criteria used to evaluate the model responses. + The keys are criterion names, and the values are the corresponding + criterion definitions. + """, + ) + + rating_scores: Optional[dict[str, str]] = Field( + None, + description="""A dictionary mapping of rating score names to their definitions.""", + ) + + @staticmethod + def _get_default_instruction() -> str: + """Returns the default instruction for evaluation.""" + return ( + "You are an expert evaluator. Your task is to evaluate the quality" + " of the responses generated by AI models. We will provide you with" + " the user prompt and an AI-generated responses.\nYou should first" + " read the user input carefully for analyzing the task, and then" + " evaluate the quality of the responses based on the Criteria" + " provided in the Evaluation section below.\nYou will assign the" + " response a rating following the Rating Scores and Evaluation" + " Steps. Give step by step explanations for your rating, and only" + " choose ratings from the Rating Scores." + ) + + instruction: Optional[str] = Field( + default_factory=lambda: MetricPromptBuilder._get_default_instruction(), + description="""The general instruction to guide the model in performing the evaluation. + If not provided, a default instruction for evaluation will be used. + """, + ) + + metric_definition: Optional[str] = Field( + None, + description="""An optional high-level description of the metric to be evaluated. + If not provided, this field will not be included in the prompt template. + """, + ) + + @staticmethod + def _get_default_evaluation_steps() -> dict[str, str]: + """Returns the default evaluation steps for metric evaluation.""" + return { + "Step 1": ( + "Assess the response in aspects of all criteria provided. Provide" + " assessment according to each criterion." + ), + "Step 2": ( + "Score based on the Rating Scores. Give a brief rationale to" + " explain your evaluation considering each individual criterion." + ), + } + + evaluation_steps: Optional[dict[str, str]] = Field( + default_factory=lambda: MetricPromptBuilder._get_default_evaluation_steps(), + description="""An optional dictionary of evaluation steps. + The keys are the names of the evaluation steps, and the values are + descriptions of the corresponding evaluation steps. If not provided, + default metric evaluation steps will be used. + """, + ) + + few_shot_examples: Optional[list[str]] = Field( + None, + description="""An optional list of few-shot examples to guide the model's evaluation. + These examples demonstrate how to apply the criteria, rating scores, + and evaluation steps to assess model responses. Providing few-shot examples + can improve the accuracy of the evaluation. If not provided, this field + will not be included in the prompt template. + """, + ) + + @staticmethod + def _serialize_dict_in_order(elements: Optional[dict[str, str]]) -> str: + """Serializes dictionary to ordered string value without brackets.""" + if elements is None: + return "" + return "\n".join(f"{key}: {value}" for key, value in sorted(elements.items())) + + @model_validator(mode="before") + @classmethod + def _prepare_fields_and_construct_text(cls, data: Any) -> Any: + """Pydantic model validator (before mode) to prepare and construct prompt text. + + This validator performs the following actions: + 1. Apply default logic for fields (instruction, evaluation_steps). + 2. Construct the 'text' string from all components. + 3. Ensure 'text' is in the data dictionary for PromptTemplate + initialization. + + Args: + data: Input data for the model, either a dictionary or an existing model + instance. + + Returns: + Processed data dictionary with the 'text' field constructed. + """ + if not isinstance(data, dict): + return data + + if "text" in data: + raise ValueError( + "The 'text' field is automatically constructed and should not be" + " provided manually." + ) + + if data.get("criteria") is None or data.get("rating_scores") is None: + raise ValueError( + "Both 'criteria' and 'rating_scores' are required to construct the" + "LLM-based metric prompt template text." + ) + + instruction = data.get("instruction", cls._get_default_instruction()) + metric_definition = data.get("metric_definition") + evaluation_steps = data.get( + "evaluation_steps", cls._get_default_evaluation_steps() + ) + criteria = data.get("criteria") + rating_scores = data.get("rating_scores") + few_shot_examples = data.get("few_shot_examples") + + template_parts = [ + "# Instruction", + instruction, + "\n", + "# Evaluation", + ] + + sections = { + "Metric Definition": metric_definition, + "Criteria": cls._serialize_dict_in_order(criteria), + "Rating Scores": cls._serialize_dict_in_order(rating_scores), + "Evaluation Steps": cls._serialize_dict_in_order(evaluation_steps), + "Evaluation Examples": ( + "\n".join(few_shot_examples) if few_shot_examples else None + ), + } + + for title, content in sections.items(): + if content: + template_parts.extend([f"## {title}", f"{content}\n"]) + + template_parts.extend( + [ + "\n", + "# User Inputs and AI-generated Response", + "## User Prompt", + "{prompt}", + "\n", + "## AI-generated Response", + "{response}", + ] + ) + + constructed_text = "\n".join(template_parts) + + data["text"] = constructed_text + return data + + def __str__(self) -> str: + """Returns the fully constructed prompt template text.""" + return self.text if self.text else "" + + +class PromptTemplateDict(TypedDict, total=False): + """A prompt template for creating prompts with variables.""" + + text: Optional[str] + """The prompt template text.""" + + +PromptTemplateOrDict = Union[PromptTemplate, PromptTemplateDict] + + +class EvalRunInferenceConfig(_common.BaseModel): + """Optional parameters for inference.""" + + dest: Optional[str] = Field( + default=None, description="""The destination path for the inference results.""" + ) + prompt_template: Optional[Union[str, PromptTemplate]] = Field( + default=None, description="""The prompt template to use for inference.""" + ) + generate_content_config: Optional[genai_types.GenerateContentConfig] = Field( + default=None, description="""The config for the generate content call.""" + ) + user_simulator_config: Optional[evals_types.UserSimulatorConfig] = Field( + default=None, + description="""Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains + conversation plans, user simulation will be triggered.""", + ) + allow_cross_region_model: Optional[bool] = Field( + default=None, + description="""Opt-in flag to authorize cross-region routing for LLM models.""", + ) + + +class EvalRunInferenceConfigDict(TypedDict, total=False): + """Optional parameters for inference.""" + + dest: Optional[str] + """The destination path for the inference results.""" + + prompt_template: Optional[Union[str, PromptTemplateDict]] + """The prompt template to use for inference.""" + + generate_content_config: Optional[genai_types.GenerateContentConfigDict] + """The config for the generate content call.""" + + user_simulator_config: Optional[evals_types.UserSimulatorConfig] + """Configuration for user simulation in multi-turn agent scraping. If provided, and the dataset contains + conversation plans, user simulation will be triggered.""" + + allow_cross_region_model: Optional[bool] + """Opt-in flag to authorize cross-region routing for LLM models.""" + + +EvalRunInferenceConfigOrDict = Union[EvalRunInferenceConfig, EvalRunInferenceConfigDict] + + +class AgentEngine(_common.BaseModel): + """An agent engine instance.""" + + api_client: Optional[Any] = Field( + default=None, description="""The underlying API client.""" + ) + api_async_client: Optional[Any] = Field( + default=None, + description="""The underlying API client for asynchronous operations.""", + ) + api_resource: Optional[ReasoningEngine] = Field( + default=None, + description="""The underlying API resource (i.e. ReasoningEngine).""", + ) + + # Allows dynamic binding of methods based on the registered operations. + model_config = ConfigDict(extra="allow") + + def __repr__(self) -> str: + return ( + f"AgentEngine(api_resource.name='{self.api_resource.name}')" + if self.api_resource is not None + else "AgentEngine(api_resource.name=None)" + ) + + def operation_schemas(self) -> Optional[list[Dict[str, Any]]]: + """Returns the schemas of all registered operations for the agent.""" + if not isinstance(self.api_resource, ReasoningEngine): + raise ValueError("api_resource is not initialized.") + if not self.api_resource.spec: + raise ValueError("api_resource.spec is not initialized.") + return self.api_resource.spec.class_methods + + def delete( + self, + force: bool = False, + config: Optional[DeleteAgentEngineConfigOrDict] = None, + ) -> None: + """Deletes the agent engine. + + Args: + force (bool): + Optional. If set to True, child resources will also be deleted. + Otherwise, the request will fail with FAILED_PRECONDITION error when + the Agent Engine has undeleted child resources. Defaults to False. + config (DeleteAgentEngineConfig): + Optional. Additional configurations for deleting the Agent Engine. + """ + if not isinstance(self.api_resource, ReasoningEngine): + raise ValueError("api_resource is not initialized.") + self.api_client.delete(name=self.api_resource.name, force=force, config=config) # type: ignore[union-attr] + + +RubricContentProperty = evals_types.RubricContentProperty +RubricContentPropertyDict = evals_types.RubricContentPropertyDict +RubricContentPropertyDictOrDict = evals_types.RubricContentPropertyOrDict + +RubricContent = evals_types.RubricContent +RubricContentDict = evals_types.RubricContentDict +RubricContentDictOrDict = evals_types.RubricContentOrDict + +Rubric = evals_types.Rubric +RubricDict = evals_types.RubricDict +RubricDictOrDict = evals_types.RubricOrDict + +RubricVerdict = evals_types.RubricVerdict +RubricVerdictDict = evals_types.RubricVerdictDict +RubricVerdictDictOrDict = evals_types.RubricVerdictOrDict + +CandidateResult = evals_types.CandidateResult +CandidateResultDict = evals_types.CandidateResultDict +CandidateResultDictOrDict = evals_types.CandidateResultOrDict + +Event = evals_types.Event +EventDict = evals_types.EventDict +EventDictOrDict = evals_types.EventOrDict + +Message = evals_types.Message +MessageDict = evals_types.MessageDict +MessageDictOrDict = evals_types.MessageOrDict + +Importance = evals_types.Importance + + +class AgentEngineDict(TypedDict, total=False): + """An agent engine instance.""" + + api_client: Optional[Any] + """The underlying API client.""" + + api_async_client: Optional[Any] + """The underlying API client for asynchronous operations.""" + + api_resource: Optional[ReasoningEngineDict] + """The underlying API resource (i.e. ReasoningEngine).""" + + +AgentEngineOrDict = Union[AgentEngine, AgentEngineDict] + + +class AgentEngineConfig(_common.BaseModel): + """Config for agent engine methods.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + staging_bucket: Optional[str] = Field( + default=None, + description="""The GCS bucket to use for staging the artifacts needed. + + It must be a valid GCS bucket name, e.g. "gs://bucket-name". It is + required if `agent_engine` is specified.""", + ) + requirements: Optional[Any] = Field( + default=None, + description="""The set of PyPI dependencies needed. + + It can either be the path to a single file (requirements.txt), or an + ordered list of strings corresponding to each line of the requirements + file.""", + ) + display_name: Optional[str] = Field( + default=None, + description="""The user-defined name of the Agent Engine. + + The name can be up to 128 characters long and can comprise any UTF-8 + character.""", + ) + description: Optional[str] = Field( + default=None, description="""The description of the Agent Engine.""" + ) + gcs_dir_name: Optional[str] = Field( + default=None, + description="""The GCS bucket directory under `staging_bucket` to use for staging + the artifacts needed.""", + ) + extra_packages: Optional[list[str]] = Field( + default=None, + description="""The set of extra user-provided packages (if any).""", + ) + env_vars: Optional[Any] = Field( + default=None, + description="""The environment variables to be set when running the Agent Engine. + + If it is a dictionary, the keys are the environment variable names, and + the values are the corresponding values.""", + ) + service_account: Optional[str] = Field( + default=None, + description="""The service account to be used for the Agent Engine. + + If not specified, the default Reasoning Engine P6SA service agent will be used.""", + ) + identity_type: Optional[IdentityType] = Field( + default=None, description="""The identity type to use for the Agent Engine.""" + ) + context_spec: Optional[ReasoningEngineContextSpec] = Field( + default=None, + description="""The context spec to be used for the Agent Engine.""", + ) + psc_interface_config: Optional[PscInterfaceConfig] = Field( + default=None, + description="""The PSC interface config for PSC-I to be used for the Agent Engine.""", + ) + min_instances: Optional[int] = Field( + default=None, + description="""The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """, + ) + max_instances: Optional[int] = Field( + default=None, + description="""The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """, + ) + resource_limits: Optional[dict[str, str]] = Field( + default=None, + description="""The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """, + ) + container_concurrency: Optional[int] = Field( + default=None, + description="""The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """, + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""The encryption spec to be used for the Agent Engine.""", + ) + labels: Optional[dict[str, str]] = Field( + default=None, description="""The labels to be used for the Agent Engine.""" + ) + agent_server_mode: Optional[AgentServerMode] = Field( + default=None, description="""The agent server mode to use for deployment.""" + ) + class_methods: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """, + ) + source_packages: Optional[list[str]] = Field( + default=None, + description="""The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """, + ) + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfig + ] = Field( + default=None, + description="""Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""", + ) + entrypoint_module: Optional[str] = Field( + default=None, + description="""The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""", + ) + entrypoint_object: Optional[str] = Field( + default=None, + description="""The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""", + ) + requirements_file: Optional[str] = Field( + default=None, + description="""The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """, + ) + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] = Field( + default=None, + description="""The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""", + ) + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] = Field( + default=None, + description="""The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """, + ) + build_options: Optional[dict[str, list[str]]] = Field( + default=None, + description="""The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """, + ) + image_spec: Optional[ReasoningEngineSpecSourceCodeSpecImageSpec] = Field( + default=None, description="""The image spec for the Agent Engine.""" + ) + agent_config_source: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSource + ] = Field( + default=None, description="""The agent config source for the Agent Engine.""" + ) + container_spec: Optional[ReasoningEngineSpecContainerSpec] = Field( + default=None, description="""The container spec for the Agent Engine.""" + ) + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfig + ] = Field( + default=None, + description="""Agent Gateway configuration for a Reasoning Engine deployment.""", + ) + keep_alive_probe: Optional[KeepAliveProbe] = Field( + default=None, + description="""Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""", + ) + traffic_config: Optional[ReasoningEngineTrafficConfig] = Field( + default=None, description="""The traffic config for the Agent Engine.""" + ) + + +class AgentEngineConfigDict(TypedDict, total=False): + """Config for agent engine methods.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + staging_bucket: Optional[str] + """The GCS bucket to use for staging the artifacts needed. + + It must be a valid GCS bucket name, e.g. "gs://bucket-name". It is + required if `agent_engine` is specified.""" + + requirements: Optional[Any] + """The set of PyPI dependencies needed. + + It can either be the path to a single file (requirements.txt), or an + ordered list of strings corresponding to each line of the requirements + file.""" + + display_name: Optional[str] + """The user-defined name of the Agent Engine. + + The name can be up to 128 characters long and can comprise any UTF-8 + character.""" + + description: Optional[str] + """The description of the Agent Engine.""" + + gcs_dir_name: Optional[str] + """The GCS bucket directory under `staging_bucket` to use for staging + the artifacts needed.""" + + extra_packages: Optional[list[str]] + """The set of extra user-provided packages (if any).""" + + env_vars: Optional[Any] + """The environment variables to be set when running the Agent Engine. + + If it is a dictionary, the keys are the environment variable names, and + the values are the corresponding values.""" + + service_account: Optional[str] + """The service account to be used for the Agent Engine. + + If not specified, the default Reasoning Engine P6SA service agent will be used.""" + + identity_type: Optional[IdentityType] + """The identity type to use for the Agent Engine.""" + + context_spec: Optional[ReasoningEngineContextSpecDict] + """The context spec to be used for the Agent Engine.""" + + psc_interface_config: Optional[PscInterfaceConfigDict] + """The PSC interface config for PSC-I to be used for the Agent Engine.""" + + min_instances: Optional[int] + """The minimum number of instances to run for the Agent Engine. + Defaults to 1. Range: [0, 10]. + """ + + max_instances: Optional[int] + """The maximum number of instances to run for the Agent Engine. + Defaults to 100. Range: [1, 1000]. + If VPC-SC or PSC-I is enabled, the acceptable range is [1, 100]. + """ + + resource_limits: Optional[dict[str, str]] + """The resource limits to be applied to the Agent Engine. + Required keys: 'cpu' and 'memory'. + Supported values for 'cpu': '1', '2', '4', '6', '8'. + Supported values for 'memory': '1Gi', '2Gi', ..., '32Gi'. + """ + + container_concurrency: Optional[int] + """The container concurrency to be used for the Agent Engine. + Recommended value: 2 * cpu + 1. Defaults to 9. + """ + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """The encryption spec to be used for the Agent Engine.""" + + labels: Optional[dict[str, str]] + """The labels to be used for the Agent Engine.""" + + agent_server_mode: Optional[AgentServerMode] + """The agent server mode to use for deployment.""" + + class_methods: Optional[list[dict[str, Any]]] + """The class methods to be used for the Agent Engine. + If specified, they'll override the class methods that are autogenerated by + default. By default, methods are generated by inspecting the agent object + and generating a corresponding method for each method defined on the + agent class. + """ + + source_packages: Optional[list[str]] + """The user-provided paths to the source packages (if any). + If specified, the files in the source packages will be packed into a + a tarball file, uploaded to Agent Engine's API, and deployed to the + Agent Engine. + The following fields will be ignored: + - agent + - extra_packages + - staging_bucket + - requirements + The following fields will be used to install and use the agent from the + source packages: + - entrypoint_module (required) + - entrypoint_object (required) + - requirements_file (optional) + - class_methods (required) + """ + + developer_connect_source: Optional[ + ReasoningEngineSpecSourceCodeSpecDeveloperConnectConfigDict + ] + """Specifies the configuration for fetching source code from a Git repository that is managed by Developer Connect. This includes the repository, revision, and directory to use.""" + + entrypoint_module: Optional[str] + """The entrypoint module to be used for the Agent Engine + This field only used when source_packages is specified.""" + + entrypoint_object: Optional[str] + """The entrypoint object to be used for the Agent Engine. + This field only used when source_packages is specified.""" + + requirements_file: Optional[str] + """The user-provided path to the requirements file (if any). + This field is only used when source_packages is specified. + If not specified, agent engine will find and use the `requirements.txt` in + the source package. + """ + + agent_framework: Optional[ + Literal["google-adk", "langchain", "langgraph", "ag2", "llama-index", "custom"] + ] + """The agent framework to be used for the Agent Engine. + The OSS agent framework used to develop the agent. + Currently supported values: "google-adk", "langchain", "langgraph", + "ag2", "llama-index", "custom". + If not specified: + - If `agent` is specified, the agent framework will be auto-detected. + - If `source_packages` is specified, the agent framework will + default to "custom".""" + + python_version: Optional[Literal["3.10", "3.11", "3.12", "3.13", "3.14"]] + """The Python version to be used for the Agent Engine. + If not specified, it will use the current Python version of the environment. + Supported versions: "3.10", "3.11", "3.12", "3.13", "3.14". + """ + + build_options: Optional[dict[str, list[str]]] + """The build options for the Agent Engine. + The following keys are supported: + - installation_scripts: + Optional. The paths to the installation scripts to be + executed in the Docker image. + The scripts must be located in the `installation_scripts` + subdirectory and the path must be added to `extra_packages`. + """ + + image_spec: Optional[ReasoningEngineSpecSourceCodeSpecImageSpecDict] + """The image spec for the Agent Engine.""" + + agent_config_source: Optional[ + ReasoningEngineSpecSourceCodeSpecAgentConfigSourceDict + ] + """The agent config source for the Agent Engine.""" + + container_spec: Optional[ReasoningEngineSpecContainerSpecDict] + """The container spec for the Agent Engine.""" + + agent_gateway_config: Optional[ + ReasoningEngineSpecDeploymentSpecAgentGatewayConfigDict + ] + """Agent Gateway configuration for a Reasoning Engine deployment.""" + + keep_alive_probe: Optional[KeepAliveProbeDict] + """Optional. Specifies the configuration for keep-alive probe. + Contains configuration on a specified endpoint that a deployment host + should use to keep the container alive based on the probe settings.""" + + traffic_config: Optional[ReasoningEngineTrafficConfigDict] + """The traffic config for the Agent Engine.""" + + +AgentEngineConfigOrDict = Union[AgentEngineConfig, AgentEngineConfigDict] + + +class RunQueryJobAgentEngineConfig(_common.BaseModel): + """Config for checking a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + query: Optional[str] = Field( + default=None, description="""The query to send to the agent engine.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, + description="""The GCS URI to use for the output. + If it is a file, the system use this file to store the response. + If it represents a directory, the system automatically generate a file + for the response. + In both cases, the input query will be stored in the same directory under + the same file name prefix as the output file.""", + ) + + +class RunQueryJobAgentEngineConfigDict(TypedDict, total=False): + """Config for checking a query job on an agent engine.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + query: Optional[str] + """The query to send to the agent engine.""" + + output_gcs_uri: Optional[str] + """The GCS URI to use for the output. + If it is a file, the system use this file to store the response. + If it represents a directory, the system automatically generate a file + for the response. + In both cases, the input query will be stored in the same directory under + the same file name prefix as the output file.""" + + +RunQueryJobAgentEngineConfigOrDict = Union[ + RunQueryJobAgentEngineConfig, RunQueryJobAgentEngineConfigDict +] + + +class RunQueryJobResult(_common.BaseModel): + """Result of running a query job.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + job_name: Optional[str] = Field( + default=None, + description="""Name of the agent engine operation to later check for status.""", + ) + input_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the input file.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the output file.""" + ) + + +class RunQueryJobResultDict(TypedDict, total=False): + """Result of running a query job.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + job_name: Optional[str] + """Name of the agent engine operation to later check for status.""" + + input_gcs_uri: Optional[str] + """The GCS URI of the input file.""" + + output_gcs_uri: Optional[str] + """The GCS URI of the output file.""" + + +RunQueryJobResultOrDict = Union[RunQueryJobResult, RunQueryJobResultDict] + + +class CheckQueryJobResponse(_common.BaseModel): + """Response from LRO.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + output_gcs_uri: Optional[str] = Field( + default=None, description="""The GCS URI of the output file.""" + ) + + +class CheckQueryJobResponseDict(TypedDict, total=False): + """Response from LRO.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + output_gcs_uri: Optional[str] + """The GCS URI of the output file.""" + + +CheckQueryJobResponseOrDict = Union[CheckQueryJobResponse, CheckQueryJobResponseDict] + + +class AssembleDataset(_common.BaseModel): + """Represents the assembled dataset.""" + + bigquery_destination: Optional[str] = Field( + default=None, + description="""The BigQuery destination of the assembled dataset.""", + ) + + +class AssembleDatasetDict(TypedDict, total=False): + """Represents the assembled dataset.""" + + bigquery_destination: Optional[str] + """The BigQuery destination of the assembled dataset.""" + + +AssembleDatasetOrDict = Union[AssembleDataset, AssembleDatasetDict] + + +class BatchPredictionResourceUsageAssessmentResult(_common.BaseModel): + """Result of batch prediction resource usage assessment.""" + + token_count: Optional[int] = Field( + default=None, description="""Number of tokens in the dataset.""" + ) + audio_token_count: Optional[int] = Field( + default=None, description="""Number of audio tokens in the dataset.""" + ) + + +class BatchPredictionResourceUsageAssessmentResultDict(TypedDict, total=False): + """Result of batch prediction resource usage assessment.""" + + token_count: Optional[int] + """Number of tokens in the dataset.""" + + audio_token_count: Optional[int] + """Number of audio tokens in the dataset.""" + + +BatchPredictionResourceUsageAssessmentResultOrDict = Union[ + BatchPredictionResourceUsageAssessmentResult, + BatchPredictionResourceUsageAssessmentResultDict, +] + + +class BatchPredictionValidationAssessmentResult(_common.BaseModel): + """Result of batch prediction validation assessment.""" + + errors: Optional[list[str]] = Field( + default=None, description="""The list of errors found in the dataset.""" + ) + + +class BatchPredictionValidationAssessmentResultDict(TypedDict, total=False): + """Result of batch prediction validation assessment.""" + + errors: Optional[list[str]] + """The list of errors found in the dataset.""" + + +BatchPredictionValidationAssessmentResultOrDict = Union[ + BatchPredictionValidationAssessmentResult, + BatchPredictionValidationAssessmentResultDict, +] + + +class TuningResourceUsageAssessmentResult(_common.BaseModel): + """Result of tuning resource usage assessment.""" + + token_count: Optional[int] = Field( + default=None, description="""The number of tokens in the dataset.""" + ) + billable_character_count: Optional[int] = Field( + default=None, + description="""The number of billable characters in the dataset.""", + ) + + +class TuningResourceUsageAssessmentResultDict(TypedDict, total=False): + """Result of tuning resource usage assessment.""" + + token_count: Optional[int] + """The number of tokens in the dataset.""" + + billable_character_count: Optional[int] + """The number of billable characters in the dataset.""" + + +TuningResourceUsageAssessmentResultOrDict = Union[ + TuningResourceUsageAssessmentResult, TuningResourceUsageAssessmentResultDict +] + + +class TuningValidationAssessmentResult(_common.BaseModel): + """The result of a tuning validation assessment.""" + + errors: Optional[list[str]] = Field( + default=None, description="""The list of errors found in the dataset.""" + ) + + +class TuningValidationAssessmentResultDict(TypedDict, total=False): + """The result of a tuning validation assessment.""" + + errors: Optional[list[str]] + """The list of errors found in the dataset.""" + + +TuningValidationAssessmentResultOrDict = Union[ + TuningValidationAssessmentResult, TuningValidationAssessmentResultDict +] + + +class Prompt(_common.BaseModel): + """Represents a prompt.""" + + prompt_data: Optional["PromptData"] = Field(default=None, description="""""") + _dataset: Optional["Dataset"] = PrivateAttr(default=None) + _dataset_version: Optional["DatasetVersion"] = PrivateAttr(default=None) + + @property + def dataset(self) -> "Dataset": + return self._dataset # type: ignore[return-value] + + @property + def dataset_version(self) -> "DatasetVersion": + return self._dataset_version # type: ignore[return-value] + + @property + def prompt_id(self) -> Optional[str]: + """Returns the ID associated with the prompt resource.""" + if self._dataset and self._dataset.name: + return self._dataset.name.split("/")[-1] + elif not self._dataset and ( + self._dataset_version and self._dataset_version.name + ): + return self._dataset_version.name.split("datasets/")[1].split("/")[0] + return None + + @property + def version_id(self) -> Optional[str]: + """Returns the ID associated with the prompt version resource.""" + if self._dataset_version and self._dataset_version.name: + return self._dataset_version.name.split("/")[-1] + return None + + def assemble_contents(self) -> list[genai_types.Content]: + """Transforms a Prompt object into a list with a single genai_types.Content object. + + This method replaces the variables in the prompt template with the values provided in prompt.prompt_data.variables. + If no variables are provided, prompt.prompt_data.contents is returned as is. Only single-turn prompts are supported. + + This can be used to call generate_content() in the Gen AI SDK. + + Example usage: + + my_prompt = types.Prompt( + prompt_data=types.PromptData( + model="gemini-2.0-flash-001", + contents=[ + genai_types.Content( + parts=[ + genai_types.Part(text="Hello {name}!"), + ], + ), + ], + variables=[ + { + "name": genai_types.Part(text="Alice"), + }, + ], + ), + ) + + from google import genai + + genai_client = genai.Client(vertexai=True, project="my-project", location="us-central1") + genai_client.models.generate_content( + model=my_prompt.prompt_data.model, + contents=my_prompt.assemble_contents(), + ) + + Returns: + A list with a single Content object that can be used to call + generate_content(). + """ + if not self.prompt_data or not self.prompt_data.contents: + return [] + + if not self.prompt_data.variables: + return self.prompt_data.contents + + if len(self.prompt_data.contents) > 1: + raise ValueError( + "Multiple contents are not supported. Use assemble_contents() for a prompt with a single Content item." + ) + + parts_to_process = self.prompt_data.contents[0].parts + if parts_to_process is None: + return [] + if not isinstance(parts_to_process, list): + parts_to_process = [parts_to_process] + + has_placeholders = False + variable_regex = r"\{.*?\}" + for item in parts_to_process: + part = ( + item + if isinstance(item, genai_types.Part) + else genai_types.Part(text=str(item)) + ) + if part.text and re.search(variable_regex, part.text): + has_placeholders = True + break + + if not has_placeholders: + return [genai_types.Content(parts=parts_to_process)] + + all_rendered_parts: list[genai_types.Part] = [] + + for var_dict in self.prompt_data.variables: + for template_item in parts_to_process: + template_part = ( + template_item + if isinstance(template_item, genai_types.Part) + else genai_types.Part(text=str(template_item)) + ) + if template_part.text: + rendered_text = template_part.text + + for key, value in var_dict.items(): + placeholder = f"{{{key}}}" + replacement_text = None + + if isinstance(value, str): + replacement_text = value + elif isinstance(value, genai_types.Part): + if value.text: + replacement_text = value.text + else: + all_rendered_parts.append(value) + if ( + replacement_text is not None + and placeholder in rendered_text + ): + rendered_text = rendered_text.replace( + placeholder, replacement_text + ) + all_rendered_parts.append(genai_types.Part(text=rendered_text)) + else: + all_rendered_parts.append(template_part) + return [genai_types.Content(parts=all_rendered_parts, role="user")] + + +PromptData = SchemaPromptSpecPromptMessage +PromptDataDict = SchemaPromptSpecPromptMessageDict +PromptDataOrDict = Union[PromptData, PromptDataDict] + +ParsedResponseUnion = Union[ + prompts_types.ParsedResponse, prompts_types.ParsedResponseFewShot +] +ParsedResponseUnionDict = Union[ + prompts_types.ParsedResponseDict, prompts_types.ParsedResponseFewShotDict +] + + +class PromptDict(TypedDict, total=False): + """Represents a prompt.""" + + prompt_data: Optional["PromptDataDict"] + """""" + + +PromptOrDict = Union[Prompt, PromptDict] + + +class SchemaPromptInstanceVariableValue(_common.BaseModel): + """Represents a prompt instance variable.""" + + part_list: Optional[SchemaPromptSpecPartList] = Field( + default=None, description="""The parts of the variable value.""" + ) + + +class SchemaPromptInstanceVariableValueDict(TypedDict, total=False): + """Represents a prompt instance variable.""" + + part_list: Optional[SchemaPromptSpecPartListDict] + """The parts of the variable value.""" + + +SchemaPromptInstanceVariableValueOrDict = Union[ + SchemaPromptInstanceVariableValue, SchemaPromptInstanceVariableValueDict +] + + +class CreatePromptConfig(_common.BaseModel): + """Config for creating a prompt.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + prompt_display_name: Optional[str] = Field( + default=None, + description="""The display name for the prompt. If not set, a default name with a timestamp will be used.""", + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the create_version request in seconds. If not set, the default timeout is 90 seconds.""", + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""", + ) + version_display_name: Optional[str] = Field( + default=None, + description="""The display name for the prompt version. If not set, a default name with a timestamp will be used.""", + ) + + +class CreatePromptConfigDict(TypedDict, total=False): + """Config for creating a prompt.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + prompt_display_name: Optional[str] + """The display name for the prompt. If not set, a default name with a timestamp will be used.""" + + timeout: Optional[int] + """The timeout for the create_version request in seconds. If not set, the default timeout is 90 seconds.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""" + + version_display_name: Optional[str] + """The display name for the prompt version. If not set, a default name with a timestamp will be used.""" + + +CreatePromptConfigOrDict = Union[CreatePromptConfig, CreatePromptConfigDict] + + +class CreatePromptVersionConfig(_common.BaseModel): + """Config for creating a prompt version.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + version_display_name: Optional[str] = Field( + default=None, + description="""The display name for the prompt version. If not set, a default name with a timestamp will be used.""", + ) + timeout: Optional[int] = Field( + default=90, + description="""The timeout for the create_version request in seconds. If not set, the default timeout is 90 seconds.""", + ) + prompt_display_name: Optional[str] = Field( + default=None, + description="""The display name for the prompt. If not set, a default name with a timestamp will be used.""", + ) + encryption_spec: Optional[genai_types.EncryptionSpec] = Field( + default=None, + description="""Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""", + ) + + +class CreatePromptVersionConfigDict(TypedDict, total=False): + """Config for creating a prompt version.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + version_display_name: Optional[str] + """The display name for the prompt version. If not set, a default name with a timestamp will be used.""" + + timeout: Optional[int] + """The timeout for the create_version request in seconds. If not set, the default timeout is 90 seconds.""" + + prompt_display_name: Optional[str] + """The display name for the prompt. If not set, a default name with a timestamp will be used.""" + + encryption_spec: Optional[genai_types.EncryptionSpecDict] + """Customer-managed encryption key spec for a prompt dataset. If set, this prompt dataset and all sub-resources of this prompt dataset will be secured by this key.""" + + +CreatePromptVersionConfigOrDict = Union[ + CreatePromptVersionConfig, CreatePromptVersionConfigDict +] + + +class GetPromptConfig(_common.BaseModel): + """Config for getting a prompt.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class GetPromptConfigDict(TypedDict, total=False): + """Config for getting a prompt.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +GetPromptConfigOrDict = Union[GetPromptConfig, GetPromptConfigDict] + + +class PromptRef(_common.BaseModel): + """Reference to a prompt.""" + + prompt_id: Optional[str] = Field(default=None, description="""""") + model: Optional[str] = Field(default=None, description="""""") + + +class PromptRefDict(TypedDict, total=False): + """Reference to a prompt.""" + + prompt_id: Optional[str] + """""" + + model: Optional[str] + """""" + + +PromptRefOrDict = Union[PromptRef, PromptRefDict] + + +class PromptVersionRef(_common.BaseModel): + """Reference to a prompt version.""" + + prompt_id: Optional[str] = Field(default=None, description="""""") + version_id: Optional[str] = Field(default=None, description="""""") + model: Optional[str] = Field(default=None, description="""""") + + +class PromptVersionRefDict(TypedDict, total=False): + """Reference to a prompt version.""" + + prompt_id: Optional[str] + """""" + + version_id: Optional[str] + """""" + + model: Optional[str] + """""" + + +PromptVersionRefOrDict = Union[PromptVersionRef, PromptVersionRefDict] + + +class OptimizeJobConfig(_common.BaseModel): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] = Field( + default=None, + description="""The gcs path to the config file, e.g. gs://bucket/config.json.""", + ) + service_account: Optional[str] = Field( + default=None, + description="""The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number.""", + ) + service_account_project_number: Optional[Union[int, str]] = Field( + default=None, + description="""The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account".""", + ) + wait_for_completion: Optional[bool] = Field( + default=True, + description="""Whether to wait for the job tocomplete. Ignored for async jobs.""", + ) + optimizer_job_display_name: Optional[str] = Field( + default=None, + description="""The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used.""", + ) + + +class OptimizeJobConfigDict(TypedDict, total=False): + """VAPO Prompt Optimizer Config.""" + + config_path: Optional[str] + """The gcs path to the config file, e.g. gs://bucket/config.json.""" + + service_account: Optional[str] + """The service account to use for the custom job. Cannot be provided at the same time as service_account_project_number.""" + + service_account_project_number: Optional[Union[int, str]] + """The project number used to construct the default service account:{service_account_project_number}-compute@developer.gserviceaccount.comCannot be provided at the same time as "service_account".""" + + wait_for_completion: Optional[bool] + """Whether to wait for the job tocomplete. Ignored for async jobs.""" + + optimizer_job_display_name: Optional[str] + """The display name of the optimization job. If not provided, a display name in the format of "vapo-optimizer-{timestamp}" will be used.""" + + +OptimizeJobConfigOrDict = Union[OptimizeJobConfig, OptimizeJobConfigDict] + + +class AgentEngineRuntimeRevision(_common.BaseModel): + """An agent engine runtime revision instance.""" + + api_client: Optional[Any] = Field( + default=None, description="""The underlying API client.""" + ) + api_async_client: Optional[Any] = Field( + default=None, + description="""The underlying API client for asynchronous operations.""", + ) + api_resource: Optional[ReasoningEngineRuntimeRevision] = Field( + default=None, + description="""The underlying API resource (i.e. ReasoningEngineRuntimeRevision).""", + ) + + # Allows dynamic binding of methods based on the registered operations. + model_config = ConfigDict(extra="allow") + + def __repr__(self) -> str: + return ( + f"AgentEngineRuntimeRevision(api_resource.name='{self.api_resource.name}')" + if self.api_resource is not None + else "AgentEngineRuntimeRevision(api_resource.name=None)" + ) + + def operation_schemas(self) -> Optional[list[Dict[str, Any]]]: + """Returns the schemas of all registered operations for the agent.""" + if not isinstance(self.api_resource, ReasoningEngineRuntimeRevision): + raise ValueError("api_resource is not initialized.") + if not self.api_resource.spec: + raise ValueError("api_resource.spec is not initialized.") + return self.api_resource.spec.class_methods + + def delete( + self, + config: Optional[DeleteAgentEngineRuntimeRevisionConfigOrDict] = None, + ) -> None: + """Deletes the agent engine runtime revision. + + Args: + config (DeleteAgentEngineRuntimeRevisionConfig): + Optional. Additional configurations for deleting the Agent Engine Runtime Revision. + """ + if not isinstance(self.api_resource, ReasoningEngineRuntimeRevision): + raise ValueError("api_resource is not initialized.") + self.api_client.delete(name=self.api_resource.name, config=config) # type: ignore[union-attr] + + +class AgentEngineRuntimeRevisionDict(TypedDict, total=False): + """An agent engine runtime revision instance.""" + + api_client: Optional[Any] + """The underlying API client.""" + + api_async_client: Optional[Any] + """The underlying API client for asynchronous operations.""" + + api_resource: Optional[ReasoningEngineRuntimeRevisionDict] + """The underlying API resource (i.e. ReasoningEngineRuntimeRevision).""" + + +AgentEngineRuntimeRevisionOrDict = Union[ + AgentEngineRuntimeRevision, AgentEngineRuntimeRevisionDict +] diff --git a/agentplatform/_genai/types/evals.py b/agentplatform/_genai/types/evals.py new file mode 100644 index 0000000000..e4e5ce3ab8 --- /dev/null +++ b/agentplatform/_genai/types/evals.py @@ -0,0 +1,977 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +import datetime +from typing import Any, Optional, Union +from google.genai import _common +from google.genai import types as genai_types +from pydantic import Field +from typing_extensions import TypedDict + + +class Importance(_common.CaseInSensitiveEnum): + """Importance level of the rubric.""" + + IMPORTANCE_UNSPECIFIED = "IMPORTANCE_UNSPECIFIED" + """Importance is not specified.""" + HIGH = "HIGH" + """High importance.""" + MEDIUM = "MEDIUM" + """Medium importance.""" + LOW = "LOW" + """Low importance.""" + + +class AgentConfig(_common.BaseModel): + """Represents configuration for an Agent.""" + + agent_id: Optional[str] = Field( + default=None, + description="""Unique identifier of the agent. + This ID is used to refer to this agent, e.g., in AgentEvent.author, or in + the `sub_agents` field. It must be unique within the `agents` map.""", + ) + agent_type: Optional[str] = Field( + default=None, + description="""The type or class of the agent (e.g., "LlmAgent", "RouterAgent", + "ToolUseAgent"). Useful for the autorater to understand the expected + behavior of the agent.""", + ) + description: Optional[str] = Field( + default=None, + description="""A high-level description of the agent's role and responsibilities. + Critical for evaluating if the agent is routing tasks correctly.""", + ) + instruction: Optional[str] = Field( + default=None, + description="""The instructions for the LLM model, guiding the agent's behavior. + Can be static or dynamic. Dynamic instructions can contain placeholders + like {variable_name} that will be resolved at runtime using the + `AgentEvent.state_delta` field.""", + ) + tools: Optional[list[genai_types.Tool]] = Field( + default=None, description="""The list of tools available to this agent.""" + ) + sub_agents: Optional[list[str]] = Field( + default=None, + description="""The list of valid agent IDs that this agent can delegate to. + This defines the directed edges in the multi-agent system graph topology.""", + ) + + @staticmethod + def _get_tool_declarations_from_agent(agent: Any) -> genai_types.ToolListUnion: + """Gets tool declarations from an agent. + + Args: + agent: The agent to get the tool declarations from. Data type is google.adk.agents.LLMAgent type. + + Returns: + The tool declarations of the agent. + """ + tool_declarations: genai_types.ToolListUnion = [] + for tool in agent.tools: + # ADK tools (e.g. AgentTool) provide their own declaration via + # _get_declaration(). Use it when available to avoid calling + # typing.get_type_hints() on tool instances whose classes use + # `from __future__ import annotations`, which causes NameError. + if hasattr(tool, "_get_declaration") and callable(tool._get_declaration): + declaration = tool._get_declaration() + if declaration is not None: + tool_declarations.append({"function_declarations": [declaration]}) + continue + + tool_declarations.append( + { + "function_declarations": [ + genai_types.FunctionDeclaration.from_callable_with_api_option( + callable=tool + ) + ] + } + ) + return tool_declarations + + @classmethod + def from_agent(cls, agent: Any) -> "AgentConfig": + """Creates an AgentConfig from an ADK agent. + + Args: + agent: The agent to get the agent info from, data type is google.adk.agents.LLMAgent type. + + Returns: + An AgentConfig populated with the agent's metadata for evaluation. + """ + agent_id = getattr(agent, "name", None) + if not agent_id: + raise ValueError(f"Agent {agent} must have a name.") + return cls( # pytype: disable=missing-parameter + agent_id=agent_id, + agent_type=agent.__class__.__name__, + description=getattr(agent, "description", None), + instruction=getattr(agent, "instruction", None), + tools=AgentConfig._get_tool_declarations_from_agent(agent), + sub_agents=[ + str(getattr(sub_agent, "name")) + for sub_agent in getattr(agent, "sub_agents", []) + if getattr(sub_agent, "name", None) is not None + ], + ) + + +class AgentConfigDict(TypedDict, total=False): + """Represents configuration for an Agent.""" + + agent_id: Optional[str] + """Unique identifier of the agent. + This ID is used to refer to this agent, e.g., in AgentEvent.author, or in + the `sub_agents` field. It must be unique within the `agents` map.""" + + agent_type: Optional[str] + """The type or class of the agent (e.g., "LlmAgent", "RouterAgent", + "ToolUseAgent"). Useful for the autorater to understand the expected + behavior of the agent.""" + + description: Optional[str] + """A high-level description of the agent's role and responsibilities. + Critical for evaluating if the agent is routing tasks correctly.""" + + instruction: Optional[str] + """The instructions for the LLM model, guiding the agent's behavior. + Can be static or dynamic. Dynamic instructions can contain placeholders + like {variable_name} that will be resolved at runtime using the + `AgentEvent.state_delta` field.""" + + tools: Optional[list[genai_types.ToolDict]] + """The list of tools available to this agent.""" + + sub_agents: Optional[list[str]] + """The list of valid agent IDs that this agent can delegate to. + This defines the directed edges in the multi-agent system graph topology.""" + + +AgentConfigOrDict = Union[AgentConfig, AgentConfigDict] + + +class AgentEvent(_common.BaseModel): + """A single event in the execution trace.""" + + author: Optional[str] = Field( + default=None, + description="""The ID of the agent or entity that generated this event. + Use "user" to denote events generated by the end-user.""", + ) + content: Optional[genai_types.Content] = Field( + default=None, description="""The content of the event.""" + ) + event_time: Optional[datetime.datetime] = Field( + default=None, description="""The timestamp when the event occurred.""" + ) + state_delta: Optional[dict[str, Any]] = Field( + default=None, + description="""The change in the session state caused by this event. + This is a key-value map of fields that were modified or added by the event.""", + ) + active_tools: Optional[list[genai_types.Tool]] = Field( + default=None, + description="""The list of tools that were active/available to the agent at the + time of this event. This overrides the `AgentConfig.tools` if set.""", + ) + + +class AgentEventDict(TypedDict, total=False): + """A single event in the execution trace.""" + + author: Optional[str] + """The ID of the agent or entity that generated this event. + Use "user" to denote events generated by the end-user.""" + + content: Optional[genai_types.ContentDict] + """The content of the event.""" + + event_time: Optional[datetime.datetime] + """The timestamp when the event occurred.""" + + state_delta: Optional[dict[str, Any]] + """The change in the session state caused by this event. + This is a key-value map of fields that were modified or added by the event.""" + + active_tools: Optional[list[genai_types.ToolDict]] + """The list of tools that were active/available to the agent at the + time of this event. This overrides the `AgentConfig.tools` if set.""" + + +AgentEventOrDict = Union[AgentEvent, AgentEventDict] + + +class ConversationTurn(_common.BaseModel): + """Represents a single turn/invocation in the conversation.""" + + turn_index: Optional[int] = Field( + default=None, + description="""The 0-based index of the turn in the conversation sequence.""", + ) + turn_id: Optional[str] = Field( + default=None, description="""A unique identifier for the turn.""" + ) + events: Optional[list[AgentEvent]] = Field( + default=None, + description="""The list of events that occurred during this turn.""", + ) + + +class ConversationTurnDict(TypedDict, total=False): + """Represents a single turn/invocation in the conversation.""" + + turn_index: Optional[int] + """The 0-based index of the turn in the conversation sequence.""" + + turn_id: Optional[str] + """A unique identifier for the turn.""" + + events: Optional[list[AgentEventDict]] + """The list of events that occurred during this turn.""" + + +ConversationTurnOrDict = Union[ConversationTurn, ConversationTurnDict] + + +class AgentData(_common.BaseModel): + """Represents data specific to multi-turn agent evaluations.""" + + agents: Optional[dict[str, AgentConfig]] = Field( + default=None, + description="""A map containing the static configurations for each agent in the system. + Key: agent_id (matches the `author` field in events). + Value: The static configuration of the agent.""", + ) + turns: Optional[list[ConversationTurn]] = Field( + default=None, + description="""A chronological list of conversation turns. + Each turn represents a logical execution cycle (e.g., User Input -> Agent + Response).""", + ) + + @classmethod + def get_agents_map(cls, agent: Any) -> dict[str, AgentConfig]: + """Recursively gets all agent configs from an agent and its sub-agents. + + Args: + agent: The agent to get the agent info from, data type is google.adk.agents.LLMAgent type. + + Returns: + A dict mapping agent_id to AgentConfig. + """ + agent_config = AgentConfig.from_agent(agent) + agent_id = agent_config.agent_id + if not agent_id: + raise ValueError(f"Agent {agent} must have a name.") + agents_map = {agent_id: agent_config} + + for sub_agent in getattr(agent, "sub_agents", []): + agents_map.update(cls.get_agents_map(sub_agent)) + + return agents_map + + @classmethod + def from_session(cls, agent: Any, session_history: list[Any]) -> "AgentData": + """Creates an AgentData object from a session history. + + Segments the flat list of session events into ConversationTurns. A new turn + is initiated by a User message. + + Args: + agent: The agent instance used in the session. + session_history: A list of raw events/messages from the session. + + Returns: + An AgentData object containing the segmented history and agent config. + """ + agents_map = cls.get_agents_map(agent) + agent_id = agent.name + + turns: list[ConversationTurn] = [] + current_turn_events: list[AgentEvent] = [] + + for event in session_history: + is_user = False + if isinstance(event, dict): + if event.get("role") == "user": + is_user = True + elif ( + isinstance(event.get("content"), dict) + and event["content"].get("role") == "user" + ): + is_user = True + elif hasattr(event, "role") and event.role == "user": + is_user = True + + if is_user and current_turn_events: + turns.append( + ConversationTurn( # pytype: disable=missing-parameter + turn_index=len(turns), + turn_id=f"turn_{len(turns)}", + events=current_turn_events, + ) + ) + current_turn_events = [] + + author = "user" if is_user else agent_id + + content = None + if isinstance(event, dict): + if "content" in event: + raw_content = event["content"] + if isinstance(raw_content, genai_types.Content): + content = raw_content + elif isinstance(raw_content, dict): + try: + content = genai_types.Content.model_validate(raw_content) + except Exception as e: + raise ValueError( + f"Failed to validate Content from dictionary in session history: {raw_content}" + ) from e + elif isinstance(raw_content, str): + content = genai_types.Content( + parts=[genai_types.Part(text=raw_content)] + ) + elif "parts" in event: + try: + content = genai_types.Content.model_validate(event) + except Exception as e: + raise ValueError( + f"Failed to validate Content from event with 'parts': {event}" + ) from e + elif hasattr(event, "content") and isinstance( + event.content, genai_types.Content + ): + content = event.content + + agent_event = AgentEvent( # pytype: disable=missing-parameter + author=author, + content=content, + ) + current_turn_events.append(agent_event) + + if current_turn_events: + turns.append( + ConversationTurn( # pytype: disable=missing-parameter + turn_index=len(turns), + turn_id=f"turn_{len(turns)}", + events=current_turn_events, + ) + ) + + return cls(agents=agents_map, turns=turns) # pytype: disable=missing-parameter + + +class AgentDataDict(TypedDict, total=False): + """Represents data specific to multi-turn agent evaluations.""" + + agents: Optional[dict[str, AgentConfigDict]] + """A map containing the static configurations for each agent in the system. + Key: agent_id (matches the `author` field in events). + Value: The static configuration of the agent.""" + + turns: Optional[list[ConversationTurnDict]] + """A chronological list of conversation turns. + Each turn represents a logical execution cycle (e.g., User Input -> Agent + Response).""" + + +AgentDataOrDict = Union[AgentData, AgentDataDict] + + +class AgentInfo(_common.BaseModel): + """The agent info of an agent system, used for agent evaluation.""" + + name: Optional[str] = Field( + default=None, description="""Agent candidate name, used as an identifier.""" + ) + agents: Optional[dict[str, AgentConfig]] = Field( + default=None, + description="""A map containing the static configurations for each agent in the system. + Key: agent_id (matches the `author` field in events). + Value: The static configuration of the agent.""", + ) + root_agent_id: Optional[str] = Field( + default=None, description="""The agent ID of the root agent.""" + ) + + @classmethod + def load_from_agent(cls, agent: Any) -> "AgentInfo": + """Loads agent info from an ADK agent. + + Args: + agent: The root agent to get the agent info from, data type is google.adk.agents.LLMAgent type. + + Returns: + The agent info of the agent system. + + Example: + ``` + from vertexai._genai import types + + agent_info = types.evals.AgentInfo.load_from_agent(agent=my_agent) + ``` + """ + agent_name = getattr(agent, "name", None) + if not agent_name: + raise ValueError(f"Agent {agent} must have a name.") + return cls( # pytype: disable=missing-parameter + name=agent_name, + agents=AgentData.get_agents_map(agent), + root_agent_id=agent_name, + ) + + +class AgentInfoDict(TypedDict, total=False): + """The agent info of an agent system, used for agent evaluation.""" + + name: Optional[str] + """Agent candidate name, used as an identifier.""" + + agents: Optional[dict[str, AgentConfigDict]] + """A map containing the static configurations for each agent in the system. + Key: agent_id (matches the `author` field in events). + Value: The static configuration of the agent.""" + + root_agent_id: Optional[str] + """The agent ID of the root agent.""" + + +AgentInfoOrDict = Union[AgentInfo, AgentInfoDict] + + +class SessionInput(_common.BaseModel): + """This field is experimental and may change in future versions. + + Input to initialize a session and run an agent, used for agent evaluation. + """ + + user_id: Optional[str] = Field(default=None, description="""The user id.""") + state: Optional[dict[str, str]] = Field( + default=None, description="""The state of the session.""" + ) + app_name: Optional[str] = Field( + default=None, + description="""The name of the app, used for local ADK agent run Runner and Session.""", + ) + + +class SessionInputDict(TypedDict, total=False): + """This field is experimental and may change in future versions. + + Input to initialize a session and run an agent, used for agent evaluation. + """ + + user_id: Optional[str] + """The user id.""" + + state: Optional[dict[str, str]] + """The state of the session.""" + + app_name: Optional[str] + """The name of the app, used for local ADK agent run Runner and Session.""" + + +SessionInputOrDict = Union[SessionInput, SessionInputDict] + + +class UserScenario(_common.BaseModel): + """User scenario to help simulate multi-turn agent run results.""" + + starting_prompt: Optional[str] = Field( + default=None, + description="""The prompt that starts the conversation between the simulated user and the agent under test.""", + ) + conversation_plan: Optional[str] = Field( + default=None, + description="""The plan for the conversation, used to drive the multi-turn agent run and generate the simulated agent evaluation dataset.""", + ) + + +class UserScenarioDict(TypedDict, total=False): + """User scenario to help simulate multi-turn agent run results.""" + + starting_prompt: Optional[str] + """The prompt that starts the conversation between the simulated user and the agent under test.""" + + conversation_plan: Optional[str] + """The plan for the conversation, used to drive the multi-turn agent run and generate the simulated agent evaluation dataset.""" + + +UserScenarioOrDict = Union[UserScenario, UserScenarioDict] + + +class UserScenarioGenerationConfig(_common.BaseModel): + """User scenario generation configuration.""" + + model_name: Optional[str] = Field( + default=None, + description="""The model name to use for user scenario generation.""", + ) + count: Optional[int] = Field( + default=None, + description="""The number of user scenarios to generate. The maximum number of scenarios that can be generated is 100.""", + ) + generation_instruction: Optional[str] = Field( + default=None, + description="""Instruction to guide the conversation scenario generation.""", + ) + environment_context: Optional[str] = Field( + default=None, + description="""Environment context to drive simulation. For example, for a QA agent, this could be the docs queried by the tools.""", + ) + + +class UserScenarioGenerationConfigDict(TypedDict, total=False): + """User scenario generation configuration.""" + + model_name: Optional[str] + """The model name to use for user scenario generation.""" + + count: Optional[int] + """The number of user scenarios to generate. The maximum number of scenarios that can be generated is 100.""" + + generation_instruction: Optional[str] + """Instruction to guide the conversation scenario generation.""" + + environment_context: Optional[str] + """Environment context to drive simulation. For example, for a QA agent, this could be the docs queried by the tools.""" + + +UserScenarioGenerationConfigOrDict = Union[ + UserScenarioGenerationConfig, UserScenarioGenerationConfigDict +] + + +class UserSimulatorConfig(_common.BaseModel): + """Configuration for a user simulator. + + Uses an LLM to generate multi-turn messages that simulate a user. + """ + + model_name: Optional[str] = Field( + default=None, + description="""The model name to get next user message for multi-turn agent run.""", + ) + model_configuration: Optional[genai_types.GenerateContentConfig] = Field( + default=None, description="""The configuration for the model.""" + ) + max_turn: Optional[int] = Field( + default=None, + description="""Maximum number of invocations allowed by the multi-turn agent + running. This property allows us to stop a run-off conversation + where the agent and the user simulator get into a never ending loop. + The initial fixed prompt is also counted as an invocation.""", + ) + + +class UserSimulatorConfigDict(TypedDict, total=False): + """Configuration for a user simulator. + + Uses an LLM to generate multi-turn messages that simulate a user. + """ + + model_name: Optional[str] + """The model name to get next user message for multi-turn agent run.""" + + model_configuration: Optional[genai_types.GenerateContentConfigDict] + """The configuration for the model.""" + + max_turn: Optional[int] + """Maximum number of invocations allowed by the multi-turn agent + running. This property allows us to stop a run-off conversation + where the agent and the user simulator get into a never ending loop. + The initial fixed prompt is also counted as an invocation.""" + + +UserSimulatorConfigOrDict = Union[UserSimulatorConfig, UserSimulatorConfigDict] + + +class Event(_common.BaseModel): + """Represents an event in a conversation between agents and users. + + It is used to store the content of the conversation, as well as the actions + taken by the agents like function calls, function responses, intermediate NL + responses etc. + """ + + event_id: Optional[str] = Field( + default=None, description="""Unique identifier for the agent event.""" + ) + content: Optional[genai_types.Content] = Field( + default=None, description="""Content of the event.""" + ) + creation_timestamp: Optional[datetime.datetime] = Field( + default=None, description="""The creation timestamp of the event.""" + ) + author: Optional[str] = Field( + default=None, description="""Name of the entity that produced the event.""" + ) + + +class EventDict(TypedDict, total=False): + """Represents an event in a conversation between agents and users. + + It is used to store the content of the conversation, as well as the actions + taken by the agents like function calls, function responses, intermediate NL + responses etc. + """ + + event_id: Optional[str] + """Unique identifier for the agent event.""" + + content: Optional[genai_types.ContentDict] + """Content of the event.""" + + creation_timestamp: Optional[datetime.datetime] + """The creation timestamp of the event.""" + + author: Optional[str] + """Name of the entity that produced the event.""" + + +EventOrDict = Union[Event, EventDict] + + +class Message(_common.BaseModel): + """Represents a single message turn in a conversation.""" + + turn_id: Optional[str] = Field( + default=None, description="""Unique identifier for the message turn.""" + ) + content: Optional[genai_types.Content] = Field( + default=None, description="""Content of the message, including function call.""" + ) + creation_timestamp: Optional[datetime.datetime] = Field( + default=None, + description="""Timestamp indicating when the message was created.""", + ) + author: Optional[str] = Field( + default=None, description="""Name of the entity that produced the message.""" + ) + + +class MessageDict(TypedDict, total=False): + """Represents a single message turn in a conversation.""" + + turn_id: Optional[str] + """Unique identifier for the message turn.""" + + content: Optional[genai_types.ContentDict] + """Content of the message, including function call.""" + + creation_timestamp: Optional[datetime.datetime] + """Timestamp indicating when the message was created.""" + + author: Optional[str] + """Name of the entity that produced the message.""" + + +MessageOrDict = Union[Message, MessageDict] + + +class Events(_common.BaseModel): + """This field is experimental and will be removed in future versions. + + Represents a list of events for an agent. + """ + + event: Optional[list[genai_types.Content]] = Field( + default=None, description="""A list of events.""" + ) + + +class EventsDict(TypedDict, total=False): + """This field is experimental and will be removed in future versions. + + Represents a list of events for an agent. + """ + + event: Optional[list[genai_types.ContentDict]] + """A list of events.""" + + +EventsOrDict = Union[Events, EventsDict] + + +class InstanceDataContents(_common.BaseModel): + """This field is experimental and will be removed in future versions. + + List of standard Content messages from Gemini API. + """ + + contents: Optional[list[genai_types.Content]] = Field( + default=None, description="""Repeated contents.""" + ) + + +class InstanceDataContentsDict(TypedDict, total=False): + """This field is experimental and will be removed in future versions. + + List of standard Content messages from Gemini API. + """ + + contents: Optional[list[genai_types.ContentDict]] + """Repeated contents.""" + + +InstanceDataContentsOrDict = Union[InstanceDataContents, InstanceDataContentsDict] + + +class InstanceData(_common.BaseModel): + """This field is experimental and will be removed in future versions. + + Instance data used to populate placeholders in a metric prompt template. + """ + + text: Optional[str] = Field(default=None, description="""Text data.""") + contents: Optional[InstanceDataContents] = Field( + default=None, description="""List of Gemini content data.""" + ) + + +class InstanceDataDict(TypedDict, total=False): + """This field is experimental and will be removed in future versions. + + Instance data used to populate placeholders in a metric prompt template. + """ + + text: Optional[str] + """Text data.""" + + contents: Optional[InstanceDataContentsDict] + """List of Gemini content data.""" + + +InstanceDataOrDict = Union[InstanceData, InstanceDataDict] + + +class Tools(_common.BaseModel): + """This field is experimental and will be removed in future versions. + + Represents a list of tools for an agent. + """ + + tool: Optional[list[genai_types.Tool]] = Field( + default=None, + description="""List of tools: each tool can have multiple function declarations.""", + ) + + +class ToolsDict(TypedDict, total=False): + """This field is experimental and will be removed in future versions. + + Represents a list of tools for an agent. + """ + + tool: Optional[list[genai_types.ToolDict]] + """List of tools: each tool can have multiple function declarations.""" + + +ToolsOrDict = Union[Tools, ToolsDict] + + +class RubricContentProperty(_common.BaseModel): + """Defines criteria based on a specific property.""" + + description: Optional[str] = Field( + default=None, + description="""Description of the property being evaluated. + Example: "The model's response is grammatically correct." """, + ) + + +class RubricContentPropertyDict(TypedDict, total=False): + """Defines criteria based on a specific property.""" + + description: Optional[str] + """Description of the property being evaluated. + Example: "The model's response is grammatically correct." """ + + +RubricContentPropertyOrDict = Union[RubricContentProperty, RubricContentPropertyDict] + + +class RubricContent(_common.BaseModel): + """Content of the rubric, defining the testable criteria.""" + + property: Optional[RubricContentProperty] = Field( + default=None, + description="""Evaluation criteria based on a specific property.""", + ) + + +class RubricContentDict(TypedDict, total=False): + """Content of the rubric, defining the testable criteria.""" + + property: Optional[RubricContentPropertyDict] + """Evaluation criteria based on a specific property.""" + + +RubricContentOrDict = Union[RubricContent, RubricContentDict] + + +class Rubric(_common.BaseModel): + """Message representing a single testable criterion for evaluation. + + One input prompt could have multiple rubrics. + """ + + rubric_id: Optional[str] = Field( + default=None, + description="""Required. Unique identifier for the rubric. + This ID is used to refer to this rubric, e.g., in RubricVerdict.""", + ) + content: Optional[RubricContent] = Field( + default=None, + description="""Required. The actual testable criteria for the rubric.""", + ) + type: Optional[str] = Field( + default=None, + description="""Optional. A type designator for the rubric, which can inform how it's + evaluated or interpreted by systems or users. + It's recommended to use consistent, well-defined, upper snake_case strings. + Examples: "SUMMARIZATION_QUALITY", "SAFETY_HARMFUL_CONTENT", + "INSTRUCTION_ADHERENCE".""", + ) + importance: Optional[Importance] = Field( + default=None, + description="""Optional. The relative importance of this rubric.""", + ) + + +class RubricDict(TypedDict, total=False): + """Message representing a single testable criterion for evaluation. + + One input prompt could have multiple rubrics. + """ + + rubric_id: Optional[str] + """Required. Unique identifier for the rubric. + This ID is used to refer to this rubric, e.g., in RubricVerdict.""" + + content: Optional[RubricContentDict] + """Required. The actual testable criteria for the rubric.""" + + type: Optional[str] + """Optional. A type designator for the rubric, which can inform how it's + evaluated or interpreted by systems or users. + It's recommended to use consistent, well-defined, upper snake_case strings. + Examples: "SUMMARIZATION_QUALITY", "SAFETY_HARMFUL_CONTENT", + "INSTRUCTION_ADHERENCE".""" + + importance: Optional[Importance] + """Optional. The relative importance of this rubric.""" + + +RubricOrDict = Union[Rubric, RubricDict] + + +class RubricVerdict(_common.BaseModel): + """Represents the verdict of an evaluation against a single rubric.""" + + evaluated_rubric: Optional[Rubric] = Field( + default=None, + description="""Required. The full rubric definition that was evaluated. + Storing this ensures the verdict is self-contained and understandable, + especially if the original rubric definition changes or was dynamically + generated.""", + ) + verdict: Optional[bool] = Field( + default=None, + description="""Required. Outcome of the evaluation against the rubric, represented as a + boolean. `true` indicates a "Pass", `false` indicates a "Fail".""", + ) + reasoning: Optional[str] = Field( + default=None, + description="""Optional. Human-readable reasoning or explanation for the verdict. + This can include specific examples or details from the evaluated content + that justify the given verdict.""", + ) + + +class RubricVerdictDict(TypedDict, total=False): + """Represents the verdict of an evaluation against a single rubric.""" + + evaluated_rubric: Optional[RubricDict] + """Required. The full rubric definition that was evaluated. + Storing this ensures the verdict is self-contained and understandable, + especially if the original rubric definition changes or was dynamically + generated.""" + + verdict: Optional[bool] + """Required. Outcome of the evaluation against the rubric, represented as a + boolean. `true` indicates a "Pass", `false` indicates a "Fail".""" + + reasoning: Optional[str] + """Optional. Human-readable reasoning or explanation for the verdict. + This can include specific examples or details from the evaluated content + that justify the given verdict.""" + + +RubricVerdictOrDict = Union[RubricVerdict, RubricVerdictDict] + + +class CandidateResult(_common.BaseModel): + """Result for a single candidate.""" + + candidate: Optional[str] = Field( + default=None, + description="""The candidate that is being evaluated. The value is the same as the candidate name in the EvaluationRequest.""", + ) + metric: Optional[str] = Field( + default=None, description="""The metric that was evaluated.""" + ) + score: Optional[float] = Field( + default=None, description="""The score of the metric.""" + ) + explanation: Optional[str] = Field( + default=None, description="""The explanation for the metric.""" + ) + rubric_verdicts: Optional[list[RubricVerdict]] = Field( + default=None, description="""The rubric verdicts for the metric.""" + ) + additional_results: Optional[dict[str, Any]] = Field( + default=None, description="""Additional results for the metric.""" + ) + + +class CandidateResultDict(TypedDict, total=False): + """Result for a single candidate.""" + + candidate: Optional[str] + """The candidate that is being evaluated. The value is the same as the candidate name in the EvaluationRequest.""" + + metric: Optional[str] + """The metric that was evaluated.""" + + score: Optional[float] + """The score of the metric.""" + + explanation: Optional[str] + """The explanation for the metric.""" + + rubric_verdicts: Optional[list[RubricVerdictDict]] + """The rubric verdicts for the metric.""" + + additional_results: Optional[dict[str, Any]] + """Additional results for the metric.""" + + +CandidateResultOrDict = Union[CandidateResult, CandidateResultDict] diff --git a/agentplatform/_genai/types/prompt_optimizer.py b/agentplatform/_genai/types/prompt_optimizer.py new file mode 100644 index 0000000000..52c6c3058f --- /dev/null +++ b/agentplatform/_genai/types/prompt_optimizer.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +from typing import Optional, Union +from google.genai import _common +from pydantic import Field +from typing_extensions import TypedDict + + +class ApplicableGuideline(_common.BaseModel): + """Applicable guideline for the optimize_prompt method.""" + + applicable_guideline: Optional[str] = Field(default=None, description="""""") + suggested_improvement: Optional[str] = Field(default=None, description="""""") + text_before_change: Optional[str] = Field(default=None, description="""""") + text_after_change: Optional[str] = Field(default=None, description="""""") + + +class ApplicableGuidelineDict(TypedDict, total=False): + """Applicable guideline for the optimize_prompt method.""" + + applicable_guideline: Optional[str] + """""" + + suggested_improvement: Optional[str] + """""" + + text_before_change: Optional[str] + """""" + + text_after_change: Optional[str] + """""" + + +ApplicableGuidelineOrDict = Union[ApplicableGuideline, ApplicableGuidelineDict] + + +class ParsedResponse(_common.BaseModel): + """Response for the optimize_prompt method.""" + + optimization_type: Optional[str] = Field(default=None, description="""""") + applicable_guidelines: Optional[list[ApplicableGuideline]] = Field( + default=None, description="""""" + ) + original_prompt: Optional[str] = Field(default=None, description="""""") + suggested_prompt: Optional[str] = Field(default=None, description="""""") + + +class ParsedResponseDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + optimization_type: Optional[str] + """""" + + applicable_guidelines: Optional[list[ApplicableGuidelineDict]] + """""" + + original_prompt: Optional[str] + """""" + + suggested_prompt: Optional[str] + """""" + + +ParsedResponseOrDict = Union[ParsedResponse, ParsedResponseDict] + + +class ParsedResponseFewShot(_common.BaseModel): + """Response for the optimize_prompt method.""" + + suggested_modifications: Optional[list[ApplicableGuideline]] = Field( + default=None, description="""""" + ) + original_system_instructions: Optional[str] = Field( + default=None, description="""""" + ) + new_system_instructions: Optional[str] = Field(default=None, description="""""") + + +class ParsedResponseFewShotDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + suggested_modifications: Optional[list[ApplicableGuidelineDict]] + """""" + + original_system_instructions: Optional[str] + """""" + + new_system_instructions: Optional[str] + """""" + + +ParsedResponseFewShotOrDict = Union[ParsedResponseFewShot, ParsedResponseFewShotDict] diff --git a/agentplatform/_genai/types/prompts.py b/agentplatform/_genai/types/prompts.py new file mode 100644 index 0000000000..52c6c3058f --- /dev/null +++ b/agentplatform/_genai/types/prompts.py @@ -0,0 +1,107 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Code generated by the Google Gen AI SDK generator DO NOT EDIT. + +from typing import Optional, Union +from google.genai import _common +from pydantic import Field +from typing_extensions import TypedDict + + +class ApplicableGuideline(_common.BaseModel): + """Applicable guideline for the optimize_prompt method.""" + + applicable_guideline: Optional[str] = Field(default=None, description="""""") + suggested_improvement: Optional[str] = Field(default=None, description="""""") + text_before_change: Optional[str] = Field(default=None, description="""""") + text_after_change: Optional[str] = Field(default=None, description="""""") + + +class ApplicableGuidelineDict(TypedDict, total=False): + """Applicable guideline for the optimize_prompt method.""" + + applicable_guideline: Optional[str] + """""" + + suggested_improvement: Optional[str] + """""" + + text_before_change: Optional[str] + """""" + + text_after_change: Optional[str] + """""" + + +ApplicableGuidelineOrDict = Union[ApplicableGuideline, ApplicableGuidelineDict] + + +class ParsedResponse(_common.BaseModel): + """Response for the optimize_prompt method.""" + + optimization_type: Optional[str] = Field(default=None, description="""""") + applicable_guidelines: Optional[list[ApplicableGuideline]] = Field( + default=None, description="""""" + ) + original_prompt: Optional[str] = Field(default=None, description="""""") + suggested_prompt: Optional[str] = Field(default=None, description="""""") + + +class ParsedResponseDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + optimization_type: Optional[str] + """""" + + applicable_guidelines: Optional[list[ApplicableGuidelineDict]] + """""" + + original_prompt: Optional[str] + """""" + + suggested_prompt: Optional[str] + """""" + + +ParsedResponseOrDict = Union[ParsedResponse, ParsedResponseDict] + + +class ParsedResponseFewShot(_common.BaseModel): + """Response for the optimize_prompt method.""" + + suggested_modifications: Optional[list[ApplicableGuideline]] = Field( + default=None, description="""""" + ) + original_system_instructions: Optional[str] = Field( + default=None, description="""""" + ) + new_system_instructions: Optional[str] = Field(default=None, description="""""") + + +class ParsedResponseFewShotDict(TypedDict, total=False): + """Response for the optimize_prompt method.""" + + suggested_modifications: Optional[list[ApplicableGuidelineDict]] + """""" + + original_system_instructions: Optional[str] + """""" + + new_system_instructions: Optional[str] + """""" + + +ParsedResponseFewShotOrDict = Union[ParsedResponseFewShot, ParsedResponseFewShotDict] diff --git a/agentplatform/frameworks/__init__.py b/agentplatform/frameworks/__init__.py new file mode 100644 index 0000000000..346cdc96c9 --- /dev/null +++ b/agentplatform/frameworks/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Classes for working with agent platform.""" + +from agentplatform.frameworks import a2a +from agentplatform.frameworks import adk +from agentplatform.frameworks import ag2 +from agentplatform.frameworks import langchain +from agentplatform.frameworks import langgraph +from agentplatform.frameworks import llama_index + + +A2aAgent = a2a.A2aAgent +AdkApp = adk.AdkApp +AG2Agent = ag2.AG2Agent +LangchainAgent = langchain.LangchainAgent +LanggraphAgent = langgraph.LanggraphAgent +LlamaIndexQueryPipelineAgent = llama_index.LlamaIndexQueryPipelineAgent + +__all__ = ( + "A2aAgent", + "AdkApp", + "AG2Agent", + "LangchainAgent", + "LanggraphAgent", + "LlamaIndexQueryPipelineAgent", +) diff --git a/agentplatform/frameworks/a2a.py b/agentplatform/frameworks/a2a.py new file mode 100644 index 0000000000..2a5ba9bcc7 --- /dev/null +++ b/agentplatform/frameworks/a2a.py @@ -0,0 +1,575 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections.abc import AsyncIterator +import os +from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING + + +if TYPE_CHECKING: + try: + from a2a.server.request_handlers import RequestHandler + from a2a.server.tasks import TaskStore + from a2a.types import AgentCard, AgentSkill + from a2a.server.agent_execution import AgentExecutor + from a2a.server.context import ServerCallContext + from a2a.types import ( + SendMessageRequest, + CancelTaskRequest, + GetTaskRequest, + GetExtendedAgentCardRequest, + SubscribeToTaskRequest, + ListTasksRequest, + ListTasksResponse, + TaskPushNotificationConfig, + GetTaskPushNotificationConfigRequest, + ListTaskPushNotificationConfigsRequest, + ListTaskPushNotificationConfigsResponse, + DeleteTaskPushNotificationConfigRequest, + Message, + Task, + ) + from a2a.server.events.event_queue import Event + + RequestHandler = RequestHandler + TaskStore = TaskStore + AgentCard = AgentCard + AgentSkill = AgentSkill + AgentExecutor = AgentExecutor + ServerCallContext = ServerCallContext + SendMessageRequest = SendMessageRequest + CancelTaskRequest = CancelTaskRequest + GetTaskRequest = GetTaskRequest + GetExtendedAgentCardRequest = GetExtendedAgentCardRequest + SubscribeToTaskRequest = SubscribeToTaskRequest + ListTasksRequest = ListTasksRequest + ListTasksResponse = ListTasksResponse + TaskPushNotificationConfig = TaskPushNotificationConfig + GetTaskPushNotificationConfigRequest = GetTaskPushNotificationConfigRequest + ListTaskPushNotificationConfigsRequest = ListTaskPushNotificationConfigsRequest + ListTaskPushNotificationConfigsResponse = ( + ListTaskPushNotificationConfigsResponse + ) + DeleteTaskPushNotificationConfigRequest = ( + DeleteTaskPushNotificationConfigRequest + ) + Message = Message + Task = Task + Event = Event + except (ImportError, AttributeError): + RequestHandler = Any + TaskStore = Any + AgentCard = Any + AgentSkill = Any + AgentExecutor = Any + ServerCallContext = Any + SendMessageRequest = Any + CancelTaskRequest = Any + GetTaskRequest = Any + GetExtendedAgentCardRequest = Any + SubscribeToTaskRequest = Any + ListTasksRequest = Any + ListTasksResponse = Any + TaskPushNotificationConfig = Any + GetTaskPushNotificationConfigRequest = Any + ListTaskPushNotificationConfigsRequest = Any + ListTaskPushNotificationConfigsResponse = Any + DeleteTaskPushNotificationConfigRequest = Any + Message = Any + Task = Any + Event = Any + AgentExecutor = Any + ServerCallContext = Any + SendMessageRequest = Any + CancelTaskRequest = Any + GetTaskRequest = Any + GetExtendedAgentCardRequest = Any + SubscribeToTaskRequest = Any + Message = Any + Task = Any + Event = Any + + +def create_agent_card( + agent_name: Optional[str] = None, + description: Optional[str] = None, + skills: Optional[List["AgentSkill"]] = None, + agent_card: Optional[Dict[str, Any]] = None, + default_input_modes: Optional[List[str]] = None, + default_output_modes: Optional[List[str]] = None, + streaming: bool = False, +) -> "AgentCard": + """Creates an AgentCard object. + + The function can be called in two ways: + 1. By providing the individual parameters: agent_name, description, and + skills. + 2. By providing a single dictionary containing all the data. + + If a dictionary is provided, the other parameters are ignored. + + Args: + agent_name (Optional[str]): The name of the agent. + description (Optional[str]): A description of the agent. + skills (Optional[List[AgentSkill]]): A list of AgentSkills. + agent_card (Optional[Dict[str, Any]]): Agent Card as a dictionary. + default_input_modes (Optional[List[str]]): A list of input modes, default + to ["text/plain"]. + default_output_modes (Optional[List[str]]): A list of output modes, + default to ["application/json"]. + streaming (bool): Whether to enable streaming for the agent. Defaults to + False. + + Returns: + AgentCard: A fully constructed AgentCard object. + + Raises: + ValueError: If neither a dictionary nor the required parameters are + provided. + """ + # pylint: disable=g-import-not-at-top + from a2a.types import AgentCard, AgentCapabilities, AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT + + # Check if a dictionary was provided. + if agent_card: + return AgentCard(**agent_card) + + # If no dictionary, use the individual parameters. + elif agent_name and description and skills: + return AgentCard( + name=agent_name, + description=description, + version="1.0.0", + default_input_modes=default_input_modes or ["text/plain"], + default_output_modes=default_output_modes or ["application/json"], + capabilities=AgentCapabilities( + streaming=streaming, extended_agent_card=True + ), + skills=skills, + supported_interfaces=[ + AgentInterface( + url="http://localhost:9999/", + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ], + ) + + # Raise an error if insufficient data is provided. + else: + raise ValueError( + "Please provide either an agent_card or all of the required " + "parameters (agent_name, description, and skills)." + ) + + +def default_a2a_agent() -> "A2aAgent": + """Creates a default A2aAgent instance.""" + # pylint: disable=g-import-not-at-top + from a2a.server.agent_execution import AgentExecutor, RequestContext + from a2a.types import AgentSkill + from a2a.server.events import EventQueue + from a2a.helpers.proto_helpers import new_text_message + + skill = AgentSkill( + id="hello_world", + name="Returns hello world", + description="just returns hello world", + tags=["hello world"], + examples=["hi", "hello world"], + ) + agent_card = create_agent_card( + agent_name="Hello World Agent", + description="Just a hello world agent", + skills=[skill], + ) + + class HelloWorldAgentExecutor(AgentExecutor): + """Hello World Agent Executor.""" + + def get_agent_response(self) -> str: + return "Hello World" + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + result = self.get_agent_response() + await event_queue.enqueue_event(new_text_message(result)) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + raise Exception("cancel not supported") + + return A2aAgent( + agent_card=agent_card, + agent_executor_builder=HelloWorldAgentExecutor, + ) + + +def _is_version_enabled(agent_card: "AgentCard", version: str) -> bool: + """Checks if a specific version compatibility should be enabled for the A2aAgent.""" + # pylint: disable=g-import-not-at-top + from a2a.utils.constants import TransportProtocol + + if not getattr(agent_card, "supported_interfaces", None): + return False + for interface in agent_card.supported_interfaces: + if ( + interface.protocol_version == version + and interface.protocol_binding == TransportProtocol.HTTP_JSON + ): + return True + return False + + +class A2aAgent: + """A class to initialize and set up an Agent-to-Agent application.""" + + agent_framework = "a2a" + + # TODO: Add instrumentation for the A2A agent. + def __init__( + self, + *, + agent_card: "AgentCard", + task_store_builder: Callable[..., "TaskStore"] = None, + task_store_kwargs: Optional[Mapping[str, Any]] = None, + agent_executor_kwargs: Optional[Mapping[str, Any]] = None, + agent_executor_builder: Optional[Callable[..., "AgentExecutor"]] = None, + request_handler_kwargs: Optional[Mapping[str, Any]] = None, + request_handler_builder: Optional[Callable[..., "RequestHandler"]] = None, + extended_agent_card: "AgentCard" = None, + ): + """Initializes the A2A agent.""" + # pylint: disable=g-import-not-at-top + from google.cloud.aiplatform import initializer + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT + + if ( + agent_card.supported_interfaces + and agent_card.supported_interfaces[0].protocol_binding + != TransportProtocol.HTTP_JSON + ): + raise ValueError( + "Only HTTP+JSON is supported for the primary interface on agent card " + ) + if not _is_version_enabled(agent_card, PROTOCOL_VERSION_CURRENT): + raise ValueError( + "A2A protocol version 1.0 is required but not enabled on the agent card." + ) + + self._tmpl_attrs: dict[str, Any] = { + "agent_card": agent_card, + "agent_executor": None, + "agent_executor_kwargs": agent_executor_kwargs or {}, + "agent_executor_builder": agent_executor_builder, + "task_store": None, + "task_store_kwargs": task_store_kwargs or {}, + "task_store_builder": task_store_builder, + "request_handler": None, + "request_handler_kwargs": request_handler_kwargs or {}, + "request_handler_builder": request_handler_builder, + "extended_agent_card": extended_agent_card, + } + self.agent_card = agent_card + self.request_handler = None + self.task_store = None + self.agent_executor = None + + def clone(self) -> "A2aAgent": + """Clones the A2A agent.""" + import copy + + return A2aAgent( + agent_card=copy.deepcopy(self.agent_card), + task_store_builder=self._tmpl_attrs.get("task_store_builder"), + task_store_kwargs=self._tmpl_attrs.get("task_store_kwargs"), + agent_executor_kwargs=self._tmpl_attrs.get("agent_executor_kwargs"), + agent_executor_builder=self._tmpl_attrs.get("agent_executor_builder"), + request_handler_kwargs=self._tmpl_attrs.get("request_handler_kwargs"), + request_handler_builder=self._tmpl_attrs.get("request_handler_builder"), + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) + + def set_up(self): + """Sets up the A2A application.""" + # pylint: disable=g-import-not-at-top + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.routes.rest_routes import create_rest_routes + from a2a.server.tasks import InMemoryTaskStore + + os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") + agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", "test-agent-engine") + version = "v1beta1" + + new_url = f"https://{location}-aiplatform.googleapis.com/{version}/projects/{project}/locations/{location}/reasoningEngines/{agent_engine_id}/a2a" + if not self.agent_card.supported_interfaces: + from a2a.types import AgentInterface + from a2a.utils.constants import TransportProtocol, PROTOCOL_VERSION_CURRENT + + self.agent_card.supported_interfaces.append( + AgentInterface( + url=new_url, + protocol_binding=TransportProtocol.HTTP_JSON, + protocol_version=PROTOCOL_VERSION_CURRENT, + ) + ) + else: + # primary interface must be HTTP+JSON + self.agent_card.supported_interfaces[0].url = new_url + self._tmpl_attrs["agent_card"] = self.agent_card + + # Create the agent executor if a builder is provided. + agent_executor_builder = self._tmpl_attrs.get("agent_executor_builder") + if agent_executor_builder: + self._tmpl_attrs["agent_executor"] = agent_executor_builder( + **self._tmpl_attrs.get("agent_executor_kwargs") + ) + self.agent_executor = self._tmpl_attrs.get("agent_executor") + + # Create the task store if a builder is provided. + task_store_builder = self._tmpl_attrs.get("task_store_builder") + if task_store_builder: + self.task_store = task_store_builder( + **self._tmpl_attrs.get("task_store_kwargs") + ) + else: + # Use the default task store if not provided. This could potentially + # lead to unexpected behavior if the agent is running on + # multiple instances. + self.task_store = InMemoryTaskStore() + + self._tmpl_attrs["task_store"] = self.task_store + + # Create the request handler if a builder is provided. + request_handler_builder = self._tmpl_attrs.get("request_handler_builder") + if request_handler_builder: + self.request_handler = request_handler_builder( + **self._tmpl_attrs.get("request_handler_kwargs") + ) + else: + # Use the default request handler if not provided. + self.request_handler = DefaultRequestHandler( + agent_executor=self._tmpl_attrs.get("agent_executor"), + task_store=self.task_store, + agent_card=self.agent_card, + extended_agent_card=self._tmpl_attrs.get("extended_agent_card"), + ) + + self._tmpl_attrs["request_handler"] = self.request_handler + + # Support native Starlette routes. + enable_v0_3 = _is_version_enabled(self.agent_card, "0.3") + self.rest_routes = create_rest_routes( + request_handler=self, + enable_v0_3_compat=enable_v0_3, + path_prefix="/a2a", + ) + + def __getattr__(self, name: str) -> Any: + """Delegates all missing RequestHandler methods to the underlying request_handler.""" + if not self.request_handler: + raise AttributeError( + f"'A2aAgent' has no attribute '{name}' and request_handler is not initialized." + ) + return getattr(self.request_handler, name) + + async def on_message_send( + self, + request: "SendMessageRequest", + context: "ServerCallContext", + ) -> "Task | Message": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_message_send(request, context) + + async def on_cancel_task( + self, + request: "CancelTaskRequest", + context: "ServerCallContext", + ) -> "Task | None": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_cancel_task(request, context) + + async def on_get_task( + self, + request: "GetTaskRequest", + context: "ServerCallContext", + ) -> "Task | None": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_get_task(request, context) + + async def on_list_tasks( + self, + request: "ListTasksRequest", + context: "ServerCallContext", + ) -> "ListTasksResponse": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_list_tasks(request, context) + + async def on_create_task_push_notification_config( + self, + request: "TaskPushNotificationConfig", + context: "ServerCallContext", + ) -> "TaskPushNotificationConfig": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_create_task_push_notification_config( + request, context + ) + + async def on_get_task_push_notification_config( + self, + request: "GetTaskPushNotificationConfigRequest", + context: "ServerCallContext", + ) -> "TaskPushNotificationConfig": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_get_task_push_notification_config( + request, context + ) + + async def on_list_task_push_notification_configs( + self, + request: "ListTaskPushNotificationConfigsRequest", + context: "ServerCallContext", + ) -> "ListTaskPushNotificationConfigsResponse": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_list_task_push_notification_configs( + request, context + ) + + async def on_delete_task_push_notification_config( + self, + request: "DeleteTaskPushNotificationConfigRequest", + context: "ServerCallContext", + ) -> None: + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_delete_task_push_notification_config( + request, context + ) + + async def on_get_extended_agent_card( + self, + request: "GetExtendedAgentCardRequest", + context: "ServerCallContext", + ) -> "AgentCard": + if not self.request_handler: + raise NotImplementedError("request_handler not available.") + return await self.request_handler.on_get_extended_agent_card(request, context) + + def register_operations(self) -> Dict[str, List[str]]: + """Registers the operations of the A2A Agent.""" + routes = { + "a2a_extension": [ + "on_message_send", + "on_get_task", + "on_list_tasks", + "on_cancel_task", + "on_create_task_push_notification_config", + "on_get_task_push_notification_config", + "on_list_task_push_notification_configs", + "on_delete_task_push_notification_config", + ] + } + if self.agent_card.capabilities and self.agent_card.capabilities.streaming: + routes["a2a_extension"].append("on_message_send_stream") + routes["a2a_extension"].append("on_subscribe_to_task") + if ( + self.agent_card.capabilities + and self.agent_card.capabilities.extended_agent_card + ): + routes["a2a_extension"].append("on_get_extended_agent_card") + return routes + + async def on_message_send_stream( + self, + request: "SendMessageRequest", + context: "ServerCallContext", + ) -> AsyncIterator["Event"]: + """Handles A2A streaming requests via SSE.""" + async for chunk in self.request_handler.on_message_send_stream( + request, context + ): + yield chunk + + async def on_subscribe_to_task( + self, + request: "SubscribeToTaskRequest", + context: "ServerCallContext", + ) -> AsyncIterator["Event"]: + """Handles A2A task resubscription requests via SSE.""" + async for chunk in self.request_handler.on_subscribe_to_task(request, context): + yield chunk + + def __getstate__(self): + """Serializes AgentCard proto to a dictionary.""" + from google.protobuf import json_format + import json + + state = self.__dict__.copy() + + def _to_dict_if_proto(obj): + if hasattr(obj, "DESCRIPTOR"): + return { + "__protobuf_AgentCard__": json.loads(json_format.MessageToJson(obj)) + } + return obj + + state["agent_card"] = _to_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + tmpl_attrs = state["_tmpl_attrs"].copy() + tmpl_attrs["agent_card"] = _to_dict_if_proto(tmpl_attrs.get("agent_card")) + tmpl_attrs["extended_agent_card"] = _to_dict_if_proto( + tmpl_attrs.get("extended_agent_card") + ) + state["_tmpl_attrs"] = tmpl_attrs + + return state + + def __setstate__(self, state): + """Deserializes AgentCard proto from a dictionary.""" + from google.protobuf import json_format + from a2a.types import AgentCard + + def _from_dict_if_proto(obj): + if isinstance(obj, dict) and "__protobuf_AgentCard__" in obj: + agent_card = AgentCard() + json_format.ParseDict(obj["__protobuf_AgentCard__"], agent_card) + return agent_card + return obj + + state["agent_card"] = _from_dict_if_proto(state.get("agent_card")) + if "_tmpl_attrs" in state: + state["_tmpl_attrs"]["agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("agent_card") + ) + state["_tmpl_attrs"]["extended_agent_card"] = _from_dict_if_proto( + state["_tmpl_attrs"].get("extended_agent_card") + ) + + self.__dict__.update(state) diff --git a/agentplatform/frameworks/adk.py b/agentplatform/frameworks/adk.py new file mode 100644 index 0000000000..35e5e40079 --- /dev/null +++ b/agentplatform/frameworks/adk.py @@ -0,0 +1,1849 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterable, + Callable, + Dict, + List, + Optional, + Union, +) + +import asyncio +from collections.abc import Awaitable +import queue +import sys +import threading +import warnings + +if TYPE_CHECKING: + try: + from google.adk.events.event import Event + + Event = Event + except (ImportError, AttributeError): + Event = Any + + try: + from google.adk.apps import App + + App = App + except (ImportError, AttributeError): + App = Any + + try: + from google.adk.agents import BaseAgent + + BaseAgent = BaseAgent + except (ImportError, AttributeError): + BaseAgent = Any + + try: + from google.adk.plugins.base_plugin import BasePlugin + + BasePlugin = BasePlugin + except (ImportError, AttributeError): + BasePlugin = Any + + try: + from google.adk.sessions import BaseSessionService + + BaseSessionService = BaseSessionService + except (ImportError, AttributeError): + BaseSessionService = Any + + try: + from google.adk.artifacts import BaseArtifactService + + BaseArtifactService = BaseArtifactService + except (ImportError, AttributeError): + BaseArtifactService = Any + + try: + from google.adk.memory import BaseMemoryService + + BaseMemoryService = BaseMemoryService + except (ImportError, AttributeError): + BaseMemoryService = Any + + try: + from google.adk.auth.credential_service.base_credential_service import ( + BaseCredentialService, + ) + + BaseCredentialService = BaseCredentialService + except (ImportError, AttributeError): + BaseCredentialService = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except (ImportError, AttributeError): + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +_DEFAULT_APP_NAME = "default_app_name" +_DEFAULT_USER_ID = "default-user-id" +_TELEMETRY_API_DISABLED_WARNING = ( + "Tracing integration for Agent Engine has migrated to a new API.\n" + "The 'telemetry.googleapis.com' has not been enabled in project %s. \n" + "**Impact:** Until this API is enabled, telemetry data will not be stored." + "\n" + "**Action:** Please enable the API by visiting " + "https://console.developers.google.com/apis/api/telemetry.googleapis.com/overview?project=%s." + "\n" + "(If you enabled this API recently, you can safely ignore this warning.)" +) + + +def get_adk_version() -> Optional[str]: + """Returns the version of the ADK package.""" + try: + from google.adk import version + + return version.__version__ + except (ImportError, AttributeError): + return None + + +def is_version_sufficient(version_to_check: str) -> bool: + """Compares the existing version of ADK with the required version. + + Args: + version_to_check: The version string to check. + + Returns: + True if the existing version is sufficient, False otherwise. + """ + try: + from packaging.version import parse + + return parse(get_adk_version()) >= parse(version_to_check) + except (AttributeError, ImportError): + return False + + +class _ArtifactVersion: + def __init__(self, **kwargs): + from google.genai import types + + self.version: Optional[str] = kwargs.get("version") + data = kwargs.get("data") + self.data: Optional[types.Part] = ( + types.Part.model_validate(data) if isinstance(data, dict) else data + ) + + def dump(self) -> Dict[str, Any]: + result = {} + if self.version: + result["version"] = self.version + if self.data: + result["data"] = self.data + return result + + +class _Artifact: + def __init__(self, **kwargs): + self.file_name: Optional[str] = kwargs.get("file_name") + self.versions: List[_ArtifactVersion] = kwargs.get("versions") + + def dump(self) -> Dict[str, Any]: + result = {} + if self.file_name: + result["file_name"] = self.file_name + if self.versions: + result["versions"] = [version.dump() for version in self.versions] + return result + + +class _Authorization: + def __init__(self, **kwargs): + self.access_token: Optional[str] = kwargs.get("access_token") or kwargs.get( + "accessToken" + ) + + +class _StreamRunRequest: + """Request object for `streaming_agent_run_with_events` method.""" + + def __init__(self, **kwargs): + from google.adk.events.event import Event + from google.genai import types + + self.message: Optional[types.Content] = kwargs.get("message") + # The new message to be processed by the agent. + + self.events: Optional[List[Event]] = kwargs.get("events") + # List of preceding events happened in the same session. + + self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts") + # List of artifacts belonging to the session. + + self.authorizations: Dict[str, _Authorization] = kwargs.get( + "authorizations", {} + ) + # The authorizations of the user, keyed by authorization ID. + + self.user_id: Optional[str] = kwargs.get("user_id") or kwargs.get( + "userId", _DEFAULT_USER_ID + ) + # The user ID. + + self.session_id: Optional[str] = kwargs.get("session_id") or kwargs.get( + "sessionId" + ) + # The session ID. + + +class _StreamingRunResponse: + """Response object for `streaming_agent_run_with_events` method. + + It contains the generated events together with the belonging artifacts. + """ + + def __init__(self, **kwargs): + self.events: Optional[List["Event"]] = kwargs.get("events") + # List of generated events. + self.artifacts: Optional[List[_Artifact]] = kwargs.get("artifacts") + # List of artifacts belonging to the session. + self.session_id: Optional[str] = kwargs.get("session_id") + # The session ID. + + def dump(self) -> Dict[str, Any]: + from agentplatform._genai import _agent_engines_utils + + result = {} + if self.events: + result["events"] = [] + for event in self.events: + event_dict = _agent_engines_utils.dump_event_for_json(event) + event_dict["invocation_id"] = event_dict.get("invocation_id", "") + result["events"].append(event_dict) + if self.artifacts: + result["artifacts"] = [artifact.dump() for artifact in self.artifacts] + if self.session_id: + result["session_id"] = self.session_id + return result + + +def _warn(msg: str): + if not hasattr(_warn, "_LOGGER"): + from google.cloud.aiplatform import base + + _warn._LOGGER = base.Logger( + __name__ + ) # pyright: ignore[reportFunctionMemberAccess] + + _warn._LOGGER.warning(msg) # pyright: ignore[reportFunctionMemberAccess] + + +async def _force_flush_otel(tracing_enabled: bool, logging_enabled: bool): + try: + import opentelemetry.trace + import opentelemetry._logs + except (ImportError, AttributeError): + _warn( + "Could not force flush telemetry data. opentelemetry-api is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + try: + import opentelemetry.sdk.trace + import opentelemetry.sdk._logs + except (ImportError, AttributeError): + _warn( + "Could not force flush telemetry data. opentelemetry-sdk is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + coros: List[Awaitable[bool]] = [] + + if tracing_enabled: + tracer_provider = opentelemetry.trace.get_tracer_provider() + if isinstance(tracer_provider, opentelemetry.sdk.trace.TracerProvider): + coros.append(asyncio.to_thread(tracer_provider.force_flush)) + + if logging_enabled: + logger_provider = opentelemetry._logs.get_logger_provider() + if isinstance(logger_provider, opentelemetry.sdk._logs.LoggerProvider): + coros.append(asyncio.to_thread(logger_provider.force_flush)) + + await asyncio.gather(*coros, return_exceptions=True) + + +def _default_instrumentor_builder( + project_id: Optional[str], + *, + enable_tracing: bool = False, + enable_logging: bool = False, +): + if not enable_tracing and not enable_logging: + return None + + if project_id is None: + _warn( + "telemetry is only supported when project is specified, proceeding with no telemetry" + ) + return None + + import os + + def _warn_missing_dependency( + package: str, + *, + needed_for_logging: bool = False, + needed_for_tracing: bool = False, + ) -> None: + _warn( + f"{package} is not installed. Please call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + MISSING_TRACE_IMPORT_ERROR_MESSAGE = "proceeding with tracing disabled because not all packages (i.e. `google-cloud-trace`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-trace`) for tracing have been installed" + MISSING_LOGGING_IMPORT_ERROR_MESSAGE = "proceeding with logging disabled because not all packages (i.e. `google-cloud-logging`, `opentelemetry-sdk`, `opentelemetry-exporter-gcp-logging`) for tracing have been installed" + + if needed_for_tracing and enable_tracing: + _warn(MISSING_TRACE_IMPORT_ERROR_MESSAGE) + if needed_for_logging and enable_logging: + _warn(MISSING_LOGGING_IMPORT_ERROR_MESSAGE) + return None + + def _detect_cloud_resource_id(project_id: str) -> Optional[str]: + location = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", "") or os.getenv( + "GOOGLE_CLOUD_LOCATION", "" + ) + agent_engine_id = os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID") + if all(v is not None for v in (location, agent_engine_id)): + return f"//aiplatform.googleapis.com/projects/{project_id}/locations/{location}/reasoningEngines/{agent_engine_id}" + return None + + try: + import opentelemetry + import opentelemetry.trace + import opentelemetry._logs + import opentelemetry._events + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-api", needed_for_tracing=True, needed_for_logging=True + ) + + try: + import opentelemetry.sdk.resources + import opentelemetry.sdk.trace + import opentelemetry.sdk.trace.export + import opentelemetry.sdk._logs + import opentelemetry.sdk._logs.export + import opentelemetry.sdk._events + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-sdk", needed_for_tracing=True, needed_for_logging=True + ) + + import uuid + + # Provide a set of resource attributes but allow to override them with env + # variables like OTEL_RESOURCE_ATTRIBUTES and OTEL_SERVICE_NAME. + cloud_resource_id = _detect_cloud_resource_id(project_id) + resource = opentelemetry.sdk.resources.Resource.create( + attributes={ + "gcp.project_id": project_id, + "cloud.account.id": project_id, + "cloud.provider": "gcp", + "cloud.platform": "gcp.agent_engine", + "service.name": os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_ID", ""), + "service.instance.id": f"{uuid.uuid4().hex}-{os.getpid()}", + "cloud.region": ( + os.getenv("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", "") + or os.getenv("GOOGLE_CLOUD_LOCATION", "") + ), + } + | ( + {"cloud.resource_id": cloud_resource_id} + if cloud_resource_id is not None + else {} + ) + ).merge(opentelemetry.sdk.resources.OTELResourceDetector().detect()) + + if enable_tracing: + try: + import opentelemetry.exporter.otlp.proto.http.version + import opentelemetry.exporter.otlp.proto.http.trace_exporter + import google.auth.transport.requests + from google.cloud.aiplatform import version as aip_version + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-exporter-otlp-proto-http", needed_for_tracing=True + ) + + import google.auth + + credentials, _ = google.auth.default() + vertex_sdk_version = aip_version.__version__ + otlp_http_version = opentelemetry.exporter.otlp.proto.http.version.__version__ + user_agent = f"Vertex-Agent-Engine/{vertex_sdk_version} OTel-OTLP-Exporter-Python/{otlp_http_version}" + + span_exporter = ( + opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter( + session=google.auth.transport.requests.AuthorizedSession( + credentials=credentials + ), + endpoint="https://telemetry.googleapis.com/v1/traces", + headers={"User-Agent": user_agent}, + ) + ) + span_processor = opentelemetry.sdk.trace.export.BatchSpanProcessor( + span_exporter=span_exporter, + ) + tracer_provider = opentelemetry.trace.get_tracer_provider() + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + _warn( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry.sdk.trace.TracerProvider(resource=resource) + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + from agentplatform._genai import _agent_engines_utils + + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry.sdk.trace.TracerProvider(resource=resource) + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry.sdk.trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + + if enable_logging: + try: + import opentelemetry.exporter.cloud_logging + except (ImportError, AttributeError): + return _warn_missing_dependency( + "opentelemetry-exporter-gcp-logging", needed_for_logging=True + ) + + class _SimpleLogRecordProcessor( + opentelemetry.sdk._logs.export.SimpleLogRecordProcessor + ): + def force_flush( + self, timeout_millis: int = 30000 + ) -> bool: # pylint: disable=no-self-use + sys.stdout.flush() + sys.stderr.flush() + return True + + logger_provider = opentelemetry.sdk._logs.LoggerProvider(resource=resource) + # Use the legacy log processor when experimental semconv is enabled. + # Exporting JSON logs to stdout is bugged; Agent Engine fails to + # correctly parse the `gen_ai.client.inference.operation.details` + # messages. + # TODO: b/480102541 - Unify both branches once the regression is fixed. + if "gen_ai_latest_experimental" in os.getenv( + "OTEL_SEMCONV_STABILITY_OPT_IN", "" + ).split(","): + logger_provider.add_log_record_processor( + opentelemetry.sdk._logs.export.BatchLogRecordProcessor( + opentelemetry.exporter.cloud_logging.CloudLoggingExporter( + project_id=project_id, + default_log_name=os.getenv( + "GCP_DEFAULT_LOG_NAME", "adk-on-agent-engine" + ), + ), + ) + ) + else: + logger_provider.add_log_record_processor( + _SimpleLogRecordProcessor( + opentelemetry.exporter.cloud_logging.CloudLoggingExporter( + project_id=project_id, + default_log_name=os.getenv( + "GCP_DEFAULT_LOG_NAME", "adk-on-agent-engine" + ), + structured_json_file=sys.stdout, + ), + ) + ) + event_logger_provider = opentelemetry.sdk._events.EventLoggerProvider( + logger_provider=logger_provider + ) + + opentelemetry._logs.set_logger_provider(logger_provider=logger_provider) + opentelemetry._events.set_event_logger_provider( + event_logger_provider=event_logger_provider + ) + + try: + from opentelemetry.instrumentation import google_genai + + google_genai.GoogleGenAiSdkInstrumentor().instrument() + except (ImportError, AttributeError): + _warn( + "telemetry enabled but proceeding without GenAI instrumentation, because not all packages (i.e. opentelemetry-instrumentation-google-genai) have been installed" + ) + + return None + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LangchainAgents in the same environment, + it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LangchainAgent is created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +def _validate_run_config(run_config: Optional[Dict[str, Any]]): + """Validates the run config.""" + from google.adk.agents.run_config import RunConfig + + if run_config is None: + return None + elif isinstance(run_config, Dict): + return RunConfig.model_validate(run_config) + raise TypeError("run_config must be a dictionary representing a RunConfig object.") + + +def _warn_if_telemetry_api_disabled(): + """Warn if telemetry API is disabled.""" + try: + import google.auth.transport.requests + import google.auth + except (ImportError, AttributeError): + return + credentials, project = google.auth.default() + session = google.auth.transport.requests.AuthorizedSession(credentials=credentials) + r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + if "Telemetry API has not been used in project" in r.text: + _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) + + +class AdkApp: + """An ADK Application.""" + + agent_framework = "google-adk" + + def __init__( + self, + *, + app: "App" = None, + agent: "BaseAgent" = None, + app_name: Optional[str] = None, + plugins: Optional[List["BasePlugin"]] = None, + enable_tracing: Optional[bool] = None, + session_service_builder: Optional[Callable[..., "BaseSessionService"]] = None, + artifact_service_builder: Optional[Callable[..., "BaseArtifactService"]] = None, + memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None, + credential_service_builder: Optional[ + Callable[..., "BaseCredentialService"] + ] = None, + instrumentor_builder: Optional[Callable[..., Any]] = None, + ): + """An ADK Application. + + See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/adk + for details on how to develop ADK applications on Agent Engine. + + Args: + agent (google.adk.agents.BaseAgent): + Required. The ADK agent to run. + app_name (str): + Optional. The name of the ADK application. Defaults to + "default-app-name" when running locally, and to the + corresponding agent engine ID when deployed on Agent Engine. + plugins (List[BasePlugin]): + Optional. The plugins to use for the ADK application. + Defaults to an empty list. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. + session_service_builder (Callable[..., BaseSessionService]): + Optional. A callable that returns an ADK session service. + Defaults to a callable that returns InMemorySessionService + when running locally and VertexAiSessionService when running + on Agent Engine. + artifact_service_builder (Callable[..., BaseArtifactService]): + Optional. A callable that returns an ADK artifact service. + Defaults to a callable that returns InMemoryArtifactService. + memory_service_builder (Callable[..., BaseMemoryService]): + Optional. A callable that returns an ADK memory service. + Defaults to a callable that returns InMemoryMemoryService + when running locally and VertexAiMemoryBankService when running + on Agent Engine. + credential_service_builder (Callable[..., BaseCredentialService]): + Optional. A callable that returns an ADK credential service. + Defaults to a callable that returns InMemoryCredentialService. + instrumentor_builder (Callable[..., Any]): + Optional. Callable that returns a new instrumentor. This can be + used for customizing the instrumentation logic of the Agent. + If not provided, a default instrumentor builder will be used. + This parameter is ignored if `enable_tracing` is False. + """ + import os + from google.cloud.aiplatform import initializer + + adk_version = get_adk_version() + if not is_version_sufficient("1.5.0"): + msg = ( + f"Unsupported google-adk version: {adk_version}, please use " + "google-adk>=1.5.0 for AdkApp deployment on Agent Engine." + ) + raise ValueError(msg) + + if not agent and not app: + raise ValueError("One of `agent` or `app` must be provided.") + if app: + if app_name: + raise ValueError( + "When app is provided, app_name should not be provided, " + "since it will be derived from app.name." + ) + if agent: + raise ValueError("When app is provided, agent should not be provided.") + if plugins: + raise ValueError( + "When app is provided, plugins should not be provided and" + " should be provided in the app instead." + ) + + self._tmpl_attrs: Dict[str, Any] = { + "agent": agent, + "app": app, + "app_name": app_name, + "plugins": plugins, + "enable_tracing": enable_tracing, + "session_service_builder": session_service_builder, + "artifact_service_builder": artifact_service_builder, + "memory_service_builder": memory_service_builder, + "credential_service_builder": credential_service_builder, + "instrumentor_builder": instrumentor_builder, + } + + def _serialize(self, obj: Any) -> Any: + """Serializes an object to be JSON compatible.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + elif hasattr(obj, "dict"): + return self._serialize(obj.dict()) + elif isinstance(obj, dict): + return {k: self._serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._serialize(v) for v in obj] + return obj + + def _app_name(self) -> str: + """Returns the app name.""" + app = self._tmpl_attrs.get("app") + return app.name if app else self._tmpl_attrs.get("app_name") + + async def _init_session( + self, + session_service: "BaseSessionService", + artifact_service: "BaseArtifactService", + request: _StreamRunRequest, + ): + """Initializes the session, and returns the session id.""" + from google.adk.events.event import Event + + session_state = None + if request.authorizations: + session_state = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + session_state[auth_id] = auth.access_token + + session = await session_service.create_session( + app_name=self._app_name(), + user_id=request.user_id, + state=session_state, + ) + if not session: + raise RuntimeError("Create session failed.") + if request.events: + for event in request.events: + await session_service.append_event(session, Event(**event)) + if request.artifacts: + await self._save_artifacts(session.id, artifact_service, request) + return session + + async def _save_artifacts( + self, + session_id: str, + artifact_service: "BaseArtifactService", + request: _StreamRunRequest, + ): + """Saves the artifacts.""" + if request.artifacts: + for artifact in request.artifacts: + artifact = _Artifact(**artifact) + for version_data in sorted( + artifact.versions, key=lambda x: x["version"] + ): + version_data = _ArtifactVersion(**version_data) + saved_version = await artifact_service.save_artifact( + app_name=self._app_name(), + user_id=request.user_id, + session_id=session_id, + filename=artifact.file_name, + artifact=version_data.data, + ) + if saved_version != version_data.version: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.debug( + "Artifact '%s' saved at version %s instead of %s", + artifact.file_name, + saved_version, + version_data.version, + ) + + async def _convert_response_events( + self, + user_id: str, + session_id: str, + events: List["Event"], + artifact_service: Optional["BaseArtifactService"], + ) -> _StreamingRunResponse: + """Converts the events to the streaming run response object.""" + import collections + + result = _StreamingRunResponse( + events=events, artifacts=[], session_id=session_id + ) + + # Save the generated artifacts into the result object. + artifact_versions = collections.defaultdict(list) + for event in events: + if event.actions and event.actions.artifact_delta: + for key, version in event.actions.artifact_delta.items(): + artifact_versions[key].append(version) + + for key, versions in artifact_versions.items(): + result.artifacts.append( + _Artifact( + file_name=key, + versions=[ + _ArtifactVersion( + version=version, + data=await artifact_service.load_artifact( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + filename=key, + version=version, + ), + ) + for version in versions + ], + ) + ) + + return result.dump() + + def clone(self): + """Returns a clone of the ADK application.""" + import copy + + return self.__class__( + app=copy.deepcopy(self._tmpl_attrs.get("app")), + enable_tracing=self._tmpl_attrs.get("enable_tracing"), + agent=( + None + if self._tmpl_attrs.get("app") + else copy.deepcopy(self._tmpl_attrs.get("agent")) + ), + app_name=( + None + if self._tmpl_attrs.get("app") + else self._tmpl_attrs.get("app_name") + ), + plugins=( + None + if self._tmpl_attrs.get("app") + else copy.deepcopy(self._tmpl_attrs.get("plugins")) + ), + session_service_builder=self._tmpl_attrs.get("session_service_builder"), + artifact_service_builder=self._tmpl_attrs.get("artifact_service_builder"), + memory_service_builder=self._tmpl_attrs.get("memory_service_builder"), + credential_service_builder=self._tmpl_attrs.get( + "credential_service_builder" + ), + instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), + ) + + def set_up(self): + """Sets up the ADK application.""" + import os + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.artifacts.in_memory_artifact_service import ( + InMemoryArtifactService, + ) + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + from google.adk.auth.credential_service.in_memory_credential_service import ( + InMemoryCredentialService, + ) + + os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" + ) or os.environ.get("GOOGLE_CLOUD_LOCATION") + if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location + if "GOOGLE_CLOUD_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_LOCATION"] = location + agent_engine_location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", # the runtime env var (if set) + location, # the location set in the AdkApp template + ) + express_mode_api_key = os.environ.get("GOOGLE_API_KEY") + if express_mode_api_key and not project: + os.environ["GOOGLE_API_KEY"] = express_mode_api_key + # Clear location and project env vars if express mode api key is provided. + os.environ.pop("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_PROJECT", None) + location = None + + # Disable content capture in custom ADK spans unless user enabled + # tracing explicitly with the old flag + # (this is to preserve compatibility with old behavior). + if self._tmpl_attrs.get("enable_tracing"): + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" + else: + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" + + enable_logging = bool(self._telemetry_enabled()) + + custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder") + + if self._tmpl_attrs.get("enable_tracing"): + _warn_if_telemetry_api_disabled() + + if self._tmpl_attrs.get("enable_tracing") is False: + _warn( + ( + "Your 'enable_tracing=False' setting is being deprecated " + "and will be removed in a future release.\n" + "This legacy setting overrides the new Cloud Console " + "toggle and environment variable controls.\n" + "Impact: The Cloud Console may incorrectly show telemetry " + "as 'On' when it is actually 'Off', and the UI toggle will " + "not work.\n" + "Action: To fix this and control telemetry, please remove " + "the 'enable_tracing' parameter from your deployment " + "code.\n" + "You can then use the " + "'GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY' " + "environment variable:\n" + "agent_engines.create(\n" + " env_vars={\n" + ' "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY": true|false\n' + " }\n" + ")\n" + "or the toggle in the Cloud Console: " + "https://console.cloud.google.com/vertex-ai/agents." + ), + ) + + project_id = self._get_project_id(project) + if custom_instrumentor and self._tracing_enabled(): + self._tmpl_attrs["instrumentor"] = custom_instrumentor(project_id) + + if not custom_instrumentor: + self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( + project_id, + enable_tracing=self._tracing_enabled(), + enable_logging=enable_logging, + ) + + if not self._tmpl_attrs.get("app_name"): + if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + self._tmpl_attrs["app_name"] = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_ID", + ) + else: + self._tmpl_attrs["app_name"] = _DEFAULT_APP_NAME + + artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") + if artifact_service_builder: + self._tmpl_attrs["artifact_service"] = artifact_service_builder() + else: + self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() + + session_service_builder = self._tmpl_attrs.get("session_service_builder") + if session_service_builder: + self._tmpl_attrs["session_service"] = session_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + try: + from google.adk.sessions.vertex_ai_session_service import ( + VertexAiSessionService, + ) + + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( + project=project, + location=agent_engine_location, + agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), + ) + except (ImportError, AttributeError): + from google.adk.sessions.vertex_ai_session_service_g3 import ( + VertexAiSessionService, + ) + + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( + project=project, + location=agent_engine_location, + agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), + ) + + else: + self._tmpl_attrs["session_service"] = InMemorySessionService() + + memory_service_builder = self._tmpl_attrs.get("memory_service_builder") + if memory_service_builder: + self._tmpl_attrs["memory_service"] = memory_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( + "1.5.0" + ): + try: + from google.adk.memory.vertex_ai_memory_bank_service import ( + VertexAiMemoryBankService, + ) + + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + project=project, + location=agent_engine_location, + agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), + ) + except (ImportError, AttributeError): + from google.adk.memory.vertex_ai_memory_bank_service_g3 import ( + VertexAiMemoryBankService, + ) + + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + project=project, + location=agent_engine_location, + agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), + ) + else: + self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + + credential_service_builder = self._tmpl_attrs.get("credential_service_builder") + if credential_service_builder: + self._tmpl_attrs["credential_service"] = credential_service_builder() + else: + self._tmpl_attrs["credential_service"] = InMemoryCredentialService() + + self._tmpl_attrs["runner"] = Runner( + app=self._tmpl_attrs.get("app"), + agent=( + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") + ), + app_name=( + None + if self._tmpl_attrs.get("app") + else self._tmpl_attrs.get("app_name") + ), + plugins=( + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins") + ), + session_service=self._tmpl_attrs.get("session_service"), + artifact_service=self._tmpl_attrs.get("artifact_service"), + memory_service=self._tmpl_attrs.get("memory_service"), + ) + self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() + self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() + self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() + self._tmpl_attrs["in_memory_runner"] = Runner( + app=self._tmpl_attrs.get("app"), + app_name=( + None + if self._tmpl_attrs.get("app") + else self._tmpl_attrs.get("app_name") + ), + agent=( + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") + ), + plugins=( + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("plugins") + ), + session_service=self._tmpl_attrs.get("in_memory_session_service"), + artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), + memory_service=self._tmpl_attrs.get("in_memory_memory_service"), + credential_service=self._tmpl_attrs.get("credential_service"), + ) + + async def async_stream_query( + self, + *, + message: Union[str, Dict[str, Any]], + user_id: str, + session_id: Optional[str] = None, + session_events: Optional[List[Dict[str, Any]]] = None, + run_config: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> AsyncIterable[Dict[str, Any]]: + """Streams responses asynchronously from the ADK application. + + Args: + message (str): + Required. The message to stream responses for. + user_id (str): + Required. The ID of the user. + session_id (str): + Optional. The ID of the session. If not provided, a new + session will be created for the user. If this is specified, then + `session_events` will be ignored. + session_events (Optional[List[Dict[str, Any]]]): + Optional. The session events to use for the query. This will be + used to initialize the session if `session_id` is not provided. + run_config (Optional[Dict[str, Any]]): + Optional. The run config to use for the query. If you want to + pass in a `run_config` pydantic object, you can pass in a dict + representing it as `run_config.model_dump(mode="json")`. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + runner. + + Yields: + Event dictionaries asynchronously. + + Raises: + TypeError: If message is not a string or a dictionary representing + a Content object. + ValueError: If both session_id and session_events are specified. + """ + from agentplatform._genai import _agent_engines_utils + from google.genai import types + + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( + "message must be a string or a dictionary representing" + " a Content object." + ) + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if session_id and session_events: + raise ValueError( + "Only one of session_id and session_events should be specified." + ) + if not session_id: + session = await self.async_create_session(user_id=user_id) + session_id = session["id"] + if session_events is not None: + # We allow for session_events to be an empty list. + from google.adk.events.event import Event + + session_service = self._tmpl_attrs.get("session_service") + for event in session_events: + if not isinstance(event, Event): + event = Event.model_validate(event) + await session_service.append_event( + session=session, + event=event, + ) + + run_config = _validate_run_config(run_config) + if run_config: + events_async = self._tmpl_attrs.get("runner").run_async( + user_id=user_id, + session_id=session_id, + new_message=content, + run_config=run_config, + **kwargs, + ) + else: + events_async = self._tmpl_attrs.get("runner").run_async( + user_id=user_id, + session_id=session_id, + new_message=content, + **kwargs, + ) + + try: + async for event in events_async: + # Yield the event data as a dictionary + yield _agent_engines_utils.dump_event_for_json(event) + finally: + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( + tracing_enabled=self._tracing_enabled(), + logging_enabled=bool(self._telemetry_enabled()), + ) + + def stream_query( + self, + *, + message: Union[str, Dict[str, Any]], + user_id: str, + session_id: Optional[str] = None, + run_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Deprecated. Use async_stream_query instead. + + Streams responses from the ADK application in response to a message. + + Args: + message (Union[str, Dict[str, Any]]): + Required. The message to stream responses for. + user_id (str): + Required. The ID of the user. + session_id (str): + Optional. The ID of the session. If not provided, a new + session will be created for the user. + run_config (Optional[Dict[str, Any]]): + Optional. The run config to use for the query. If you want to + pass in a `run_config` pydantic object, you can pass in a dict + representing it as `run_config.model_dump(mode="json")`. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + runner. + + Yields: + The output of querying the ADK application. + """ + warnings.warn( + ( + "AdkApp.stream_query(...) is deprecated. " + "Use AdkApp.async_stream_query(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#stream-responses " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + from agentplatform._genai import _agent_engines_utils + from google.genai import types + + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( + "message must be a string or a dictionary representing" + " a Content object." + ) + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = self.create_session(user_id=user_id) + session_id = session["id"] + run_config = _validate_run_config(run_config) + if run_config: + for event in self._tmpl_attrs.get("runner").run( + user_id=user_id, + session_id=session_id, + new_message=content, + run_config=run_config, + **kwargs, + ): + yield _agent_engines_utils.dump_event_for_json(event) + else: + for event in self._tmpl_attrs.get("runner").run( + user_id=user_id, + session_id=session_id, + new_message=content, + **kwargs, + ): + yield _agent_engines_utils.dump_event_for_json(event) + + async def streaming_agent_run_with_events(self, request_json: str): + """Streams responses asynchronously from the ADK application. + + In general, you should use `async_stream_query` instead, as it has a + more structured API and works with the respective ADK services that + you have defined for the AdkApp. This method is primarily meant for + invocation from AgentSpace. + + Args: + request_json (str): + Required. The request to stream responses for. + """ + + import json + from google.genai import types + from google.genai.errors import ClientError + + request = _StreamRunRequest(**json.loads(request_json)) + if not any( + self._tmpl_attrs.get(service) + for service in ( + "in_memory_runner", + "runner", + "in_memory_artifact_service", + "artifact_service", + "in_memory_session_service", + "session_service", + "in_memory_memory_service", + "memory_service", + ) + ): + self.set_up() + + # Try to get the session, if it doesn't exist, create a new one. + if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") + session = None + try: + session = await session_service.get_session( + app_name=self._app_name(), + user_id=request.user_id, + session_id=request.session_id, + ) + if session: + await self._save_artifacts( + session_id=request.session_id, + artifact_service=artifact_service, + request=request, + ) + except ClientError: + pass + if not session: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await self._init_session( + session_service=session_service, + artifact_service=artifact_service, + request=request, + ) + if not session: + raise RuntimeError("Session initialization failed.") + + # Run the agent + message_for_agent = types.Content(**request.message) + try: + async for event in runner.run_async( + user_id=request.user_id, + session_id=session.id, + new_message=message_for_agent, + ): + converted_event = await self._convert_response_events( + user_id=request.user_id, + session_id=session.id, + events=[event], + artifact_service=artifact_service, + ) + yield converted_event + finally: + if session and not request.session_id: + await session_service.delete_session( + app_name=self._app_name(), + user_id=request.user_id, + session_id=session.id, + ) + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( + tracing_enabled=self._tracing_enabled(), + logging_enabled=bool(self._telemetry_enabled()), + ) + + async def bidi_stream_query( + self, + request_queue: Any, + ) -> AsyncIterable[Any]: + """Bidi streaming query the ADK application. + + Args: + request_queue: + The queue of requests to stream responses for, with the type of + asyncio.Queue[Any]. + + Raises: + TypeError: If the request_queue is not an asyncio.Queue instance. + ValueError: If the first request does not have a user_id. + ValidationError: If failed to convert to LiveRequest. + + Yields: + The stream responses of querying the ADK application. + """ + from google.adk.agents.live_request_queue import LiveRequest + from google.adk.agents.live_request_queue import LiveRequestQueue + from agentplatform._genai import _agent_engines_utils + + # Manual type check needed as Pydantic doesn't support asyncio.Queue. + if not isinstance(request_queue, asyncio.Queue): + raise TypeError("request_queue must be an asyncio.Queue instance.") + + first_request = await request_queue.get() + user_id = first_request.get("user_id") + if not user_id: + raise ValueError("The first request must have a user_id.") + + session_id = first_request.get("session_id") + run_config = first_request.get("run_config") + first_live_request = first_request.get("live_request") + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + state = first_request.get("state") + session = await self.async_create_session(user_id=user_id, state=state) + session_id = session["id"] if isinstance(session, dict) else session.id + run_config = _validate_run_config(run_config) + + live_request_queue = LiveRequestQueue() + + if first_live_request and isinstance(first_live_request, Dict): + live_request_queue.send(LiveRequest.model_validate(first_live_request)) + + # Forwards live requests to the agent. + async def _forward_requests(): + while True: + request = await request_queue.get() + live_request = LiveRequest.model_validate(request) + live_request_queue.send(live_request) + + # Forwards events to the client. + async def _forward_events(): + if run_config: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + run_config=run_config, + ) + else: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + ) + async for event in events_async: + yield _agent_engines_utils.dump_event_for_json(event) + + requests_task = asyncio.create_task(_forward_requests()) + + try: + async for event in _forward_events(): + yield event + finally: + requests_task.cancel() + try: + await requests_task + except asyncio.CancelledError: + pass + + async def async_get_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Get a session for the given user. + + Args: + user_id (str): + Required. The ID of the user. + session_id (str): + Required. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + session service. + + Returns: + Session: The session instance (if any). It returns None if the + session is not found. + + Raises: + RuntimeError: If the session is not found. + """ + if not self._tmpl_attrs.get("session_service"): + self.set_up() + session = await self._tmpl_attrs.get("session_service").get_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + if not session: + raise RuntimeError( + "Session not found. Please create it using .create_session()" + ) + return session + + def get_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Deprecated. Use async_get_session instead. + + Get a session for the given user. + """ + warnings.warn( + ( + "AdkApp.get_session(...) is deprecated. " + "Use AdkApp.async_get_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#get-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_get_session(): + return await self.async_get_session( + user_id=user_id, session_id=session_id, **kwargs + ) + + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_get_session()) + event_queue.put(result) + except Exception as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + + # Wait for the thread to finish + thread.join() + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError( + "Session not found. Please create it using .create_session()" + ) from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + + async def async_list_sessions(self, *, user_id: str, **kwargs): + """List sessions for the given user. + + Args: + user_id (str): + Required. The ID of the user. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + session service. + + Returns: + ListSessionsResponse: The list of sessions. + """ + if not self._tmpl_attrs.get("session_service"): + self.set_up() + return await self._tmpl_attrs.get("session_service").list_sessions( + app_name=self._app_name(), + user_id=user_id, + **kwargs, + ) + + def list_sessions(self, *, user_id: str, **kwargs): + """Deprecated. Use async_list_sessions instead. + + List sessions for the given user. + """ + warnings.warn( + ( + "AdkApp.list_sessions(...) is deprecated. " + "Use AdkApp.async_list_sessions(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#list-sessions " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue() + + async def _invoke_async_list_sessions(): + try: + response = await self.async_list_sessions(user_id=user_id, **kwargs) + event_queue.put(response) + except RuntimeError as e: + event_queue.put(e) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_list_sessions()) + finally: + event_queue.put(None) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + try: + return event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to list sessions.") from None + + async def async_create_session( + self, + *, + user_id: str, + session_id: Optional[str] = None, + state: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Creates a new session. + + Args: + user_id (str): + Required. The ID of the user. + session_id (str): + Optional. The ID of the session. If not provided, an ID + will be be generated for the session. + state (dict[str, Any]): + Optional. The initial state of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + session service. + + Returns: + Session: The newly created session instance. + """ + if not self._tmpl_attrs.get("session_service"): + self.set_up() + session = await self._tmpl_attrs.get("session_service").create_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + state=state, + **kwargs, + ) + return self._serialize(session) + + def create_session( + self, + *, + user_id: str, + session_id: Optional[str] = None, + state: Optional[Dict[str, Any]] = None, + **kwargs, + ): + """Deprecated. Use async_create_session instead. + + Creates a new session. + """ + warnings.warn( + ( + "AdkApp.create_session(...) is deprecated. " + "Use AdkApp.async_create_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#create-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_create_session(): + return await self.async_create_session( + user_id=user_id, + session_id=session_id, + state=state, + **kwargs, + ) + + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_create_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to create session.") from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + + async def async_delete_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Deletes a session for the given user. + + Args: + user_id (str): + Required. The ID of the user. + session_id (str): + Required. The ID of the session. + **kwargs (dict[str, Any]): + Optional. Additional keyword arguments to pass to the + session service. + """ + if not self._tmpl_attrs.get("session_service"): + self.set_up() + await self._tmpl_attrs.get("session_service").delete_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + + def delete_session( + self, + *, + user_id: str, + session_id: str, + **kwargs, + ): + """Deprecated. Use async_delete_session instead. + + Deletes a session for the given user. + """ + warnings.warn( + ( + "AdkApp.delete_session(...) is deprecated. " + "Use AdkApp.async_delete_session(...) instead. See " + "https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/use/adk#delete-session " + "for more details." + ), + DeprecationWarning, + stacklevel=2, + ) + event_queue = queue.Queue(maxsize=1) + + async def _invoke_async_delete_session(): + await self.async_delete_session( + user_id=user_id, session_id=session_id, **kwargs + ) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_delete_session()) + event_queue.put(None) + except RuntimeError as e: + event_queue.put(e) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + + outcome = event_queue.get(timeout=10) + if isinstance(outcome, RuntimeError): + raise outcome from None + + async def async_add_session_to_memory(self, *, session: Dict[str, Any]): + """Generates memories. + + Args: + session (Dict[str, Any]): + Required. The session to use for generating memories. It should + be a dictionary representing an ADK Session object, e.g. + session.model_dump(mode="json"). + """ + from google.adk.sessions.session import Session + + if isinstance(session, Dict): + session = Session.model_validate(session) + elif not isinstance(session, Session): + raise TypeError("session must be a Session object.") + if not session.events: + # Get the latest version of the session in case it was updated. + session = await self.async_get_session( + user_id=session.user_id, + session_id=session.id, + ) + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").add_session_to_memory( + session=session, + ) + + async def async_search_memory(self, *, user_id: str, query: str): + """Searches memories for the given user. + + Args: + user_id: The id of the user. + query: The query to match the memories on. + + Returns: + A SearchMemoryResponse containing the matching memories. + """ + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").search_memory( + app_name=self._app_name(), + user_id=user_id, + query=query, + ) + + def register_operations(self) -> Dict[str, List[str]]: + """Registers the operations of the ADK application.""" + return { + "": [ + "get_session", + "list_sessions", + "create_session", + "delete_session", + ], + "async": [ + "async_get_session", + "async_list_sessions", + "async_create_session", + "async_delete_session", + "async_add_session_to_memory", + "async_search_memory", + ], + "stream": ["stream_query"], + "async_stream": [ + "async_stream_query", + "streaming_agent_run_with_events", + ], + "bidi_stream": ["bidi_stream_query"], + } + + def _telemetry_enabled(self) -> Optional[bool]: + """Return status of telemetry enablement depending on enablement env variable. + + In detail: + - Logging is always enabled when telemetry is enabled. + - Tracing is enabled depending on the truth table seen in `_tracing_enabled` method, in order to not break existing user enablement. + + Returns: + True if telemetry is enabled, False if telemetry is disabled, or None + if telemetry enablement is not set (i.e. old deployments which don't support this env variable). + """ + import os + + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" + ) + + env_value = os.getenv( + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "unspecified" + ).lower() + + if env_value in ("true", "1"): + return True + if env_value in ("false", "0"): + return False + return None + + # Tracing enablement follows truth table: + def _tracing_enabled(self) -> bool: + """Tracing enablement follows true table: + + | enable_tracing | enable_telemetry(env) | tracing_actually_enabled | + |----------------|-----------------------|--------------------------| + | false | false | false | + | false | true | false | + | false | None | false | + | true | false | false | + | true | true | true | + | true | None | true | + | None(default) | false | false | + | None(default) | true | adk_version >= 1.17 | + | None(default) | None | false | + """ + enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") + enable_telemetry: Optional[bool] = self._telemetry_enabled() + + return (enable_tracing is True and enable_telemetry is not False) or ( + enable_tracing is None + and enable_telemetry is True + and is_version_sufficient("1.17.0") + ) + + def _get_project_id(self, project: str) -> Optional[str]: + if project: + try: + from google.cloud.aiplatform.utils import ( + resource_manager_utils, + ) + from google.api_core import exceptions + + return resource_manager_utils.get_project_id(project) + # Fail open as temporary workaround for identity_type config parameter + except (exceptions.PermissionDenied, exceptions.Unauthenticated): + return project + + return None diff --git a/agentplatform/frameworks/ag2.py b/agentplatform/frameworks/ag2.py new file mode 100644 index 0000000000..30cbfe1a24 --- /dev/null +++ b/agentplatform/frameworks/ag2.py @@ -0,0 +1,501 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Union, +) +import os +import copy + +if TYPE_CHECKING: + try: + from autogen import agentchat + + ConversableAgent = agentchat.ConversableAgent + ChatResult = agentchat.ChatResult + except ImportError: + ConversableAgent = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +def _prepare_runnable_kwargs( + runnable_kwargs: Mapping[str, Any], + system_instruction: str, + runnable_name: str, + llm_config: Mapping[str, Any], +) -> Mapping[str, Any]: + """Prepares the configuration for a runnable, applying defaults and enforcing constraints.""" + if runnable_kwargs is None: + runnable_kwargs = {} + + if ( + "human_input_mode" in runnable_kwargs + and runnable_kwargs["human_input_mode"] != "NEVER" + ): + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + f"human_input_mode={runnable_kwargs['human_input_mode']}" + "is not supported. Will be enforced to 'NEVER'." + ) + runnable_kwargs["human_input_mode"] = "NEVER" + + if "system_message" not in runnable_kwargs and system_instruction: + runnable_kwargs["system_message"] = system_instruction + + if "name" not in runnable_kwargs: + runnable_kwargs["name"] = runnable_name + + if "llm_config" not in runnable_kwargs: + runnable_kwargs["llm_config"] = llm_config + + return runnable_kwargs + + +def _default_runnable_builder( + **runnable_kwargs: Any, +) -> "ConversableAgent": + from autogen import agentchat + + return agentchat.ConversableAgent(**runnable_kwargs) + + +def _default_instrumentor_builder(project_id: str): + from agentplatform._genai import _agent_engines_utils + + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_autogen = _agent_engines_utils._import_openinference_autogen_or_warn() + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = ( + _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() + ) + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_autogen, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=project_id, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(project_id), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider() + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple AG2Agents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + instrumentor = openinference_autogen.AutogenInstrumentor() + instrumentor.uninstrument() + instrumentor.instrument() + return instrumentor + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + return None + + +def _validate_callable_parameters_are_annotated(callable: Callable): + """Validates that the parameters of the callable have type annotations. + + This ensures that they can be used for constructing AG2 tools that are + usable with Gemini function calling. + """ + import inspect + + parameters = dict(inspect.signature(callable).parameters) + for name, parameter in parameters.items(): + if parameter.annotation == inspect.Parameter.empty: + raise TypeError( + f"Callable={callable.__name__} has untyped input_arg={name}. " + f"Please specify a type when defining it, e.g. `{name}: str`." + ) + + +def _validate_tools(tools: Sequence[Callable[..., Any]]): + """Validates that the tools are usable for tool calling.""" + for tool in tools: + if isinstance(tool, Callable): + _validate_callable_parameters_are_annotated(tool) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple AG2Agents in the same environment, + it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new AG2Agent is created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class AG2Agent: + """An AG2 Agent.""" + + agent_framework = "ag2" + + def __init__( + self, + model: str, + runnable_name: str, + *, + api_type: Optional[str] = None, + llm_config: Optional[Mapping[str, Any]] = None, + system_instruction: Optional[str] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + runnable_builder: Optional[Callable[..., "ConversableAgent"]] = None, + tools: Optional[Sequence[Callable[..., Any]]] = None, + enable_tracing: bool = False, + instrumentor_builder: Optional[Callable[..., Any]] = None, + ): + """Initializes the AG2 Agent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + ```python + # runnable_builder + runnable = runnable_builder( + llm_config=llm_config, + system_message=system_instruction, + **runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to + ```python + # llm_config + llm_config = { + "config_list": [{ + "project_id": initializer.global_config.project, + "location": initializer.global_config.location, + "model": "gemini-1.0-pro-001", + "api_type": "google", + }] + } + + # runnable_builder + runnable = ConversableAgent( + llm_config=llm_config, + name="Default AG2 Agent" + system_message="You are a helpful AI Assistant.", + human_input_mode="NEVER", + ) + ``` + + By default, if `llm_config` is not specified, a default configuration + will be created using the provided `model` and `api_type`. + + If `runnable_builder` is not specified, a default runnable builder will + be used, configured with the `system_instruction`, `runnable_name` and + `runnable_kwargs`. + + Args: + model (str): + Required. The name of the model (e.g. "gemini-1.0-pro"). + Used to create a default `llm_config` if one is not provided. + This parameter is ignored if `llm_config` is provided. + runnable_name (str): + Required. The name of the runnable. + This name is used as the default `runnable_kwargs["name"]` + unless `runnable_kwargs` already contains a "name", in which + case the provided `runnable_kwargs["name"]` will be used. + api_type (str): + Optional. The API type to use for the language model. + Used to create a default `llm_config` if one is not provided. + This parameter is ignored if `llm_config` is provided. + llm_config (Mapping[str, Any]): + Optional. Configuration dictionary for the language model. + If provided, this configuration will be used directly. + Otherwise, a default `llm_config` will be created using `model` + and `api_type`. This `llm_config` is used as the default + `runnable_kwargs["llm_config"]` unless `runnable_kwargs` already + contains a "llm_config", in which case the provided + `runnable_kwargs["llm_config"]` will be used. + system_instruction (str): + Optional. The system instruction for the agent. + This instruction is used as the default + `runnable_kwargs["system_message"]` unless `runnable_kwargs` + already contains a "system_message", in which case the provided + `runnable_kwargs["system_message"]` will be used. + runnable_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + the runnable. Details of the kwargs can be found in + https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent. + `runnable_kwargs` only supports `human_input_mode="NEVER"`. + Other `human_input_mode` values will trigger a warning. + runnable_builder (Callable[..., "ConversableAgent"]): + Optional. Callable that returns a new runnable. This can be used + for customizing the orchestration logic of the Agent. + If not provided, a default runnable builder will be used. + tools (Sequence[Callable[..., Any]]): + Optional. The tools for the agent to be able to use. All input + callables (e.g. function or class method) will be converted + to a AG2 tool . Defaults to None. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. + instrumentor_builder (Callable[..., Any]): + Optional. Callable that returns a new instrumentor. This can be + used for customizing the instrumentation logic of the Agent. + If not provided, a default instrumentor builder will be used. + This parameter is ignored if `enable_tracing` is False. + """ + from google.cloud.aiplatform import initializer + + self._tmpl_attrs: dict[str, Any] = { + "model_name": model, + "api_type": api_type or "google", + "system_instruction": system_instruction, + "runnable_name": runnable_name, + "tools": [], + "ag2_tool_objects": [], + "runnable": None, + "runnable_builder": runnable_builder, + "instrumentor": None, + "instrumentor_builder": instrumentor_builder, + "enable_tracing": enable_tracing, + "provided_llm_config": copy.deepcopy(llm_config), + "provided_runnable_kwargs": copy.deepcopy(runnable_kwargs), + } + if tools: + # We validate tools at initialization for actionable feedback before + # they are deployed. + _validate_tools(tools) + self._tmpl_attrs["tools"] = tools + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the runnable, binds the runnable with tools. + + This method should not be called for an object that being passed to + the ReasoningEngine service for deployment, as it initializes clients + that can not be serialized. + """ + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" + ) or os.environ.get("GOOGLE_CLOUD_LOCATION") + llm_config = { + "config_list": [ + { + "project_id": project, + "location": location, + "model": self._tmpl_attrs.get("model_name"), + "api_type": self._tmpl_attrs.get("api_type"), + } + ] + } + if self._tmpl_attrs.get("provided_llm_config"): + llm_config = self._tmpl_attrs.get("provided_llm_config") + + runnable_kwargs = _prepare_runnable_kwargs( + runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"), + llm_config=llm_config, + system_instruction=self._tmpl_attrs.get("system_instruction"), + runnable_name=self._tmpl_attrs.get("runnable_name"), + ) + + if self._tmpl_attrs.get("enable_tracing"): + instrumentor_builder = ( + self._tmpl_attrs.get("instrumentor_builder") + or _default_instrumentor_builder + ) + self._tmpl_attrs["instrumentor"] = instrumentor_builder(project_id=project) + + # Set up tools. + tools = self._tmpl_attrs.get("tools") + ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects") + if tools and not ag2_tool_objects: + from agentplatform._genai import ( + _agent_engines_utils, + ) + + autogen_tools = _agent_engines_utils._import_autogen_tools_or_warn() + if autogen_tools: + for tool in tools: + ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool)) + + # Set up runnable. + runnable_builder = ( + self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder + ) + self._tmpl_attrs["runnable"] = runnable_builder(**runnable_kwargs) + + def clone(self) -> "AG2Agent": + """Returns a clone of the AG2Agent.""" + import copy + + return AG2Agent( + model=self._tmpl_attrs.get("model_name"), + api_type=self._tmpl_attrs.get("api_type"), + llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")), + system_instruction=self._tmpl_attrs.get("system_instruction"), + runnable_name=self._tmpl_attrs.get("runnable_name"), + tools=copy.deepcopy(self._tmpl_attrs.get("tools")), + runnable_kwargs=copy.deepcopy( + self._tmpl_attrs.get("provided_runnable_kwargs") + ), + runnable_builder=self._tmpl_attrs.get("runnable_builder"), + enable_tracing=self._tmpl_attrs.get("enable_tracing"), + instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), + ) + + def query( + self, + *, + input: Union[str, Mapping[str, Any]], + max_turns: Optional[int] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Queries the Agent with the given input. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + max_turns (int): + Optional. The maximum number of turns to run the agent for. + If not provided, the agent will run indefinitely. + If `max_turns` is a `float`, it will be converted to `int` + through rounding. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.run()` method of the corresponding runnable. + Details of the kwargs can be found in + https://docs.ag2.ai/docs/api-reference/autogen/ConversableAgent#run. + The `user_input` parameter defaults to `False`, and should not + be passed through `kwargs`. + + Returns: + The output of querying the Agent with the given input. + """ + if isinstance(input, str): + input = {"content": input} + + if max_turns and isinstance(max_turns, float): + # Supporting auto-conversion float to int. + max_turns = round(max_turns) + + if "user_input" in kwargs: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "The `user_input` parameter should not be passed through" + "kwargs. The `user_input` defaults to `False`." + ) + kwargs.pop("user_input") + + if not self._tmpl_attrs.get("runnable"): + self.set_up() + + response = self._tmpl_attrs.get("runnable").run( + message=input, + user_input=False, + tools=self._tmpl_attrs.get("ag2_tool_objects"), + max_turns=max_turns, + **kwargs, + ) + + from agentplatform._genai import _agent_engines_utils + + return _agent_engines_utils.to_json_serializable_autogen_object(response) diff --git a/agentplatform/frameworks/langchain.py b/agentplatform/frameworks/langchain.py new file mode 100644 index 0000000000..7e8caa9eab --- /dev/null +++ b/agentplatform/frameworks/langchain.py @@ -0,0 +1,715 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Union, +) + +if TYPE_CHECKING: + try: + from langchain_core import runnables + from langchain_core import tools as lc_tools + from langchain_core.language_models import base as lc_language_models + + BaseTool = lc_tools.BaseTool + BaseLanguageModel = lc_language_models.BaseLanguageModel + GetSessionHistoryCallable = runnables.history.GetSessionHistoryCallable + RunnableConfig = runnables.RunnableConfig + RunnableSerializable = runnables.RunnableSerializable + except ImportError: + BaseTool = Any + BaseLanguageModel = Any + GetSessionHistoryCallable = Any + RunnableConfig = Any + RunnableSerializable = Any + + try: + from langchain_google_genai.functions_utils import _ToolsType + except ImportError: + try: + from langchain_google_vertexai.functions_utils import _ToolsType + except ImportError: + _ToolsType = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +def _default_runnable_kwargs(has_history: bool) -> Mapping[str, Any]: + # https://github.com/langchain-ai/langchain/blob/5784dfed001730530637793bea1795d9d5a7c244/libs/core/langchain_core/runnables/history.py#L237-L241 + runnable_kwargs = { + # input_messages_key (str): Must be specified if the underlying + # agent accepts a dict as input. + "input_messages_key": "input", + # output_messages_key (str): Must be specified if the underlying + # agent returns a dict as output. + "output_messages_key": "output", + } + if has_history: + # history_messages_key (str): Must be specified if the underlying + # agent accepts a dict as input and a separate key for historical + # messages. + runnable_kwargs["history_messages_key"] = "history" + return runnable_kwargs + + +def _default_output_parser(): + try: + from langchain_classic.agents.output_parsers.tools import ToolsAgentOutputParser + except (ModuleNotFoundError, ImportError): + try: + from langchain.agents.output_parsers.tools import ToolsAgentOutputParser + except (ModuleNotFoundError, ImportError): + # Fallback to an older version if needed. + from langchain.agents.output_parsers.openai_tools import ( + OpenAIToolsAgentOutputParser as ToolsAgentOutputParser, + ) + return ToolsAgentOutputParser() + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "BaseLanguageModel": + model_kwargs = model_kwargs or {} + try: + from langchain_google_genai import ChatGoogleGenerativeAI + + model = ChatGoogleGenerativeAI( + model=model_name, + project=project, + location=location, + vertexai=True, + **model_kwargs, + ) + return model + except ImportError: + from langchain_google_vertexai import ChatVertexAI + + model = ChatVertexAI( + model_name=model_name, project=project, location=location, **model_kwargs + ) + return model + + +def _default_runnable_builder( + model: "BaseLanguageModel", + *, + system_instruction: Optional[str] = None, + tools: Optional["_ToolsType"] = None, + prompt: Optional["RunnableSerializable"] = None, + output_parser: Optional["RunnableSerializable"] = None, + chat_history: Optional["GetSessionHistoryCallable"] = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + agent_executor_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +) -> "RunnableSerializable": + from langchain_core import tools as lc_tools + + try: + from langchain_classic.agents import AgentExecutor + except ImportError: + from langchain.agents import AgentExecutor + + try: + from langchain_core.tools import StructuredTool + except ImportError: + from langchain.tools.base import StructuredTool + + # The prompt template and runnable_kwargs needs to be customized depending + # on whether the user intends for the agent to have history. The way the + # user would reflect that is by setting chat_history (which defaults to + # None). + has_history: bool = chat_history is not None + prompt = prompt or _default_prompt( + has_history=has_history, + system_instruction=system_instruction, + ) + output_parser = output_parser or _default_output_parser() + model_tool_kwargs = model_tool_kwargs or {} + agent_executor_kwargs = agent_executor_kwargs or {} + runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history) + if tools: + model = model.bind_tools(tools=tools, **model_tool_kwargs) + else: + tools = [] + agent_executor = AgentExecutor( + agent=prompt | model | output_parser, + tools=[ + ( + tool + if isinstance(tool, lc_tools.BaseTool) + else StructuredTool.from_function(tool) + ) + for tool in tools + if isinstance(tool, (Callable, lc_tools.BaseTool)) + ], + **agent_executor_kwargs, + ) + if has_history: + from langchain_core.runnables.history import RunnableWithMessageHistory + + return RunnableWithMessageHistory( + runnable=agent_executor, + get_session_history=chat_history, + **runnable_kwargs, + ) + return agent_executor + + +def _default_instrumentor_builder(project_id: str): + from agentplatform._genai import _agent_engines_utils + + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_langchain = ( + _agent_engines_utils._import_openinference_langchain_or_warn() + ) + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = ( + _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() + ) + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_langchain, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=project_id, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(project_id), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider() + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple LangchainAgents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + instrumentor = openinference_langchain.LangChainInstrumentor() + if instrumentor.is_instrumented_by_opentelemetry: + instrumentor.uninstrument() + instrumentor.instrument() + return instrumentor + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + return None + + +def _default_prompt( + has_history: bool, + system_instruction: Optional[str] = None, +) -> "RunnableSerializable": + from langchain_core import prompts + + try: + from langchain_classic.agents.format_scratchpad.tools import ( + format_to_tool_messages, + ) + except (ModuleNotFoundError, ImportError): + try: + from langchain.agents.format_scratchpad.tools import format_to_tool_messages + except (ModuleNotFoundError, ImportError): + from langchain.agents.format_scratchpad.openai_tools import ( + format_to_openai_tool_messages as format_to_tool_messages, + ) + + system_instructions = [] + if system_instruction: + system_instructions = [("system", system_instruction)] + + if has_history: + return { + "history": lambda x: x["history"], + "input": lambda x: x["input"], + "agent_scratchpad": ( + lambda x: format_to_tool_messages(x["intermediate_steps"]) + ), + } | prompts.ChatPromptTemplate.from_messages( + system_instructions + + [ + prompts.MessagesPlaceholder(variable_name="history"), + ("user", "{input}"), + prompts.MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + else: + return { + "input": lambda x: x["input"], + "agent_scratchpad": ( + lambda x: format_to_tool_messages(x["intermediate_steps"]) + ), + } | prompts.ChatPromptTemplate.from_messages( + system_instructions + + [ + ("user", "{input}"), + prompts.MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + + +def _validate_callable_parameters_are_annotated(callable: Callable): + """Validates that the parameters of the callable have type annotations. + + This ensures that they can be used for constructing LangChain tools that are + usable with Gemini function calling. + """ + import inspect + + parameters = dict(inspect.signature(callable).parameters) + for name, parameter in parameters.items(): + if parameter.annotation == inspect.Parameter.empty: + raise TypeError( + f"Callable={callable.__name__} has untyped input_arg={name}. " + f"Please specify a type when defining it, e.g. `{name}: str`." + ) + + +def _validate_tools(tools: "_ToolsType"): + """Validates that the tools are usable for tool calling.""" + for tool in tools: + if isinstance(tool, Callable): + _validate_callable_parameters_are_annotated(tool) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LangchainAgents in the same environment, + it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LangchainAgent is created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class LangchainAgent: + """A Langchain Agent. + + See https://cloud.google.com/vertex-ai/generative-ai/docs/reasoning-engine/develop + for details. + """ + + agent_framework = "langchain" + + def __init__( + self, + model: str, + *, + system_instruction: Optional[str] = None, + prompt: Optional["RunnableSerializable"] = None, + tools: Optional["_ToolsType"] = None, + output_parser: Optional["RunnableSerializable"] = None, + chat_history: Optional["GetSessionHistoryCallable"] = None, + model_kwargs: Optional[Mapping[str, Any]] = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + agent_executor_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable] = None, + runnable_builder: Optional[Callable] = None, + enable_tracing: bool = False, + instrumentor_builder: Optional[Callable[..., Any]] = None, + ): + """Initializes the LangchainAgent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + + ``` + model = model_builder(model_name=model, model_kwargs=model_kwargs) + runnable = runnable_builder( + prompt=prompt, + model=model, + tools=tools, + output_parser=output_parser, + chat_history=chat_history, + agent_executor_kwargs=agent_executor_kwargs, + runnable_kwargs=runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to + ``` + # model_builder + from langchain_google_vertexai import ChatVertexAI + llm = ChatVertexAI(model_name=model, **model_kwargs) + + # runnable_builder + from langchain import agents + from langchain_core.runnables.history import RunnableWithMessageHistory + llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs) + agent_executor = agents.AgentExecutor( + agent=prompt | llm_with_tools | output_parser, + tools=tools, + **agent_executor_kwargs, + ) + runnable = RunnableWithMessageHistory( + runnable=agent_executor, + get_session_history=chat_history, + **runnable_kwargs, + ) + ``` + + Args: + model (str): + Optional. The name of the model (e.g. "gemini-1.0-pro"). + system_instruction (str): + Optional. The system instruction to use for the agent. This + argument should not be specified if `prompt` is specified. + prompt (langchain_core.runnables.RunnableSerializable): + Optional. The prompt template for the model. Defaults to a + ChatPromptTemplate. + tools (Sequence[langchain_core.tools.BaseTool, Callable]): + Optional. The tools for the agent to be able to use. All input + callables (e.g. function or class method) will be converted + to a langchain.tools.base.StructuredTool. Defaults to None. + output_parser (langchain_core.runnables.RunnableSerializable): + Optional. The output parser for the model. Defaults to an + output parser that works with Gemini function-calling. + chat_history (langchain_core.runnables.history.GetSessionHistoryCallable): + Optional. Callable that returns a new BaseChatMessageHistory. + Defaults to None, i.e. chat_history is not preserved. + model_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + chat_models.ChatVertexAI. An example would be + ``` + { + # temperature (float): Sampling temperature, it controls the + # degree of randomness in token selection. + "temperature": 0.28, + # max_output_tokens (int): Token limit determines the + # maximum amount of text output from one prompt. + "max_output_tokens": 1000, + # top_p (float): Tokens are selected from most probable to + # least, until the sum of their probabilities equals the + # top_p value. + "top_p": 0.95, + # top_k (int): How the model selects tokens for output, the + # next token is selected from among the top_k most probable + # tokens. + "top_k": 40, + } + ``` + model_tool_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments when binding tools to the + model using `model.bind_tools()`. + agent_executor_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + langchain.agents.AgentExecutor. An example would be + ``` + { + # Whether to return the agent's trajectory of intermediate + # steps at the end in addition to the final output. + "return_intermediate_steps": False, + # The maximum number of steps to take before ending the + # execution loop. + "max_iterations": 15, + # The method to use for early stopping if the agent never + # returns `AgentFinish`. Either 'force' or 'generate'. + "early_stopping_method": "force", + # How to handle errors raised by the agent's output parser. + # Defaults to `False`, which raises the error. + "handle_parsing_errors": False, + } + ``` + runnable_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + langchain.runnables.history.RunnableWithMessageHistory if + chat_history is specified. If chat_history is None, this will be + ignored. + model_builder (Callable): + Optional. Callable that returns a new language model. Defaults + to a a callable that returns ChatVertexAI based on `model`, + `model_kwargs` and the parameters in `vertexai.init`. + runnable_builder (Callable): + Optional. Callable that returns a new runnable. This can be used + for customizing the orchestration logic of the Agent based on + the model returned by `model_builder` and the rest of the input + arguments. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. + instrumentor_builder (Callable[..., Any]): + Optional. Callable that returns a new instrumentor. This can be + used for customizing the instrumentation logic of the Agent. + If not provided, a default instrumentor builder will be used. + This parameter is ignored if `enable_tracing` is False. + + Raises: + ValueError: If both `prompt` and `system_instruction` are specified. + TypeError: If there is an invalid tool (e.g. function with an input + that did not specify its type). + """ + from google.cloud.aiplatform import initializer + + self._tmpl_attrs: dict[str, Any] = { + "tools": [], + "model_name": model, + "system_instruction": system_instruction, + "prompt": prompt, + "output_parser": output_parser, + "chat_history": chat_history, + "model_kwargs": model_kwargs, + "model_tool_kwargs": model_tool_kwargs, + "agent_executor_kwargs": agent_executor_kwargs, + "runnable_kwargs": runnable_kwargs, + "model_builder": model_builder, + "runnable_builder": runnable_builder, + "enable_tracing": enable_tracing, + "model": None, + "runnable": None, + "instrumentor": None, + "instrumentor_builder": instrumentor_builder, + } + if tools: + # We validate tools at initialization for actionable feedback before + # they are deployed. + _validate_tools(tools) + self._tmpl_attrs["tools"] = tools + if prompt and system_instruction: + raise ValueError( + "Only one of `prompt` or `system_instruction` should be specified. " + "Consider incorporating the system instruction into the prompt " + "rather than passing it separately as an argument." + ) + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the model, binds the model with tools, and connects it + with the prompt template and output parser. + + This method should not be called for an object being passed to the + service for deployment, as it might initialize clients that can not be + serialized. + """ + import os + + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" + ) or os.environ.get("GOOGLE_CLOUD_LOCATION") + if self._tmpl_attrs.get("enable_tracing"): + instrumentor_builder = ( + self._tmpl_attrs.get("instrumentor_builder") + or _default_instrumentor_builder + ) + self._tmpl_attrs["instrumentor"] = instrumentor_builder(project_id=project) + model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder + self._tmpl_attrs["model"] = model_builder( + model_name=self._tmpl_attrs.get("model_name"), + model_kwargs=self._tmpl_attrs.get("model_kwargs"), + project=project, + location=location, + ) + runnable_builder = ( + self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder + ) + self._tmpl_attrs["runnable"] = runnable_builder( + prompt=self._tmpl_attrs.get("prompt"), + model=self._tmpl_attrs.get("model"), + tools=self._tmpl_attrs.get("tools"), + system_instruction=self._tmpl_attrs.get("system_instruction"), + output_parser=self._tmpl_attrs.get("output_parser"), + chat_history=self._tmpl_attrs.get("chat_history"), + model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"), + agent_executor_kwargs=self._tmpl_attrs.get("agent_executor_kwargs"), + runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"), + ) + + def clone(self) -> "LangchainAgent": + """Returns a clone of the LangchainAgent.""" + import copy + + return LangchainAgent( + model=self._tmpl_attrs.get("model_name"), + system_instruction=self._tmpl_attrs.get("system_instruction"), + prompt=copy.deepcopy(self._tmpl_attrs.get("prompt")), + tools=copy.deepcopy(self._tmpl_attrs.get("tools")), + output_parser=copy.deepcopy(self._tmpl_attrs.get("output_parser")), + chat_history=copy.deepcopy(self._tmpl_attrs.get("chat_history")), + model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")), + model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")), + agent_executor_kwargs=copy.deepcopy( + self._tmpl_attrs.get("agent_executor_kwargs") + ), + runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")), + model_builder=self._tmpl_attrs.get("model_builder"), + runnable_builder=self._tmpl_attrs.get("runnable_builder"), + enable_tracing=self._tmpl_attrs.get("enable_tracing"), + instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), + ) + + def query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional["RunnableConfig"] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Returns: + The output of querying the Agent with the given input and config. + """ + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd + + if isinstance(input, str): + input = {"input": input} + if not self._tmpl_attrs.get("runnable"): + self.set_up() + return dumpd( + self._tmpl_attrs.get("runnable").invoke( + input=input, config=config, **kwargs + ) + ) + + def stream_query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional["RunnableConfig"] = None, + **kwargs, + ) -> Iterable[Any]: + """Stream queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Yields: + The output of querying the Agent with the given input and config. + """ + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load import dump as langchain_load_dump + + dumpd = langchain_load_dump.dumpd + + if isinstance(input, str): + input = {"input": input} + if not self._tmpl_attrs.get("runnable"): + self.set_up() + for chunk in self._tmpl_attrs.get("runnable").stream( + input=input, + config=config, + **kwargs, + ): + yield dumpd(chunk) diff --git a/agentplatform/frameworks/langgraph.py b/agentplatform/frameworks/langgraph.py new file mode 100644 index 0000000000..0a1cef5ce1 --- /dev/null +++ b/agentplatform/frameworks/langgraph.py @@ -0,0 +1,711 @@ +# -*- coding: utf-8 -*- +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Union, +) + +if TYPE_CHECKING: + try: + from langchain_core.language_models import base as lc_language_models + + BaseLanguageModel = lc_language_models.BaseLanguageModel + except ImportError: + BaseLanguageModel = Any + + try: + from langchain_google_genai.functions_utils import _ToolsType + + _ToolLike = _ToolsType + except ImportError: + try: + from langchain_google_vertexai.functions_utils import _ToolsType + + _ToolLike = _ToolsType + except ImportError: + _ToolLike = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + try: + from langgraph_checkpoint.checkpoint import base + + BaseCheckpointSaver = base.BaseCheckpointSaver + except ImportError: + try: + from langgraph.checkpoint import base + + BaseCheckpointSaver = base.BaseCheckpointSaver + except ImportError: + BaseCheckpointSaver = Any + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "BaseLanguageModel": + """Default callable for building a language model. + + Args: + model_name (str): + Required. The name of the model (e.g. "gemini-1.0-pro"). + project (str): + Required. The Google Cloud project ID. + location (str): + Required. The Google Cloud location. + model_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + chat_models.ChatVertexAI. + + Returns: + BaseLanguageModel: The language model. + """ + model_kwargs = model_kwargs or {} + try: + from langchain_google_genai import ChatGoogleGenerativeAI + + model = ChatGoogleGenerativeAI( + model=model_name, + project=project, + location=location, + vertexai=True, + **model_kwargs, + ) + return model + except ImportError: + from langchain_google_vertexai import ChatVertexAI + + model = ChatVertexAI( + model_name=model_name, project=project, location=location, **model_kwargs + ) + return model + + +def _default_runnable_builder( + model: "BaseLanguageModel", + *, + tools: Optional[Sequence["_ToolLike"]] = None, + checkpointer: Optional[Any] = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +): + """Default callable for building a runnable. + + Args: + model (BaseLanguageModel): + Required. The language model. + tools (Optional[Sequence[_ToolLike]]): + Optional. The tools for the agent to be able to use. + checkpointer (Optional[Checkpointer]): + Optional. The checkpointer for the agent. + model_tool_kwargs (Optional[Mapping[str, Any]]): + Optional. Additional keyword arguments when binding tools to the model. + runnable_kwargs (Optional[Mapping[str, Any]]): + Optional. Additional keyword arguments for the runnable. + + Returns: + RunnableSerializable: The runnable. + """ + from langgraph import prebuilt as langgraph_prebuilt + + model_tool_kwargs = model_tool_kwargs or {} + runnable_kwargs = runnable_kwargs or {} + if tools: + model = model.bind_tools(tools=tools, **model_tool_kwargs) + else: + tools = [] + if checkpointer: + if "checkpointer" in runnable_kwargs: + from google.cloud.aiplatform import base + + base.Logger(__name__).warning( + "checkpointer is being specified in both checkpointer_builder " + "and runnable_kwargs. Please specify it in only one of them. " + "Overriding the checkpointer in runnable_kwargs." + ) + runnable_kwargs["checkpointer"] = checkpointer + return langgraph_prebuilt.create_react_agent( + model, + tools=tools, + **runnable_kwargs, + ) + + +def _default_instrumentor_builder(project_id: str): + from agentplatform._genai import _agent_engines_utils + + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_langchain = ( + _agent_engines_utils._import_openinference_langchain_or_warn() + ) + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = ( + _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() + ) + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_langchain, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=project_id, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(project_id), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = opentelemetry.trace.get_tracer_provider() + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + base.Logger(__name__).warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple LangchainAgents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + instrumentor = openinference_langchain.LangChainInstrumentor() + if instrumentor.is_instrumented_by_opentelemetry: + instrumentor.uninstrument() + instrumentor.instrument() + return instrumentor + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + return None + + +def _validate_callable_parameters_are_annotated(callable: Callable): + """Validates that the parameters of the callable have type annotations. + + This ensures that they can be used for constructing LangChain tools that are + usable with Gemini function calling. + + Args: + callable (Callable): The callable to validate. + + Raises: + TypeError: If any parameter is not annotated. + """ + import inspect + + parameters = dict(inspect.signature(callable).parameters) + for name, parameter in parameters.items(): + if parameter.annotation == inspect.Parameter.empty: + raise TypeError( + f"Callable={callable.__name__} has untyped input_arg={name}. " + f"Please specify a type when defining it, e.g. `{name}: str`." + ) + + +def _validate_tools(tools: Sequence["_ToolLike"]): + """Validates that the tools are usable for tool calling. + + Args: + tools (Sequence[_ToolLike]): The tools to validate. + + Raises: + TypeError: If any tool is a callable with untyped parameters. + """ + for tool in tools: + if isinstance(tool, Callable): + _validate_callable_parameters_are_annotated(tool) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LangchainAgents in the same environment, + it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LangchainAgent is created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class LanggraphAgent: + """A LangGraph Agent.""" + + agent_framework = "langgraph" + + def __init__( + self, + model: str, + *, + tools: Optional[Sequence["_ToolLike"]] = None, + model_kwargs: Optional[Mapping[str, Any]] = None, + model_tool_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable[..., "BaseLanguageModel"]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + runnable_builder: Optional[Callable[..., Any]] = None, + checkpointer_kwargs: Optional[Mapping[str, Any]] = None, + checkpointer_builder: Optional[Callable[..., "BaseCheckpointSaver"]] = None, + enable_tracing: bool = False, + instrumentor_builder: Optional[Callable[..., Any]] = None, + ): + """Initializes the LangGraph Agent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + ```python + model = model_builder(model_name=model, model_kwargs=model_kwargs) + runnable = runnable_builder( + model=model, + tools=tools, + model_tool_kwargs=model_tool_kwargs, + runnable_kwargs=runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to + ```python + # model_builder + from langchain_google_vertexai import ChatVertexAI + llm = ChatVertexAI(model_name=model, **model_kwargs) + + # runnable_builder + from langgraph.prebuilt import create_react_agent + llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs) + runnable = create_react_agent( + llm_with_tools, + tools=tools, + **runnable_kwargs, + ) + ``` + + By default, no checkpointer is used (i.e. there is no state history). To + enable checkpointing, provide a `checkpointer_builder` function that + returns a checkpointer instance. + + **Example using Spanner:** + ```python + def checkpointer_builder(instance_id, database_id, project_id, **kwargs): + from langchain_google_spanner import SpannerCheckpointSaver + + checkpointer = SpannerCheckpointSaver(instance_id, database_id, project_id) + with checkpointer.cursor() as cur: + cur.execute("DROP TABLE IF EXISTS checkpoints") + cur.execute("DROP TABLE IF EXISTS checkpoint_writes") + checkpointer.setup() + + return checkpointer + ``` + + **Example using an in-memory checkpointer:** + ```python + def checkpointer_builder(**kwargs): + from langgraph.checkpoint.memory import MemorySaver + + return MemorySaver() + ``` + + The `checkpointer_builder` function will be called with any keyword + arguments passed to the agent's constructor. Ensure your + `checkpointer_builder` function accepts `**kwargs` to handle these + arguments, even if unused. + + Args: + model (str): + Optional. The name of the model (e.g. "gemini-1.0-pro"). + tools (Sequence[langchain_core.tools.BaseTool, Callable]): + Optional. The tools for the agent to be able to use. All input + callables (e.g. function or class method) will be converted + to a langchain.tools.base.StructuredTool. Defaults to None. + model_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + chat_models.ChatVertexAI. An example would be + ``` + { + # temperature (float): Sampling temperature, it controls the + # degree of randomness in token selection. + "temperature": 0.28, + # max_output_tokens (int): Token limit determines the + # maximum amount of text output from one prompt. + "max_output_tokens": 1000, + # top_p (float): Tokens are selected from most probable to + # least, until the sum of their probabilities equals the + # top_p value. + "top_p": 0.95, + # top_k (int): How the model selects tokens for output, the + # next token is selected from among the top_k most probable + # tokens. + "top_k": 40, + } + ``` + model_tool_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments when binding tools to the + model using `model.bind_tools()`. + model_builder (Callable[..., BaseLanguageModel]): + Optional. Callable that returns a new language model. Defaults + to a a callable that returns ChatVertexAI based on `model`, + `model_kwargs` and the parameters in `vertexai.init`. + runnable_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + langchain.runnables.history.RunnableWithMessageHistory if + chat_history is specified. If chat_history is None, this will be + ignored. + runnable_builder (Callable[..., RunnableSerializable]): + Optional. Callable that returns a new runnable. This can be used + for customizing the orchestration logic of the Agent based on + the model returned by `model_builder` and the rest of the input + arguments. + checkpointer_kwargs (Mapping[str, Any]): + Optional. Additional keyword arguments for the constructor of + the checkpointer returned by `checkpointer_builder`. + checkpointer_builder (Callable[..., "BaseCheckpointSaver"]): + Optional. Callable that returns a checkpointer. This can be used + for defining the checkpointer of the Agent. Defaults to None. + enable_tracing (bool): + Optional. Whether to enable tracing in Cloud Trace. Defaults to + False. + instrumentor_builder (Callable[..., Any]): + Optional. Callable that returns a new instrumentor. This can be + used for customizing the instrumentation logic of the Agent. + If not provided, a default instrumentor builder will be used. + This parameter is ignored if `enable_tracing` is False. + + Raises: + TypeError: If there is an invalid tool (e.g. function with an input + that did not specify its type). + """ + from google.cloud.aiplatform import initializer + + self._tmpl_attrs: dict[str, Any] = { + "tools": [], + "model_name": model, + "model_kwargs": model_kwargs, + "model_tool_kwargs": model_tool_kwargs, + "runnable_kwargs": runnable_kwargs, + "checkpointer_kwargs": checkpointer_kwargs, + "model": None, + "model_builder": model_builder, + "runnable": None, + "runnable_builder": runnable_builder, + "checkpointer": None, + "checkpointer_builder": checkpointer_builder, + "enable_tracing": enable_tracing, + "instrumentor": None, + "instrumentor_builder": instrumentor_builder, + } + if tools: + # We validate tools at initialization for actionable feedback before + # they are deployed. + _validate_tools(tools) + self._tmpl_attrs["tools"] = tools + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the model, binds the model with tools, and connects it + with the prompt template and output parser. + + This method should not be called for an object that being passed to + the ReasoningEngine service for deployment, as it initializes clients + that can not be serialized. + """ + import os + + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get( + "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" + ) or os.environ.get("GOOGLE_CLOUD_LOCATION") + if self._tmpl_attrs.get("enable_tracing"): + instrumentor_builder = ( + self._tmpl_attrs.get("instrumentor_builder") + or _default_instrumentor_builder + ) + self._tmpl_attrs["instrumentor"] = instrumentor_builder(project_id=project) + model_builder = self._tmpl_attrs.get("model_builder") or _default_model_builder + self._tmpl_attrs["model"] = model_builder( + model_name=self._tmpl_attrs.get("model_name"), + model_kwargs=self._tmpl_attrs.get("model_kwargs"), + project=project, + location=location, + ) + checkpointer_builder = self._tmpl_attrs.get("checkpointer_builder") + if checkpointer_builder: + checkpointer_kwargs = self._tmpl_attrs.get("checkpointer_kwargs") or {} + self._tmpl_attrs["checkpointer"] = checkpointer_builder( + **checkpointer_kwargs + ) + runnable_builder = ( + self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder + ) + self._tmpl_attrs["runnable"] = runnable_builder( + model=self._tmpl_attrs.get("model"), + tools=self._tmpl_attrs.get("tools"), + checkpointer=self._tmpl_attrs.get("checkpointer"), + model_tool_kwargs=self._tmpl_attrs.get("model_tool_kwargs"), + runnable_kwargs=self._tmpl_attrs.get("runnable_kwargs"), + ) + + def clone(self) -> "LanggraphAgent": + """Returns a clone of the LanggraphAgent.""" + import copy + + return LanggraphAgent( + model=self._tmpl_attrs.get("model_name"), + tools=copy.deepcopy(self._tmpl_attrs.get("tools")), + model_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_kwargs")), + model_tool_kwargs=copy.deepcopy(self._tmpl_attrs.get("model_tool_kwargs")), + runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")), + checkpointer_kwargs=copy.deepcopy( + self._tmpl_attrs.get("checkpointer_kwargs") + ), + model_builder=self._tmpl_attrs.get("model_builder"), + runnable_builder=self._tmpl_attrs.get("runnable_builder"), + checkpointer_builder=self._tmpl_attrs.get("checkpointer_builder"), + enable_tracing=self._tmpl_attrs.get("enable_tracing"), + instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), + ) + + def query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Returns: + The output of querying the Agent with the given input and config. + """ + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load.dump import dumpd + + if isinstance(input, str): + input = {"input": input, "messages": [("user", input)]} + if not self._tmpl_attrs.get("runnable"): + self.set_up() + return dumpd( + self._tmpl_attrs.get("runnable").invoke( + input=input, config=config, **kwargs + ) + ) + + def stream_query( + self, + *, + input: Union[str, Mapping[str, Any]], + config: Optional[dict[str, Any]] = None, + **kwargs, + ) -> Iterable[Any]: + """Stream queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + config (langchain_core.runnables.RunnableConfig): + Optional. The config (if any) to be used for invoking the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Yields: + The output of querying the Agent with the given input and config. + """ + try: + from langchain_core.load import dumpd + except ImportError: + from langchain.load.dump import dumpd + + if isinstance(input, str): + input = {"input": input, "messages": [("user", input)]} + if not self._tmpl_attrs.get("runnable"): + self.set_up() + for chunk in self._tmpl_attrs.get("runnable").stream( + input=input, + config=config, + **kwargs, + ): + yield dumpd(chunk) + + def get_state_history( + self, + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Iterable[Any]: + """Gets the state history of the Agent. + + Args: + config (Optional[RunnableConfig]): + Optional. The config for invoking the Agent. + **kwargs: + Optional. Additional keyword arguments for the `.invoke()` method. + + Yields: + Dict[str, Any]: The state history of the Agent. + """ + if not self._tmpl_attrs.get("runnable"): + self.set_up() + for state_snapshot in self._tmpl_attrs.get("runnable").get_state_history( + config=config, + **kwargs, + ): + yield state_snapshot._asdict() + + def get_state( + self, + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Gets the current state of the Agent. + + Args: + config (Optional[RunnableConfig]): + Optional. The config for invoking the Agent. + **kwargs: + Optional. Additional keyword arguments for the `.invoke()` method. + + Returns: + Dict[str, Any]: The current state of the Agent. + """ + if not self._tmpl_attrs.get("runnable"): + self.set_up() + return ( + self._tmpl_attrs.get("runnable") + .get_state(config=config, **kwargs) + ._asdict() + ) + + def update_state( + self, + config: Optional[dict[str, Any]] = None, + **kwargs: Any, + ) -> Dict[str, Any]: + """Updates the state of the Agent. + + Args: + config (Optional[RunnableConfig]): + Optional. The config for invoking the Agent. + **kwargs: + Optional. Additional keyword arguments for the `.invoke()` method. + + Returns: + Dict[str, Any]: The updated state of the Agent. + """ + if not self._tmpl_attrs.get("runnable"): + self.set_up() + return self._tmpl_attrs.get("runnable").update_state(config=config, **kwargs) + + def register_operations(self) -> Mapping[str, Sequence[str]]: + """Registers the operations of the Agent. + + This mapping defines how different operation modes (e.g., "", "stream") + are implemented by specific methods of the Agent. The "default" mode, + represented by the empty string ``, is associated with the `query` API, + while the "stream" mode is associated with the `stream_query` API. + + Returns: + Mapping[str, Sequence[str]]: A mapping of operation modes to a list + of method names that implement those operation modes. + """ + return { + "": ["query", "get_state", "update_state"], + "stream": ["stream_query", "get_state_history"], + } diff --git a/agentplatform/frameworks/llama_index.py b/agentplatform/frameworks/llama_index.py new file mode 100644 index 0000000000..c8cbbda813 --- /dev/null +++ b/agentplatform/frameworks/llama_index.py @@ -0,0 +1,558 @@ +# -*- coding: utf-8 -*- +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Mapping, + Optional, + Sequence, + Union, +) + +if TYPE_CHECKING: + try: + from llama_index.core.base.query_pipeline import query + from llama_index.core.llms import function_calling + from llama_index.core import query_pipeline + + FunctionCallingLLM = function_calling.FunctionCallingLLM + QueryComponent = query.QUERY_COMPONENT_TYPE + QueryPipeline = query_pipeline.QueryPipeline + except ImportError: + FunctionCallingLLM = Any + QueryComponent = Any + QueryPipeline = Any + + try: + from opentelemetry.sdk import trace + + TracerProvider = trace.TracerProvider + SpanProcessor = trace.SpanProcessor + SynchronousMultiSpanProcessor = trace.SynchronousMultiSpanProcessor + except ImportError: + TracerProvider = Any + SpanProcessor = Any + SynchronousMultiSpanProcessor = Any + + +def _default_model_builder( + model_name: str, + *, + project: str, + location: str, + model_kwargs: Optional[Mapping[str, Any]] = None, +) -> "FunctionCallingLLM": + """Creates a default model builder for LlamaIndex.""" + from llama_index.llms import google_genai + + model_kwargs = model_kwargs or {} + model = google_genai.GoogleGenAI( + model=model_name, + vertexai_config={"project": project, "location": location}, + **model_kwargs, + ) + return model + + +def _default_runnable_builder( + model: "FunctionCallingLLM", + *, + system_instruction: Optional[str] = None, + prompt: Optional["QueryComponent"] = None, + retriever: Optional["QueryComponent"] = None, + response_synthesizer: Optional["QueryComponent"] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, +) -> "QueryPipeline": + """Creates a default runnable builder for LlamaIndex.""" + try: + from llama_index.core.query_pipeline import QueryPipeline + except ImportError: + raise ImportError( + "Please call 'pip install google-cloud-aiplatform[llama_index]'." + ) + + prompt = prompt or _default_prompt( + system_instruction=system_instruction, + ) + pipeline = QueryPipeline(**runnable_kwargs) + pipeline_modules = { + "prompt": prompt, + "model": model, + } + if retriever: + pipeline_modules["retriever"] = retriever + if response_synthesizer: + pipeline_modules["response_synthesizer"] = response_synthesizer + + pipeline.add_modules(pipeline_modules) + pipeline.add_link("prompt", "model") + if "retriever" in pipeline_modules: + pipeline.add_link("model", "retriever") + if "response_synthesizer" in pipeline_modules: + pipeline.add_link("model", "response_synthesizer", dest_key="query_str") + if "retriever" in pipeline_modules: + pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes") + + return pipeline + + +def _default_prompt( + system_instruction: Optional[str] = None, +) -> "QueryComponent": + """Creates a default prompt template for LlamaIndex. + + Handles both system instruction and user input. + + Args: + system_instruction (str, optional): The system instruction to use. + + Returns: + QueryComponent: The LlamaIndex QueryComponent. + """ + try: + from llama_index.core import prompts + from llama_index.core.base.llms import types + except ImportError: + raise ImportError( + "Please call 'pip install google-cloud-aiplatform[llama_index]'." + ) + + # Define a prompt template + message_templates = [] + if system_instruction: + message_templates.append( + types.ChatMessage(role=types.MessageRole.SYSTEM, content=system_instruction) + ) + # Add user input message + message_templates.append( + types.ChatMessage(role=types.MessageRole.USER, content="{input}") + ) + + # Create the prompt template + return prompts.ChatPromptTemplate(message_templates=message_templates) + + +def _override_active_span_processor( + tracer_provider: "TracerProvider", + active_span_processor: "SynchronousMultiSpanProcessor", +): + """Overrides the active span processor. + + When working with multiple LlamaIndexQueryPipelineAgents in the same + environment, it's crucial to manage trace exports carefully. + Each agent needs its own span processor tied to a unique project ID. + While we add a new span processor for each agent, this can lead to + unexpected behavior. + For instance, with two agents linked to different projects, traces from the + second agent might be sent to both projects. + To prevent this and guarantee traces go to the correct project, we overwrite + the active span processor whenever a new LlamaIndexQueryPipelineAgent is + created. + + Args: + tracer_provider (TracerProvider): + The tracer provider to use for the project. + active_span_processor (SynchronousMultiSpanProcessor): + The active span processor overrides the tracer provider's + active span processor. + """ + if tracer_provider._active_span_processor: + tracer_provider._active_span_processor.shutdown() + tracer_provider._active_span_processor = active_span_processor + + +class LlamaIndexQueryPipelineAgent: + """A LlamaIndex Query Pipeline Agent. + + This agent uses a query pipeline for LLAIndex, including prompt, model, + retrieval and summarization steps. More details can be found in + https://docs.llamaindex.ai/en/stable/module_guides/querying/pipeline/. + """ + + agent_framework = "llama-index" + + def __init__( + self, + model: str, + *, + system_instruction: Optional[str] = None, + prompt: Optional["QueryComponent"] = None, + model_kwargs: Optional[Mapping[str, Any]] = None, + model_builder: Optional[Callable[..., "FunctionCallingLLM"]] = None, + retriever_kwargs: Optional[Mapping[str, Any]] = None, + retriever_builder: Optional[Callable[..., "QueryComponent"]] = None, + response_synthesizer_kwargs: Optional[Mapping[str, Any]] = None, + response_synthesizer_builder: Optional[Callable[..., "QueryComponent"]] = None, + runnable_kwargs: Optional[Mapping[str, Any]] = None, + runnable_builder: Optional[Callable[..., "QueryPipeline"]] = None, + enable_tracing: bool = False, + ): + """Initializes the LlamaIndexQueryPipelineAgent. + + Under-the-hood, assuming .set_up() is called, this will correspond to + ```python + # model_builder + model = model_builder(model_name, project, location, model_kwargs) + + # runnable_builder + runnable = runnable_builder( + prompt=prompt, + model=model, + retriever=retriever_builder(model, retriever_kwargs), + response_synthesizer=response_synthesizer_builder( + model, response_synthesizer_kwargs + ), + runnable_kwargs=runnable_kwargs, + ) + ``` + + When everything is based on their default values, this corresponds to a + query pipeline `Prompt - Model`: + ```python + # Default Model Builder + model = google_genai.GoogleGenAI( + model=model_name, + vertexai_config={ + "project": initializer.global_config.project, + "location": initializer.global_config.location, + }, + ) + + # Default Prompt Builder + prompt = prompts.ChatPromptTemplate( + message_templates=[ + types.ChatMessage( + role=types.MessageRole.USER, + content="{input}", + ), + ], + ) + + # Default Runnable Builder + runnable = QueryPipeline( + modules = { + "prompt": prompt, + "model": model, + }, + ) + pipeline.add_link("prompt", "model") + ``` + + When `system_instruction` is specified, the prompt will be updated to + include the system instruction. + ```python + # Updated Prompt Builder + prompt = prompts.ChatPromptTemplate( + message_templates=[ + types.ChatMessage( + role=types.MessageRole.SYSTEM, + content=system_instruction, + ), + types.ChatMessage( + role=types.MessageRole.USER, + content="{input}", + ), + ], + ) + ``` + + When all inputs are specified, this corresponds to a query pipeline + `Prompt - Model - Retriever - Summarizer`: + ```python + runnable = QueryPipeline( + modules = { + "prompt": prompt, + "model": model, + "retriever": retriever_builder(retriever_kwargs), + "response_synthesizer": response_synthesizer_builder( + response_synthesizer_kwargs + ), + }, + ) + pipeline.add_link("prompt", "model") + pipeline.add_link("model", "retriever") + pipeline.add_link("model", "response_synthesizer", dest_key="query_str") + pipeline.add_link("retriever", "response_synthesizer", dest_key="nodes") + ``` + + Args: + model (str): + The name of the model (e.g. "gemini-1.0-pro"). + system_instruction (str): + Optional. The system instruction to use for the agent. + prompt (llama_index.core.base.query_pipeline.query.QUERY_COMPONENT_TYPE): + Optional. The prompt template for the model. + model_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the model constructor of the + google_genai.GoogleGenAI. An example of a model_kwargs is: + ```python + { + # api_key (string): The API key for the GoogleGenAI model. + # The API can also be fetched from the GOOGLE_API_KEY + # environment variable. If `vertexai_config` is provided, + # the API key is ignored. + "api_key": "your_api_key", + # temperature (float): Sampling temperature, it controls the + # degree of randomness in token selection. If not provided, + # the default temperature is 0.1. + "temperature": 0.1, + # context_window (int): The context window of the model. + # If not provided, the default context window is 200000. + "context_window": 200000, + # max_tokens (int): Token limit determines the maximum + # amount of text output from one prompt. If not provided, + # the default max_tokens is 256. + "max_tokens": 256, + # is_function_calling_model (bool): Whether the model is a + # function calling model. If not provided, the default + # is_function_calling_model is True. + "is_function_calling_model": True, + } + ``` + model_builder (Callable): + Optional. Callable that returns a language model. + retriever_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the retriever constructor. + retriever_builder (Callable): + Optional. Callable that returns a retriever object. + response_synthesizer_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the response synthesizer constructor. + response_synthesizer_builder (Callable): + Optional. Callable that returns a response_synthesizer object. + runnable_kwargs (Mapping[str, Any]): + Optional. Keyword arguments for the runnable constructor. + runnable_builder (Callable): + Optional. Callable that returns a runnable (query pipeline). + enable_tracing (bool): + Optional. Whether to enable tracing. Defaults to False. + """ + self._model_name = model + self._system_instruction = system_instruction + self._prompt = prompt + + self._model = None + self._model_kwargs = model_kwargs or {} + self._model_builder = model_builder + + self._retriever = None + self._retriever_kwargs = retriever_kwargs or {} + self._retriever_builder = retriever_builder + + self._response_synthesizer = None + self._response_synthesizer_kwargs = response_synthesizer_kwargs or {} + self._response_synthesizer_builder = response_synthesizer_builder + + self._runnable = None + self._runnable_kwargs = runnable_kwargs or {} + self._runnable_builder = runnable_builder + + self._instrumentor = None + self._enable_tracing = enable_tracing + + def set_up(self): + """Sets up the agent for execution of queries at runtime. + + It initializes the model, connects it with the prompt template, + retriever and response_synthesizer. + + This method should not be called for an object that being passed to + the ReasoningEngine service for deployment, as it initializes clients + that can not be serialized. + """ + import os + + project = os.environ.get("GOOGLE_CLOUD_PROJECT") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") + if self._enable_tracing: + from agentplatform._genai import ( + _agent_engines_utils, + ) + + cloud_trace_exporter = ( + _agent_engines_utils._import_cloud_trace_exporter_or_warn() + ) + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_llama_index = ( + _agent_engines_utils._import_openinference_llama_index_or_warn() + ) + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = ( + _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() + ) + if all( + ( + cloud_trace_exporter, + cloud_trace_v2, + openinference_llama_index, + opentelemetry, + opentelemetry_sdk_trace, + ) + ): + import google.auth + + credentials, _ = google.auth.default() + span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( + project_id=project, + client=cloud_trace_v2.TraceServiceClient( + credentials=credentials.with_quota_project(project), + ), + ) + span_processor: SpanProcessor = ( + opentelemetry_sdk_trace.export.SimpleSpanProcessor( + span_exporter=span_exporter, + ) + ) + tracer_provider: TracerProvider = ( + opentelemetry.trace.get_tracer_provider() + ) + # Get the appropriate tracer provider: + # 1. If _TRACER_PROVIDER is already set, use that. + # 2. Otherwise, if the OTEL_PYTHON_TRACER_PROVIDER environment + # variable is set, use that. + # 3. As a final fallback, use _PROXY_TRACER_PROVIDER. + # If none of the above is set, we log a warning, and + # create a tracer provider. + if not tracer_provider: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "No tracer provider. By default, " + "we should get one of the following providers: " + "OTEL_PYTHON_TRACER_PROVIDER, _TRACER_PROVIDER, " + "or _PROXY_TRACER_PROVIDER." + ) + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids AttributeError: + # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no + # attribute 'add_span_processor'. + if _agent_engines_utils.is_noop_or_proxy_tracer_provider( + tracer_provider + ): + tracer_provider = opentelemetry_sdk_trace.TracerProvider() + opentelemetry.trace.set_tracer_provider(tracer_provider) + # Avoids OpenTelemetry client already exists error. + _override_active_span_processor( + tracer_provider, + opentelemetry_sdk_trace.SynchronousMultiSpanProcessor(), + ) + tracer_provider.add_span_processor(span_processor) + # Keep the instrumentation up-to-date. + # When creating multiple LlamaIndexQueryPipelineAgents, + # we need to keep the instrumentation up-to-date. + # We deliberately override the instrument each time, + # so that if different agents end up using different + # instrumentations, we guarantee that the user is always + # working with the most recent agent's instrumentation. + self._instrumentor = openinference_llama_index.LlamaIndexInstrumentor() + if self._instrumentor.is_instrumented_by_opentelemetry: + self._instrumentor.uninstrument() + self._instrumentor.instrument() + else: + from google.cloud.aiplatform import base + + _LOGGER = base.Logger(__name__) + _LOGGER.warning( + "enable_tracing=True but proceeding with tracing disabled " + "because not all packages for tracing have been installed" + ) + + model_builder = self._model_builder or _default_model_builder + self._model = model_builder( + model_name=self._model_name, + model_kwargs=self._model_kwargs, + project=project, + location=location, + ) + + if self._retriever_builder: + self._retriever = self._retriever_builder( + model=self._model, + retriever_kwargs=self._retriever_kwargs, + ) + + if self._response_synthesizer_builder: + self._response_synthesizer = self._response_synthesizer_builder( + model=self._model, + response_synthesizer_kwargs=self._response_synthesizer_kwargs, + ) + + runnable_builder = self._runnable_builder or _default_runnable_builder + self._runnable = runnable_builder( + prompt=self._prompt, + model=self._model, + system_instruction=self._system_instruction, + retriever=self._retriever, + response_synthesizer=self._response_synthesizer, + runnable_kwargs=self._runnable_kwargs, + ) + + def clone(self) -> "LlamaIndexQueryPipelineAgent": + """Returns a clone of the LlamaIndexQueryPipelineAgent.""" + import copy + + return LlamaIndexQueryPipelineAgent( + model=self._model_name, + system_instruction=self._system_instruction, + prompt=copy.deepcopy(self._prompt), + model_kwargs=copy.deepcopy(self._model_kwargs), + model_builder=self._model_builder, + retriever_kwargs=copy.deepcopy(self._retriever_kwargs), + retriever_builder=self._retriever_builder, + response_synthesizer_kwargs=copy.deepcopy( + self._response_synthesizer_kwargs + ), + response_synthesizer_builder=self._response_synthesizer_builder, + runnable_kwargs=copy.deepcopy(self._runnable_kwargs), + runnable_builder=self._runnable_builder, + enable_tracing=self._enable_tracing, + ) + + def query( + self, + input: Union[str, Mapping[str, Any]], + **kwargs: Any, + ) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]: + """Queries the Agent with the given input and config. + + Args: + input (Union[str, Mapping[str, Any]]): + Required. The input to be passed to the Agent. + **kwargs: + Optional. Any additional keyword arguments to be passed to the + `.invoke()` method of the corresponding AgentExecutor. + + Returns: + The output of querying the Agent with the given input and config. + """ + from agentplatform._genai import _agent_engines_utils + + if isinstance(input, str): + input = {"input": input} + + if not self._runnable: + self.set_up() + + if kwargs.get("batch"): + nest_asyncio = _agent_engines_utils._import_nest_asyncio_or_warn() + nest_asyncio.apply() + + return _agent_engines_utils.to_json_serializable_llama_index_object( + self._runnable.run(**input, **kwargs) + ) diff --git a/tests/unit/vertexai/genai/replays/conftest.py b/tests/unit/agentplatform/genai/replays/conftest.py similarity index 97% rename from tests/unit/vertexai/genai/replays/conftest.py rename to tests/unit/agentplatform/genai/replays/conftest.py index 0f9eed0c5b..e61b247de1 100644 --- a/tests/unit/vertexai/genai/replays/conftest.py +++ b/tests/unit/agentplatform/genai/replays/conftest.py @@ -19,11 +19,11 @@ import os from unittest import mock -from vertexai._genai import ( - client as vertexai_genai_client_module, +from agentplatform._genai import ( + client as agentplatform_genai_client_module, ) -from vertexai._genai import _agent_engines_utils -from vertexai._genai.client import ( +from agentplatform._genai import _agent_engines_utils +from agentplatform._genai.client import ( _GENAI_MODULES_TELEMETRY_HEADER, ) from google.cloud.aiplatform import version as aip_version @@ -31,8 +31,8 @@ from google.genai import _replay_api_client from google.genai import types as genai_types from google.genai import client as google_genai_client_module -from vertexai._genai import _gcs_utils -from vertexai._genai import prompt_optimizer +from agentplatform._genai import _gcs_utils +from agentplatform._genai import prompt_optimizer import pytest @@ -272,7 +272,7 @@ def client(use_vertex, replays_prefix, http_options, request): replay_client = _replay_api_client.ReplayApiClient( mode=mode, replay_id=replay_id, - vertexai=use_vertex, + agentplatform=use_vertex, http_options=http_options, ) @@ -280,7 +280,7 @@ def client(use_vertex, replays_prefix, http_options, request): google_genai_client_module.Client, "_get_api_client" ) as patch_method: patch_method.return_value = replay_client - google_genai_client = vertexai_genai_client_module.Client() + google_genai_client = agentplatform_genai_client_module.Client() if mode != "replay": yield google_genai_client @@ -313,7 +313,9 @@ def client(use_vertex, replays_prefix, http_options, request): ) as mock_job_wait: mock_job_wait.return_value = None - google_genai_client = vertexai_genai_client_module.Client() + google_genai_client = ( + agentplatform_genai_client_module.Client() + ) # Yield the client so that cleanup can be completed at the end of the test. yield google_genai_client diff --git a/tests/unit/vertexai/genai/replays/credentials.json b/tests/unit/agentplatform/genai/replays/credentials.json similarity index 100% rename from tests/unit/vertexai/genai/replays/credentials.json rename to tests/unit/agentplatform/genai/replays/credentials.json diff --git a/tests/unit/vertexai/genai/replays/pytest_helper.py b/tests/unit/agentplatform/genai/replays/pytest_helper.py similarity index 100% rename from tests/unit/vertexai/genai/replays/pytest_helper.py rename to tests/unit/agentplatform/genai/replays/pytest_helper.py diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_delete.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_delete.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_memories_delete.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_delete.py index cf7dc1189f..d41a534be4 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_delete.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_get.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_get.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_memories_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_get.py index 4b5fde1f8c..f1633b360f 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_create.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_create.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_create.py index 1aba8565e2..aeffc1be38 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_create_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_generate.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_generate.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_generate.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_generate.py index c577b20277..015766a646 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_generate.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_generate.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_generate_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_get_generate_memories_operation.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_generate_memories_operation.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_get_generate_memories_operation.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_generate_memories_operation.py index b888f606ef..54fda3511e 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_get_generate_memories_operation.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_generate_memories_operation.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get_generate_memories_operation(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_get_memory_operation.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_memory_operation.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_get_memory_operation.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_memory_operation.py index 587b1f7806..17bf74f6bc 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_get_memory_operation.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_get_memory_operation.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get_memory_operation(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_list.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_list.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_list.py index 0b564591c5..ece015a636 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_list_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_purge.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_purge.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_purge.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_purge.py index bb4a4a473c..4fb0909aee 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_purge.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_purge.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_purge(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_retrieve.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_retrieve.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_retrieve.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_retrieve.py index 9c7ebda52b..e09a5f960d 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_retrieve.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_retrieve.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_retrieve(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_rollback.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_rollback.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_rollback.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_rollback.py index e9ff84fdb1..f4f1a1a5e6 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_rollback.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_rollback.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_rollback(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memories_private_update.py b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_update.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memories_private_update.py rename to tests/unit/agentplatform/genai/replays/test_ae_memories_private_update.py index c11387bc36..4a1554cc43 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memories_private_update.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memories_private_update.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_update_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memory_revisions_get.py b/tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_get.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memory_revisions_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_get.py index 73d7ea049d..b932273474 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memory_revisions_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_memory_revisions(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_memory_revisions_private_list.py b/tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_private_list.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_memory_revisions_private_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_private_list.py index a160745c28..e026e48eb7 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_memory_revisions_private_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_memory_revisions_private_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_list_memory_revisions(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_create.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_create.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_create.py index ee6b935dbc..70ffd10c02 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_create.py @@ -16,8 +16,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring,bad-indentation -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_create_sandbox_snapshot(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_delete.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_delete.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_delete.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_delete.py index d3c0b43ca0..70bbcab6f3 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_delete.py @@ -16,8 +16,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring,bad-indentation -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_sandbox_snapshot(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_get.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_get.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_get.py index e9e6435449..1f56fe2f38 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_get.py @@ -16,8 +16,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring,bad-indentation -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_sandbox_snapshot(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_list.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_list.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_list.py index cd5676bf51..37f3f4cf74 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_snapshots_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_snapshots_list.py @@ -16,8 +16,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring,bad-indentation -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_list_sandbox_snapshots(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_byoc_create.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_byoc_create.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_byoc_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_byoc_create.py index 72a0f69e95..a268b9dccf 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_byoc_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_byoc_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_sandbox_templates_byoc_create(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_default_create.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_default_create.py similarity index 92% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_default_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_default_create.py index d251ca2933..8d0d14e140 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_default_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_default_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_sandbox_templates_default_create(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_delete.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_delete.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_delete.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_delete.py index e4c7511f0e..ae33132e1a 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_delete.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_sandbox_templates_delete(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get.py index 6cc8340f95..f5a062e70c 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_sandbox_templates_get(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py index ddbea1a715..a7e6de09ad 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_get_sandbox_template_operation.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_sandbox_template_operation(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_list.py b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_list.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_list.py index d488c41ff3..d1ec4b8240 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandbox_templates_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandbox_templates_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_sandbox_templates_list(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_create.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_create.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_create.py index 7a6f822bee..63bb4b16ee 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_create(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_delete.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_delete.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_delete.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_delete.py index 45c3eb4cb9..23d35628ce 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_delete.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_delete(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_execute_code.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_execute_code.py similarity index 93% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_execute_code.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_execute_code.py index ee07284663..46116410b3 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_execute_code.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_execute_code.py @@ -16,8 +16,8 @@ import json -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_execute_code(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get.py index 35a3c41fb1..34f6d10de7 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py index bdf1202372..0613482517 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_get_sandbox_operation.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get_operation(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_list.py b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_list.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_list.py index 0457b49067..d513d1e514 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_sandboxes_private_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_sandboxes_private_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_list(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_delete.py b/tests/unit/agentplatform/genai/replays/test_ae_session_delete.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_session_delete.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_delete.py index f077cd26bf..faba1f7cee 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_delete.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_session_non_blocking(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_events_append.py b/tests/unit/agentplatform/genai/replays/test_ae_session_events_append.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_session_events_append.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_events_append.py index 4319a78533..f86aef23b5 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_events_append.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_events_append.py @@ -16,8 +16,8 @@ import datetime -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_append_session_event(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_events_private_list.py b/tests/unit/agentplatform/genai/replays/test_ae_session_events_private_list.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_session_events_private_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_events_private_list.py index fb903364cc..52407b5c40 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_events_private_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_events_private_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_list_session_events(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_private_create.py b/tests/unit/agentplatform/genai/replays/test_ae_session_private_create.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_session_private_create.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_private_create.py index 1c4d17013f..20455437c4 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_private_create.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_private_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_create_session(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_private_get.py b/tests/unit/agentplatform/genai/replays/test_ae_session_private_get.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_session_private_get.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_private_get.py index 5ac70531fd..c761dc2326 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_private_get.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_private_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get_session_operation(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_private_list.py b/tests/unit/agentplatform/genai/replays/test_ae_session_private_list.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_ae_session_private_list.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_private_list.py index 4dd8e8c7d2..5ad07423d5 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_private_list.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_private_list.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_list_session(client): diff --git a/tests/unit/vertexai/genai/replays/test_ae_session_private_update.py b/tests/unit/agentplatform/genai/replays/test_ae_session_private_update.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_ae_session_private_update.py rename to tests/unit/agentplatform/genai/replays/test_ae_session_private_update.py index 646e12bd5c..59cbbbe225 100644 --- a/tests/unit/vertexai/genai/replays/test_ae_session_private_update.py +++ b/tests/unit/agentplatform/genai/replays/test_ae_session_private_update.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_update_session(client): diff --git a/tests/unit/vertexai/genai/replays/test_agent_engine_a2a_methods.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_methods.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_agent_engine_a2a_methods.py rename to tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_methods.py index e057a8ed48..69f145eed2 100644 --- a/tests/unit/vertexai/genai/replays/test_agent_engine_a2a_methods.py +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_a2a_methods.py @@ -16,8 +16,8 @@ from unittest import mock -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import _api_client import httpx import pytest diff --git a/tests/unit/vertexai/genai/replays/test_agent_engine_private_create.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_create.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_agent_engine_private_create.py rename to tests/unit/agentplatform/genai/replays/test_agent_engine_private_create.py index 12c9100318..2c2545662d 100644 --- a/tests/unit/vertexai/genai/replays/test_agent_engine_private_create.py +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_create.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_create_with_labels(client): diff --git a/tests/unit/vertexai/genai/replays/test_agent_engine_private_delete.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_delete.py similarity index 88% rename from tests/unit/vertexai/genai/replays/test_agent_engine_private_delete.py rename to tests/unit/agentplatform/genai/replays/test_agent_engine_private_delete.py index 3a6f7d82fa..1c47ae8492 100644 --- a/tests/unit/vertexai/genai/replays/test_agent_engine_private_delete.py +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_delete.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_delete(client): diff --git a/tests/unit/vertexai/genai/replays/test_agent_engine_private_get.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_get.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_agent_engine_private_get.py rename to tests/unit/agentplatform/genai/replays/test_agent_engine_private_get.py index 9c7b08d6ac..90e056d6ec 100644 --- a/tests/unit/vertexai/genai/replays/test_agent_engine_private_get.py +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_get.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_get(client): diff --git a/tests/unit/vertexai/genai/replays/test_agent_engine_private_update.py b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_update.py similarity index 89% rename from tests/unit/vertexai/genai/replays/test_agent_engine_private_update.py rename to tests/unit/agentplatform/genai/replays/test_agent_engine_private_update.py index b4b0914bdc..76b09075b1 100644 --- a/tests/unit/vertexai/genai/replays/test_agent_engine_private_update.py +++ b/tests/unit/agentplatform/genai/replays/test_agent_engine_private_update.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_private_update(client): diff --git a/tests/unit/vertexai/genai/replays/test_append_agent_engine_a2a_task_events.py b/tests/unit/agentplatform/genai/replays/test_append_agent_engine_a2a_task_events.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_append_agent_engine_a2a_task_events.py rename to tests/unit/agentplatform/genai/replays/test_append_agent_engine_a2a_task_events.py index c9d4dbb3ae..38da90438d 100644 --- a/tests/unit/vertexai/genai/replays/test_append_agent_engine_a2a_task_events.py +++ b/tests/unit/agentplatform/genai/replays/test_append_agent_engine_a2a_task_events.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_append_agent_engine_session_event.py b/tests/unit/agentplatform/genai/replays/test_append_agent_engine_session_event.py similarity index 95% rename from tests/unit/vertexai/genai/replays/test_append_agent_engine_session_event.py rename to tests/unit/agentplatform/genai/replays/test_append_agent_engine_session_event.py index b00c199c4b..6c27eecbcb 100644 --- a/tests/unit/vertexai/genai/replays/test_append_agent_engine_session_event.py +++ b/tests/unit/agentplatform/genai/replays/test_append_agent_engine_session_event.py @@ -16,7 +16,7 @@ import datetime -from tests.unit.vertexai.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper def test_append_session_event(client): diff --git a/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py new file mode 100644 index 0000000000..f6e00ced52 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_assemble_multimodal_datasets.py @@ -0,0 +1,100 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +DATASET = "projects/vertex-sdk-dev/locations/us-central1/datasets/8810841321427173376" + + +def test_assemble_dataset(client): + operation = client.datasets._assemble_multimodal_dataset( + name=DATASET, + gemini_request_read_config={ + "template_config": { + "field_mapping": {"question": "questionColumn"}, + }, + }, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +def test_assemble_dataset_public(client): + bigquery_destination = client.datasets.assemble( + name=DATASET, + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + model="gemini-1.5-flash", + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ) + ), + ) + assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_assemble_dataset_async(client): + operation = await client.aio.datasets._assemble_multimodal_dataset( + name=DATASET, + gemini_request_read_config={ + "template_config": { + "field_mapping": {"question": "questionColumn"}, + }, + }, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +@pytest.mark.asyncio +async def test_assemble_dataset_public_async(client): + bigquery_destination = await client.aio.datasets.assemble( + name=DATASET, + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + model="gemini-1.5-flash", + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ) + ), + ) + assert bigquery_destination.startswith(f"bq://{BIGQUERY_TABLE_NAME}") diff --git a/tests/unit/agentplatform/genai/replays/test_assess_multimodal_dataset.py b/tests/unit/agentplatform/genai/replays/test_assess_multimodal_dataset.py new file mode 100644 index 0000000000..0e41ef608b --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_assess_multimodal_dataset.py @@ -0,0 +1,269 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +DATASET = "projects/vertex-sdk-dev/locations/us-central1/datasets/8810841321427173376" + + +def test_assess_dataset(client): + operation = client.datasets._assess_multimodal_dataset( + name=DATASET, + tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( + model_name="gemini-2.5-flash-001" + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ), + ), + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +def test_assess_tuning_resources(client): + response = client.datasets.assess_tuning_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ) + ), + ) + assert isinstance(response, types.TuningResourceUsageAssessmentResult) + + +def test_assess_tuning_validity(client): + response = client.datasets.assess_tuning_validity( + dataset_name=DATASET, + dataset_usage="SFT_VALIDATION", + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.TuningValidationAssessmentResult) + + +def test_assess_batch_prediction_resources(client): + response = client.datasets.assess_batch_prediction_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) + + +def test_assess_batch_prediction_validity(client): + response = client.datasets.assess_batch_prediction_validity( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.BatchPredictionValidationAssessmentResult) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_assess_dataset_async(client): + operation = await client.aio.datasets._assess_multimodal_dataset( + name=DATASET, + tuning_resource_usage_assessment_config=types.TuningResourceUsageAssessmentConfig( + model_name="gemini-2.5-flash-001" + ), + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + ], + ), + ), + ), + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + + +@pytest.mark.asyncio +async def test_assess_tuning_resources_async(client): + response = await client.aio.datasets.assess_tuning_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + } + ], + ), + ) + ), + ) + assert isinstance(response, types.TuningResourceUsageAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_tuning_validity_async(client): + response = await client.aio.datasets.assess_tuning_validity( + dataset_name=DATASET, + dataset_usage="SFT_VALIDATION", + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.TuningValidationAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_batch_prediction_resources_async(client): + response = await client.aio.datasets.assess_batch_prediction_resources( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.BatchPredictionResourceUsageAssessmentResult) + + +@pytest.mark.asyncio +async def test_assess_batch_prediction_validity_async(client): + response = await client.aio.datasets.assess_batch_prediction_validity( + dataset_name=DATASET, + model_name="gemini-2.5-flash-001", + gemini_request_read_config=types.GeminiRequestReadConfig( + template_config=types.GeminiTemplateConfig( + gemini_example=types.GeminiExample( + contents=[ + { + "role": "user", + "parts": [{"text": "What is the capital of {name}?"}], + }, + { + "role": "model", + "parts": [{"text": "{capital}"}], + }, + ], + ), + ) + ), + ) + assert isinstance(response, types.BatchPredictionValidationAssessmentResult) diff --git a/tests/unit/vertexai/genai/replays/test_batch_evaluate.py b/tests/unit/agentplatform/genai/replays/test_batch_evaluate.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_batch_evaluate.py rename to tests/unit/agentplatform/genai/replays/test_batch_evaluate.py index bddd044818..493d2ddf41 100644 --- a/tests/unit/vertexai/genai/replays/test_batch_evaluate.py +++ b/tests/unit/agentplatform/genai/replays/test_batch_evaluate.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine.py index aeb6599968..e7759f6d2b 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine.py @@ -18,8 +18,8 @@ import re import sys -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_a2a_task.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_a2a_task.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_a2a_task.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_a2a_task.py index cde695d0fe..5a52dcbfff 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_a2a_task.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_a2a_task.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_developer_connect.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_developer_connect.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_developer_connect.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_developer_connect.py index d2bc37ee25..aab9f92f4d 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_developer_connect.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_developer_connect.py @@ -16,8 +16,8 @@ import sys -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_docker.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_docker.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_docker.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_docker.py index c4c4cb342f..f76ccd9c61 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_docker.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_docker.py @@ -16,7 +16,7 @@ import sys -from tests.unit.vertexai.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper _TEST_CLASS_METHODS = [ {"name": "async_stream_query", "api_mode": "async_stream"}, diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_memory.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_memory.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_memory.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_memory.py index b441469eff..57f5250ec5 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_memory.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_memory.py @@ -16,8 +16,8 @@ import datetime -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_create_memory_with_ttl(client): diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_sandbox.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_sandbox.py similarity index 93% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_sandbox.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_sandbox.py index 70c3e73269..e849f14eab 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_sandbox.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_sandbox.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_create_sandbox(client): diff --git a/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_session.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py rename to tests/unit/agentplatform/genai/replays/test_create_agent_engine_session.py index a1a8b8930f..666b6d6ad7 100644 --- a/tests/unit/vertexai/genai/replays/test_create_agent_engine_session.py +++ b/tests/unit/agentplatform/genai/replays/test_create_agent_engine_session.py @@ -16,8 +16,8 @@ import datetime -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_create_session_with_ttl(client): diff --git a/tests/unit/agentplatform/genai/replays/test_create_evaluation_item.py b/tests/unit/agentplatform/genai/replays/test_create_evaluation_item.py new file mode 100644 index 0000000000..202133a493 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_create_evaluation_item.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import pytest + +GCS_URI = ( + "gs://lakeyk-limited-bucket/agora_eval_080525/request_4813679498589372416.json" +) +DISPLAY_NAME = "test_eval_item" + + +def test_create_eval_item(client): + """Tests that create_evaluation_item() returns a correctly structured EvaluationItem.""" + evaluation_item = client.evals.create_evaluation_item( + evaluation_item_type=types.EvaluationItemType.REQUEST, + gcs_uri=GCS_URI, + display_name=DISPLAY_NAME, + ) + # Retrieve the evaluation item to check that it was created correctly. + retrieved_evaluation_item = client.evals.get_evaluation_item( + name=evaluation_item.name + ) + check_evaluation_item( + evaluation_item, + retrieved_evaluation_item, + ) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_create_eval_item_async(client): + """Tests that create_evaluation_item() returns a correctly structured EvaluationItem.""" + evaluation_item = await client.aio.evals.create_evaluation_item( + evaluation_item_type=types.EvaluationItemType.REQUEST, + gcs_uri=GCS_URI, + display_name=DISPLAY_NAME, + ) + # Retrieve the evaluation item to check that it was created correctly. + retrieved_evaluation_item = await client.aio.evals.get_evaluation_item( + name=evaluation_item.name + ) + check_evaluation_item( + evaluation_item, + retrieved_evaluation_item, + ) + + +def check_evaluation_item( + evaluation_item: types.EvaluationItem, + retrieved_evaluation_item: types.EvaluationItem, +): + assert isinstance(evaluation_item, types.EvaluationItem) + assert evaluation_item.gcs_uri == GCS_URI + assert evaluation_item.evaluation_item_type == types.EvaluationItemType.REQUEST + assert evaluation_item.display_name == DISPLAY_NAME + assert retrieved_evaluation_item.gcs_uri == GCS_URI + assert ( + retrieved_evaluation_item.evaluation_item_type + == types.EvaluationItemType.REQUEST + ) + assert retrieved_evaluation_item.display_name == DISPLAY_NAME + # Check the request data. + request = retrieved_evaluation_item.evaluation_request + assert ( + "If your ball is curving during flight from left to right" + in request.prompt.text + ) + # Check the first candidate response. + assert request.candidate_responses[0].candidate == "gemini-2.0-flash-001@default" + assert ( + "Keep your knees bent during the backswing" + in request.candidate_responses[0].text + ) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.create_evaluation_item", +) diff --git a/tests/unit/vertexai/genai/replays/test_create_evaluation_run.py b/tests/unit/agentplatform/genai/replays/test_create_evaluation_run.py similarity index 99% rename from tests/unit/vertexai/genai/replays/test_create_evaluation_run.py rename to tests/unit/agentplatform/genai/replays/test_create_evaluation_run.py index 3fb251d16d..790441fae9 100644 --- a/tests/unit/vertexai/genai/replays/test_create_evaluation_run.py +++ b/tests/unit/agentplatform/genai/replays/test_create_evaluation_run.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types from google.genai import types as genai_types import pandas as pd import pytest diff --git a/tests/unit/agentplatform/genai/replays/test_create_evaluation_set.py b/tests/unit/agentplatform/genai/replays/test_create_evaluation_set.py new file mode 100644 index 0000000000..9105609061 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_create_evaluation_set.py @@ -0,0 +1,58 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import pytest + + +EVAL_ITEMS = [ + "projects/503583131166/locations/us-central1/evaluationItems/4411504533427978240", + "projects/503583131166/locations/us-central1/evaluationItems/8621947972554326016", +] +DISPLAY_NAME = "test_eval_set" + + +def test_create_eval_set(client): + """Tests that create_evaluation_set() returns a correctly structured EvaluationSet.""" + evaluation_set = client.evals.create_evaluation_set( + evaluation_items=EVAL_ITEMS, display_name=DISPLAY_NAME + ) + assert isinstance(evaluation_set, types.EvaluationSet) + assert evaluation_set.display_name == DISPLAY_NAME + assert evaluation_set.evaluation_items == EVAL_ITEMS + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_create_eval_set_async(client): + """Tests that create_evaluation_set() returns a correctly structured EvaluationSet.""" + evaluation_set = await client.aio.evals.create_evaluation_set( + evaluation_items=EVAL_ITEMS, + display_name=DISPLAY_NAME, + ) + assert isinstance(evaluation_set, types.EvaluationSet) + assert evaluation_set.display_name == DISPLAY_NAME + assert evaluation_set.evaluation_items == EVAL_ITEMS + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.create_evaluation_set", +) diff --git a/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py new file mode 100644 index 0000000000..f6a44d36c2 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_create_multimodal_datasets.py @@ -0,0 +1,551 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import sys +from unittest import mock + +from google.cloud import bigquery +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import _datasets_utils +from agentplatform._genai import types +import pandas as pd +import pytest + + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" + + +@pytest.fixture +def mock_bigquery_client(is_replay_mode): + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigquery" + ) as mock_try_import_bigquery: + mock_dataset = mock.MagicMock() + mock_dataset.location = "us-central1" + + mock_client = mock.MagicMock() + mock_client.get_dataset.return_value = mock_dataset + + mock_try_import_bigquery.return_value.Client.return_value = mock_client + mock_try_import_bigquery.return_value.TableReference = ( + bigquery.TableReference + ) + + yield mock_try_import_bigquery + else: + yield None + + +@pytest.fixture +def mock_import_bigframes(is_replay_mode): + if is_replay_mode: + with mock.patch.object( + _datasets_utils, "_try_import_bigframes" + ) as mock_import_bigframes: + session = mock.MagicMock() + session.read_pandas.return_value = mock.MagicMock() + + bigframes = mock.MagicMock() + bigframes.connect.return_value = mock.MagicMock() + + mock_import_bigframes.return_value = bigframes + + yield mock_import_bigframes + else: + yield None + + +@pytest.fixture +def mock_generate_multimodal_dataset_display_name(): + with mock.patch.object( + _datasets_utils, "generate_multimodal_dataset_display_name" + ) as mock_generate: + mock_generate.return_value = "test-generated-name" + yield mock_generate + + +def test_create_dataset(client): + create_dataset_operation = client.datasets._create_multimodal_dataset( + name="projects/vertex-sdk-dev/locations/us-central1", + display_name="test-display-name", + metadata_schema_uri=METADATA_SCHEMA_URI, + metadata={ + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + ) + assert isinstance(create_dataset_operation, types.MultimodalDatasetOperation) + assert create_dataset_operation + + +def test_create_dataset_from_bigquery(client): + dataset = client.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "description": "test-description-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +def test_create_dataset_from_bigquery_with_uri(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +def test_create_dataset_from_bigquery_preserves_other_metadata(client): + dataset = client.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +def test_create_dataset_from_bigquery_no_display_name(client): + dataset = client.datasets.create_from_bigquery( + multimodal_dataset={ + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-generated-name" + + +def test_create_dataset_from_bigquery_raises_if_neither(client): + with pytest.raises( + ValueError, match="At least one of `bigquery_uri` or `multimodal_dataset`" + ): + client.datasets.create_from_bigquery() + + +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_pandas(client, is_replay_mode): + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + + dataset = client.datasets.create_from_pandas( + dataframe=dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-pandas", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-pandas" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher" +) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_bigframes(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = client.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal( + rows.to_dataframe(), dataframe, check_index_type=False + ) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="bigframes requires python 3.10 or higher", +) +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = client.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_create_dataset_async(client): + create_dataset_operation = await client.aio.datasets._create_multimodal_dataset( + name="projects/vertex-sdk-dev/locations/us-central1", + display_name="test-display-name", + metadata_schema_uri=METADATA_SCHEMA_URI, + metadata={ + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + ) + assert isinstance(create_dataset_operation, types.MultimodalDatasetOperation) + assert create_dataset_operation + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "description": "test-description-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +async def test_create_dataset_from_bigquery_with_uri_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_preserves_other_metadata_async( + client, +): + dataset = await client.aio.datasets.create_from_bigquery( + bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}", + multimodal_dataset={ + "display_name": "test-from-bigquery-uri", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery-uri" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name") +async def test_create_dataset_from_bigquery_no_display_name_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + multimodal_dataset={ + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-generated-name" + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_raises_if_neither_async(client): + with pytest.raises( + ValueError, match="At least one of `bigquery_uri` or `multimodal_dataset`" + ): + await client.aio.datasets.create_from_bigquery() + + +@pytest.mark.asyncio +async def test_create_dataset_from_bigquery_async_with_timeout(client): + dataset = await client.aio.datasets.create_from_bigquery( + config=types.CreateMultimodalDatasetConfig(timeout=120), + multimodal_dataset={ + "display_name": "test-from-bigquery", + "description": "test-description-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigquery" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_pandas_async(client, is_replay_mode): + dataframe = pd.DataFrame( + { + "col1": ["col1row1", "col1row2"], + "col2": ["col2row1", "col2row2"], + } + ) + + dataset = await client.aio.datasets.create_from_pandas( + dataframe=dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-pandas", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-pandas" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher" +) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_bigframes_async(client, is_replay_mode): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = await client.aio.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) + if not is_replay_mode: + bigquery_client = bigquery.Client( + project=client._api_client.project, + location=client._api_client.location, + credentials=client._api_client._credentials, + ) + rows = bigquery_client.list_rows( + dataset.metadata.input_config.bigquery_source.uri[5:] + ) + pd.testing.assert_frame_equal( + rows.to_dataframe(), dataframe, check_index_type=False + ) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="bigframes requires python 3.10 or higher", +) +@pytest.mark.asyncio +@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes") +async def test_create_dataset_from_bigframes_preserves_other_metadata_async( + client, is_replay_mode +): + import bigframes.pandas + + dataframe = pd.DataFrame( + { + "col1": ["col1"], + "col2": ["col2"], + } + ) + if is_replay_mode: + bf_dataframe = mock.MagicMock() + bf_dataframe.to_gbq.return_value = "temp_table_id" + else: + bf_dataframe = bigframes.pandas.DataFrame(dataframe) + + dataset = await client.aio.datasets.create_from_bigframes( + dataframe=bf_dataframe, + target_table_id=BIGQUERY_TABLE_NAME, + multimodal_dataset={ + "display_name": "test-from-bigframes", + "metadata": { + "gemini_request_read_config": { + "assembled_request_column_name": "test_column" + } + }, + }, + ) + + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-from-bigframes" + assert ( + dataset.metadata.gemini_request_read_config.assembled_request_column_name + == "test_column" + ) + assert dataset.metadata.input_config.bigquery_source.uri == ( + f"bq://{BIGQUERY_TABLE_NAME}" + ) diff --git a/tests/unit/vertexai/genai/replays/test_create_prompt.py b/tests/unit/agentplatform/genai/replays/test_create_prompt.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_create_prompt.py rename to tests/unit/agentplatform/genai/replays/test_create_prompt.py index 1545881a6b..584c575201 100644 --- a/tests/unit/vertexai/genai/replays/test_create_prompt.py +++ b/tests/unit/agentplatform/genai/replays/test_create_prompt.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_custom_code_execution_metric.py b/tests/unit/agentplatform/genai/replays/test_custom_code_execution_metric.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_custom_code_execution_metric.py rename to tests/unit/agentplatform/genai/replays/test_custom_code_execution_metric.py index f3422f4484..93ea2f4210 100644 --- a/tests/unit/vertexai/genai/replays/test_custom_code_execution_metric.py +++ b/tests/unit/agentplatform/genai/replays/test_custom_code_execution_metric.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pandas as pd import pytest diff --git a/tests/unit/vertexai/genai/replays/test_delete_ae_runtime_revision.py b/tests/unit/agentplatform/genai/replays/test_delete_ae_runtime_revision.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_delete_ae_runtime_revision.py rename to tests/unit/agentplatform/genai/replays/test_delete_ae_runtime_revision.py index 5054503d70..ee95b88ce6 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_ae_runtime_revision.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_ae_runtime_revision.py @@ -17,8 +17,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine.py b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine.py similarity index 93% rename from tests/unit/vertexai/genai/replays/test_delete_agent_engine.py rename to tests/unit/agentplatform/genai/replays/test_delete_agent_engine.py index ea83bdeb2f..5ba141f5e2 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine.py @@ -18,8 +18,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_agent_engine_delete(client, caplog): diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_a2a_task.py b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_a2a_task.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_delete_agent_engine_a2a_task.py rename to tests/unit/agentplatform/genai/replays/test_delete_agent_engine_a2a_task.py index 1312a34400..030768b74f 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_a2a_task.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_a2a_task.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import errors import pytest diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_memory.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py rename to tests/unit/agentplatform/genai/replays/test_delete_agent_engine_memory.py index a9ab9b714c..30c65ee22e 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_memory.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_memory.py @@ -17,8 +17,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_sandbox.py b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_sandbox.py similarity index 92% rename from tests/unit/vertexai/genai/replays/test_delete_agent_engine_sandbox.py rename to tests/unit/agentplatform/genai/replays/test_delete_agent_engine_sandbox.py index a85a9b0f7b..3b74eca585 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_sandbox.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_sandbox.py @@ -13,8 +13,8 @@ # limitations under the License. # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_sandbox(client): diff --git a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_session.py b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_session.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_delete_agent_engine_session.py rename to tests/unit/agentplatform/genai/replays/test_delete_agent_engine_session.py index a0d8e52aff..ef7bc6d62d 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_agent_engine_session.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_agent_engine_session.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_delete_session(client): diff --git a/tests/unit/agentplatform/genai/replays/test_delete_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_delete_multimodal_datasets.py new file mode 100644 index 0000000000..3621b522b5 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_delete_multimodal_datasets.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" + + +def test_delete_dataset(client): + dataset = client.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + + operation = client.datasets._delete_multimodal_dataset( + name=dataset.name, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + assert operation.done + + +def test_delete_dataset_with_public_method(client): + dataset = client.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + + operation = client.datasets.delete_multimodal_dataset( + name=dataset.name, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + assert operation.done + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_delete_dataset_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + + operation = await client.aio.datasets._delete_multimodal_dataset( + name=dataset.name, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + assert operation.done + + +@pytest.mark.asyncio +async def test_delete_dataset_with_public_method_async(client): + dataset = await client.aio.datasets.create_from_bigquery( + multimodal_dataset={ + "display_name": "test-from-bigquery", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + + operation = await client.aio.datasets.delete_multimodal_dataset( + name=dataset.name, + ) + assert isinstance(operation, types.MultimodalDatasetOperation) + assert operation.done diff --git a/tests/unit/vertexai/genai/replays/test_delete_prompt.py b/tests/unit/agentplatform/genai/replays/test_delete_prompt.py similarity index 95% rename from tests/unit/vertexai/genai/replays/test_delete_prompt.py rename to tests/unit/agentplatform/genai/replays/test_delete_prompt.py index dc46c38df4..593de489da 100644 --- a/tests/unit/vertexai/genai/replays/test_delete_prompt.py +++ b/tests/unit/agentplatform/genai/replays/test_delete_prompt.py @@ -16,8 +16,8 @@ import logging -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_evaluate.py b/tests/unit/agentplatform/genai/replays/test_evaluate.py similarity index 99% rename from tests/unit/vertexai/genai/replays/test_evaluate.py rename to tests/unit/agentplatform/genai/replays/test_evaluate.py index 812e4a66ac..8d73cd6d0c 100644 --- a/tests/unit/vertexai/genai/replays/test_evaluate.py +++ b/tests/unit/agentplatform/genai/replays/test_evaluate.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring import re -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pandas as pd diff --git a/tests/unit/vertexai/genai/replays/test_evaluate_instances.py b/tests/unit/agentplatform/genai/replays/test_evaluate_instances.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_evaluate_instances.py rename to tests/unit/agentplatform/genai/replays/test_evaluate_instances.py index 5a1d02a9e8..b5c5be671f 100644 --- a/tests/unit/vertexai/genai/replays/test_evaluate_instances.py +++ b/tests/unit/agentplatform/genai/replays/test_evaluate_instances.py @@ -16,8 +16,8 @@ import json -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pandas as pd import pytest diff --git a/tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py b/tests/unit/agentplatform/genai/replays/test_evaluate_predefined_metrics.py similarity index 99% rename from tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py rename to tests/unit/agentplatform/genai/replays/test_evaluate_predefined_metrics.py index dfebc95022..5217f119d8 100644 --- a/tests/unit/vertexai/genai/replays/test_evaluate_predefined_metrics.py +++ b/tests/unit/agentplatform/genai/replays/test_evaluate_predefined_metrics.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types import pandas as pd diff --git a/tests/unit/vertexai/genai/replays/test_evaluation_metric.py b/tests/unit/agentplatform/genai/replays/test_evaluation_metric.py similarity index 95% rename from tests/unit/vertexai/genai/replays/test_evaluation_metric.py rename to tests/unit/agentplatform/genai/replays/test_evaluation_metric.py index b456c184ca..87be9e26b3 100644 --- a/tests/unit/vertexai/genai/replays/test_evaluation_metric.py +++ b/tests/unit/agentplatform/genai/replays/test_evaluation_metric.py @@ -15,8 +15,8 @@ # pylint: disable=protected-access,bad-continuation,missing-function-docstring import re -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import errors import pytest diff --git a/tests/unit/vertexai/genai/replays/test_execute_code_agent_engine_sandbox.py b/tests/unit/agentplatform/genai/replays/test_execute_code_agent_engine_sandbox.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_execute_code_agent_engine_sandbox.py rename to tests/unit/agentplatform/genai/replays/test_execute_code_agent_engine_sandbox.py index e1f20ef3cd..8100be01c2 100644 --- a/tests/unit/vertexai/genai/replays/test_execute_code_agent_engine_sandbox.py +++ b/tests/unit/agentplatform/genai/replays/test_execute_code_agent_engine_sandbox.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_execute_code_sandbox(client): diff --git a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py b/tests/unit/agentplatform/genai/replays/test_generate_agent_engine_memories.py similarity index 99% rename from tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py rename to tests/unit/agentplatform/genai/replays/test_generate_agent_engine_memories.py index 5a432b9196..fe8801f32e 100644 --- a/tests/unit/vertexai/genai/replays/test_generate_agent_engine_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_generate_agent_engine_memories.py @@ -18,8 +18,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types diff --git a/tests/unit/agentplatform/genai/replays/test_generate_conversation_scenarios.py b/tests/unit/agentplatform/genai/replays/test_generate_conversation_scenarios.py new file mode 100644 index 0000000000..7a362a44df --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_generate_conversation_scenarios.py @@ -0,0 +1,110 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import pytest + + +def test_gen_conversation_scenarios(client): + """Tests that generate_conversation_scenarios() correctly calls the API and parses the response.""" + eval_dataset = client.evals.generate_conversation_scenarios( + agent_info=types.evals.AgentInfo( + agents={ + "booking-agent": types.evals.AgentConfig( + agent_id="booking-agent", + agent_type="service_agent", + description="An agent capable of booking flights and hotels.", + instruction="You are a helpful travel assistant. Use tools to find flights.", + tools=[ + { + "function_declarations": [ + { + "name": "search_flights", + "description": "Search for available flights.", + } + ] + } + ], + ) + }, + root_agent_id="booking-agent", + ), + config=types.evals.UserScenarioGenerationConfig( + count=2, + generation_instruction=( + "Generate scenarios where the user tries to book a flight but" + " changes their mind about the destination." + ), + environment_context="Today is Monday. Flights to Paris are available.", + model_name="gemini-2.5-flash", + ), + ) + assert isinstance(eval_dataset, types.EvaluationDataset) + assert len(eval_dataset.eval_cases) == 2 + assert eval_dataset.eval_cases[0].user_scenario.starting_prompt + assert eval_dataset.eval_cases[0].user_scenario.conversation_plan + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_gen_conversation_scenarios_async(client): + """Tests that generate_conversation_scenarios() async correctly calls the API and parses the response.""" + eval_dataset = await client.aio.evals.generate_conversation_scenarios( + agent_info=types.evals.AgentInfo( + agents={ + "booking-agent": types.evals.AgentConfig( + agent_id="booking-agent", + agent_type="service_agent", + description="An agent capable of booking flights and hotels.", + instruction="You are a helpful travel assistant. Use tools to find flights.", + tools=[ + { + "function_declarations": [ + { + "name": "search_flights", + "description": "Search for available flights.", + } + ] + } + ], + ) + }, + root_agent_id="booking-agent", + ), + config=types.evals.UserScenarioGenerationConfig( + count=2, + generation_instruction=( + "Generate scenarios where the user tries to book a flight but" + " changes their mind about the destination." + ), + environment_context="Today is Monday. Flights to Paris are available.", + model_name="gemini-2.5-flash", + ), + ) + assert isinstance(eval_dataset, types.EvaluationDataset) + assert len(eval_dataset.eval_cases) == 2 + assert eval_dataset.eval_cases[1].user_scenario.starting_prompt + assert eval_dataset.eval_cases[1].user_scenario.conversation_plan + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.generate_conversation_scenarios", +) diff --git a/tests/unit/vertexai/genai/replays/test_generate_loss_clusters.py b/tests/unit/agentplatform/genai/replays/test_generate_loss_clusters.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_generate_loss_clusters.py rename to tests/unit/agentplatform/genai/replays/test_generate_loss_clusters.py index 8611c7a5fd..d990268c76 100644 --- a/tests/unit/vertexai/genai/replays/test_generate_loss_clusters.py +++ b/tests/unit/agentplatform/genai/replays/test_generate_loss_clusters.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_get_ae_runtime_revision.py b/tests/unit/agentplatform/genai/replays/test_get_ae_runtime_revision.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_get_ae_runtime_revision.py rename to tests/unit/agentplatform/genai/replays/test_get_ae_runtime_revision.py index b356b613e4..6bd0e92d58 100644 --- a/tests/unit/vertexai/genai/replays/test_get_ae_runtime_revision.py +++ b/tests/unit/agentplatform/genai/replays/test_get_ae_runtime_revision.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/replays/test_get_agent_engine_a2a_task.py b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_a2a_task.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_get_agent_engine_a2a_task.py rename to tests/unit/agentplatform/genai/replays/test_get_agent_engine_a2a_task.py index 11ef480f4b..e33ec85977 100644 --- a/tests/unit/vertexai/genai/replays/test_get_agent_engine_a2a_task.py +++ b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_a2a_task.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_memory.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py rename to tests/unit/agentplatform/genai/replays/test_get_agent_engine_memory.py index 04ef6b50bd..eb93764b83 100644 --- a/tests/unit/vertexai/genai/replays/test_get_agent_engine_memory.py +++ b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_memory.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_memory(client): diff --git a/tests/unit/vertexai/genai/replays/test_get_agent_engine_sandbox.py b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_sandbox.py similarity index 92% rename from tests/unit/vertexai/genai/replays/test_get_agent_engine_sandbox.py rename to tests/unit/agentplatform/genai/replays/test_get_agent_engine_sandbox.py index b82a74a397..a69c3b6ea1 100644 --- a/tests/unit/vertexai/genai/replays/test_get_agent_engine_sandbox.py +++ b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_sandbox.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_sandbox(client): diff --git a/tests/unit/vertexai/genai/replays/test_get_agent_engine_session.py b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_session.py similarity index 90% rename from tests/unit/vertexai/genai/replays/test_get_agent_engine_session.py rename to tests/unit/agentplatform/genai/replays/test_get_agent_engine_session.py index 4dcce0e36c..a26d55d269 100644 --- a/tests/unit/vertexai/genai/replays/test_get_agent_engine_session.py +++ b/tests/unit/agentplatform/genai/replays/test_get_agent_engine_session.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_get_session(client): diff --git a/tests/unit/agentplatform/genai/replays/test_get_evaluation_item.py b/tests/unit/agentplatform/genai/replays/test_get_evaluation_item.py new file mode 100644 index 0000000000..641f7647bf --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_evaluation_item.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import datetime +import pytest + + +def test_get_eval_item_response(client): + """Tests that get_evaluation_item() returns a correctly structured EvaluationItem.""" + evaluation_item_name = "projects/503583131166/locations/us-central1/evaluationItems/1486082323915997184" + evaluation_item = client.evals.get_evaluation_item(name=evaluation_item_name) + assert isinstance(evaluation_item, types.EvaluationItem) + check_item_1486082323915997184(evaluation_item, evaluation_item_name) + + +def test_get_eval_item_request(client): + """Tests that get_evaluation_item() returns a correctly structured EvaluationItem with request.""" + evaluation_item_name = "projects/503583131166/locations/us-central1/evaluationItems/4813679498589372416" + evaluation_item = client.evals.get_evaluation_item(name=evaluation_item_name) + assert isinstance(evaluation_item, types.EvaluationItem) + check_item_4813679498589372416(evaluation_item, evaluation_item_name) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_get_eval_item_response_async(client): + """Tests that get_evaluation_item() returns a correctly structured EvaluationItem.""" + eval_item_id = "1486082323915997184" + evaluation_item_name = ( + f"projects/503583131166/locations/us-central1/evaluationItems/{eval_item_id}" + ) + evaluation_item = await client.aio.evals.get_evaluation_item(name=eval_item_id) + check_item_1486082323915997184(evaluation_item, evaluation_item_name) + + +@pytest.mark.asyncio +async def test_get_eval_item_request_async(client): + """Tests that get_evaluation_item() returns a correctly structured EvaluationItem with request.""" + eval_item_id = "4813679498589372416" + evaluation_item_name = ( + f"projects/503583131166/locations/us-central1/evaluationItems/{eval_item_id}" + ) + evaluation_item = await client.aio.evals.get_evaluation_item(name=eval_item_id) + check_item_4813679498589372416(evaluation_item, evaluation_item_name) + + +def check_item_1486082323915997184( + evaluation_item: types.EvaluationItem, evaluation_item_name: str +): + assert evaluation_item.name == evaluation_item_name + assert evaluation_item.display_name == "universal result for 7119522507803066368" + assert evaluation_item.evaluation_item_type == types.EvaluationItemType.RESULT + assert ( + evaluation_item.gcs_uri + == "gs://lakeyk-limited-bucket/agora_eval_080525/result_1486082323915997184.json" + ) + assert evaluation_item.create_time == datetime.datetime( + 2025, 9, 8, 20, 55, 46, 713792, tzinfo=datetime.timezone.utc + ) + assert isinstance(evaluation_item.evaluation_response, types.EvaluationItemResult) + assert ( + evaluation_item.evaluation_response.evaluation_request + == "projects/503583131166/locations/us-central1/evaluationItems/7119522507803066368" + ) + assert ( + evaluation_item.evaluation_response.evaluation_run + == "projects/503583131166/locations/us-central1/evaluationRuns/1957799200510967808" + ) + # Check the first candidate result. + candidate_result = evaluation_item.evaluation_response.candidate_results[0] + assert candidate_result.candidate == "gemini-2.0-flash-001@default" + assert candidate_result.metric == "universal" + assert candidate_result.score == 0.2857143 + # Check the first rubric verdict. + rubric_verdict = candidate_result.rubric_verdicts[0] + assert rubric_verdict.verdict + assert ( + rubric_verdict.reasoning + == "The entire response is written in the English language." + ) + assert rubric_verdict.evaluated_rubric.type == "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE" + assert rubric_verdict.evaluated_rubric.importance == "HIGH" + assert ( + rubric_verdict.evaluated_rubric.content.property.description + == "The response is in English." + ) + # Check the request. + request = evaluation_item.evaluation_response.request + assert ( + "There is a wide range of potato varieties to choose from" + in request.prompt.text + ) + assert request.candidate_responses[0].candidate == "gemini-2.0-flash-001@default" + assert "Pick out your potato variety" in request.candidate_responses[0].text + + +def check_item_4813679498589372416( + evaluation_item: types.EvaluationItem, evaluation_item_name: str +): + assert evaluation_item.name == evaluation_item_name + assert evaluation_item.display_name == "4813679498589372416" + assert evaluation_item.evaluation_item_type == types.EvaluationItemType.REQUEST + assert ( + evaluation_item.gcs_uri + == "gs://lakeyk-limited-bucket/agora_eval_080525/request_4813679498589372416.json" + ) + assert evaluation_item.create_time == datetime.datetime( + 2025, 9, 8, 20, 55, 46, 338353, tzinfo=datetime.timezone.utc + ) + assert isinstance(evaluation_item.evaluation_request, types.EvaluationItemRequest) + # Check the request. + request = evaluation_item.evaluation_request + assert ( + "If your ball is curving during flight from left to right" + in request.prompt.text + ) + # Check the first candidate response. + assert request.candidate_responses[0].candidate == "gemini-2.0-flash-001@default" + assert ( + "Keep your knees bent during the backswing" + in request.candidate_responses[0].text + ) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.get_evaluation_item", +) diff --git a/tests/unit/agentplatform/genai/replays/test_get_evaluation_run.py b/tests/unit/agentplatform/genai/replays/test_get_evaluation_run.py new file mode 100644 index 0000000000..cc8b7bbde1 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_evaluation_run.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import datetime +import pytest + + +def test_get_eval_run(client): + """Tests that get_evaluation_run() returns a correctly structured EvaluationRun.""" + client._api_client._http_options.api_version = "v1beta1" + evaluation_run_name = ( + "projects/977012026409/locations/us-central1/evaluationRuns/3940878372367761408" + ) + evaluation_run = client.evals.get_evaluation_run( + name=evaluation_run_name, include_evaluation_items=True + ) + check_run_3940878372367761408(client, evaluation_run, evaluation_run_name) + check_run_3940878372367761408_evaluation_item_results( + client, evaluation_run, evaluation_run_name + ) + + +def test_get_eval_run_include_evaluation_items_false(client): + """Tests that get_evaluation_run() returns a correctly structured EvaluationRun.""" + client._api_client._http_options.api_version = "v1beta1" + evaluation_run_name = ( + "projects/977012026409/locations/us-central1/evaluationRuns/3940878372367761408" + ) + evaluation_run = client.evals.get_evaluation_run(name=evaluation_run_name) + check_run_3940878372367761408(client, evaluation_run, evaluation_run_name) + assert evaluation_run.evaluation_item_results is None + + +def test_get_eval_run_bq_source(client): + """Tests that get_evaluation_run() returns a correctly structured EvaluationRun.""" + evaluation_run_name = ( + "projects/503583131166/locations/us-central1/evaluationRuns/1968424880881795072" + ) + evaluation_run = client.evals.get_evaluation_run( + name=evaluation_run_name, include_evaluation_items=True + ) + assert isinstance(evaluation_run, types.EvaluationRun) + assert evaluation_run.name == evaluation_run_name + assert evaluation_run.display_name == "test1" + assert evaluation_run.data_source.bigquery_request_set == types.BigQueryRequestSet( + uri="bq://lakeyk-test-limited.inference_batch_prediction_input.1317387725199900672_1b", + prompt_column="request", + rubrics_column="rubric", + candidate_response_columns={ + "baseline_model_response": "baseline_model_response", + "checkpoint_1": "checkpoint_1", + "checkpoint_2": "checkpoint_2", + }, + sampling_config=types.SamplingConfig( + sampling_count=100, + sampling_method=types.SamplingMethod.RANDOM, + sampling_duration="60s", + ), + ) + + +def test_get_eval_run_eval_set_source(client): + """Tests that get_evaluation_run() returns a correctly structured EvaluationRun.""" + evaluation_run_name = ( + "projects/503583131166/locations/us-central1/evaluationRuns/6903525647549726720" + ) + evaluation_run = client.evals.get_evaluation_run( + name=evaluation_run_name, include_evaluation_items=True + ) + assert isinstance(evaluation_run, types.EvaluationRun) + assert evaluation_run.name == evaluation_run_name + assert evaluation_run.display_name == "test3" + assert evaluation_run.data_source.evaluation_set == ( + "projects/503583131166/locations/us-central1/evaluationSets/6619939608513740800" + ) + assert evaluation_run.state == types.EvaluationRunState.FAILED + assert evaluation_run.error.message == ( + "code=INVALID_ARGUMENT, message=EvaluationRun 6903525647549726720 has no " + "items, cause=null" + ) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_get_eval_run_async(client): + """Tests that get_evaluation_run() returns a correctly structured EvaluationRun.""" + client._api_client._http_options.api_version = "v1beta1" + eval_run_id = "3940878372367761408" + evaluation_run_name = ( + f"projects/977012026409/locations/us-central1/evaluationRuns/{eval_run_id}" + ) + evaluation_run = await client.aio.evals.get_evaluation_run(name=eval_run_id) + check_run_3940878372367761408(client, evaluation_run, evaluation_run_name) + assert evaluation_run.evaluation_item_results is None + + +def check_run_3940878372367761408( + client, evaluation_run: types.EvaluationRun, evaluation_run_name: str +): + assert isinstance(evaluation_run, types.EvaluationRun) + assert evaluation_run.name == evaluation_run_name + assert ( + evaluation_run.display_name + == "evaluation_run_9a464a39-6d40-4d4e-a5e2-a4ceabea4b15" + ) + assert evaluation_run.metadata == {"pipeline_id": "8162140658019074048"} + assert evaluation_run.create_time == datetime.datetime( + 2026, 3, 18, 1, 10, 13, 360535, tzinfo=datetime.timezone.utc + ) + assert evaluation_run.completion_time == datetime.datetime( + 2026, 3, 18, 1, 11, 0, 448191, tzinfo=datetime.timezone.utc + ) + assert evaluation_run.state == types.EvaluationRunState.SUCCEEDED + assert evaluation_run.evaluation_set_snapshot == ( + "projects/977012026409/locations/us-central1/evaluationSets/3885168317211607040" + ) + assert ( + evaluation_run.data_source.evaluation_set + == "projects/977012026409/locations/us-central1/evaluationSets/3991900109943078912" + ) + assert evaluation_run.evaluation_run_results.evaluation_set == ( + "projects/977012026409/locations/us-central1/evaluationSets/3885168317211607040" + ) + assert evaluation_run.evaluation_run_results.summary_metrics.total_items == 2 + assert evaluation_run.error is None + + +def check_run_3940878372367761408_evaluation_item_results( + client, evaluation_run: types.EvaluationRun, evaluation_run_name: str +): + eval_result = evaluation_run.evaluation_item_results + assert isinstance(eval_result, types.EvaluationResult) + assert eval_result.summary_metrics == [ + types.AggregatedMetricResult( + metric_name="general_quality_v1", + mean_score=0.13333333656191826, + stdev_score=0.03333333507180214, + ), + ] + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.get_evaluation_run", +) diff --git a/tests/unit/agentplatform/genai/replays/test_get_evaluation_set.py b/tests/unit/agentplatform/genai/replays/test_get_evaluation_set.py new file mode 100644 index 0000000000..0f0da59aa5 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_evaluation_set.py @@ -0,0 +1,89 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform import types +import datetime +import pytest + + +def test_get_eval_set(client): + """Tests that get_evaluation_set() returns a correctly structured EvaluationSet.""" + evaluation_set_name = ( + "projects/503583131166/locations/us-central1/evaluationSets/102386522778501120" + ) + evaluation_set = client.evals.get_evaluation_set(name=evaluation_set_name) + assert isinstance(evaluation_set, types.EvaluationSet) + check_set_102386522778501120(evaluation_set, evaluation_set_name) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_get_eval_set_async(client): + """Tests that get_evaluation_set() returns a correctly structured EvaluationSet.""" + eval_set_id = "102386522778501120" + evaluation_set_name = ( + f"projects/503583131166/locations/us-central1/evaluationSets/{eval_set_id}" + ) + evaluation_set = await client.aio.evals.get_evaluation_set(name=eval_set_id) + check_set_102386522778501120(evaluation_set, evaluation_set_name) + + +def check_set_102386522778501120( + evaluation_set: types.EvaluationSet, evaluation_set_name: str +): + assert evaluation_set.name == evaluation_set_name + assert ( + evaluation_set.display_name + == "Results Set for EvaluationRun 1957799200510967808" + ) + assert evaluation_set.evaluation_items == [ + "projects/503583131166/locations/us-central1/evaluationItems/2748216119486578688", + "projects/503583131166/locations/us-central1/evaluationItems/1486082323915997184", + "projects/503583131166/locations/us-central1/evaluationItems/2219043163270545408", + "projects/503583131166/locations/us-central1/evaluationItems/8570244537769787392", + "projects/503583131166/locations/us-central1/evaluationItems/2112082672120496128", + "projects/503583131166/locations/us-central1/evaluationItems/8192505119024087040", + "projects/503583131166/locations/us-central1/evaluationItems/1383625432393318400", + "projects/503583131166/locations/us-central1/evaluationItems/5832267070561058816", + "projects/503583131166/locations/us-central1/evaluationItems/1733991409653907456", + "projects/503583131166/locations/us-central1/evaluationItems/2549142942207967232", + "projects/503583131166/locations/us-central1/evaluationItems/8565740938142416896", + "projects/503583131166/locations/us-central1/evaluationItems/6069620844672319488", + "projects/503583131166/locations/us-central1/evaluationItems/7777822109585113088", + "projects/503583131166/locations/us-central1/evaluationItems/5656415578861076480", + "projects/503583131166/locations/us-central1/evaluationItems/5926842662735839232", + "projects/503583131166/locations/us-central1/evaluationItems/648623899457617920", + "projects/503583131166/locations/us-central1/evaluationItems/4349245787016790016", + "projects/503583131166/locations/us-central1/evaluationItems/1119038954285301760", + "projects/503583131166/locations/us-central1/evaluationItems/5741983971781115904", + ] + assert evaluation_set.create_time == datetime.datetime( + 2025, 9, 8, 20, 55, 46, 413954, tzinfo=datetime.timezone.utc + ) + assert evaluation_set.update_time == datetime.datetime( + 2025, 9, 8, 20, 55, 46, 413954, tzinfo=datetime.timezone.utc + ) + assert evaluation_set.metadata is None + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.get_evaluation_set", +) diff --git a/tests/unit/agentplatform/genai/replays/test_get_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_get_multimodal_datasets.py new file mode 100644 index 0000000000..ee7e2f406d --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_multimodal_datasets.py @@ -0,0 +1,88 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +DATASET = "projects/964831358985/locations/us-central1/datasets/8810841321427173376" + + +def test_get_dataset(client): + dataset = client.datasets._get_multimodal_dataset( + name=DATASET, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + +def test_get_dataset_from_public_method(client): + dataset = client.datasets.get_multimodal_dataset( + name=DATASET, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + +def test_get_dataset_by_id(client): + dataset = client.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_get_dataset_async(client): + dataset = await client.aio.datasets._get_multimodal_dataset( + name=DATASET, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + +@pytest.mark.asyncio +async def test_get_dataset_from_public_method_async(client): + dataset = await client.aio.datasets.get_multimodal_dataset( + name=DATASET, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" + + +@pytest.mark.asyncio +async def test_get_dataset_by_id_async(client): + dataset = await client.aio.datasets.get_multimodal_dataset( + name="8810841321427173376", + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.name == DATASET + assert dataset.display_name == "test-display-name" diff --git a/tests/unit/agentplatform/genai/replays/test_get_prompt_operation.py b/tests/unit/agentplatform/genai/replays/test_get_prompt_operation.py new file mode 100644 index 0000000000..67948dce61 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_prompt_operation.py @@ -0,0 +1,32 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper + + +def test_get_dataset_operation(client): + dataset_operation = client.prompts._get_dataset_operation( + dataset_id="6550997480673116160", + operation_id="5108504762664353792", + ) + assert dataset_operation.name is not None + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompts._get_dataset_operation", +) diff --git a/tests/unit/agentplatform/genai/replays/test_get_prompt_resource.py b/tests/unit/agentplatform/genai/replays/test_get_prompt_resource.py new file mode 100644 index 0000000000..8ab082dea2 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_get_prompt_resource.py @@ -0,0 +1,109 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +from google.genai import types as genai_types + +import pytest + + +TEST_PROMPT_DATASET_ID = "6550997480673116160" +TEST_PROMPT_VERSION_ID = "2" + + +def test_get_dataset(client): + dataset = client.prompts._get_dataset_resource(name=TEST_PROMPT_DATASET_ID) + assert isinstance(dataset, types.Dataset) + + +def test_get_prompt(client): + prompt = client.prompts.get(prompt_id=TEST_PROMPT_DATASET_ID) + assert isinstance(prompt, types.Prompt) + assert isinstance(prompt.dataset, types.Dataset) + assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID) + assert ( + prompt.prompt_data + == prompt.dataset.metadata.prompt_api_schema.multimodal_prompt.prompt_message + ) + assert isinstance(prompt.prompt_data, types.SchemaPromptSpecPromptMessage) + + contents = prompt.assemble_contents() + assert isinstance(contents[0], genai_types.Content) + + +def test_get_prompt_version(client): + prompt = client.prompts.get_version( + prompt_id=TEST_PROMPT_DATASET_ID, + version_id=TEST_PROMPT_VERSION_ID, + ) + assert isinstance(prompt, types.Prompt) + assert isinstance(prompt.dataset, types.Dataset) + assert isinstance(prompt.dataset_version, types.DatasetVersion) + assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID) + assert prompt.dataset_version.name.endswith(TEST_PROMPT_VERSION_ID) + + +def test_get_prompt_with_variables_and_assemble_contents(client): + prompt = client.prompts.get( + prompt_id="4505721135056289792", + ) + assert isinstance(prompt.prompt_data, types.SchemaPromptSpecPromptMessage) + assembled_contents = prompt.assemble_contents() + assert isinstance(assembled_contents, list) + assert len(assembled_contents) == 1 + assert isinstance(assembled_contents[0], genai_types.Content) + assert assembled_contents[0].parts[0].text == "Hello, Alice! How are you?" + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompts._get_dataset_resource", +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_get_prompt_async(client): + prompt = await client.aio.prompts.get(prompt_id=TEST_PROMPT_DATASET_ID) + assert isinstance(prompt, types.Prompt) + assert isinstance(prompt.dataset, types.Dataset) + assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID) + assert ( + prompt.prompt_data + == prompt.dataset.metadata.prompt_api_schema.multimodal_prompt.prompt_message + ) + assert isinstance(prompt.prompt_data, types.SchemaPromptSpecPromptMessage) + + contents = prompt.assemble_contents() + assert isinstance(contents[0], genai_types.Content) + + +@pytest.mark.asyncio +async def test_get_prompt_version_async(client): + prompt = await client.aio.prompts.get_version( + prompt_id=TEST_PROMPT_DATASET_ID, version_id=TEST_PROMPT_VERSION_ID + ) + assert isinstance(prompt, types.Prompt) + assert isinstance(prompt.dataset, types.Dataset) + assert prompt.dataset.name.endswith(TEST_PROMPT_DATASET_ID) + assert ( + prompt.prompt_data + == prompt.dataset.metadata.prompt_api_schema.multimodal_prompt.prompt_message + ) + assert isinstance(prompt.prompt_data, types.SchemaPromptSpecPromptMessage) diff --git a/tests/unit/vertexai/genai/replays/test_ingest_events_memory_bank.py b/tests/unit/agentplatform/genai/replays/test_ingest_events_memory_bank.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_ingest_events_memory_bank.py rename to tests/unit/agentplatform/genai/replays/test_ingest_events_memory_bank.py index b4a4ccb820..05b5e393aa 100644 --- a/tests/unit/vertexai/genai/replays/test_ingest_events_memory_bank.py +++ b/tests/unit/agentplatform/genai/replays/test_ingest_events_memory_bank.py @@ -14,7 +14,7 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper def test_ingest_events(client): diff --git a/tests/unit/agentplatform/genai/replays/test_internal_generate_rubrics.py b/tests/unit/agentplatform/genai/replays/test_internal_generate_rubrics.py new file mode 100644 index 0000000000..747858e370 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_internal_generate_rubrics.py @@ -0,0 +1,170 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from google.genai import types as genai_types + +_TEST_RUBRIC_GENERATION_PROMPT = """SPECIAL INSTRUCTION: think silently. Silent thinking token budget: 16384. + +You are a teacher who is responsible for scoring a student\'s response to a prompt. In order to score that response, you must write down a rubric for each prompt. That rubric states what properties the response must have in order to be a valid response to the prompt. Properties are weighted by importance via the "importance" field. + +Rubric requirements: +- Properties either exist or don\'t exist. +- Properties can be either implicit in the prompt or made explicit by the prompt. +- Make sure to always include the correct expected human language as one of the properties. If the prompt asks for code, the programming language should be covered by a separate property. +- The correct expected language may be explicit in the text of the prompt but is usually simply implicit in the prompt itself. +- Be as comprehensive as possible with the list of properties in the rubric. +- All properties in the rubric must be in English, regardless of the language of the prompt. +- Rubric properties should not specify correct answers in their descriptions, e.g. to math and factoid questions if the prompt calls for such an answer. Rather, it should check that the response contains an answer and optional supporting evidence if relevant, and assume some other process will later validate correctness. A rubric property should however call out any false premises present in the prompt. + +About importance: +- Most properties will be of medium importance by default. +- Properties of high importance are critical to be fulfilled in a good response. +- Properties of low importance are considered optional or supplementary nice-to-haves. + +You will see prompts in many different languages, not just English. For each prompt you see, you will write down this rubric in JSON format. + +IMPORTANT: Never respond to the prompt given. Only write a rubric. + +Example: +What is the tallest building in the world? + +```json +{ + "criteria":[ + { + "rubric_id": "00001", + "property": "The response is in English.", + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "high" + }, + { + "rubric_id": "00002", + "property": "Contains the name of the tallest building in the world.", + "type": "QA_ANSWER:FACTOID", + "importance": "high" + }, + { + "rubric_id": "00003", + "property": "Contains the exact height of the tallest building.", + "type": "QA_SUPPORTING_EVIDENCE:HEIGHT", + "importance": "low" + }, + { + "rubric_id": "00004", + "property": "Contains the location of the tallest building.", + "type": "QA_SUPPORTING_EVIDENCE:LOCATION", + "importance": "low" + }, + ... + ] +} +``` + +Write me a letter to my HOA asking them to reconsider the fees they are asking me to pay because I haven\'t mowed my lawn on time. I have been very busy at work. +```json +{ + "criteria": [ + { + "rubric_id": "00001", + "property": "The response is in English.", + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "high" + }, + { + "rubric_id": "00002", + "property": "The response is formatted as a letter.", + "type": "FORMAT_REQUIREMENT:FORMAL_LETTER", + "importance": "medium" + }, + { + "rubric_id": "00003", + "property": "The letter is addressed to the Homeowners Association (HOA).", + "type": "CONTENT_REQUIREMENT:ADDRESSEE", + "importance": "medium" + }, + { + "rubric_id": "00004", + "property": "The letter explains that the sender has not mowed their lawn on time.", + "type": "CONTENT_REQUIREMENT:BACKGROUND_CONTEXT:TARDINESS", + "importance": "medium" + }, + { + "rubric_id": "00005", + "property": "The letter provides a reason for not mowing the lawn, specifically being busy at work.", + "type": "CONTENT_REQUIREMENT:EXPLANATION:EXCUSE:BUSY", + "importance": "medium" + }, + { + "rubric_id": "00006", + "property": "The letter discusses that the sender has been in compliance until now.", + "type": "OPTIONAL_CONTENT:SUPPORTING_EVIDENCE:COMPLIANCE", + "importance": "low" + }, + { + "rubric_id": "00007", + "property": "The letter requests that the HOA reconsider the fees associated with not mowing the lawn on time.", + "type": "CONTENT_REQUIREMENT:REQUEST:FEE_WAIVER", + "importance": "high" + }, + { + "rubric_id": "00008", + "property": "The letter maintains a polite and respectful tone.", + "type": "CONTENT_REQUIREMENT:FORMALITY:FORMAL", + "importance": "high" + }, + { + "rubric_id": "00009", + "property": "The letter includes a closing (e.g., \'Sincerely\') and the sender\'s name.", + "type": "CONTENT_REQUIREMENT:SIGNATURE", + "importance": "medium" + } + ] +} +``` + +Now write a rubric for the following user prompt. Remember to write only the rubric, NOT response to the prompt. + +User prompt: +{prompt}""" + + +def test_internal_method_generate_rubrics(client): + """Tests the internal _generate_rubrics method.""" + test_contents = [ + genai_types.Content( + parts=[ + genai_types.Part( + text="Generate a short story about a friendly dragon.", + ), + ], + ) + ] + response = client.evals._generate_rubrics( + contents=test_contents, + rubric_generation_spec=genai_types.RubricGenerationSpec( + prompt_template=_TEST_RUBRIC_GENERATION_PROMPT, + ), + ) + assert len(response.generated_rubrics) >= 1 + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals._generate_rubrics", +) diff --git a/tests/unit/vertexai/genai/replays/test_list_ae_runtime_revisions.py b/tests/unit/agentplatform/genai/replays/test_list_ae_runtime_revisions.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_list_ae_runtime_revisions.py rename to tests/unit/agentplatform/genai/replays/test_list_ae_runtime_revisions.py index 87c276cec0..0bb782cfe0 100644 --- a/tests/unit/vertexai/genai/replays/test_list_ae_runtime_revisions.py +++ b/tests/unit/agentplatform/genai/replays/test_list_ae_runtime_revisions.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_task_events.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_task_events.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_task_events.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_task_events.py index ef3e483d8c..257b4e2939 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_task_events.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_task_events.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_tasks.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_tasks.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_tasks.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_tasks.py index dd758bc46d..9e25683d78 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_a2a_tasks.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_a2a_tasks.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_memories.py similarity index 95% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_memories.py index 0e84b5c64e..95b65bfd7a 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_memories.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_list_memories(client): diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_sandboxes.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_sandboxes.py similarity index 92% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_sandboxes.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_sandboxes.py index fd8e72421a..bef059fd8e 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_sandboxes.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_sandboxes.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_list_sandboxes(client): diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_session_events.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_session_events.py similarity index 95% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_session_events.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_session_events.py index cd0295aeed..e0ad42953f 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_session_events.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_session_events.py @@ -17,8 +17,8 @@ import datetime import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_list_session_events(client): diff --git a/tests/unit/vertexai/genai/replays/test_list_agent_engine_sessions.py b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_sessions.py similarity index 94% rename from tests/unit/vertexai/genai/replays/test_list_agent_engine_sessions.py rename to tests/unit/agentplatform/genai/replays/test_list_agent_engine_sessions.py index 1b296ce992..3b98b28eeb 100644 --- a/tests/unit/vertexai/genai/replays/test_list_agent_engine_sessions.py +++ b/tests/unit/agentplatform/genai/replays/test_list_agent_engine_sessions.py @@ -17,8 +17,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_list_sessions(client): diff --git a/tests/unit/agentplatform/genai/replays/test_list_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_list_multimodal_datasets.py new file mode 100644 index 0000000000..5fa33dce9c --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_list_multimodal_datasets.py @@ -0,0 +1,41 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" + + +def test_list_dataset(client): + datasets = client.datasets._list_multimodal_datasets() + assert isinstance(datasets, types.ListMultimodalDatasetsResponse) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_list_dataset_async(client): + datasets = await client.aio.datasets._list_multimodal_datasets() + assert isinstance(datasets, types.ListMultimodalDatasetsResponse) diff --git a/tests/unit/vertexai/genai/replays/test_list_prompts.py b/tests/unit/agentplatform/genai/replays/test_list_prompts.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_list_prompts.py rename to tests/unit/agentplatform/genai/replays/test_list_prompts.py index 59fe537326..fd812b2c56 100644 --- a/tests/unit/vertexai/genai/replays/test_list_prompts.py +++ b/tests/unit/agentplatform/genai/replays/test_list_prompts.py @@ -13,8 +13,8 @@ # limitations under the License. # -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types import pytest diff --git a/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_async_optimize_prompt_return_type.py b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_async_optimize_prompt_return_type.py new file mode 100644 index 0000000000..d3d7cdec66 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_async_optimize_prompt_return_type.py @@ -0,0 +1,105 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +import pandas as pd +import pytest + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_optimize_prompt(client): + """Tests the optimize request parameters method.""" + + test_prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompt_optimizer.optimize_prompt(prompt=test_prompt) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +@pytest.mark.asyncio +async def test_optimize_prompt_w_optimization_target(client): + """Tests the optimize request parameters method with optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +@pytest.mark.asyncio +async def test_optimize_prompt_w_few_shot_optimization_target(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "target_response": ["target1", "target2"], + } + ) + response = await client.aio.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponseFewShot) + + +@pytest.mark.asyncio +async def test_optimize_prompt_w_few_shot_optimization_rubrics(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "rubrics": ["rubric1", "rubric2"], + "rubrics_evaluations": ["[True, True]", "[True, False]"], + } + ) + response = await client.aio.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponseFewShot) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompt_optimizer.optimize_prompt", +) diff --git a/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_job_state.py b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_job_state.py new file mode 100644 index 0000000000..f0e4676994 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_job_state.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import logging +import os +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +from google.genai import types as genai_types +import pytest + + +def _raise_for_unset_env_vars() -> None: + if not os.environ.get("VAPO_CONFIG_PATH"): + raise ValueError("VAPO_CONFIG_PATH environment variable is not set.") + if not os.environ.get("VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"): + raise ValueError( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER " "environment variable is not set." + ) + + +# If you re-record this test, you will need to update the replay file to +# include the placeholder values for config path and service account +def test_optimize(client): + """Tests the optimize request parameters method.""" + + _raise_for_unset_env_vars() + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + wait_for_completion=True, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = client.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED + + +def test_optimize_nano(client): + """Tests the optimize request parameters method.""" + + _raise_for_unset_env_vars() + + config_path = os.environ.get("VAPO_CONFIG_PATH") + root, ext = os.path.splitext(config_path) + nano_path = f"{root}_nano{ext}" + + config = types.PromptOptimizerConfig( + config_path=nano_path, + wait_for_completion=True, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + + job = client.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompt_optimizer.optimize", +) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_optimize_async(client): + _raise_for_unset_env_vars() + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = await client.aio.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + + +@pytest.mark.asyncio +async def test_optimize_nano_async(client): + _raise_for_unset_env_vars() + config_path = os.environ.get("VAPO_CONFIG_PATH") + root, ext = os.path.splitext(config_path) + nano_path = f"{root}_nano{ext}" + + config = types.PromptOptimizerConfig( + config_path=nano_path, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = await client.aio.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + + +@pytest.mark.asyncio +async def test_optimize_async_with_config_wait_for_completion(client, caplog): + _raise_for_unset_env_vars() + caplog.set_level(logging.INFO) + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + wait_for_completion=True, + ) + job = await client.aio.prompt_optimizer.optimize( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + assert "Ignoring wait_for_completion=True" in caplog.text diff --git a/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py new file mode 100644 index 0000000000..dc0001b069 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_prompt_optimizer_optimize_prompt_return_type.py @@ -0,0 +1,97 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +import pandas as pd + + +def test_optimize_prompt(client): + """Tests the optimize request parameters method.""" + + test_prompt = "Generate system instructions for analyzing medical articles" + response = client.prompt_optimizer.optimize_prompt(prompt=test_prompt) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +def test_optimize_prompt_w_optimization_target(client): + """Tests the optimize request parameters method with optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + response = client.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +def test_optimize_prompt_w_few_shot_optimization_target(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "target_response": ["target1", "target2"], + } + ) + response = client.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponseFewShot) + + +def test_optimize_prompt_w_few_shot_optimization_rubrics(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "rubrics": ["rubric1", "rubric2"], + "rubrics_evaluations": ["[True, True]", "[True, False]"], + } + ) + response = client.prompt_optimizer.optimize_prompt( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponseFewShot) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompt_optimizer.optimize_prompt", +) diff --git a/tests/unit/agentplatform/genai/replays/test_prompts_launch_job.py b/tests/unit/agentplatform/genai/replays/test_prompts_launch_job.py new file mode 100644 index 0000000000..c139087476 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_prompts_launch_job.py @@ -0,0 +1,153 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import logging +import os +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +from google.genai import types as genai_types +import pytest + + +def _raise_for_unset_env_vars() -> None: + if not os.environ.get("VAPO_CONFIG_PATH"): + raise ValueError("VAPO_CONFIG_PATH environment variable is not set.") + if not os.environ.get("VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER"): + raise ValueError( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER " "environment variable is not set." + ) + + +# If you re-record this test, you will need to update the replay file to +# include the placeholder values for config path and service account +def test_launch_job(client): + """Tests the optimize request parameters method.""" + + _raise_for_unset_env_vars() + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + wait_for_completion=True, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = client.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED + + +def test_launch_job_nano(client): + """Tests the optimize request parameters method for nano.""" + + _raise_for_unset_env_vars() + + config_path = os.environ.get("VAPO_CONFIG_PATH") + root, ext = os.path.splitext(config_path) + nano_path = f"{root}_nano{ext}" + + config = types.PromptOptimizerConfig( + config_path=nano_path, + wait_for_completion=True, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + + job = client.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_SUCCEEDED + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompts.launch_optimization_job", +) + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_launch_job_async(client): + _raise_for_unset_env_vars() + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = await client.aio.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + + +@pytest.mark.asyncio +async def test_launch_job_nano_async(client): + _raise_for_unset_env_vars() + config_path = os.environ.get("VAPO_CONFIG_PATH") + root, ext = os.path.splitext(config_path) + nano_path = f"{root}_nano{ext}" + + config = types.PromptOptimizerConfig( + config_path=nano_path, + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + ) + job = await client.aio.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + + +@pytest.mark.asyncio +async def test_optimize_async_with_config_wait_for_completion(client, caplog): + _raise_for_unset_env_vars() + caplog.set_level(logging.INFO) + + config = types.PromptOptimizerConfig( + config_path=os.environ.get("VAPO_CONFIG_PATH"), + service_account_project_number=os.environ.get( + "VAPO_SERVICE_ACCOUNT_PROJECT_NUMBER" + ), + optimizer_job_display_name="optimizer_job_test", + wait_for_completion=True, + ) + job = await client.aio.prompts.launch_optimization_job( + method=types.PromptOptimizerMethod.VAPO, + config=config, + ) + assert isinstance(job, types.CustomJob) + assert job.state == genai_types.JobState.JOB_STATE_PENDING + assert "Ignoring wait_for_completion=True" in caplog.text diff --git a/tests/unit/agentplatform/genai/replays/test_prompts_optimize.py b/tests/unit/agentplatform/genai/replays/test_prompts_optimize.py new file mode 100644 index 0000000000..72fa3c1b22 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_prompts_optimize.py @@ -0,0 +1,176 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +import pandas as pd + +import pytest + + +def test_optimize_prompt(client): + """Tests the optimize request parameters method.""" + + test_prompt = "Generate system instructions for analyzing medical articles" + response = client.prompts.optimize(prompt=test_prompt) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +def test_optimize_prompt_w_optimization_target(client): + """Tests the optimize request parameters method with optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + response = client.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +def test_optimize_prompt_w_few_shot_optimization_target(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "target_response": ["target1", "target2"], + } + ) + response = client.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponse) + + +def test_optimize_prompt_w_few_shot_optimization_rubrics(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "rubrics": ["rubric1", "rubric2"], + "rubrics_evaluations": ["[True, True]", "[True, False]"], + } + ) + response = client.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponse) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompts.optimize", +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_async_optimize_prompt(client): + """Tests the optimize request parameters method.""" + + test_prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompts.optimize(prompt=test_prompt) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +@pytest.mark.asyncio +async def test_async_optimize_prompt_w_optimization_target(client): + """Tests the optimize request parameters method with optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + response = await client.aio.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + + +@pytest.mark.asyncio +async def test_async_optimize_prompt_w_few_shot_optimization_target(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "target_response": ["target1", "target2"], + } + ) + response = await client.aio.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponse) + + +@pytest.mark.asyncio +async def test_async_optimize_prompt_w_few_shot_optimization_rubrics(client): + """Tests the optimize request parameters method with few shot optimization target.""" + test_prompt = "Generate system instructions for analyzing medical articles" + df = pd.DataFrame( + { + "prompt": ["prompt1", "prompt2"], + "model_response": ["response1", "response2"], + "rubrics": ["rubric1", "rubric2"], + "rubrics_evaluations": ["[True, True]", "[True, False]"], + } + ) + response = await client.aio.prompts.optimize( + prompt=test_prompt, + config=types.OptimizeConfig( + optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_RUBRICS, + examples_dataframe=df, + ), + ) + assert isinstance(response, types.OptimizeResponse) + assert response.raw_text_response + assert isinstance(response.raw_text_response, str) + if response.parsed_response: + assert isinstance(response.parsed_response, types.prompts.ParsedResponse) diff --git a/tests/unit/agentplatform/genai/replays/test_public_generate_rubrics.py b/tests/unit/agentplatform/genai/replays/test_public_generate_rubrics.py new file mode 100644 index 0000000000..c03351e3cc --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_public_generate_rubrics.py @@ -0,0 +1,215 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types +import pandas as pd + +_TEST_RUBRIC_GENERATION_PROMPT = """SPECIAL INSTRUCTION: think silently. Silent thinking token budget: 16384. + +You are a teacher who is responsible for scoring a student\'s response to a prompt. In order to score that response, you must write down a rubric for each prompt. That rubric states what properties the response must have in order to be a valid response to the prompt. Properties are weighted by importance via the "importance" field. + +Rubric requirements: +- Properties either exist or don\'t exist. +- Properties can be either implicit in the prompt or made explicit by the prompt. +- Make sure to always include the correct expected human language as one of the properties. If the prompt asks for code, the programming language should be covered by a separate property. +- The correct expected language may be explicit in the text of the prompt but is usually simply implicit in the prompt itself. +- Be as comprehensive as possible with the list of properties in the rubric. +- All properties in the rubric must be in English, regardless of the language of the prompt. +- Rubric properties should not specify correct answers in their descriptions, e.g. to math and factoid questions if the prompt calls for such an answer. Rather, it should check that the response contains an answer and optional supporting evidence if relevant, and assume some other process will later validate correctness. A rubric property should however call out any false premises present in the prompt. + +About importance: +- Most properties will be of medium importance by default. +- Properties of high importance are critical to be fulfilled in a good response. +- Properties of low importance are considered optional or supplementary nice-to-haves. + +You will see prompts in many different languages, not just English. For each prompt you see, you will write down this rubric in JSON format. + +IMPORTANT: Never respond to the prompt given. Only write a rubric. + +Example: +What is the tallest building in the world? + +```json +{ + "criteria":[ + { + "rubric_id": "00001", + "property": "The response is in English.", + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "high" + }, + { + "rubric_id": "00002", + "property": "Contains the name of the tallest building in the world.", + "type": "QA_ANSWER:FACTOID", + "importance": "high" + }, + { + "rubric_id": "00003", + "property": "Contains the exact height of the tallest building.", + "type": "QA_SUPPORTING_EVIDENCE:HEIGHT", + "importance": "low" + }, + { + "rubric_id": "00004", + "property": "Contains the location of the tallest building.", + "type": "QA_SUPPORTING_EVIDENCE:LOCATION", + "importance": "low" + }, + ... + ] +} +``` + +Write me a letter to my HOA asking them to reconsider the fees they are asking me to pay because I haven\'t mowed my lawn on time. I have been very busy at work. +```json +{ + "criteria": [ + { + "rubric_id": "00001", + "property": "The response is in English.", + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "high" + }, + { + "rubric_id": "00002", + "property": "The response is formatted as a letter.", + "type": "FORMAT_REQUIREMENT:FORMAL_LETTER", + "importance": "medium" + }, + { + "rubric_id": "00003", + "property": "The letter is addressed to the Homeowners Association (HOA).", + "type": "CONTENT_REQUIREMENT:ADDRESSEE", + "importance": "medium" + }, + { + "rubric_id": "00004", + "property": "The letter explains that the sender has not mowed their lawn on time.", + "type": "CONTENT_REQUIREMENT:BACKGROUND_CONTEXT:TARDINESS", + "importance": "medium" + }, + { + "rubric_id": "00005", + "property": "The letter provides a reason for not mowing the lawn, specifically being busy at work.", + "type": "CONTENT_REQUIREMENT:EXPLANATION:EXCUSE:BUSY", + "importance": "medium" + }, + { + "rubric_id": "00006", + "property": "The letter discusses that the sender has been in compliance until now.", + "type": "OPTIONAL_CONTENT:SUPPORTING_EVIDENCE:COMPLIANCE", + "importance": "low" + }, + { + "rubric_id": "00007", + "property": "The letter requests that the HOA reconsider the fees associated with not mowing the lawn on time.", + "type": "CONTENT_REQUIREMENT:REQUEST:FEE_WAIVER", + "importance": "high" + }, + { + "rubric_id": "00008", + "property": "The letter maintains a polite and respectful tone.", + "type": "CONTENT_REQUIREMENT:FORMALITY:FORMAL", + "importance": "high" + }, + { + "rubric_id": "00009", + "property": "The letter includes a closing (e.g., \'Sincerely\') and the sender\'s name.", + "type": "CONTENT_REQUIREMENT:SIGNATURE", + "importance": "medium" + } + ] +} +``` + +Now write a rubric for the following user prompt. Remember to write only the rubric, NOT response to the prompt. + +User prompt: +{prompt}""" + +_PROMPTS_DF = pd.DataFrame( + { + "prompt": [ + "Explain the theory of relativity in one sentence.", + "Write a short poem about a cat.", + ] + } +) + + +def test_public_method_generate_rubrics(client): + """Tests the public generate_rubrics method.""" + + eval_dataset = client.evals.generate_rubrics( + src=_PROMPTS_DF, + prompt_template=_TEST_RUBRIC_GENERATION_PROMPT, + rubric_group_name="text_quality_rubrics", + ) + eval_dataset_df = eval_dataset.eval_dataset_df + + # Assertions focus on the returned DataFrame + assert isinstance(eval_dataset, types.EvaluationDataset) + assert isinstance(eval_dataset_df, pd.DataFrame) + assert "rubric_groups" in eval_dataset_df.columns + assert len(eval_dataset_df) == 2 + + # Check the structure of the first row's rubric_groups + first_rubric_group = eval_dataset_df["rubric_groups"][0] + assert isinstance(first_rubric_group, dict) + assert "text_quality_rubrics" in first_rubric_group + assert isinstance(first_rubric_group["text_quality_rubrics"], list) + assert first_rubric_group["text_quality_rubrics"] + assert isinstance(first_rubric_group["text_quality_rubrics"][0], types.evals.Rubric) + + +def test_public_method_generate_rubrics_with_metric(client): + """Tests the public generate_rubrics method with a metric.""" + client._api_client._http_options.api_version = "v1beta1" + client._api_client._http_options.base_url = ( + "https://us-central1-staging-aiplatform.sandbox.googleapis.com/" + ) + metric_resource_name = "projects/977012026409/locations/us-central1/evaluationMetrics/6048334299558576128" + metric = types.Metric( + name="my_custom_metric", metric_resource_name=metric_resource_name + ) + eval_dataset = client.evals.generate_rubrics( + src=_PROMPTS_DF, rubric_group_name="my_registered_rubrics", metric=metric + ) + eval_dataset_df = eval_dataset.eval_dataset_df + + assert isinstance(eval_dataset, types.EvaluationDataset) + assert isinstance(eval_dataset_df, pd.DataFrame) + assert "rubric_groups" in eval_dataset_df.columns + assert len(eval_dataset_df) == 2 + + first_rubric_group = eval_dataset_df["rubric_groups"][0] + assert isinstance(first_rubric_group, dict) + assert "my_registered_rubrics" in first_rubric_group + assert isinstance(first_rubric_group["my_registered_rubrics"], list) + assert first_rubric_group["my_registered_rubrics"] + assert isinstance( + first_rubric_group["my_registered_rubrics"][0], types.evals.Rubric + ) + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="evals.generate_rubrics", +) diff --git a/tests/unit/vertexai/genai/replays/test_purge_agent_engine_memories.py b/tests/unit/agentplatform/genai/replays/test_purge_agent_engine_memories.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_purge_agent_engine_memories.py rename to tests/unit/agentplatform/genai/replays/test_purge_agent_engine_memories.py index 55a8bd5005..5de7d3bcd9 100644 --- a/tests/unit/vertexai/genai/replays/test_purge_agent_engine_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_purge_agent_engine_memories.py @@ -17,7 +17,7 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper def test_purge_memories(client): diff --git a/tests/unit/agentplatform/genai/replays/test_resources/mock_eval_config.yaml b/tests/unit/agentplatform/genai/replays/test_resources/mock_eval_config.yaml new file mode 100644 index 0000000000..960b646fdf --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_resources/mock_eval_config.yaml @@ -0,0 +1,92 @@ +name: text_quality +prompt_template: ' + + # Instruction + + You are an expert evaluator. Your task is to evaluate the quality of the responses + generated by AI models. + + We will provide you with the user input and an AI-generated response. + + You should first read the user input carefully for analyzing the task, and then + evaluate the quality of the responses based on the Criteria provided in the Evaluation + section below. + + + You will assign the response a rating following the Rating Rubric and Evaluation + Steps. Give step-by-step explanations for your rating, and only choose ratings from + the Rating Rubric. + + + # Evaluation + + ## Metric Definition + + You will be assessing Text Quality, which measures how effectively the text conveys + clear, accurate, and engaging information that directly addresses the user''s prompt, + considering factors like fluency, coherence, relevance, and conciseness. + + + ## Criteria + + Coherence: The response presents ideas in a logical and organized manner, with clear + transitions and a consistent focus, making it easy to follow and understand. + + Fluency: The text flows smoothly and naturally, adhering to grammatical rules and + using appropriate vocabulary. + + Instruction following: The response demonstrates a clear understanding of the task + instructions, satisfying all of the instruction''s requirements. + + Groundedness: The response contains information included only in the context. The + response does not reference any outside information. + + Verbosity: The response is appropriately concise, providing sufficient detail without + using complex language to thoroughly address the prompt without being overly wordy + or excessively brief. + + + ## Rating Rubric + + 5: (Very good). Exceptionally clear, coherent, fluent, and concise. Fully adheres + to instructions and stays grounded. + + 4: (Good). Well-written, coherent, and fluent. Mostly adheres to instructions and + stays grounded. Minor room for improvement. + + 3: (Ok). Adequate writing with decent coherence and fluency. Partially fulfills + instructions and may contain minor ungrounded information. Could be more concise. + + 2: (Bad). Poorly written, lacking coherence and fluency. Struggles to adhere to + instructions and may include ungrounded information. Issues with conciseness. + + 1: (Very bad). Very poorly written, incoherent, and non-fluent. Fails to follow + instructions and contains substantial ungrounded information. Severely lacking in + conciseness. + + + + ## Evaluation Steps + + STEP 1: Assess the response in aspects of all criteria provided. Provide assessment + according to each criterion. + + STEP 2: Score based on the rating rubric. Give a brief rationale to explain your + evaluation considering each individual criterion. + + + # User Inputs and AI-generated Response + + ## User Inputs + + ### Prompt + + {prompt} + + + ## AI-generated Response + + {response} + + ' +version: v1 \ No newline at end of file diff --git a/tests/unit/agentplatform/genai/replays/test_resources/request_4813679498589372416.json b/tests/unit/agentplatform/genai/replays/test_resources/request_4813679498589372416.json new file mode 100644 index 0000000000..d2f085be36 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_resources/request_4813679498589372416.json @@ -0,0 +1,9 @@ +{ + "prompt": { + "text": "If your ball is curving during flight from left to right (for a right-handed golfer), try to keep your knees bent and flexed during the backswing. It\u0027s natural to want to straighten out your back knee during the backswing but try to avoid the impulse. Don\u0027t let your knee travel backward either; keep it flexed in position and underneath the hip. A hook is a ball that travels slightly to the right (for a right-handed golfer) and then dramatically to the left. This happens when the ball has a counterclockwise spin, meaning that it\u0027s being hit from right to left instead of from back to front. Try looking at your grip. If you\u0027re a right-handed golfer and more than two knuckles on your left hand are visible when you hold the club, turn to a \"weaker\" grip and make sure only two knuckles are visible. Make sure your stance isn\u0027t aiming too far to the left. You can try to overcompensate a little to the right, but this can also make the hooking motion worse if you overcompensate too much. Place a golf club down on the ground to make sure you\u0027re aiming straight at your target. \" Sometimes your swing is \"fat,\" others it\u0027s \"thin,\" and your drive doesn\u0027t get as much distance as you\u0027d like. The most common remedy for this problem is keeping your head down and your eye on the ball throughout the backswing. When you move your head back in the backswing, you\u0027re actually increasing the distance between the base of the neck and the bottom of the ball. This makes it much tougher to hit the ball right in your wheelhouse. Keep your eye on the ball and you should be driving longer and more consistently.\n\nProvide a summary of the article in two or three sentences:\n\n" + }, + "candidateResponses": [{ + "candidate": "gemini-2.0-flash-001@default", + "text": "Keep your knees bent during the backswing. Check your grip. Check your aim. Keep your head down." + }] +} \ No newline at end of file diff --git a/tests/unit/agentplatform/genai/replays/test_resources/result_1486082323915997184.json b/tests/unit/agentplatform/genai/replays/test_resources/result_1486082323915997184.json new file mode 100644 index 0000000000..09cdf19419 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_resources/result_1486082323915997184.json @@ -0,0 +1,282 @@ +{ + "evaluationRequest": "projects/503583131166/locations/us-central1/evaluationItems/7119522507803066368", + "candidateResults": [{ + "candidate": "gemini-2.0-flash-001@default", + "metric": "universal", + "score": 0.2857143, + "rubricVerdicts": [{ + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is in English." + } + }, + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The entire response is written in the English language." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is limited to two or three sentences." + } + }, + "type": "FORMAT_REQUIREMENT:LENGTH_CONSTRAINT", + "importance": "HIGH" + }, + "reasoning": "The response contains 9 sentences, exceeding the specified limit of two or three sentences." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response summarizes the provided article about potato planting." + } + }, + "type": "CONTENT_REQUIREMENT:SUMMARY", + "importance": "HIGH" + }, + "reasoning": "The response provides a list of procedural steps rather than a comprehensive summary of the article, which includes information on potato categories and planting seasons in addition to the planting steps." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary includes information regarding different potato categories and their respective planting/harvesting seasons." + } + }, + "type": "CONTENT_REQUIREMENT:SUMMARY_DETAIL:POTATO_CATEGORIES", + "importance": "MEDIUM" + }, + "reasoning": "The response only makes general statements about picking a variety and planting at the right time, without including any specific details about different potato categories or their respective planting/harvesting seasons as described in the article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary includes key steps for planting potatoes in containers, such as soil preparation, fertilization, spacing, and watering." + } + }, + "type": "CONTENT_REQUIREMENT:SUMMARY_DETAIL:CONTAINER_PLANTING", + "importance": "MEDIUM" + }, + "reasoning": "While the response includes steps for soil preparation, fertilization, and watering, it omits the key step of spacing the potatoes, which was detailed in the article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response does not introduce information not present in the provided article." + } + }, + "type": "CONTENT_REQUIREMENT:NO_EXTERNAL_INFORMATION", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "Every piece of information presented in the response can be directly found or inferred from the provided article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary is clear, concise, and coherent." + } + }, + "type": "QUALITY:CLARITY_COHERENCE", + "importance": "MEDIUM" + }, + "reasoning": "The response is a list of instructions rather than a coherent narrative summary. While individual points are clear and concise, the response as a whole does not provide a comprehensive or flowing summary of the article\u0027s content." + }] + }, { + "candidate": "checkpoint_1", + "metric": "universal", + "score": 1.0, + "rubricVerdicts": [{ + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is in English." + } + }, + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The entire response is written in the English language." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is a summary of the provided article about planting potatoes." + } + }, + "type": "CONTENT_REQUIREMENT:SUMMARY", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The response condenses the key information from the provided article, which is about planting potatoes in pots." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response consists of exactly two or three sentences." + } + }, + "type": "FORMAT_REQUIREMENT:LENGTH", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The response contains two sentences, which meets the specified length requirement." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary accurately reflects the key information presented in the article." + } + }, + "type": "CONTENT_REQUIREMENT:ACCURACY", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "All the points mentioned in the summary are directly and accurately derived from the content of the provided article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary covers the main aspects of potato planting and care discussed in the article, such as variety categories/timing, initial planting steps (e.g., soil, fertilizer, spacing), and ongoing care (e.g., watering, adding soil)." + } + }, + "type": "CONTENT_REQUIREMENT:COVERAGE", + "importance": "MEDIUM" + }, + "verdict": true, + "reasoning": "The summary addresses potato varieties and planting timelines, initial planting steps like soil preparation, fertilizer application, and spacing, and ongoing care including watering and maintaining soil moisture. While \"adding soil\" as an ongoing step isn\u0027t explicitly listed, the other major aspects are covered." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response does not introduce any information not present in the original article." + } + }, + "type": "CONTENT_REQUIREMENT:NO_EXTERNAL_INFORMATION", + "importance": "MEDIUM" + }, + "verdict": true, + "reasoning": "All information presented in the summary is directly sourced from the provided article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is concise and avoids unnecessary detail." + } + }, + "type": "FORMAT_REQUIREMENT:CONCISENESS", + "importance": "MEDIUM" + }, + "verdict": true, + "reasoning": "The response is only two sentences long and successfully summarizes the main points of the article without including excessive specific details like exact measurements or timelines." + }] + }, { + "candidate": "checkpoint_2", + "metric": "universal", + "score": 1.0, + "rubricVerdicts": [{ + "evaluatedRubric": { + "content": { + "property": { + "description": "The response is in English." + } + }, + "type": "LANGUAGE:PRIMARY_RESPONSE_LANGUAGE", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The entire response is written in the English language." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response consists of exactly two or three sentences." + } + }, + "type": "FORMAT_REQUIREMENT:LENGTH_CONSTRAINT", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The response contains three sentences, which meets the requirement of being exactly two or three sentences." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The response summarizes the provided article." + } + }, + "type": "CONTENT_REQUIREMENT:SUMMARY", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The response accurately condenses the main topics and instructions provided in the original article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary accurately reflects the information presented in the article." + } + }, + "type": "CONTENT_REQUIREMENT:ACCURACY", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "Every point mentioned in the summary is directly supported by and accurately reflects the content of the original article." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary includes information about potato categories and their corresponding planting and harvesting times." + } + }, + "type": "CONTENT_REQUIREMENT:KEY_INFORMATION:POTATO_TYPES_TIMING", + "importance": "MEDIUM" + }, + "verdict": true, + "reasoning": "The first sentence of the summary explicitly mentions \"different maturity categories (early, mid-season, and late) and their respective planting times.\"" + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary covers key aspects of planting potatoes, specifically in containers, such as soil preparation, fertilization, tuber spacing, and watering." + } + }, + "type": "CONTENT_REQUIREMENT:KEY_INFORMATION:PLANTING_INSTRUCTIONS", + "importance": "HIGH" + }, + "verdict": true, + "reasoning": "The summary explicitly mentions \"preparing the soil, spacing tubers, fertilizing, and watering\" and refers to \"potato cultivation in pots,\" covering all specified aspects." + }, { + "evaluatedRubric": { + "content": { + "property": { + "description": "The summary is coherent and concise." + } + }, + "type": "QUALITY:COHERENCE_CONCISENESS", + "importance": "MEDIUM" + }, + "verdict": true, + "reasoning": "The summary is well-structured, easy to read, and effectively condenses the article\u0027s main points into three sentences without being verbose." + }] + }], + "evaluationRun": "projects/503583131166/locations/us-central1/evaluationRuns/1957799200510967808", + "request": { + "prompt": { + "text": "There is a wide range of potato varieties to choose from, but they generally fall under five basic categories: first early, second early, early main crop, main crop, and late main crop. Knowing what category your selected potato variety falls under will tell you when to plant and harvest it. Plant early varieties as early as March or April. They\u0027ll take between 75-90 days to mature for harvest. However, if you plant them too early, they could be damaged by frost. Plant mid-season potatoes from May to July. They\u0027ll take between 85-110 days to mature for harvest. Mid-season potatoes grow best in warmer climates and temperatures. Plant late season potatoes from July to August. They\u0027ll take between 120-135 days to mature for harvest. These potatoes usually tolerate winter temperatures and frost better than early varieties. The entire bottom of the pot should be covered. Lightly pat the soil down with your hands to make sure that it is firmly packed. If the soil is not solid enough, the potatoes will sink to the very bottom of the pot. Both the fiberglass screen and the pebble/stone layer should be fully covered and no longer visible after you add the soil. Use a scoop or spade to lightly sprinkle granular organic fertilizer on top of the first soil layer in the pot. The amount that you use will depend on the type of fertilizer that you buy and the size of your pot. Read the instructions on the fertilizer\u0027s package carefully and follow them directly. Choose a fertilizer that is explicitly for container planting. An organic or vegetable-specific blend is even better. Once your potatoes are established, you\u0027ll want to feed them liquid fertilizer as often as every 2 weeks. Again, reference the fertilizer\u0027s package for instructions. Space the tubers at least 6 inches (15 cm) apart. Ensure that they are not touching one another or the edge of the pot. Do not overcrowd each pot by planting more than the recommended amount of seed potatoes. When the potatoes are in position, press them gently into the soil and fertilizer mix. It\u0027s likely that your seed potatoes will have eyes all over their skin, so just position the side with the most eyes facing upwards. Use a spade to gently add more soil into the pot, creating an even layer. When you are finished, your seed potatoes will be completely hidden. Use your hands to gently pat down the soil, being careful not to crush your tubers by pressing too hard. Depending on the size of your pot, you may need to add even more growing medium. Use your best judgment and fill it up most of the way, but not all the way to the rim. You will add more soil as the potatoes grow, so account for this as you fill. Use a watering can or hose to slowly add water to the freshly planted pot. Try to cover all of the soil\u0027s exposed surface as you are watering. When water starts to flow out of the pot\u0027s bottom drainage holes, stop watering. Potatoes do best with soil that is moist, but not soggy. Water-logged soil can actually lead to the growth of fungus. You can add a layer of straw to the top of your soil to keep it moist and shaded throughout the day.\n\nProvide a summary of the article in two or three sentences:\n\n" + }, + "candidateResponses": [{ + "candidate": "gemini-2.0-flash-001@default", + "text": "Pick out your potato variety.\nPlant the potatoes at the right time of year.\nLay the fiberglass screen.\nPour the pebbles.\nPour the soil.\nAdd the fertilizer.\nPlace the seed potatoes in the pot.\nAdd a layer of soil.\nWater your potatoes.\n\n" + }, { + "candidate": "checkpoint_1", + "text": "This article provides a guide to planting potatoes in pots, outlining different potato varieties and their planting timelines. It details the steps for planting, including soil preparation, adding fertilizer, spacing tubers, and watering, while emphasizing the importance of proper soil moisture and providing advice on protecting the soil from the sun.\n" + }, { + "candidate": "checkpoint_2", + "text": "This article provides instructions on how to plant potatoes, covering the different maturity categories (early, mid-season, and late) and their respective planting times. The guide details essential steps such as preparing the soil, spacing tubers, fertilizing, and watering to ensure successful potato cultivation in pots. It emphasizes the importance of proper soil moisture and the option of using straw to maintain ideal growing conditions.\n" + }] + } +} \ No newline at end of file diff --git a/tests/unit/agentplatform/genai/replays/test_restore_prompt_version.py b/tests/unit/agentplatform/genai/replays/test_restore_prompt_version.py new file mode 100644 index 0000000000..27bab87097 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_restore_prompt_version.py @@ -0,0 +1,108 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import ( + test_create_prompt, +) +from agentplatform._genai import types +import pytest + + +TEST_PROMPT_DATASET_ID = "6550997480673116160" +TEST_PROMPT_VERSION_ID = "2" + + +prompt_contents_2 = test_create_prompt.TEST_PROMPT.model_copy(deep=True) +prompt_contents_2.prompt_data.variables = None +prompt_contents_2.prompt_data.contents[0].parts[0].text = "Is this Alice?" + + +def test_restore_version(client): + # Create 2 versions on my_prompt + prompt_v1 = client.prompts.create_version( + prompt=test_create_prompt.TEST_PROMPT.model_copy(deep=True), + config=types.CreatePromptVersionConfig( + prompt_display_name="my_prompt", + version_display_name="my_prompt_v1", + ), + ) + prompt_v2 = client.prompts.update( + prompt_id=prompt_v1.prompt_id, + prompt=prompt_contents_2, + config=types.CreatePromptVersionConfig( + prompt_display_name="my_prompt", + version_display_name="my_prompt_v2", + ), + ) + my_prompt_v1_id = prompt_v1.dataset_version.name.split("/")[-1] + my_prompt_v2_id = prompt_v2.dataset_version.name.split("/")[-1] + assert my_prompt_v2_id != my_prompt_v1_id + assert ( + prompt_v1.prompt_data.contents[0].parts[0].text == "Hello, {name}! How are you?" + ) + assert prompt_v1.prompt_data.variables is not None + assert prompt_v2.prompt_data.contents[0].parts[0].text == "Is this Alice?" + assert not prompt_v2.prompt_data.variables + + # Restore version to my_prompt_v1_id + restored_prompt = client.prompts.restore_version( + prompt_id=prompt_v1.prompt_id, + version_id=my_prompt_v1_id, + ) + assert restored_prompt.dataset_version.name.split("/")[-1] == my_prompt_v1_id + assert ( + restored_prompt.prompt_data.contents[0].parts[0].text + == "Hello, {name}! How are you?" + ) + assert restored_prompt.prompt_data.variables is not None + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), + test_method="prompts.restore_version", +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_restore_version_async(client): + # Create 2 versions on my_prompt + prompt_v1 = client.prompts.create_version( + prompt=test_create_prompt.TEST_PROMPT.model_copy(deep=True), + config=types.CreatePromptVersionConfig( + prompt_display_name="my_prompt", version_display_name="my_prompt_v1" + ), + ) + prompt_v2 = client.prompts.update( + prompt_id=prompt_v1.prompt_id, + prompt=prompt_contents_2, + config=types.CreatePromptVersionConfig( + prompt_display_name="my_prompt", version_display_name="my_prompt_v2" + ), + ) + my_prompt_v1_id = prompt_v1.dataset_version.name.split("/")[-1] + my_prompt_v2_id = prompt_v2.dataset_version.name.split("/")[-1] + assert my_prompt_v2_id != my_prompt_v1_id + + # Restore version to my_prompt_v1_id + restored_prompt = await client.aio.prompts.restore_version( + prompt_id=prompt_v1.prompt_id, + version_id=my_prompt_v1_id, + ) + assert restored_prompt.dataset_version.name.split("/")[-1] == my_prompt_v1_id diff --git a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py b/tests/unit/agentplatform/genai/replays/test_retrieve_agent_engine_memories.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py rename to tests/unit/agentplatform/genai/replays/test_retrieve_agent_engine_memories.py index 7b9cef3bc7..8609735f7a 100644 --- a/tests/unit/vertexai/genai/replays/test_retrieve_agent_engine_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_retrieve_agent_engine_memories.py @@ -18,8 +18,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import pagers diff --git a/tests/unit/vertexai/genai/replays/test_run_inference.py b/tests/unit/agentplatform/genai/replays/test_run_inference.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_run_inference.py rename to tests/unit/agentplatform/genai/replays/test_run_inference.py index c090e68272..22e7d42d40 100644 --- a/tests/unit/vertexai/genai/replays/test_run_inference.py +++ b/tests/unit/agentplatform/genai/replays/test_run_inference.py @@ -16,8 +16,8 @@ import pytest -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types pytest.importorskip( diff --git a/tests/unit/vertexai/genai/replays/test_skills_get.py b/tests/unit/agentplatform/genai/replays/test_skills_get.py similarity index 91% rename from tests/unit/vertexai/genai/replays/test_skills_get.py rename to tests/unit/agentplatform/genai/replays/test_skills_get.py index 824c0921f1..d069a1cf15 100644 --- a/tests/unit/vertexai/genai/replays/test_skills_get.py +++ b/tests/unit/agentplatform/genai/replays/test_skills_get.py @@ -1,7 +1,7 @@ """Tests the skills.get() method against the autopush endpoint.""" from google.api_core import exceptions -from tests.unit.vertexai.genai.replays import pytest_helper +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper import pytest PROJECT_ID = "demo-project" diff --git a/tests/unit/vertexai/genai/replays/test_structured_memories.py b/tests/unit/agentplatform/genai/replays/test_structured_memories.py similarity index 96% rename from tests/unit/vertexai/genai/replays/test_structured_memories.py rename to tests/unit/agentplatform/genai/replays/test_structured_memories.py index ce5df9fa45..0c9fe7356c 100644 --- a/tests/unit/vertexai/genai/replays/test_structured_memories.py +++ b/tests/unit/agentplatform/genai/replays/test_structured_memories.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_generate_and_retrieve_profile(client): diff --git a/tests/unit/vertexai/genai/replays/test_update_agent_engine.py b/tests/unit/agentplatform/genai/replays/test_update_agent_engine.py similarity index 92% rename from tests/unit/vertexai/genai/replays/test_update_agent_engine.py rename to tests/unit/agentplatform/genai/replays/test_update_agent_engine.py index edc71a112d..a5e237e162 100644 --- a/tests/unit/vertexai/genai/replays/test_update_agent_engine.py +++ b/tests/unit/agentplatform/genai/replays/test_update_agent_engine.py @@ -15,8 +15,8 @@ # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types def test_agent_engines_update(client): diff --git a/tests/unit/agentplatform/genai/replays/test_update_multimodal_datasets.py b/tests/unit/agentplatform/genai/replays/test_update_multimodal_datasets.py new file mode 100644 index 0000000000..3134a7e13a --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_update_multimodal_datasets.py @@ -0,0 +1,104 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types + +import pytest + +METADATA_SCHEMA_URI = ( + "gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml" +) +BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table" +DATASET = "projects/vertex-sdk-dev/locations/us-central1/datasets/8810841321427173376" + + +def test_update_dataset(client): + dataset = client.datasets._update_multimodal_dataset( + name=DATASET, + display_name="test-display-name (updated with internal method)", + description="test-description (updated with internal method)", + metadata={ + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-display-name (updated with internal method)" + assert dataset.description == "test-description (updated with internal method)" + + +def test_update_dataset_with_public_method(client): + dataset = client.datasets.update_multimodal_dataset( + multimodal_dataset={ + "name": DATASET, + "display_name": "test-display-name (updated with public method)", + "description": "test-description (updated with public method)", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-display-name (updated with public method)" + assert dataset.description == "test-description (updated with public method)" + + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_update_dataset_async(client): + dataset = await client.aio.datasets._update_multimodal_dataset( + name=DATASET, + display_name="test-display-name (updated with internal method)", + description="test-description (updated with internal method)", + metadata={ + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-display-name (updated with internal method)" + assert dataset.description == "test-description (updated with internal method)" + + +@pytest.mark.asyncio +async def test_update_dataset_with_public_method_async(client): + dataset = await client.aio.datasets.update_multimodal_dataset( + multimodal_dataset={ + "name": DATASET, + "display_name": "test-display-name (updated with public method)", + "description": "test-description (updated with public method)", + "metadata": { + "inputConfig": { + "bigquerySource": {"uri": f"bq://{BIGQUERY_TABLE_NAME}"}, + }, + }, + } + ) + assert isinstance(dataset, types.MultimodalDataset) + assert dataset.display_name == "test-display-name (updated with public method)" + assert dataset.description == "test-description (updated with public method)" diff --git a/tests/unit/vertexai/genai/replays/test_update_prompt.py b/tests/unit/agentplatform/genai/replays/test_update_prompt.py similarity index 97% rename from tests/unit/vertexai/genai/replays/test_update_prompt.py rename to tests/unit/agentplatform/genai/replays/test_update_prompt.py index 44ef65668e..7c830a65a9 100644 --- a/tests/unit/vertexai/genai/replays/test_update_prompt.py +++ b/tests/unit/agentplatform/genai/replays/test_update_prompt.py @@ -14,8 +14,8 @@ # # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types from google.genai import types as genai_types import pytest diff --git a/tests/unit/vertexai/genai/replays/test_update_traffic_agent_engine.py b/tests/unit/agentplatform/genai/replays/test_update_traffic_agent_engine.py similarity index 98% rename from tests/unit/vertexai/genai/replays/test_update_traffic_agent_engine.py rename to tests/unit/agentplatform/genai/replays/test_update_traffic_agent_engine.py index 41438e5e95..d60a7cf3d2 100644 --- a/tests/unit/vertexai/genai/replays/test_update_traffic_agent_engine.py +++ b/tests/unit/agentplatform/genai/replays/test_update_traffic_agent_engine.py @@ -15,8 +15,8 @@ # pylint: disable=protected-access,bad-continuation,missing-function-docstring -from tests.unit.vertexai.genai.replays import pytest_helper -from vertexai._genai import types +from google.cloud.aiplatform.tests.unit.agentplatform.genai.replays import pytest_helper +from agentplatform._genai import types _TEST_CLASS_METHODS = [ {"name": "query", "api_mode": ""}, diff --git a/tests/unit/vertexai/genai/run_replay_tests.sh b/tests/unit/agentplatform/genai/run_replay_tests.sh similarity index 94% rename from tests/unit/vertexai/genai/run_replay_tests.sh rename to tests/unit/agentplatform/genai/run_replay_tests.sh index 821a426395..ab0a909a89 100755 --- a/tests/unit/vertexai/genai/run_replay_tests.sh +++ b/tests/unit/agentplatform/genai/run_replay_tests.sh @@ -3,10 +3,10 @@ # This script runs replay tests for the Vertex SDK GenAI client # It is intended to be used from the google3 directory of a CitC client. # You can provide a specific test file to run, or it will run all the replay tests -# in third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/replays/ +# in third_party/py/google/cloud/aiplatform/tests/unit/agentplatform/genai/replays/ # # Example: -# ./third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/run_replay_tests.sh test_evals.py +# ./third_party/py/google/cloud/aiplatform/tests/unit/agentplatform/genai/run_replay_tests.sh test_evals.py # Supported flags: # --mode : Specifies the run mode. Options are: @@ -135,7 +135,7 @@ fi # Set pytest path for which tests to run -DEFAULT_TEST_PATH="tests/unit/vertexai/genai/replays/" +DEFAULT_TEST_PATH="tests/unit/agentplatform/genai/replays/" if [ -n "$TEST_FILE_ARG" ]; then PYTEST_PATH="${DEFAULT_TEST_PATH}${TEST_FILE_ARG}" diff --git a/tests/unit/vertexai/genai/test_agent_engine_runtime_revisions.py b/tests/unit/agentplatform/genai/test_agent_engine_runtime_revisions.py similarity index 99% rename from tests/unit/vertexai/genai/test_agent_engine_runtime_revisions.py rename to tests/unit/agentplatform/genai/test_agent_engine_runtime_revisions.py index c5ab408d10..a2e3b5798e 100644 --- a/tests/unit/vertexai/genai/test_agent_engine_runtime_revisions.py +++ b/tests/unit/agentplatform/genai/test_agent_engine_runtime_revisions.py @@ -25,11 +25,11 @@ from google import auth from google.auth import credentials as auth_credentials from google.cloud import aiplatform -import vertexai +import agentplatform from google.cloud.aiplatform import initializer -from vertexai._genai import _agent_engines_utils -from vertexai._genai import runtime_revisions -from vertexai._genai import types as _genai_types +from agentplatform._genai import _agent_engines_utils +from agentplatform._genai import runtime_revisions +from agentplatform._genai import types as _genai_types from google.genai import types as genai_types import pytest @@ -830,10 +830,10 @@ class TestRuntimeRevisionsHelpers: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) + importlib.reload(agentplatform) importlib.reload(os) os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE - self.client = vertexai.Client( + self.client = agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, @@ -917,10 +917,10 @@ class TestRuntimeRevisions: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) + importlib.reload(agentplatform) importlib.reload(os) os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE - self.client = vertexai.Client( + self.client = agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, @@ -1139,10 +1139,10 @@ class TestAsyncRuntimeRevisions: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) + importlib.reload(agentplatform) importlib.reload(os) os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE - self.client = vertexai.Client( + self.client = agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, diff --git a/tests/unit/vertexai/genai/test_agent_engines.py b/tests/unit/agentplatform/genai/test_agent_engines.py similarity index 100% rename from tests/unit/vertexai/genai/test_agent_engines.py rename to tests/unit/agentplatform/genai/test_agent_engines.py diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/agentplatform/genai/test_evals.py similarity index 92% rename from tests/unit/vertexai/genai/test_evals.py rename to tests/unit/agentplatform/genai/test_evals.py index 1f15ac95a3..00fc7f2dda 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/agentplatform/genai/test_evals.py @@ -25,19 +25,21 @@ import google.auth.credentials from google.cloud import aiplatform -import vertexai +import agentplatform from google.cloud.aiplatform import initializer as aiplatform_initializer -from vertexai import _genai -from vertexai._genai import _evals_data_converters -from vertexai._genai import _evals_metric_handlers -from vertexai._genai import _evals_visualization -from vertexai._genai import _evals_metric_loaders -from vertexai._genai import _gcs_utils -from vertexai._genai import _observability_data_converter -from vertexai._genai import _transformers -from vertexai._genai import evals -from vertexai._genai import types as vertexai_genai_types -from vertexai._genai.types import common as common_types +from agentplatform import _genai +from agentplatform._genai import _evals_data_converters +from agentplatform._genai import _evals_metric_handlers +from agentplatform._genai import _evals_visualization +from agentplatform._genai import _evals_metric_loaders +from agentplatform._genai import _gcs_utils +from agentplatform._genai import _observability_data_converter +from agentplatform._genai import _transformers +from agentplatform._genai import evals +from agentplatform._genai import ( + types as agentplatform_genai_types, +) +from agentplatform._genai.types import common as common_types from google.genai import client from google.genai import errors as genai_errors from google.genai import types as genai_types @@ -126,13 +128,13 @@ def mock_eval_dependencies(mock_api_client_fixture): mock.patch("google.cloud.storage.Client") as mock_storage_client, mock.patch("google.cloud.bigquery.Client") as mock_bq_client, mock.patch( - "vertexai._genai.evals.Evals.evaluate_instances" + "google.cloud.aiplatform.agentplatform._genai.evals.Evals.evaluate_instances" ) as mock_evaluate_instances, mock.patch( - "vertexai._genai._gcs_utils.GcsUtils.upload_json_to_prefix" + "google.cloud.aiplatform.agentplatform._genai._gcs_utils.GcsUtils.upload_json_to_prefix" ) as mock_upload_to_gcs, mock.patch( - "vertexai._genai._evals_metric_loaders.LazyLoadedPrebuiltMetric._fetch_and_parse" + "google.cloud.aiplatform.agentplatform._genai._evals_metric_loaders.LazyLoadedPrebuiltMetric._fetch_and_parse" ) as mock_fetch_prebuilt_metric, ): # fmt: on @@ -140,31 +142,31 @@ def mock_eval_dependencies(mock_api_client_fixture): def mock_evaluate_instances_side_effect(*args, **kwargs): metric_config = kwargs.get("metric_config", {}) if "exact_match_input" in metric_config: - return vertexai_genai_types.EvaluateInstancesResponse( - exact_match_results=vertexai_genai_types.ExactMatchResults( + return agentplatform_genai_types.EvaluateInstancesResponse( + exact_match_results=agentplatform_genai_types.ExactMatchResults( exact_match_metric_values=[ genai_types.ExactMatchMetricValue(score=1.0) ] ) ) elif "rouge_input" in metric_config: - return vertexai_genai_types.EvaluateInstancesResponse( - rouge_results=vertexai_genai_types.RougeResults( + return agentplatform_genai_types.EvaluateInstancesResponse( + rouge_results=agentplatform_genai_types.RougeResults( rouge_metric_values=[genai_types.RougeMetricValue(score=0.8)] ) ) elif "comet_input" in metric_config: - return vertexai_genai_types.EvaluateInstancesResponse( - comet_result=vertexai_genai_types.CometResult(score=0.75) + return agentplatform_genai_types.EvaluateInstancesResponse( + comet_result=agentplatform_genai_types.CometResult(score=0.75) ) - return vertexai_genai_types.EvaluateInstancesResponse() + return agentplatform_genai_types.EvaluateInstancesResponse() mock_evaluate_instances.side_effect = mock_evaluate_instances_side_effect mock_upload_to_gcs.return_value = ( "gs://mock-bucket/mock_path/evaluation_result_timestamp.json" ) - mock_prebuilt_fluency_metric = vertexai_genai_types.LLMMetric( + mock_prebuilt_fluency_metric = agentplatform_genai_types.LLMMetric( name="fluency", prompt_template="Is this fluent? {response}" ) mock_prebuilt_fluency_metric._is_predefined = True @@ -186,43 +188,49 @@ def mock_evaluate_instances_side_effect(*args, **kwargs): class TestGetApiClientWithLocation: - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_get_api_client_with_location_override( - self, mock_vertexai_client, mock_api_client_fixture + self, mock_agentplatform_client, mock_api_client_fixture ): mock_api_client_fixture.location = "us-central1" new_location = "europe-west1" _evals_common._get_api_client_with_location( mock_api_client_fixture, new_location ) - mock_vertexai_client.assert_called_once_with( + mock_agentplatform_client.assert_called_once_with( project=mock_api_client_fixture.project, location=new_location, credentials=mock_api_client_fixture._credentials, http_options=mock_api_client_fixture._http_options, ) - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_get_api_client_with_same_location( - self, mock_vertexai_client, mock_api_client_fixture + self, mock_agentplatform_client, mock_api_client_fixture ): mock_api_client_fixture.location = "us-central1" new_location = "us-central1" _evals_common._get_api_client_with_location( mock_api_client_fixture, new_location ) - mock_vertexai_client.assert_not_called() + mock_agentplatform_client.assert_not_called() - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_get_api_client_with_none_location( - self, mock_vertexai_client, mock_api_client_fixture + self, mock_agentplatform_client, mock_api_client_fixture ): mock_api_client_fixture.location = "us-central1" new_location = None _evals_common._get_api_client_with_location( mock_api_client_fixture, new_location ) - mock_vertexai_client.assert_not_called() + mock_agentplatform_client.assert_not_called() class TestTransformers: @@ -292,13 +300,13 @@ def test_t_inline_results_sanitizes_agent_data(self): common_types.EvaluationDataset( eval_cases=[ common_types.EvalCase( - agent_data=vertexai_genai_types.evals.AgentData( + agent_data=agentplatform_genai_types.evals.AgentData( turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, turn_id="turn_0", events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="user", content=genai_types.Content( role="user", @@ -307,7 +315,7 @@ def test_t_inline_results_sanitizes_agent_data(self): ], ), ), - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="model", content=genai_types.Content( role="model", @@ -321,7 +329,7 @@ def test_t_inline_results_sanitizes_agent_data(self): ], ), ), - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="model", content=genai_types.Content( role="model", @@ -469,9 +477,9 @@ def test_t_inline_results_strips_none_tool_fields(self): common_types.EvaluationDataset( eval_cases=[ common_types.EvalCase( - agent_data=vertexai_genai_types.evals.AgentData( + agent_data=agentplatform_genai_types.evals.AgentData( agents={ - "agent_0": vertexai_genai_types.evals.AgentConfig( + "agent_0": agentplatform_genai_types.evals.AgentConfig( agent_id="agent_0", agent_type="LlmAgent", instruction="You are a helper.", @@ -488,10 +496,10 @@ def test_t_inline_results_strips_none_tool_fields(self): ) }, turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="user", content=genai_types.Content( parts=[genai_types.Part(text="Hi")], @@ -607,7 +615,7 @@ def test_response_structure(self): def test_get_loss_analysis_html(self): """Tests that _get_loss_analysis_html generates valid HTML with data.""" - from vertexai._genai import _evals_visualization + from agentplatform._genai import _evals_visualization import json data = { @@ -678,7 +686,7 @@ def test_get_loss_analysis_html(self): def test_display_loss_clusters_response_no_ipython(self): """Tests graceful fallback when not in IPython.""" - from vertexai._genai import _evals_visualization + from agentplatform._genai import _evals_visualization from unittest import mock response = common_types.GenerateLossClustersResponse( @@ -709,7 +717,7 @@ def test_display_loss_clusters_response_no_ipython(self): def test_display_loss_analysis_result_no_ipython(self): """Tests graceful fallback for individual result when not in IPython.""" - from vertexai._genai import _evals_visualization + from agentplatform._genai import _evals_visualization from unittest import mock result = common_types.LossAnalysisResult( @@ -735,7 +743,7 @@ def test_display_loss_analysis_result_no_ipython(self): def test_enrich_scenario_from_agent_data_in_eval_cases(self): """Tests scenario extraction from agent_data in eval_cases.""" - from vertexai._genai import _evals_utils + from agentplatform._genai import _evals_utils # API response: evaluation_result has NO request (real API behavior) api_response = common_types.GenerateLossClustersResponse( @@ -797,12 +805,12 @@ def test_enrich_scenario_from_agent_data_in_eval_cases(self): common_types.EvaluationDataset( eval_cases=[ common_types.EvalCase( - agent_data=vertexai_genai_types.evals.AgentData( + agent_data=agentplatform_genai_types.evals.AgentData( turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="user", content={ "parts": [ @@ -836,7 +844,7 @@ def test_enrich_scenario_from_agent_data_in_eval_cases(self): def test_enrich_scenario_from_user_scenario_starting_prompt(self): """Tests scenario extraction from user_scenario.starting_prompt.""" - from vertexai._genai import _evals_utils + from agentplatform._genai import _evals_utils api_response = common_types.GenerateLossClustersResponse( results=[ @@ -891,7 +899,7 @@ def test_enrich_scenario_from_user_scenario_starting_prompt(self): common_types.EvaluationDataset( eval_cases=[ common_types.EvalCase( - user_scenario=vertexai_genai_types.evals.UserScenario( + user_scenario=agentplatform_genai_types.evals.UserScenario( starting_prompt="I want to book a hotel in Tokyo.", conversation_plan="User asks to book a hotel.", ) @@ -917,7 +925,7 @@ def test_enrich_scenario_from_user_scenario_starting_prompt(self): def test_enrich_scenario_from_dataframe_agent_data(self): """Tests scenario extraction from DataFrame agent_data column.""" import pandas as pd - from vertexai._genai import _evals_utils + from agentplatform._genai import _evals_utils api_response = common_types.GenerateLossClustersResponse( results=[ @@ -952,12 +960,12 @@ def test_enrich_scenario_from_dataframe_agent_data(self): ], ) # eval_result with agent_data in DataFrame (run_inference output) - agent_data_obj = vertexai_genai_types.evals.AgentData( + agent_data_obj = agentplatform_genai_types.evals.AgentData( turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="user", content={"parts": [{"text": "Find flights to London"}]}, ), @@ -998,14 +1006,14 @@ def test_enrich_scenario_from_dataframe_agent_data(self): def test_enrich_scenario_e2e_simulation(self): """Simulates the full e2e flow: generate_scenarios -> run_inference -> evaluate -> loss_clusters.""" import pandas as pd - from vertexai._genai import _evals_utils + from agentplatform._genai import _evals_utils # Step 1: Simulate generate_conversation_scenarios output # This creates eval_cases with user_scenario but no agent_data scenario_dataset = common_types.EvaluationDataset( eval_cases=[ common_types.EvalCase( - user_scenario=vertexai_genai_types.evals.UserScenario( + user_scenario=agentplatform_genai_types.evals.UserScenario( starting_prompt="I need to book a flight from NYC to Paris for next Friday.", conversation_plan="User books a flight.", ) @@ -1024,17 +1032,17 @@ def test_enrich_scenario_e2e_simulation(self): # Step 2: Simulate run_inference output # run_inference extracts eval_dataset_df from the input, runs inference, # then returns a NEW EvaluationDataset with only eval_dataset_df (no eval_cases) - agent_data_obj = vertexai_genai_types.evals.AgentData( + agent_data_obj = agentplatform_genai_types.evals.AgentData( agents={ - "travel_agent": vertexai_genai_types.evals.AgentConfig( + "travel_agent": agentplatform_genai_types.evals.AgentConfig( agent_id="travel_agent", ) }, turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="user", content=genai_types.Content( parts=[ @@ -1045,7 +1053,7 @@ def test_enrich_scenario_e2e_simulation(self): role="user", ), ), - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="travel_agent", content=genai_types.Content( parts=[ @@ -1177,7 +1185,7 @@ def test_enrich_scenario_e2e_simulation(self): def test_enrich_scenario_from_dataframe_starting_prompt(self): """Tests scenario extraction from DataFrame starting_prompt column.""" import pandas as pd - from vertexai._genai import _evals_utils + from agentplatform._genai import _evals_utils api_response = common_types.GenerateLossClustersResponse( results=[ @@ -1973,23 +1981,27 @@ class TestEvals: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) - self.client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + self.client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) @pytest.mark.usefixtures("google_auth_mock") @mock.patch.object(client.Client, "_get_api_client") @mock.patch.object(evals.Evals, "batch_evaluate") def test_eval_batch_evaluate(self, mock_evaluate, mock_get_api_client): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.evals.batch_evaluate( - dataset=vertexai_genai_types.EvaluationDataset(), - metrics=[vertexai_genai_types.Metric(name="test")], + dataset=agentplatform_genai_types.EvaluationDataset(), + metrics=[agentplatform_genai_types.Metric(name="test")], dest="gs://bucket/output", - config=vertexai_genai_types.EvaluateDatasetConfig(), + config=agentplatform_genai_types.EvaluateDatasetConfig(), ) mock_evaluate.assert_called_once() @@ -1997,7 +2009,7 @@ def test_eval_batch_evaluate(self, mock_evaluate, mock_get_api_client): @mock.patch.object(_evals_common, "_execute_evaluation") def test_eval_evaluate_with_agent_info(self, mock_execute_evaluation): """Tests that agent_info is passed to _execute_evaluation.""" - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame([{"prompt": "p1", "response": "r1"}]) ) agent_info = { @@ -2006,7 +2018,7 @@ def test_eval_evaluate_with_agent_info(self, mock_execute_evaluation): } self.client.evals.evaluate( dataset=dataset, - metrics=[vertexai_genai_types.Metric(name="exact_match")], + metrics=[agentplatform_genai_types.Metric(name="exact_match")], agent_info=agent_info, ) mock_execute_evaluation.assert_called_once() @@ -2017,7 +2029,7 @@ def test_eval_evaluate_with_agent_info(self, mock_execute_evaluation): class TestEvalsVisualization: @mock.patch( - "vertexai._genai._evals_visualization._is_ipython_env", + "google.cloud.aiplatform.agentplatform._genai._evals_visualization._is_ipython_env", return_value=True, ) def test_display_evaluation_result_with_agent_trace_prefixes(self, mock_is_ipython): @@ -2058,17 +2070,17 @@ def test_display_evaluation_result_with_agent_trace_prefixes(self, mock_is_ipyth }, ] ) - eval_dataset = vertexai_genai_types.EvaluationDataset( + eval_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - eval_result = vertexai_genai_types.EvaluationResult( + eval_result = agentplatform_genai_types.EvaluationResult( evaluation_dataset=[eval_dataset], - agent_info=vertexai_genai_types.evals.AgentInfo(name="test_agent"), + agent_info=agentplatform_genai_types.evals.AgentInfo(name="test_agent"), eval_case_results=[ - vertexai_genai_types.EvalCaseResult( + agentplatform_genai_types.EvalCaseResult( eval_case_index=0, response_candidate_results=[ - vertexai_genai_types.ResponseCandidateResult( + agentplatform_genai_types.ResponseCandidateResult( response_index=0, metric_results={} ) ], @@ -2090,7 +2102,7 @@ def test_display_evaluation_result_with_agent_trace_prefixes(self, mock_is_ipyth del sys.modules["IPython.display"] @mock.patch( - "vertexai._genai._evals_visualization._is_ipython_env", + "google.cloud.aiplatform.agentplatform._genai._evals_visualization._is_ipython_env", return_value=True, ) def test_display_evaluation_result_with_non_ascii_character(self, mock_is_ipython): @@ -2109,16 +2121,16 @@ def test_display_evaluation_result_with_non_ascii_character(self, mock_is_ipytho }, ] ) - eval_dataset = vertexai_genai_types.EvaluationDataset( + eval_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - eval_result = vertexai_genai_types.EvaluationResult( + eval_result = agentplatform_genai_types.EvaluationResult( evaluation_dataset=[eval_dataset], eval_case_results=[ - vertexai_genai_types.EvalCaseResult( + agentplatform_genai_types.EvalCaseResult( eval_case_index=0, response_candidate_results=[ - vertexai_genai_types.ResponseCandidateResult( + agentplatform_genai_types.ResponseCandidateResult( response_index=0, metric_results={} ) ], @@ -2157,9 +2169,9 @@ class TestEvalsRunInference: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) + importlib.reload(agentplatform) importlib.reload(_genai.client) - importlib.reload(vertexai_genai_types) + importlib.reload(agentplatform_genai_types) importlib.reload(_evals_utils) importlib.reload(_evals_data_converters) importlib.reload(_evals_common) @@ -2169,11 +2181,13 @@ def setup_method(self): if hasattr(_evals_common._thread_local_data, "agent_engine_instances"): del _evals_common._thread_local_data.agent_engine_instances - vertexai.init( + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) - self.client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + self.client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) @mock.patch.object(_evals_common, "Models") @mock.patch.object(_evals_utils, "EvalDatasetLoader") @@ -2423,7 +2437,7 @@ def test_inference_with_prompt_template( mock_generate_content_response ) - config = vertexai_genai_types.EvalRunInferenceConfig( + config = agentplatform_genai_types.EvalRunInferenceConfig( prompt_template="Hello {text_input}" ) inference_result = self.client.evals.run_inference( @@ -2472,7 +2486,7 @@ def test_inference_with_gcs_destination( ) gcs_dest_dir = "gs://bucket/output" - config = vertexai_genai_types.EvalRunInferenceConfig(dest=gcs_dest_dir) + config = agentplatform_genai_types.EvalRunInferenceConfig(dest=gcs_dest_dir) inference_result = self.client.evals.run_inference( model="gemini-pro", src=mock_df, config=config @@ -2531,7 +2545,9 @@ def test_inference_with_local_destination( ) with tempfile.TemporaryDirectory() as local_dest_dir: - config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_dir) + config = agentplatform_genai_types.EvalRunInferenceConfig( + dest=local_dest_dir + ) inference_result = self.client.evals.run_inference( model="gemini-pro", src=mock_df, config=config @@ -2592,7 +2608,9 @@ def test_inference_from_request_column_save_to_local_dir( ) with tempfile.TemporaryDirectory() as local_dest_dir: - config = vertexai_genai_types.EvalRunInferenceConfig(dest=local_dest_dir) + config = agentplatform_genai_types.EvalRunInferenceConfig( + dest=local_dest_dir + ) inference_result = self.client.evals.run_inference( model="gemini-pro", src=mock_df, config=config @@ -3009,7 +3027,7 @@ def test_inference_with_multimodal_content( mock_generate_content_response ) - config = vertexai_genai_types.EvalRunInferenceConfig( + config = agentplatform_genai_types.EvalRunInferenceConfig( prompt_template="multimodal prompt: {media_content}{text_input}" ) inference_result = self.client.evals.run_inference( @@ -3048,10 +3066,12 @@ def test_inference_with_multimodal_content( assert inference_result.gcs_source is None @mock.patch.object(_evals_utils, "EvalDatasetLoader") - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_run_inference_with_agent_engine_and_session_inputs_dict( self, - mock_vertexai_client, + mock_agentplatform_client, mock_eval_dataset_loader, ): mock_df = pd.DataFrame( @@ -3087,7 +3107,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( ] mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) - mock_vertexai_client.return_value.agent_engines.get.return_value = ( + mock_agentplatform_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -3097,7 +3117,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( ) mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df) - mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( + mock_agentplatform_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) mock_agent_engine.create_session.assert_called_once_with( @@ -3161,10 +3181,12 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict( assert inference_result.gcs_source is None @mock.patch.object(_evals_utils, "EvalDatasetLoader") - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_run_inference_with_agent_engine_and_session_inputs_literal_string( self, - mock_vertexai_client, + mock_agentplatform_client, mock_eval_dataset_loader, ): session_inputs_str = '{"user_id": "123", "state": {"a": "1"}}' @@ -3196,7 +3218,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( ] mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) - mock_vertexai_client.return_value.agent_engines.get.return_value = ( + mock_agentplatform_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -3206,7 +3228,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( ) mock_eval_dataset_loader.return_value.load.assert_called_once_with(mock_df) - mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with( + mock_agentplatform_client.return_value.agent_engines.get.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123" ) mock_agent_engine.create_session.assert_called_once_with( @@ -3265,10 +3287,12 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string( assert inference_result.gcs_source is None @mock.patch.object(_evals_utils, "EvalDatasetLoader") - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_run_inference_with_agent_engine_with_response_column_raises_error( self, - mock_vertexai_client, + mock_agentplatform_client, mock_eval_dataset_loader, ): mock_df = pd.DataFrame( @@ -3288,7 +3312,7 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error( ) mock_agent_engine = mock.Mock() - mock_vertexai_client.return_value.agent_engines.get.return_value = ( + mock_agentplatform_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -3303,10 +3327,12 @@ def test_run_inference_with_agent_engine_with_response_column_raises_error( ) in str(excinfo.value) @mock.patch.object(_evals_utils, "EvalDatasetLoader") - @mock.patch("vertexai._genai._evals_common.vertexai.Client") + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.agentplatform.Client" + ) def test_run_inference_with_agent_engine_falls_back_to_managed_sessions_api( self, - mock_vertexai_client, + mock_agentplatform_client, mock_eval_dataset_loader, ): """Tests that run_inference falls back to the managed Sessions API @@ -3360,7 +3386,7 @@ def test_run_inference_with_agent_engine_falls_back_to_managed_sessions_api( }, ] mock_agent_engine.stream_query.return_value = iter(stream_query_return_value) - mock_vertexai_client.return_value.agent_engines.get.return_value = ( + mock_agentplatform_client.return_value.agent_engines.get.return_value = ( mock_agent_engine ) @@ -3373,7 +3399,7 @@ def test_run_inference_with_agent_engine_falls_back_to_managed_sessions_api( mock_agent_engine.api_client.sessions.create.assert_called_once_with( name="projects/test-project/locations/us-central1/reasoningEngines/123", user_id="123", - config=vertexai_genai_types.CreateAgentEngineSessionConfig( + config=agentplatform_genai_types.CreateAgentEngineSessionConfig( session_state={"a": "1"}, ), ) @@ -3619,9 +3645,9 @@ def test_run_inference_with_litellm_string_prompt_format( """Tests inference with LiteLLM using a simple prompt string.""" # fmt: off with mock.patch( - "vertexai._genai._evals_common.litellm" + "google.cloud.aiplatform.agentplatform._genai._evals_common.litellm" ) as mock_litellm, mock.patch( - "vertexai._genai._evals_common._call_litellm_completion" + "google.cloud.aiplatform.agentplatform._genai._evals_common._call_litellm_completion" ) as mock_call_litellm_completion: mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None) prompt_df = pd.DataFrame([{"prompt": "What is LiteLLM?"}]) @@ -3677,10 +3703,10 @@ def test_run_inference_with_litellm_openai_request_format( # fmt: off with ( mock.patch( - "vertexai._genai._evals_common.litellm" + "google.cloud.aiplatform.agentplatform._genai._evals_common.litellm" ) as mock_litellm, mock.patch( - "vertexai._genai._evals_common._call_litellm_completion" + "google.cloud.aiplatform.agentplatform._genai._evals_common._call_litellm_completion" ) as mock_call_litellm_completion, ): # fmt: on @@ -3756,7 +3782,7 @@ def test_run_inference_with_unsupported_model_string( mock_api_client_fixture, ): with mock.patch( - "vertexai._genai._evals_common.litellm" + "google.cloud.aiplatform.agentplatform._genai._evals_common.litellm" ) as mock_litellm_package: mock_litellm_package.get_llm_provider.side_effect = ValueError( "unsupported model" @@ -3769,7 +3795,9 @@ def test_run_inference_with_unsupported_model_string( model="some-random-model/name", src=prompt_df ) - @mock.patch("vertexai._genai._evals_common.litellm", None) + @mock.patch( + "google.cloud.aiplatform.agentplatform._genai._evals_common.litellm", None + ) def test_run_inference_with_litellm_import_error(self, mock_api_client_fixture): evals_module = evals.Evals(api_client_=mock_api_client_fixture) prompt_df = pd.DataFrame([{"prompt": "test"}]) @@ -3825,7 +3853,7 @@ def test_run_inference_with_litellm_parsing( ] mock_run_litellm_inference.return_value = raw_responses # fmt: off - with mock.patch("vertexai._genai._evals_common.litellm") as mock_litellm: + with mock.patch("google.cloud.aiplatform.agentplatform._genai._evals_common.litellm") as mock_litellm: # fmt: on mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None) inference_result = self.client.evals.run_inference( @@ -3893,7 +3921,7 @@ class TestEvalsMetricHandlers: def test_has_tool_call_with_tool_call(self): events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="1", content=genai_types.Content( parts=[ @@ -3910,7 +3938,7 @@ def test_has_tool_call_with_tool_call(self): def test_has_tool_call_no_tool_call(self): events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="1", content=genai_types.Content(parts=[genai_types.Part(text="hello")]), ) @@ -3925,11 +3953,11 @@ def test_has_tool_call_none_events(self): def test_has_tool_call_mixed_events(self): events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="1", content=genai_types.Content(parts=[genai_types.Part(text="hello")]), ), - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="2", content=genai_types.Content( parts=[ @@ -3946,7 +3974,7 @@ def test_has_tool_call_mixed_events(self): def test_has_tool_call_with_agent_event(self): events = [ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="model", content=genai_types.Content( parts=[ @@ -3971,7 +3999,7 @@ def test_run_agent_rewrites_gemini_3_model_name( self, mock_execute_inference_concurrently, mock_api_client_fixture ): mock_execute_inference_concurrently.return_value = [] - user_simulator_config = vertexai_genai_types.evals.UserSimulatorConfig( + user_simulator_config = agentplatform_genai_types.evals.UserSimulatorConfig( model_name="gemini-3-preview" ) prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]}) @@ -4003,7 +4031,7 @@ def mock_execute(*args, **kwargs): def test_run_agent_raises_error_if_gemini_3_and_allow_cross_region_model_false( self, mock_execute_inference_concurrently, mock_api_client_fixture ): - user_simulator_config = vertexai_genai_types.evals.UserSimulatorConfig( + user_simulator_config = agentplatform_genai_types.evals.UserSimulatorConfig( model_name="gemini-3-preview" ) prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]}) @@ -4028,7 +4056,7 @@ def test_run_agent_rewrites_gemini_3_model_name_empty_env( self, mock_execute_inference_concurrently, mock_api_client_fixture ): mock_execute_inference_concurrently.return_value = [] - user_simulator_config = vertexai_genai_types.evals.UserSimulatorConfig( + user_simulator_config = agentplatform_genai_types.evals.UserSimulatorConfig( model_name="gemini-3-preview" ) prompt_dataset = pd.DataFrame({"prompt": ["prompt1"]}) @@ -4060,7 +4088,7 @@ class TestRunAgentInternal: """Unit tests for the _run_agent_internal function.""" def setup_method(self): - importlib.reload(vertexai_genai_types) + importlib.reload(agentplatform_genai_types) importlib.reload(_evals_common) @mock.patch.object(_evals_common, "_run_agent") @@ -4346,7 +4374,9 @@ class TestIsMultiTurnAgentSimulation: """Unit tests for the _is_multi_turn_agent_simulation function.""" def test_is_multi_turn_agent_simulation_with_config(self): - config = vertexai_genai_types.evals.UserSimulatorConfig(model_name="gemini-pro") + config = agentplatform_genai_types.evals.UserSimulatorConfig( + model_name="gemini-pro" + ) assert _evals_common._is_multi_turn_agent_simulation( user_simulator_config=config, prompt_dataset=pd.DataFrame() ) @@ -4370,7 +4400,7 @@ class TestMetricPromptBuilder: def test_metric_prompt_builder_minimal_fields(self): criteria = {"criterion1": "definition1"} rating_scores = {"score1": "description1"} - builder = vertexai_genai_types.MetricPromptBuilder( + builder = agentplatform_genai_types.MetricPromptBuilder( criteria=criteria, rating_scores=rating_scores, metric_definition=None, @@ -4410,7 +4440,7 @@ def test_metric_prompt_builder_all_fields(self): metric_definition = "Custom metric definition." evaluation_steps = {"step1": "custom step 1"} few_shot_examples = ["example1", "example2"] - builder = vertexai_genai_types.MetricPromptBuilder( + builder = agentplatform_genai_types.MetricPromptBuilder( criteria=criteria, rating_scores=rating_scores, instruction=instruction, @@ -4432,7 +4462,7 @@ def test_metric_prompt_builder_all_fields(self): def test_metric_prompt_builder_default_instruction_and_steps_in_text(self): criteria = {"c1": "v1"} rating_scores = {"s1": "d1"} - builder = vertexai_genai_types.MetricPromptBuilder( + builder = agentplatform_genai_types.MetricPromptBuilder( criteria=criteria, rating_scores=rating_scores, metric_definition=None, @@ -4468,7 +4498,7 @@ def test_metric_prompt_builder_custom_instruction_and_steps_in_text(self): rating_scores = {"s1": "d1"} custom_instruction = "My custom instructions." custom_steps = {"Step 1": "Do this first.", "Step 2": "Then do that."} - builder = vertexai_genai_types.MetricPromptBuilder( + builder = agentplatform_genai_types.MetricPromptBuilder( criteria=criteria, rating_scores=rating_scores, instruction=custom_instruction, @@ -4485,7 +4515,7 @@ def test_metric_prompt_builder_missing_criteria_raises_error(self): ValueError, match="Both 'criteria' and 'rating_scores' are required", ): - vertexai_genai_types.MetricPromptBuilder( + agentplatform_genai_types.MetricPromptBuilder( rating_scores={"score1": "description1"}, criteria=None, metric_definition=None, @@ -4497,7 +4527,7 @@ def test_metric_prompt_builder_missing_rating_scores_raises_error(self): ValueError, match="Both 'criteria' and 'rating_scores' are required", ): - vertexai_genai_types.MetricPromptBuilder( + agentplatform_genai_types.MetricPromptBuilder( criteria={"criterion1": "definition1"}, rating_scores=None, metric_definition=None, @@ -4509,24 +4539,24 @@ class TestPromptTemplate: """Unit tests for the PromptTemplate class.""" def test_prompt_template_variables(self): - template = vertexai_genai_types.PromptTemplate( + template = agentplatform_genai_types.PromptTemplate( text="Hello {name}, welcome to {place}!" ) assert template.variables == {"name", "place"} def test_prompt_template_assemble_simple(self): - template = vertexai_genai_types.PromptTemplate(text="Hello {name}.") + template = agentplatform_genai_types.PromptTemplate(text="Hello {name}.") assert template.assemble(name="World") == "Hello World." def test_prompt_template_assemble_missing_variable_raises_error(self): - template = vertexai_genai_types.PromptTemplate(text="Hello {name}.") + template = agentplatform_genai_types.PromptTemplate(text="Hello {name}.") with pytest.raises( ValueError, match="Missing value for template variable 'name'" ): template.assemble() def test_prompt_template_assemble_extra_variable_raises_error(self): - template = vertexai_genai_types.PromptTemplate(text="Hello {name}.") + template = agentplatform_genai_types.PromptTemplate(text="Hello {name}.") with pytest.raises( ValueError, match="Invalid variable name 'extra_var' provided" ): @@ -4534,25 +4564,25 @@ def test_prompt_template_assemble_extra_variable_raises_error(self): def test_prompt_template_text_must_not_be_empty(self): with pytest.raises(ValueError, match="Prompt template text cannot be empty"): - vertexai_genai_types.PromptTemplate(text=" ") + agentplatform_genai_types.PromptTemplate(text=" ") def test_prompt_template_assemble_all_text_single_part_returns_string(self): - template = vertexai_genai_types.PromptTemplate(text="{greeting}, {name}.") + template = agentplatform_genai_types.PromptTemplate(text="{greeting}, {name}.") result = template.assemble(greeting="Hi", name="There") assert result == "Hi, There." def test_prompt_template_str_representation(self): template_text = "This is a template: {var}" - template = vertexai_genai_types.PromptTemplate(text=template_text) + template = agentplatform_genai_types.PromptTemplate(text=template_text) assert str(template) == template_text def test_prompt_template_repr_representation(self): template_text = "Test {repr}" - template = vertexai_genai_types.PromptTemplate(text=template_text) + template = agentplatform_genai_types.PromptTemplate(text=template_text) assert repr(template) == f"PromptTemplate(text='{template_text}')" def test_prompt_template_assemble_multimodal_output(self): - template = vertexai_genai_types.PromptTemplate( + template = agentplatform_genai_types.PromptTemplate( text="Context: {image_data} Question: {query}" ) image_content_json = genai_types.Content( @@ -4583,7 +4613,7 @@ def test_prompt_template_assemble_multimodal_output(self): def test_prompt_template_assemble_multimodal_variable_integration(self): template_str = "Observe: {media_part} and then answer: {txt_part}" - template = vertexai_genai_types.PromptTemplate(text=template_str) + template = agentplatform_genai_types.PromptTemplate(text=template_str) media_var_value = genai_types.Content( parts=[ @@ -4646,7 +4676,7 @@ def test_convert_simple_prompt_response(self): } ] result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 1 eval_case = result_dataset.eval_cases[0] @@ -4882,7 +4912,7 @@ def test_convert_with_raw_string_response(self): } ] result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 1 eval_case = result_dataset.eval_cases[0] @@ -4910,7 +4940,7 @@ def test_convert_simple_prompt_response(self): ) raw_data = raw_data_df.to_dict(orient="records") result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 1 eval_case = result_dataset.eval_cases[0] @@ -5096,7 +5126,7 @@ def test_convert_with_intermediate_events_as_event_objects(self): "response": ["Hi"], "intermediate_events": [ [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] @@ -5133,7 +5163,7 @@ def test_convert_simple_prompt_response(self): } ] result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 1 eval_case = result_dataset.eval_cases[0] @@ -5280,7 +5310,7 @@ def test_convert_simple_request_response(self): ] result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 1 eval_case = result_dataset.eval_cases[0] @@ -5359,14 +5389,18 @@ def test_convert_with_conversation_history(self): ) assert len(eval_case.conversation_history) == 2 - assert eval_case.conversation_history[0] == vertexai_genai_types.evals.Message( + assert eval_case.conversation_history[ + 0 + ] == agentplatform_genai_types.evals.Message( content=genai_types.Content( parts=[genai_types.Part(text="Hello")], role="user" ), turn_id="0", author="user", ) - assert eval_case.conversation_history[1] == vertexai_genai_types.evals.Message( + assert eval_case.conversation_history[ + 1 + ] == agentplatform_genai_types.evals.Message( content=genai_types.Content( parts=[genai_types.Part(text="Hi")], role="system" ), @@ -5403,7 +5437,7 @@ def test_convert_multiple_request_response(self): ] result_dataset = self.converter.convert(raw_data) - assert isinstance(result_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(result_dataset, agentplatform_genai_types.EvaluationDataset) assert len(result_dataset.eval_cases) == 2 eval_case = result_dataset.eval_cases[0] @@ -5552,10 +5586,10 @@ def test_agent_info_creation(self): ) ] ) - agent_info = vertexai_genai_types.evals.AgentInfo( + agent_info = agentplatform_genai_types.evals.AgentInfo( name="agent_system", agents={ - "agent1": vertexai_genai_types.evals.AgentConfig( + "agent1": agentplatform_genai_types.evals.AgentConfig( agent_id="agent1", instruction="instruction1", description="description1", @@ -5585,7 +5619,7 @@ def my_search_tool(query: str) -> str: mock_agent.tools = [my_search_tool] mock_agent.sub_agents = [] - agent_info = vertexai_genai_types.evals.AgentInfo.load_from_agent( + agent_info = agentplatform_genai_types.evals.AgentInfo.load_from_agent( agent=mock_agent, ) @@ -5613,7 +5647,7 @@ def test_load_from_agent_with_get_declaration_tool(self): mock_agent.tools = [mock_tool] mock_agent.sub_agents = [] - agent_info = vertexai_genai_types.evals.AgentInfo.load_from_agent( + agent_info = agentplatform_genai_types.evals.AgentInfo.load_from_agent( agent=mock_agent, ) @@ -5647,7 +5681,7 @@ def my_plain_tool(query: str) -> str: mock_agent.tools = [mock_adk_tool, my_plain_tool] mock_agent.sub_agents = [] - agent_info = vertexai_genai_types.evals.AgentInfo.load_from_agent( + agent_info = agentplatform_genai_types.evals.AgentInfo.load_from_agent( agent=mock_agent, ) @@ -5683,7 +5717,7 @@ def test_load_from_agent_with_none_declaration_falls_back(self): mock_callable_declaration = mock.Mock(spec=genai_types.FunctionDeclaration) mock_from_callable.return_value = mock_callable_declaration - agent_info = vertexai_genai_types.evals.AgentInfo.load_from_agent( + agent_info = agentplatform_genai_types.evals.AgentInfo.load_from_agent( agent=mock_agent, ) @@ -5699,7 +5733,7 @@ class TestValidateDatasetAgentData: """Unit tests for the _validate_dataset_agent_data function.""" def test_valid_agent_data_in_df(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [ { @@ -5711,7 +5745,7 @@ def test_valid_agent_data_in_df(self): "agent_data": '{"turns": [{"turn_index": 0, "turn_id": "2", "events": []}]}' }, { - "agent_data": vertexai_genai_types.evals.AgentData( + "agent_data": agentplatform_genai_types.evals.AgentData( turns=[{"turn_index": 0, "turn_id": "3", "events": []}] ) }, @@ -5721,20 +5755,20 @@ def test_valid_agent_data_in_df(self): _evals_utils._validate_dataset_agent_data(dataset) def test_valid_agent_data_in_eval_cases(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_cases=[ - vertexai_genai_types.EvalCase( + agentplatform_genai_types.EvalCase( agent_data={ "turns": [{"turn_index": 0, "turn_id": "1", "events": []}] } ), - vertexai_genai_types.EvalCase( + agentplatform_genai_types.EvalCase( agent_data=json.loads( '{"turns": [{"turn_index": 0, "turn_id": "2", "events": []}]}' ) ), - vertexai_genai_types.EvalCase( - agent_data=vertexai_genai_types.evals.AgentData( + agentplatform_genai_types.EvalCase( + agent_data=agentplatform_genai_types.evals.AgentData( turns=[{"turn_index": 0, "turn_id": "3", "events": []}] ) ), @@ -5743,21 +5777,21 @@ def test_valid_agent_data_in_eval_cases(self): _evals_utils._validate_dataset_agent_data(dataset) def test_invalid_json_string_raises_error(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame([{"agent_data": '{"turns":'}]) ) with pytest.raises(ValueError, match="is not valid JSON"): _evals_utils._validate_dataset_agent_data(dataset) def test_invalid_dict_raises_error(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame([{"agent_data": {"agents": 123}}]) ) with pytest.raises(ValueError, match="is inconsistent with AgentData type"): _evals_utils._validate_dataset_agent_data(dataset) def test_valid_agent_data_with_error_in_dict(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"agent_data": {"error": "some error message"}}] ) @@ -5765,7 +5799,7 @@ def test_valid_agent_data_with_error_in_dict(self): _evals_utils._validate_dataset_agent_data(dataset) def test_valid_agent_data_with_error_in_string(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"agent_data": '{"error": "some error message"}'}] ) @@ -5773,14 +5807,14 @@ def test_valid_agent_data_with_error_in_string(self): _evals_utils._validate_dataset_agent_data(dataset) def test_invalid_agent_data_type_raises_error(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame([{"agent_data": 123}]) ) with pytest.raises(ValueError, match="is inconsistent with AgentData type"): _evals_utils._validate_dataset_agent_data(dataset) def test_conflict_with_inference_configs_raises_error(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [ { @@ -5802,7 +5836,7 @@ def test_conflict_with_inference_configs_raises_error(self): _evals_utils._validate_dataset_agent_data(dataset, inference_configs) def test_no_conflict_with_inference_configs(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame([{"agent_data": {"turns": []}}]) ) inference_configs = { @@ -5811,7 +5845,7 @@ def test_no_conflict_with_inference_configs(self): _evals_utils._validate_dataset_agent_data(dataset, inference_configs) def test_no_conflict_if_inference_configs_has_no_agent_configs(self): - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [ { @@ -5831,7 +5865,7 @@ class TestEvent: """Unit tests for the Event class.""" def test_event_creation(self): - event = vertexai_genai_types.evals.Event( + event = agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] @@ -5859,10 +5893,10 @@ def test_eval_case_with_agent_eval_fields(self): ) ] ) - agent_info = vertexai_genai_types.evals.AgentInfo( + agent_info = agentplatform_genai_types.evals.AgentInfo( name="agent_system", agents={ - "agent1": vertexai_genai_types.evals.AgentConfig( + "agent1": agentplatform_genai_types.evals.AgentConfig( agent_id="agent1", instruction="instruction1", tools=[tool], @@ -5870,17 +5904,17 @@ def test_eval_case_with_agent_eval_fields(self): }, ) intermediate_events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] ), ) ] - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -5896,7 +5930,7 @@ class TestSessionInput: """Unit tests for the SessionInput class.""" def test_session_input_creation(self): - session_input = vertexai_genai_types.evals.SessionInput( + session_input = agentplatform_genai_types.evals.SessionInput( user_id="user1", state={"key": "value"}, ) @@ -5909,32 +5943,32 @@ class TestBuildEvaluationInstance: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) - importlib.reload(vertexai_genai_types) + importlib.reload(agentplatform) + importlib.reload(agentplatform_genai_types) importlib.reload(_evals_data_converters) importlib.reload(_evals_metric_handlers) - vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + agentplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) self.mock_api_client = mock.Mock(spec=client.Client) self.mock_evals_module = evals.Evals(api_client_=self.mock_api_client) def test_build_evaluation_instance_basic_filtering_and_fields(self): - metric = vertexai_genai_types.LLMMetric( + metric = agentplatform_genai_types.LLMMetric( name="test_quality", prompt_template="Eval: {prompt} with {response}. Context: {custom_context}. Ref: {reference}", ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content( parts=[genai_types.Part(text="User prompt text")] ), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content( parts=[genai_types.Part(text="Model response text")] ) ) ], - reference=vertexai_genai_types.ResponseCandidate( + reference=agentplatform_genai_types.ResponseCandidate( response=genai_types.Content( parts=[genai_types.Part(text="Ground truth text")] ) @@ -5969,26 +6003,26 @@ def test_build_evaluation_instance_basic_filtering_and_fields(self): assert "extra_field_not_in_template" not in instance.other_data.map_instance def test_build_evaluation_instance_various_field_types(self): - metric = vertexai_genai_types.LLMMetric( + metric = agentplatform_genai_types.LLMMetric( name="test_various_fields", prompt_template="{prompt}{response}{conversation_history}{system_instruction}{dict_field}{list_field}{int_field}{bool_field}", ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="The Prompt")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content( parts=[genai_types.Part(text="The Response")] ) ) ], conversation_history=[ - vertexai_genai_types.evals.Message( + agentplatform_genai_types.evals.Message( content=genai_types.Content( parts=[genai_types.Part(text="Turn 1 user")], role="user" ) ), - vertexai_genai_types.evals.Message( + agentplatform_genai_types.evals.Message( content=genai_types.Content( parts=[genai_types.Part(text="Turn 1 model")], role="model" ) @@ -6036,7 +6070,7 @@ class TestMetric: """Unit tests for the Metric class.""" def test_metric_creation_success(self): - metric = vertexai_genai_types.Metric(name="TestMetric") + metric = agentplatform_genai_types.Metric(name="TestMetric") assert metric.name == "testmetric" assert metric.custom_function is None @@ -6044,7 +6078,7 @@ def test_metric_creation_with_custom_function(self): def my_custom_function(data: dict): return 1.0 - metric = vertexai_genai_types.Metric( + metric = agentplatform_genai_types.Metric( name="custom_metric", custom_function=my_custom_function ) assert metric.name == "custom_metric" @@ -6052,27 +6086,31 @@ def my_custom_function(data: dict): def test_metric_name_validation_empty_raises_error(self): with pytest.raises(ValueError, match="Metric name cannot be empty."): - vertexai_genai_types.Metric(name="") + agentplatform_genai_types.Metric(name="") with pytest.raises(ValueError, match="Metric name cannot be empty."): - vertexai_genai_types.Metric(name=None) + agentplatform_genai_types.Metric(name=None) def test_llm_metric_prompt_template_validation_empty_raises_error(self): with pytest.raises(ValueError, match="Prompt template cannot be empty."): - vertexai_genai_types.LLMMetric(name="test_metric", prompt_template=None) + agentplatform_genai_types.LLMMetric( + name="test_metric", prompt_template=None + ) with pytest.raises( ValueError, match="Prompt template cannot be an empty string." ): - vertexai_genai_types.LLMMetric(name="test_metric", prompt_template="") + agentplatform_genai_types.LLMMetric(name="test_metric", prompt_template="") with pytest.raises( ValueError, match="Prompt template cannot be an empty string." ): - vertexai_genai_types.LLMMetric(name="test_metric", prompt_template=" ") + agentplatform_genai_types.LLMMetric( + name="test_metric", prompt_template=" " + ) def test_llm_metric_sampling_count_validation_raise_errors(self): with pytest.raises( ValueError, match="judge_model_sampling_count must be between 1 and 32." ): - vertexai_genai_types.LLMMetric( + agentplatform_genai_types.LLMMetric( name="test_metric", prompt_template="test_prompt_template", judge_model_sampling_count=0, @@ -6080,7 +6118,7 @@ def test_llm_metric_sampling_count_validation_raise_errors(self): with pytest.raises( ValueError, match="judge_model_sampling_count must be between 1 and 32." ): - vertexai_genai_types.LLMMetric( + agentplatform_genai_types.LLMMetric( name="test_metric", prompt_template="test_prompt_template", judge_model_sampling_count=-1, @@ -6088,22 +6126,22 @@ def test_llm_metric_sampling_count_validation_raise_errors(self): with pytest.raises( ValueError, match="judge_model_sampling_count must be between 1 and 32." ): - vertexai_genai_types.LLMMetric( + agentplatform_genai_types.LLMMetric( name="test_metric", prompt_template="test_prompt_template", judge_model_sampling_count=100, ) def test_metric_name_validation_lowercase(self): - metric = vertexai_genai_types.Metric(name="UPPERCASEMetric") + metric = agentplatform_genai_types.Metric(name="UPPERCASEMetric") assert metric.name == "uppercasemetric" - @mock.patch("vertexai._genai.types.common.yaml.dump") + @mock.patch("google.cloud.aiplatform.agentplatform._genai.types.common.yaml.dump") @mock.patch("builtins.open", new_callable=mock.mock_open) def test_metric_to_yaml_file_with_version_and_set_fields( self, mock_open_file, mock_yaml_dump ): - metric_obj = vertexai_genai_types.Metric( + metric_obj = agentplatform_genai_types.Metric( name="MyMetricToDump", prompt_template="Evaluate: {input}", judge_model="gemini-1.5-pro", @@ -6135,12 +6173,12 @@ def test_metric_to_yaml_file_with_version_and_set_fields( allow_unicode=True, ) - @mock.patch("vertexai._genai.types.common.yaml.dump") + @mock.patch("google.cloud.aiplatform.agentplatform._genai.types.common.yaml.dump") @mock.patch("builtins.open", new_callable=mock.mock_open) def test_metric_to_yaml_file_without_version_minimal_fields( self, mock_open_file, mock_yaml_dump ): - metric_obj = vertexai_genai_types.Metric(name="MinimalMetric") + metric_obj = agentplatform_genai_types.Metric(name="MinimalMetric") test_file_path = "/fake/path/minimal_metric.yaml" metric_obj.to_yaml_file(test_file_path) @@ -6156,9 +6194,9 @@ def test_metric_to_yaml_file_without_version_minimal_fields( allow_unicode=True, ) - @mock.patch("vertexai._genai.types.common.yaml", None) + @mock.patch("google.cloud.aiplatform.agentplatform._genai.types.common.yaml", None) def test_metric_to_yaml_file_raises_importerror_if_yaml_is_none(self): - metric_obj = vertexai_genai_types.Metric(name="ErrorMetric") + metric_obj = agentplatform_genai_types.Metric(name="ErrorMetric") with pytest.raises( ImportError, match="YAML serialization requires the pyyaml library" ): @@ -6864,10 +6902,10 @@ def test_eval_case_to_agent_data(self): ) ] ) - agent_info = vertexai_genai_types.evals.AgentInfo( + agent_info = agentplatform_genai_types.evals.AgentInfo( name="agent_system", agents={ - "agent1": vertexai_genai_types.evals.AgentConfig( + "agent1": agentplatform_genai_types.evals.AgentConfig( agent_id="agent1", instruction="instruction1", tools=[tool], @@ -6875,7 +6913,7 @@ def test_eval_case_to_agent_data(self): }, ) intermediate_events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] @@ -6883,10 +6921,10 @@ def test_eval_case_to_agent_data(self): author="agent1", ) ] - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -6910,17 +6948,17 @@ def test_eval_case_to_agent_data(self): def test_eval_case_to_agent_data_events_only(self): intermediate_events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] ), ) ] - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -6937,15 +6975,15 @@ def test_eval_case_to_agent_data_events_only(self): def test_eval_case_to_agent_data_empty_event_content(self): intermediate_events = [ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=None, ) ] - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -6959,10 +6997,10 @@ def test_eval_case_to_agent_data_empty_event_content(self): assert agent_data.turns[0].events[0].content is None def test_eval_case_to_agent_data_empty_intermediate_events_list(self): - agent_info = vertexai_genai_types.evals.AgentInfo( + agent_info = agentplatform_genai_types.evals.AgentInfo( name="agent_system", agents={ - "agent1": vertexai_genai_types.evals.AgentConfig( + "agent1": agentplatform_genai_types.evals.AgentConfig( agent_id="agent1", instruction="instruction1", tools=[], @@ -6970,10 +7008,10 @@ def test_eval_case_to_agent_data_empty_intermediate_events_list(self): }, ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -6989,20 +7027,20 @@ def test_eval_case_to_agent_data_empty_intermediate_events_list(self): assert len(agent_data.turns[0].events) == 2 def test_eval_case_to_agent_data_agent_info_empty_tools(self): - agent_info = vertexai_genai_types.evals.AgentInfo( + agent_info = agentplatform_genai_types.evals.AgentInfo( name="agent_system", agents={ - "agent1": vertexai_genai_types.evals.AgentConfig( + "agent1": agentplatform_genai_types.evals.AgentConfig( agent_id="agent1", instruction="instruction1", tools=[], ) }, ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -7021,17 +7059,17 @@ def test_eval_case_to_agent_data_agent_info_empty_tools(self): def test_eval_case_to_agent_data_agent_info_empty(self): intermediate_events = [ - vertexai_genai_types.Event( + agentplatform_genai_types.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] ), ) ] - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], @@ -7052,20 +7090,20 @@ def test_tool_use_quality_metric_no_tool_call_logs_warning( self, mock_warning, mock_api_client_fixture ): """Tests that PredefinedMetricHandler warns for tool_use_quality_v1 if no tool call.""" - metric = vertexai_genai_types.Metric(name="tool_use_quality_v1") + metric = agentplatform_genai_types.Metric(name="tool_use_quality_v1") handler = _evals_metric_handlers.PredefinedMetricHandler( module=evals.Evals(api_client_=mock_api_client_fixture), metric=metric ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( eval_case_id="case-no-tool-call", prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], intermediate_events=[ - vertexai_genai_types.evals.Event( + agentplatform_genai_types.evals.Event( event_id="event1", content=genai_types.Content( parts=[genai_types.Part(text="intermediate event")] @@ -7085,25 +7123,25 @@ def test_build_request_payload_tool_use_quality_v1_with_agent_data_tool_call( self, mock_warning, mock_api_client_fixture ): """Tests that PredefinedMetricHandler does not warn if tool call is in agent_data.""" - metric = vertexai_genai_types.Metric(name="tool_use_quality_v1") + metric = agentplatform_genai_types.Metric(name="tool_use_quality_v1") handler = _evals_metric_handlers.PredefinedMetricHandler( module=evals.Evals(api_client_=mock_api_client_fixture), metric=metric ) - eval_case = vertexai_genai_types.EvalCase( + eval_case = agentplatform_genai_types.EvalCase( eval_case_id="case-with-agent-data-tool-call", prompt=genai_types.Content(parts=[genai_types.Part(text="Hello")]), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content(parts=[genai_types.Part(text="Hi")]) ) ], - agent_data=vertexai_genai_types.evals.AgentData( + agent_data=agentplatform_genai_types.evals.AgentData( turns=[ - vertexai_genai_types.evals.ConversationTurn( + agentplatform_genai_types.evals.ConversationTurn( turn_index=0, turn_id="turn_0", events=[ - vertexai_genai_types.evals.AgentEvent( + agentplatform_genai_types.evals.AgentEvent( author="model", content=genai_types.Content( parts=[ @@ -7416,10 +7454,10 @@ def test_execute_evaluation_computation_metric( } ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - computation_metric = vertexai_genai_types.Metric(name="exact_match") + computation_metric = agentplatform_genai_types.Metric(name="exact_match") result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, @@ -7427,7 +7465,7 @@ def test_execute_evaluation_computation_metric( metrics=[computation_metric], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -7458,7 +7496,7 @@ def test_execute_evaluation_with_agent_info( } ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) predefined_metric = genai_types.PredefinedMetricSpec( @@ -7495,7 +7533,7 @@ def test_execute_evaluation_with_agent_info( agent_info=agent_info, ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert len(result.eval_case_results) == 1 assert result.agent_info.name == "agent_system" assert "agent1" in result.agent_info.agents @@ -7527,17 +7565,17 @@ def test_execute_evaluation_translation_metric( } ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - translation_metric = vertexai_genai_types.Metric(name="comet") + translation_metric = agentplatform_genai_types.Metric(name="comet") result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, dataset=input_dataset, metrics=[translation_metric], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -7553,20 +7591,20 @@ def test_execute_evaluation_llm_metric( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="text_quality", prompt_template="Evaluate: {response}" ) with mock.patch( - "vertexai._genai.evals.Evals._evaluate_instances" + "google.cloud.aiplatform.agentplatform._genai.evals.Evals._evaluate_instances" ) as mock_evaluate_instances_unified: mock_evaluate_instances_unified.return_value = ( - vertexai_genai_types.EvaluateInstancesResponse( + agentplatform_genai_types.EvaluateInstancesResponse( metric_results=[ - vertexai_genai_types.MetricResult( + agentplatform_genai_types.MetricResult( score=0.9, explanation="Mocked LLM explanation", error=None, @@ -7581,7 +7619,7 @@ def test_execute_evaluation_llm_metric( dataset=input_dataset, metrics=[llm_metric], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -7601,7 +7639,7 @@ def test_execute_evaluation_hallucination_metric(self, mock_api_client_fixture): dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) @@ -7609,11 +7647,11 @@ def test_execute_evaluation_hallucination_metric(self, mock_api_client_fixture): api_client=mock_api_client_fixture, dataset=input_dataset, metrics=[ - vertexai_genai_types.RubricMetric.HALLUCINATION, - vertexai_genai_types.RubricMetric.TOOL_USE_QUALITY, + agentplatform_genai_types.RubricMetric.HALLUCINATION, + agentplatform_genai_types.RubricMetric.TOOL_USE_QUALITY, ], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 2 assert result.summary_metrics[0].metric_name == "hallucination_v1" @@ -7648,19 +7686,19 @@ def test_execute_evaluation_with_openai_schema( }, } ] - converted_eval_case = vertexai_genai_types.EvalCase( + converted_eval_case = agentplatform_genai_types.EvalCase( prompt=genai_types.Content( parts=[genai_types.Part(text="OpenAI Prompt")], role="user" ), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content( parts=[genai_types.Part(text="Candidate Response")] ) ) ], ) - mock_converted_dataset = vertexai_genai_types.EvaluationDataset( + mock_converted_dataset = agentplatform_genai_types.EvaluationDataset( eval_cases=[converted_eval_case] ) @@ -7670,10 +7708,10 @@ def test_execute_evaluation_with_openai_schema( mock_converter_instance.convert.return_value = mock_converted_dataset mock_get_converter.return_value = mock_converter_instance - input_dataset_for_loader = vertexai_genai_types.EvaluationDataset( + input_dataset_for_loader = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame(mock_openai_raw_data) ) - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="test_metric", prompt_template="Evaluate: {response}" ) @@ -7685,7 +7723,7 @@ def test_execute_evaluation_with_openai_schema( _evals_metric_handlers.LLMMetricHandler, "get_metric_result" ) as mock_llm_process: mock_llm_process.return_value = ( - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="test_metric", score=0.75 ) ) @@ -7705,7 +7743,7 @@ def test_execute_evaluation_with_openai_schema( ) mock_converter_instance.convert.assert_called_once_with(mock_openai_raw_data) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] assert summary_metric.metric_name == "test_metric" @@ -7717,14 +7755,14 @@ def test_execute_evaluation_custom_metric( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) def my_custom_metric_fn(data: dict): return 0.5 - custom_metric = vertexai_genai_types.Metric( + custom_metric = agentplatform_genai_types.Metric( name="my_custom", custom_function=my_custom_metric_fn ) @@ -7733,7 +7771,7 @@ def my_custom_metric_fn(data: dict): dataset=input_dataset, metrics=[custom_metric], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -7751,24 +7789,24 @@ def test_llm_metric_default_aggregation_mixed_results( {"prompt": "P3", "response": "R3"}, ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="quality", prompt_template="Rate: {response}" ) with mock.patch( - "vertexai._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" ) as mock_llm_process: mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="quality", score=0.8, explanation="Good" ), - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="quality", score=0.6, explanation="Okay" ), - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="quality", error_message="Processing failed" ), ] @@ -7796,31 +7834,33 @@ def test_llm_metric_custom_aggregation_success( dataset_df = pd.DataFrame( [{"prompt": "P1", "response": "R1"}, {"prompt": "P2", "response": "R2"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - def custom_agg_fn(results: list[vertexai_genai_types.EvalCaseMetricResult]): + def custom_agg_fn( + results: list[agentplatform_genai_types.EvalCaseMetricResult], + ): return { "my_custom_stat": 123, "mean_score": 0.75, "num_cases_valid": len(results), } - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="custom_quality", prompt_template="Rate: {response}", aggregate_summary_fn=custom_agg_fn, ) with mock.patch( - "vertexai._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" ) as mock_llm_process: mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="custom_quality", score=0.8 ), - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="custom_quality", score=0.7 ), ] @@ -7846,28 +7886,28 @@ def test_llm_metric_custom_aggregation_error_fallback( dataset_df = pd.DataFrame( [{"prompt": "P1", "response": "R1"}, {"prompt": "P2", "response": "R2"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) def custom_agg_fn_error( - results: list[vertexai_genai_types.EvalCaseMetricResult], + results: list[agentplatform_genai_types.EvalCaseMetricResult], ): raise ValueError("Custom aggregation failed") - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="error_fallback_quality", prompt_template="Rate: {response}", aggregate_summary_fn=custom_agg_fn_error, ) with mock.patch( - "vertexai._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" ) as mock_llm_process: mock_llm_process.side_effect = [ - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="error_fallback_quality", score=0.9 ), - vertexai_genai_types.EvalCaseMetricResult( + agentplatform_genai_types.EvalCaseMetricResult( metric_name="error_fallback_quality", score=0.5 ), ] @@ -7890,25 +7930,27 @@ def test_llm_metric_custom_aggregation_invalid_return_type_fallback( self, mock_api_client_fixture, mock_eval_dependencies ): dataset_df = pd.DataFrame([{"prompt": "P1", "response": "R1"}]) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) def custom_agg_fn_invalid_type( - results: list[vertexai_genai_types.EvalCaseMetricResult], + results: list[agentplatform_genai_types.EvalCaseMetricResult], ): return "not a dict" - llm_metric = vertexai_genai_types.LLMMetric( + llm_metric = agentplatform_genai_types.LLMMetric( name="invalid_type_fallback", prompt_template="Rate: {response}", aggregate_summary_fn=custom_agg_fn_invalid_type, ) with mock.patch( - "vertexai._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers.LLMMetricHandler.get_metric_result" ) as mock_llm_process: - mock_llm_process.return_value = vertexai_genai_types.EvalCaseMetricResult( - metric_name="invalid_type_fallback", score=0.8 + mock_llm_process.return_value = ( + agentplatform_genai_types.EvalCaseMetricResult( + metric_name="invalid_type_fallback", score=0.8 + ) ) result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, @@ -7926,7 +7968,7 @@ def test_execute_evaluation_lazy_loaded_prebuilt_metric_instance( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) @@ -7935,12 +7977,12 @@ def test_execute_evaluation_lazy_loaded_prebuilt_metric_instance( ) with mock.patch( - "vertexai._genai.evals.Evals._evaluate_instances" + "google.cloud.aiplatform.agentplatform._genai.evals.Evals._evaluate_instances" ) as mock_evaluate_instances_unified: mock_evaluate_instances_unified.return_value = ( - vertexai_genai_types.EvaluateInstancesResponse( + agentplatform_genai_types.EvaluateInstancesResponse( metric_results=[ - vertexai_genai_types.MetricResult( + agentplatform_genai_types.MetricResult( score=0.9, explanation="Mocked LLM explanation", error=None, @@ -7959,7 +8001,7 @@ def test_execute_evaluation_lazy_loaded_prebuilt_metric_instance( mock_eval_dependencies[ "mock_fetch_prebuilt_metric" ].assert_called_once_with(mock_api_client_fixture) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -7972,19 +8014,19 @@ def test_execute_evaluation_prebuilt_metric_via_loader( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - prebuilt_metric = vertexai_genai_types.RubricMetric.FLUENCY + prebuilt_metric = agentplatform_genai_types.RubricMetric.FLUENCY with mock.patch( - "vertexai._genai.evals.Evals._evaluate_instances" + "google.cloud.aiplatform.agentplatform._genai.evals.Evals._evaluate_instances" ) as mock_evaluate_instances_unified: mock_evaluate_instances_unified.return_value = ( - vertexai_genai_types.EvaluateInstancesResponse( + agentplatform_genai_types.EvaluateInstancesResponse( metric_results=[ - vertexai_genai_types.MetricResult( + agentplatform_genai_types.MetricResult( score=0.9, explanation="Mocked LLM explanation", error=None, @@ -8003,7 +8045,7 @@ def test_execute_evaluation_prebuilt_metric_via_loader( mock_eval_dependencies[ "mock_fetch_prebuilt_metric" ].assert_called_once_with(mock_api_client_fixture) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert result.evaluation_dataset == [input_dataset] assert len(result.summary_metrics) == 1 summary_metric = result.summary_metrics[0] @@ -8022,10 +8064,10 @@ def test_execute_evaluation_with_gcs_destination( } ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - metric = vertexai_genai_types.Metric(name="exact_match") + metric = agentplatform_genai_types.Metric(name="exact_match") gcs_dest = "gs://my-bucket/eval_results/" result = _evals_common._execute_evaluation( @@ -8048,9 +8090,9 @@ def test_execute_evaluation_multiple_datasets( ): df1 = pd.DataFrame([{"prompt": "p1", "response": "r1a", "reference": "ref1"}]) df2 = pd.DataFrame([{"prompt": "p1", "response": "r1b", "reference": "ref1"}]) - dataset1 = vertexai_genai_types.EvaluationDataset(eval_dataset_df=df1) - dataset2 = vertexai_genai_types.EvaluationDataset(eval_dataset_df=df2) - metric = vertexai_genai_types.Metric(name="exact_match") + dataset1 = agentplatform_genai_types.EvaluationDataset(eval_dataset_df=df1) + dataset2 = agentplatform_genai_types.EvaluationDataset(eval_dataset_df=df2) + metric = agentplatform_genai_types.Metric(name="exact_match") result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, @@ -8058,7 +8100,7 @@ def test_execute_evaluation_multiple_datasets( metrics=[metric], ) - assert isinstance(result, vertexai_genai_types.EvaluationResult) + assert isinstance(result, agentplatform_genai_types.EvaluationResult) assert len(result.eval_case_results) == 1 case_result = result.eval_case_results[0] assert len(case_result.response_candidate_results) == 2 @@ -8089,19 +8131,19 @@ def test_execute_evaluation_deduplicates_candidate_names( self, mock_api_client_fixture, mock_eval_dependencies ): """Tests that duplicate candidate names are indexed.""" - dataset1 = vertexai_genai_types.EvaluationDataset( + dataset1 = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"prompt": "p1", "response": "r1", "reference": "ref1"}] ), candidate_name="gemini-pro", ) - dataset2 = vertexai_genai_types.EvaluationDataset( + dataset2 = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"prompt": "p1", "response": "r2", "reference": "ref1"}] ), candidate_name="gemini-flash", ) - dataset3 = vertexai_genai_types.EvaluationDataset( + dataset3 = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"prompt": "p1", "response": "r3", "reference": "ref1"}] ), @@ -8109,8 +8151,8 @@ def test_execute_evaluation_deduplicates_candidate_names( ) mock_eval_dependencies["mock_evaluate_instances"].return_value = ( - vertexai_genai_types.EvaluateInstancesResponse( - exact_match_results=vertexai_genai_types.ExactMatchResults( + agentplatform_genai_types.EvaluateInstancesResponse( + exact_match_results=agentplatform_genai_types.ExactMatchResults( exact_match_metric_values=[ genai_types.ExactMatchMetricValue(score=1.0) ] @@ -8121,7 +8163,7 @@ def test_execute_evaluation_deduplicates_candidate_names( result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, dataset=[dataset1, dataset2, dataset3], - metrics=[vertexai_genai_types.Metric(name="exact_match")], + metrics=[agentplatform_genai_types.Metric(name="exact_match")], ) assert result.metadata.candidate_names == [ @@ -8130,7 +8172,7 @@ def test_execute_evaluation_deduplicates_candidate_names( "gemini-pro #2", ] - @mock.patch("vertexai._genai._evals_common.datetime") + @mock.patch("google.cloud.aiplatform.agentplatform._genai._evals_common.datetime") def test_execute_evaluation_adds_creation_timestamp( self, mock_datetime, mock_api_client_fixture, mock_eval_dependencies ): @@ -8142,12 +8184,12 @@ def test_execute_evaluation_adds_creation_timestamp( ) mock_datetime.datetime.now.return_value = mock_now - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=pd.DataFrame( [{"prompt": "p", "response": "r", "reference": "r"}] ) ) - metric = vertexai_genai_types.Metric(name="exact_match") + metric = agentplatform_genai_types.Metric(name="exact_match") result = _evals_common._execute_evaluation( api_client=mock_api_client_fixture, @@ -8159,11 +8201,11 @@ def test_execute_evaluation_adds_creation_timestamp( assert result.metadata.creation_timestamp == mock_now @mock.patch( - "vertexai._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS", + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS", frozenset(["summarization_quality"]), ) @mock.patch("time.sleep", return_value=None) - @mock.patch("vertexai._genai.evals.Evals._evaluate_instances") # fmt: skip + @mock.patch("google.cloud.aiplatform.agentplatform._genai.evals.Evals._evaluate_instances") # fmt: skip def test_predefined_metric_retry_on_resource_exhausted( self, mock_private_evaluate_instances, @@ -8173,11 +8215,11 @@ def test_predefined_metric_retry_on_resource_exhausted( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - metric = vertexai_genai_types.Metric(name="summarization_quality") - metric_result = vertexai_genai_types.MetricResult( + metric = agentplatform_genai_types.Metric(name="summarization_quality") + metric_result = agentplatform_genai_types.MetricResult( score=0.9, explanation="Mocked predefined explanation", rubric_verdicts=[], @@ -8193,7 +8235,7 @@ def test_predefined_metric_retry_on_resource_exhausted( mock_private_evaluate_instances.side_effect = [ genai_errors.ClientError(code=429, response_json=error_response_json), genai_errors.ClientError(code=429, response_json=error_response_json), - vertexai_genai_types.EvaluateInstancesResponse( + agentplatform_genai_types.EvaluateInstancesResponse( metric_results=[metric_result] ), ] @@ -8212,11 +8254,11 @@ def test_predefined_metric_retry_on_resource_exhausted( assert summary_metric.mean_score == 0.9 @mock.patch( - "vertexai._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS", + "google.cloud.aiplatform.agentplatform._genai._evals_metric_handlers._evals_constant.SUPPORTED_PREDEFINED_METRICS", frozenset(["summarization_quality"]), ) @mock.patch("time.sleep", return_value=None) - @mock.patch("vertexai._genai.evals.Evals._evaluate_instances") # fmt: skip + @mock.patch("google.cloud.aiplatform.agentplatform._genai.evals.Evals._evaluate_instances") # fmt: skip def test_predefined_metric_retry_fail_on_resource_exhausted( self, mock_private_evaluate_instances, @@ -8226,7 +8268,7 @@ def test_predefined_metric_retry_fail_on_resource_exhausted( dataset_df = pd.DataFrame( [{"prompt": "Test prompt", "response": "Test response"}] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) error_response_json = { @@ -8236,7 +8278,7 @@ def test_predefined_metric_retry_fail_on_resource_exhausted( "status": "RESOURCE_EXHAUSTED", } } - metric = vertexai_genai_types.Metric(name="summarization_quality") + metric = agentplatform_genai_types.Metric(name="summarization_quality") mock_private_evaluate_instances.side_effect = [ genai_errors.ClientError(code=429, response_json=error_response_json), genai_errors.ClientError(code=429, response_json=error_response_json), @@ -8282,16 +8324,14 @@ def read_file_contents_side_effect(src: str) -> str: ) eval_cases = [ - vertexai_genai_types.ObservabilityEvalCase( + agentplatform_genai_types.ObservabilityEvalCase( input_src="gs://project/input.json", output_src="gs://project/output.json", system_instruction_src="gs://project/system_instruction.json", ) ] - result = ( - vertexai_genai_types.EvaluationDataset.load_from_observability_eval_cases( - eval_cases - ) + result = agentplatform_genai_types.EvaluationDataset.load_from_observability_eval_cases( + eval_cases ) mock_gcs_utils.return_value.read_file_contents.assert_has_calls( @@ -8336,15 +8376,13 @@ def read_file_contents_side_effect(src: str) -> str: ) eval_cases = [ - vertexai_genai_types.ObservabilityEvalCase( + agentplatform_genai_types.ObservabilityEvalCase( input_src="gs://project/input.json", output_src="gs://project/output.json", ) ] - result = ( - vertexai_genai_types.EvaluationDataset.load_from_observability_eval_cases( - eval_cases - ) + result = agentplatform_genai_types.EvaluationDataset.load_from_observability_eval_cases( + eval_cases ) mock_gcs_utils.return_value.read_file_contents.assert_has_calls( @@ -8392,21 +8430,19 @@ def read_file_contents_side_effect(src: str) -> str: ) eval_cases = [ - vertexai_genai_types.ObservabilityEvalCase( + agentplatform_genai_types.ObservabilityEvalCase( input_src="gs://project/input_1.json", output_src="gs://project/output_1.json", system_instruction_src="gs://project/system_instruction_1.json", ), - vertexai_genai_types.ObservabilityEvalCase( + agentplatform_genai_types.ObservabilityEvalCase( input_src="gs://project/input_2.json", output_src="gs://project/output_2.json", system_instruction_src="gs://project/system_instruction_2.json", ), ] - result = ( - vertexai_genai_types.EvaluationDataset.load_from_observability_eval_cases( - eval_cases - ) + result = agentplatform_genai_types.EvaluationDataset.load_from_observability_eval_cases( + eval_cases ) assert result.eval_dataset_df is not None @@ -8431,7 +8467,7 @@ class TestEvalsGenerateConversationScenarios: def setup_method(self, method): self.mock_client = mock.MagicMock(spec=client.Client) - self.mock_client.vertexai = True + self.mock_client.agentplatform = True self.mock_api_client = mock.MagicMock() self.mock_client._api_client = self.mock_api_client @@ -8451,14 +8487,14 @@ def test_generate_conversation_scenarios(self): evals_module = evals.Evals(api_client_=self.mock_api_client) eval_dataset = evals_module.generate_conversation_scenarios( - agent_info=vertexai_genai_types.evals.AgentInfo( + agent_info=agentplatform_genai_types.evals.AgentInfo( agents={"agent_1": {}}, root_agent_id="agent_1", ), config={"count": 2}, allow_cross_region_model=True, ) - assert isinstance(eval_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(eval_dataset, agentplatform_genai_types.EvaluationDataset) assert len(eval_dataset.eval_cases) == 2 assert eval_dataset.eval_cases[0].user_scenario.starting_prompt == "Prompt 1" assert eval_dataset.eval_cases[0].user_scenario.conversation_plan == "Plan 1" @@ -8487,14 +8523,14 @@ async def test_async_generate_conversation_scenarios(self): async_evals_module = evals.AsyncEvals(api_client_=self.mock_api_client) eval_dataset = await async_evals_module.generate_conversation_scenarios( - agent_info=vertexai_genai_types.evals.AgentInfo( + agent_info=agentplatform_genai_types.evals.AgentInfo( agents={"agent_1": {}}, root_agent_id="agent_1", ), config={"count": 2}, allow_cross_region_model=True, ) - assert isinstance(eval_dataset, vertexai_genai_types.EvaluationDataset) + assert isinstance(eval_dataset, agentplatform_genai_types.EvaluationDataset) assert len(eval_dataset.eval_cases) == 2 assert eval_dataset.eval_cases[0].user_scenario.starting_prompt == "Prompt 1" @@ -8551,24 +8587,24 @@ class TestConvertRequestToDatasetRow: """Unit tests for the _convert_request_to_dataset_row function.""" def test_convert_request_to_dataset_row_with_prompt_and_golden(self): - request = vertexai_genai_types.EvaluationItemRequest( - prompt=vertexai_genai_types.EvaluationPrompt(text="test prompt"), - golden_response=vertexai_genai_types.CandidateResponse( + request = agentplatform_genai_types.EvaluationItemRequest( + prompt=agentplatform_genai_types.EvaluationPrompt(text="test prompt"), + golden_response=agentplatform_genai_types.CandidateResponse( text="golden response" ), ) result = _evals_common._convert_request_to_dataset_row(request) assert result["prompt"] == "test prompt" - assert result["reference"] == vertexai_genai_types.CandidateResponse( + assert result["reference"] == agentplatform_genai_types.CandidateResponse( text="golden response" ) assert result["intermediate_events"] == [] assert result["agent_data"] is None def test_convert_request_to_dataset_row_with_user_scenario(self): - request = vertexai_genai_types.EvaluationItemRequest( - prompt=vertexai_genai_types.EvaluationPrompt( - user_scenario=vertexai_genai_types.evals.UserScenario( + request = agentplatform_genai_types.EvaluationItemRequest( + prompt=agentplatform_genai_types.EvaluationPrompt( + user_scenario=agentplatform_genai_types.evals.UserScenario( starting_prompt="start prompt", conversation_plan="convo plan" ) ) @@ -8579,9 +8615,9 @@ def test_convert_request_to_dataset_row_with_user_scenario(self): assert result["prompt"] is None def test_convert_request_to_dataset_row_with_candidate_events(self): - request = vertexai_genai_types.EvaluationItemRequest( + request = agentplatform_genai_types.EvaluationItemRequest( candidate_responses=[ - vertexai_genai_types.CandidateResponse( + agentplatform_genai_types.CandidateResponse( candidate="test-candidate", text="candidate text", events=[ @@ -8607,9 +8643,9 @@ def test_convert_request_to_dataset_row_with_candidate_events(self): def test_convert_request_to_dataset_row_with_agent_data(self): mock_agent_data = {"turns": []} - request = vertexai_genai_types.EvaluationItemRequest( + request = agentplatform_genai_types.EvaluationItemRequest( candidate_responses=[ - vertexai_genai_types.CandidateResponse( + agentplatform_genai_types.CandidateResponse( candidate="test-candidate", agent_data=mock_agent_data ) ] @@ -8914,24 +8950,24 @@ def test_resolve_dataset_preserves_conversation_history( role="model", parts=[genai_types.Part(text="Old model msg")] ) - dataset = vertexai_genai_types.EvaluationDataset( + dataset = agentplatform_genai_types.EvaluationDataset( eval_cases=[ - vertexai_genai_types.EvalCase( + agentplatform_genai_types.EvalCase( prompt=genai_types.Content( parts=[genai_types.Part(text="test prompt")] ), responses=[ - vertexai_genai_types.ResponseCandidate( + agentplatform_genai_types.ResponseCandidate( response=genai_types.Content( parts=[genai_types.Part(text="test response")] ) ) ], conversation_history=[ - vertexai_genai_types.evals.Message( + agentplatform_genai_types.evals.Message( turn_id="0", content=history_content_1 ), - vertexai_genai_types.evals.Message( + agentplatform_genai_types.evals.Message( turn_id="1", content=history_content_2 ), ], @@ -9085,7 +9121,7 @@ class TestComputationMetricRetry: @mock.patch("time.sleep", return_value=None) # fmt: off @mock.patch( - "vertexai._genai.evals.Evals.evaluate_instances" + "google.cloud.aiplatform.agentplatform._genai.evals.Evals.evaluate_instances" ) # fmt: on def test_computation_metric_retry_on_resource_exhausted( @@ -9104,10 +9140,10 @@ def test_computation_metric_retry_on_resource_exhausted( } ] ) - input_dataset = vertexai_genai_types.EvaluationDataset( + input_dataset = agentplatform_genai_types.EvaluationDataset( eval_dataset_df=dataset_df ) - metric = vertexai_genai_types.Metric(name="bleu") + metric = agentplatform_genai_types.Metric(name="bleu") error_response_json = { "error": { "code": 429, diff --git a/tests/unit/vertexai/genai/test_genai_client.py b/tests/unit/agentplatform/genai/test_genai_client.py similarity index 68% rename from tests/unit/vertexai/genai/test_genai_client.py rename to tests/unit/agentplatform/genai/test_genai_client.py index 16c7068b78..fbccec7422 100644 --- a/tests/unit/vertexai/genai/test_genai_client.py +++ b/tests/unit/agentplatform/genai/test_genai_client.py @@ -20,8 +20,8 @@ from unittest import mock from google.cloud import aiplatform -import vertexai -from vertexai._genai import client as vertexai_client +import agentplatform +from agentplatform._genai import client as agentplatform_client from google.cloud.aiplatform import initializer as aiplatform_initializer @@ -38,47 +38,53 @@ class TestGenAiClient: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) @pytest.mark.usefixtures("google_auth_mock") def test_genai_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) assert test_client is not None - assert test_client._api_client.vertexai + assert test_client._api_client.agentplatform assert test_client._api_client.project == _TEST_PROJECT assert test_client._api_client.location == _TEST_LOCATION @pytest.mark.asyncio @pytest.mark.usefixtures("google_auth_mock") async def test_async_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) - assert isinstance(test_client.aio, vertexai._genai.client.AsyncClient) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) + assert isinstance(test_client.aio, agentplatform._genai.client.AsyncClient) @pytest.mark.usefixtures("google_auth_mock") def test_live_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_async_client = test_client.aio - assert isinstance(test_async_client.live, vertexai._genai.live.AsyncLive) + assert isinstance(test_async_client.live, agentplatform._genai.live.AsyncLive) @pytest.mark.usefixtures("google_auth_mock") def test_types(self): - assert vertexai.types is not None - assert vertexai.types.LLMMetric is not None + assert agentplatform.types is not None + assert agentplatform.types.LLMMetric is not None @pytest.mark.asyncio @pytest.mark.usefixtures("google_auth_mock") async def test_async_content_manager(self): with mock.patch.object( - vertexai_client.AsyncClient, "aclose", autospec=True + agentplatform_client.AsyncClient, "aclose", autospec=True ) as mock_aclose: - async with vertexai.Client( + async with agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION ).aio as async_client: - assert isinstance(async_client, vertexai_client.AsyncClient) + assert isinstance(async_client, agentplatform_client.AsyncClient) mock_aclose.assert_called_once() @@ -86,9 +92,9 @@ async def test_async_content_manager(self): @pytest.mark.usefixtures("google_auth_mock") async def test_call_aclose_async_client(self): with mock.patch.object( - vertexai_client.AsyncClient, "aclose", autospec=True + agentplatform_client.AsyncClient, "aclose", autospec=True ) as mock_aclose: - async_client = vertexai.Client( + async_client = agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION ).aio await async_client.aclose() diff --git a/tests/unit/vertexai/genai/test_genai_skills.py b/tests/unit/agentplatform/genai/test_genai_skills.py similarity index 97% rename from tests/unit/vertexai/genai/test_genai_skills.py rename to tests/unit/agentplatform/genai/test_genai_skills.py index cc6a54ecaf..a8901a114b 100644 --- a/tests/unit/vertexai/genai/test_genai_skills.py +++ b/tests/unit/agentplatform/genai/test_genai_skills.py @@ -1,4 +1,4 @@ -# //third_party/py/google/cloud/aiplatform/tests/unit/vertexai/genai/test_genai_skills.py +# //third_party/py/google/cloud/aiplatform/tests/unit/agentplatform/genai/test_genai_skills.py import json from unittest import mock import google.auth.credentials @@ -12,7 +12,7 @@ def skills_client(): creds = mock.create_autospec(google.auth.credentials.Credentials, instance=True) creds.token = "test_token" - client = vertexai_client.Client( + client = agentplatform_client.Client( project="test-project", location="test-location", credentials=creds ) return client.skills diff --git a/tests/unit/vertexai/genai/test_live_agent_engines.py b/tests/unit/agentplatform/genai/test_live_agent_engines.py similarity index 93% rename from tests/unit/vertexai/genai/test_live_agent_engines.py rename to tests/unit/agentplatform/genai/test_live_agent_engines.py index 2d57df860e..8af10749a3 100644 --- a/tests/unit/vertexai/genai/test_live_agent_engines.py +++ b/tests/unit/agentplatform/genai/test_live_agent_engines.py @@ -19,9 +19,9 @@ import google.auth import google.auth.credentials from google.cloud import aiplatform -import vertexai +import agentplatform from google.cloud.aiplatform import initializer as aiplatform_initializer -from vertexai._genai import live_agent_engines +from agentplatform._genai import live_agent_engines import pytest @@ -36,8 +36,8 @@ class TestLiveAgentEngines: def setup_method(self): importlib.reload(aiplatform_initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) @@ -56,7 +56,9 @@ async def test_async_live_agent_engines_connect( mock_creds.valid = True mock_auth_default.return_value = (mock_creds, None) - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) mock_ws = mock.AsyncMock() mock_ws_connect.return_value.__aenter__.return_value = mock_ws diff --git a/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py b/tests/unit/agentplatform/genai/test_multimodal_datasets_genai.py similarity index 98% rename from tests/unit/vertexai/genai/test_multimodal_datasets_genai.py rename to tests/unit/agentplatform/genai/test_multimodal_datasets_genai.py index c120bcc95c..21a9daeda4 100644 --- a/tests/unit/vertexai/genai/test_multimodal_datasets_genai.py +++ b/tests/unit/agentplatform/genai/test_multimodal_datasets_genai.py @@ -15,8 +15,8 @@ """Tests for multimodal datasets.""" from unittest import mock -from vertexai._genai import _datasets_utils -from vertexai._genai import types +from agentplatform._genai import _datasets_utils +from agentplatform._genai import types from google.genai import types as genai_types import pytest diff --git a/tests/unit/vertexai/genai/test_prompt_optimizer.py b/tests/unit/agentplatform/genai/test_prompt_optimizer.py similarity index 85% rename from tests/unit/vertexai/genai/test_prompt_optimizer.py rename to tests/unit/agentplatform/genai/test_prompt_optimizer.py index 2f34c7447c..58b360f03a 100644 --- a/tests/unit/vertexai/genai/test_prompt_optimizer.py +++ b/tests/unit/agentplatform/genai/test_prompt_optimizer.py @@ -17,10 +17,10 @@ import importlib from unittest import mock -import vertexai -from vertexai._genai import prompt_optimizer -from vertexai._genai import prompts -from vertexai._genai import types +import agentplatform +from agentplatform._genai import prompt_optimizer +from agentplatform._genai import prompts +from agentplatform._genai import types from google.genai import client import pandas as pd import pytest @@ -39,15 +39,17 @@ class TestPromptOptimizer: """Unit tests for the Prompt Optimizer client.""" def setup_method(self): - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) @pytest.mark.usefixtures("google_auth_mock") def test_prompt_optimizer_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) assert test_client.prompt_optimizer is not None @mock.patch.object(client.Client, "_get_api_client") @@ -57,7 +59,9 @@ def test_prompt_optimizer_optimize( self, mock_custom_job_prompts, mock_custom_job_prompt_optimizer, mock_client ): """Test that prompt_optimizer.optimize method creates a custom job.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompt_optimizer.optimize( method=types.PromptOptimizerMethod.VAPO, config=types.PromptOptimizerConfig( @@ -76,7 +80,9 @@ def test_prompt_optimizer_optimize_nano( self, mock_custom_job_prompts, mock_custom_job_prompt_optimizer, mock_client ): """Test that prompt_optimizer.optimize method creates a custom job.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompt_optimizer.optimize( method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, config=types.PromptOptimizerConfig( @@ -94,7 +100,9 @@ def test_prompt_optimizer_optimize_prompt( self, mock_custom_optimize_prompt, mock_client ): """Test that prompt_optimizer.optimize_prompt method calls optimize_prompt API.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompt_optimizer.optimize_prompt(prompt="test_prompt") mock_client.assert_called_once() mock_custom_optimize_prompt.assert_called_once() @@ -104,22 +112,26 @@ class TestPrompts: """Unit tests for the Prompts client.""" def setup_method(self): - importlib.reload(vertexai) - vertexai.init( + importlib.reload(agentplatform) + agentplatform.init( project=_TEST_PROJECT, location=_TEST_LOCATION, ) @pytest.mark.usefixtures("google_auth_mock") def test_prompt_optimizer_client(self): - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) assert test_client.prompts is not None @mock.patch.object(client.Client, "_get_api_client") @mock.patch.object(prompts.Prompts, "_create_custom_job_resource") def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client): """Test that prompt_optimizer.optimize method creates a custom job.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompts.launch_optimization_job( method=types.PromptOptimizerMethod.VAPO, config=types.PromptOptimizerConfig( @@ -135,7 +147,9 @@ def test_prompt_optimizer_optimize(self, mock_custom_job, mock_client): @mock.patch.object(prompts.Prompts, "_create_custom_job_resource") def test_prompt_optimizer_optimize_nano(self, mock_custom_job, mock_client): """Test that prompt_optimizer.optimize method creates a custom job.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompts.launch_optimization_job( method=types.PromptOptimizerMethod.OPTIMIZATION_TARGET_GEMINI_NANO, config=types.PromptOptimizerConfig( @@ -153,7 +167,9 @@ def test_prompt_optimizer_optimize_prompt( self, mock_custom_optimize_prompt, mock_client ): """Test that prompt_optimizer.optimize_prompt method calls optimize_prompt API.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_client.prompts.optimize(prompt="test_prompt") mock_client.assert_called_once() mock_custom_optimize_prompt.assert_called_once() @@ -168,7 +184,9 @@ def test_prompt_optimizer_optimize_few_shot(self, mock_custom_optimize_prompt): "target_response": ["target1", "target2"], } ) - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) test_config = types.OptimizeConfig( optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_FEW_SHOT_TARGET_RESPONSE, examples_dataframe=df, @@ -191,7 +209,9 @@ def test_prompt_optimizer_optimize_prompt_with_optimization_target( self, mock_custom_optimize_prompt ): """Test that prompt_optimizer.optimize_prompt method calls _custom_optimize_prompt with optimization_target.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) config = types.OptimizeConfig( optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, ) @@ -210,7 +230,9 @@ async def test_async_prompt_optimizer_optimize_prompt( self, mock_custom_optimize_prompt ): """Test that async prompt_optimizer.optimize_prompt method calls optimize_prompt API.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) await test_client.aio.prompts.optimize(prompt="test_prompt") mock_custom_optimize_prompt.assert_called_once() @@ -220,7 +242,9 @@ async def test_async_prompt_optimizer_optimize_prompt_with_optimization_target( self, mock_custom_optimize_prompt ): """Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with optimization_target.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) config = types.OptimizeConfig( optimization_target=types.OptimizeTarget.OPTIMIZATION_TARGET_GEMINI_NANO, ) @@ -239,7 +263,9 @@ async def test_async_prompt_optimizer_optimize_prompt_few_shot_target_response( self, mock_custom_optimize_prompt ): """Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot target response.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) df = pd.DataFrame( { "prompt": ["prompt1", "prompt2"], @@ -266,7 +292,9 @@ async def test_async_prompt_optimizer_optimize_prompt_few_shot_rubrics( self, mock_custom_optimize_prompt ): """Test that async prompt_optimizer.optimize_prompt calls optimize_prompt with few shot rubrics.""" - test_client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION) + test_client = agentplatform.Client( + project=_TEST_PROJECT, location=_TEST_LOCATION + ) df = pd.DataFrame( { "prompt": ["prompt1", "prompt2"], diff --git a/tests/unit/vertexai/genai/test_sandbox.py b/tests/unit/agentplatform/genai/test_sandbox.py similarity index 97% rename from tests/unit/vertexai/genai/test_sandbox.py rename to tests/unit/agentplatform/genai/test_sandbox.py index 120a794182..8f3f331d2c 100644 --- a/tests/unit/vertexai/genai/test_sandbox.py +++ b/tests/unit/agentplatform/genai/test_sandbox.py @@ -20,9 +20,9 @@ from google import auth from google.auth import credentials as auth_credentials from google.cloud import aiplatform -import vertexai +import agentplatform from google.cloud.aiplatform import initializer -from vertexai._genai import sandboxes +from agentplatform._genai import sandboxes from google.genai import client from google.genai import types as genai_types import pytest @@ -60,9 +60,9 @@ class TestSandbox: def setup_method(self): importlib.reload(initializer) importlib.reload(aiplatform) - importlib.reload(vertexai) + importlib.reload(agentplatform) os.environ[_TEST_AGENT_ENGINE_ENV_KEY] = _TEST_AGENT_ENGINE_ENV_VALUE - self.client = vertexai.Client( + self.client = agentplatform.Client( project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=_TEST_CREDENTIALS, diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index f7eb3e1ade..9d9bb37cc6 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -25,16 +25,16 @@ from google import auth from google.auth import credentials as auth_credentials from google.cloud import storage -import vertexai +import agentplatform from google.cloud import aiplatform from google.cloud.aiplatform_v1 import types as aip_types from google.cloud.aiplatform_v1.services import reasoning_engine_service from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer -from vertexai.agent_engines import _utils -from vertexai import agent_engines -from vertexai.agent_engines.templates import adk as adk_template -from vertexai.agent_engines import _agent_engines +from agentplatform.agent_engines import _utils +from agentplatform import agent_engines +from agentplatform.frameworks import adk as adk_template +from agentplatform.agent_engines import _agent_engines from google.api_core import operation as ga_operation from google.genai import types import pytest @@ -150,9 +150,9 @@ def google_auth_mock(): @pytest.fixture -def vertexai_init_mock(): - with mock.patch.object(vertexai, "init") as vertexai_init_mock: - yield vertexai_init_mock +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock @pytest.fixture @@ -204,7 +204,7 @@ def logger_provider_force_flush_mock(): @pytest.fixture def default_instrumentor_builder_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._default_instrumentor_builder" + "google.cloud.aiplatform.agentplatform.frameworks.adk._default_instrumentor_builder" ) as default_instrumentor_builder_mock: yield default_instrumentor_builder_mock @@ -220,7 +220,7 @@ def simple_span_processor_mock(): @pytest.fixture def adk_version_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version" + "google.cloud.aiplatform.agentplatform.frameworks.adk.get_adk_version" ) as adk_version_mock: yield adk_version_mock @@ -228,7 +228,7 @@ def adk_version_mock(): @pytest.fixture def is_version_sufficient_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.is_version_sufficient" + "google.cloud.aiplatform.agentplatform.frameworks.adk.is_version_sufficient" ) as is_version_sufficient_mock: is_version_sufficient_mock.return_value = True @@ -245,7 +245,7 @@ def get_project_id_mock(): @pytest.fixture def warn_if_telemetry_api_disabled_mock(): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk._warn_if_telemetry_api_disabled" + "google.cloud.aiplatform.agentplatform.frameworks.adk._warn_if_telemetry_api_disabled" ) as warn_if_telemetry_api_disabled_mock: yield warn_if_telemetry_api_disabled_mock @@ -312,7 +312,7 @@ async def run_async(self, *args, **kwargs): class TestAdkApp: def test_adk_version(self): with mock.patch( - "google.cloud.aiplatform.vertexai.agent_engines.templates.adk.get_adk_version", + "google.cloud.aiplatform.agentplatform.frameworks.adk.get_adk_version", return_value="0.5.0", ): with pytest.raises( @@ -326,8 +326,8 @@ def test_adk_version(self): def setup_method(self): importlib.reload(initializer) - importlib.reload(vertexai) - vertexai.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + importlib.reload(agentplatform) + agentplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) def teardown_method(self): initializer.global_pool.shutdown(wait=True) @@ -937,9 +937,9 @@ def test_dump_event_for_json(): # def test_adk_app_initialization_with_api_key(): # importlib.reload(initializer) -# importlib.reload(vertexai) +# importlib.reload(agentplatform) # try: -# vertexai.init(api_key=_TEST_API_KEY) +# agentplatform.init(api_key=_TEST_API_KEY) # app = agent_engines.AdkApp(agent=_TEST_AGENT) # assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY # assert app._tmpl_attrs.get("runner") is None