Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion plugins/communication_protocols/gql/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "utcp-gql"
version = "1.1.0"
version = "1.1.1"
authors = [
{ name = "UTCP Contributors" },
]
Expand All @@ -14,6 +14,7 @@ requires-python = ">=3.10"
dependencies = [
"pydantic>=2.0",
"gql>=3.0",
"aiohttp>=3.8",
"utcp>=1.1"
]
classifiers = [
Expand Down
219 changes: 219 additions & 0 deletions plugins/communication_protocols/gql/src/utcp_gql/_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
"""URL validation for the GraphQL communication protocol.

Mirror of ``utcp_http._security`` -- intentionally duplicated rather
than cross-plugin-imported so ``utcp-gql`` does not gain a runtime
dependency on ``utcp-http``. Keep the two files in sync when changing
the validator behavior. Backs GHSA-ppx3-28rw-8fpf (the original CVE
fix did not reach this plugin) and GHSA-9qhg-99ww-9mqc (redirect
SSRF on the GraphQL endpoint).
"""

from __future__ import annotations

from contextlib import asynccontextmanager
from ipaddress import ip_address
from typing import Any, AsyncIterator, Optional
from urllib.parse import urljoin, urlparse

# Hostnames considered safe to talk to over plain HTTP.
_LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1", "[::1]"})


def is_secure_url(url: str) -> bool:
"""Return True if ``url`` is safe to fetch from a UTCP HTTP protocol.

Allowed:
- Any ``https://`` URL.
- ``http://`` URLs whose host is exactly ``localhost``, ``127.0.0.1``,
or ``::1``.

Disallowed:
- Plain ``http://`` to any other host (MITM exposure).
- URLs whose hostname *starts* with ``localhost`` / ``127.0.0.1`` but
isn't actually loopback (e.g. ``http://localhost.evil.com``,
``http://127.0.0.1.attacker.example``). The earlier ``startswith``
check let these through.
- Anything without a scheme/host (file://, gopher://, javascript:, ...).
"""
if not isinstance(url, str) or not url:
return False

try:
parsed = urlparse(url)
except ValueError:
return False

scheme = (parsed.scheme or "").lower()
if scheme not in {"http", "https"}:
return False

host = (parsed.hostname or "").lower()
if not host:
return False

if scheme == "https":
return True

# http:// is only allowed for loopback.
if host in _LOOPBACK_HOSTNAMES:
return True

# Catch any other literal loopback IP that urlparse normalised
# (e.g. ``http://127.000.000.001``).
try:
return ip_address(host).is_loopback
except ValueError:
return False


def is_loopback_url(url: str) -> bool:
"""Return True if ``url``'s host is a literal loopback address.

Used by the OpenAPI converter to detect the SSRF case where a remote spec
declares ``servers: [{ url: "http://127.0.0.1:..." }]`` to redirect tool
invocation at the host running the agent. Hostname-based — not a string
prefix — so ``http://localhost.evil.com`` returns False.
"""
if not isinstance(url, str) or not url:
return False

try:
parsed = urlparse(url)
except ValueError:
return False

host = (parsed.hostname or "").lower()
if not host:
return False

if host in _LOOPBACK_HOSTNAMES:
return True

try:
return ip_address(host).is_loopback
except ValueError:
return False


def ensure_secure_url(url: str, *, context: Optional[str] = None) -> None:
"""Raise ``ValueError`` if ``url`` is not safe to fetch.

``context`` is a short label (``"manual discovery"``, ``"tool invocation"``,
etc.) included in the error so log readers can tell which trust boundary
was breached.
"""
if is_secure_url(url):
return

where = f" during {context}" if context else ""
raise ValueError(
f"Security error{where}: URL must use HTTPS or be a literal loopback "
f"address (localhost / 127.0.0.1 / ::1). Got: {url!r}. "
"Plain HTTP to any other host is rejected to prevent MITM attacks "
"and SSRF into internal services."
)


# HTTP statuses where the server expects the client to re-issue the request
# against the URL given in the ``Location`` header. 303 forces a GET; the
# rest preserve the original method.
_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})


