diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 60f02bc27..49a2c2af5 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -28,6 +28,7 @@ McpHttpTransportV20250326, McpHttpTransportV20250618, McpHttpTransportV20251125, + McpHttpTransportV20260618, ) from .protocol import Protocol, ToolSchema from .tool import ToolboxTool @@ -86,6 +87,15 @@ def __init__( ) match protocol: + case Protocol.MCP_v20260618: + self.__transport = McpHttpTransportV20260618( + url, + session, + protocol, + client_name, + client_version, + telemetry_enabled=telemetry_enabled, + ) case Protocol.MCP_v20251125: self.__transport = McpHttpTransportV20251125( url, diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py index 95a93a79f..ca5d0217f 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py @@ -16,10 +16,12 @@ from .v20250326.mcp import McpHttpTransportV20250326 from .v20250618.mcp import McpHttpTransportV20250618 from .v20251125.mcp import McpHttpTransportV20251125 +from .v20260618.mcp import McpHttpTransportV20260618 __all__ = [ "McpHttpTransportV20241105", "McpHttpTransportV20250326", "McpHttpTransportV20250618", "McpHttpTransportV20251125", + "McpHttpTransportV20260618", ] diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py new file mode 100644 index 000000000..d29681df5 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py @@ -0,0 +1,299 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Mapping, Optional, TypeVar + +from pydantic import BaseModel + +from ... import version +from ...protocol import ManifestSchema, TelemetryAttributes +from .. import telemetry +from ..transport_base import _McpHttpTransportBase +from . import types + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +class McpHttpTransportV20260618(_McpHttpTransportBase): + """Transport for the MCP draft Request-Metadata (v2026-06-18) protocol.""" + + async def _send_request( + self, + url: str, + request: types.MCPRequest[ReceiveResultT] | types.MCPNotification, + headers: Optional[Mapping[str, str]] = None, + is_retry: bool = False, + ) -> ReceiveResultT | None: + """Sends a JSON-RPC request to the MCP server with version negotiation retry.""" + req_headers = dict(headers or {}) + req_headers["MCP-Protocol-Version"] = self._protocol_version + + # Dynamically update the _meta protocol version in the parameters model + if hasattr(request, "params") and request.params is not None: + if ( + hasattr(request.params, "field_meta") + and request.params.field_meta is not None + ): + request.params.field_meta.protocol_version = self._protocol_version + + params = ( + request.params.model_dump(mode="json", exclude_none=True, by_alias=True) + if isinstance(request.params, BaseModel) + else request.params + ) + + rpc_msg: BaseModel + if isinstance(request, types.MCPNotification): + rpc_msg = types.JSONRPCNotification(method=request.method, params=params) + else: + rpc_msg = types.JSONRPCRequest(method=request.method, params=params) + + payload = rpc_msg.model_dump(mode="json", exclude_none=True) + + async with self._session.post( + url, json=payload, headers=req_headers + ) as response: + if response.status == 400: + try: + json_resp = await response.json() + if ( + "error" in json_resp + and json_resp["error"].get("code") == -32001 + ): + if is_retry: + raise RuntimeError( + "Protocol negotiation failed: server rejected negotiated version" + ) + + server_supported = ( + json_resp["error"].get("data", {}).get("supported", []) + ) + from ...protocol import Protocol + + client_supported = Protocol.get_supported_mcp_versions() + mutually_supported = [ + v for v in client_supported if v in server_supported + ] + + if mutually_supported: + self._protocol_version = mutually_supported[0] + return await self._send_request( + url, request, headers=headers, is_retry=True + ) + else: + raise RuntimeError( + "No mutually supported protocol version. " + f"Client supports: {client_supported}, " + f"Server supports: {server_supported}" + ) + except Exception as e: + if isinstance(e, RuntimeError): + raise e + + if not response.ok: + error_text = await response.text() + raise RuntimeError( + "API request failed with status" + f" {response.status} ({response.reason}). Server response:" + f" {error_text}" + ) + + if response.status == 204 or response.content.at_eof(): + return None + + json_resp = await response.json() + + # Check for JSON-RPC Error + if "error" in json_resp: + try: + err = types.JSONRPCError.model_validate(json_resp).error + raise RuntimeError( + f"MCP request failed with code {err.code}: {err.message}" + ) + except Exception: + # Fallback if the error doesn't match our schema exactly + raw_error = json_resp.get("error", {}) + raise RuntimeError(f"MCP request failed: {raw_error}") + + # Parse Result + if isinstance(request, types.MCPRequest): + try: + rpc_resp = types.JSONRPCResponse.model_validate(json_resp) + return request.get_result_model().model_validate(rpc_resp.result) + except Exception as e: + raise RuntimeError(f"Failed to parse JSON-RPC response: {e}") + return None + + async def _initialize_session( + self, headers: Optional[Mapping[str, str]] = None + ) -> None: + """No-op for stateless transport since there is no session handshake.""" + pass + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self._ensure_initialized(headers=headers) + + url = self._mcp_base_url + (toolset_name if toolset_name else "") + + meta = types.MCPMeta( + protocol_version=self._protocol_version, + client_info=types.Implementation( + name=self._client_name or "toolbox-core-python", + version=self._client_version or version.__version__, + ), + client_capabilities=types.ClientCapabilities(), + ) + + if self._telemetry_enabled: + operation_start = time.time() + span, traceparent, tracestate = telemetry.start_span( + self._tracer, + "tools/list", + self._protocol_version, + url, + network_transport="tcp", + ) + if span is not None: + meta.traceparent = traceparent or None + meta.tracestate = tracestate or None + + error: Optional[Exception] = None + try: + result = await self._send_request( + url=url, + request=types.ListToolsRequest( + params=types.ListToolsRequestParams(field_meta=meta) + ), + headers=headers, + ) + if result is None: + raise RuntimeError("Failed to list tools: No response from server.") + + tools_map = {t["name"]: self._convert_tool_schema(t) for t in result.tools} + + return ManifestSchema( + serverVersion="1.0.0", + tools=tools_map, + ) + except Exception as e: + error = e + raise + finally: + if self._telemetry_enabled: + operation_duration = time.time() - operation_start + telemetry.record_operation_duration( + self._operation_duration_histogram, + operation_duration, + "tools/list", + self._protocol_version, + url, + network_transport="tcp", + error=error, + ) + telemetry.end_span(span, error=error) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + manifest = await self.tools_list(headers=headers) + + if tool_name not in manifest.tools: + raise ValueError(f"Tool '{tool_name}' not found.") + + return ManifestSchema( + serverVersion=manifest.serverVersion, + tools={tool_name: manifest.tools[tool_name]}, + ) + + async def tool_invoke( + self, + tool_name: str, + arguments: dict, + headers: Optional[Mapping[str, str]], + telemetry_attributes: Optional[TelemetryAttributes] = None, + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self._ensure_initialized(headers=headers) + + payload = self._build_telemetry_payload(telemetry_attributes) + + meta = types.MCPMeta( + protocol_version=self._protocol_version, + client_info=types.Implementation( + name=self._client_name or "toolbox-core-python", + version=self._client_version or version.__version__, + ), + client_capabilities=types.ClientCapabilities(), + telemetry_attributes=payload, + ) + + span = None + if self._telemetry_enabled: + operation_start = time.time() + span, traceparent, tracestate = telemetry.start_span( + self._tracer, + "tools/call", + self._protocol_version, + self._mcp_base_url, + tool_name=tool_name, + network_transport="tcp", + ) + meta.traceparent = traceparent or None + meta.tracestate = tracestate or None + if span is not None and payload: + for key, value in payload.items(): + span.set_attribute(key, value) + + error: Optional[Exception] = None + try: + result = await self._send_request( + url=self._mcp_base_url, + request=types.CallToolRequest( + params=types.CallToolRequestParams( + name=tool_name, arguments=arguments, field_meta=meta + ) + ), + headers=headers, + ) + + if result is None: + raise RuntimeError( + f"Failed to invoke tool '{tool_name}': No response from server." + ) + + return self._process_tool_result_content(result.content) + except Exception as e: + error = e + raise + finally: + if self._telemetry_enabled: + operation_duration = time.time() - operation_start + telemetry.record_operation_duration( + self._operation_duration_histogram, + operation_duration, + "tools/call", + self._protocol_version, + self._mcp_base_url, + tool_name=tool_name, + network_transport="tcp", + error=error, + ) + telemetry.end_span(span, error=error) diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py new file mode 100644 index 000000000..79448aed9 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py @@ -0,0 +1,146 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, Generic, Literal, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + + +class _BaseMCPModel(BaseModel): + """Base model with common configuration.""" + + model_config = ConfigDict(extra="allow") + + +class JSONRPCRequest(_BaseMCPModel): + jsonrpc: Literal["2.0"] = "2.0" + id: str | int = Field(default_factory=lambda: str(uuid.uuid4())) + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(_BaseMCPModel): + """A notification which does not expect a response (no ID).""" + + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + result: dict[str, Any] + + +class ErrorData(_BaseMCPModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + + +class ClientCapabilities(_BaseMCPModel): + tools: dict[str, Any] = Field(default_factory=dict) + + +class Implementation(_BaseMCPModel): + name: str + version: str + + +class MCPMeta(_BaseMCPModel): + """Metadata for MCP requests. + + Carries the three required fields in io.modelcontextprotocol/* namespace. + """ + + protocol_version: str = Field( + ..., serialization_alias="io.modelcontextprotocol/protocolVersion" + ) + client_info: Implementation = Field( + ..., serialization_alias="io.modelcontextprotocol/clientInfo" + ) + client_capabilities: ClientCapabilities = Field( + ..., serialization_alias="io.modelcontextprotocol/clientCapabilities" + ) + + # Tracing and attributes + traceparent: str | None = None + tracestate: str | None = None + telemetry_attributes: dict[str, Any] | None = Field( + default=None, serialization_alias="dev.mcp-toolbox/telemetry" + ) + + +class ListToolsResult(_BaseMCPModel): + tools: list[dict[str, Any]] + + +class TextContent(_BaseMCPModel): + type: Literal["text"] + text: str + + +class CallToolResult(_BaseMCPModel): + content: list[TextContent] + isError: bool = False + + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class MCPRequest(_BaseMCPModel, Generic[ResultT]): + method: str + params: dict[str, Any] | BaseModel | None = None + + def get_result_model(self) -> Type[ResultT]: + raise NotImplementedError + + +class MCPNotification(_BaseMCPModel): + method: str + params: dict[str, Any] | BaseModel | None = None + + +class ListToolsRequestParams(_BaseMCPModel): + field_meta: MCPMeta = Field(..., serialization_alias="_meta") + + +class ListToolsRequest(MCPRequest[ListToolsResult]): + method: Literal["tools/list"] = "tools/list" + params: ListToolsRequestParams + + def get_result_model(self) -> Type[ListToolsResult]: + return ListToolsResult + + +class CallToolRequestParams(_BaseMCPModel): + name: str + arguments: dict[str, Any] + field_meta: MCPMeta = Field(..., serialization_alias="_meta") + + +class CallToolRequest(MCPRequest[CallToolResult]): + method: Literal["tools/call"] = "tools/call" + params: CallToolRequestParams + + def get_result_model(self) -> Type[CallToolResult]: + return CallToolResult diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index 191287e3d..6d2abdf12 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -47,17 +47,19 @@ def _empty_string_to_none(cls, value: Any) -> Any: class Protocol(str, Enum): """Defines how the client should choose between communication protocols.""" + MCP_v20260618 = "DRAFT-2026-v1" MCP_v20250618 = "2025-06-18" MCP_v20250326 = "2025-03-26" MCP_v20241105 = "2024-11-05" MCP_v20251125 = "2025-11-25" MCP = MCP_v20250618 - MCP_LATEST = MCP_v20251125 + MCP_LATEST = MCP_v20260618 @staticmethod def get_supported_mcp_versions() -> list[str]: """Returns a list of supported MCP protocol versions.""" return [ + Protocol.MCP_v20260618.value, Protocol.MCP_v20251125.value, Protocol.MCP_v20250618.value, Protocol.MCP_v20250326.value, diff --git a/packages/toolbox-core/tests/conformance/client.py b/packages/toolbox-core/tests/conformance/client.py index 9ab58812c..5d59d593a 100644 --- a/packages/toolbox-core/tests/conformance/client.py +++ b/packages/toolbox-core/tests/conformance/client.py @@ -18,9 +18,16 @@ import sys from toolbox_core.client import ToolboxClient +from toolbox_core.protocol import Protocol async def main(): + """Harness main execution block. + + NOTE: All non-protocol outputs (logs, traces, errors) must be directed to + sys.stderr. The test runner captures stdout for protocol messages only, + printing other content to stdout will pollute the stream and crash the runner. + """ if len(sys.argv) < 2: print("Usage: client.py ", file=sys.stderr) sys.exit(1) @@ -41,7 +48,13 @@ async def main(): client_headers = {"Accept": "application/json, text/event-stream"} - async with ToolboxClient(server_url, client_headers=client_headers) as client: + protocol = Protocol.MCP + if scenario == "request-metadata": + protocol = Protocol.MCP_v20260618 + + async with ToolboxClient( + server_url, client_headers=client_headers, protocol=protocol + ) as client: if scenario == "initialize": await client.load_toolset() print("Client initialization test completed", file=sys.stderr) @@ -51,6 +64,10 @@ async def main(): await add_numbers(a=1, b=2) print("Invoked add_numbers(a=1, b=2)", file=sys.stderr) + elif scenario == "request-metadata": + await client.load_toolset() + print("Client request-metadata test completed", file=sys.stderr) + else: # Default behavior: load default toolset to trigger standard interactions await client.load_toolset() diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20260618.py b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py new file mode 100644 index 000000000..a68e843cd --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py @@ -0,0 +1,280 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.v20260618 import types +from toolbox_core.mcp_transport.v20260618.mcp import McpHttpTransportV20260618 +from toolbox_core.protocol import ManifestSchema, Protocol + + +def create_fake_tools_list_result(): + return types.ListToolsResult( + tools=[ + { + "name": "get_weather", + "description": "Gets the weather.", + "inputSchema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + } + ] + ) + + +@pytest_asyncio.fixture( + params=[False, True], ids=["telemetry_disabled", "telemetry_enabled"] +) +async def transport(request, mocker): + if request.param: + mocker.patch("toolbox_core.mcp_transport.telemetry.TELEMETRY_AVAILABLE", True) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.get_tracer", return_value=MagicMock() + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.get_meter", return_value=MagicMock() + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.create_operation_duration_histogram", + return_value=MagicMock(), + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.create_session_duration_histogram", + return_value=MagicMock(), + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.start_span", + return_value=(MagicMock(), "00-traceparent", ""), + ) + mocker.patch("toolbox_core.mcp_transport.telemetry.end_span") + mocker.patch("toolbox_core.mcp_transport.telemetry.record_operation_duration") + mocker.patch("toolbox_core.mcp_transport.telemetry.record_session_duration") + mock_session = AsyncMock(spec=ClientSession) + transport = McpHttpTransportV20260618( + "http://fake-server.com", + session=mock_session, + protocol=Protocol.MCP_v20260618, + telemetry_enabled=request.param, + ) + yield transport + await transport.close() + + +@pytest.mark.asyncio +class TestMcpHttpTransportV20260618: + + # --- Request Sending Tests (Standard + Header) --- + + async def test_send_request_success(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult() + + async def test_send_request_adds_protocol_header(self, transport): + """Test that the MCP-Protocol-Version header is added.""" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + await transport._send_request("url", TestRequest()) + + call_args = transport._session.post.call_args + headers = call_args.kwargs["headers"] + assert headers["MCP-Protocol-Version"] == "DRAFT-2026-v1" + + # --- Version Negotiation Tests --- + + async def test_version_negotiation_retry_success(self, transport): + """Tests that the client retries when the server rejects initial version.""" + mock_response_reject = AsyncMock() + mock_response_reject.ok = False + mock_response_reject.status = 400 + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32001, + "message": "Unsupported protocol version", + "data": {"supported": ["DRAFT-2026-v1"]}, + }, + } + + mock_response_accept = AsyncMock() + mock_response_accept.ok = True + mock_response_accept.status = 200 + mock_response_accept.content = Mock() + mock_response_accept.content.at_eof.return_value = False + mock_response_accept.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "result": {}, + } + + # Configure first call to return reject response, second call to succeed + transport._session.post.return_value.__aenter__.side_effect = [ + mock_response_reject, + mock_response_accept, + ] + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult() + assert transport._session.post.call_count == 2 + + async def test_version_negotiation_loop_prevention(self, transport): + """Tests that the client raises an error if the retry gets rejected (loop prevention).""" + mock_response_reject = AsyncMock() + mock_response_reject.ok = False + mock_response_reject.status = 400 + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32001, + "message": "Unsupported protocol version", + "data": {"supported": ["DRAFT-2026-v1"]}, + }, + } + + # Return rejection repeatedly + transport._session.post.return_value.__aenter__.side_effect = [ + mock_response_reject, + mock_response_reject, + ] + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises( + RuntimeError, + match="Protocol negotiation failed: server rejected negotiated version", + ): + await transport._send_request("url", TestRequest()) + + assert transport._session.post.call_count == 2 + + async def test_version_negotiation_empty_intersection(self, transport): + """Tests that the client errors immediately without retrying when there is no mutual version.""" + mock_response_reject = AsyncMock() + mock_response_reject.ok = False + mock_response_reject.status = 400 + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32001, + "message": "Unsupported protocol version", + "data": {"supported": ["UNSUPPORTED-VERSION"]}, + }, + } + + transport._session.post.return_value.__aenter__.return_value = ( + mock_response_reject + ) + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises( + RuntimeError, match="No mutually supported protocol version" + ): + await transport._send_request("url", TestRequest()) + + assert transport._session.post.call_count == 1 + + # --- Tool Management Tests --- + + async def test_tools_list_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + manifest = await transport.tools_list() + assert isinstance(manifest, ManifestSchema) + assert "get_weather" in manifest.tools + + async def test_tool_invoke_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text="Result")] + ), + ) + result = await transport.tool_invoke("tool", {}, {}) + assert result == "Result" diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index 6acaeaded..d6f47c5eb 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -24,9 +24,11 @@ from toolbox_core.tool import ToolboxTool +# TODO: Include draft versions in E2E integration tests once the server +# supports SEP-2575 (stateless MCP / Request-Metadata). @pytest_asyncio.fixture( scope="function", - params=Protocol.get_supported_mcp_versions(), + params=[v for v in Protocol.get_supported_mcp_versions() if "DRAFT" not in v], ) async def toolbox(request): """Creates a ToolboxClient instance shared by all tests in this module.""" diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index 3f7200e08..c467975c1 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -77,7 +77,13 @@ def test_get_supported_mcp_versions(): Tests that get_supported_mcp_versions returns the correct list of versions, sorted from newest to oldest. """ - expected_versions = ["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"] + expected_versions = [ + "DRAFT-2026-v1", + "2025-11-25", + "2025-06-18", + "2025-03-26", + "2024-11-05", + ] supported_versions = Protocol.get_supported_mcp_versions() assert supported_versions == expected_versions