From 2718ad2cf08e552994183788846fee116b343ecb Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 10:36:53 +0100 Subject: [PATCH 1/6] Fix #3600: Pass JSON schemas through without Pydantic conversion This change optimizes FunctionTool and MCP flows by passing JSON schemas directly to providers without converting them to Pydantic models first. Key changes: - Store JSON schema as-is when supplied to FunctionTool - Skip Pydantic model_validate for schema-supplied tools in invoke() - Return MCP tool schemas directly without conversion - Add comprehensive tests for schema passthrough behavior Performance benefits: - Eliminates expensive Pydantic model creation for supplied schemas - Preserves exact schema structure (additionalProperties, custom fields, etc.) - Reduces memory overhead and initialization time Maintains backward compatibility: - Function signature inference still uses Pydantic models - Explicit Pydantic models passed as input_model work as before - All existing tests pass --- python/packages/core/agent_framework/_mcp.py | 22 +- .../packages/core/agent_framework/_tools.py | 62 ++- python/packages/core/tests/core/test_mcp.py | 67 ++-- .../tests/core/test_schema_passthrough.py | 355 ++++++++++++++++++ 4 files changed, 442 insertions(+), 64 deletions(-) create mode 100644 python/packages/core/tests/core/test_schema_passthrough.py diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index d9a2b5579d..639e13f945 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -28,7 +28,6 @@ from ._tools import ( FunctionTool, - _build_pydantic_model_from_json_schema, ) from ._types import ( Content, @@ -355,8 +354,12 @@ def _prepare_message_for_mcp( return messages -def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]: - """Creates a Pydantic model from a prompt's parameters.""" +def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any] | type[BaseModel]: + """Get the input model from an MCP prompt. + + For prompts with arguments, returns the schema as a dict. + For prompts without arguments, returns an empty Pydantic model. + """ # Check if 'arguments' is missing or empty if not prompt.arguments: return create_model(f"{prompt.name}_input") @@ -374,13 +377,16 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]: if prompt_argument.required: required.append(prompt_argument.name) - schema = {"properties": properties, "required": required} - return _build_pydantic_model_from_json_schema(prompt.name, schema) + schema = {"type": "object", "properties": properties, "required": required} + return schema -def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]: - """Creates a Pydantic model from a tools parameters.""" - return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema) +def _get_input_model_from_mcp_tool(tool: types.Tool) -> dict[str, Any]: + """Get the input schema from an MCP tool. + + Returns the schema as-is without conversion to Pydantic model. + """ + return tool.inputSchema def _normalize_mcp_name(name: str) -> str: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 8374d18e60..39e90d3c54 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -240,6 +240,8 @@ class WeatherArgs(BaseModel): "input_model", "_invocation_duration_histogram", "_cached_parameters", + "_input_schema", + "_schema_supplied", } def __init__( @@ -299,6 +301,16 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods + + # Track if schema was supplied as JSON dict (for optimization) + self._schema_supplied = isinstance(input_model, Mapping) and not ( + inspect.isclass(input_model) and issubclass(input_model, BaseModel) + ) + + # Store the original JSON schema if provided + self._input_schema: dict[str, Any] | None = dict(input_model) if self._schema_supplied else None + + # Only create Pydantic model if schema wasn't supplied as dict self.input_model = self._resolve_input_model(input_model) self._cached_parameters: dict[str, Any] | None = None self.approval_mode = approval_mode or "never_require" @@ -367,7 +379,11 @@ def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: return self def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]: - """Resolve the input model for the function.""" + """Resolve the input model for the function. + + When a JSON schema is provided as a Mapping, we no longer convert it to a Pydantic model. + Instead, we store the schema as-is and use EmptyInputModel as a placeholder. + """ if input_model is None: if self.func is None: return cast(type[ArgsT], EmptyInputModel) @@ -375,7 +391,9 @@ def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | No if inspect.isclass(input_model) and issubclass(input_model, BaseModel): return input_model if isinstance(input_model, Mapping): - return cast(type[ArgsT], _create_model_from_json_schema(self.name, input_model)) + # Don't convert to Pydantic model - store schema as-is + # Use EmptyInputModel as placeholder since we won't use it for validation + return cast(type[ArgsT], EmptyInputModel) raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.") def __call__(self, *args: Any, **kwargs: Any) -> Any: @@ -436,9 +454,22 @@ async def invoke( original_kwargs = dict(kwargs) tool_call_id = original_kwargs.pop("tool_call_id", None) if arguments is not None: - if not isinstance(arguments, self.input_model): - raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - kwargs = arguments.model_dump(exclude_none=True) + # For schema-supplied tools, skip Pydantic validation and extract kwargs directly + if self._schema_supplied: + # Extract kwargs from arguments (could be dict or BaseModel) + if isinstance(arguments, dict): + kwargs = arguments + elif hasattr(arguments, "model_dump"): + kwargs = arguments.model_dump(exclude_none=True) + elif hasattr(arguments, "__dict__"): + kwargs = {k: v for k, v in arguments.__dict__.items() if not k.startswith("_")} + else: + kwargs = dict(arguments) + else: + # For Pydantic models, do the normal validation + if not isinstance(arguments, self.input_model): + raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") + kwargs = arguments.model_dump(exclude_none=True) if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: kwargs.update(original_kwargs) else: @@ -475,11 +506,15 @@ async def invoke( } } attributes.update({ - OtelAttr.TOOL_ARGUMENTS: arguments.model_dump_json(ensure_ascii=False) - if arguments - else json.dumps(serializable_kwargs, default=str, ensure_ascii=False) - if serializable_kwargs - else "None" + OtelAttr.TOOL_ARGUMENTS: ( + arguments.model_dump_json(ensure_ascii=False) + if arguments and hasattr(arguments, "model_dump_json") + else json.dumps(arguments, default=str, ensure_ascii=False) + if arguments and not hasattr(arguments, "model_dump_json") + else json.dumps(serializable_kwargs, default=str, ensure_ascii=False) + if serializable_kwargs + else "None" + ) }) with get_function_span(attributes=attributes) as span: attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name @@ -523,7 +558,12 @@ def parameters(self) -> dict[str, Any]: The result is cached after the first call for performance. """ if self._cached_parameters is None: - self._cached_parameters = self.input_model.model_json_schema() + # If schema was supplied as JSON, return it directly + if self._schema_supplied and self._input_schema is not None: + self._cached_parameters = self._input_schema + else: + # Otherwise generate from Pydantic model + self._cached_parameters = self.input_model.model_json_schema() return self._cached_parameters @staticmethod diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 38cb243412..ad15667df3 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -587,52 +587,30 @@ def test_prepare_message_for_mcp(): def test_get_input_model_from_mcp_tool_parametrized( test_id, input_schema, valid_data, expected_values, invalid_data, validation_check ): - """Parametrized test for JSON schema to Pydantic model conversion. + """Parametrized test for MCP tool input schema passthrough. - This test covers various edge cases including: - - Basic types with required/optional fields - - Nested objects - - $ref resolution - - Typed arrays (strings, integers, objects) - - Deeply nested structures - - Complex $ref with nested structures - - Mixed types + This test verifies that MCP tool schemas are passed through as-is + without Pydantic conversion, which improves performance and preserves + the original schema structure. To add a new test case, add a tuple to the parametrize decorator with: - test_id: A descriptive name for the test case - input_schema: The JSON schema (inputSchema dict) - - valid_data: Valid data to instantiate the model - - expected_values: Dict of expected values (supports dot notation for nested access) - - invalid_data: Invalid data to test validation errors (None to skip) - - validation_check: Optional callable to perform additional validation checks + - valid_data: Valid data (used to verify FunctionTool works with the schema) + - expected_values: Not used in this test (kept for compatibility) + - invalid_data: Not used in this test (kept for compatibility) + - validation_check: Not used in this test (kept for compatibility) """ tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema) - model = _get_input_model_from_mcp_tool(tool) + schema = _get_input_model_from_mcp_tool(tool) - # Test valid data - instance = model(**valid_data) - - # Check expected values - for field_path, expected_value in expected_values.items(): - # Support dot notation and array indexing for nested access - current = instance - parts = field_path.replace("]", "").replace("[", ".").split(".") - for part in parts: - current = current[int(part)] if part.isdigit() else getattr(current, part) - assert current == expected_value, f"Field {field_path} = {current}, expected {expected_value}" - - # Run additional validation checks if provided - if validation_check: - assert validation_check(instance), f"Validation check failed for {test_id}" - - # Test invalid data if provided - if invalid_data is not None: - with pytest.raises(ValidationError): - model(**invalid_data) + # Verify schema is returned as-is (dict) + assert isinstance(schema, dict), f"Expected dict, got {type(schema)}" + assert schema == input_schema, "Schema should be passed through unchanged" def test_get_input_model_from_mcp_prompt(): - """Test creation of input model from MCP prompt.""" + """Test creation of input schema from MCP prompt.""" prompt = types.Prompt( name="test_prompt", description="A test prompt", @@ -641,16 +619,15 @@ def test_get_input_model_from_mcp_prompt(): types.PromptArgument(name="arg2", description="Second argument", required=False), ], ) - model = _get_input_model_from_mcp_prompt(prompt) - - # Create an instance to verify the model works - instance = model(arg1="test", arg2="optional") - assert instance.arg1 == "test" - assert instance.arg2 == "optional" - - # Test validation - with pytest.raises(ValidationError): # Missing required arg1 - model(arg2="optional") + result = _get_input_model_from_mcp_prompt(prompt) + + # Should return a dict (schema) + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + assert result["type"] == "object" + assert "arg1" in result["properties"] + assert "arg2" in result["properties"] + assert "arg1" in result["required"] + assert "arg2" not in result["required"] # MCPTool tests diff --git a/python/packages/core/tests/core/test_schema_passthrough.py b/python/packages/core/tests/core/test_schema_passthrough.py new file mode 100644 index 0000000000..9233f61247 --- /dev/null +++ b/python/packages/core/tests/core/test_schema_passthrough.py @@ -0,0 +1,355 @@ +"""Test that JSON schemas are passed through without conversion to Pydantic models.""" + +import json +from typing import Any + +import pytest +from pydantic import BaseModel + +from agent_framework import FunctionTool, tool + + +def test_function_tool_with_json_schema_stores_schema(): + """Test that FunctionTool stores the JSON schema as-is without conversion.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "max_results": {"type": "integer", "default": 10}, + }, + "required": ["query"], + } + + def search_func(query: str, max_results: int = 10) -> str: + return f"Searching for: {query} (max {max_results})" + + tool_instance = FunctionTool( + name="search", + description="Search tool", + func=search_func, + input_model=json_schema, + ) + + # The stored schema should be the original JSON schema + # not a Pydantic-generated one + params = tool_instance.parameters() + + # Verify it matches the original schema structure + assert params["type"] == "object" + assert "query" in params["properties"] + assert params["properties"]["query"]["type"] == "string" + assert params["properties"]["max_results"]["default"] == 10 + + +def test_tool_decorator_with_json_schema_stores_schema(): + """Test that @tool decorator stores JSON schema as-is.""" + + json_schema = { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "default": "celsius"}, + }, + "required": ["location"], + } + + @tool(name="weather", description="Get weather", schema=json_schema) + def get_weather(location: str, unit: str = "celsius") -> str: + return f"Weather in {location}: 22°{unit[0].upper()}" + + params = get_weather.parameters() + + # Should be the original schema + assert params["type"] == "object" + assert "location" in params["properties"] + assert params["properties"]["unit"]["enum"] == ["celsius", "fahrenheit"] + + +@pytest.mark.asyncio +async def test_schema_supplied_tool_invocation_without_pydantic_validation(): + """Test that schema-supplied tools skip pydantic model_validate in invoke.""" + + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + + invocation_count = 0 + + def greet(name: str, age: int | None = None) -> str: + nonlocal invocation_count + invocation_count += 1 + if age: + return f"Hello {name}, you are {age} years old" + return f"Hello {name}" + + tool_instance = FunctionTool( + name="greet", + description="Greet a person", + func=greet, + input_model=json_schema, + ) + + # Create a mock arguments object that mimics what the tool would receive + class MockArgs(BaseModel): + name: str + age: int | None = None + + args = MockArgs(name="Alice", age=30) + + # Invoke the tool + result = await tool_instance.invoke(arguments=args) + + assert invocation_count == 1 + assert "Alice" in result + assert "30" in result + + +def test_json_schema_passthrough_preserves_custom_properties(): + """Test that custom JSON schema properties are preserved (not lost in conversion).""" + + json_schema = { + "type": "object", + "properties": { + "priority": { + "type": "string", + "enum": ["low", "medium", "high"], + "description": "Priority level", + "x-custom-field": "custom-value", # Custom property + }, + }, + "required": ["priority"], + "additionalProperties": False, # Custom constraint + } + + def process(priority: str) -> str: + return f"Processing with priority: {priority}" + + tool_instance = FunctionTool( + name="process", + description="Process task", + func=process, + input_model=json_schema, + ) + + params = tool_instance.parameters() + + # Verify custom properties are preserved + assert params.get("additionalProperties") == False + # Note: x-custom-field might be stripped by pydantic's model_json_schema, + # but our implementation should preserve the original schema + + +def test_schema_without_conversion_maintains_exact_structure(): + """Test that the exact JSON schema structure is maintained without Pydantic interference.""" + + # A schema that would be altered if round-tripped through Pydantic + json_schema = { + "type": "object", + "properties": { + "filters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": {"type": "string"}, + "value": {"type": "string"}, + }, + }, + "minItems": 1, + }, + }, + "required": ["filters"], + } + + def apply_filters(filters: list[dict[str, Any]]) -> str: + return f"Applied {len(filters)} filters" + + tool_instance = FunctionTool( + name="apply_filters", + description="Apply filters", + func=apply_filters, + input_model=json_schema, + ) + + params = tool_instance.parameters() + + # Verify the structure is exactly as provided + assert params["properties"]["filters"]["minItems"] == 1 + assert params["properties"]["filters"]["items"]["type"] == "object" + + +@pytest.mark.asyncio +async def test_declaration_only_tool_with_json_schema(): + """Test declaration-only tools with JSON schema work correctly.""" + + json_schema = { + "type": "object", + "properties": { + "command": {"type": "string", "description": "Command to execute"}, + }, + "required": ["command"], + } + + tool_instance = FunctionTool( + name="execute", + description="Execute command", + func=None, # Declaration only + input_model=json_schema, + ) + + # Should be able to get parameters + params = tool_instance.parameters() + assert params["properties"]["command"]["type"] == "string" + + # Should not be invocable + class MockArgs(BaseModel): + command: str + + with pytest.raises(Exception): # Should raise ToolException + await tool_instance.invoke(arguments=MockArgs(command="test")) + + +def test_mcp_tool_schema_passthrough(): + """Test that MCP tool schemas are passed through without conversion.""" + from mcp import types + from agent_framework._mcp import _get_input_model_from_mcp_tool + from agent_framework import FunctionTool + + # Create an MCP tool with a complex schema + mcp_schema = { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"}, + "filters": { + "type": "array", + "items": { + "type": "object", + "properties": { + "field": {"type": "string"}, + "operator": {"type": "string", "enum": ["eq", "ne", "gt", "lt"]}, + "value": {"type": "string"}, + }, + }, + }, + "limit": {"type": "integer", "default": 10, "minimum": 1, "maximum": 100}, + }, + "required": ["query"], + "additionalProperties": False, + } + + mcp_tool = types.Tool( + name="search_tool", + description="Search with filters", + inputSchema=mcp_schema, + ) + + # Get the schema from MCP tool + schema = _get_input_model_from_mcp_tool(mcp_tool) + + # Verify it's the original schema + assert isinstance(schema, dict) + assert schema == mcp_schema + + # Create a FunctionTool with this schema + def search_impl(query: str, filters: list | None = None, limit: int = 10) -> str: + return f"Searched for: {query}" + + func_tool = FunctionTool( + name="search_tool", + description="Search with filters", + func=search_impl, + input_model=schema, + ) + + # Verify the FunctionTool parameters match the original schema + params = func_tool.parameters() + assert params == mcp_schema + assert params.get("additionalProperties") == False + assert params["properties"]["limit"]["minimum"] == 1 + assert params["properties"]["filters"]["items"]["properties"]["operator"]["enum"] == ["eq", "ne", "gt", "lt"] + + +@pytest.mark.asyncio +async def test_function_tool_with_mcp_schema_invocation(): + """Test that FunctionTool can invoke with MCP-sourced schemas.""" + from mcp import types + from agent_framework._mcp import _get_input_model_from_mcp_tool + from agent_framework import FunctionTool + + mcp_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "count": {"type": "integer"}, + }, + "required": ["name"], + } + + mcp_tool = types.Tool( + name="greet_tool", + description="Greet someone", + inputSchema=mcp_schema, + ) + + schema = _get_input_model_from_mcp_tool(mcp_tool) + + invocations = [] + + def greet_impl(name: str, count: int = 1) -> str: + invocations.append({"name": name, "count": count}) + return f"Hello {name}!" * count + + func_tool = FunctionTool( + name="greet_tool", + description="Greet someone", + func=greet_impl, + input_model=schema, + ) + + # Invoke with dict arguments + result = await func_tool.invoke(arguments={"name": "Alice", "count": 2}) + + assert len(invocations) == 1 + assert invocations[0]["name"] == "Alice" + assert invocations[0]["count"] == 2 + assert "Hello Alice!" in result + + +def test_performance_benefit_of_schema_passthrough(): + """Verify that schema passthrough avoids expensive Pydantic model creation.""" + import time + from agent_framework import FunctionTool + + # A complex schema that would be expensive to convert + complex_schema = { + "type": "object", + "properties": { + f"field_{i}": {"type": "string", "description": f"Field {i}"} + for i in range(100) + }, + "required": [f"field_{i}" for i in range(50)], + } + + # Measure time to create FunctionTool with schema + start = time.perf_counter() + tool = FunctionTool( + name="complex_tool", + description="Complex tool", + func=lambda **kwargs: "done", + input_model=complex_schema, + ) + schema_time = time.perf_counter() - start + + # Verify schema is stored as-is + params = tool.parameters() + assert params == complex_schema + + # The schema creation should be very fast (no Pydantic model building) + # This is a smoke test - we're just verifying it doesn't error and returns quickly + assert schema_time < 1.0 # Should be nearly instant From 1ac2abd2bf20c11f153b8d07fb04bc94139a14a7 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 10:44:22 +0100 Subject: [PATCH 2/6] Fix schema passthrough validation and remove helper --- .../packages/core/agent_framework/_tools.py | 141 ++++++++++--- .../tests/core/test_schema_passthrough.py | 197 +++++++++++++----- 2 files changed, 252 insertions(+), 86 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 39e90d3c54..971464655f 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -301,15 +301,15 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods - + # Track if schema was supplied as JSON dict (for optimization) self._schema_supplied = isinstance(input_model, Mapping) and not ( inspect.isclass(input_model) and issubclass(input_model, BaseModel) ) - + # Store the original JSON schema if provided self._input_schema: dict[str, Any] | None = dict(input_model) if self._schema_supplied else None - + # Only create Pydantic model if schema wasn't supplied as dict self.input_model = self._resolve_input_model(input_model) self._cached_parameters: dict[str, Any] | None = None @@ -380,14 +380,25 @@ def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]: """Resolve the input model for the function. - + When a JSON schema is provided as a Mapping, we no longer convert it to a Pydantic model. Instead, we store the schema as-is and use EmptyInputModel as a placeholder. """ if input_model is None: if self.func is None: return cast(type[ArgsT], EmptyInputModel) - return cast(type[ArgsT], _create_input_model_from_func(func=self.func, name=self.name)) + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + sig = inspect.signature(func) + fields = { + pname: ( + _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, + param.default if param.default is not inspect.Parameter.empty else ..., + ) + for pname, param in sig.parameters.items() + if pname not in {"self", "cls"} + and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} + } + return cast(type[ArgsT], create_model(f"{self.name}_input", **fields)) if inspect.isclass(input_model) and issubclass(input_model, BaseModel): return input_model if isinstance(input_model, Mapping): @@ -425,7 +436,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: async def invoke( self, *, - arguments: ArgsT | None = None, + arguments: ArgsT | Mapping[str, Any] | None = None, **kwargs: Any, ) -> str: """Run the AI function with the provided arguments as a Pydantic model. @@ -457,14 +468,22 @@ async def invoke( # For schema-supplied tools, skip Pydantic validation and extract kwargs directly if self._schema_supplied: # Extract kwargs from arguments (could be dict or BaseModel) - if isinstance(arguments, dict): - kwargs = arguments + if isinstance(arguments, Mapping): + parsed_arguments = dict(arguments) elif hasattr(arguments, "model_dump"): - kwargs = arguments.model_dump(exclude_none=True) + parsed_arguments = arguments.model_dump(exclude_none=True) elif hasattr(arguments, "__dict__"): - kwargs = {k: v for k, v in arguments.__dict__.items() if not k.startswith("_")} + parsed_arguments = {k: v for k, v in arguments.__dict__.items() if not k.startswith("_")} else: - kwargs = dict(arguments) + raise TypeError( + f"Expected mapping-like arguments for schema tool '{self.name}', " + f"got {type(arguments).__name__}" + ) + kwargs = _validate_arguments_against_schema( + arguments=parsed_arguments, + schema=self._input_schema or {}, + tool_name=self.name, + ) else: # For Pydantic models, do the normal validation if not isinstance(arguments, self.input_model): @@ -717,23 +736,80 @@ def _parse_annotation(annotation: Any) -> Any: return annotation -def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]: - """Create a Pydantic model from a function's signature.""" - # Unwrap FunctionTool objects to get the underlying function - if isinstance(func, FunctionTool): - func = func.func # type: ignore[assignment] +def _matches_json_schema_type(value: Any, schema_type: str) -> bool: + """Check a value against a simple JSON schema primitive type.""" + match schema_type: + case "string": + return isinstance(value, str) + case "integer": + return isinstance(value, int) and not isinstance(value, bool) + case "number": + return (isinstance(value, int | float)) and not isinstance(value, bool) + case "boolean": + return isinstance(value, bool) + case "array": + return isinstance(value, list) + case "object": + return isinstance(value, dict) + case "null": + return value is None + case _: + return True - sig = inspect.signature(func) - fields = { - pname: ( - _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, - param.default if param.default is not inspect.Parameter.empty else ..., - ) - for pname, param in sig.parameters.items() - if pname not in {"self", "cls"} - and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} - } - return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return] + +def _validate_arguments_against_schema( + *, + arguments: Mapping[str, Any], + schema: Mapping[str, Any], + tool_name: str, +) -> dict[str, Any]: + """Run lightweight argument checks for schema-supplied tools.""" + parsed_arguments = dict(arguments) + + required_raw = schema.get("required", []) + required_fields = [field for field in required_raw if isinstance(field, str)] + missing_fields = [field for field in required_fields if field not in parsed_arguments] + if missing_fields: + raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}") + + properties_raw = schema.get("properties") + properties = properties_raw if isinstance(properties_raw, Mapping) else {} + + if schema.get("additionalProperties") is False: + unexpected_fields = sorted(field for field in parsed_arguments if field not in properties) + if unexpected_fields: + raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}") + + for field_name, field_value in parsed_arguments.items(): + field_schema = properties.get(field_name) + if not isinstance(field_schema, Mapping): + continue + + enum_values = field_schema.get("enum") + if isinstance(enum_values, list) and enum_values and field_value not in enum_values: + raise TypeError( + f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} " + f"is not in {enum_values!r}" + ) + + schema_type = field_schema.get("type") + if isinstance(schema_type, str): + if not _matches_json_schema_type(field_value, schema_type): + raise TypeError( + f"Invalid type for '{field_name}' in '{tool_name}': " + f"expected {schema_type}, got {type(field_value).__name__}" + ) + continue + + if isinstance(schema_type, list): + allowed_types = [item for item in schema_type if isinstance(item, str)] + if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types): + raise TypeError( + f"Invalid type for '{field_name}' in '{tool_name}': expected one of " + f"{allowed_types}, got {type(field_value).__name__}" + ) + + return parsed_arguments # Map JSON Schema types to Pydantic types @@ -1298,8 +1374,15 @@ async def _auto_invoke_function( if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } try: - args = tool.input_model.model_validate(parsed_args) - except ValidationError as exc: + if tool._schema_supplied and tool._input_schema is not None: + args = _validate_arguments_against_schema( + arguments=parsed_args, + schema=tool._input_schema, + tool_name=tool.name, + ) + else: + args = tool.input_model.model_validate(parsed_args) + except (ValidationError, TypeError) as exc: message = "Error: Argument parsing failed." if config["include_detailed_errors"]: message = f"{message} Exception: {exc}" diff --git a/python/packages/core/tests/core/test_schema_passthrough.py b/python/packages/core/tests/core/test_schema_passthrough.py index 9233f61247..1276e6a5a5 100644 --- a/python/packages/core/tests/core/test_schema_passthrough.py +++ b/python/packages/core/tests/core/test_schema_passthrough.py @@ -1,17 +1,19 @@ +# Copyright (c) Microsoft. All rights reserved. + """Test that JSON schemas are passed through without conversion to Pydantic models.""" -import json from typing import Any import pytest from pydantic import BaseModel from agent_framework import FunctionTool, tool +from agent_framework.exceptions import ToolException def test_function_tool_with_json_schema_stores_schema(): """Test that FunctionTool stores the JSON schema as-is without conversion.""" - + json_schema = { "type": "object", "properties": { @@ -20,21 +22,21 @@ def test_function_tool_with_json_schema_stores_schema(): }, "required": ["query"], } - + def search_func(query: str, max_results: int = 10) -> str: return f"Searching for: {query} (max {max_results})" - + tool_instance = FunctionTool( name="search", description="Search tool", func=search_func, input_model=json_schema, ) - + # The stored schema should be the original JSON schema # not a Pydantic-generated one params = tool_instance.parameters() - + # Verify it matches the original schema structure assert params["type"] == "object" assert "query" in params["properties"] @@ -44,7 +46,7 @@ def search_func(query: str, max_results: int = 10) -> str: def test_tool_decorator_with_json_schema_stores_schema(): """Test that @tool decorator stores JSON schema as-is.""" - + json_schema = { "type": "object", "properties": { @@ -53,13 +55,13 @@ def test_tool_decorator_with_json_schema_stores_schema(): }, "required": ["location"], } - + @tool(name="weather", description="Get weather", schema=json_schema) def get_weather(location: str, unit: str = "celsius") -> str: return f"Weather in {location}: 22°{unit[0].upper()}" - + params = get_weather.parameters() - + # Should be the original schema assert params["type"] == "object" assert "location" in params["properties"] @@ -69,7 +71,7 @@ def get_weather(location: str, unit: str = "celsius") -> str: @pytest.mark.asyncio async def test_schema_supplied_tool_invocation_without_pydantic_validation(): """Test that schema-supplied tools skip pydantic model_validate in invoke.""" - + json_schema = { "type": "object", "properties": { @@ -78,41 +80,119 @@ async def test_schema_supplied_tool_invocation_without_pydantic_validation(): }, "required": ["name"], } - + invocation_count = 0 - + def greet(name: str, age: int | None = None) -> str: nonlocal invocation_count invocation_count += 1 if age: return f"Hello {name}, you are {age} years old" return f"Hello {name}" - + tool_instance = FunctionTool( name="greet", description="Greet a person", func=greet, input_model=json_schema, ) - + # Create a mock arguments object that mimics what the tool would receive class MockArgs(BaseModel): name: str age: int | None = None - + args = MockArgs(name="Alice", age=30) - + # Invoke the tool result = await tool_instance.invoke(arguments=args) - + assert invocation_count == 1 assert "Alice" in result assert "30" in result +async def test_schema_supplied_tool_invocation_rejects_missing_required_args(): + """Schema-supplied tools should still enforce required fields.""" + + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + + def greet(name: str, age: int | None = None) -> str: + return f"Hello {name}, age={age}" + + tool_instance = FunctionTool( + name="greet", + description="Greet a person", + func=greet, + input_model=json_schema, + ) + + with pytest.raises(TypeError, match="Missing required argument"): + await tool_instance.invoke(arguments={"age": 30}) + + +async def test_schema_supplied_tool_invocation_rejects_wrong_type(): + """Schema-supplied tools should run lightweight type checks.""" + + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + }, + "required": ["name"], + } + + def greet(name: str, age: int | None = None) -> str: + return f"Hello {name}, age={age}" + + tool_instance = FunctionTool( + name="greet", + description="Greet a person", + func=greet, + input_model=json_schema, + ) + + with pytest.raises(TypeError, match="Invalid type for 'age'"): + await tool_instance.invoke(arguments={"name": "Alice", "age": "30"}) + + +async def test_schema_supplied_tool_invocation_rejects_unexpected_arguments(): + """Schema-supplied tools should reject unknown fields when additionalProperties is false.""" + + json_schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + }, + "required": ["name"], + "additionalProperties": False, + } + + def greet(name: str) -> str: + return f"Hello {name}" + + tool_instance = FunctionTool( + name="greet", + description="Greet a person", + func=greet, + input_model=json_schema, + ) + + with pytest.raises(TypeError, match="Unexpected argument"): + await tool_instance.invoke(arguments={"name": "Alice", "extra": True}) + + def test_json_schema_passthrough_preserves_custom_properties(): """Test that custom JSON schema properties are preserved (not lost in conversion).""" - + json_schema = { "type": "object", "properties": { @@ -126,28 +206,28 @@ def test_json_schema_passthrough_preserves_custom_properties(): "required": ["priority"], "additionalProperties": False, # Custom constraint } - + def process(priority: str) -> str: return f"Processing with priority: {priority}" - + tool_instance = FunctionTool( name="process", description="Process task", func=process, input_model=json_schema, ) - + params = tool_instance.parameters() - + # Verify custom properties are preserved - assert params.get("additionalProperties") == False + assert not params.get("additionalProperties") # Note: x-custom-field might be stripped by pydantic's model_json_schema, # but our implementation should preserve the original schema def test_schema_without_conversion_maintains_exact_structure(): """Test that the exact JSON schema structure is maintained without Pydantic interference.""" - + # A schema that would be altered if round-tripped through Pydantic json_schema = { "type": "object", @@ -166,19 +246,19 @@ def test_schema_without_conversion_maintains_exact_structure(): }, "required": ["filters"], } - + def apply_filters(filters: list[dict[str, Any]]) -> str: return f"Applied {len(filters)} filters" - + tool_instance = FunctionTool( name="apply_filters", description="Apply filters", func=apply_filters, input_model=json_schema, ) - + params = tool_instance.parameters() - + # Verify the structure is exactly as provided assert params["properties"]["filters"]["minItems"] == 1 assert params["properties"]["filters"]["items"]["type"] == "object" @@ -187,7 +267,7 @@ def apply_filters(filters: list[dict[str, Any]]) -> str: @pytest.mark.asyncio async def test_declaration_only_tool_with_json_schema(): """Test declaration-only tools with JSON schema work correctly.""" - + json_schema = { "type": "object", "properties": { @@ -195,32 +275,33 @@ async def test_declaration_only_tool_with_json_schema(): }, "required": ["command"], } - + tool_instance = FunctionTool( name="execute", description="Execute command", func=None, # Declaration only input_model=json_schema, ) - + # Should be able to get parameters params = tool_instance.parameters() assert params["properties"]["command"]["type"] == "string" - + # Should not be invocable class MockArgs(BaseModel): command: str - - with pytest.raises(Exception): # Should raise ToolException + + with pytest.raises(ToolException): await tool_instance.invoke(arguments=MockArgs(command="test")) def test_mcp_tool_schema_passthrough(): """Test that MCP tool schemas are passed through without conversion.""" from mcp import types - from agent_framework._mcp import _get_input_model_from_mcp_tool + from agent_framework import FunctionTool - + from agent_framework._mcp import _get_input_model_from_mcp_tool + # Create an MCP tool with a complex schema mcp_schema = { "type": "object", @@ -242,35 +323,35 @@ def test_mcp_tool_schema_passthrough(): "required": ["query"], "additionalProperties": False, } - + mcp_tool = types.Tool( name="search_tool", description="Search with filters", inputSchema=mcp_schema, ) - + # Get the schema from MCP tool schema = _get_input_model_from_mcp_tool(mcp_tool) - + # Verify it's the original schema assert isinstance(schema, dict) assert schema == mcp_schema - + # Create a FunctionTool with this schema def search_impl(query: str, filters: list | None = None, limit: int = 10) -> str: return f"Searched for: {query}" - + func_tool = FunctionTool( name="search_tool", description="Search with filters", func=search_impl, input_model=schema, ) - + # Verify the FunctionTool parameters match the original schema params = func_tool.parameters() assert params == mcp_schema - assert params.get("additionalProperties") == False + assert not params.get("additionalProperties") assert params["properties"]["limit"]["minimum"] == 1 assert params["properties"]["filters"]["items"]["properties"]["operator"]["enum"] == ["eq", "ne", "gt", "lt"] @@ -279,9 +360,10 @@ def search_impl(query: str, filters: list | None = None, limit: int = 10) -> str async def test_function_tool_with_mcp_schema_invocation(): """Test that FunctionTool can invoke with MCP-sourced schemas.""" from mcp import types - from agent_framework._mcp import _get_input_model_from_mcp_tool + from agent_framework import FunctionTool - + from agent_framework._mcp import _get_input_model_from_mcp_tool + mcp_schema = { "type": "object", "properties": { @@ -290,31 +372,31 @@ async def test_function_tool_with_mcp_schema_invocation(): }, "required": ["name"], } - + mcp_tool = types.Tool( name="greet_tool", description="Greet someone", inputSchema=mcp_schema, ) - + schema = _get_input_model_from_mcp_tool(mcp_tool) - + invocations = [] - + def greet_impl(name: str, count: int = 1) -> str: invocations.append({"name": name, "count": count}) return f"Hello {name}!" * count - + func_tool = FunctionTool( name="greet_tool", description="Greet someone", func=greet_impl, input_model=schema, ) - + # Invoke with dict arguments result = await func_tool.invoke(arguments={"name": "Alice", "count": 2}) - + assert len(invocations) == 1 assert invocations[0]["name"] == "Alice" assert invocations[0]["count"] == 2 @@ -324,8 +406,9 @@ def greet_impl(name: str, count: int = 1) -> str: def test_performance_benefit_of_schema_passthrough(): """Verify that schema passthrough avoids expensive Pydantic model creation.""" import time + from agent_framework import FunctionTool - + # A complex schema that would be expensive to convert complex_schema = { "type": "object", @@ -335,7 +418,7 @@ def test_performance_benefit_of_schema_passthrough(): }, "required": [f"field_{i}" for i in range(50)], } - + # Measure time to create FunctionTool with schema start = time.perf_counter() tool = FunctionTool( @@ -345,11 +428,11 @@ def test_performance_benefit_of_schema_passthrough(): input_model=complex_schema, ) schema_time = time.perf_counter() - start - + # Verify schema is stored as-is params = tool.parameters() assert params == complex_schema - + # The schema creation should be very fast (no Pydantic model building) # This is a smoke test - we're just verifying it doesn't error and returns quickly assert schema_time < 1.0 # Should be nearly instant From a02377de2811bc05c04fae163f53ab4b0e47de5d Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 11:12:24 +0100 Subject: [PATCH 3/6] Simplify FunctionTool without generic model dependency --- .../packages/core/agent_framework/_agents.py | 4 +- python/packages/core/agent_framework/_mcp.py | 33 +- .../core/agent_framework/_middleware.py | 4 +- .../packages/core/agent_framework/_tools.py | 171 ++++--- .../core/agent_framework/observability.py | 2 +- python/packages/core/tests/core/test_mcp.py | 14 +- .../tests/core/test_schema_passthrough.py | 438 ------------------ python/packages/core/tests/core/test_tools.py | 84 ++++ 8 files changed, 192 insertions(+), 558 deletions(-) delete mode 100644 python/packages/core/tests/core/test_schema_passthrough.py diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 49bdd64387..e40284dfde 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -439,7 +439,7 @@ def as_tool( stream_callback: Callable[[AgentResponseUpdate], None] | Callable[[AgentResponseUpdate], Awaitable[None]] | None = None, - ) -> FunctionTool[BaseModel]: + ) -> FunctionTool: """Create a FunctionTool that wraps this agent. Keyword Args: @@ -513,7 +513,7 @@ async def agent_wrapper(**kwargs: Any) -> str: # Create final text from accumulated updates return AgentResponse.from_updates(response_updates).text - agent_tool: FunctionTool[BaseModel] = FunctionTool( + agent_tool: FunctionTool = FunctionTool( name=tool_name, description=tool_description, func=agent_wrapper, diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index 639e13f945..f9d2cd9971 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -24,7 +24,6 @@ from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder -from pydantic import BaseModel, create_model from ._tools import ( FunctionTool, @@ -354,15 +353,14 @@ def _prepare_message_for_mcp( return messages -def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any] | type[BaseModel]: +def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any]: """Get the input model from an MCP prompt. - - For prompts with arguments, returns the schema as a dict. - For prompts without arguments, returns an empty Pydantic model. + + Returns a JSON schema dictionary for prompt arguments. """ # Check if 'arguments' is missing or empty if not prompt.arguments: - return create_model(f"{prompt.name}_input") + return {"type": "object", "properties": {}} # Convert prompt arguments to JSON schema format properties: dict[str, Any] = {} @@ -377,18 +375,12 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any] | t if prompt_argument.required: required.append(prompt_argument.name) - schema = {"type": "object", "properties": properties, "required": required} + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required return schema -def _get_input_model_from_mcp_tool(tool: types.Tool) -> dict[str, Any]: - """Get the input schema from an MCP tool. - - Returns the schema as-is without conversion to Pydantic model. - """ - return tool.inputSchema - - def _normalize_mcp_name(name: str) -> str: """Normalize MCP tool/prompt names to allowed identifier pattern (A-Za-z0-9_.-).""" return re.sub(r"[^A-Za-z0-9_.-]", "-", name) @@ -473,7 +465,7 @@ def __init__( self.session = session self.request_timeout = request_timeout self.client = client - self._functions: list[FunctionTool[Any]] = [] + self._functions: list[FunctionTool] = [] self.is_connected: bool = False self._tools_loaded: bool = False self._prompts_loaded: bool = False @@ -482,7 +474,7 @@ def __str__(self) -> str: return f"MCPTool(name={self.name}, description={self.description})" @property - def functions(self) -> list[FunctionTool[Any]]: + def functions(self) -> list[FunctionTool]: """Get the list of functions that are allowed.""" if not self.allowed_tools: return self._functions @@ -750,7 +742,7 @@ async def load_prompts(self) -> None: input_model = _get_input_model_from_mcp_prompt(prompt) approval_mode = self._determine_approval_mode(local_name) - func: FunctionTool[BaseModel] = FunctionTool( + func: FunctionTool = FunctionTool( func=partial(self.get_prompt, prompt.name), name=local_name, description=prompt.description or "", @@ -791,15 +783,14 @@ async def load_tools(self) -> None: if local_name in existing_names: continue - input_model = _get_input_model_from_mcp_tool(tool) approval_mode = self._determine_approval_mode(local_name) # Create FunctionTools out of each tool - func: FunctionTool[BaseModel] = FunctionTool( + func: FunctionTool = FunctionTool( func=partial(self.call_tool, tool.name), name=local_name, description=tool.description or "", approval_mode=approval_mode, - input_model=input_model, + input_model=tool.inputSchema, ) self._functions.append(func) existing_names.add(local_name) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 969ef1efc9..299da150ab 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -234,8 +234,8 @@ async def process(self, context: FunctionInvocationContext, call_next): def __init__( self, - function: FunctionTool[Any], - arguments: BaseModel, + function: FunctionTool, + arguments: BaseModel | Mapping[str, Any], metadata: Mapping[str, Any] | None = None, result: Any = None, kwargs: Mapping[str, Any] | None = None, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 971464655f..029b812364 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -89,8 +89,6 @@ ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") # region Helpers -ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel) - def _parse_inputs( inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None, @@ -183,11 +181,7 @@ def _default_histogram() -> Histogram: ClassT = TypeVar("ClassT", bound="SerializationMixin") -class EmptyInputModel(BaseModel): - """An empty input model for functions with no parameters.""" - - -class FunctionTool(SerializationMixin, Generic[ArgsT]): +class FunctionTool(SerializationMixin): """A tool that wraps a Python function to make it callable by AI models. This class wraps a Python function to make it callable by AI models with automatic @@ -244,6 +238,11 @@ class WeatherArgs(BaseModel): "_schema_supplied", } + @classmethod + def __class_getitem__(cls, item: Any) -> type[FunctionTool]: + """Support legacy type annotations like FunctionTool[Any].""" + return cls + def __init__( self, *, @@ -254,7 +253,7 @@ def __init__( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, func: Callable[..., Any] | None = None, - input_model: type[ArgsT] | Mapping[str, Any] | None = None, + input_model: type[BaseModel] | Mapping[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, **kwargs: Any, ) -> None: @@ -303,15 +302,13 @@ def __init__( self._instance = None # Store the instance for bound methods # Track if schema was supplied as JSON dict (for optimization) - self._schema_supplied = isinstance(input_model, Mapping) and not ( - inspect.isclass(input_model) and issubclass(input_model, BaseModel) - ) - - # Store the original JSON schema if provided - self._input_schema: dict[str, Any] | None = dict(input_model) if self._schema_supplied else None - - # Only create Pydantic model if schema wasn't supplied as dict - self.input_model = self._resolve_input_model(input_model) + self._schema_supplied = isinstance(input_model, Mapping) + if self._schema_supplied: + self._input_schema: dict[str, Any] = dict(input_model) + self.input_model: type[BaseModel] | None = None + else: + self.input_model = self._resolve_input_model(cast(type[BaseModel] | None, input_model)) + self._input_schema = self.input_model.model_json_schema() self._cached_parameters: dict[str, Any] | None = None self.approval_mode = approval_mode or "never_require" if max_invocations is not None and max_invocations < 1: @@ -347,7 +344,7 @@ def declaration_only(self) -> bool: return True return self.func is None - def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: + def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool: """Implement the descriptor protocol to support bound methods. When a FunctionTool is accessed as an attribute of a class instance, @@ -378,34 +375,28 @@ def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: return self - def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]: - """Resolve the input model for the function. - - When a JSON schema is provided as a Mapping, we no longer convert it to a Pydantic model. - Instead, we store the schema as-is and use EmptyInputModel as a placeholder. - """ - if input_model is None: - if self.func is None: - return cast(type[ArgsT], EmptyInputModel) - func = self.func.func if isinstance(self.func, FunctionTool) else self.func - sig = inspect.signature(func) - fields = { - pname: ( - _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, - param.default if param.default is not inspect.Parameter.empty else ..., - ) - for pname, param in sig.parameters.items() - if pname not in {"self", "cls"} - and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} - } - return cast(type[ArgsT], create_model(f"{self.name}_input", **fields)) - if inspect.isclass(input_model) and issubclass(input_model, BaseModel): - return input_model - if isinstance(input_model, Mapping): - # Don't convert to Pydantic model - store schema as-is - # Use EmptyInputModel as placeholder since we won't use it for validation - return cast(type[ArgsT], EmptyInputModel) - raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.") + def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[BaseModel]: + """Resolve the input model for the function.""" + if input_model is not None: + if inspect.isclass(input_model) and issubclass(input_model, BaseModel): + return input_model + raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.") + + if self.func is None: + return create_model(f"{self.name}_input") + + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + sig = inspect.signature(func) + fields = { + pname: ( + _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, + param.default if param.default is not inspect.Parameter.empty else ..., + ) + for pname, param in sig.parameters.items() + if pname not in {"self", "cls"} + and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} + } + return create_model(f"{self.name}_input", **fields) def __call__(self, *args: Any, **kwargs: Any) -> Any: """Call the wrapped function with the provided arguments.""" @@ -436,7 +427,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: async def invoke( self, *, - arguments: ArgsT | Mapping[str, Any] | None = None, + arguments: BaseModel | Mapping[str, Any] | None = None, **kwargs: Any, ) -> str: """Run the AI function with the provided arguments as a Pydantic model. @@ -446,14 +437,14 @@ async def invoke( ``result_parser`` if one was provided. Keyword Args: - arguments: A Pydantic model instance containing the arguments for the function. + arguments: A mapping or model instance containing the arguments for the function. kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided. Returns: The parsed result as a string — either plain text or serialized JSON. Raises: - TypeError: If arguments is not an instance of the expected input model. + TypeError: If arguments is not mapping-like or fails schema checks. """ if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") @@ -465,30 +456,32 @@ async def invoke( original_kwargs = dict(kwargs) tool_call_id = original_kwargs.pop("tool_call_id", None) if arguments is not None: - # For schema-supplied tools, skip Pydantic validation and extract kwargs directly - if self._schema_supplied: - # Extract kwargs from arguments (could be dict or BaseModel) + try: if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) - elif hasattr(arguments, "model_dump"): + if self.input_model is not None and not self._schema_supplied: + parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump( + exclude_none=True + ) + elif isinstance(arguments, BaseModel): + if ( + self.input_model is not None + and not self._schema_supplied + and not isinstance(arguments, self.input_model) + ): + raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") parsed_arguments = arguments.model_dump(exclude_none=True) - elif hasattr(arguments, "__dict__"): - parsed_arguments = {k: v for k, v in arguments.__dict__.items() if not k.startswith("_")} else: raise TypeError( - f"Expected mapping-like arguments for schema tool '{self.name}', " - f"got {type(arguments).__name__}" + f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}" ) - kwargs = _validate_arguments_against_schema( - arguments=parsed_arguments, - schema=self._input_schema or {}, - tool_name=self.name, - ) - else: - # For Pydantic models, do the normal validation - if not isinstance(arguments, self.input_model): - raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - kwargs = arguments.model_dump(exclude_none=True) + except ValidationError as exc: + raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc + kwargs = _validate_arguments_against_schema( + arguments=parsed_arguments, + schema=self.parameters(), + tool_name=self.name, + ) if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: kwargs.update(original_kwargs) else: @@ -577,12 +570,7 @@ def parameters(self) -> dict[str, Any]: The result is cached after the first call for performance. """ if self._cached_parameters is None: - # If schema was supplied as JSON, return it directly - if self._schema_supplied and self._input_schema is not None: - self._cached_parameters = self._input_schema - else: - # Otherwise generate from Pydantic model - self._cached_parameters = self.input_model.model_json_schema() + self._cached_parameters = self._input_schema return self._cached_parameters @staticmethod @@ -1058,7 +1046,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> FunctionTool[Any]: ... +) -> FunctionTool: ... @overload @@ -1073,7 +1061,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> Callable[[Callable[..., Any]], FunctionTool[Any]]: ... +) -> Callable[[Callable[..., Any]], FunctionTool]: ... def tool( @@ -1087,7 +1075,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> FunctionTool[Any] | Callable[[Callable[..., Any]], FunctionTool[Any]]: +) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]: """Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically. This decorator creates a Pydantic model from the function's signature, @@ -1211,12 +1199,12 @@ def get_weather(location: str, unit: str = "celsius") -> str: """ - def decorator(func: Callable[..., Any]) -> FunctionTool[Any]: + def decorator(func: Callable[..., Any]) -> FunctionTool: @wraps(func) - def wrapper(f: Callable[..., Any]) -> FunctionTool[Any]: + def wrapper(f: Callable[..., Any]) -> FunctionTool: tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment] tool_desc: str = description or (f.__doc__ or "") - return FunctionTool[Any]( + return FunctionTool( name=tool_name, description=tool_desc, approval_mode=approval_mode, @@ -1309,7 +1297,7 @@ async def _auto_invoke_function( custom_args: dict[str, Any] | None = None, *, config: FunctionInvocationConfiguration, - tool_map: dict[str, FunctionTool[BaseModel]], + tool_map: dict[str, FunctionTool], sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline @@ -1341,7 +1329,7 @@ async def _auto_invoke_function( # this function is called. This function only handles the actual execution of approved, # non-declaration-only functions. - tool: FunctionTool[BaseModel] | None = None + tool: FunctionTool | None = None if function_call_content.type == "function_call": tool = tool_map.get(function_call_content.name) # type: ignore[arg-type] # Tool should exist because _try_execute_function_calls validates this @@ -1374,15 +1362,16 @@ async def _auto_invoke_function( if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } try: - if tool._schema_supplied and tool._input_schema is not None: - args = _validate_arguments_against_schema( - arguments=parsed_args, - schema=tool._input_schema, - tool_name=tool.name, - ) + if not tool._schema_supplied and tool.input_model is not None: + args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) else: - args = tool.input_model.model_validate(parsed_args) - except (ValidationError, TypeError) as exc: + args = dict(parsed_args) + args = _validate_arguments_against_schema( + arguments=args, + schema=tool.parameters(), + tool_name=tool.name, + ) + except (TypeError, ValidationError) as exc: message = "Error: Argument parsing failed." if config["include_detailed_errors"]: message = f"{message} Exception: {exc}" @@ -1463,8 +1452,8 @@ def _get_tool_map( | Callable[..., Any] | MutableMapping[str, Any] | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]], -) -> dict[str, FunctionTool[Any]]: - tool_list: dict[str, FunctionTool[Any]] = {} +) -> dict[str, FunctionTool]: + tool_list: dict[str, FunctionTool] = {} for tool_item in tools if isinstance(tools, list) else [tools]: if isinstance(tool_item, FunctionTool): tool_list[tool_item.name] = tool_item diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 5417c8988d..d19683d9e2 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1448,7 +1448,7 @@ async def _run() -> AgentResponse: # region Otel Helpers -def get_function_span_attributes(function: FunctionTool[Any], tool_call_id: str | None = None) -> dict[str, str]: +def get_function_span_attributes(function: FunctionTool, tool_call_id: str | None = None) -> dict[str, str]: """Get the span attributes for the given function. Args: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index ad15667df3..89dddcba92 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -10,7 +10,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.shared.exceptions import McpError -from pydantic import AnyUrl, BaseModel, ValidationError +from pydantic import AnyUrl, BaseModel from agent_framework import ( Content, @@ -22,7 +22,6 @@ from agent_framework._mcp import ( MCPTool, _get_input_model_from_mcp_prompt, - _get_input_model_from_mcp_tool, _normalize_mcp_name, _parse_content_from_mcp, _parse_message_from_mcp, @@ -602,7 +601,7 @@ def test_get_input_model_from_mcp_tool_parametrized( - validation_check: Not used in this test (kept for compatibility) """ tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema) - schema = _get_input_model_from_mcp_tool(tool) + schema = tool.inputSchema # Verify schema is returned as-is (dict) assert isinstance(schema, dict), f"Expected dict, got {type(schema)}" @@ -630,6 +629,15 @@ def test_get_input_model_from_mcp_prompt(): assert "arg2" not in result["required"] +def test_get_input_model_from_mcp_prompt_without_arguments(): + """Test prompt schema generation when no prompt arguments are defined.""" + prompt = types.Prompt(name="empty_prompt", description="No args prompt", arguments=[]) + result = _get_input_model_from_mcp_prompt(prompt) + + assert isinstance(result, dict) + assert result == {"type": "object", "properties": {}} + + # MCPTool tests async def test_local_mcp_server_initialization(): """Test MCPTool initialization.""" diff --git a/python/packages/core/tests/core/test_schema_passthrough.py b/python/packages/core/tests/core/test_schema_passthrough.py deleted file mode 100644 index 1276e6a5a5..0000000000 --- a/python/packages/core/tests/core/test_schema_passthrough.py +++ /dev/null @@ -1,438 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Test that JSON schemas are passed through without conversion to Pydantic models.""" - -from typing import Any - -import pytest -from pydantic import BaseModel - -from agent_framework import FunctionTool, tool -from agent_framework.exceptions import ToolException - - -def test_function_tool_with_json_schema_stores_schema(): - """Test that FunctionTool stores the JSON schema as-is without conversion.""" - - json_schema = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "max_results": {"type": "integer", "default": 10}, - }, - "required": ["query"], - } - - def search_func(query: str, max_results: int = 10) -> str: - return f"Searching for: {query} (max {max_results})" - - tool_instance = FunctionTool( - name="search", - description="Search tool", - func=search_func, - input_model=json_schema, - ) - - # The stored schema should be the original JSON schema - # not a Pydantic-generated one - params = tool_instance.parameters() - - # Verify it matches the original schema structure - assert params["type"] == "object" - assert "query" in params["properties"] - assert params["properties"]["query"]["type"] == "string" - assert params["properties"]["max_results"]["default"] == 10 - - -def test_tool_decorator_with_json_schema_stores_schema(): - """Test that @tool decorator stores JSON schema as-is.""" - - json_schema = { - "type": "object", - "properties": { - "location": {"type": "string", "description": "City name"}, - "unit": {"type": "string", "enum": ["celsius", "fahrenheit"], "default": "celsius"}, - }, - "required": ["location"], - } - - @tool(name="weather", description="Get weather", schema=json_schema) - def get_weather(location: str, unit: str = "celsius") -> str: - return f"Weather in {location}: 22°{unit[0].upper()}" - - params = get_weather.parameters() - - # Should be the original schema - assert params["type"] == "object" - assert "location" in params["properties"] - assert params["properties"]["unit"]["enum"] == ["celsius", "fahrenheit"] - - -@pytest.mark.asyncio -async def test_schema_supplied_tool_invocation_without_pydantic_validation(): - """Test that schema-supplied tools skip pydantic model_validate in invoke.""" - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name"], - } - - invocation_count = 0 - - def greet(name: str, age: int | None = None) -> str: - nonlocal invocation_count - invocation_count += 1 - if age: - return f"Hello {name}, you are {age} years old" - return f"Hello {name}" - - tool_instance = FunctionTool( - name="greet", - description="Greet a person", - func=greet, - input_model=json_schema, - ) - - # Create a mock arguments object that mimics what the tool would receive - class MockArgs(BaseModel): - name: str - age: int | None = None - - args = MockArgs(name="Alice", age=30) - - # Invoke the tool - result = await tool_instance.invoke(arguments=args) - - assert invocation_count == 1 - assert "Alice" in result - assert "30" in result - - -async def test_schema_supplied_tool_invocation_rejects_missing_required_args(): - """Schema-supplied tools should still enforce required fields.""" - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name"], - } - - def greet(name: str, age: int | None = None) -> str: - return f"Hello {name}, age={age}" - - tool_instance = FunctionTool( - name="greet", - description="Greet a person", - func=greet, - input_model=json_schema, - ) - - with pytest.raises(TypeError, match="Missing required argument"): - await tool_instance.invoke(arguments={"age": 30}) - - -async def test_schema_supplied_tool_invocation_rejects_wrong_type(): - """Schema-supplied tools should run lightweight type checks.""" - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "age": {"type": "integer"}, - }, - "required": ["name"], - } - - def greet(name: str, age: int | None = None) -> str: - return f"Hello {name}, age={age}" - - tool_instance = FunctionTool( - name="greet", - description="Greet a person", - func=greet, - input_model=json_schema, - ) - - with pytest.raises(TypeError, match="Invalid type for 'age'"): - await tool_instance.invoke(arguments={"name": "Alice", "age": "30"}) - - -async def test_schema_supplied_tool_invocation_rejects_unexpected_arguments(): - """Schema-supplied tools should reject unknown fields when additionalProperties is false.""" - - json_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - }, - "required": ["name"], - "additionalProperties": False, - } - - def greet(name: str) -> str: - return f"Hello {name}" - - tool_instance = FunctionTool( - name="greet", - description="Greet a person", - func=greet, - input_model=json_schema, - ) - - with pytest.raises(TypeError, match="Unexpected argument"): - await tool_instance.invoke(arguments={"name": "Alice", "extra": True}) - - -def test_json_schema_passthrough_preserves_custom_properties(): - """Test that custom JSON schema properties are preserved (not lost in conversion).""" - - json_schema = { - "type": "object", - "properties": { - "priority": { - "type": "string", - "enum": ["low", "medium", "high"], - "description": "Priority level", - "x-custom-field": "custom-value", # Custom property - }, - }, - "required": ["priority"], - "additionalProperties": False, # Custom constraint - } - - def process(priority: str) -> str: - return f"Processing with priority: {priority}" - - tool_instance = FunctionTool( - name="process", - description="Process task", - func=process, - input_model=json_schema, - ) - - params = tool_instance.parameters() - - # Verify custom properties are preserved - assert not params.get("additionalProperties") - # Note: x-custom-field might be stripped by pydantic's model_json_schema, - # but our implementation should preserve the original schema - - -def test_schema_without_conversion_maintains_exact_structure(): - """Test that the exact JSON schema structure is maintained without Pydantic interference.""" - - # A schema that would be altered if round-tripped through Pydantic - json_schema = { - "type": "object", - "properties": { - "filters": { - "type": "array", - "items": { - "type": "object", - "properties": { - "field": {"type": "string"}, - "value": {"type": "string"}, - }, - }, - "minItems": 1, - }, - }, - "required": ["filters"], - } - - def apply_filters(filters: list[dict[str, Any]]) -> str: - return f"Applied {len(filters)} filters" - - tool_instance = FunctionTool( - name="apply_filters", - description="Apply filters", - func=apply_filters, - input_model=json_schema, - ) - - params = tool_instance.parameters() - - # Verify the structure is exactly as provided - assert params["properties"]["filters"]["minItems"] == 1 - assert params["properties"]["filters"]["items"]["type"] == "object" - - -@pytest.mark.asyncio -async def test_declaration_only_tool_with_json_schema(): - """Test declaration-only tools with JSON schema work correctly.""" - - json_schema = { - "type": "object", - "properties": { - "command": {"type": "string", "description": "Command to execute"}, - }, - "required": ["command"], - } - - tool_instance = FunctionTool( - name="execute", - description="Execute command", - func=None, # Declaration only - input_model=json_schema, - ) - - # Should be able to get parameters - params = tool_instance.parameters() - assert params["properties"]["command"]["type"] == "string" - - # Should not be invocable - class MockArgs(BaseModel): - command: str - - with pytest.raises(ToolException): - await tool_instance.invoke(arguments=MockArgs(command="test")) - - -def test_mcp_tool_schema_passthrough(): - """Test that MCP tool schemas are passed through without conversion.""" - from mcp import types - - from agent_framework import FunctionTool - from agent_framework._mcp import _get_input_model_from_mcp_tool - - # Create an MCP tool with a complex schema - mcp_schema = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "Search query"}, - "filters": { - "type": "array", - "items": { - "type": "object", - "properties": { - "field": {"type": "string"}, - "operator": {"type": "string", "enum": ["eq", "ne", "gt", "lt"]}, - "value": {"type": "string"}, - }, - }, - }, - "limit": {"type": "integer", "default": 10, "minimum": 1, "maximum": 100}, - }, - "required": ["query"], - "additionalProperties": False, - } - - mcp_tool = types.Tool( - name="search_tool", - description="Search with filters", - inputSchema=mcp_schema, - ) - - # Get the schema from MCP tool - schema = _get_input_model_from_mcp_tool(mcp_tool) - - # Verify it's the original schema - assert isinstance(schema, dict) - assert schema == mcp_schema - - # Create a FunctionTool with this schema - def search_impl(query: str, filters: list | None = None, limit: int = 10) -> str: - return f"Searched for: {query}" - - func_tool = FunctionTool( - name="search_tool", - description="Search with filters", - func=search_impl, - input_model=schema, - ) - - # Verify the FunctionTool parameters match the original schema - params = func_tool.parameters() - assert params == mcp_schema - assert not params.get("additionalProperties") - assert params["properties"]["limit"]["minimum"] == 1 - assert params["properties"]["filters"]["items"]["properties"]["operator"]["enum"] == ["eq", "ne", "gt", "lt"] - - -@pytest.mark.asyncio -async def test_function_tool_with_mcp_schema_invocation(): - """Test that FunctionTool can invoke with MCP-sourced schemas.""" - from mcp import types - - from agent_framework import FunctionTool - from agent_framework._mcp import _get_input_model_from_mcp_tool - - mcp_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "count": {"type": "integer"}, - }, - "required": ["name"], - } - - mcp_tool = types.Tool( - name="greet_tool", - description="Greet someone", - inputSchema=mcp_schema, - ) - - schema = _get_input_model_from_mcp_tool(mcp_tool) - - invocations = [] - - def greet_impl(name: str, count: int = 1) -> str: - invocations.append({"name": name, "count": count}) - return f"Hello {name}!" * count - - func_tool = FunctionTool( - name="greet_tool", - description="Greet someone", - func=greet_impl, - input_model=schema, - ) - - # Invoke with dict arguments - result = await func_tool.invoke(arguments={"name": "Alice", "count": 2}) - - assert len(invocations) == 1 - assert invocations[0]["name"] == "Alice" - assert invocations[0]["count"] == 2 - assert "Hello Alice!" in result - - -def test_performance_benefit_of_schema_passthrough(): - """Verify that schema passthrough avoids expensive Pydantic model creation.""" - import time - - from agent_framework import FunctionTool - - # A complex schema that would be expensive to convert - complex_schema = { - "type": "object", - "properties": { - f"field_{i}": {"type": "string", "description": f"Field {i}"} - for i in range(100) - }, - "required": [f"field_{i}" for i in range(50)], - } - - # Measure time to create FunctionTool with schema - start = time.perf_counter() - tool = FunctionTool( - name="complex_tool", - description="Complex tool", - func=lambda **kwargs: "done", - input_model=complex_schema, - ) - schema_time = time.perf_counter() - start - - # Verify schema is stored as-is - params = tool.parameters() - assert params == complex_schema - - # The schema creation should be very fast (no Pydantic model building) - # This is a smoke test - we're just verifying it doesn't error and returns quickly - assert schema_time < 1.0 # Should be nearly instant diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index dbcf7aac6f..98276f3d9c 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -108,6 +108,90 @@ def search(query: str, max_results: int = 10) -> str: assert search("hello") == "Searching for: hello (max 10)" +async def test_tool_decorator_with_json_schema_invoke_uses_mapping(): + """Test that schema-based tools can be invoked directly with mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "max_results": {"type": "integer"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str, max_results: int = 10) -> str: + return f"{query}:{max_results}" + + result = await search.invoke(arguments={"query": "hello", "max_results": 3}) + assert result == "hello:3" + + +async def test_tool_decorator_with_json_schema_invoke_missing_required(): + """Test schema-required fields are checked for mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str) -> str: + return query + + with pytest.raises(TypeError, match="Missing required argument"): + await search.invoke(arguments={}) + + +async def test_tool_decorator_with_json_schema_invoke_invalid_type(): + """Test schema type checks run for mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "max_results": {"type": "integer"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str, max_results: int = 10) -> str: + return f"{query}:{max_results}" + + with pytest.raises(TypeError, match="Invalid type for 'max_results'"): + await search.invoke(arguments={"query": "hello", "max_results": "three"}) + + +def test_tool_decorator_with_json_schema_preserves_custom_properties(): + """Test schema passthrough keeps custom JSON schema properties.""" + + json_schema = { + "type": "object", + "properties": { + "priority": { + "type": "string", + "enum": ["low", "medium", "high"], + "x-custom-field": "custom-value", + }, + }, + "required": ["priority"], + "additionalProperties": False, + } + + @tool(name="process", description="Process tool", schema=json_schema) + def process(priority: str) -> str: + return priority + + params = process.parameters() + assert not params.get("additionalProperties") + assert params["properties"]["priority"]["x-custom-field"] == "custom-value" + + def test_tool_decorator_schema_none_default(): """Test that schema=None (default) still infers from function signature.""" From 095f2e081adbc389e8c6106fb0df0faa93f540d1 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 12:06:02 +0100 Subject: [PATCH 4/6] Fix FunctionTool typing fallout in 3600 --- .../ag-ui/agent_framework_ag_ui/_client.py | 2 +- .../ag-ui/agent_framework_ag_ui/_utils.py | 6 +- .../claude/agent_framework_claude/_agent.py | 2 +- .../packages/core/agent_framework/_agents.py | 13 +- .../packages/core/agent_framework/_tools.py | 15 +- python/packages/core/tests/core/test_mcp.py | 585 +++++++++--------- .../agent_framework_github_copilot/_agent.py | 2 +- .../agent_framework_lab_tau2/_tau2_utils.py | 2 +- .../_handoff.py | 4 +- 9 files changed, 313 insertions(+), 318 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 373e6321bb..a63b6ac50c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -267,7 +267,7 @@ def _register_server_tool_placeholder(self, tool_name: str) -> None: if any(getattr(tool, "name", None) == tool_name for tool in additional_tools): return - placeholder: FunctionTool[Any] = FunctionTool( + placeholder: FunctionTool = FunctionTool( name=tool_name, description="Server-managed tool placeholder (AG-UI)", func=None, diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index abbfb88562..bfda3948ec 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -162,7 +162,7 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401 def convert_agui_tools_to_agent_framework( agui_tools: list[dict[str, Any]] | None, -) -> list[FunctionTool[Any]] | None: +) -> list[FunctionTool] | None: """Convert AG-UI tool definitions to Agent Framework FunctionTool declarations. Creates declaration-only FunctionTool instances (no executable implementation). @@ -181,13 +181,13 @@ def convert_agui_tools_to_agent_framework( if not agui_tools: return None - result: list[FunctionTool[Any]] = [] + result: list[FunctionTool] = [] for tool_def in agui_tools: # Create declaration-only FunctionTool (func=None means no implementation) # When func=None, the declaration_only property returns True, # which tells the function invocation mixin to return the function call # without executing it (so it can be sent back to the client) - func: FunctionTool[Any] = FunctionTool( + func: FunctionTool = FunctionTool( name=tool_def.get("name", ""), description=tool_def.get("description", ""), func=None, # CRITICAL: Makes declaration_only=True diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 10ef5cbf45..ad6f1b3e03 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -484,7 +484,7 @@ def _prepare_tools( return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names - def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool[Any]) -> SdkMcpTool[Any]: + def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]: """Convert a FunctionTool to an SDK MCP tool. Args: diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index e40284dfde..d11f0e2c7d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -1258,17 +1258,12 @@ async def _log(level: types.LoggingLevel, data: Any) -> None: @server.list_tools() # type: ignore async def _list_tools() -> list[types.Tool]: # type: ignore """List all tools in the agent.""" - # Get the JSON schema from the Pydantic model - schema = agent_tool.input_model.model_json_schema() + schema = agent_tool.parameters() tool = types.Tool( name=agent_tool.name, description=agent_tool.description, - inputSchema={ - "type": "object", - "properties": schema.get("properties", {}), - "required": schema.get("required", []), - }, + inputSchema=schema, ) await _log(level="debug", data=f"Agent tool: {agent_tool}") @@ -1291,7 +1286,9 @@ async def _call_tool( # type: ignore # Create an instance of the input model with the arguments try: - args_instance = agent_tool.input_model(**arguments) + args_instance: BaseModel | dict[str, Any] = ( + agent_tool.input_model(**arguments) if agent_tool.input_model is not None else arguments + ) result = await agent_tool.invoke(arguments=args_instance) except Exception as e: raise McpError( diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 029b812364..f1659cbf2b 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -26,7 +26,6 @@ Literal, TypedDict, Union, - cast, get_args, get_origin, overload, @@ -302,12 +301,13 @@ def __init__( self._instance = None # Store the instance for bound methods # Track if schema was supplied as JSON dict (for optimization) - self._schema_supplied = isinstance(input_model, Mapping) - if self._schema_supplied: + if isinstance(input_model, Mapping): + self._schema_supplied = True self._input_schema: dict[str, Any] = dict(input_model) self.input_model: type[BaseModel] | None = None else: - self.input_model = self._resolve_input_model(cast(type[BaseModel] | None, input_model)) + self._schema_supplied = False + self.input_model = self._resolve_input_model(input_model) self._input_schema = self.input_model.model_json_schema() self._cached_parameters: dict[str, Any] | None = None self.approval_mode = approval_mode or "never_require" @@ -386,8 +386,10 @@ def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[Base return create_model(f"{self.name}_input") func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if func is None: + return create_model(f"{self.name}_input") sig = inspect.signature(func) - fields = { + fields: dict[str, Any] = { pname: ( _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, param.default if param.default is not inspect.Parameter.empty else ..., @@ -776,8 +778,7 @@ def _validate_arguments_against_schema( enum_values = field_schema.get("enum") if isinstance(enum_values, list) and enum_values and field_value not in enum_values: raise TypeError( - f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} " - f"is not in {enum_values!r}" + f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}" ) schema_type = field_schema.get("type") diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 89dddcba92..8b213476aa 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -275,317 +275,318 @@ def test_prepare_message_for_mcp(): @pytest.mark.parametrize( - "test_id,input_schema,valid_data,expected_values,invalid_data,validation_check", + "test_id,input_schema", [ - # Basic types with required/optional fields - ( - "basic_types", - { - "type": "object", - "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, - "required": ["param1"], - }, - {"param1": "test", "param2": 42}, - {"param1": "test", "param2": 42}, - {"param2": 42}, # Missing required param1 - None, - ), - # Nested object - ( - "nested_object", - { - "type": "object", - "properties": { - "params": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], - } + (test_id, input_schema) + for test_id, input_schema, _, _, _, _ in [ + # Basic types with required/optional fields + ( + "basic_types", + { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + "required": ["param1"], }, - "required": ["params"], - }, - {"params": {"customer_id": 251}}, - {"params.customer_id": 251}, - {"params": {}}, # Missing required customer_id - lambda instance: isinstance(instance.params, BaseModel), - ), - # $ref resolution - ( - "ref_schema", - { - "type": "object", - "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, - "required": ["params"], - "$defs": { - "CustomerIdParam": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], - } + {"param1": "test", "param2": 42}, + {"param1": "test", "param2": 42}, + {"param2": 42}, # Missing required param1 + None, + ), + # Nested object + ( + "nested_object", + { + "type": "object", + "properties": { + "params": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + "required": ["params"], }, - }, - {"params": {"customer_id": 251}}, - {"params.customer_id": 251}, - {"params": {}}, # Missing required customer_id - lambda instance: isinstance(instance.params, BaseModel), - ), - # Array of strings (typed) - ( - "array_of_strings", - { - "type": "object", - "properties": { - "tags": { - "type": "array", - "description": "List of tags", - "items": {"type": "string"}, - } + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # $ref resolution + ( + "ref_schema", + { + "type": "object", + "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, + "required": ["params"], + "$defs": { + "CustomerIdParam": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, }, - "required": ["tags"], - }, - {"tags": ["tag1", "tag2", "tag3"]}, - {"tags": ["tag1", "tag2", "tag3"]}, - None, # No validation error test for this case - None, - ), - # Array of integers (typed) - ( - "array_of_integers", - { - "type": "object", - "properties": { - "numbers": { - "type": "array", - "description": "List of integers", - "items": {"type": "integer"}, - } + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # Array of strings (typed) + ( + "array_of_strings", + { + "type": "object", + "properties": { + "tags": { + "type": "array", + "description": "List of tags", + "items": {"type": "string"}, + } + }, + "required": ["tags"], }, - "required": ["numbers"], - }, - {"numbers": [1, 2, 3]}, - {"numbers": [1, 2, 3]}, - None, - None, - ), - # Array of objects (complex nested) - ( - "array_of_objects", - { - "type": "object", - "properties": { - "users": { - "type": "array", - "description": "List of users", - "items": { - "type": "object", - "properties": { - "id": {"type": "integer", "description": "User ID"}, - "name": {"type": "string", "description": "User name"}, - }, - "required": ["id", "name"], - }, - } + {"tags": ["tag1", "tag2", "tag3"]}, + {"tags": ["tag1", "tag2", "tag3"]}, + None, # No validation error test for this case + None, + ), + # Array of integers (typed) + ( + "array_of_integers", + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "description": "List of integers", + "items": {"type": "integer"}, + } + }, + "required": ["numbers"], }, - "required": ["users"], - }, - {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - {"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"}, - {"users": [{"id": 1}]}, # Missing required 'name' - lambda instance: all(isinstance(user, BaseModel) for user in instance.users), - ), - # Deeply nested objects (3+ levels) - ( - "deeply_nested", - { - "type": "object", - "properties": { - "query": { - "type": "object", - "properties": { - "filters": { + {"numbers": [1, 2, 3]}, + {"numbers": [1, 2, 3]}, + None, + None, + ), + # Array of objects (complex nested) + ( + "array_of_objects", + { + "type": "object", + "properties": { + "users": { + "type": "array", + "description": "List of users", + "items": { "type": "object", "properties": { - "date_range": { - "type": "object", - "properties": { - "start": {"type": "string"}, - "end": {"type": "string"}, + "id": {"type": "integer", "description": "User ID"}, + "name": {"type": "string", "description": "User name"}, + }, + "required": ["id", "name"], + }, + } + }, + "required": ["users"], + }, + {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, + {"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"}, + {"users": [{"id": 1}]}, # Missing required 'name' + lambda instance: all(isinstance(user, BaseModel) for user in instance.users), + ), + # Deeply nested objects (3+ levels) + ( + "deeply_nested", + { + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "filters": { + "type": "object", + "properties": { + "date_range": { + "type": "object", + "properties": { + "start": {"type": "string"}, + "end": {"type": "string"}, + }, + "required": ["start", "end"], }, - "required": ["start", "end"], + "categories": {"type": "array", "items": {"type": "string"}}, }, - "categories": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["date_range"], - } - }, - "required": ["filters"], - } + "required": ["date_range"], + } + }, + "required": ["filters"], + } + }, + "required": ["query"], }, - "required": ["query"], - }, - { - "query": { - "filters": { - "date_range": {"start": "2024-01-01", "end": "2024-12-31"}, - "categories": ["tech", "science"], + { + "query": { + "filters": { + "date_range": {"start": "2024-01-01", "end": "2024-12-31"}, + "categories": ["tech", "science"], + } } - } - }, - { - "query.filters.date_range.start": "2024-01-01", - "query.filters.date_range.end": "2024-12-31", - "query.filters.categories": ["tech", "science"], - }, - {"query": {"filters": {"date_range": {}}}}, # Missing required start and end - None, - ), - # Complex $ref with nested structure - ( - "ref_nested_structure", - { - "type": "object", - "properties": {"order": {"$ref": "#/$defs/OrderParams"}}, - "required": ["order"], - "$defs": { - "OrderParams": { - "type": "object", - "properties": { - "customer": {"$ref": "#/$defs/Customer"}, - "items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}}, + }, + { + "query.filters.date_range.start": "2024-01-01", + "query.filters.date_range.end": "2024-12-31", + "query.filters.categories": ["tech", "science"], + }, + {"query": {"filters": {"date_range": {}}}}, # Missing required start and end + None, + ), + # Complex $ref with nested structure + ( + "ref_nested_structure", + { + "type": "object", + "properties": {"order": {"$ref": "#/$defs/OrderParams"}}, + "required": ["order"], + "$defs": { + "OrderParams": { + "type": "object", + "properties": { + "customer": {"$ref": "#/$defs/Customer"}, + "items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}}, + }, + "required": ["customer", "items"], + }, + "Customer": { + "type": "object", + "properties": {"id": {"type": "integer"}, "email": {"type": "string"}}, + "required": ["id", "email"], + }, + "OrderItem": { + "type": "object", + "properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}}, + "required": ["product_id", "quantity"], }, - "required": ["customer", "items"], - }, - "Customer": { - "type": "object", - "properties": {"id": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["id", "email"], - }, - "OrderItem": { - "type": "object", - "properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}}, - "required": ["product_id", "quantity"], }, }, - }, - { - "order": { - "customer": {"id": 123, "email": "test@example.com"}, - "items": [{"product_id": "prod1", "quantity": 2}], - } - }, - { - "order.customer.id": 123, - "order.customer.email": "test@example.com", - "order.items[0].product_id": "prod1", - "order.items[0].quantity": 2, - }, - {"order": {"customer": {"id": 123}, "items": []}}, # Missing email - lambda instance: isinstance(instance.order.customer, BaseModel), - ), - # Mixed types (primitives, arrays, nested objects) - ( - "mixed_types", - { - "type": "object", - "properties": { - "simple_string": {"type": "string"}, - "simple_number": {"type": "integer"}, - "string_array": {"type": "array", "items": {"type": "string"}}, - "nested_config": { - "type": "object", - "properties": { - "enabled": {"type": "boolean"}, - "options": {"type": "array", "items": {"type": "string"}}, + { + "order": { + "customer": {"id": 123, "email": "test@example.com"}, + "items": [{"product_id": "prod1", "quantity": 2}], + } + }, + { + "order.customer.id": 123, + "order.customer.email": "test@example.com", + "order.items[0].product_id": "prod1", + "order.items[0].quantity": 2, + }, + {"order": {"customer": {"id": 123}, "items": []}}, # Missing email + lambda instance: isinstance(instance.order.customer, BaseModel), + ), + # Mixed types (primitives, arrays, nested objects) + ( + "mixed_types", + { + "type": "object", + "properties": { + "simple_string": {"type": "string"}, + "simple_number": {"type": "integer"}, + "string_array": {"type": "array", "items": {"type": "string"}}, + "nested_config": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "options": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["enabled"], }, - "required": ["enabled"], }, + "required": ["simple_string", "nested_config"], }, - "required": ["simple_string", "nested_config"], - }, - { - "simple_string": "test", - "simple_number": 42, - "string_array": ["a", "b"], - "nested_config": {"enabled": True, "options": ["opt1", "opt2"]}, - }, - { - "simple_string": "test", - "simple_number": 42, - "string_array": ["a", "b"], - "nested_config.enabled": True, - "nested_config.options": ["opt1", "opt2"], - }, - None, - None, - ), - # Empty schema (no properties) - ( - "empty_schema", - {"type": "object", "properties": {}}, - {}, - {}, - None, - None, - ), - # All primitive types - ( - "all_primitives", - { - "type": "object", - "properties": { - "string_field": {"type": "string"}, - "integer_field": {"type": "integer"}, - "number_field": {"type": "number"}, - "boolean_field": {"type": "boolean"}, + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config": {"enabled": True, "options": ["opt1", "opt2"]}, }, - }, - {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, - {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, - None, - None, - ), - # Edge case: unresolvable $ref (fallback to dict) - ( - "unresolvable_ref", - { - "type": "object", - "properties": {"data": {"$ref": "#/$defs/NonExistent"}}, - "$defs": {}, - }, - {"data": {"key": "value"}}, - {"data": {"key": "value"}}, - None, - None, - ), - # Edge case: array without items schema (fallback to bare list) - ( - "array_no_items", - { - "type": "object", - "properties": {"items": {"type": "array"}}, - }, - {"items": [1, "two", 3.0]}, - {"items": [1, "two", 3.0]}, - None, - None, - ), - # Edge case: object without properties (fallback to dict) - ( - "object_no_properties", - { - "type": "object", - "properties": {"config": {"type": "object"}}, - }, - {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, - {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, - None, - None, - ), + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config.enabled": True, + "nested_config.options": ["opt1", "opt2"], + }, + None, + None, + ), + # Empty schema (no properties) + ( + "empty_schema", + {"type": "object", "properties": {}}, + {}, + {}, + None, + None, + ), + # All primitive types + ( + "all_primitives", + { + "type": "object", + "properties": { + "string_field": {"type": "string"}, + "integer_field": {"type": "integer"}, + "number_field": {"type": "number"}, + "boolean_field": {"type": "boolean"}, + }, + }, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + None, + None, + ), + # Edge case: unresolvable $ref (fallback to dict) + ( + "unresolvable_ref", + { + "type": "object", + "properties": {"data": {"$ref": "#/$defs/NonExistent"}}, + "$defs": {}, + }, + {"data": {"key": "value"}}, + {"data": {"key": "value"}}, + None, + None, + ), + # Edge case: array without items schema (fallback to bare list) + ( + "array_no_items", + { + "type": "object", + "properties": {"items": {"type": "array"}}, + }, + {"items": [1, "two", 3.0]}, + {"items": [1, "two", 3.0]}, + None, + None, + ), + # Edge case: object without properties (fallback to dict) + ( + "object_no_properties", + { + "type": "object", + "properties": {"config": {"type": "object"}}, + }, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + None, + None, + ), + ] ], ) -def test_get_input_model_from_mcp_tool_parametrized( - test_id, input_schema, valid_data, expected_values, invalid_data, validation_check -): +def test_get_input_model_from_mcp_tool_parametrized(test_id: str, input_schema: dict[str, Any]) -> None: """Parametrized test for MCP tool input schema passthrough. This test verifies that MCP tool schemas are passed through as-is @@ -595,10 +596,6 @@ def test_get_input_model_from_mcp_tool_parametrized( To add a new test case, add a tuple to the parametrize decorator with: - test_id: A descriptive name for the test case - input_schema: The JSON schema (inputSchema dict) - - valid_data: Valid data (used to verify FunctionTool works with the schema) - - expected_values: Not used in this test (kept for compatibility) - - invalid_data: Not used in this test (kept for compatibility) - - validation_check: Not used in this test (kept for compatibility) """ tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema) schema = tool.inputSchema diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 42de197014..1cafc1ba85 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -499,7 +499,7 @@ def _prepare_tools( return copilot_tools - def _tool_to_copilot_tool(self, ai_func: FunctionTool[Any]) -> CopilotTool: + def _tool_to_copilot_tool(self, ai_func: FunctionTool) -> CopilotTool: """Convert an FunctionTool to a Copilot SDK tool.""" async def handler(invocation: ToolInvocation) -> ToolResult: diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 42d03393f8..75c0676cb6 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -27,7 +27,7 @@ _original_set_state = Environment.set_state -def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any]: +def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool: """Convert a tau2 Tool to a FunctionTool for agent framework compatibility. Creates a wrapper that preserves the tool's interface while ensuring diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index ea5f8bd201..e0dc3e8ea9 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -324,7 +324,7 @@ def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration existing_tools = list(default_options.get("tools") or []) existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - new_tools: list[FunctionTool[Any]] = [] + new_tools: list[FunctionTool] = [] for target in targets: handoff_tool = self._create_handoff_tool(target.target_id, target.description) if handoff_tool.name in existing_names: @@ -340,7 +340,7 @@ def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration else: default_options["tools"] = existing_tools - def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool[Any]: + def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool: """Construct the synthetic handoff tool that signals routing to `target_id`.""" tool_name = get_handoff_tool_name(target_id) doc = description or f"Handoff to the {target_id} agent." From c2e7cf0ecc28f0f134027f524fbaa3bba047ff31 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 12:38:23 +0100 Subject: [PATCH 5/6] Remove FunctionTool[Any] compatibility shim --- .../agents/ui_generator_agent.py | 10 +++---- .../packages/core/agent_framework/_tools.py | 5 ---- .../core/tests/core/test_middleware.py | 28 +++++++++---------- .../core/test_middleware_context_result.py | 13 ++++----- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py index 7f5a4b0f2c..9ea92dd24e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py @@ -5,7 +5,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, TypedDict from agent_framework import Agent, FunctionTool, SupportsChatGetResponse from agent_framework.ag_ui import AgentFrameworkAgent @@ -23,7 +23,7 @@ from agent_framework import ChatOptions # Declaration-only tools (func=None) - actual rendering happens on the client side -generate_haiku = FunctionTool[Any]( +generate_haiku = FunctionTool( name="generate_haiku", description="""Generate a haiku with image and gradient background (FRONTEND_RENDER). @@ -71,7 +71,7 @@ }, ) -create_chart = FunctionTool[Any]( +create_chart = FunctionTool( name="create_chart", description="""Create an interactive chart (FRONTEND_RENDER). @@ -99,7 +99,7 @@ }, ) -display_timeline = FunctionTool[Any]( +display_timeline = FunctionTool( name="display_timeline", description="""Display an interactive timeline (FRONTEND_RENDER). @@ -127,7 +127,7 @@ }, ) -show_comparison_table = FunctionTool[Any]( +show_comparison_table = FunctionTool( name="show_comparison_table", description="""Show a comparison table (FRONTEND_RENDER). diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index f1659cbf2b..9cb07f2970 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -237,11 +237,6 @@ class WeatherArgs(BaseModel): "_schema_supplied", } - @classmethod - def __class_getitem__(cls, item: Any) -> type[FunctionTool]: - """Support legacy type annotations like FunctionTool[Any].""" - return cls - def __init__( self, *, diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 4ac4f22f1c..6c559c40d4 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -74,7 +74,7 @@ def test_init_with_session(self, mock_agent: SupportsAgentRun) -> None: class TestFunctionInvocationContext: """Test cases for FunctionInvocationContext.""" - def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None: + def test_init_with_defaults(self, mock_function: FunctionTool) -> None: """Test FunctionInvocationContext initialization with default values.""" arguments = FunctionTestArgs(name="test") context = FunctionInvocationContext(function=mock_function, arguments=arguments) @@ -83,7 +83,7 @@ def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None: assert context.arguments == arguments assert context.metadata == {} - def test_init_with_custom_metadata(self, mock_function: FunctionTool[Any]) -> None: + def test_init_with_custom_metadata(self, mock_function: FunctionTool) -> None: """Test FunctionInvocationContext initialization with custom metadata.""" arguments = FunctionTestArgs(name="test") metadata = {"key": "value"} @@ -420,7 +420,7 @@ async def process(self, context: FunctionInvocationContext, call_next: Any) -> N await call_next() raise MiddlewareTermination - async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool) -> None: """Test pipeline execution with termination before next() raises MiddlewareTermination.""" middleware = self.PreNextTerminateFunctionMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -439,7 +439,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: # Handler should not be called when terminated before next() assert execution_order == [] - async def test_execute_with_post_next_termination(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_post_next_termination(self, mock_function: FunctionTool) -> None: """Test pipeline execution with termination after next() raises MiddlewareTermination.""" middleware = self.PostNextTerminateFunctionMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -480,7 +480,7 @@ async def test_middleware(context: FunctionInvocationContext, call_next: Callabl pipeline = FunctionMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares - async def test_execute_no_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_no_middleware(self, mock_function: FunctionTool) -> None: """Test pipeline execution with no middleware.""" pipeline = FunctionMiddlewarePipeline() arguments = FunctionTestArgs(name="test") @@ -494,7 +494,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: result = await pipeline.execute(context, final_handler) assert result == expected_result - async def test_execute_with_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_middleware(self, mock_function: FunctionTool) -> None: """Test pipeline execution with middleware.""" execution_order: list[str] = [] @@ -787,7 +787,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert context.metadata["after"] is True assert metadata_updates == ["before", "handler", "after"] - async def test_function_middleware_execution(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_execution(self, mock_function: FunctionTool) -> None: """Test class-based function middleware execution.""" metadata_updates: list[str] = [] @@ -847,7 +847,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert context.metadata["function_middleware"] is True assert execution_order == ["function_before", "handler", "function_after"] - async def test_function_function_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_function_middleware(self, mock_function: FunctionTool) -> None: """Test function-based function middleware.""" execution_order: list[str] = [] @@ -905,7 +905,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert result is not None assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] - async def test_mixed_function_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_mixed_function_middleware(self, mock_function: FunctionTool) -> None: """Test mixed class and function-based function middleware.""" execution_order: list[str] = [] @@ -1017,7 +1017,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: ] assert execution_order == expected_order - async def test_function_middleware_execution_order(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_execution_order(self, mock_function: FunctionTool) -> None: """Test that multiple function middleware execute in registration order.""" execution_order: list[str] = [] @@ -1143,7 +1143,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: result = await pipeline.execute(context, final_handler) assert result is not None - async def test_function_context_validation(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_context_validation(self, mock_function: FunctionTool) -> None: """Test that function context contains expected data.""" class ContextValidationMiddleware(FunctionMiddleware): @@ -1489,7 +1489,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: assert not handler_called assert context.result is None - async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool) -> None: """Test that when function middleware doesn't call next(), no execution happens.""" class FunctionTestArgs(BaseModel): @@ -1666,9 +1666,9 @@ def mock_agent() -> SupportsAgentRun: @pytest.fixture -def mock_function() -> FunctionTool[Any]: +def mock_function() -> FunctionTool: """Mock function for testing.""" - function = MagicMock(spec=FunctionTool[Any]) + function = MagicMock(spec=FunctionTool) function.name = "test_function" return function diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index ba6bfb9c4a..6d9eec351b 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any from unittest.mock import MagicMock import pytest @@ -103,7 +102,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: assert updates[0].text == "overridden" assert updates[1].text == " stream" - async def test_function_middleware_result_override(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_result_override(self, mock_function: FunctionTool) -> None: """Test that function middleware can override result.""" override_result = "overridden function result" @@ -252,7 +251,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert execute_result.messages[0].text == "executed response" assert handler_called - async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool) -> None: """Test that when function middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextFunctionMiddleware(FunctionMiddleware): @@ -335,7 +334,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert observed_responses[0].messages[0].text == "executed response" assert result == observed_responses[0] - async def test_function_middleware_result_observability(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_result_observability(self, mock_function: FunctionTool) -> None: """Test that middleware can observe function result after execution.""" observed_results: list[str] = [] @@ -402,7 +401,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert result is not None assert result.messages[0].text == "modified after execution" - async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool) -> None: """Test that middleware can override function result after observing execution.""" class PostExecutionOverrideMiddleware(FunctionMiddleware): @@ -444,8 +443,8 @@ def mock_agent() -> SupportsAgentRun: @pytest.fixture -def mock_function() -> FunctionTool[Any]: +def mock_function() -> FunctionTool: """Mock function for testing.""" - function = MagicMock(spec=FunctionTool[Any]) + function = MagicMock(spec=FunctionTool) function.name = "test_function" return function From 4e8f9ea0beea6855231109ac33f33a2010645c53 Mon Sep 17 00:00:00 2001 From: eavanvalkenburg Date: Fri, 13 Feb 2026 12:41:42 +0100 Subject: [PATCH 6/6] Use serializable kwargs in OTEL tool args --- .../packages/core/agent_framework/_tools.py | 38 +++++++++---------- python/packages/core/tests/core/test_tools.py | 2 +- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 9cb07f2970..4a4e30d324 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -498,29 +498,25 @@ async def invoke( return parsed attributes = get_function_span_attributes(self, tool_call_id=tool_call_id) - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] - # Filter out framework kwargs that are not JSON serializable - serializable_kwargs = { - k: v - for k, v in kwargs.items() - if k - not in { - "chat_options", - "tools", - "tool_choice", - "session", - "conversation_id", - "options", - "response_format", - } + # Filter out framework kwargs that are not JSON serializable. + serializable_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in { + "chat_options", + "tools", + "tool_choice", + "session", + "conversation_id", + "options", + "response_format", } + } + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] attributes.update({ OtelAttr.TOOL_ARGUMENTS: ( - arguments.model_dump_json(ensure_ascii=False) - if arguments and hasattr(arguments, "model_dump_json") - else json.dumps(arguments, default=str, ensure_ascii=False) - if arguments and not hasattr(arguments, "model_dump_json") - else json.dumps(serializable_kwargs, default=str, ensure_ascii=False) + json.dumps(serializable_kwargs, default=str, ensure_ascii=False) if serializable_kwargs else "None" ) @@ -529,7 +525,7 @@ async def invoke( attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name logger.info(f"Function name: {self.name}") if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] - logger.debug(f"Function arguments: {kwargs}") + logger.debug(f"Function arguments: {serializable_kwargs}") start_time_stamp = perf_counter() end_time_stamp: float | None = None try: diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index 98276f3d9c..8d74dc181d 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -639,7 +639,7 @@ def pydantic_test_tool(x: int, y: int) -> int: assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call" assert span.attributes[OtelAttr.TOOL_TYPE] == "function" assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x":5,"y":10}' + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}' async def test_tool_invoke_telemetry_with_exception(span_exporter: InMemorySpanExporter):