diff --git a/server/src/agent_control_server/config.py b/server/src/agent_control_server/config.py index 82ce14a4..a295d18c 100644 --- a/server/src/agent_control_server/config.py +++ b/server/src/agent_control_server/config.py @@ -125,6 +125,24 @@ class AgentControlServerDatabaseConfig(BaseSettings): "DB_DATABASE", ) driver: str = _env_alias_field("psycopg", "AGENT_CONTROL_DB_DRIVER", "DB_DRIVER") + pool_size: int = Field( + default=5, + ge=1, + validation_alias=AliasChoices("AGENT_CONTROL_DB_POOL_SIZE", "DB_POOL_SIZE"), + ) + max_overflow: int = Field( + default=0, + ge=0, + validation_alias=AliasChoices("AGENT_CONTROL_DB_MAX_OVERFLOW", "DB_MAX_OVERFLOW"), + ) + pool_timeout_seconds: float = Field( + default=5.0, + gt=0, + validation_alias=AliasChoices( + "AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", + "DB_POOL_TIMEOUT_SECONDS", + ), + ) def get_url(self) -> str: """Get database URL, preferring an explicit URL if configured.""" diff --git a/server/src/agent_control_server/db.py b/server/src/agent_control_server/db.py index 4ba28846..a0816886 100644 --- a/server/src/agent_control_server/db.py +++ b/server/src/agent_control_server/db.py @@ -1,9 +1,14 @@ from collections.abc import AsyncGenerator +from typing import Any +from prometheus_client import Gauge +from sqlalchemy import event +from sqlalchemy.engine.url import make_url from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio.engine import AsyncEngine from sqlalchemy.orm import DeclarativeBase -from .config import db_config +from .config import AgentControlServerDatabaseConfig, db_config class Base(DeclarativeBase): @@ -13,11 +18,53 @@ class Base(DeclarativeBase): # Async SQLAlchemy setup for PostgreSQL db_url = db_config.get_url() -async_engine = create_async_engine( - db_url, - echo=False, +SQLALCHEMY_CHECKED_OUT_CONNECTIONS = Gauge( + "agent_control_server_sqlalchemy_checked_out_connections", + "Number of checked out SQLAlchemy connections.", + ["pool_name"], + multiprocess_mode="livesum", ) + +def _supports_queue_pool_config(url: str) -> bool: + """Return whether SQLAlchemy QueuePool kwargs should be applied for this URL.""" + return make_url(url).get_backend_name() != "sqlite" + + +def _build_async_engine_kwargs( + url: str, + config: AgentControlServerDatabaseConfig, +) -> dict[str, Any]: + """Build async SQLAlchemy engine kwargs from database config.""" + kwargs: dict[str, Any] = {"echo": False} + if not _supports_queue_pool_config(url): + return kwargs + + kwargs.update( + pool_pre_ping=True, + pool_size=config.pool_size, + max_overflow=config.max_overflow, + pool_timeout=config.pool_timeout_seconds, + pool_reset_on_return="rollback", + ) + return kwargs + + +def _instrument_connection_pool(engine: AsyncEngine) -> None: + """Track checked-out connections from the async engine's underlying pool.""" + + @event.listens_for(engine.sync_engine.pool, "checkin") + def receive_checkin(dbapi_conn: Any, connection_record: Any) -> None: + SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").dec() + + @event.listens_for(engine.sync_engine.pool, "checkout") + def receive_checkout(dbapi_conn: Any, connection_record: Any, connection_proxy: Any) -> None: + SQLALCHEMY_CHECKED_OUT_CONNECTIONS.labels("default").inc() + + +async_engine = create_async_engine(db_url, **_build_async_engine_kwargs(db_url, db_config)) +_instrument_connection_pool(async_engine) + AsyncSessionLocal = async_sessionmaker( bind=async_engine, autoflush=False, diff --git a/server/src/agent_control_server/endpoints/evaluation.py b/server/src/agent_control_server/endpoints/evaluation.py index bc66381f..a31d757d 100644 --- a/server/src/agent_control_server/endpoints/evaluation.py +++ b/server/src/agent_control_server/endpoints/evaluation.py @@ -13,10 +13,9 @@ from agent_control_models.errors import ErrorCode, ValidationErrorItem from fastapi import APIRouter, Depends, Request from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from ..auth_framework import Operation, Principal, require_operation -from ..db import get_async_db +from ..db import AsyncSessionLocal from ..errors import APIValidationError, NotFoundError from ..logging_utils import get_logger from ..models import Agent @@ -136,6 +135,41 @@ async def _evaluation_context(request: Request) -> dict[str, object]: return {"target_type": target_type, "target_id": target_id} +async def _load_engine_controls( + request: EvaluationRequest, + principal: Principal, +) -> list[ControlAdapter]: + """Load and materialize controls before evaluator execution starts.""" + namespace_key = principal.namespace_key + + async with AsyncSessionLocal() as db: + agent_result = await db.execute( + select(Agent).where( + Agent.name == request.agent_name, + Agent.namespace_key == namespace_key, + ) + ) + agent = agent_result.scalar_one_or_none() + if agent is None: + raise NotFoundError( + error_code=ErrorCode.AGENT_NOT_FOUND, + detail=f"Agent '{request.agent_name}' not found", + resource="Agent", + resource_id=request.agent_name, + hint="Register the agent via initAgent before evaluating.", + ) + + runtime_controls = await ControlService(db).list_runtime_controls_for_agent( + request.agent_name, + namespace_key=namespace_key, + target_type=request.target_type, + target_id=request.target_id, + allow_invalid_step_name_regex=True, + ) + + return [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls] + + @router.post( "", response_model=EvaluationResponse, @@ -144,7 +178,6 @@ async def _evaluation_context(request: Request) -> dict[str, object]: ) async def evaluate( request: EvaluationRequest, - db: AsyncSession = Depends(get_async_db), principal: Principal = Depends( require_operation(Operation.RUNTIME_USE, context_builder=_evaluation_context) ), @@ -163,34 +196,7 @@ async def evaluate( on the server; SDKs reconstruct and emit those events separately through the observability ingestion endpoint. """ - namespace_key = principal.namespace_key - - agent_result = await db.execute( - select(Agent).where( - Agent.name == request.agent_name, - Agent.namespace_key == namespace_key, - ) - ) - agent = agent_result.scalar_one_or_none() - if agent is None: - raise NotFoundError( - error_code=ErrorCode.AGENT_NOT_FOUND, - detail=f"Agent '{request.agent_name}' not found", - resource="Agent", - resource_id=request.agent_name, - hint="Register the agent via initAgent before evaluating.", - ) - - runtime_controls = await ControlService(db).list_runtime_controls_for_agent( - request.agent_name, - namespace_key=namespace_key, - target_type=request.target_type, - target_id=request.target_id, - allow_invalid_step_name_regex=True, - ) - - engine_controls = [ControlAdapter(c.id, c.name, c.control) for c in runtime_controls] - + engine_controls = await _load_engine_controls(request, principal) engine = ControlEngine(engine_controls) try: raw_response = await engine.process(request) diff --git a/server/src/agent_control_server/main.py b/server/src/agent_control_server/main.py index 4aea5a6e..16152824 100644 --- a/server/src/agent_control_server/main.py +++ b/server/src/agent_control_server/main.py @@ -19,7 +19,7 @@ from . import __version__ as server_version from .auth import get_api_key_from_header from .config import observability_settings, settings -from .db import AsyncSessionLocal +from .db import AsyncSessionLocal, async_engine from .endpoints.agents import router as agent_router from .endpoints.auth import router as auth_router from .endpoints.control_bindings import router as control_binding_router @@ -198,6 +198,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await app.state.event_store.close() logger.info("EventStore closed") + await async_engine.dispose() + logger.info("Database engine disposed") + app = FastAPI( title="Agent Control Server", diff --git a/server/tests/test_config.py b/server/tests/test_config.py index 51f48be1..59fdf338 100644 --- a/server/tests/test_config.py +++ b/server/tests/test_config.py @@ -89,6 +89,21 @@ def test_db_config_ignores_blank_agent_control_url_and_uses_legacy(monkeypatch) assert config.get_url() == "sqlite:///tmp/legacy.db" +def test_db_config_reads_pool_settings_from_env(monkeypatch) -> None: + # Given: database pool settings are configured via environment variables + monkeypatch.setenv("AGENT_CONTROL_DB_POOL_SIZE", "7") + monkeypatch.setenv("AGENT_CONTROL_DB_MAX_OVERFLOW", "2") + monkeypatch.setenv("AGENT_CONTROL_DB_POOL_TIMEOUT_SECONDS", "3.5") + + # When: loading DB config from the environment + config = AgentControlServerDatabaseConfig() + + # Then: the explicit pool settings are used + assert config.pool_size == 7 + assert config.max_overflow == 2 + assert config.pool_timeout_seconds == 3.5 + + def test_settings_parses_cors_origins_string() -> None: # Given: a comma-separated CORS origins string settings = Settings(cors_origins="https://a.example, https://b.example") diff --git a/server/tests/test_db.py b/server/tests/test_db.py new file mode 100644 index 00000000..2b8b314a --- /dev/null +++ b/server/tests/test_db.py @@ -0,0 +1,44 @@ +"""Tests for server database engine configuration.""" + +from agent_control_server.config import AgentControlServerDatabaseConfig +from agent_control_server.db import _build_async_engine_kwargs + + +def test_build_async_engine_kwargs_applies_postgres_pool_config() -> None: + # Given: custom PostgreSQL connection pool settings + config = AgentControlServerDatabaseConfig( + pool_size=7, + max_overflow=2, + pool_timeout_seconds=3.5, + ) + + # When: building async engine kwargs for Postgres + kwargs = _build_async_engine_kwargs( + "postgresql+psycopg://user:password@localhost:5432/agent_control", + config, + ) + + # Then: the engine is configured with a bounded, health-checked pool + assert kwargs == { + "echo": False, + "pool_pre_ping": True, + "pool_size": 7, + "max_overflow": 2, + "pool_timeout": 3.5, + "pool_reset_on_return": "rollback", + } + + +def test_build_async_engine_kwargs_skips_pool_config_for_sqlite() -> None: + # Given: custom pool settings with a SQLite URL + config = AgentControlServerDatabaseConfig( + pool_size=7, + max_overflow=2, + pool_timeout_seconds=3.5, + ) + + # When: building async engine kwargs for SQLite + kwargs = _build_async_engine_kwargs("sqlite+aiosqlite:///tmp/agent-control.db", config) + + # Then: SQLite keeps SQLAlchemy's default local-dev pool behavior + assert kwargs == {"echo": False} diff --git a/server/tests/test_evaluation_error_handling.py b/server/tests/test_evaluation_error_handling.py index 1df795da..6de16afa 100644 --- a/server/tests/test_evaluation_error_handling.py +++ b/server/tests/test_evaluation_error_handling.py @@ -5,11 +5,13 @@ from agent_control_models import ( ControlMatch, EvaluationRequest, + EvaluationResponse, EvaluatorResult, Step, ) from fastapi.testclient import TestClient +from agent_control_server.db import async_engine from agent_control_server.endpoints.evaluation import ( SAFE_EVALUATOR_ERROR, SAFE_EVALUATOR_TIMEOUT_ERROR, @@ -327,6 +329,31 @@ async def raise_value_error(*_args, **_kwargs): assert body["errors"][0]["message"] == "Invalid evaluation request or control configuration." +def test_evaluation_releases_db_connection_before_engine_processing( + client: TestClient, + monkeypatch, +) -> None: + """Evaluation should not hold a DB connection while evaluator work runs.""" + agent_name, _ = create_and_assign_policy(client) + checked_out_counts: list[int] = [] + + import agent_control_engine.core as core_module + + async def process_with_pool_assertion(*_args, **_kwargs): + pool = async_engine.sync_engine.pool + checked_out = pool.checkedout() if hasattr(pool, "checkedout") else 0 + checked_out_counts.append(checked_out) + return EvaluationResponse(is_safe=True, confidence=1.0) + + monkeypatch.setattr(core_module.ControlEngine, "process", process_with_pool_assertion) + + payload = Step(type="llm", name="test-step", input="test content", output=None) + req = EvaluationRequest(agent_name=agent_name, step=payload, stage="pre") + resp = client.post("/api/v1/evaluation", json=req.model_dump(mode="json")) + + assert resp.status_code == 200 + assert checked_out_counts == [0] + def test_evaluation_ignores_merge_headers_and_remains_pure(client: TestClient) -> None: """/evaluation should return only semantic results regardless of merge headers."""