Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions server/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ requires-python = ">=3.12"
# duplicate module conflict when galileo extras are installed
dependencies = [
"fastapi>=0.109.0",
"prometheus-client>=0.20.0",
"starlette-exporter>=0.23.0",
"uvicorn[standard]>=0.27.0",
"httpx>=0.27.0", # For auth_framework HTTP providers
Expand Down
89 changes: 79 additions & 10 deletions server/src/agent_control_server/auth_framework/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
_UPSTREAM_TOKEN_HEADER_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_SERVICE_TOKEN_HEADER"
_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_EXTRA_FORWARD_HEADERS"
_UPSTREAM_CA_FILE_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_CA_FILE"
_UPSTREAM_KEEPALIVE_EXPIRY_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_KEEPALIVE_EXPIRY_SECONDS"
_UPSTREAM_MAX_CONNECTIONS_ENV = "AGENT_CONTROL_AUTH_UPSTREAM_MAX_CONNECTIONS"
_UPSTREAM_MAX_KEEPALIVE_CONNECTIONS_ENV = (
"AGENT_CONTROL_AUTH_UPSTREAM_MAX_KEEPALIVE_CONNECTIONS"
)

# Runtime flow.
_RUNTIME_MODE_ENV = "AGENT_CONTROL_RUNTIME_AUTH_MODE"
Expand Down Expand Up @@ -212,25 +217,42 @@ def _build_default_provider() -> RequestAuthorizer:
url = os.environ.get(_UPSTREAM_URL_ENV)
if not url:
raise RuntimeError(f"{_MODE_ENV}=http_upstream but {_UPSTREAM_URL_ENV} is not set.")
timeout = float(os.environ.get(_UPSTREAM_TIMEOUT_ENV, "5.0"))
timeout = _load_float_env(_UPSTREAM_TIMEOUT_ENV, 5.0)
token = os.environ.get(_UPSTREAM_TOKEN_ENV)
token_header = os.environ.get(_UPSTREAM_TOKEN_HEADER_ENV, "X-Agent-Control-Service-Token")
extra_forward_headers = _parse_extra_forward_headers(
os.environ.get(_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV)
)
ca_file = (os.environ.get(_UPSTREAM_CA_FILE_ENV) or "").strip() or None
keepalive_expiry_seconds = _load_float_env(_UPSTREAM_KEEPALIVE_EXPIRY_ENV, 1.0)
max_connections = _load_int_env(_UPSTREAM_MAX_CONNECTIONS_ENV, 100)
max_keepalive_connections = _load_int_env(_UPSTREAM_MAX_KEEPALIVE_CONNECTIONS_ENV, 20)
_validate_http_upstream_connection_config(
keepalive_expiry_seconds=keepalive_expiry_seconds,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
)
_logger.info("Default auth provider: http_upstream url=%s", url)
try:
return HttpUpstreamAuthProvider(
HttpUpstreamConfig(
url=url,
timeout_seconds=timeout,
service_token=token,
service_token_header=token_header,
extra_forward_headers=extra_forward_headers,
ca_file=ca_file,
)
upstream_config = HttpUpstreamConfig(
url=url,
timeout_seconds=timeout,
service_token=token,
service_token_header=token_header,
extra_forward_headers=extra_forward_headers,
ca_file=ca_file,
keepalive_expiry_seconds=keepalive_expiry_seconds,
max_connections=max_connections,
max_keepalive_connections=max_keepalive_connections,
)
except ValueError as exc:
raise RuntimeError(
"Invalid http_upstream auth configuration from "
f"{_UPSTREAM_TOKEN_HEADER_ENV} or "
f"{_UPSTREAM_EXTRA_FORWARD_HEADERS_ENV}: {exc}"
) from exc
try:
return HttpUpstreamAuthProvider(upstream_config)
except (OSError, ssl.SSLError) as exc:
raise RuntimeError(
f"{_UPSTREAM_CA_FILE_ENV}={ca_file!r} not found or unreadable."
Expand Down Expand Up @@ -279,6 +301,53 @@ def _parse_extra_forward_headers(raw: str | None) -> tuple[str, ...]:
return tuple(result)


def _load_float_env(env_name: str, default: float) -> float:
raw = os.environ.get(env_name)
if raw is None:
return default
try:
return float(raw)
except ValueError as exc:
raise RuntimeError(f"{env_name}={raw!r} is not a number.") from exc


def _load_int_env(env_name: str, default: int) -> int:
raw = os.environ.get(env_name)
if raw is None:
return default
try:
return int(raw)
except ValueError as exc:
raise RuntimeError(f"{env_name}={raw!r} is not an integer.") from exc


def _validate_http_upstream_connection_config(
*,
keepalive_expiry_seconds: float,
max_connections: int,
max_keepalive_connections: int,
) -> None:
if keepalive_expiry_seconds < 0:
raise RuntimeError(
f"{_UPSTREAM_KEEPALIVE_EXPIRY_ENV}={keepalive_expiry_seconds} "
"must be greater than or equal to 0."
)
if max_connections <= 0:
raise RuntimeError(
f"{_UPSTREAM_MAX_CONNECTIONS_ENV}={max_connections} must be greater than 0."
)
if max_keepalive_connections < 0:
raise RuntimeError(
f"{_UPSTREAM_MAX_KEEPALIVE_CONNECTIONS_ENV}={max_keepalive_connections} "
"must be greater than or equal to 0."
)
if max_keepalive_connections > max_connections:
raise RuntimeError(
f"{_UPSTREAM_MAX_KEEPALIVE_CONNECTIONS_ENV}={max_keepalive_connections} "
f"must be less than or equal to {_UPSTREAM_MAX_CONNECTIONS_ENV}={max_connections}."
)


def _resolve_runtime_mode() -> str:
raw = os.environ.get(_RUNTIME_MODE_ENV)
if raw is None or not raw.strip():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@
import ssl
from dataclasses import dataclass
from datetime import datetime
from time import perf_counter
from typing import Any

import httpx
from agent_control_models.errors import ErrorCode, ErrorReason
from fastapi import Request
from prometheus_client import Counter, Histogram
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -66,6 +68,17 @@

_DEFAULT_FORWARDED_HEADERS = ("X-API-Key", "Authorization", "Cookie")

_AUTH_UPSTREAM_ATTEMPTS = Counter(
"agent_control_server_auth_upstream_attempts_total",
"Auth upstream HTTP attempts made by Agent Control.",
("operation", "outcome", "status_code", "error_type"),
)
_AUTH_UPSTREAM_ATTEMPT_DURATION = Histogram(
"agent_control_server_auth_upstream_attempt_duration_seconds",
"Duration of auth upstream HTTP attempts made by Agent Control.",
("operation", "outcome"),
)


class _UpstreamGrant(BaseModel):
"""Strict schema for the upstream authorization-service response.
Expand Down Expand Up @@ -154,7 +167,30 @@ class HttpUpstreamConfig:
ca_file: str | None = None
"""Optional CA bundle path used only when verifying the auth upstream."""

keepalive_expiry_seconds: float = 1.0
"""Idle lifetime for pooled upstream connections.

Keep this shorter than the upstream server's own keepalive/recycle window
so Agent Control does not reuse sockets the upstream has already closed.
"""

max_connections: int = 100
"""Maximum concurrent connections to the auth upstream."""

max_keepalive_connections: int = 20
"""Maximum idle connections retained for the auth upstream."""

def __post_init__(self) -> None:
if self.keepalive_expiry_seconds < 0:
raise ValueError("keepalive_expiry_seconds must be greater than or equal to 0")
if self.max_connections <= 0:
raise ValueError("max_connections must be greater than 0")
if self.max_keepalive_connections < 0:
raise ValueError("max_keepalive_connections must be greater than or equal to 0")
if self.max_keepalive_connections > self.max_connections:
raise ValueError(
"max_keepalive_connections must be less than or equal to max_connections"
)
if self.service_token is None:
return
forwarded = {
Expand All @@ -180,14 +216,18 @@ def __init__(
self._owns_client = client is None
if client is not None:
self._client = client
elif config.ca_file is not None:
ssl_context = ssl.create_default_context(cafile=config.ca_file)
self._client = httpx.AsyncClient(
timeout=config.timeout_seconds,
verify=ssl_context,
)
else:
self._client = httpx.AsyncClient(timeout=config.timeout_seconds)
client_kwargs: dict[str, Any] = {
"timeout": config.timeout_seconds,
"limits": httpx.Limits(
max_connections=config.max_connections,
max_keepalive_connections=config.max_keepalive_connections,
keepalive_expiry=config.keepalive_expiry_seconds,
),
}
if config.ca_file is not None:
client_kwargs["verify"] = ssl.create_default_context(cafile=config.ca_file)
self._client = httpx.AsyncClient(**client_kwargs)

async def aclose(self) -> None:
"""Release the HTTP client if this provider created it."""
Expand All @@ -205,27 +245,43 @@ async def authorize(
if context:
payload["context"] = context

response = await self._post_upstream(operation, payload, headers)
return self._handle_response(response, operation, context)

async def _post_upstream(
self,
operation: Operation,
payload: dict[str, Any],
headers: dict[str, str],
) -> httpx.Response:
started = perf_counter()
try:
response = await self._client.post(
self._config.url,
json=payload,
headers=headers,
)
except httpx.HTTPError as exc:
_observe_upstream_attempt(
operation,
perf_counter() - started,
outcome="http_error",
error=exc,
)
_logger.warning(
"Auth upstream unreachable for operation %s: %s",
operation.value,
exc,
)
raise APIError(
status_code=503,
error_code=ErrorCode.AUTH_MISCONFIGURED,
reason=ErrorReason.SERVICE_UNAVAILABLE,
detail="Authorization service unavailable.",
hint="Retry the request; if the failure persists, contact the operator.",
) from exc
raise _authorization_service_unavailable_error() from exc

return self._handle_response(response, operation, context)
_observe_upstream_attempt(
operation,
perf_counter() - started,
outcome="response",
status_code=response.status_code,
)
return response

def _forward_headers(self, request: Request) -> dict[str, str]:
headers: dict[str, str] = {}
Expand Down Expand Up @@ -363,6 +419,36 @@ def _parse_principal(self, response: httpx.Response) -> Principal:
)


def _observe_upstream_attempt(
operation: Operation,
duration_seconds: float,
*,
outcome: str,
status_code: int | None = None,
error: httpx.HTTPError | None = None,
) -> None:
_AUTH_UPSTREAM_ATTEMPTS.labels(
operation=operation.value,
outcome=outcome,
status_code=str(status_code) if status_code is not None else "none",
error_type=type(error).__name__ if error is not None else "none",
).inc()
_AUTH_UPSTREAM_ATTEMPT_DURATION.labels(
operation=operation.value,
outcome=outcome,
).observe(duration_seconds)


def _authorization_service_unavailable_error() -> APIError:
return APIError(
status_code=503,
error_code=ErrorCode.AUTH_MISCONFIGURED,
reason=ErrorReason.SERVICE_UNAVAILABLE,
detail="Authorization service unavailable.",
hint="Retry the request; if the failure persists, contact the operator.",
)


def _ensure_target_context_matches_grant(
context: dict[str, Any] | None,
principal: Principal,
Expand Down
Loading
Loading