diff --git a/plugins/communication_protocols/gnmi/README.md b/plugins/communication_protocols/gnmi/README.md new file mode 100644 index 0000000..50e9e1a --- /dev/null +++ b/plugins/communication_protocols/gnmi/README.md @@ -0,0 +1,94 @@ +# UTCP gNMI Plugin + +This plugin adds a gNMI (gRPC) communication protocol compatible with UTCP 1.0. It follows UTCP’s plugin architecture: a CallTemplate and serializer, a CommunicationProtocol for discovery and execution, and registration via the `utcp.plugins` entry point. + +## Installation + +- Ensure you have Python 3.10+ +- Dependencies: `utcp`, `grpcio`, `protobuf`, `pydantic`, `aiohttp` +- Install in your environment (example if published): + +``` +pip install utcp-gnmi +``` + +## Registration + +Register the plugin into UTCP’s registries: + +``` +from utcp_gnmi import register +register() +``` + +This registers: +- Protocol: `gnmi` +- Call template serializer: `gnmi` + +## Configuration (UTCP 1.0) + +Use `UtcpClientConfig.manual_call_templates` to declare gNMI providers and tools. + +Example: + +``` +{ + "manual_call_templates": [ + { + "name": "routerA", + "call_template_type": "gnmi", + "target": "localhost:50051", + "use_tls": false, + "metadata": {"authorization": "Bearer ${API_TOKEN}"}, + "metadata_fields": ["tenant-id"], + "operation": "get", + "stub_module": "gnmi_pb2_grpc", + "message_module": "gnmi_pb2" + } + ] +} +``` + +Fields: +- `call_template_type`: must be `gnmi` +- `target`: gRPC host:port +- `use_tls`: boolean; TLS required unless localhost/127.0.0.1 +- `metadata`: static key/value pairs added to gRPC metadata +- `metadata_fields`: dynamic keys populated from tool args +- `operation`: one of `capabilities`, `get`, `set`, `subscribe` +- `stub_module`/`message_module`: import paths to generated Python stubs + +## Security + +- Enforces TLS (`grpc.aio.secure_channel`) unless `target` is `localhost` or `127.0.0.1` +- Do not use insecure channels over public networks +- Prefer mTLS for production environments (future enhancement adds cert fields) + +## Authentication + +Supported via UTCP `Auth` model: +- API Key: injects into `authorization` (or custom) metadata +- Basic: `authorization: Basic ` +- OAuth2: client credentials; token fetched via `aiohttp` and cached + +## Operations + +- `capabilities`: unary `Capabilities` RPC +- `get`: unary `Get` RPC; maps `paths` list into `GetRequest.path` +- `set`: unary `Set` RPC; maps `updates` list into `SetRequest.update` +- `subscribe`: streaming `Subscribe` RPC; yields responses as dicts + +## Testing + +Run tests: + +``` +python -m pytest plugins/communication_protocols/gnmi/tests/test_gnmi_plugin.py -q +``` + +The tests validate manual registration, tool presence (including `subscribe`), and serializer round-trip. + +## Notes + +- Tool discovery registers canonical gNMI operations (`capabilities/get/set/subscribe`) +- Reflection-based discovery and mTLS configuration can be added in follow-up PRs \ No newline at end of file diff --git a/plugins/communication_protocols/gnmi/pyproject.toml b/plugins/communication_protocols/gnmi/pyproject.toml new file mode 100644 index 0000000..1f5d151 --- /dev/null +++ b/plugins/communication_protocols/gnmi/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "utcp-gnmi" +version = "1.0.0" +authors = [ + { name = "UTCP Contributors" }, +] +description = "UTCP gNMI communication protocol plugin over gRPC" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "pydantic>=2.0", + "protobuf>=4.21", + "grpcio>=1.60", + "utcp>=1.0", + "aiohttp>=3.8" +] +license = "MPL-2.0" + +[project.optional-dependencies] +dev = [ + "build", + "pytest", + "pytest-asyncio", + "pytest-cov", +] + +[project.entry-points."utcp.plugins"] +gnmi = "utcp_gnmi:register" \ No newline at end of file diff --git a/plugins/communication_protocols/gnmi/src/utcp_gnmi/__init__.py b/plugins/communication_protocols/gnmi/src/utcp_gnmi/__init__.py new file mode 100644 index 0000000..f36dc6d --- /dev/null +++ b/plugins/communication_protocols/gnmi/src/utcp_gnmi/__init__.py @@ -0,0 +1,12 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template +from utcp_gnmi.gnmi_communication_protocol import GnmiCommunicationProtocol +from utcp_gnmi.gnmi_call_template import GnmiCallTemplateSerializer + +def register(): + register_communication_protocol("gnmi", GnmiCommunicationProtocol()) + register_call_template("gnmi", GnmiCallTemplateSerializer()) + +__all__ = [ + "GnmiCommunicationProtocol", + "GnmiCallTemplateSerializer", +] \ No newline at end of file diff --git a/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_call_template.py b/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_call_template.py new file mode 100644 index 0000000..ec5f897 --- /dev/null +++ b/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_call_template.py @@ -0,0 +1,26 @@ +from typing import Optional, Dict, List, Literal + +from utcp.data.call_template import CallTemplate +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback + +class GnmiCallTemplate(CallTemplate): + call_template_type: Literal["gnmi"] = "gnmi" + target: str + use_tls: bool = True + metadata: Optional[Dict[str, str]] = None + metadata_fields: Optional[List[str]] = None + operation: Literal["capabilities", "get", "set", "subscribe"] = "get" + stub_module: str = "gnmi_pb2_grpc" + message_module: str = "gnmi_pb2" + +class GnmiCallTemplateSerializer(Serializer[GnmiCallTemplate]): + def to_dict(self, obj: GnmiCallTemplate) -> dict: + return obj.model_dump() + + def validate_dict(self, obj: dict) -> GnmiCallTemplate: + try: + return GnmiCallTemplate.model_validate(obj) + except Exception as e: + raise UtcpSerializerValidationError("Invalid GnmiCallTemplate: " + traceback.format_exc()) from e \ No newline at end of file diff --git a/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_communication_protocol.py b/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_communication_protocol.py new file mode 100644 index 0000000..a3d074e --- /dev/null +++ b/plugins/communication_protocols/gnmi/src/utcp_gnmi/gnmi_communication_protocol.py @@ -0,0 +1,299 @@ +import importlib +import logging +from typing import Dict, Any, List, AsyncGenerator + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool, JsonSchema +from utcp.data.utcp_manual import UtcpManual +from utcp.data.register_manual_response import RegisterManualResult +from utcp_gnmi.gnmi_call_template import GnmiCallTemplate +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth + +class GnmiCommunicationProtocol(CommunicationProtocol): + def __init__(self): + self._oauth_tokens: Dict[str, Dict[str, Any]] = {} + def _load_gnmi_modules(self, tool_call_template: GnmiCallTemplate) -> tuple[Any, Any, Any, Any, Any]: + grpc = importlib.import_module("grpc") + aio = importlib.import_module("grpc.aio") + json_format = importlib.import_module("google.protobuf.json_format") + stub_mod = importlib.import_module(tool_call_template.stub_module) + msg_mod = importlib.import_module(tool_call_template.message_module) + return grpc, aio, json_format, stub_mod, msg_mod + def _create_grpc_channel(self, grpc, aio, target: str, use_tls: bool) -> Any: + if use_tls: + creds = grpc.ssl_channel_credentials() + return aio.secure_channel(target, creds) + return aio.insecure_channel(target) + def _create_grpc_stub(self, stub_mod, channel) -> Any: + stub = None + for attr in dir(stub_mod): + if attr.endswith("Stub"): + stub_cls = getattr(stub_mod, attr) + stub = stub_cls(channel) + break + if stub is None: + raise ValueError("gNMI stub not found in stub_module") + return stub + async def _build_metadata(self, tool_call_template: GnmiCallTemplate, tool_args: Dict[str, Any]) -> List[tuple[str, str]]: + metadata: List[tuple[str, str]] = [] + if tool_call_template.metadata: + metadata.extend([(k, v) for k, v in tool_call_template.metadata.items()]) + if tool_call_template.metadata_fields: + for k in tool_call_template.metadata_fields: + if k in tool_args: + metadata.append((k, str(tool_args[k]))) + if tool_call_template.auth: + if isinstance(tool_call_template.auth, ApiKeyAuth): + if tool_call_template.auth.api_key: + metadata.append((tool_call_template.auth.var_name or "authorization", tool_call_template.auth.api_key)) + elif isinstance(tool_call_template.auth, BasicAuth): + import base64 + token = base64.b64encode(f"{tool_call_template.auth.username}:{tool_call_template.auth.password}".encode()).decode() + metadata.append(("authorization", f"Basic {token}")) + elif isinstance(tool_call_template.auth, OAuth2Auth): + token = await self._handle_oauth2(tool_call_template.auth) + metadata.append(("authorization", f"Bearer {token}")) + return metadata + async def register_manual(self, caller, manual_call_template: CallTemplate) -> RegisterManualResult: + if not isinstance(manual_call_template, GnmiCallTemplate): + raise ValueError("GnmiCommunicationProtocol can only be used with GnmiCallTemplate") + + target = manual_call_template.target + if not manual_call_template.use_tls: + host = target + if host.startswith("[") and "]" in host: + host = host[1:host.index("]")] + else: + host = host.split(":")[0] + is_local = host == "localhost" + try: + from ipaddress import ip_address + is_loopback = ip_address(host).is_loopback + except Exception: + is_loopback = False + if not (is_local or is_loopback): + return RegisterManualResult( + success=False, + manual_call_template=manual_call_template, + manual=UtcpManual(manual_version="0.0.0", tools=[]), + errors=["Insecure channel only allowed for localhost or loopback addresses"] + ) + + tools: List[Tool] = [] + ops = ["capabilities", "get", "set", "subscribe"] + for op in ops: + tct = GnmiCallTemplate( + name=manual_call_template.name, + call_template_type="gnmi", + auth=manual_call_template.auth, + target=manual_call_template.target, + use_tls=manual_call_template.use_tls, + metadata=manual_call_template.metadata, + metadata_fields=manual_call_template.metadata_fields, + operation=op, + stub_module=manual_call_template.stub_module, + message_module=manual_call_template.message_module, + ) + inputs = JsonSchema(type="object", properties={}) + outputs = JsonSchema(type="object", properties={}) + tool = Tool( + name=op, + description="", + inputs=inputs, + outputs=outputs, + tags=["gnmi", op], + tool_call_template=tct, + ) + tools.append(tool) + + manual = UtcpManual(manual_version="1.0.0", tools=tools) + return RegisterManualResult( + success=True, + manual_call_template=manual_call_template, + manual=manual, + errors=[], + ) + + async def deregister_manual(self, caller, manual_call_template: CallTemplate) -> None: + return None + + async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> Any: + if not isinstance(tool_call_template, GnmiCallTemplate): + raise ValueError("GnmiCommunicationProtocol can only be used with GnmiCallTemplate") + + op = tool_call_template.operation + target = tool_call_template.target + + metadata = await self._build_metadata(tool_call_template, tool_args) + + grpc, aio, json_format, stub_mod, msg_mod = self._load_gnmi_modules(tool_call_template) + channel = self._create_grpc_channel(grpc, aio, target, tool_call_template.use_tls) + + try: + stub = self._create_grpc_stub(stub_mod, channel) + + if op == "capabilities": + req = getattr(msg_mod, "CapabilityRequest")() + resp = await stub.Capabilities(req, metadata=metadata) + elif op == "get": + req = getattr(msg_mod, "GetRequest")() + paths = tool_args.get("paths", []) + for p in paths: + path_msg = getattr(msg_mod, "Path")() + for elem in [e for e in p.strip("/").split("/") if e]: + pe = getattr(msg_mod, "PathElem")(name=elem) + path_msg.elem.append(pe) + req.path.append(path_msg) + resp = await stub.Get(req, metadata=metadata) + elif op == "set": + req = getattr(msg_mod, "SetRequest")() + updates = tool_args.get("updates", []) + for upd in updates: + path_msg = getattr(msg_mod, "Path")() + for elem in [e for e in str(upd.get("path", "")).strip("/").split("/") if e]: + pe = getattr(msg_mod, "PathElem")(name=elem) + path_msg.elem.append(pe) + v = upd.get("value", "") + val = None + try: + import json + if isinstance(v, (dict, list)): + val = getattr(msg_mod, "TypedValue")(json_ietf_val=json.dumps(v).encode("utf-8")) + elif isinstance(v, bool): + val = getattr(msg_mod, "TypedValue")(bool_val=v) + elif isinstance(v, int) and not isinstance(v, bool): + val = getattr(msg_mod, "TypedValue")(int_val=v) + elif isinstance(v, float): + val = getattr(msg_mod, "TypedValue")(float_val=v) + elif isinstance(v, str): + val = getattr(msg_mod, "TypedValue")(string_val=v) + else: + val = getattr(msg_mod, "TypedValue")(json_ietf_val=json.dumps(v).encode("utf-8")) + except Exception: + val = getattr(msg_mod, "TypedValue")(string_val=str(v)) + update_msg = getattr(msg_mod, "Update")(path=path_msg, val=val) + req.update.append(update_msg) + resp = await stub.Set(req, metadata=metadata) + elif op == "subscribe": + raise ValueError("Unsupported gNMI operation") + else: + raise ValueError("Unsupported gNMI operation") + + return json_format.MessageToDict(resp) + finally: + await channel.close() + + async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate) -> AsyncGenerator[Any, None]: + if not isinstance(tool_call_template, GnmiCallTemplate): + raise ValueError("GnmiCommunicationProtocol can only be used with GnmiCallTemplate") + if tool_call_template.operation != "subscribe": + result = await self.call_tool(caller, tool_name, tool_args, tool_call_template) + yield result + return + grpc, aio, json_format, stub_mod, msg_mod = self._load_gnmi_modules(tool_call_template) + target = tool_call_template.target + channel = self._create_grpc_channel(grpc, aio, target, tool_call_template.use_tls) + try: + stub = self._create_grpc_stub(stub_mod, channel) + metadata = await self._build_metadata(tool_call_template, tool_args) + req = getattr(msg_mod, "SubscribeRequest")() + sub_list = getattr(msg_mod, "SubscriptionList")() + mode_str = str(tool_args.get("mode", "STREAM")).upper() + try: + sub_list.mode = getattr(msg_mod, "SubscriptionList").Mode.Value(mode_str) + except Exception: + mode_map = {"STREAM": 0, "ONCE": 1, "POLL": 2} + sub_list.mode = mode_map.get(mode_str, 0) + paths = tool_args.get("paths", []) + for p in paths: + path_msg = getattr(msg_mod, "Path")() + for elem in [e for e in p.strip("/").split("/") if e]: + pe = getattr(msg_mod, "PathElem")(name=elem) + path_msg.elem.append(pe) + sub = getattr(msg_mod, "Subscription")(path=path_msg) + sub_list.subscription.append(sub) + req.subscribe.CopyFrom(sub_list) + call = stub.Subscribe(req, metadata=metadata) + async for resp in call: + yield json_format.MessageToDict(resp) + finally: + await channel.close() + + async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: + import aiohttp + import time + key = f"{auth_details.token_url}|{auth_details.client_id}|{auth_details.scope}" + now = time.time() + cached = self._oauth_tokens.get(key) + if cached and cached.get("access_token") and cached.get("expires_at", now + 1) > now: + return cached["access_token"] + async with aiohttp.ClientSession() as session: + try: + body_data = { + "grant_type": "client_credentials", + "client_id": auth_details.client_id, + "client_secret": auth_details.client_secret, + "scope": auth_details.scope, + } + async with session.post(auth_details.token_url, data=body_data) as response: + response.raise_for_status() + token_response = await response.json() + access_token = token_response.get("access_token") + expires_in = token_response.get("expires_in") + ttl = expires_in if isinstance(expires_in, (int, float)) else 300 + self._oauth_tokens[key] = {"access_token": access_token, "expires_at": now + ttl - 10} + return access_token + except aiohttp.ClientResponseError as e: + logging.getLogger(__name__).warning(f"OAuth2 client_credentials failed: {e.status}") + from aiohttp import BasicAuth as AiohttpBasicAuth + header_auth = AiohttpBasicAuth(auth_details.client_id, auth_details.client_secret) + header_data = { + "grant_type": "client_credentials", + "scope": auth_details.scope, + } + async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + response.raise_for_status() + token_response = await response.json() + access_token = token_response.get("access_token") + expires_in = token_response.get("expires_in") + ttl = expires_in if isinstance(expires_in, (int, float)) else 300 + self._oauth_tokens[key] = {"access_token": access_token, "expires_at": time.time() + ttl - 10} + return access_token + except aiohttp.ClientError as e: + logging.getLogger(__name__).warning(f"OAuth2 request error: {e}") + from aiohttp import BasicAuth as AiohttpBasicAuth + header_auth = AiohttpBasicAuth(auth_details.client_id, auth_details.client_secret) + header_data = { + "grant_type": "client_credentials", + "scope": auth_details.scope, + } + async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + response.raise_for_status() + token_response = await response.json() + access_token = token_response.get("access_token") + expires_in = token_response.get("expires_in") + ttl = expires_in if isinstance(expires_in, (int, float)) else 300 + self._oauth_tokens[key] = {"access_token": access_token, "expires_at": time.time() + ttl - 10} + return access_token + except Exception as e: + logging.getLogger(__name__).error(f"OAuth2 unexpected error: {e}") + from aiohttp import BasicAuth as AiohttpBasicAuth + header_auth = AiohttpBasicAuth(auth_details.client_id, auth_details.client_secret) + header_data = { + "grant_type": "client_credentials", + "scope": auth_details.scope, + } + async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + response.raise_for_status() + token_response = await response.json() + access_token = token_response.get("access_token") + expires_in = token_response.get("expires_in") + ttl = expires_in if isinstance(expires_in, (int, float)) else 300 + self._oauth_tokens[key] = {"access_token": access_token, "expires_at": time.time() + ttl - 10} + return access_token + + async def close(self) -> None: + self._oauth_tokens.clear() \ No newline at end of file diff --git a/plugins/communication_protocols/gnmi/tests/test_gnmi_plugin.py b/plugins/communication_protocols/gnmi/tests/test_gnmi_plugin.py new file mode 100644 index 0000000..4126fb5 --- /dev/null +++ b/plugins/communication_protocols/gnmi/tests/test_gnmi_plugin.py @@ -0,0 +1,50 @@ +import sys +from pathlib import Path +import pytest + +plugin_src = Path(__file__).parent.parent / "src" +sys.path.insert(0, str(plugin_src)) + +core_src = Path(__file__).parent.parent.parent.parent.parent / "core" / "src" +sys.path.insert(0, str(core_src)) + +from utcp.utcp_client import UtcpClient +from utcp_gnmi import register + +@pytest.mark.asyncio +async def test_register_manual_and_tools(): + register() + client = await UtcpClient.create(config={ + "manual_call_templates": [ + { + "name": "routerA", + "call_template_type": "gnmi", + "target": "localhost:50051", + "use_tls": False, + "operation": "get" + } + ] + }) + tools = await client.config.tool_repository.get_tools() + names = [t.name for t in tools] + assert any(n.startswith("routerA.") for n in names) + assert any(n.endswith("subscribe") for n in names) + +def test_serializer_roundtrip(): + from utcp_gnmi.gnmi_call_template import GnmiCallTemplateSerializer + serializer = GnmiCallTemplateSerializer() + data = { + "name": "routerB", + "call_template_type": "gnmi", + "target": "localhost:50051", + "use_tls": False, + "metadata": {"authorization": "Bearer token"}, + "metadata_fields": ["tenant-id"], + "operation": "set", + "stub_module": "gnmi_pb2_grpc", + "message_module": "gnmi_pb2" + } + obj = serializer.validate_dict(data) + out = serializer.to_dict(obj) + assert out["call_template_type"] == "gnmi" + assert out["operation"] == "set" \ No newline at end of file diff --git a/plugins/communication_protocols/gql/README.md b/plugins/communication_protocols/gql/README.md index 8febb5a..3aad98e 100644 --- a/plugins/communication_protocols/gql/README.md +++ b/plugins/communication_protocols/gql/README.md @@ -1 +1,47 @@ -Find the UTCP readme at https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file + +# UTCP GraphQL Communication Protocol Plugin + +This plugin integrates GraphQL as a UTCP 1.0 communication protocol and call template. It supports discovery via schema introspection, authenticated calls, and header handling. + +## Getting Started + +### Installation + +```bash +pip install utcp-gql +``` + +### Registration + +```python +import utcp_gql +utcp_gql.register() +``` + +## How To Use + +- Ensure the plugin is imported and registered: `import utcp_gql; utcp_gql.register()`. +- Add a manual in your client config: + ```json + { + "name": "my_graph", + "call_template_type": "graphql", + "url": "https://your.graphql/endpoint", + "operation_type": "query", + "headers": { "x-client": "utcp" }, + "header_fields": ["x-session-id"] + } + ``` +- Call a tool: + ```python + await client.call_tool("my_graph.someQuery", {"id": "123", "x-session-id": "abc"}) + ``` + +## Notes + +- Tool names are prefixed by the manual name (e.g., `my_graph.someQuery`). +- Headers merge static `headers` plus whitelisted dynamic fields from `header_fields`. +- Supported auth: API key, Basic auth, OAuth2 (client-credentials). +- Security: only `https://` or `http://localhost`/`http://127.0.0.1` endpoints. + +For UTCP core docs, see https://github.com/universal-tool-calling-protocol/python-utcp. \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py index e69de29..7362502 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/__init__.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/__init__.py @@ -0,0 +1,9 @@ +from utcp.plugins.discovery import register_communication_protocol, register_call_template + +from .gql_communication_protocol import GraphQLCommunicationProtocol +from .gql_call_template import GraphQLProvider, GraphQLProviderSerializer + + +def register(): + register_communication_protocol("graphql", GraphQLCommunicationProtocol()) + register_call_template("graphql", GraphQLProviderSerializer()) \ No newline at end of file diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py index dfe5b07..3848d29 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_call_template.py @@ -1,7 +1,10 @@ from utcp.data.call_template import CallTemplate -from utcp.data.auth import Auth +from utcp.data.auth import Auth, AuthSerializer +from utcp.interfaces.serializer import Serializer +from utcp.exceptions import UtcpSerializerValidationError +import traceback from typing import Dict, List, Optional, Literal -from pydantic import Field +from pydantic import Field, field_serializer, field_validator class GraphQLProvider(CallTemplate): """Provider configuration for GraphQL-based tools. @@ -27,3 +30,31 @@ class GraphQLProvider(CallTemplate): auth: Optional[Auth] = None headers: Optional[Dict[str, str]] = None header_fields: Optional[List[str]] = Field(default=None, description="List of input fields to be sent as request headers for the initial connection.") + + @field_serializer("auth") + def serialize_auth(self, auth: Optional[Auth]): + if auth is None: + return None + return AuthSerializer().to_dict(auth) + + @field_validator("auth", mode="before") + @classmethod + def validate_auth(cls, v: Optional[Auth | dict]): + if v is None: + return None + if isinstance(v, Auth): + return v + return AuthSerializer().validate_dict(v) + + +class GraphQLProviderSerializer(Serializer[GraphQLProvider]): + def to_dict(self, obj: GraphQLProvider) -> dict: + return obj.model_dump() + + def validate_dict(self, data: dict) -> GraphQLProvider: + try: + return GraphQLProvider.model_validate(data) + except Exception as e: + raise UtcpSerializerValidationError( + f"Invalid GraphQLProvider: {e}\n{traceback.format_exc()}" + ) diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index f27f803..82811ab 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -1,36 +1,55 @@ -import sys -from typing import Dict, Any, List, Optional, Callable +import logging +from typing import Dict, Any, List, Optional, AsyncGenerator, TYPE_CHECKING + import aiohttp -import asyncio -import ssl from gql import Client as GqlClient, gql as gql_query from gql.transport.aiohttp import AIOHTTPTransport -from utcp.client.client_transport_interface import ClientTransportInterface -from utcp.shared.provider import Provider, GraphQLProvider -from utcp.shared.tool import Tool, ToolInputOutputSchema -from utcp.shared.auth import ApiKeyAuth, BasicAuth, OAuth2Auth -import logging + +from utcp.interfaces.communication_protocol import CommunicationProtocol +from utcp.data.call_template import CallTemplate +from utcp.data.tool import Tool, JsonSchema +from utcp.data.utcp_manual import UtcpManual +from utcp.data.register_manual_response import RegisterManualResult +from utcp.data.auth_implementations.api_key_auth import ApiKeyAuth +from utcp.data.auth_implementations.basic_auth import BasicAuth +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth + +from utcp_gql.gql_call_template import GraphQLProvider + +if TYPE_CHECKING: + from utcp.utcp_client import UtcpClient + logging.basicConfig( level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s" + format="%(asctime)s [%(levelname)s] %(filename)s:%(lineno)d - %(message)s", ) logger = logging.getLogger(__name__) -class GraphQLClientTransport(ClientTransportInterface): - """ - Simple, robust, production-ready GraphQL transport using gql. - Stateless, per-operation. Supports all GraphQL features. + +class GraphQLCommunicationProtocol(CommunicationProtocol): + """GraphQL protocol implementation for UTCP 1.0. + + - Discovers tools via GraphQL schema introspection. + - Executes per-call sessions using `gql` over HTTP(S). + - Supports `ApiKeyAuth`, `BasicAuth`, and `OAuth2Auth`. + - Enforces HTTPS or localhost for security. """ - def __init__(self): + + def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str): - if not (url.startswith("https://") or url.startswith("http://localhost") or url.startswith("http://127.0.0.1")): + def _enforce_https_or_localhost(self, url: str) -> None: + if not ( + url.startswith("https://") + or url.startswith("http://localhost") + or url.startswith("http://127.0.0.1") + ): raise ValueError( - f"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. Got: {url}. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks." + "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " + "Non-secure URLs are vulnerable to man-in-the-middle attacks. " + f"Got: {url}." ) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: @@ -39,10 +58,10 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: return self._oauth_tokens[client_id]["access_token"] async with aiohttp.ClientSession() as session: data = { - 'grant_type': 'client_credentials', - 'client_id': client_id, - 'client_secret': auth.client_secret, - 'scope': auth.scope + "grant_type": "client_credentials", + "client_id": client_id, + "client_secret": auth.client_secret, + "scope": auth.scope, } async with session.post(auth.token_url, data=data) as resp: resp.raise_for_status() @@ -50,87 +69,161 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: self._oauth_tokens[client_id] = token_response return token_response["access_token"] - async def _prepare_headers(self, provider: GraphQLProvider) -> Dict[str, str]: - headers = provider.headers.copy() if provider.headers else {} + async def _prepare_headers( + self, provider: GraphQLProvider, tool_args: Optional[Dict[str, Any]] = None + ) -> Dict[str, str]: + headers: Dict[str, str] = provider.headers.copy() if provider.headers else {} if provider.auth: if isinstance(provider.auth, ApiKeyAuth): - if provider.auth.api_key: - if provider.auth.location == "header": - headers[provider.auth.var_name] = provider.auth.api_key - # (query/cookie not supported for GraphQL by default) + if provider.auth.api_key and provider.auth.location == "header": + headers[provider.auth.var_name] = provider.auth.api_key elif isinstance(provider.auth, BasicAuth): import base64 + userpass = f"{provider.auth.username}:{provider.auth.password}" headers["Authorization"] = "Basic " + base64.b64encode(userpass.encode()).decode() elif isinstance(provider.auth, OAuth2Auth): token = await self._handle_oauth2(provider.auth) headers["Authorization"] = f"Bearer {token}" + + # Map selected tool_args into headers if requested + if tool_args and provider.header_fields: + for field in provider.header_fields: + if field in tool_args and isinstance(tool_args[field], str): + headers[field] = tool_args[field] + return headers - async def register_tool_provider(self, manual_provider: Provider) -> List[Tool]: - if not isinstance(manual_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(manual_provider.url) - headers = await self._prepare_headers(manual_provider) - transport = AIOHTTPTransport(url=manual_provider.url, headers=headers) - async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - schema = session.client.schema - tools = [] - # Queries - if hasattr(schema, 'query_type') and schema.query_type: - for name, field in schema.query_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Mutations - if hasattr(schema, 'mutation_type') and schema.mutation_type: - for name, field in schema.mutation_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - # Subscriptions (listed, but not called here) - if hasattr(schema, 'subscription_type') and schema.subscription_type: - for name, field in schema.subscription_type.fields.items(): - tools.append(Tool( - name=name, - description=getattr(field, 'description', '') or '', - inputs=ToolInputOutputSchema(required=None), - tool_provider=manual_provider - )) - return tools - - async def deregister_tool_provider(self, manual_provider: Provider) -> None: - # Stateless: nothing to do - pass - - async def call_tool(self, tool_name: str, tool_args: Dict[str, Any], tool_provider: Provider, query: Optional[str] = None) -> Any: - if not isinstance(tool_provider, GraphQLProvider): - raise ValueError("GraphQLClientTransport can only be used with GraphQLProvider") - self._enforce_https_or_localhost(tool_provider.url) - headers = await self._prepare_headers(tool_provider) - transport = AIOHTTPTransport(url=tool_provider.url, headers=headers) + async def register_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> RegisterManualResult: + if not isinstance(manual_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(manual_call_template.url) + + try: + headers = await self._prepare_headers(manual_call_template) + transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) + async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + schema = session.client.schema + tools: List[Tool] = [] + + # Queries + if hasattr(schema, "query_type") and schema.query_type: + for name, field in schema.query_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Mutations + if hasattr(schema, "mutation_type") and schema.mutation_type: + for name, field in schema.mutation_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + # Subscriptions (listed for completeness) + if hasattr(schema, "subscription_type") and schema.subscription_type: + for name, field in schema.subscription_type.fields.items(): + tools.append( + Tool( + name=name, + description=getattr(field, "description", "") or "", + inputs=JsonSchema(type="object"), + outputs=JsonSchema(type="object"), + tool_call_template=manual_call_template, + ) + ) + + manual = UtcpManual(tools=tools) + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=manual, + success=True, + errors=[], + ) + except Exception as e: + logger.error(f"GraphQL manual registration failed for '{manual_call_template.name}': {e}") + return RegisterManualResult( + manual_call_template=manual_call_template, + manual=UtcpManual(manual_version="0.0.0", tools=[]), + success=False, + errors=[str(e)], + ) + + async def deregister_manual( + self, caller: "UtcpClient", manual_call_template: CallTemplate + ) -> None: + # Stateless: nothing to clean up + return None + + async def call_tool( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> Any: + if not isinstance(tool_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + self._enforce_https_or_localhost(tool_call_template.url) + + headers = await self._prepare_headers(tool_call_template, tool_args) + transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: - if query is not None: - document = gql_query(query) - result = await session.execute(document, variable_values=tool_args) - return result - # If no query provided, build a simple query - # Default to query operation - op_type = getattr(tool_provider, 'operation_type', 'query') - arg_str = ', '.join(f"${k}: String" for k in tool_args.keys()) + op_type = getattr(tool_call_template, "operation_type", "query") + # Strip manual prefix if present (client prefixes at save time) + base_tool_name = tool_name.split(".", 1)[-1] if "." in tool_name else tool_name + # Filter out header fields from GraphQL variables; these are sent via HTTP headers + header_fields = tool_call_template.header_fields or [] + filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields} + + defs = [] + for k, v in filtered_args.items(): + if isinstance(v, bool): + t = "Boolean" + elif isinstance(v, int) and not isinstance(v, bool): + t = "Int" + elif isinstance(v, float): + t = "Float" + else: + t = "String" + defs.append(f"${k}: {t}") + arg_str = ", ".join(defs) var_defs = f"({arg_str})" if arg_str else "" - arg_pass = ', '.join(f"{k}: ${k}" for k in tool_args.keys()) + arg_pass = ", ".join(f"{k}: ${k}" for k in filtered_args.keys()) arg_pass = f"({arg_pass})" if arg_pass else "" - gql_str = f"{op_type} {var_defs} {{ {tool_name}{arg_pass} }}" + + gql_str = f"{op_type} {var_defs} {{ {base_tool_name}{arg_pass} }}" document = gql_query(gql_str) - result = await session.execute(document, variable_values=tool_args) + result = await session.execute(document, variable_values=filtered_args) return result + async def call_tool_streaming( + self, + caller: "UtcpClient", + tool_name: str, + tool_args: Dict[str, Any], + tool_call_template: CallTemplate, + ) -> AsyncGenerator[Any, None]: + if not isinstance(tool_call_template, GraphQLProvider): + raise ValueError("GraphQLCommunicationProtocol requires a GraphQLProvider call template") + if getattr(tool_call_template, "operation_type", "query") == "subscription": + raise ValueError("GraphQL subscription streaming is not implemented") + result = await self.call_tool(caller, tool_name, tool_args, tool_call_template) + yield result + async def close(self) -> None: - self._oauth_tokens.clear() + self._oauth_tokens.clear() \ No newline at end of file diff --git a/plugins/communication_protocols/gql/tests/test_graphql_protocol.py b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py new file mode 100644 index 0000000..1b1bb74 --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_graphql_protocol.py @@ -0,0 +1,110 @@ +import os +import sys +import types +import pytest + + +# Ensure plugin src is importable +PLUGIN_SRC = os.path.join(os.path.dirname(__file__), "..", "src") +PLUGIN_SRC = os.path.abspath(PLUGIN_SRC) +if PLUGIN_SRC not in sys.path: + sys.path.append(PLUGIN_SRC) + +import utcp_gql +# Simplify imports: use the main module and assign local aliases +GraphQLProvider = utcp_gql.gql_call_template.GraphQLProvider +gql_module = utcp_gql.gql_communication_protocol + +from utcp.data.utcp_manual import UtcpManual +from utcp.utcp_client import UtcpClient +from utcp.implementations.utcp_client_implementation import UtcpClientImplementation + + +class FakeSchema: + def __init__(self): + # Minimal field objects with descriptions + self.query_type = types.SimpleNamespace( + fields={ + "hello": types.SimpleNamespace(description="Returns greeting"), + } + ) + self.mutation_type = types.SimpleNamespace( + fields={ + "add": types.SimpleNamespace(description="Adds numbers"), + } + ) + self.subscription_type = None + + +class FakeClientObj: + def __init__(self): + self.client = types.SimpleNamespace(schema=FakeSchema()) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + async def execute(self, document, variable_values=None): + # document is a gql query; we can base behavior on variable_values + variable_values = variable_values or {} + # Determine operation by presence of variables used + if "hello" in str(document): + name = variable_values.get("name", "") + return {"hello": f"Hello {name}"} + if "add" in str(document): + a = int(variable_values.get("a", 0)) + b = int(variable_values.get("b", 0)) + return {"add": a + b} + return {"ok": True} + + +class FakeTransport: + def __init__(self, url: str, headers: dict | None = None): + self.url = url + self.headers = headers or {} + + +@pytest.mark.asyncio +async def test_graphql_register_and_call(monkeypatch): + # Patch gql client/transport used by protocol to avoid needing a real server + monkeypatch.setattr(gql_module, "GqlClient", lambda *args, **kwargs: FakeClientObj()) + monkeypatch.setattr(gql_module, "AIOHTTPTransport", FakeTransport) + # Avoid real GraphQL parsing; pass-through document string to fake execute + monkeypatch.setattr(gql_module, "gql_query", lambda s: s) + + # Register plugin (call_template serializer + protocol) + utcp_gql.register() + + # Create protocol and manual call template + protocol = gql_module.GraphQLCommunicationProtocol() + provider = GraphQLProvider( + name="mock_graph", + call_template_type="graphql", + url="http://localhost/graphql", + operation_type="query", + headers={"x-client": "utcp"}, + header_fields=["x-session-id"], + ) + + # Minimal UTCP client implementation for caller context + client: UtcpClient = await UtcpClientImplementation.create() + client.config.variables = {} + + # Register and discover tools + reg = await protocol.register_manual(client, provider) + assert reg.success is True + assert isinstance(reg.manual, UtcpManual) + tool_names = sorted(t.name for t in reg.manual.tools) + assert "hello" in tool_names + assert "add" in tool_names + + # Call hello + res = await protocol.call_tool(client, "mock_graph.hello", {"name": "UTCP", "x-session-id": "abc"}, provider) + assert res == {"hello": "Hello UTCP"} + + # Call add (mutation) + provider.operation_type = "mutation" + res2 = await protocol.call_tool(client, "mock_graph.add", {"a": 2, "b": 3}, provider) + assert res2 == {"add": 5} \ No newline at end of file diff --git a/plugins/communication_protocols/mcp/pyproject.toml b/plugins/communication_protocols/mcp/pyproject.toml index 9232cd5..2efd4c3 100644 --- a/plugins/communication_protocols/mcp/pyproject.toml +++ b/plugins/communication_protocols/mcp/pyproject.toml @@ -16,7 +16,7 @@ dependencies = [ "mcp>=1.12", "utcp>=1.0", "mcp-use>=1.3", - "langchain==0.3.27", + "langchain>=0.3.27,<0.4.0", ] classifiers = [ "Development Status :: 4 - Beta", diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py index 10fc1d6..8b27d1c 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_call_template.py @@ -68,6 +68,10 @@ class TCPProvider(CallTemplate): default='\x00', description="Delimiter to detect end of TCP response (e.g., '\n', '\r\n', '\x00'). Used with 'delimiter' framing." ) + interpret_escape_sequences: bool = Field( + default=True, + description="If True, interpret Python-style escape sequences in message_delimiter (e.g., '\\n', '\\r\\n', '\\x00'). If False, use the delimiter literally as provided." + ) # Fixed-length framing options fixed_message_length: Optional[int] = Field( default=None, diff --git a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py index d5d64ac..b2f08c3 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/tcp_communication_protocol.py @@ -148,9 +148,14 @@ def _encode_message_with_framing(self, message: str, provider: TCPProvider) -> b elif provider.framing_strategy == "delimiter": # Add delimiter after the message delimiter = provider.message_delimiter or "\x00" - # Handle escape sequences - delimiter = delimiter.encode('utf-8').decode('unicode_escape') - return message_bytes + delimiter.encode('utf-8') + if provider.interpret_escape_sequences: + # Handle escape sequences (e.g., "\n", "\r\n", "\x00") + delimiter = delimiter.encode('utf-8').decode('unicode_escape') + delimiter_bytes = delimiter.encode('utf-8') + else: + # Use delimiter literally as provided + delimiter_bytes = delimiter.encode('utf-8') + return message_bytes + delimiter_bytes elif provider.framing_strategy in ("fixed_length", "stream"): # No additional framing needed @@ -202,8 +207,19 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid elif provider.framing_strategy == "delimiter": # Read until delimiter is found + # Delimiter handling: + # The code supports both literal delimiters (e.g., "\\x00") and escape-sequence interpreted delimiters (e.g., "\x00") + # via the `interpret_escape_sequences` flag in TCPProvider. This ensures compatibility with both legacy and updated + # wire protocols. The delimiter is interpreted according to the flag, so no breaking change occurs unless the flag + # is set differently than expected by the server/client. + # Example: + # If interpret_escape_sequences is True, "\\x00" becomes a null byte; if False, it remains four literal bytes. + # delimiter = delimiter.encode('utf-8') delimiter = provider.message_delimiter or "\x00" - delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + if provider.interpret_escape_sequences: + delimiter_bytes = delimiter.encode('utf-8').decode('unicode_escape').encode('utf-8') + else: + delimiter_bytes = delimiter.encode('utf-8') response_data = b"" while True: @@ -213,9 +229,9 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid response_data += chunk # Check if we've received the delimiter - if response_data.endswith(delimiter): + if response_data.endswith(delimiter_bytes): # Remove delimiter from response - return response_data[:-len(delimiter)] + return response_data[:-len(delimiter_bytes)] elif provider.framing_strategy == "fixed_length": # Read exactly fixed_message_length bytes @@ -246,6 +262,13 @@ def _decode_response_with_framing(self, sock: socket.socket, provider: TCPProvid break return response_data + + else: + # Copilot AI (5 days ago): + # The else branch for unknown framing strategies was previously removed, + # which could cause silent fallthrough and confusing behavior. Add explicit + # validation to raise a descriptive error when an unsupported strategy is provided. + raise ValueError(f"Unknown framing strategy: {provider.framing_strategy!r}") async def _send_tcp_message( self, diff --git a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py index b59ef37..b612d40 100644 --- a/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py +++ b/plugins/communication_protocols/socket/src/utcp_socket/udp_communication_protocol.py @@ -15,6 +15,7 @@ from utcp.data.call_template import CallTemplate, CallTemplateSerializer from utcp.data.register_manual_response import RegisterManualResult from utcp.data.utcp_manual import UtcpManual +from utcp.exceptions import UtcpSerializerValidationError import logging logger = logging.getLogger(__name__) @@ -98,8 +99,8 @@ def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_temp try: ctpl = CallTemplateSerializer().validate_dict(normalized["tool_call_template"]) # type: ignore normalized["tool_call_template"] = ctpl - except Exception: - # Fallback to manual template if validation fails + except (UtcpSerializerValidationError, ValueError) as e: + logger.exception(f"Failed to validate existing tool_call_template; falling back to manual template: {e}") normalized["tool_call_template"] = manual_call_template elif "tool_provider" in normalized and normalized["tool_provider"] is not None: # Convert legacy provider -> call template @@ -107,12 +108,15 @@ def _ensure_tool_call_template(self, tool_data: Dict[str, Any], manual_call_temp ctpl = UDPProviderSerializer().validate_dict(normalized["tool_provider"]) # type: ignore normalized.pop("tool_provider", None) normalized["tool_call_template"] = ctpl - except Exception: + except (UtcpSerializerValidationError, ValueError) as e: + logger.exception(f"Failed to convert legacy tool_provider to call template; falling back to manual template: {e}") normalized.pop("tool_provider", None) normalized["tool_call_template"] = manual_call_template else: normalized["tool_call_template"] = manual_call_template except Exception: + # Any unexpected error during normalization should be logged + logger.exception("Unexpected error normalizing tool definition; falling back to manual template") normalized["tool_call_template"] = manual_call_template return normalized @@ -321,5 +325,12 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too self._log_error(f"Error calling UDP tool '{tool_name}': {traceback.format_exc()}") raise + # Copilot AI (5 days ago): + # The call_tool_streaming method wraps a generator function but doesn't use the async def syntax for the method itself. + # While this works, it's inconsistent with the other implementation in tcp_communication_protocol.py (lines 384-387) which properly uses async def with an inner generator. + # For consistency and clarity, this should also use async def directly: + # + # async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): + # yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, Any], tool_call_template: CallTemplate): yield await self.call_tool(caller, tool_name, tool_args, tool_call_template) diff --git a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py index 1b6ffb2..d359fd9 100644 --- a/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py +++ b/plugins/communication_protocols/socket/tests/test_tcp_communication_protocol.py @@ -15,6 +15,7 @@ async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): # Read any incoming data to simulate request handling await reader.read(1024) except Exception: + # Ignore exceptions during read (e.g., client disconnects), as this is a test server. pass # Send response and close connection writer.write(response_container["bytes"]) @@ -23,6 +24,7 @@ async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): writer.close() await writer.wait_closed() except Exception: + # Ignore exceptions during writer close; connection may already be closed or in error state. pass server = await asyncio.start_server(handle, host="127.0.0.1", port=0) diff --git a/scripts/socket_sanity.py b/scripts/socket_sanity.py index 40b2c16..5ac6028 100644 --- a/scripts/socket_sanity.py +++ b/scripts/socket_sanity.py @@ -1,7 +1,5 @@ import sys -import os import json -import time import socket import threading import asyncio @@ -39,6 +37,7 @@ def run(): try: parsed = json.loads(msg) except Exception: + # Ignore JSON parsing errors; non-JSON input will be handled below parsed = None if isinstance(parsed, dict) and parsed.get("type") == "utcp": manual = { @@ -182,6 +181,7 @@ def handle_client(conn: socket.socket, addr): try: conn.shutdown(socket.SHUT_RDWR) except Exception: + # Ignore errors if socket is already closed or shutdown fails pass conn.close() @@ -259,7 +259,7 @@ def ensure_dict(s): assert udp_resp.get("ok") is True and udp_resp.get("echo") == "hello" assert tcp_resp.get("ok") is True and tcp_resp.get("echo") == "world" - print("Sanity passed: UDP/TCP discovery and calls work with tool_call_template normalization.") + print("Sanity check passed: UDP/TCP discovery and calls work with tool_call_template normalization.") if __name__ == "__main__": asyncio.run(run_sanity()) \ No newline at end of file