@asynccontextmanager
async def safe_request_with_redirects(
session: Any,
method: str,
url: str,
*,
context: str,
max_redirects: int = 5,
**kwargs: Any,
) -> AsyncIterator[Any]:
"""Issue an aiohttp request that re-validates every redirect hop.

Closes the residual SSRF window left by ``ensure_secure_url`` (which
only inspects the initial URL): aiohttp by default follows 3xx
redirects without rechecking, so an attacker-controlled server could
302 the client into ``http://169.254.169.254/...`` (cloud metadata)
or any internal HTTP service and the response body would be handed
back to the caller. Backs GHSA-9qhg-99ww-9mqc.

Behavior:
* Calls ``ensure_secure_url(url, context=context)`` on the initial
URL.
* Disables aiohttp's auto-follow (``allow_redirects=False``).
* On a 3xx response with a ``Location`` header, resolves the
target against the current URL and runs ``ensure_secure_url``
on it before issuing the next hop. Rejection raises and the
redirect chain is aborted with the connection released.
* Caps the chain at ``max_redirects`` hops. Exceeding that raises
``RuntimeError``.
* Mirrors RFC 7231 method semantics: 303 forces ``GET`` and drops
any request body; 301/302/307/308 preserve method and body.

Usage:
```python
async with safe_request_with_redirects(
session, "GET", url, context="tool invocation", params=...
) as response:
response.raise_for_status()
...
```
"""
ensure_secure_url(url, context=context)
# We control redirect behavior ourselves; refuse to let callers override.
kwargs.pop("allow_redirects", None)

current_url = url
current_method = method
hops = 0
final_response = None

try:
while True:
response = await session.request(
current_method,
current_url,
allow_redirects=False,
**kwargs,
)
if response.status not in _REDIRECT_STATUSES:
final_response = response
break

location = response.headers.get("Location")
if not location:
# 3xx with no Location header — nothing to follow. Let
# the caller handle the unusual response.
final_response = response
break

if hops >= max_redirects:
response.release()
raise RuntimeError(
f"Too many redirects (>{max_redirects}) during {context} "
f"starting from {url!r}."
)

next_url = urljoin(current_url, location)
try:
ensure_secure_url(
next_url, context=f"{context} (redirect target)"
)
except Exception:
response.release()
raise

response.release()
if response.status == 303:
current_method = "GET"
kwargs.pop("json", None)
kwargs.pop("data", None)
current_url = next_url
hops += 1

yield final_response
finally:
if final_response is not None:
final_response.release()
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from utcp.data.auth_implementations.oauth2_auth import OAuth2Auth

from utcp_gql.gql_call_template import GraphQLCallTemplate
from utcp_gql._security import ensure_secure_url, safe_request_with_redirects

if TYPE_CHECKING:
from utcp.utcp_client import UtcpClient
Expand All @@ -40,30 +41,35 @@ class GraphQLCommunicationProtocol(CommunicationProtocol):
def __init__(self) -> None:
self._oauth_tokens: Dict[str, Dict[str, Any]] = {}

def _enforce_https_or_localhost(self, url: str) -> None:
if not (
url.startswith("https://")
or url.startswith("http://localhost")
or url.startswith("http://127.0.0.1")
):
raise ValueError(
"Security error: URL must use HTTPS or start with 'http://localhost' or 'http://127.0.0.1'. "
"Non-secure URLs are vulnerable to man-in-the-middle attacks. "
f"Got: {url}."
)

async def _handle_oauth2(self, auth: OAuth2Auth) -> str:
"""Fetch an OAuth2 access token.

Validates the token URL with the hostname-based ``ensure_secure_url``
helper before any credential bytes leave the process, and follows
redirects only after re-validating each hop -- defends against the
sibling SSRF / credential-exfiltration patterns in
GHSA-8cp3-qxj6-px34 and GHSA-9qhg-99ww-9mqc.
"""
client_id = auth.client_id
if client_id in self._oauth_tokens:
return self._oauth_tokens[client_id]["access_token"]

