diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 373e6321bb..a63b6ac50c 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -267,7 +267,7 @@ def _register_server_tool_placeholder(self, tool_name: str) -> None: if any(getattr(tool, "name", None) == tool_name for tool in additional_tools): return - placeholder: FunctionTool[Any] = FunctionTool( + placeholder: FunctionTool = FunctionTool( name=tool_name, description="Server-managed tool placeholder (AG-UI)", func=None, diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py index abbfb88562..bfda3948ec 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_utils.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_utils.py @@ -162,7 +162,7 @@ def make_json_safe(obj: Any) -> Any: # noqa: ANN401 def convert_agui_tools_to_agent_framework( agui_tools: list[dict[str, Any]] | None, -) -> list[FunctionTool[Any]] | None: +) -> list[FunctionTool] | None: """Convert AG-UI tool definitions to Agent Framework FunctionTool declarations. Creates declaration-only FunctionTool instances (no executable implementation). @@ -181,13 +181,13 @@ def convert_agui_tools_to_agent_framework( if not agui_tools: return None - result: list[FunctionTool[Any]] = [] + result: list[FunctionTool] = [] for tool_def in agui_tools: # Create declaration-only FunctionTool (func=None means no implementation) # When func=None, the declaration_only property returns True, # which tells the function invocation mixin to return the function call # without executing it (so it can be sent back to the client) - func: FunctionTool[Any] = FunctionTool( + func: FunctionTool = FunctionTool( name=tool_def.get("name", ""), description=tool_def.get("description", ""), func=None, # CRITICAL: Makes declaration_only=True diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py index 7f5a4b0f2c..9ea92dd24e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/ui_generator_agent.py @@ -5,7 +5,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any, TypedDict +from typing import TYPE_CHECKING, TypedDict from agent_framework import Agent, FunctionTool, SupportsChatGetResponse from agent_framework.ag_ui import AgentFrameworkAgent @@ -23,7 +23,7 @@ from agent_framework import ChatOptions # Declaration-only tools (func=None) - actual rendering happens on the client side -generate_haiku = FunctionTool[Any]( +generate_haiku = FunctionTool( name="generate_haiku", description="""Generate a haiku with image and gradient background (FRONTEND_RENDER). @@ -71,7 +71,7 @@ }, ) -create_chart = FunctionTool[Any]( +create_chart = FunctionTool( name="create_chart", description="""Create an interactive chart (FRONTEND_RENDER). @@ -99,7 +99,7 @@ }, ) -display_timeline = FunctionTool[Any]( +display_timeline = FunctionTool( name="display_timeline", description="""Display an interactive timeline (FRONTEND_RENDER). @@ -127,7 +127,7 @@ }, ) -show_comparison_table = FunctionTool[Any]( +show_comparison_table = FunctionTool( name="show_comparison_table", description="""Show a comparison table (FRONTEND_RENDER). diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 10ef5cbf45..ad6f1b3e03 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -484,7 +484,7 @@ def _prepare_tools( return create_sdk_mcp_server(name=TOOLS_MCP_SERVER_NAME, tools=sdk_tools), tool_names - def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool[Any]) -> SdkMcpTool[Any]: + def _function_tool_to_sdk_mcp_tool(self, func_tool: FunctionTool) -> SdkMcpTool[Any]: """Convert a FunctionTool to an SDK MCP tool. Args: diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 49bdd64387..d11f0e2c7d 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -439,7 +439,7 @@ def as_tool( stream_callback: Callable[[AgentResponseUpdate], None] | Callable[[AgentResponseUpdate], Awaitable[None]] | None = None, - ) -> FunctionTool[BaseModel]: + ) -> FunctionTool: """Create a FunctionTool that wraps this agent. Keyword Args: @@ -513,7 +513,7 @@ async def agent_wrapper(**kwargs: Any) -> str: # Create final text from accumulated updates return AgentResponse.from_updates(response_updates).text - agent_tool: FunctionTool[BaseModel] = FunctionTool( + agent_tool: FunctionTool = FunctionTool( name=tool_name, description=tool_description, func=agent_wrapper, @@ -1258,17 +1258,12 @@ async def _log(level: types.LoggingLevel, data: Any) -> None: @server.list_tools() # type: ignore async def _list_tools() -> list[types.Tool]: # type: ignore """List all tools in the agent.""" - # Get the JSON schema from the Pydantic model - schema = agent_tool.input_model.model_json_schema() + schema = agent_tool.parameters() tool = types.Tool( name=agent_tool.name, description=agent_tool.description, - inputSchema={ - "type": "object", - "properties": schema.get("properties", {}), - "required": schema.get("required", []), - }, + inputSchema=schema, ) await _log(level="debug", data=f"Agent tool: {agent_tool}") @@ -1291,7 +1286,9 @@ async def _call_tool( # type: ignore # Create an instance of the input model with the arguments try: - args_instance = agent_tool.input_model(**arguments) + args_instance: BaseModel | dict[str, Any] = ( + agent_tool.input_model(**arguments) if agent_tool.input_model is not None else arguments + ) result = await agent_tool.invoke(arguments=args_instance) except Exception as e: raise McpError( diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index d9a2b5579d..f9d2cd9971 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -24,11 +24,9 @@ from mcp.shared.context import RequestContext from mcp.shared.exceptions import McpError from mcp.shared.session import RequestResponder -from pydantic import BaseModel, create_model from ._tools import ( FunctionTool, - _build_pydantic_model_from_json_schema, ) from ._types import ( Content, @@ -355,11 +353,14 @@ def _prepare_message_for_mcp( return messages -def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]: - """Creates a Pydantic model from a prompt's parameters.""" +def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> dict[str, Any]: + """Get the input model from an MCP prompt. + + Returns a JSON schema dictionary for prompt arguments. + """ # Check if 'arguments' is missing or empty if not prompt.arguments: - return create_model(f"{prompt.name}_input") + return {"type": "object", "properties": {}} # Convert prompt arguments to JSON schema format properties: dict[str, Any] = {} @@ -374,13 +375,10 @@ def _get_input_model_from_mcp_prompt(prompt: types.Prompt) -> type[BaseModel]: if prompt_argument.required: required.append(prompt_argument.name) - schema = {"properties": properties, "required": required} - return _build_pydantic_model_from_json_schema(prompt.name, schema) - - -def _get_input_model_from_mcp_tool(tool: types.Tool) -> type[BaseModel]: - """Creates a Pydantic model from a tools parameters.""" - return _build_pydantic_model_from_json_schema(tool.name, tool.inputSchema) + schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + schema["required"] = required + return schema def _normalize_mcp_name(name: str) -> str: @@ -467,7 +465,7 @@ def __init__( self.session = session self.request_timeout = request_timeout self.client = client - self._functions: list[FunctionTool[Any]] = [] + self._functions: list[FunctionTool] = [] self.is_connected: bool = False self._tools_loaded: bool = False self._prompts_loaded: bool = False @@ -476,7 +474,7 @@ def __str__(self) -> str: return f"MCPTool(name={self.name}, description={self.description})" @property - def functions(self) -> list[FunctionTool[Any]]: + def functions(self) -> list[FunctionTool]: """Get the list of functions that are allowed.""" if not self.allowed_tools: return self._functions @@ -744,7 +742,7 @@ async def load_prompts(self) -> None: input_model = _get_input_model_from_mcp_prompt(prompt) approval_mode = self._determine_approval_mode(local_name) - func: FunctionTool[BaseModel] = FunctionTool( + func: FunctionTool = FunctionTool( func=partial(self.get_prompt, prompt.name), name=local_name, description=prompt.description or "", @@ -785,15 +783,14 @@ async def load_tools(self) -> None: if local_name in existing_names: continue - input_model = _get_input_model_from_mcp_tool(tool) approval_mode = self._determine_approval_mode(local_name) # Create FunctionTools out of each tool - func: FunctionTool[BaseModel] = FunctionTool( + func: FunctionTool = FunctionTool( func=partial(self.call_tool, tool.name), name=local_name, description=tool.description or "", approval_mode=approval_mode, - input_model=input_model, + input_model=tool.inputSchema, ) self._functions.append(func) existing_names.add(local_name) diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index 969ef1efc9..299da150ab 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -234,8 +234,8 @@ async def process(self, context: FunctionInvocationContext, call_next): def __init__( self, - function: FunctionTool[Any], - arguments: BaseModel, + function: FunctionTool, + arguments: BaseModel | Mapping[str, Any], metadata: Mapping[str, Any] | None = None, result: Any = None, kwargs: Mapping[str, Any] | None = None, diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 8374d18e60..4a4e30d324 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -26,7 +26,6 @@ Literal, TypedDict, Union, - cast, get_args, get_origin, overload, @@ -89,8 +88,6 @@ ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") # region Helpers -ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel) - def _parse_inputs( inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None, @@ -183,11 +180,7 @@ def _default_histogram() -> Histogram: ClassT = TypeVar("ClassT", bound="SerializationMixin") -class EmptyInputModel(BaseModel): - """An empty input model for functions with no parameters.""" - - -class FunctionTool(SerializationMixin, Generic[ArgsT]): +class FunctionTool(SerializationMixin): """A tool that wraps a Python function to make it callable by AI models. This class wraps a Python function to make it callable by AI models with automatic @@ -240,6 +233,8 @@ class WeatherArgs(BaseModel): "input_model", "_invocation_duration_histogram", "_cached_parameters", + "_input_schema", + "_schema_supplied", } def __init__( @@ -252,7 +247,7 @@ def __init__( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, func: Callable[..., Any] | None = None, - input_model: type[ArgsT] | Mapping[str, Any] | None = None, + input_model: type[BaseModel] | Mapping[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, **kwargs: Any, ) -> None: @@ -299,7 +294,16 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods - self.input_model = self._resolve_input_model(input_model) + + # Track if schema was supplied as JSON dict (for optimization) + if isinstance(input_model, Mapping): + self._schema_supplied = True + self._input_schema: dict[str, Any] = dict(input_model) + self.input_model: type[BaseModel] | None = None + else: + self._schema_supplied = False + self.input_model = self._resolve_input_model(input_model) + self._input_schema = self.input_model.model_json_schema() self._cached_parameters: dict[str, Any] | None = None self.approval_mode = approval_mode or "never_require" if max_invocations is not None and max_invocations < 1: @@ -335,7 +339,7 @@ def declaration_only(self) -> bool: return True return self.func is None - def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: + def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool: """Implement the descriptor protocol to support bound methods. When a FunctionTool is accessed as an attribute of a class instance, @@ -366,17 +370,30 @@ def __get__(self, obj: Any, objtype: type | None = None) -> FunctionTool[ArgsT]: return self - def _resolve_input_model(self, input_model: type[ArgsT] | Mapping[str, Any] | None) -> type[ArgsT]: + def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[BaseModel]: """Resolve the input model for the function.""" - if input_model is None: - if self.func is None: - return cast(type[ArgsT], EmptyInputModel) - return cast(type[ArgsT], _create_input_model_from_func(func=self.func, name=self.name)) - if inspect.isclass(input_model) and issubclass(input_model, BaseModel): - return input_model - if isinstance(input_model, Mapping): - return cast(type[ArgsT], _create_model_from_json_schema(self.name, input_model)) - raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.") + if input_model is not None: + if inspect.isclass(input_model) and issubclass(input_model, BaseModel): + return input_model + raise TypeError("input_model must be a Pydantic BaseModel subclass or a JSON schema dict.") + + if self.func is None: + return create_model(f"{self.name}_input") + + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if func is None: + return create_model(f"{self.name}_input") + sig = inspect.signature(func) + fields: dict[str, Any] = { + pname: ( + _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, + param.default if param.default is not inspect.Parameter.empty else ..., + ) + for pname, param in sig.parameters.items() + if pname not in {"self", "cls"} + and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} + } + return create_model(f"{self.name}_input", **fields) def __call__(self, *args: Any, **kwargs: Any) -> Any: """Call the wrapped function with the provided arguments.""" @@ -407,7 +424,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: async def invoke( self, *, - arguments: ArgsT | None = None, + arguments: BaseModel | Mapping[str, Any] | None = None, **kwargs: Any, ) -> str: """Run the AI function with the provided arguments as a Pydantic model. @@ -417,14 +434,14 @@ async def invoke( ``result_parser`` if one was provided. Keyword Args: - arguments: A Pydantic model instance containing the arguments for the function. + arguments: A mapping or model instance containing the arguments for the function. kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided. Returns: The parsed result as a string — either plain text or serialized JSON. Raises: - TypeError: If arguments is not an instance of the expected input model. + TypeError: If arguments is not mapping-like or fails schema checks. """ if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") @@ -436,9 +453,32 @@ async def invoke( original_kwargs = dict(kwargs) tool_call_id = original_kwargs.pop("tool_call_id", None) if arguments is not None: - if not isinstance(arguments, self.input_model): - raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") - kwargs = arguments.model_dump(exclude_none=True) + try: + if isinstance(arguments, Mapping): + parsed_arguments = dict(arguments) + if self.input_model is not None and not self._schema_supplied: + parsed_arguments = self.input_model.model_validate(parsed_arguments).model_dump( + exclude_none=True + ) + elif isinstance(arguments, BaseModel): + if ( + self.input_model is not None + and not self._schema_supplied + and not isinstance(arguments, self.input_model) + ): + raise TypeError(f"Expected {self.input_model.__name__}, got {type(arguments).__name__}") + parsed_arguments = arguments.model_dump(exclude_none=True) + else: + raise TypeError( + f"Expected mapping-like arguments for tool '{self.name}', got {type(arguments).__name__}" + ) + except ValidationError as exc: + raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc + kwargs = _validate_arguments_against_schema( + arguments=parsed_arguments, + schema=self.parameters(), + tool_name=self.name, + ) if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: kwargs.update(original_kwargs) else: @@ -458,34 +498,34 @@ async def invoke( return parsed attributes = get_function_span_attributes(self, tool_call_id=tool_call_id) - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] - # Filter out framework kwargs that are not JSON serializable - serializable_kwargs = { - k: v - for k, v in kwargs.items() - if k - not in { - "chat_options", - "tools", - "tool_choice", - "session", - "conversation_id", - "options", - "response_format", - } + # Filter out framework kwargs that are not JSON serializable. + serializable_kwargs = { + k: v + for k, v in kwargs.items() + if k + not in { + "chat_options", + "tools", + "tool_choice", + "session", + "conversation_id", + "options", + "response_format", } + } + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] attributes.update({ - OtelAttr.TOOL_ARGUMENTS: arguments.model_dump_json(ensure_ascii=False) - if arguments - else json.dumps(serializable_kwargs, default=str, ensure_ascii=False) - if serializable_kwargs - else "None" + OtelAttr.TOOL_ARGUMENTS: ( + json.dumps(serializable_kwargs, default=str, ensure_ascii=False) + if serializable_kwargs + else "None" + ) }) with get_function_span(attributes=attributes) as span: attributes[OtelAttr.MEASUREMENT_FUNCTION_TAG_NAME] = self.name logger.info(f"Function name: {self.name}") if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] - logger.debug(f"Function arguments: {kwargs}") + logger.debug(f"Function arguments: {serializable_kwargs}") start_time_stamp = perf_counter() end_time_stamp: float | None = None try: @@ -523,7 +563,7 @@ def parameters(self) -> dict[str, Any]: The result is cached after the first call for performance. """ if self._cached_parameters is None: - self._cached_parameters = self.input_model.model_json_schema() + self._cached_parameters = self._input_schema return self._cached_parameters @staticmethod @@ -677,23 +717,79 @@ def _parse_annotation(annotation: Any) -> Any: return annotation -def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[BaseModel]: - """Create a Pydantic model from a function's signature.""" - # Unwrap FunctionTool objects to get the underlying function - if isinstance(func, FunctionTool): - func = func.func # type: ignore[assignment] +def _matches_json_schema_type(value: Any, schema_type: str) -> bool: + """Check a value against a simple JSON schema primitive type.""" + match schema_type: + case "string": + return isinstance(value, str) + case "integer": + return isinstance(value, int) and not isinstance(value, bool) + case "number": + return (isinstance(value, int | float)) and not isinstance(value, bool) + case "boolean": + return isinstance(value, bool) + case "array": + return isinstance(value, list) + case "object": + return isinstance(value, dict) + case "null": + return value is None + case _: + return True - sig = inspect.signature(func) - fields = { - pname: ( - _parse_annotation(param.annotation) if param.annotation is not inspect.Parameter.empty else str, - param.default if param.default is not inspect.Parameter.empty else ..., - ) - for pname, param in sig.parameters.items() - if pname not in {"self", "cls"} - and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} - } - return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return] + +def _validate_arguments_against_schema( + *, + arguments: Mapping[str, Any], + schema: Mapping[str, Any], + tool_name: str, +) -> dict[str, Any]: + """Run lightweight argument checks for schema-supplied tools.""" + parsed_arguments = dict(arguments) + + required_raw = schema.get("required", []) + required_fields = [field for field in required_raw if isinstance(field, str)] + missing_fields = [field for field in required_fields if field not in parsed_arguments] + if missing_fields: + raise TypeError(f"Missing required argument(s) for '{tool_name}': {', '.join(sorted(missing_fields))}") + + properties_raw = schema.get("properties") + properties = properties_raw if isinstance(properties_raw, Mapping) else {} + + if schema.get("additionalProperties") is False: + unexpected_fields = sorted(field for field in parsed_arguments if field not in properties) + if unexpected_fields: + raise TypeError(f"Unexpected argument(s) for '{tool_name}': {', '.join(unexpected_fields)}") + + for field_name, field_value in parsed_arguments.items(): + field_schema = properties.get(field_name) + if not isinstance(field_schema, Mapping): + continue + + enum_values = field_schema.get("enum") + if isinstance(enum_values, list) and enum_values and field_value not in enum_values: + raise TypeError( + f"Invalid value for '{field_name}' in '{tool_name}': {field_value!r} is not in {enum_values!r}" + ) + + schema_type = field_schema.get("type") + if isinstance(schema_type, str): + if not _matches_json_schema_type(field_value, schema_type): + raise TypeError( + f"Invalid type for '{field_name}' in '{tool_name}': " + f"expected {schema_type}, got {type(field_value).__name__}" + ) + continue + + if isinstance(schema_type, list): + allowed_types = [item for item in schema_type if isinstance(item, str)] + if allowed_types and not any(_matches_json_schema_type(field_value, item) for item in allowed_types): + raise TypeError( + f"Invalid type for '{field_name}' in '{tool_name}': expected one of " + f"{allowed_types}, got {type(field_value).__name__}" + ) + + return parsed_arguments # Map JSON Schema types to Pydantic types @@ -942,7 +1038,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> FunctionTool[Any]: ... +) -> FunctionTool: ... @overload @@ -957,7 +1053,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> Callable[[Callable[..., Any]], FunctionTool[Any]]: ... +) -> Callable[[Callable[..., Any]], FunctionTool]: ... def tool( @@ -971,7 +1067,7 @@ def tool( max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, result_parser: Callable[[Any], str] | None = None, -) -> FunctionTool[Any] | Callable[[Callable[..., Any]], FunctionTool[Any]]: +) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]: """Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically. This decorator creates a Pydantic model from the function's signature, @@ -1095,12 +1191,12 @@ def get_weather(location: str, unit: str = "celsius") -> str: """ - def decorator(func: Callable[..., Any]) -> FunctionTool[Any]: + def decorator(func: Callable[..., Any]) -> FunctionTool: @wraps(func) - def wrapper(f: Callable[..., Any]) -> FunctionTool[Any]: + def wrapper(f: Callable[..., Any]) -> FunctionTool: tool_name: str = name or getattr(f, "__name__", "unknown_function") # type: ignore[assignment] tool_desc: str = description or (f.__doc__ or "") - return FunctionTool[Any]( + return FunctionTool( name=tool_name, description=tool_desc, approval_mode=approval_mode, @@ -1193,7 +1289,7 @@ async def _auto_invoke_function( custom_args: dict[str, Any] | None = None, *, config: FunctionInvocationConfiguration, - tool_map: dict[str, FunctionTool[BaseModel]], + tool_map: dict[str, FunctionTool], sequence_index: int | None = None, request_index: int | None = None, middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline @@ -1225,7 +1321,7 @@ async def _auto_invoke_function( # this function is called. This function only handles the actual execution of approved, # non-declaration-only functions. - tool: FunctionTool[BaseModel] | None = None + tool: FunctionTool | None = None if function_call_content.type == "function_call": tool = tool_map.get(function_call_content.name) # type: ignore[arg-type] # Tool should exist because _try_execute_function_calls validates this @@ -1258,8 +1354,16 @@ async def _auto_invoke_function( if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } try: - args = tool.input_model.model_validate(parsed_args) - except ValidationError as exc: + if not tool._schema_supplied and tool.input_model is not None: + args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) + else: + args = dict(parsed_args) + args = _validate_arguments_against_schema( + arguments=args, + schema=tool.parameters(), + tool_name=tool.name, + ) + except (TypeError, ValidationError) as exc: message = "Error: Argument parsing failed." if config["include_detailed_errors"]: message = f"{message} Exception: {exc}" @@ -1340,8 +1444,8 @@ def _get_tool_map( | Callable[..., Any] | MutableMapping[str, Any] | Sequence[FunctionTool | Callable[..., Any] | MutableMapping[str, Any]], -) -> dict[str, FunctionTool[Any]]: - tool_list: dict[str, FunctionTool[Any]] = {} +) -> dict[str, FunctionTool]: + tool_list: dict[str, FunctionTool] = {} for tool_item in tools if isinstance(tools, list) else [tools]: if isinstance(tool_item, FunctionTool): tool_list[tool_item.name] = tool_item diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 5417c8988d..d19683d9e2 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -1448,7 +1448,7 @@ async def _run() -> AgentResponse: # region Otel Helpers -def get_function_span_attributes(function: FunctionTool[Any], tool_call_id: str | None = None) -> dict[str, str]: +def get_function_span_attributes(function: FunctionTool, tool_call_id: str | None = None) -> dict[str, str]: """Get the span attributes for the given function. Args: diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 38cb243412..8b213476aa 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -10,7 +10,7 @@ from mcp import types from mcp.client.session import ClientSession from mcp.shared.exceptions import McpError -from pydantic import AnyUrl, BaseModel, ValidationError +from pydantic import AnyUrl, BaseModel from agent_framework import ( Content, @@ -22,7 +22,6 @@ from agent_framework._mcp import ( MCPTool, _get_input_model_from_mcp_prompt, - _get_input_model_from_mcp_tool, _normalize_mcp_name, _parse_content_from_mcp, _parse_message_from_mcp, @@ -276,363 +275,338 @@ def test_prepare_message_for_mcp(): @pytest.mark.parametrize( - "test_id,input_schema,valid_data,expected_values,invalid_data,validation_check", + "test_id,input_schema", [ - # Basic types with required/optional fields - ( - "basic_types", - { - "type": "object", - "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, - "required": ["param1"], - }, - {"param1": "test", "param2": 42}, - {"param1": "test", "param2": 42}, - {"param2": 42}, # Missing required param1 - None, - ), - # Nested object - ( - "nested_object", - { - "type": "object", - "properties": { - "params": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], - } + (test_id, input_schema) + for test_id, input_schema, _, _, _, _ in [ + # Basic types with required/optional fields + ( + "basic_types", + { + "type": "object", + "properties": {"param1": {"type": "string"}, "param2": {"type": "number"}}, + "required": ["param1"], }, - "required": ["params"], - }, - {"params": {"customer_id": 251}}, - {"params.customer_id": 251}, - {"params": {}}, # Missing required customer_id - lambda instance: isinstance(instance.params, BaseModel), - ), - # $ref resolution - ( - "ref_schema", - { - "type": "object", - "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, - "required": ["params"], - "$defs": { - "CustomerIdParam": { - "type": "object", - "properties": {"customer_id": {"type": "integer"}}, - "required": ["customer_id"], - } + {"param1": "test", "param2": 42}, + {"param1": "test", "param2": 42}, + {"param2": 42}, # Missing required param1 + None, + ), + # Nested object + ( + "nested_object", + { + "type": "object", + "properties": { + "params": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, + "required": ["params"], }, - }, - {"params": {"customer_id": 251}}, - {"params.customer_id": 251}, - {"params": {}}, # Missing required customer_id - lambda instance: isinstance(instance.params, BaseModel), - ), - # Array of strings (typed) - ( - "array_of_strings", - { - "type": "object", - "properties": { - "tags": { - "type": "array", - "description": "List of tags", - "items": {"type": "string"}, - } + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # $ref resolution + ( + "ref_schema", + { + "type": "object", + "properties": {"params": {"$ref": "#/$defs/CustomerIdParam"}}, + "required": ["params"], + "$defs": { + "CustomerIdParam": { + "type": "object", + "properties": {"customer_id": {"type": "integer"}}, + "required": ["customer_id"], + } + }, }, - "required": ["tags"], - }, - {"tags": ["tag1", "tag2", "tag3"]}, - {"tags": ["tag1", "tag2", "tag3"]}, - None, # No validation error test for this case - None, - ), - # Array of integers (typed) - ( - "array_of_integers", - { - "type": "object", - "properties": { - "numbers": { - "type": "array", - "description": "List of integers", - "items": {"type": "integer"}, - } + {"params": {"customer_id": 251}}, + {"params.customer_id": 251}, + {"params": {}}, # Missing required customer_id + lambda instance: isinstance(instance.params, BaseModel), + ), + # Array of strings (typed) + ( + "array_of_strings", + { + "type": "object", + "properties": { + "tags": { + "type": "array", + "description": "List of tags", + "items": {"type": "string"}, + } + }, + "required": ["tags"], }, - "required": ["numbers"], - }, - {"numbers": [1, 2, 3]}, - {"numbers": [1, 2, 3]}, - None, - None, - ), - # Array of objects (complex nested) - ( - "array_of_objects", - { - "type": "object", - "properties": { - "users": { - "type": "array", - "description": "List of users", - "items": { - "type": "object", - "properties": { - "id": {"type": "integer", "description": "User ID"}, - "name": {"type": "string", "description": "User name"}, - }, - "required": ["id", "name"], - }, - } + {"tags": ["tag1", "tag2", "tag3"]}, + {"tags": ["tag1", "tag2", "tag3"]}, + None, # No validation error test for this case + None, + ), + # Array of integers (typed) + ( + "array_of_integers", + { + "type": "object", + "properties": { + "numbers": { + "type": "array", + "description": "List of integers", + "items": {"type": "integer"}, + } + }, + "required": ["numbers"], }, - "required": ["users"], - }, - {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - {"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"}, - {"users": [{"id": 1}]}, # Missing required 'name' - lambda instance: all(isinstance(user, BaseModel) for user in instance.users), - ), - # Deeply nested objects (3+ levels) - ( - "deeply_nested", - { - "type": "object", - "properties": { - "query": { - "type": "object", - "properties": { - "filters": { + {"numbers": [1, 2, 3]}, + {"numbers": [1, 2, 3]}, + None, + None, + ), + # Array of objects (complex nested) + ( + "array_of_objects", + { + "type": "object", + "properties": { + "users": { + "type": "array", + "description": "List of users", + "items": { "type": "object", "properties": { - "date_range": { - "type": "object", - "properties": { - "start": {"type": "string"}, - "end": {"type": "string"}, + "id": {"type": "integer", "description": "User ID"}, + "name": {"type": "string", "description": "User name"}, + }, + "required": ["id", "name"], + }, + } + }, + "required": ["users"], + }, + {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, + {"users[0].id": 1, "users[0].name": "Alice", "users[1].id": 2, "users[1].name": "Bob"}, + {"users": [{"id": 1}]}, # Missing required 'name' + lambda instance: all(isinstance(user, BaseModel) for user in instance.users), + ), + # Deeply nested objects (3+ levels) + ( + "deeply_nested", + { + "type": "object", + "properties": { + "query": { + "type": "object", + "properties": { + "filters": { + "type": "object", + "properties": { + "date_range": { + "type": "object", + "properties": { + "start": {"type": "string"}, + "end": {"type": "string"}, + }, + "required": ["start", "end"], }, - "required": ["start", "end"], + "categories": {"type": "array", "items": {"type": "string"}}, }, - "categories": {"type": "array", "items": {"type": "string"}}, - }, - "required": ["date_range"], - } - }, - "required": ["filters"], - } + "required": ["date_range"], + } + }, + "required": ["filters"], + } + }, + "required": ["query"], }, - "required": ["query"], - }, - { - "query": { - "filters": { - "date_range": {"start": "2024-01-01", "end": "2024-12-31"}, - "categories": ["tech", "science"], + { + "query": { + "filters": { + "date_range": {"start": "2024-01-01", "end": "2024-12-31"}, + "categories": ["tech", "science"], + } } - } - }, - { - "query.filters.date_range.start": "2024-01-01", - "query.filters.date_range.end": "2024-12-31", - "query.filters.categories": ["tech", "science"], - }, - {"query": {"filters": {"date_range": {}}}}, # Missing required start and end - None, - ), - # Complex $ref with nested structure - ( - "ref_nested_structure", - { - "type": "object", - "properties": {"order": {"$ref": "#/$defs/OrderParams"}}, - "required": ["order"], - "$defs": { - "OrderParams": { - "type": "object", - "properties": { - "customer": {"$ref": "#/$defs/Customer"}, - "items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}}, + }, + { + "query.filters.date_range.start": "2024-01-01", + "query.filters.date_range.end": "2024-12-31", + "query.filters.categories": ["tech", "science"], + }, + {"query": {"filters": {"date_range": {}}}}, # Missing required start and end + None, + ), + # Complex $ref with nested structure + ( + "ref_nested_structure", + { + "type": "object", + "properties": {"order": {"$ref": "#/$defs/OrderParams"}}, + "required": ["order"], + "$defs": { + "OrderParams": { + "type": "object", + "properties": { + "customer": {"$ref": "#/$defs/Customer"}, + "items": {"type": "array", "items": {"$ref": "#/$defs/OrderItem"}}, + }, + "required": ["customer", "items"], + }, + "Customer": { + "type": "object", + "properties": {"id": {"type": "integer"}, "email": {"type": "string"}}, + "required": ["id", "email"], + }, + "OrderItem": { + "type": "object", + "properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}}, + "required": ["product_id", "quantity"], }, - "required": ["customer", "items"], - }, - "Customer": { - "type": "object", - "properties": {"id": {"type": "integer"}, "email": {"type": "string"}}, - "required": ["id", "email"], - }, - "OrderItem": { - "type": "object", - "properties": {"product_id": {"type": "string"}, "quantity": {"type": "integer"}}, - "required": ["product_id", "quantity"], }, }, - }, - { - "order": { - "customer": {"id": 123, "email": "test@example.com"}, - "items": [{"product_id": "prod1", "quantity": 2}], - } - }, - { - "order.customer.id": 123, - "order.customer.email": "test@example.com", - "order.items[0].product_id": "prod1", - "order.items[0].quantity": 2, - }, - {"order": {"customer": {"id": 123}, "items": []}}, # Missing email - lambda instance: isinstance(instance.order.customer, BaseModel), - ), - # Mixed types (primitives, arrays, nested objects) - ( - "mixed_types", - { - "type": "object", - "properties": { - "simple_string": {"type": "string"}, - "simple_number": {"type": "integer"}, - "string_array": {"type": "array", "items": {"type": "string"}}, - "nested_config": { - "type": "object", - "properties": { - "enabled": {"type": "boolean"}, - "options": {"type": "array", "items": {"type": "string"}}, + { + "order": { + "customer": {"id": 123, "email": "test@example.com"}, + "items": [{"product_id": "prod1", "quantity": 2}], + } + }, + { + "order.customer.id": 123, + "order.customer.email": "test@example.com", + "order.items[0].product_id": "prod1", + "order.items[0].quantity": 2, + }, + {"order": {"customer": {"id": 123}, "items": []}}, # Missing email + lambda instance: isinstance(instance.order.customer, BaseModel), + ), + # Mixed types (primitives, arrays, nested objects) + ( + "mixed_types", + { + "type": "object", + "properties": { + "simple_string": {"type": "string"}, + "simple_number": {"type": "integer"}, + "string_array": {"type": "array", "items": {"type": "string"}}, + "nested_config": { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + "options": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["enabled"], }, - "required": ["enabled"], }, + "required": ["simple_string", "nested_config"], }, - "required": ["simple_string", "nested_config"], - }, - { - "simple_string": "test", - "simple_number": 42, - "string_array": ["a", "b"], - "nested_config": {"enabled": True, "options": ["opt1", "opt2"]}, - }, - { - "simple_string": "test", - "simple_number": 42, - "string_array": ["a", "b"], - "nested_config.enabled": True, - "nested_config.options": ["opt1", "opt2"], - }, - None, - None, - ), - # Empty schema (no properties) - ( - "empty_schema", - {"type": "object", "properties": {}}, - {}, - {}, - None, - None, - ), - # All primitive types - ( - "all_primitives", - { - "type": "object", - "properties": { - "string_field": {"type": "string"}, - "integer_field": {"type": "integer"}, - "number_field": {"type": "number"}, - "boolean_field": {"type": "boolean"}, + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config": {"enabled": True, "options": ["opt1", "opt2"]}, }, - }, - {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, - {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, - None, - None, - ), - # Edge case: unresolvable $ref (fallback to dict) - ( - "unresolvable_ref", - { - "type": "object", - "properties": {"data": {"$ref": "#/$defs/NonExistent"}}, - "$defs": {}, - }, - {"data": {"key": "value"}}, - {"data": {"key": "value"}}, - None, - None, - ), - # Edge case: array without items schema (fallback to bare list) - ( - "array_no_items", - { - "type": "object", - "properties": {"items": {"type": "array"}}, - }, - {"items": [1, "two", 3.0]}, - {"items": [1, "two", 3.0]}, - None, - None, - ), - # Edge case: object without properties (fallback to dict) - ( - "object_no_properties", - { - "type": "object", - "properties": {"config": {"type": "object"}}, - }, - {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, - {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, - None, - None, - ), + { + "simple_string": "test", + "simple_number": 42, + "string_array": ["a", "b"], + "nested_config.enabled": True, + "nested_config.options": ["opt1", "opt2"], + }, + None, + None, + ), + # Empty schema (no properties) + ( + "empty_schema", + {"type": "object", "properties": {}}, + {}, + {}, + None, + None, + ), + # All primitive types + ( + "all_primitives", + { + "type": "object", + "properties": { + "string_field": {"type": "string"}, + "integer_field": {"type": "integer"}, + "number_field": {"type": "number"}, + "boolean_field": {"type": "boolean"}, + }, + }, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + {"string_field": "test", "integer_field": 42, "number_field": 3.14, "boolean_field": True}, + None, + None, + ), + # Edge case: unresolvable $ref (fallback to dict) + ( + "unresolvable_ref", + { + "type": "object", + "properties": {"data": {"$ref": "#/$defs/NonExistent"}}, + "$defs": {}, + }, + {"data": {"key": "value"}}, + {"data": {"key": "value"}}, + None, + None, + ), + # Edge case: array without items schema (fallback to bare list) + ( + "array_no_items", + { + "type": "object", + "properties": {"items": {"type": "array"}}, + }, + {"items": [1, "two", 3.0]}, + {"items": [1, "two", 3.0]}, + None, + None, + ), + # Edge case: object without properties (fallback to dict) + ( + "object_no_properties", + { + "type": "object", + "properties": {"config": {"type": "object"}}, + }, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + {"config": {"arbitrary": "data", "nested": {"key": "value"}}}, + None, + None, + ), + ] ], ) -def test_get_input_model_from_mcp_tool_parametrized( - test_id, input_schema, valid_data, expected_values, invalid_data, validation_check -): - """Parametrized test for JSON schema to Pydantic model conversion. - - This test covers various edge cases including: - - Basic types with required/optional fields - - Nested objects - - $ref resolution - - Typed arrays (strings, integers, objects) - - Deeply nested structures - - Complex $ref with nested structures - - Mixed types +def test_get_input_model_from_mcp_tool_parametrized(test_id: str, input_schema: dict[str, Any]) -> None: + """Parametrized test for MCP tool input schema passthrough. + + This test verifies that MCP tool schemas are passed through as-is + without Pydantic conversion, which improves performance and preserves + the original schema structure. To add a new test case, add a tuple to the parametrize decorator with: - test_id: A descriptive name for the test case - input_schema: The JSON schema (inputSchema dict) - - valid_data: Valid data to instantiate the model - - expected_values: Dict of expected values (supports dot notation for nested access) - - invalid_data: Invalid data to test validation errors (None to skip) - - validation_check: Optional callable to perform additional validation checks """ tool = types.Tool(name="test_tool", description="A test tool", inputSchema=input_schema) - model = _get_input_model_from_mcp_tool(tool) - - # Test valid data - instance = model(**valid_data) + schema = tool.inputSchema - # Check expected values - for field_path, expected_value in expected_values.items(): - # Support dot notation and array indexing for nested access - current = instance - parts = field_path.replace("]", "").replace("[", ".").split(".") - for part in parts: - current = current[int(part)] if part.isdigit() else getattr(current, part) - assert current == expected_value, f"Field {field_path} = {current}, expected {expected_value}" - - # Run additional validation checks if provided - if validation_check: - assert validation_check(instance), f"Validation check failed for {test_id}" - - # Test invalid data if provided - if invalid_data is not None: - with pytest.raises(ValidationError): - model(**invalid_data) + # Verify schema is returned as-is (dict) + assert isinstance(schema, dict), f"Expected dict, got {type(schema)}" + assert schema == input_schema, "Schema should be passed through unchanged" def test_get_input_model_from_mcp_prompt(): - """Test creation of input model from MCP prompt.""" + """Test creation of input schema from MCP prompt.""" prompt = types.Prompt( name="test_prompt", description="A test prompt", @@ -641,16 +615,24 @@ def test_get_input_model_from_mcp_prompt(): types.PromptArgument(name="arg2", description="Second argument", required=False), ], ) - model = _get_input_model_from_mcp_prompt(prompt) + result = _get_input_model_from_mcp_prompt(prompt) + + # Should return a dict (schema) + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + assert result["type"] == "object" + assert "arg1" in result["properties"] + assert "arg2" in result["properties"] + assert "arg1" in result["required"] + assert "arg2" not in result["required"] + - # Create an instance to verify the model works - instance = model(arg1="test", arg2="optional") - assert instance.arg1 == "test" - assert instance.arg2 == "optional" +def test_get_input_model_from_mcp_prompt_without_arguments(): + """Test prompt schema generation when no prompt arguments are defined.""" + prompt = types.Prompt(name="empty_prompt", description="No args prompt", arguments=[]) + result = _get_input_model_from_mcp_prompt(prompt) - # Test validation - with pytest.raises(ValidationError): # Missing required arg1 - model(arg2="optional") + assert isinstance(result, dict) + assert result == {"type": "object", "properties": {}} # MCPTool tests diff --git a/python/packages/core/tests/core/test_middleware.py b/python/packages/core/tests/core/test_middleware.py index 4ac4f22f1c..6c559c40d4 100644 --- a/python/packages/core/tests/core/test_middleware.py +++ b/python/packages/core/tests/core/test_middleware.py @@ -74,7 +74,7 @@ def test_init_with_session(self, mock_agent: SupportsAgentRun) -> None: class TestFunctionInvocationContext: """Test cases for FunctionInvocationContext.""" - def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None: + def test_init_with_defaults(self, mock_function: FunctionTool) -> None: """Test FunctionInvocationContext initialization with default values.""" arguments = FunctionTestArgs(name="test") context = FunctionInvocationContext(function=mock_function, arguments=arguments) @@ -83,7 +83,7 @@ def test_init_with_defaults(self, mock_function: FunctionTool[Any]) -> None: assert context.arguments == arguments assert context.metadata == {} - def test_init_with_custom_metadata(self, mock_function: FunctionTool[Any]) -> None: + def test_init_with_custom_metadata(self, mock_function: FunctionTool) -> None: """Test FunctionInvocationContext initialization with custom metadata.""" arguments = FunctionTestArgs(name="test") metadata = {"key": "value"} @@ -420,7 +420,7 @@ async def process(self, context: FunctionInvocationContext, call_next: Any) -> N await call_next() raise MiddlewareTermination - async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_pre_next_termination(self, mock_function: FunctionTool) -> None: """Test pipeline execution with termination before next() raises MiddlewareTermination.""" middleware = self.PreNextTerminateFunctionMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -439,7 +439,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: # Handler should not be called when terminated before next() assert execution_order == [] - async def test_execute_with_post_next_termination(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_post_next_termination(self, mock_function: FunctionTool) -> None: """Test pipeline execution with termination after next() raises MiddlewareTermination.""" middleware = self.PostNextTerminateFunctionMiddleware() pipeline = FunctionMiddlewarePipeline(middleware) @@ -480,7 +480,7 @@ async def test_middleware(context: FunctionInvocationContext, call_next: Callabl pipeline = FunctionMiddlewarePipeline(test_middleware) assert pipeline.has_middlewares - async def test_execute_no_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_no_middleware(self, mock_function: FunctionTool) -> None: """Test pipeline execution with no middleware.""" pipeline = FunctionMiddlewarePipeline() arguments = FunctionTestArgs(name="test") @@ -494,7 +494,7 @@ async def final_handler(ctx: FunctionInvocationContext) -> str: result = await pipeline.execute(context, final_handler) assert result == expected_result - async def test_execute_with_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_execute_with_middleware(self, mock_function: FunctionTool) -> None: """Test pipeline execution with middleware.""" execution_order: list[str] = [] @@ -787,7 +787,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert context.metadata["after"] is True assert metadata_updates == ["before", "handler", "after"] - async def test_function_middleware_execution(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_execution(self, mock_function: FunctionTool) -> None: """Test class-based function middleware execution.""" metadata_updates: list[str] = [] @@ -847,7 +847,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert context.metadata["function_middleware"] is True assert execution_order == ["function_before", "handler", "function_after"] - async def test_function_function_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_function_middleware(self, mock_function: FunctionTool) -> None: """Test function-based function middleware.""" execution_order: list[str] = [] @@ -905,7 +905,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert result is not None assert execution_order == ["class_before", "function_before", "handler", "function_after", "class_after"] - async def test_mixed_function_middleware(self, mock_function: FunctionTool[Any]) -> None: + async def test_mixed_function_middleware(self, mock_function: FunctionTool) -> None: """Test mixed class and function-based function middleware.""" execution_order: list[str] = [] @@ -1017,7 +1017,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: ] assert execution_order == expected_order - async def test_function_middleware_execution_order(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_execution_order(self, mock_function: FunctionTool) -> None: """Test that multiple function middleware execute in registration order.""" execution_order: list[str] = [] @@ -1143,7 +1143,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: result = await pipeline.execute(context, final_handler) assert result is not None - async def test_function_context_validation(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_context_validation(self, mock_function: FunctionTool) -> None: """Test that function context contains expected data.""" class ContextValidationMiddleware(FunctionMiddleware): @@ -1489,7 +1489,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: assert not handler_called assert context.result is None - async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_no_next_no_execution(self, mock_function: FunctionTool) -> None: """Test that when function middleware doesn't call next(), no execution happens.""" class FunctionTestArgs(BaseModel): @@ -1666,9 +1666,9 @@ def mock_agent() -> SupportsAgentRun: @pytest.fixture -def mock_function() -> FunctionTool[Any]: +def mock_function() -> FunctionTool: """Mock function for testing.""" - function = MagicMock(spec=FunctionTool[Any]) + function = MagicMock(spec=FunctionTool) function.name = "test_function" return function diff --git a/python/packages/core/tests/core/test_middleware_context_result.py b/python/packages/core/tests/core/test_middleware_context_result.py index ba6bfb9c4a..6d9eec351b 100644 --- a/python/packages/core/tests/core/test_middleware_context_result.py +++ b/python/packages/core/tests/core/test_middleware_context_result.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any from unittest.mock import MagicMock import pytest @@ -103,7 +102,7 @@ async def _stream() -> AsyncIterable[AgentResponseUpdate]: assert updates[0].text == "overridden" assert updates[1].text == " stream" - async def test_function_middleware_result_override(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_result_override(self, mock_function: FunctionTool) -> None: """Test that function middleware can override result.""" override_result = "overridden function result" @@ -252,7 +251,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert execute_result.messages[0].text == "executed response" assert handler_called - async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_conditional_no_next(self, mock_function: FunctionTool) -> None: """Test that when function middleware conditionally doesn't call next(), no execution happens.""" class ConditionalNoNextFunctionMiddleware(FunctionMiddleware): @@ -335,7 +334,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert observed_responses[0].messages[0].text == "executed response" assert result == observed_responses[0] - async def test_function_middleware_result_observability(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_result_observability(self, mock_function: FunctionTool) -> None: """Test that middleware can observe function result after execution.""" observed_results: list[str] = [] @@ -402,7 +401,7 @@ async def final_handler(ctx: AgentContext) -> AgentResponse: assert result is not None assert result.messages[0].text == "modified after execution" - async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool[Any]) -> None: + async def test_function_middleware_post_execution_override(self, mock_function: FunctionTool) -> None: """Test that middleware can override function result after observing execution.""" class PostExecutionOverrideMiddleware(FunctionMiddleware): @@ -444,8 +443,8 @@ def mock_agent() -> SupportsAgentRun: @pytest.fixture -def mock_function() -> FunctionTool[Any]: +def mock_function() -> FunctionTool: """Mock function for testing.""" - function = MagicMock(spec=FunctionTool[Any]) + function = MagicMock(spec=FunctionTool) function.name = "test_function" return function diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index dbcf7aac6f..8d74dc181d 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -108,6 +108,90 @@ def search(query: str, max_results: int = 10) -> str: assert search("hello") == "Searching for: hello (max 10)" +async def test_tool_decorator_with_json_schema_invoke_uses_mapping(): + """Test that schema-based tools can be invoked directly with mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "max_results": {"type": "integer"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str, max_results: int = 10) -> str: + return f"{query}:{max_results}" + + result = await search.invoke(arguments={"query": "hello", "max_results": 3}) + assert result == "hello:3" + + +async def test_tool_decorator_with_json_schema_invoke_missing_required(): + """Test schema-required fields are checked for mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str) -> str: + return query + + with pytest.raises(TypeError, match="Missing required argument"): + await search.invoke(arguments={}) + + +async def test_tool_decorator_with_json_schema_invoke_invalid_type(): + """Test schema type checks run for mapping arguments.""" + + json_schema = { + "type": "object", + "properties": { + "query": {"type": "string"}, + "max_results": {"type": "integer"}, + }, + "required": ["query"], + } + + @tool(name="search", description="Search tool", schema=json_schema) + def search(query: str, max_results: int = 10) -> str: + return f"{query}:{max_results}" + + with pytest.raises(TypeError, match="Invalid type for 'max_results'"): + await search.invoke(arguments={"query": "hello", "max_results": "three"}) + + +def test_tool_decorator_with_json_schema_preserves_custom_properties(): + """Test schema passthrough keeps custom JSON schema properties.""" + + json_schema = { + "type": "object", + "properties": { + "priority": { + "type": "string", + "enum": ["low", "medium", "high"], + "x-custom-field": "custom-value", + }, + }, + "required": ["priority"], + "additionalProperties": False, + } + + @tool(name="process", description="Process tool", schema=json_schema) + def process(priority: str) -> str: + return priority + + params = process.parameters() + assert not params.get("additionalProperties") + assert params["properties"]["priority"]["x-custom-field"] == "custom-value" + + def test_tool_decorator_schema_none_default(): """Test that schema=None (default) still infers from function signature.""" @@ -555,7 +639,7 @@ def pydantic_test_tool(x: int, y: int) -> int: assert span.attributes[OtelAttr.TOOL_CALL_ID] == "pydantic_call" assert span.attributes[OtelAttr.TOOL_TYPE] == "function" assert span.attributes[OtelAttr.TOOL_DESCRIPTION] == "A test tool with Pydantic args" - assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x":5,"y":10}' + assert span.attributes[OtelAttr.TOOL_ARGUMENTS] == '{"x": 5, "y": 10}' async def test_tool_invoke_telemetry_with_exception(span_exporter: InMemorySpanExporter): diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 42de197014..1cafc1ba85 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -499,7 +499,7 @@ def _prepare_tools( return copilot_tools - def _tool_to_copilot_tool(self, ai_func: FunctionTool[Any]) -> CopilotTool: + def _tool_to_copilot_tool(self, ai_func: FunctionTool) -> CopilotTool: """Convert an FunctionTool to a Copilot SDK tool.""" async def handler(invocation: ToolInvocation) -> ToolResult: diff --git a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py index 42d03393f8..75c0676cb6 100644 --- a/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py +++ b/python/packages/lab/tau2/agent_framework_lab_tau2/_tau2_utils.py @@ -27,7 +27,7 @@ _original_set_state = Environment.set_state -def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool[Any]: +def convert_tau2_tool_to_function_tool(tau2_tool: Tool) -> FunctionTool: """Convert a tau2 Tool to a FunctionTool for agent framework compatibility. Creates a wrapper that preserves the tool's interface while ensuring diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index ea5f8bd201..e0dc3e8ea9 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -324,7 +324,7 @@ def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration existing_tools = list(default_options.get("tools") or []) existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - new_tools: list[FunctionTool[Any]] = [] + new_tools: list[FunctionTool] = [] for target in targets: handoff_tool = self._create_handoff_tool(target.target_id, target.description) if handoff_tool.name in existing_names: @@ -340,7 +340,7 @@ def _apply_auto_tools(self, agent: Agent, targets: Sequence[HandoffConfiguration else: default_options["tools"] = existing_tools - def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool[Any]: + def _create_handoff_tool(self, target_id: str, description: str | None = None) -> FunctionTool: """Construct the synthetic handoff tool that signals routing to `target_id`.""" tool_name = get_handoff_tool_name(target_id) doc = description or f"Handoff to the {target_id} agent."