diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index e522d792..a47ff86c 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -81,6 +81,12 @@ class ParameterSchema(BaseModel): authSources: Optional[list[str]] = None items: Optional["ParameterSchema"] = None additionalProperties: Optional[Union[bool, AdditionalPropertiesSchema]] = None + default: Optional[Any] = None + + @property + def has_default(self) -> bool: + """Returns True if `default` was explicitly provided in schema input.""" + return "default" in self.model_fields_set def __get_type(self) -> Type: base_type: Type @@ -103,11 +109,17 @@ def __get_type(self) -> Type: return base_type def to_param(self) -> Parameter: + default_value = Parameter.empty + if self.has_default: + default_value = self.default + elif not self.required: + default_value = None + return Parameter( self.name, Parameter.POSITIONAL_OR_KEYWORD, annotation=self.__get_type(), - default=Parameter.empty if self.required else None, + default=default_value, ) diff --git a/packages/toolbox-core/src/toolbox_core/tool.py b/packages/toolbox-core/src/toolbox_core/tool.py index 0c72c39b..128e57cd 100644 --- a/packages/toolbox-core/src/toolbox_core/tool.py +++ b/packages/toolbox-core/src/toolbox_core/tool.py @@ -13,9 +13,8 @@ # limitations under the License. import copy -import itertools from collections import OrderedDict -from inspect import Signature +from inspect import Parameter, Signature from types import MappingProxyType from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union from warnings import warn @@ -86,13 +85,17 @@ def __init__( self.__params = params self.__pydantic_model = params_to_pydantic_model(name, self.__params) - # Separate parameters into required (no default) and optional (with - # default) to prevent the "non-default argument follows default + # Separate parameters into those without a default and those with a + # default to prevent the "non-default argument follows default # argument" error when creating the function signature. - required_params = (p for p in self.__params if p.required) - optional_params = (p for p in self.__params if not p.required) - ordered_params = itertools.chain(required_params, optional_params) - inspect_type_params = [param.to_param() for param in ordered_params] + inspect_type_params = [param.to_param() for param in self.__params] + params_no_default = [ + p for p in inspect_type_params if p.default is Parameter.empty + ] + params_with_default = [ + p for p in inspect_type_params if p.default is not Parameter.empty + ] + inspect_type_params = params_no_default + params_with_default # the following properties are set to help anyone that might inspect it determine usage self.__name__ = name diff --git a/packages/toolbox-core/src/toolbox_core/utils.py b/packages/toolbox-core/src/toolbox_core/utils.py index 08a87a45..d68f585a 100644 --- a/packages/toolbox-core/src/toolbox_core/utils.py +++ b/packages/toolbox-core/src/toolbox_core/utils.py @@ -111,10 +111,12 @@ def params_to_pydantic_model( field_definitions = {} for field in params: - # Determine the default value based on the 'required' flag. + # Determine the default value based on the 'required' flag and the 'default' field. # '...' (Ellipsis) signifies a required field in Pydantic. - # 'None' makes the field optional with a default value of None. + # If a default value is provided in the schema, it should be used. default_value = ... if field.required else None + if field.has_default: + default_value = field.default field_definitions[field.name] = cast( Any, diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index 8dd60e3f..fe6320cd 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -289,3 +289,55 @@ def test_parameter_schema_map_unsupported_value_type_error(): expected_error_msg = f"Unsupported schema type: {unsupported_type}" with pytest.raises(ValueError, match=expected_error_msg): schema._ParameterSchema__get_type() + + +def test_parameter_schema_with_default(): + """Tests ParameterSchema with a default value provided.""" + schema = ParameterSchema( + name="limit", + type="integer", + description="Limit results", + required=False, + default=10, + ) + expected_type = Optional[int] + + assert schema._ParameterSchema__get_type() == expected_type + + param = schema.to_param() + assert isinstance(param, Parameter) + assert param.name == "limit" + assert param.annotation == expected_type + assert param.default == 10 + + +def test_parameter_schema_required_with_default(): + """Tests ParameterSchema with default value, ignoring required=True implies it is optional in python signature.""" + # Although illogical in some schemas, if default is present, it should be used as default. + schema = ParameterSchema( + name="retry_count", + type="integer", + description="Retries", + required=True, + default=3, + ) + + # get_type still respects required=True for type hint + assert schema._ParameterSchema__get_type() == int + + param = schema.to_param() + assert param.default == 3 + + +def test_parameter_schema_required_with_explicit_none_default(): + """Tests explicit default=None is treated as a provided default.""" + schema = ParameterSchema( + name="opt_in", + type="boolean", + description="Optional flag", + required=True, + default=None, + ) + + param = schema.to_param() + assert param.default is None diff --git a/packages/toolbox-core/tests/test_utils.py b/packages/toolbox-core/tests/test_utils.py index b3ddd7c3..5c9dd29d 100644 --- a/packages/toolbox-core/tests/test_utils.py +++ b/packages/toolbox-core/tests/test_utils.py @@ -35,6 +35,8 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock: param_mock.name = name param_mock.description = description param_mock.required = True + param_mock.default = None + param_mock.has_default = False mock_param_info = Mock() mock_param_info.annotation = annotation @@ -422,6 +424,24 @@ def test_params_to_pydantic_model_with_params(): Model(name="Bob", age="thirty", is_active=True) +def test_params_to_pydantic_model_uses_explicit_default_none(): + """Test that explicit default=None is honored for required schema fields.""" + tool_name = "MyToolWithExplicitNoneDefault" + params = [ + ParameterSchema( + name="message", + type="string", + description="Message value", + required=True, + default=None, + ) + ] + Model = params_to_pydantic_model(tool_name, params) + + assert "message" in Model.model_fields + assert Model.model_fields["message"].default is None + + @pytest.mark.asyncio async def test_resolve_value_plain_value(): """Test resolving a plain, non-callable value."""