Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions server/src/agent_control_server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,24 @@ class AgentControlServerDatabaseConfig(BaseSettings):
"DB_DATABASE",
)
driver: str = _env_alias_field("psycopg", "AGENT_CONTROL_DB_DRIVER", "DB_DRIVER")
pool_size: int = Field(
default=5,
ge=1,
validation_alias=AliasChoices("AGENT_CONTROL_DB_POOL_SIZE", "DB_POOL_SIZE"),
)
max_overflow: int = Field(
default=0,
ge=0,
validation_alias=AliasChoices("AGENT_CONTROL_DB_MAX_OVERFLOW", "DB_MAX_OVERFLOW"),
)
pool_timeout_seconds: float = Field(
default=5.0,
gt=0,
validation_alias=AliasChoices(
"AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS",
"DB_POOL_TIMEOUT_SECONDS",
),
)

def get_url(self) -> str:
"""Get database URL, preferring an explicit URL if configured."""
Expand Down
55 changes: 51 additions & 4 deletions server/src/agent_control_server/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from collections.abc import AsyncGenerator
from typing import Any

from prometheus_client import Gauge
from sqlalchemy import event
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.asyncio.engine import AsyncEngine
from sqlalchemy.orm import DeclarativeBase

from .config import db_config
from .config import AgentControlServerDatabaseConfig, db_config


class Base(DeclarativeBase):
Expand All @@ -13,11 +18,53 @@ class Base(DeclarativeBase):
# Async SQLAlchemy setup for PostgreSQL
db_url = db_config.get_url()