ensure_secure_url(auth.token_url, context="OAuth2 token URL")

async with aiohttp.ClientSession() as session:
data = {
"grant_type": "client_credentials",
"client_id": client_id,
"client_secret": auth.client_secret,
"scope": auth.scope,
}
async with session.post(auth.token_url, data=data) as resp:
async with safe_request_with_redirects(
session,
"POST",
auth.token_url,
context="OAuth2 token fetch",
data=data,
) as resp:
resp.raise_for_status()
token_response = await resp.json()
self._oauth_tokens[client_id] = token_response
Expand Down Expand Up @@ -94,17 +100,45 @@ async def _prepare_headers(

return headers

@staticmethod
def _disable_transport_redirects(transport: AIOHTTPTransport) -> None:
"""Patch the underlying aiohttp session used by AIOHTTPTransport
so its ``session.post`` refuses to follow 3xx responses.

gql's AIOHTTPTransport does not expose ``allow_redirects`` and
the default ClientSession setting would let an attacker-
controlled GraphQL endpoint 302 the client into an internal
service after the URL had already passed ``ensure_secure_url``.
See GHSA-9qhg-99ww-9mqc / GHSA-ppx3-28rw-8fpf.
"""
aio_session = getattr(transport, "session", None)
if aio_session is None:
return
original_post = aio_session.post

def _no_redirect_post(*args: Any, **kwargs: Any):
kwargs["allow_redirects"] = False
return original_post(*args, **kwargs)

aio_session.post = _no_redirect_post # type: ignore[method-assign]

async def register_manual(
self, caller: "UtcpClient", manual_call_template: CallTemplate
) -> RegisterManualResult:
if not isinstance(manual_call_template, GraphQLCallTemplate):
raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template")
self._enforce_https_or_localhost(manual_call_template.url)
# Hostname-based validation -- replaces the broken ``startswith``
# prefix check that let ``http://127.0.0.1.attacker.example``
# through (GHSA-ppx3-28rw-8fpf).
ensure_secure_url(
manual_call_template.url, context="GraphQL manual discovery"
)

try:
headers = await self._prepare_headers(manual_call_template)
transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers)
async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
self._disable_transport_redirects(transport)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1: Redirect protection is applied too late: it runs after entering the GqlClient context, but schema fetch can already happen during __aenter__. This leaves an SSRF/credential-leak path on the initial GraphQL request.

Prompt for AI agents
Check if this issue is valid — if so, understand the root cause and fix it. At plugins/communication_protocols/gql/src/utcp_gql/gql_communication_protocol.py, line 141:

<comment>Redirect protection is applied too late: it runs after entering the GqlClient context, but schema fetch can already happen during `__aenter__`. This leaves an SSRF/credential-leak path on the initial GraphQL request.</comment>

<file context>
@@ -94,17 +100,45 @@ async def _prepare_headers(
             headers = await self._prepare_headers(manual_call_template)
             transport = AIOHTTPTransport(url=manual_call_template.url, headers=headers)
             async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
+                self._disable_transport_redirects(transport)
                 schema = session.client.schema
                 tools: List[Tool] = []
</file context>

schema = session.client.schema
tools: List[Tool] = []

Expand Down Expand Up @@ -178,11 +212,14 @@ async def call_tool(
) -> Any:
if not isinstance(tool_call_template, GraphQLCallTemplate):
raise ValueError("GraphQLCommunicationProtocol requires a GraphQLCallTemplate call template")
self._enforce_https_or_localhost(tool_call_template.url)
ensure_secure_url(
tool_call_template.url, context="GraphQL tool invocation"
)

headers = await self._prepare_headers(tool_call_template, tool_args)
transport = AIOHTTPTransport(url=tool_call_template.url, headers=headers)
async with GqlClient(transport=transport, fetch_schema_from_transport=True) as session:
self._disable_transport_redirects(transport)
# Filter out header fields from GraphQL variables; these are sent via HTTP headers
header_fields = tool_call_template.header_fields or []
filtered_args = {k: v for k, v in tool_args.items() if k not in header_fields}
Expand Down
Loading
Loading