diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index f3751206a8..2710c3894c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index d857da9635..e31db15788 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index 9133092c3f..4a2b2129c0 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -35,6 +35,7 @@ from pydantic import Field from pydantic import field_validator from pydantic import model_validator +from pydantic import TypeAdapter from typing_extensions import override from typing_extensions import TypeAlias @@ -48,6 +49,7 @@ from ..models.llm_response import LlmResponse from ..models.registry import LLMRegistry from ..planners.base_planner import BasePlanner +from ..tools._gemini_schema_util import validate_and_dump_schema from ..tools.base_tool import BaseTool from ..tools.base_toolset import BaseToolset from ..tools.function_tool import FunctionTool @@ -314,9 +316,9 @@ class LlmAgent(BaseAgent): """ # Controlled input/output configurations - Start - input_schema: Optional[type[BaseModel]] = None + input_schema: Optional[Any] = None """The input schema when agent is used as a tool.""" - output_schema: Optional[type[BaseModel]] = None + output_schema: Optional[Any] = None """The output schema when agent replies. NOTE: @@ -833,9 +835,7 @@ def __maybe_save_output_to_state(self, event: Event): # Do not attempt to parse it as JSON. if not result.strip(): return - result = self.output_schema.model_validate_json(result).model_dump( - exclude_none=True - ) + result = validate_and_dump_schema(self.output_schema, result) event.actions.state_delta[self.output_key] = result @model_validator(mode='after') diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index 1b9559b29c..a3c1fccaf6 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -140,31 +140,6 @@ def _is_builtin_primitive_or_compound( return annotation in _py_builtin_type_to_schema_type.keys() -def _raise_for_any_of_if_mldev(schema: types.Schema): - if schema.any_of: - raise ValueError( - 'AnyOf is not supported in function declaration schema for Google AI.' - ) - - -def _update_for_default_if_mldev(schema: types.Schema): - if schema.default is not None: - # TODO(kech): Remove this workaround once mldev supports default value. - schema.default = None - logger.warning( - 'Default value is not supported in function declaration schema for' - ' Google AI.' - ) - - -def _raise_if_schema_unsupported( - variant: GoogleLLMVariant, schema: types.Schema -): - if variant == GoogleLLMVariant.GEMINI_API: - _raise_for_any_of_if_mldev(schema) - # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value - - def _is_default_value_compatible( default_value: Any, annotation: inspect.Parameter.annotation ) -> bool: @@ -230,7 +205,6 @@ def _parse_schema_from_parameter( raise ValueError(default_value_error_msg) schema.default = param.default schema.type = _py_builtin_type_to_schema_type[param.annotation] - _raise_if_schema_unsupported(variant, schema) return schema if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): schema.type = types.Type.STRING @@ -244,7 +218,6 @@ def _parse_schema_from_parameter( if default_value not in schema.enum: raise ValueError(default_value_error_msg) schema.default = default_value - _raise_if_schema_unsupported(variant, schema) return schema if ( get_origin(param.annotation) is Union @@ -285,7 +258,6 @@ def _parse_schema_from_parameter( if not _is_default_value_compatible(param.default, param.annotation): raise ValueError(default_value_error_msg) schema.default = param.default - _raise_if_schema_unsupported(variant, schema) return schema if isinstance(param.annotation, _GenericAlias) or isinstance( param.annotation, typing_types.GenericAlias @@ -298,7 +270,6 @@ def _parse_schema_from_parameter( if not _is_default_value_compatible(param.default, param.annotation): raise ValueError(default_value_error_msg) schema.default = param.default - _raise_if_schema_unsupported(variant, schema) return schema if origin is Literal: if not all(isinstance(arg, str) for arg in args): @@ -311,7 +282,6 @@ def _parse_schema_from_parameter( if not _is_default_value_compatible(param.default, param.annotation): raise ValueError(default_value_error_msg) schema.default = param.default - _raise_if_schema_unsupported(variant, schema) return schema if origin is list: schema.type = types.Type.ARRAY @@ -328,7 +298,6 @@ def _parse_schema_from_parameter( if not _is_default_value_compatible(param.default, param.annotation): raise ValueError(default_value_error_msg) schema.default = param.default - _raise_if_schema_unsupported(variant, schema) return schema if origin is Union: schema.any_of = [] @@ -374,7 +343,6 @@ def _parse_schema_from_parameter( if not _is_default_value_compatible(param.default, param.annotation): raise ValueError(default_value_error_msg) schema.default = param.default - _raise_if_schema_unsupported(variant, schema) return schema # all other generic alias will be invoked in raise branch if ( @@ -399,7 +367,6 @@ def _parse_schema_from_parameter( ), func_name, ) - _raise_if_schema_unsupported(variant, schema) return schema if inspect.isclass(param.annotation) and issubclass( param.annotation, ToolContext @@ -413,7 +380,6 @@ def _parse_schema_from_parameter( # null is not a valid type in schema, use object instead. schema.type = types.Type.OBJECT schema.nullable = True - _raise_if_schema_unsupported(variant, schema) return schema raise ValueError( f'Failed to parse the parameter {param} of function {func_name} for' diff --git a/src/google/adk/tools/_gemini_schema_util.py b/src/google/adk/tools/_gemini_schema_util.py index 07df5379d8..3e0bfceaad 100644 --- a/src/google/adk/tools/_gemini_schema_util.py +++ b/src/google/adk/tools/_gemini_schema_util.py @@ -20,7 +20,9 @@ from google.genai.types import JSONSchema from google.genai.types import Schema +from pydantic import BaseModel from pydantic import Field +from pydantic import TypeAdapter from ..utils.variant_utils import get_google_llm_variant @@ -208,3 +210,11 @@ def _to_gemini_schema(openapi_schema: dict[str, Any]) -> Schema: json_schema=_ExtendedJSONSchema.model_validate(sanitized_schema), api_option=get_google_llm_variant(), ) + + +def validate_and_dump_schema(schema: Any, json_data: str) -> Any: + """Validates json data against a schema and returns a serializable object.""" + validated_result = TypeAdapter(schema).validate_json(json_data) + if isinstance(validated_result, BaseModel): + return validated_result.model_dump(exclude_none=True) + return validated_result diff --git a/src/google/adk/tools/agent_tool.py b/src/google/adk/tools/agent_tool.py index 91135dce5f..399e9d7cac 100644 --- a/src/google/adk/tools/agent_tool.py +++ b/src/google/adk/tools/agent_tool.py @@ -21,6 +21,7 @@ from google.genai import types from pydantic import BaseModel from pydantic import model_validator +from pydantic import TypeAdapter from typing_extensions import override from . import _automatic_function_calling_util @@ -30,6 +31,7 @@ from ..memory.in_memory_memory_service import InMemoryMemoryService from ..utils.context_utils import Aclosing from ._forwarding_artifact_service import ForwardingArtifactService +from ._gemini_schema_util import validate_and_dump_schema from .base_tool import BaseTool from .tool_configs import BaseToolConfig from .tool_configs import ToolArgsConfig @@ -202,11 +204,7 @@ async def run_async( input_value = input_schema.model_validate(args) content = types.Content( role='user', - parts=[ - types.Part.from_text( - text=input_value.model_dump_json(exclude_none=True) - ) - ], + parts=[types.Part.from_text(text=text)], ) else: content = types.Content( diff --git a/tests/unittests/agents/test_llm_agent_fields.py b/tests/unittests/agents/test_llm_agent_fields.py index 8a3623cb70..076ab0bdd6 100644 --- a/tests/unittests/agents/test_llm_agent_fields.py +++ b/tests/unittests/agents/test_llm_agent_fields.py @@ -16,13 +16,17 @@ import logging from typing import Any +from typing import cast +from typing import Literal from typing import Optional +from typing import Union from unittest import mock from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext from google.adk.agents.llm_agent import LlmAgent from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.events.event import Event from google.adk.models.anthropic_llm import Claude from google.adk.models.google_llm import Gemini from google.adk.models.lite_llm import LiteLlm @@ -34,9 +38,12 @@ from google.adk.tools.google_search_tool import GoogleSearchTool from google.adk.tools.vertex_ai_search_tool import VertexAiSearchTool from google.genai import types +from google.genai.types import Part from pydantic import BaseModel import pytest +from .. import testing_utils + async def _create_readonly_context( agent: LlmAgent, state: Optional[dict[str, Any]] = None @@ -519,3 +526,51 @@ def test_builtin_planner_overwrite_logging(caplog): 'Overwriting `thinking_config` from `generate_content_config`' in caplog.text ) + + +def test_output_schema_with_union(): + """Tests if agent can have a Union type in output_schema.""" + + class CustomOutput1(BaseModel): + custom_output1: str + + class CustomOutput2(BaseModel): + custom_output2: str + + agent = LlmAgent( + name='test_agent', + output_schema=Union[CustomOutput1, CustomOutput2, Literal['option3']], + output_key='test_output', + ) + + # Test with the first type + event1 = Event( + author='test_agent', + content=types.Content( + parts=[Part(text='{"custom_output1": "response1"}')] + ), + ) + agent._LlmAgent__maybe_save_output_to_state(event1) + assert event1.actions.state_delta['test_output'] == { + 'custom_output1': 'response1' + } + + # Test with the second type + event2 = Event( + author='test_agent', + content=types.Content( + parts=[Part(text='{"custom_output2": "response2"}')] + ), + ) + agent._LlmAgent__maybe_save_output_to_state(event2) + assert event2.actions.state_delta['test_output'] == { + 'custom_output2': 'response2' + } + + # Test with the literal type + event3 = Event( + author='test_agent', + content=types.Content(parts=[Part(text='"option3"')]), + ) + agent._LlmAgent__maybe_save_output_to_state(event3) + assert event3.actions.state_delta['test_output'] == 'option3' diff --git a/tests/unittests/tools/test_agent_tool.py b/tests/unittests/tools/test_agent_tool.py index b5f59be0fc..ea9af8eb00 100644 --- a/tests/unittests/tools/test_agent_tool.py +++ b/tests/unittests/tools/test_agent_tool.py @@ -13,7 +13,9 @@ # limitations under the License. from typing import Any +from typing import Literal from typing import Optional +from typing import Union from google.adk.agents.callback_context import CallbackContext from google.adk.agents.invocation_context import InvocationContext @@ -31,7 +33,6 @@ from google.adk.sessions.in_memory_session_service import InMemorySessionService from google.adk.tools.agent_tool import AgentTool from google.adk.tools.tool_context import ToolContext -from google.adk.utils.variant_utils import GoogleLLMVariant from google.genai import types from google.genai.types import Part from pydantic import BaseModel @@ -421,6 +422,165 @@ class CustomOutput(BaseModel): assert mock_model.requests[1].config.response_mime_type == 'application/json' +@mark.parametrize( + 'env_variables', + ['GOOGLE_AI', 'VERTEX'], + indirect=True, +) +def test_custom_schema_with_union(env_variables): + """Tests if agent can have a Union type in output_schema.""" + + class CustomInput(BaseModel): + custom_input: str + + class CustomOutput(BaseModel): + custom_output: str + + class CustomOutput2(BaseModel): + custom_output2: str + + # --- Test Case 1: Model returns the first type in the Union --- + mock_model_1 = testing_utils.MockModel.create( + responses=[ + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_1'} + ), + '{"custom_output": "response_union_1"}', # Matches CustomOutput + 'Final response after union output 1.', + ] + ) + + tool_agent = Agent( + name='tool_agent', + model=mock_model_1, + input_schema=CustomInput, + output_schema=Union[CustomOutput, CustomOutput2, Literal['option3']], + output_key='tool_output_union', + ) + + root_agent = Agent( + name='root_agent', + model=mock_model_1, + tools=[AgentTool(agent=tool_agent)], + ) + + runner = testing_utils.InMemoryRunner(root_agent) + events = runner.run('test_union_input_1') + + simplified_events = testing_utils.simplify_events(events) + assert simplified_events == [ + ( + 'root_agent', + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_1'} + ), + ), + ( + 'root_agent', + Part.from_function_response( + name='tool_agent', + response={'custom_output': 'response_union_1'}, + ), + ), + ('root_agent', 'Final response after union output 1.'), + ] + assert runner.session.state['tool_output_union'] == { + 'custom_output': 'response_union_1' + } + assert len(mock_model_1.requests) == 3 + assert ( + mock_model_1.requests[1].config.response_schema + == Union[CustomOutput, CustomOutput2, Literal['option3']] + ) + + # --- Test Case 2: Model returns the second type in the Union --- + mock_model_2 = testing_utils.MockModel.create( + responses=[ + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_2'} + ), + '{"custom_output2": "response_union_2"}', # Matches CustomOutput2 + 'Final response after union output 2.', + ] + ) + + # Re-configure the agent with the new mock model + tool_agent.model = mock_model_2 + root_agent.model = mock_model_2 + + runner_2 = testing_utils.InMemoryRunner(root_agent) + events_2 = runner_2.run('test_union_input_2') + + simplified_events_2 = testing_utils.simplify_events(events_2) + assert simplified_events_2 == [ + ( + 'root_agent', + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_2'} + ), + ), + ( + 'root_agent', + Part.from_function_response( + name='tool_agent', + response={'custom_output2': 'response_union_2'}, + ), + ), + ('root_agent', 'Final response after union output 2.'), + ] + assert runner_2.session.state['tool_output_union'] == { + 'custom_output2': 'response_union_2' + } + assert len(mock_model_2.requests) == 3 + assert ( + mock_model_2.requests[1].config.response_schema + == Union[CustomOutput, CustomOutput2, Literal['option3']] + ) + + # --- Test Case 3: Model returns the literal type in the Union --- + mock_model_3 = testing_utils.MockModel.create( + responses=[ + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_3'} + ), + '"option3"', # Matches Literal['option3'] + 'Final response after literal output.', + ] + ) + + # Re-configure the agent with the new mock model + tool_agent.model = mock_model_3 + root_agent.model = mock_model_3 + + runner_3 = testing_utils.InMemoryRunner(root_agent) + events_3 = runner_3.run('test_union_input_3') + + simplified_events_3 = testing_utils.simplify_events(events_3) + assert simplified_events_3 == [ + ( + 'root_agent', + Part.from_function_call( + name='tool_agent', args={'custom_input': 'test_union_3'} + ), + ), + ( + 'root_agent', + Part.from_function_response( + name='tool_agent', + response={'result': 'option3'}, + ), + ), + ('root_agent', 'Final response after literal output.'), + ] + # When the result is not a BaseModel, it's stored directly. + assert runner_3.session.state['tool_output_union'] == 'option3' + assert len(mock_model_3.requests) == 3 + assert ( + mock_model_3.requests[1].config.response_schema + == Union[CustomOutput, CustomOutput2, Literal['option3']] + ) + + @mark.parametrize( 'env_variables', [