Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion packages/toolbox-core/src/toolbox_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class ParameterSchema(BaseModel):
authSources: Optional[list[str]] = None
items: Optional["ParameterSchema"] = None
additionalProperties: Optional[Union[bool, AdditionalPropertiesSchema]] = None
default: Optional[Any] = None

@property
def has_default(self) -> bool:
"""Returns True if `default` was explicitly provided in schema input."""
return "default" in self.model_fields_set

def __get_type(self) -> Type:
base_type: Type
Expand All @@ -103,11 +109,17 @@ def __get_type(self) -> Type:
return base_type

def to_param(self) -> Parameter:
default_value = Parameter.empty
if self.has_default:
default_value = self.default
elif not self.required:
default_value = None

return Parameter(
self.name,
Parameter.POSITIONAL_OR_KEYWORD,
annotation=self.__get_type(),
default=Parameter.empty if self.required else None,
default=default_value,
)


Expand Down
19 changes: 11 additions & 8 deletions packages/toolbox-core/src/toolbox_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
# limitations under the License.

import copy
import itertools
from collections import OrderedDict
from inspect import Signature
from inspect import Parameter, Signature
from types import MappingProxyType
from typing import Any, Awaitable, Callable, Mapping, Optional, Sequence, Union
from warnings import warn
Expand Down Expand Up @@ -86,13 +85,17 @@ def __init__(
self.__params = params
self.__pydantic_model = params_to_pydantic_model(name, self.__params)

# Separate parameters into required (no default) and optional (with
# default) to prevent the "non-default argument follows default
# Separate parameters into those without a default and those with a
# default to prevent the "non-default argument follows default
# argument" error when creating the function signature.
required_params = (p for p in self.__params if p.required)
optional_params = (p for p in self.__params if not p.required)
ordered_params = itertools.chain(required_params, optional_params)
inspect_type_params = [param.to_param() for param in ordered_params]
inspect_type_params = [param.to_param() for param in self.__params]
params_no_default = [
p for p in inspect_type_params if p.default is Parameter.empty
]
params_with_default = [
p for p in inspect_type_params if p.default is not Parameter.empty
]
inspect_type_params = params_no_default + params_with_default

# the following properties are set to help anyone that might inspect it determine usage
self.__name__ = name
Expand Down
6 changes: 4 additions & 2 deletions packages/toolbox-core/src/toolbox_core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,12 @@ def params_to_pydantic_model(
field_definitions = {}
for field in params:

# Determine the default value based on the 'required' flag.
# Determine the default value based on the 'required' flag and the 'default' field.
# '...' (Ellipsis) signifies a required field in Pydantic.
# 'None' makes the field optional with a default value of None.
# If a default value is provided in the schema, it should be used.
default_value = ... if field.required else None
if field.has_default:
default_value = field.default

field_definitions[field.name] = cast(
Any,
Expand Down
52 changes: 52 additions & 0 deletions packages/toolbox-core/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,55 @@ def test_parameter_schema_map_unsupported_value_type_error():
expected_error_msg = f"Unsupported schema type: {unsupported_type}"
with pytest.raises(ValueError, match=expected_error_msg):
schema._ParameterSchema__get_type()


def test_parameter_schema_with_default():
"""Tests ParameterSchema with a default value provided."""
schema = ParameterSchema(
name="limit",
type="integer",
description="Limit results",
required=False,
default=10,
)
expected_type = Optional[int]

assert schema._ParameterSchema__get_type() == expected_type

param = schema.to_param()
assert isinstance(param, Parameter)
assert param.name == "limit"
assert param.annotation == expected_type
assert param.default == 10


def test_parameter_schema_required_with_default():
"""Tests ParameterSchema with default value, ignoring required=True implies it is optional in python signature."""
# Although illogical in some schemas, if default is present, it should be used as default.
schema = ParameterSchema(
name="retry_count",
type="integer",
description="Retries",
required=True,
default=3,
)

# get_type still respects required=True for type hint
assert schema._ParameterSchema__get_type() == int

param = schema.to_param()
assert param.default == 3


def test_parameter_schema_required_with_explicit_none_default():
"""Tests explicit default=None is treated as a provided default."""
schema = ParameterSchema(
name="opt_in",
type="boolean",
description="Optional flag",
required=True,
default=None,
)

param = schema.to_param()
assert param.default is None
20 changes: 20 additions & 0 deletions packages/toolbox-core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def create_param_mock(name: str, description: str, annotation: Type) -> Mock:
param_mock.name = name
param_mock.description = description
param_mock.required = True
param_mock.default = None
param_mock.has_default = False

mock_param_info = Mock()
mock_param_info.annotation = annotation
Expand Down Expand Up @@ -422,6 +424,24 @@ def test_params_to_pydantic_model_with_params():
Model(name="Bob", age="thirty", is_active=True)


def test_params_to_pydantic_model_uses_explicit_default_none():
"""Test that explicit default=None is honored for required schema fields."""
tool_name = "MyToolWithExplicitNoneDefault"
params = [
ParameterSchema(
name="message",
type="string",
description="Message value",
required=True,
default=None,
)
]
Model = params_to_pydantic_model(tool_name, params)

assert "message" in Model.model_fields
assert Model.model_fields["message"].default is None


@pytest.mark.asyncio
async def test_resolve_value_plain_value():
"""Test resolving a plain, non-callable value."""
Expand Down
Loading