diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index a0168fbe21..9eda7a078c 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -15,6 +15,7 @@ from __future__ import annotations +from enum import Enum import inspect import logging import types as typing_types @@ -75,7 +76,7 @@ def _raise_if_schema_unsupported( ): if variant == GoogleLLMVariant.GEMINI_API: _raise_for_any_of_if_mldev(schema) - _update_for_default_if_mldev(schema) + # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value def _is_default_value_compatible( @@ -145,6 +146,20 @@ def _parse_schema_from_parameter( schema.type = _py_builtin_type_to_schema_type[param.annotation] _raise_if_schema_unsupported(variant, schema) return schema + if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): + schema.type = types.Type.STRING + schema.enum = [e.value for e in param.annotation] + if param.default is not inspect.Parameter.empty: + default_value = ( + param.default.value + if isinstance(param.default, Enum) + else param.default + ) + if default_value not in schema.enum: + raise ValueError(default_value_error_msg) + schema.default = default_value + _raise_if_schema_unsupported(variant, schema) + return schema if ( get_origin(param.annotation) is Union # only parse simple UnionType, example int | str | float | bool diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index edf3c7128e..8be1f86520 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum from typing import Dict from typing import List @@ -22,6 +23,7 @@ # TODO: crewai requires python 3.10 as minimum # from crewai_tools import FileReadTool from pydantic import BaseModel +import pytest def test_string_input(): @@ -220,6 +222,34 @@ def simple_function( assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT' +def test_enums(): + + class InputEnum(Enum): + AGENT = 'agent' + TOOL = 'tool' + + def simple_function(input: InputEnum = InputEnum.AGENT): + return input.value + + function_decl = _automatic_function_calling_util.build_function_declaration( + func=simple_function + ) + + assert function_decl.name == 'simple_function' + assert function_decl.parameters.type == 'OBJECT' + assert function_decl.parameters.properties['input'].type == 'STRING' + assert function_decl.parameters.properties['input'].default == 'agent' + assert function_decl.parameters.properties['input'].enum == ['agent', 'tool'] + + def simple_function_with_wrong_enum(input: InputEnum = 'WRONG_ENUM'): + return input.value + + with pytest.raises(ValueError): + _automatic_function_calling_util.build_function_declaration( + func=simple_function_with_wrong_enum + ) + + def test_basemodel_list(): class ChildInput(BaseModel): input_str: str