diff --git a/plugins/communication_protocols/gql/pyproject.toml b/plugins/communication_protocols/gql/pyproject.toml index 4377268..59e2392 100644 --- a/plugins/communication_protocols/gql/pyproject.toml +++ b/plugins/communication_protocols/gql/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-gql" -version = "1.1.0" +version = "1.1.1" authors = [ { name = "UTCP Contributors" }, ] @@ -14,6 +14,7 @@ requires-python = ">=3.10" dependencies = [ "pydantic>=2.0", "gql>=3.0", + "aiohttp>=3.8", "utcp>=1.1" ] classifiers = [ diff --git a/plugins/communication_protocols/gql/src/utcp_gql/_security.py b/plugins/communication_protocols/gql/src/utcp_gql/_security.py new file mode 100644 index 0000000..f344267 --- /dev/null +++ b/plugins/communication_protocols/gql/src/utcp_gql/_security.py @@ -0,0 +1,219 @@ +"""URL validation for the GraphQL communication protocol. + +Mirror of ``utcp_http._security`` -- intentionally duplicated rather +than cross-plugin-imported so ``utcp-gql`` does not gain a runtime +dependency on ``utcp-http``. Keep the two files in sync when changing +the validator behavior. Backs GHSA-ppx3-28rw-8fpf (the original CVE +fix did not reach this plugin) and GHSA-9qhg-99ww-9mqc (redirect +SSRF on the GraphQL endpoint). +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from ipaddress import ip_address +from typing import Any, AsyncIterator, Optional +from urllib.parse import urljoin, urlparse + +# Hostnames considered safe to talk to over plain HTTP. +_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) + + +def is_secure_url(url: str) -> bool: + """Return True if ``url`` is safe to fetch from a UTCP HTTP protocol. + + Allowed: + - Any ``https://`` URL. + - ``http://`` URLs whose host is exactly ``localhost``, ``127.0.0.1``, + or ``::1``. + + Disallowed: + - Plain ``http://`` to any other host (MITM exposure). + - URLs whose hostname *starts* with ``localhost`` / ``127.0.0.1`` but + isn't actually loopback (e.g. ``http://localhost.evil.com``, + ``http://127.0.0.1.attacker.example``). The earlier ``startswith`` + check let these through. + - Anything without a scheme/host (file://, gopher://, javascript:, ...). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"http", "https"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "https": + return True + + # http:// is only allowed for loopback. + if host in _LOOPBACK_HOSTNAMES: + return True + + # Catch any other literal loopback IP that urlparse normalised + # (e.g. ``http://127.000.000.001``). + try: + return ip_address(host).is_loopback + except ValueError: + return False + + +def is_loopback_url(url: str) -> bool: + """Return True if ``url``'s host is a literal loopback address. + + Used by the OpenAPI converter to detect the SSRF case where a remote spec + declares ``servers: [{ url: "http://127.0.0.1:..." }]`` to redirect tool + invocation at the host running the agent. Hostname-based — not a string + prefix — so ``http://localhost.evil.com`` returns False. + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if host in _LOOPBACK_HOSTNAMES: + return True + + try: + return ip_address(host).is_loopback + except ValueError: + return False + + +def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to fetch. + + ``context`` is a short label (``"manual discovery"``, ``"tool invocation"``, + etc.) included in the error so log readers can tell which trust boundary + was breached. + """ + if is_secure_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: URL must use HTTPS or be a literal loopback " + f"address (localhost / 127.0.0.1 / ::1). Got: {url!r}. " + "Plain HTTP to any other host is rejected to prevent MITM attacks " + "and SSRF into internal services." + ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py index 16b945c..694ff99 100644 --- a/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py +++ b/plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py @@ -15,6 +15,7 @@ from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql._security import ensure_secure_url, safe_request_with_redirects if TYPE_CHECKING: from utcp.utcp_client import UtcpClient @@ -40,22 +41,21 @@ class GraphQLCommunicationProtocol(CommunicationProtocol): def __init__(self) -> None: self._oauth_tokens: Dict[str, Dict[str, Any]] = {} - def _enforce_https_or_localhost(self, url: str) -> None: - if not ( - url.startswith("https://") - or url.startswith("http://localhost") - or url.startswith("http://127.0.0.1") - ): - raise ValueError( - "Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. " - "Non-secure URLs are vulnerable to man-in-the-middle attacks. " - f"Got: {url}." - ) - async def _handle_oauth2(self, auth: OAuth2Auth) -> str: + """Fetch an OAuth2 access token. + + Validates the token URL with the hostname-based ``ensure_secure_url`` + helper before any credential bytes leave the process, and follows + redirects only after re-validating each hop -- defends against the + sibling SSRF / credential-exfiltration patterns in + GHSA-8cp3-qxj6-px34 and GHSA-9qhg-99ww-9mqc. + """ client_id = auth.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + + ensure_secure_url(auth.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: data = { "grant_type": "client_credentials", @@ -63,7 +63,13 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: "client_secret": auth.client_secret, "scope": auth.scope, } - async with session.post(auth.token_url, data=data) as resp: + async with safe_request_with_redirects( + session, + "POST", + auth.token_url, + context="OAuth2 token fetch", + data=data, + ) as resp: resp.raise_for_status() token_response = await resp.json() self._oauth_tokens[client_id] = token_response @@ -94,17 +100,45 @@ async def _prepare_headers( return headers + @staticmethod + def _disable_transport_redirects(transport: AIOHTTPTransport) -> None: + """Patch the underlying aiohttp session used by AIOHTTPTransport + so its ``session.post`` refuses to follow 3xx responses. + + gql's AIOHTTPTransport does not expose ``allow_redirects`` and + the default ClientSession setting would let an attacker- + controlled GraphQL endpoint 302 the client into an internal + service after the URL had already passed ``ensure_secure_url``. + See GHSA-9qhg-99ww-9mqc / GHSA-ppx3-28rw-8fpf. + """ + aio_session = getattr(transport, "session", None) + if aio_session is None: + return + original_post = aio_session.post + + def _no_redirect_post(*args: Any, **kwargs: Any): + kwargs["allow_redirects"] = False + return original_post(*args, **kwargs) + + aio_session.post = _no_redirect_post # type: ignore[method-assign] + async def register_manual( self, caller: "UtcpClient", manual_call_template: CallTemplate ) -> RegisterManualResult: if not isinstance(manual_call_template, GraphQLCallTemplate): raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") - self._enforce_https_or_localhost(manual_call_template.url) + # Hostname-based validation -- replaces the broken ``startswith`` + # prefix check that let ``http://127.0.0.1.attacker.example`` + # through (GHSA-ppx3-28rw-8fpf). + ensure_secure_url( + manual_call_template.url, context="GraphQL manual discovery" + ) try: headers = await self._prepare_headers(manual_call_template) transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + self._disable_transport_redirects(transport) schema = session.client.schema tools: List[Tool] = [] @@ -178,11 +212,14 @@ async def call_tool( ) -> Any: if not isinstance(tool_call_template, GraphQLCallTemplate): raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template") - self._enforce_https_or_localhost(tool_call_template.url) + ensure_secure_url( + tool_call_template.url, context="GraphQL tool invocation" + ) headers = await self._prepare_headers(tool_call_template, tool_args) transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers) async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session: + self._disable_transport_redirects(transport) # Filter out header fields from GraphQL variables; these are sent via HTTP headers header_fields = tool_call_template.header_fields or [] filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields} diff --git a/plugins/communication_protocols/gql/tests/test_gql_security.py b/plugins/communication_protocols/gql/tests/test_gql_security.py new file mode 100644 index 0000000..da93008 --- /dev/null +++ b/plugins/communication_protocols/gql/tests/test_gql_security.py @@ -0,0 +1,111 @@ +"""Security tests for the GraphQL communication protocol (utcp-gql). + +Pin the fixes for GHSA-ppx3-28rw-8fpf (the original CVE-2026-44661 +URL hardening missed this plugin) and the OAuth2 / redirect halves +of GHSA-8cp3-qxj6-px34 / GHSA-9qhg-99ww-9mqc. +""" + +import pytest + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_gql._security import ( + ensure_secure_url, + is_secure_url, +) +from utcp_gql.gql_call_template import GraphQLCallTemplate +from utcp_gql.gql_communication_protocol import GraphQLCommunicationProtocol + + +# --------------------------------------------------------------------------- +# Hostname-based validator must reject the same prefix bypass as utcp-http. +# --------------------------------------------------------------------------- + + +class TestUrlValidatorRejectsPrefixBypass: + @pytest.mark.parametrize( + "url", + [ + "http://localhost.evil.com/graphql", + "http://127.0.0.1.attacker.example/graphql", + "http://169.254.169.254/graphql", + "http://10.0.0.5/graphql", + "http://internal.service.local/graphql", + "http://example.com/graphql", + ], + ) + def test_bypass_url_rejected(self, url: str) -> None: + assert is_secure_url(url) is False + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + ensure_secure_url(url) + + @pytest.mark.parametrize( + "url", + [ + "https://api.example.com/graphql", + "http://localhost/graphql", + "http://127.0.0.1:9090/graphql", + "http://[::1]:9090/graphql", + ], + ) + def test_legitimate_url_accepted(self, url: str) -> None: + assert is_secure_url(url) is True + ensure_secure_url(url) # must not raise + + +# --------------------------------------------------------------------------- +# register_manual + call_tool: URL validation is now hostname-based. +# --------------------------------------------------------------------------- + + +class TestRegisterAndCallRejectBypass: + @pytest.mark.asyncio + async def test_register_manual_rejects_prefix_bypass(self) -> None: + proto = GraphQLCommunicationProtocol() + tpl = GraphQLCallTemplate( + name="evil", + url="http://127.0.0.1.attacker.example/graphql", + ) + # The validator runs before register_manual's try/except so the + # ValueError propagates rather than being captured in the + # result. + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + await proto.register_manual(None, tpl) + + @pytest.mark.asyncio + async def test_call_tool_rejects_prefix_bypass(self) -> None: + proto = GraphQLCommunicationProtocol() + tpl = GraphQLCallTemplate( + name="evil", + url="http://localhost.evil.com/graphql", + ) + with pytest.raises(ValueError, match="HTTPS or be a literal loopback"): + await proto.call_tool(None, "x", {}, tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL is validated before credential bytes leave the process. +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected(self) -> None: + proto = GraphQLCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = GraphQLCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) diff --git a/plugins/communication_protocols/http/pyproject.toml b/plugins/communication_protocols/http/pyproject.toml index db0af80..ca64fe0 100644 --- a/plugins/communication_protocols/http/pyproject.toml +++ b/plugins/communication_protocols/http/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-http" -version = "1.1.3" +version = "1.1.4" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/http/src/utcp_http/_security.py b/plugins/communication_protocols/http/src/utcp_http/_security.py index db98cc5..8023fc0 100644 --- a/plugins/communication_protocols/http/src/utcp_http/_security.py +++ b/plugins/communication_protocols/http/src/utcp_http/_security.py @@ -9,9 +9,10 @@ from __future__ import annotations +from contextlib import asynccontextmanager from ipaddress import ip_address -from typing import Optional -from urllib.parse import urlparse +from typing import Any, AsyncIterator, Optional +from urllib.parse import urljoin, urlparse # Hostnames considered safe to talk to over plain HTTP. _LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) @@ -110,3 +111,108 @@ def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: "Plain HTTP to any other host is rejected to prevent MITM attacks " "and SSRF into internal services." ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py index 7eb8aa0..bab7214 100644 --- a/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/http_communication_protocol.py @@ -33,7 +33,7 @@ from utcp_http.http_call_template import HttpCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth from utcp_http.openapi_converter import OpenApiConverter -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import logging logging.basicConfig( @@ -153,7 +153,7 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R # Set content-type header if body is provided and header not already set if body_content is not None and "Content-Type" not in request_headers: request_headers["Content-Type"] = manual_call_template.content_type - + # Prepare body content based on content type data = None json_data = None @@ -162,20 +162,24 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R json_data = body_content else: data = body_content - - # Make the request with the call template's HTTP method - method = manual_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). + method = manual_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", params=query_params, headers=request_headers, auth=auth, json=json_data, data=data, cookies=cookies, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() # Raise exception for 4XX/5XX responses @@ -306,19 +310,24 @@ async def call_tool(self, caller, tool_name: str, tool_args: Dict[str, Any], too else: data = body_content - # Make the request with the appropriate HTTP method - method = tool_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + # Re-validate every redirect hop -- aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled tool endpoint 302 us into an + # internal service and hand its body back to the + # caller (GHSA-9qhg-99ww-9mqc). + method = tool_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="tool invocation", params=query_params, headers=request_headers, auth=auth, json=json_data, data=data, cookies=cookies, - timeout=aiohttp.ClientTimeout(total=30.0) + timeout=aiohttp.ClientTimeout(total=30.0), ) as response: response.raise_for_status() @@ -356,13 +365,27 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, yield result async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + The token URL ultimately comes from a call template, and call + templates can be sourced from attacker-controlled OpenAPI specs + (the ``OpenApiConverter`` copies ``tokenUrl`` from the spec). + Validate it before posting credentials so an attacker spec + cannot redirect ``client_id`` / ``client_secret`` exfiltration + through this protocol -- see GHSA-8cp3-qxj6-px34. The redirect + helper also blocks the post-issue redirect SSRF + (GHSA-9qhg-99ww-9mqc) on the token endpoint itself. """ - Handles OAuth2 client credentials flow, trying both body and auth header methods.""" client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: # Method 1: Send credentials in the request body try: @@ -373,7 +396,13 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: 'client_secret': auth_details.client_secret, 'scope': auth_details.scope } - async with session.post(auth_details.token_url, data=body_data) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=body_data, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response @@ -389,7 +418,14 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: 'grant_type': 'client_credentials', 'scope': auth_details.scope } - async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=header_data, + auth=header_auth, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response diff --git a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py index d53bdfd..c29e926 100644 --- a/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py +++ b/plugins/communication_protocols/http/src/utcp_http/openapi_converter.py @@ -27,7 +27,7 @@ from utcp.data.utcp_manual import UtcpManual from utcp.data.tool import Tool, JsonSchema from utcp_http.http_call_template import HttpCallTemplate -from utcp_http._security import is_loopback_url +from utcp_http._security import ensure_secure_url, is_loopback_url class OpenApiConverter: """REQUIRED @@ -368,6 +368,17 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> if flow_type in ["authorizationCode", "accessCode", "clientCredentials", "application"]: token_url = flow_config.get("tokenUrl") if token_url: + # Reject obviously-internal or plain-HTTP + # token URLs at conversion time so an + # attacker-controlled OpenAPI spec cannot + # smuggle a credential-exfiltration sink + # into the generated call template. The + # runtime check in ``_handle_oauth2`` also + # enforces this -- see + # GHSA-8cp3-qxj6-px34. + ensure_secure_url( + token_url, context="OAuth2 tokenUrl in OpenAPI spec" + ) # Use the current counter value for both placeholders client_id_placeholder = self._get_placeholder("CLIENT_ID") client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") @@ -379,12 +390,15 @@ def _create_auth_from_scheme(self, scheme: Dict[str, Any], scheme_name: str) -> client_secret=client_secret_placeholder, scope=" ".join(flow_config.get("scopes", {}).keys()) or None ) - + # OpenAPI 2.0 format (flows directly in scheme) else: flow_type = scheme.get("flow", "") token_url = scheme.get("tokenUrl") if token_url and flow_type in ["accessCode", "application", "clientCredentials"]: + ensure_secure_url( + token_url, context="OAuth2 tokenUrl in OpenAPI spec" + ) # Use the current counter value for both placeholders client_id_placeholder = self._get_placeholder("CLIENT_ID") client_secret_placeholder = self._get_placeholder("CLIENT_SECRET") diff --git a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py index b9ab964..fa5c3d9 100644 --- a/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/sse_communication_protocol.py @@ -17,7 +17,7 @@ from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_http.sse_call_template import SseCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import traceback import logging @@ -116,19 +116,23 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R else: data = body_content - # Make the request (typically GET for discovery, but respect configuration) + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). method = "GET" # Default to GET for discovery - request_method = getattr(session, method.lower()) - - async with request_method( + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", headers=request_headers, auth=auth, params=query_params, cookies=cookies, json=json_data, data=data, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() response_data = await response.json() @@ -208,10 +212,26 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, data = body_content if "application/json" not in request_headers.get("Content-Type", "") else None json_data = body_content if "application/json" in request_headers.get("Content-Type", "") else None + # SSE handshake must not follow redirects: the streaming + # response has to stay open for the lifetime of the tool + # call, which is incompatible with the per-hop validator's + # release semantics, and SSE redirects are pathological in + # practice. Reject 3xx outright so an attacker-controlled + # endpoint cannot redirect the handshake into an internal + # service (GHSA-9qhg-99ww-9mqc). response = await session.request( method, url, params=query_params, headers=request_headers, - auth=auth, cookies=cookies, json=json_data, data=data, timeout=None + auth=auth, cookies=cookies, json=json_data, data=data, + timeout=None, allow_redirects=False, ) + if 300 <= response.status < 400: + response.release() + raise RuntimeError( + f"SSE endpoint at {url!r} returned a {response.status} " + f"redirect. Redirects are not followed during SSE " + f"handshakes; update the call template to point at " + f"the final URL directly." + ) response.raise_for_status() async for event in self._process_sse_stream(response, tool_call_template.event_type): yield event @@ -275,26 +295,52 @@ async def _process_sse_stream(self, response: aiohttp.ClientResponse, event_type pass # Session is managed and closed by deregister_tool_provider async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: - """Handles OAuth2 client credentials flow, trying both body and auth header methods.""" + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + Validates the token URL before posting credentials so an + attacker-controlled OpenAPI spec cannot redirect ``client_id`` / + ``client_secret`` exfiltration through this protocol + (GHSA-8cp3-qxj6-px34). The redirect helper also blocks the + post-issue redirect SSRF (GHSA-9qhg-99ww-9mqc) on the token + endpoint itself. + """ client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: try: # Method 1: Credentials in body body_data = {'grant_type': 'client_credentials', 'client_id': client_id, 'client_secret': auth_details.client_secret, 'scope': auth_details.scope} - async with session.post(auth_details.token_url, data=body_data) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=body_data, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response return token_response["access_token"] except aiohttp.ClientError as e: logger.error(f"OAuth2 with body failed: {e}. Trying Basic Auth.") - + try: # Method 2: Credentials in header header_auth = aiohttp.BasicAuth(client_id, auth_details.client_secret) header_data = {'grant_type': 'client_credentials', 'scope': auth_details.scope} - async with session.post(auth_details.token_url, data=header_data, auth=header_auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data=header_data, + auth=header_auth, + ) as response: response.raise_for_status() token_response = await response.json() self._oauth_tokens[client_id] = token_response diff --git a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py index 72fb2f2..668735c 100644 --- a/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py +++ b/plugins/communication_protocols/http/src/utcp_http/streamable_http_communication_protocol.py @@ -15,7 +15,7 @@ from utcp.data.auth_implementations import OAuth2Auth from utcp_http.streamable_http_call_template import StreamableHttpCallTemplate from aiohttp import ClientSession, BasicAuth as AiohttpBasicAuth, ClientResponse -from utcp_http._security import ensure_secure_url +from utcp_http._security import ensure_secure_url, safe_request_with_redirects import logging logging.basicConfig( @@ -119,19 +119,23 @@ async def register_manual(self, caller, manual_call_template: CallTemplate) -> R else: data = body_content - # Make the request with the template's HTTP method - method = manual_call_template.http_method.lower() - request_method = getattr(session, method) - - async with request_method( + # Re-validate every redirect hop. aiohttp's default + # ``allow_redirects=True`` would otherwise let an + # attacker-controlled discovery URL 302 us into an + # internal service (GHSA-9qhg-99ww-9mqc). + method = manual_call_template.http_method.upper() + async with safe_request_with_redirects( + session, + method, url, + context="manual discovery", headers=request_headers, auth=auth, params=query_params, cookies=cookies, json=json_data, data=data, - timeout=aiohttp.ClientTimeout(total=10.0) + timeout=aiohttp.ClientTimeout(total=10.0), ) as response: response.raise_for_status() response_data = await response.json() @@ -248,6 +252,12 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, else: data = body_content + # Streaming handshake must not follow redirects: the + # response has to stay open for the lifetime of the tool + # call, which is incompatible with the per-hop validator's + # release semantics. Reject 3xx outright so an + # attacker-controlled endpoint cannot redirect us into an + # internal service (GHSA-9qhg-99ww-9mqc). response = await session.request( method=tool_call_template.http_method, url=url, @@ -257,8 +267,17 @@ async def call_tool_streaming(self, caller, tool_name: str, tool_args: Dict[str, cookies=cookies, json=json_data, data=data, - timeout=timeout + timeout=timeout, + allow_redirects=False, ) + if 300 <= response.status < 400: + response.release() + raise RuntimeError( + f"Streamable HTTP endpoint at {url!r} returned a " + f"{response.status} redirect. Redirects are not " + f"followed during streaming handshakes; update the " + f"call template to point at the final URL directly." + ) response.raise_for_status() async for chunk in self._process_http_stream(response, tool_call_template.chunk_size, tool_call_template.name): @@ -314,16 +333,40 @@ async def _process_http_stream(self, response: ClientResponse, chunk_size: Optio pass async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: - """Handles OAuth2 client credentials flow, trying both body and auth header methods.""" + """Handle OAuth2 client credentials flow, trying both body and + auth header methods. + + Validates the token URL before posting credentials so an + attacker-controlled OpenAPI spec cannot redirect ``client_id`` / + ``client_secret`` exfiltration through this protocol + (GHSA-8cp3-qxj6-px34). The redirect helper also blocks the + post-issue redirect SSRF (GHSA-9qhg-99ww-9mqc) on the token + endpoint itself. + """ client_id = auth_details.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + # Reject obviously-internal or plain-HTTP non-loopback token + # endpoints before any credential bytes leave the process. + ensure_secure_url(auth_details.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: # Method 1: Credentials in body try: logger.info(f"Attempting OAuth2 token fetch for '{client_id}' with credentials in body.") - async with session.post(auth_details.token_url, data={'grant_type': 'client_credentials', 'client_id': client_id, 'client_secret': auth_details.client_secret, 'scope': auth_details.scope}) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data={ + 'grant_type': 'client_credentials', + 'client_id': client_id, + 'client_secret': auth_details.client_secret, + 'scope': auth_details.scope, + }, + ) as response: response.raise_for_status() token_data = await response.json() self._oauth_tokens[client_id] = token_data @@ -335,7 +378,17 @@ async def _handle_oauth2(self, auth_details: OAuth2Auth) -> str: try: logger.info(f"Attempting OAuth2 token fetch for '{client_id}' with Basic Auth header.") auth = AiohttpBasicAuth(client_id, auth_details.client_secret) - async with session.post(auth_details.token_url, data={'grant_type': 'client_credentials', 'scope': auth_details.scope}, auth=auth) as response: + async with safe_request_with_redirects( + session, + "POST", + auth_details.token_url, + context="OAuth2 token fetch", + data={ + 'grant_type': 'client_credentials', + 'scope': auth_details.scope, + }, + auth=auth, + ) as response: response.raise_for_status() token_data = await response.json() self._oauth_tokens[client_id] = token_data diff --git a/plugins/communication_protocols/http/tests/test_redirect_security.py b/plugins/communication_protocols/http/tests/test_redirect_security.py new file mode 100644 index 0000000..f7bb498 --- /dev/null +++ b/plugins/communication_protocols/http/tests/test_redirect_security.py @@ -0,0 +1,291 @@ +"""Tests for the redirect + OAuth2 token-URL hardening landing in +utcp-http 1.1.4. + +Pin the fixes for: +- GHSA-9qhg-99ww-9mqc: aiohttp's default ``allow_redirects=True`` let + attacker-controlled tool/manual endpoints 302 the client into + internal services that ``ensure_secure_url`` was supposed to block. +- GHSA-8cp3-qxj6-px34: OAuth2 ``tokenUrl`` from a remote OpenAPI spec + was used verbatim, so an attacker spec could POST the victim's + ``client_id`` / ``client_secret`` to any URL. +""" + +import pytest +from aiohttp import web + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_http._security import safe_request_with_redirects +from utcp_http.http_communication_protocol import HttpCommunicationProtocol +from utcp_http.http_call_template import HttpCallTemplate +from utcp_http.openapi_converter import OpenApiConverter + + +# --------------------------------------------------------------------------- +# safe_request_with_redirects: behaviour table. +# --------------------------------------------------------------------------- + + +class TestSafeRequestWithRedirects: + @pytest.mark.asyncio + async def test_initial_url_validated(self) -> None: + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(ValueError, match="manual discovery"): + async with safe_request_with_redirects( + session, + "GET", + "http://169.254.169.254/latest/meta-data/", + context="manual discovery", + ): + pass + + @pytest.mark.asyncio + async def test_redirect_to_internal_target_is_blocked( + self, aiohttp_server + ) -> None: + """Attacker-controlled origin 302s to a non-loopback plain-HTTP + URL. The helper must reject before the second hop is issued. + """ + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("http://169.254.169.254/latest/meta-data/") + + app = web.Application() + app.router.add_get("/tool", _redirect) + server = await aiohttp_server(app) + attacker_url = str(server.make_url("/tool")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(ValueError, match="redirect target"): + async with safe_request_with_redirects( + session, + "GET", + attacker_url, + context="tool invocation", + ): + pass + + @pytest.mark.asyncio + async def test_redirect_to_loopback_is_allowed( + self, aiohttp_server + ) -> None: + """Legit loopback-to-loopback redirect is followed.""" + async def _final(request: web.Request) -> web.Response: + return web.json_response({"hop": "final"}) + + app = web.Application() + app.router.add_get("/final", _final) + + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/final") + + app.router.add_get("/start", _redirect) + server = await aiohttp_server(app) + start_url = str(server.make_url("/start")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + async with safe_request_with_redirects( + session, "GET", start_url, context="tool invocation" + ) as response: + payload = await response.json() + assert payload == {"hop": "final"} + + @pytest.mark.asyncio + async def test_redirect_loop_is_capped(self, aiohttp_server) -> None: + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("/loop") + + app = web.Application() + app.router.add_get("/loop", _redirect) + server = await aiohttp_server(app) + loop_url = str(server.make_url("/loop")) + + import aiohttp + + async with aiohttp.ClientSession() as session: + with pytest.raises(RuntimeError, match="Too many redirects"): + async with safe_request_with_redirects( + session, + "GET", + loop_url, + context="tool invocation", + max_redirects=3, + ): + pass + + +# --------------------------------------------------------------------------- +# End-to-end: HttpCommunicationProtocol.call_tool must not exfiltrate +# internal responses via a 302. +# --------------------------------------------------------------------------- + + +class TestCallToolRedirectExfiltration: + @pytest.mark.asyncio + async def test_attacker_redirect_to_internal_blocked( + self, aiohttp_server + ) -> None: + # Internal "metadata" service -- on loopback for the test so we + # can stand it up, but the validator rejects it because the + # OUTER tool URL is non-loopback (it would in production live + # on 169.254.169.254). We instead point the 302 at the + # canonical metadata URL to assert the rejection mechanism. + async def _redirect(request: web.Request) -> web.Response: + raise web.HTTPFound("http://169.254.169.254/latest/meta-data/") + + app = web.Application() + app.router.add_get("/tool", _redirect) + server = await aiohttp_server(app) + attacker_url = str(server.make_url("/tool")) + + proto = HttpCommunicationProtocol() + tpl = HttpCallTemplate( + name="lookup", url=attacker_url, http_method="GET" + ) + + with pytest.raises(ValueError, match="redirect target"): + await proto.call_tool(None, "lookup", {}, tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL must be validated before any credential bytes leave +# the process. +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected_at_runtime(self) -> None: + proto = HttpCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/oauth/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = HttpCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + +class TestOAuth2TokenUrlExtractedFromOpenApiSpec: + """Reject malicious tokenUrl at OpenAPI conversion time so the bad + URL never makes it into a generated HttpCallTemplate. + """ + + def test_internal_token_url_in_oauth2_clientcredentials_rejected( + self, + ) -> None: + malicious_spec = { + "openapi": "3.0.0", + "info": {"title": "evil", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"evilOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "evilOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "http://169.254.169.254/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + malicious_spec, spec_url="https://attacker.example/openapi.json" + ) + with pytest.raises(ValueError, match="OAuth2 tokenUrl"): + converter.convert() + + def test_plain_http_token_url_to_attacker_rejected(self) -> None: + malicious_spec = { + "openapi": "3.0.0", + "info": {"title": "evil", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"evilOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "evilOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "http://attacker.example/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + malicious_spec, spec_url="https://api.example.com/openapi.json" + ) + with pytest.raises(ValueError, match="OAuth2 tokenUrl"): + converter.convert() + + def test_legitimate_https_token_url_accepted(self) -> None: + good_spec = { + "openapi": "3.0.0", + "info": {"title": "good", "version": "1.0"}, + "servers": [{"url": "https://api.example.com"}], + "paths": { + "/x": { + "get": { + "operationId": "x", + "security": [{"goodOAuth2": ["read"]}], + "responses": {"200": {"description": "ok"}}, + } + } + }, + "components": { + "securitySchemes": { + "goodOAuth2": { + "type": "oauth2", + "flows": { + "clientCredentials": { + "tokenUrl": "https://auth.example.com/token", + "scopes": {"read": "read access"}, + } + }, + } + } + }, + } + converter = OpenApiConverter( + good_spec, spec_url="https://api.example.com/openapi.json" + ) + manual = converter.convert() + assert len(manual.tools) == 1 diff --git a/plugins/communication_protocols/websocket/pyproject.toml b/plugins/communication_protocols/websocket/pyproject.toml index 09ce85c..952873f 100644 --- a/plugins/communication_protocols/websocket/pyproject.toml +++ b/plugins/communication_protocols/websocket/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "utcp-websocket" -version = "1.1.0" +version = "1.1.1" authors = [ { name = "UTCP Contributors" }, ] diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py new file mode 100644 index 0000000..cff8c06 --- /dev/null +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/_security.py @@ -0,0 +1,276 @@ +"""URL validation for the WebSocket communication protocol. + +Mirror of ``utcp_http._security`` -- intentionally duplicated rather +than cross-plugin-imported so ``utcp-websocket`` does not gain a +runtime dependency on ``utcp-http``. Keep in sync when changing the +validator behavior. Backs GHSA-ppx3-28rw-8fpf (the WebSocket plugin +was missing the URL check entirely, despite its docstrings claiming +"WSS or localhost only"). + +WebSocket URLs use the ``ws://`` and ``wss://`` schemes, so this +module exposes :func:`is_secure_ws_url` / :func:`ensure_secure_ws_url` +in addition to the HTTP-scheme helpers. ``wss://`` is always allowed; +``ws://`` is allowed only for literal loopback hosts. +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from ipaddress import ip_address +from typing import Any, AsyncIterator, Optional +from urllib.parse import urljoin, urlparse + +# Hostnames considered safe to talk to over plain HTTP. +_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"}) + + +def _hostname_is_loopback(host: str) -> bool: + if host in _LOOPBACK_HOSTNAMES: + return True + try: + return ip_address(host).is_loopback + except ValueError: + return False + + +def is_secure_url(url: str) -> bool: + """Return True if ``url`` is safe to fetch from a UTCP HTTP protocol. + + Allowed: + - Any ``https://`` URL. + - ``http://`` URLs whose host is exactly ``localhost``, ``127.0.0.1``, + or ``::1``. + + Disallowed: + - Plain ``http://`` to any other host (MITM exposure). + - URLs whose hostname *starts* with ``localhost`` / ``127.0.0.1`` but + isn't actually loopback (e.g. ``http://localhost.evil.com``, + ``http://127.0.0.1.attacker.example``). The earlier ``startswith`` + check let these through. + - Anything without a scheme/host (file://, gopher://, javascript:, ...). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"http", "https"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "https": + return True + + # http:// is only allowed for loopback. + return _hostname_is_loopback(host) + + +def is_secure_ws_url(url: str) -> bool: + """Return True if ``url`` is safe to open as a WebSocket connection. + + Allowed: + - Any ``wss://`` URL. + - ``ws://`` URLs whose host is a literal loopback address. + + Mirrors :func:`is_secure_url` for the WebSocket schemes. Backs the + "WSS or localhost only" guarantee that the WebSocket plugin's + docstrings advertise but the code did not previously enforce + (GHSA-ppx3-28rw-8fpf). + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + scheme = (parsed.scheme or "").lower() + if scheme not in {"ws", "wss"}: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if scheme == "wss": + return True + + return _hostname_is_loopback(host) + + +def ensure_secure_ws_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to open as a WebSocket. + + Companion to :func:`ensure_secure_url` for WebSocket schemes. + """ + if is_secure_ws_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: WebSocket URL must use WSS or be a literal " + f"loopback address (ws://localhost / ws://127.0.0.1 / ws://[::1]). " + f"Got: {url!r}. Plain WS to any other host is rejected to prevent " + "MITM attacks and SSRF into internal services." + ) + + +def is_loopback_url(url: str) -> bool: + """Return True if ``url``'s host is a literal loopback address. + + Used by the OpenAPI converter to detect the SSRF case where a remote spec + declares ``servers: [{ url: "http://127.0.0.1:..." }]`` to redirect tool + invocation at the host running the agent. Hostname-based — not a string + prefix — so ``http://localhost.evil.com`` returns False. + """ + if not isinstance(url, str) or not url: + return False + + try: + parsed = urlparse(url) + except ValueError: + return False + + host = (parsed.hostname or "").lower() + if not host: + return False + + if host in _LOOPBACK_HOSTNAMES: + return True + + try: + return ip_address(host).is_loopback + except ValueError: + return False + + +def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None: + """Raise ``ValueError`` if ``url`` is not safe to fetch. + + ``context`` is a short label (``"manual discovery"``, ``"tool invocation"``, + etc.) included in the error so log readers can tell which trust boundary + was breached. + """ + if is_secure_url(url): + return + + where = f" during {context}" if context else "" + raise ValueError( + f"Security error{where}: URL must use HTTPS or be a literal loopback " + f"address (localhost / 127.0.0.1 / ::1). Got: {url!r}. " + "Plain HTTP to any other host is rejected to prevent MITM attacks " + "and SSRF into internal services." + ) + + +# HTTP statuses where the server expects the client to re-issue the request +# against the URL given in the ``Location`` header. 303 forces a GET; the +# rest preserve the original method. +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) + + +@asynccontextmanager +async def safe_request_with_redirects( + session: Any, + method: str, + url: str, + *, + context: str, + max_redirects: int = 5, + **kwargs: Any, +) -> AsyncIterator[Any]: + """Issue an aiohttp request that re-validates every redirect hop. + + Closes the residual SSRF window left by ``ensure_secure_url`` (which + only inspects the initial URL): aiohttp by default follows 3xx + redirects without rechecking, so an attacker-controlled server could + 302 the client into ``http://169.254.169.254/...`` (cloud metadata) + or any internal HTTP service and the response body would be handed + back to the caller. Backs GHSA-9qhg-99ww-9mqc. + + Behavior: + * Calls ``ensure_secure_url(url, context=context)`` on the initial + URL. + * Disables aiohttp's auto-follow (``allow_redirects=False``). + * On a 3xx response with a ``Location`` header, resolves the + target against the current URL and runs ``ensure_secure_url`` + on it before issuing the next hop. Rejection raises and the + redirect chain is aborted with the connection released. + * Caps the chain at ``max_redirects`` hops. Exceeding that raises + ``RuntimeError``. + * Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops + any request body; 301/302/307/308 preserve method and body. + + Usage: + ```python + async with safe_request_with_redirects( + session, "GET", url, context="tool invocation", params=... + ) as response: + response.raise_for_status() + ... + ``` + """ + ensure_secure_url(url, context=context) + # We control redirect behavior ourselves; refuse to let callers override. + kwargs.pop("allow_redirects", None) + + current_url = url + current_method = method + hops = 0 + final_response = None + + try: + while True: + response = await session.request( + current_method, + current_url, + allow_redirects=False, + **kwargs, + ) + if response.status not in _REDIRECT_STATUSES: + final_response = response + break + + location = response.headers.get("Location") + if not location: + # 3xx with no Location header — nothing to follow. Let + # the caller handle the unusual response. + final_response = response + break + + if hops >= max_redirects: + response.release() + raise RuntimeError( + f"Too many redirects (>{max_redirects}) during {context} " + f"starting from {url!r}." + ) + + next_url = urljoin(current_url, location) + try: + ensure_secure_url( + next_url, context=f"{context} (redirect target)" + ) + except Exception: + response.release() + raise + + response.release() + if response.status == 303: + current_method = "GET" + kwargs.pop("json", None) + kwargs.pop("data", None) + current_url = next_url + hops += 1 + + yield final_response + finally: + if final_response is not None: + final_response.release() diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py index 81dbb2c..4ce6dfe 100644 --- a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_call_template.py @@ -83,10 +83,23 @@ class WebSocketCallTemplate(CallTemplate): @field_validator("url") @classmethod def validate_url(cls, v: str) -> str: - """Validate WebSocket URL format.""" - if not (v.startswith("wss://") or v.startswith("ws://localhost") or v.startswith("ws://127.0.0.1")): + """Validate WebSocket URL format. + + Uses the hostname-based ``is_secure_ws_url`` helper rather than + a ``startswith`` prefix match: the prefix form let + ``ws://localhost.evil.com`` and ``ws://127.0.0.1.attacker.example`` + through, which is the bypass tracked in GHSA-ppx3-28rw-8fpf. + """ + # Local import keeps the call-template module free of an + # always-on import of the validator (and matches how the HTTP + # plugins handle the same concern). + from utcp_websocket._security import is_secure_ws_url + + if not is_secure_ws_url(v): raise ValueError( - f"WebSocket URL must use wss:// or start with ws://localhost or ws://127.0.0.1. Got: {v}" + f"WebSocket URL must use wss:// or be a literal loopback " + f"address (ws://localhost / ws://127.0.0.1 / ws://[::1]). " + f"Got: {v!r}." ) return v diff --git a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py index 48a1d21..17fb2db 100644 --- a/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py +++ b/plugins/communication_protocols/websocket/src/utcp_websocket/websocket_communication_protocol.py @@ -29,6 +29,11 @@ from utcp.data.auth_implementations.basic_auth import BasicAuth from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth from utcp_websocket.websocket_call_template import WebSocketCallTemplate +from utcp_websocket._security import ( + ensure_secure_url, + ensure_secure_ws_url, + safe_request_with_redirects, +) logging.basicConfig( level=logging.INFO, @@ -136,11 +141,20 @@ def _format_tool_call_message( return json.dumps(arguments) async def _handle_oauth2(self, auth: OAuth2Auth) -> str: - """Handle OAuth2 authentication and token management.""" + """Handle OAuth2 authentication and token management. + + Validates the token URL with ``ensure_secure_url`` before any + credential bytes leave the process, and re-validates every + redirect hop. Closes the sibling SSRF / credential-exfiltration + patterns in GHSA-8cp3-qxj6-px34 and GHSA-9qhg-99ww-9mqc on the + OAuth2 path used by this plugin. + """ client_id = auth.client_id if client_id in self._oauth_tokens: return self._oauth_tokens[client_id]["access_token"] + ensure_secure_url(auth.token_url, context="OAuth2 token URL") + async with aiohttp.ClientSession() as session: data = { 'grant_type': 'client_credentials', @@ -148,7 +162,13 @@ async def _handle_oauth2(self, auth: OAuth2Auth) -> str: 'client_secret': auth.client_secret, 'scope': auth.scope } - async with session.post(auth.token_url, data=data) as resp: + async with safe_request_with_redirects( + session, + "POST", + auth.token_url, + context="OAuth2 token fetch", + data=data, + ) as resp: resp.raise_for_status() token_response = await resp.json() self._oauth_tokens[client_id] = token_response @@ -175,7 +195,21 @@ async def _prepare_headers(self, call_template: WebSocketCallTemplate) -> Dict[s return headers async def _get_connection(self, call_template: WebSocketCallTemplate) -> ClientWebSocketResponse: - """Get or create a WebSocket connection for the call template.""" + """Get or create a WebSocket connection for the call template. + + Enforces the "WSS or loopback" guarantee that the module + docstring advertises. The previous implementation skipped this + check entirely, letting any URL through, which is the + WebSocket half of GHSA-ppx3-28rw-8fpf. Also disables + redirect-following on the upgrade request to prevent a + post-validation redirect from steering the handshake into an + internal service (GHSA-9qhg-99ww-9mqc). + """ + # Hostname-based validation -- never let attacker-controlled or + # plain-WS-to-non-loopback URLs through, regardless of headers + # already configured on the call template. + ensure_secure_ws_url(call_template.url, context="WebSocket connection") + provider_key = f"{call_template.name}_{call_template.url}" # Check if we have an active connection @@ -198,7 +232,11 @@ async def _get_connection(self, call_template: WebSocketCallTemplate) -> ClientW call_template.url, headers=headers, protocols=[call_template.protocol] if call_template.protocol else None, - heartbeat=30 if call_template.keep_alive else None + heartbeat=30 if call_template.keep_alive else None, + # aiohttp's ws_connect defaults to following HTTP + # redirects on the upgrade handshake; refuse so a + # 3xx response cannot land us on a different host. + allow_redirects=False, ) self._connections[provider_key] = ws logger.info(f"WebSocket connected to {call_template.url}") diff --git a/plugins/communication_protocols/websocket/tests/test_websocket_security.py b/plugins/communication_protocols/websocket/tests/test_websocket_security.py new file mode 100644 index 0000000..9fb5d58 --- /dev/null +++ b/plugins/communication_protocols/websocket/tests/test_websocket_security.py @@ -0,0 +1,172 @@ +"""Security tests for the WebSocket communication protocol +(utcp-websocket). + +Pin the fixes for GHSA-ppx3-28rw-8fpf: the previous implementation +did NO URL validation at all despite its docstrings advertising +"WSS or localhost only", letting any ``ws://`` URL connect (with +credentials attached) to an attacker-controlled host. Also covers +the OAuth2 / redirect halves of GHSA-8cp3-qxj6-px34 and +GHSA-9qhg-99ww-9mqc. +""" + +import pytest + +from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth +from utcp_websocket._security import ( + ensure_secure_url, + ensure_secure_ws_url, + is_secure_url, + is_secure_ws_url, +) +from utcp_websocket.websocket_call_template import WebSocketCallTemplate +from utcp_websocket.websocket_communication_protocol import ( + WebSocketCommunicationProtocol, +) + + +# --------------------------------------------------------------------------- +# WebSocket-scheme validator: ws:// is loopback-only, wss:// always OK. +# --------------------------------------------------------------------------- + + +class TestWebSocketUrlValidator: + @pytest.mark.parametrize( + "url", + [ + "wss://api.example.com/socket", + "ws://localhost/socket", + "ws://127.0.0.1:9090/socket", + "ws://[::1]:9090/socket", + ], + ) + def test_secure_ws_url_accepted(self, url: str) -> None: + assert is_secure_ws_url(url) is True + ensure_secure_ws_url(url) + + @pytest.mark.parametrize( + "url", + [ + # Plain ws:// to non-loopback host (MITM + SSRF surface). + "ws://169.254.169.254/socket", + "ws://internal.service.local/socket", + "ws://10.0.0.5/socket", + "ws://example.com/socket", + # The localhost.evil.com / 127.0.0.1.attacker.example bypass: + # not loopback even though the prefix looks like it. + "ws://localhost.evil.com/socket", + "ws://127.0.0.1.attacker.example/socket", + # HTTP schemes are not WebSocket URLs. + "http://localhost/socket", + "https://api.example.com/socket", + # Junk inputs. + "", + "not-a-url", + "javascript:alert(1)", + ], + ) + def test_insecure_ws_url_rejected(self, url: str) -> None: + assert is_secure_ws_url(url) is False + with pytest.raises(ValueError, match="WebSocket URL"): + ensure_secure_ws_url(url) + + +# --------------------------------------------------------------------------- +# _get_connection enforces ensure_secure_ws_url -- the plugin used to +# accept any URL silently. +# --------------------------------------------------------------------------- + + +class TestTemplateRejectsBypass: + """The Pydantic field validator on WebSocketCallTemplate is the + first line of defence -- with the new hostname-based check it + catches the prefix bypass that the original ``startswith`` form + let through. + """ + + @pytest.mark.parametrize( + "url", + [ + "ws://169.254.169.254/", + "ws://localhost.evil.com/socket", + "ws://127.0.0.1.attacker.example/socket", + "ws://example.com/socket", + "http://localhost/socket", # not a WebSocket scheme + ], + ) + def test_template_rejects_bypass(self, url: str) -> None: + with pytest.raises(Exception) as exc_info: + WebSocketCallTemplate(name="ws", url=url) + # Pydantic wraps the message inside its own ValidationError -- + # the underlying ValueError text must still be present so + # operators can see what was rejected. + assert "WebSocket URL" in str(exc_info.value) + + +class TestGetConnectionRejectsLoopbackBypass: + """Defence in depth: ``_get_connection`` itself runs the same + hostname-based check so a template that bypassed the Pydantic + validator (e.g. constructed without ``model_validate``) still + cannot open the WebSocket. + """ + + @pytest.mark.asyncio + async def test_connection_rejected_when_template_bypassed(self) -> None: + proto = WebSocketCommunicationProtocol() + # Construct a template that *would* fail the field validator, + # but skip validation by going through ``model_construct``. + tpl = WebSocketCallTemplate.model_construct( + name="ws", + url="ws://localhost.evil.com/socket", + call_template_type="websocket", + keep_alive=True, + timeout=30, + ) + with pytest.raises(ValueError, match="WebSocket URL"): + await proto._get_connection(tpl) + + +# --------------------------------------------------------------------------- +# OAuth2 token URL is validated (the WebSocket plugin's OAuth2 path +# goes over HTTP, so it uses ensure_secure_url not ensure_secure_ws_url). +# --------------------------------------------------------------------------- + + +class TestOAuth2TokenUrlValidation: + @pytest.mark.asyncio + async def test_internal_token_url_rejected(self) -> None: + proto = WebSocketCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://169.254.169.254/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + @pytest.mark.asyncio + async def test_plain_http_non_loopback_token_url_rejected(self) -> None: + proto = WebSocketCommunicationProtocol() + auth = OAuth2Auth( + token_url="http://attacker.example/token", + client_id="victim-id", + client_secret="victim-secret", + ) + with pytest.raises(ValueError, match="OAuth2 token URL"): + await proto._handle_oauth2(auth) + + +# --------------------------------------------------------------------------- +# Sanity: the HTTP-scheme validator is also re-exported (the OAuth2 +# token endpoint goes over HTTP/HTTPS). +# --------------------------------------------------------------------------- + + +class TestHttpUrlValidator: + def test_https_accepted(self) -> None: + assert is_secure_url("https://api.example.com/oauth/token") is True + ensure_secure_url("https://api.example.com/oauth/token") + + def test_internal_rejected(self) -> None: + assert is_secure_url("http://169.254.169.254/token") is False + with pytest.raises(ValueError): + ensure_secure_url("http://169.254.169.254/token")