async_engine = create_async_engine(
db_url,
echo=False,
SQLALCHEMY_CHECKED_OUT_CONNECTIONS = Gauge(
"agent_control_server_sqlalchemy_checked_out_connections",
"Number of checked out SQLAlchemy connections.",
["pool_name"],
multiprocess_mode="livesum",
)


def _supports_queue_pool_config(url: str) -> bool:
"""Return whether SQLAlchemy QueuePool kwargs should be applied for this URL."""
return make_url(url).get_backend_name() != "sqlite"


def _build_async_engine_kwargs(
url: str,
config: AgentControlServerDatabaseConfig,
) -> dict[str, Any]:
"""Build async SQLAlchemy engine kwargs from database config."""
kwargs: dict[str, Any] = {"echo": False}
if not _supports_queue_pool_config(url):
return kwargs

kwargs.update(
pool_pre_ping=True,
pool_size=config.pool_size,
max_overflow=config.max_overflow,
pool_timeout=config.pool_timeout_seconds,
pool_reset_on_return="rollback",
)
return kwargs


def _instrument_connection_pool(engine: AsyncEngine) -> None:
"""Track checked-out connections from the async engine's underlying pool."""

@event.listens_for(engine.sync_engine.pool, "checkin")
def receive_checkin(dbapi_conn: Any, connection_record: Any) -> None:
SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").dec()

@event.listens_for(engine.sync_engine.pool, "checkout")
def receive_checkout(dbapi_conn: Any, connection_record: Any, connection_proxy: Any) -> None:
SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").inc()


async_engine = create_async_engine(db_url, **_build_async_engine_kwargs(db_url, db_config))
_instrument_connection_pool(async_engine)

AsyncSessionLocal = async_sessionmaker(
bind=async_engine,
autoflush=False,
Expand Down
68 changes: 37 additions & 31 deletions server/src/agent_control_server/endpoints/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
from agent_control_models.errors import ErrorCode, ValidationErrorItem
from fastapi import APIRouter, Depends, Request
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession

from ..auth_framework import Operation, Principal, require_operation
from ..db import get_async_db
from ..db import AsyncSessionLocal
from ..errors import APIValidationError, NotFoundError
from ..logging_utils import get_logger
from ..models import Agent
Expand Down Expand Up @@ -136,6 +135,41 @@ async def _evaluation_context(request: Request) -> dict[str, object]:
return {"target_type": target_type, "target_id": target_id}


async def _load_engine_controls(
request: EvaluationRequest,
principal: Principal,
) -> list[ControlAdapter]:
"""Load and materialize controls before evaluator execution starts."""
namespace_key = principal.namespace_key

async with AsyncSessionLocal() as db:
agent_result = await db.execute(
select(Agent).where(
Agent.name == request.agent_name,
Agent.namespace_key == namespace_key,
)
)
agent = agent_result.scalar_one_or_none()
if agent is None:
raise NotFoundError(
error_code=ErrorCode.AGENT_NOT_FOUND,
detail=f"Agent '{request.agent_name}' not found",
resource="Agent",
resource_id=request.agent_name,
hint="Register the agent via initAgent before evaluating.",
)

runtime_controls = await ControlService(db).list_runtime_controls_for_agent(
request.agent_name,
namespace_key=namespace_key,
target_type=request.target_type,
target_id=request.target_id,
allow_invalid_step_name_regex=True,
)

return [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls]


@router.post(
"",
response_model=EvaluationResponse,
Expand All @@ -144,7 +178,6 @@ async def _evaluation_context(request: Request) -> dict[str, object]:
)
async def evaluate(
request: EvaluationRequest,
db: AsyncSession = Depends(get_async_db),
principal: Principal = Depends(
require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context)
),
Expand All @@ -163,34 +196,7 @@ async def evaluate(
on the server; SDKs reconstruct and emit those events separately through
the observability ingestion endpoint.
"""
namespace_key = principal.namespace_key

agent_result = await db.execute(
select(Agent).where(
Agent.name == request.agent_name,
Agent.namespace_key == namespace_key,
)
)
agent = agent_result.scalar_one_or_none()
if agent is None:
raise NotFoundError(
error_code=ErrorCode.AGENT_NOT_FOUND,
detail=f"Agent '{request.agent_name}' not found",
resource="Agent",
resource_id=request.agent_name,
hint="Register the agent via initAgent before evaluating.",
)

runtime_controls = await ControlService(db).list_runtime_controls_for_agent(
request.agent_name,
namespace_key=namespace_key,
target_type=request.target_type,
target_id=request.target_id,
allow_invalid_step_name_regex=True,
)

engine_controls = [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls]

engine_controls = await _load_engine_controls(request, principal)
engine = ControlEngine(engine_controls)
try:
raw_response = await engine.process(request)
Expand Down
5 changes: 4 additions & 1 deletion server/src/agent_control_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from . import __version__ as server_version
from .auth import get_api_key_from_header
from .config import observability_settings, settings
from .db import AsyncSessionLocal
from .db import AsyncSessionLocal, async_engine
from .endpoints.agents import router as agent_router
from .endpoints.auth import router as auth_router
from .endpoints.control_bindings import router as control_binding_router
Expand Down Expand Up @@ -198,6 +198,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await app.state.event_store.close()
logger.info("EventStore closed")

await async_engine.dispose()
logger.info("Database engine disposed")


app = FastAPI(
title="Agent Control Server",
Expand Down
15 changes: 15 additions & 0 deletions server/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@ def test_db_config_ignores_blank_agent_control_url_and_uses_legacy(monkeypatch)
assert config.get_url() == "sqlite:///tmp/legacy.db"


def test_db_config_reads_pool_settings_from_env(monkeypatch) -> None:
# Given: database pool settings are configured via environment variables
monkeypatch.setenv("AGENT_CONTROL_DB_POOL_SIZE", "7")
monkeypatch.setenv("AGENT_CONTROL_DB_MAX_OVERFLOW", "2")
monkeypatch.setenv("AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", "3.5")

# When: loading DB config from the environment
config = AgentControlServerDatabaseConfig()

# Then: the explicit pool settings are used
assert config.pool_size == 7
assert config.max_overflow == 2
assert config.pool_timeout_seconds == 3.5


def test_settings_parses_cors_origins_string() -> None:
# Given: a comma-separated CORS origins string
settings = Settings(cors_origins="https://a.example, https://b.example")
Expand Down
44 changes: 44 additions & 0 deletions server/tests/test_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Tests for server database engine configuration."""

from agent_control_server.config import AgentControlServerDatabaseConfig
from agent_control_server.db import _build_async_engine_kwargs


def test_build_async_engine_kwargs_applies_postgres_pool_config() -> None:
# Given: custom PostgreSQL connection pool settings
config = AgentControlServerDatabaseConfig(
pool_size=7,
max_overflow=2,
pool_timeout_seconds=3.5,
)

# When: building async engine kwargs for Postgres
kwargs = _build_async_engine_kwargs(
"postgresql+psycopg://user:password@localhost:5432/agent_control",
config,
)

# Then: the engine is configured with a bounded, health-checked pool
assert kwargs == {
"echo": False,
"pool_pre_ping": True,
"pool_size": 7,
"max_overflow": 2,
"pool_timeout": 3.5,
"pool_reset_on_return": "rollback",
}


def test_build_async_engine_kwargs_skips_pool_config_for_sqlite() -> None:
# Given: custom pool settings with a SQLite URL
config = AgentControlServerDatabaseConfig(
pool_size=7,
max_overflow=2,
pool_timeout_seconds=3.5,
)

# When: building async engine kwargs for SQLite
kwargs = _build_async_engine_kwargs("sqlite+aiosqlite:///tmp/agent-control.db", config)

# Then: SQLite keeps SQLAlchemy's default local-dev pool behavior
assert kwargs == {"echo": False}
27 changes: 27 additions & 0 deletions server/tests/test_evaluation_error_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from agent_control_models import (
ControlMatch,
EvaluationRequest,
EvaluationResponse,
EvaluatorResult,
Step,
)
from fastapi.testclient import TestClient

from agent_control_server.db import async_engine
from agent_control_server.endpoints.evaluation import (
SAFE_EVALUATOR_ERROR,
SAFE_EVALUATOR_TIMEOUT_ERROR,
Expand Down Expand Up @@ -327,6 +329,31 @@ async def raise_value_error(*_args, **_kwargs):
assert body["errors"][0]["message"] == "Invalid evaluation request or control configuration."


def test_evaluation_releases_db_connection_before_engine_processing(
client: TestClient,
monkeypatch,
) -> None:
"""Evaluation should not hold a DB connection while evaluator work runs."""
agent_name, _ = create_and_assign_policy(client)
checked_out_counts: list[int] = []

import agent_control_engine.core as core_module

async def process_with_pool_assertion(*_args, **_kwargs):
pool = async_engine.sync_engine.pool
checked_out = pool.checkedout() if hasattr(pool, "checkedout") else 0
checked_out_counts.append(checked_out)
return EvaluationResponse(is_safe=True, confidence=1.0)

monkeypatch.setattr(core_module.ControlEngine, "process", process_with_pool_assertion)

payload = Step(type="llm", name="test-step", input="test content", output=None)
req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre")
resp = client.post("/api/v1/evaluation", json=req.model_dump(mode="json"))

assert resp.status_code == 200
assert checked_out_counts == [0]


def test_evaluation_ignores_merge_headers_and_remains_pure(client: TestClient) -> None:
"""/evaluation should return only semantic results regardless of merge headers."""
Expand Down
Loading