diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..6a4989cc --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,36 @@ +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def get_run(run_id: int, expdb: Connection) -> Row | None: + """Check if a run exists. Used to distinguish 571 (run not found) from 572 (no trace).""" + return expdb.execute( + text( + """ + SELECT rid + FROM run + WHERE rid = :run_id + """, + ), + parameters={"run_id": run_id}, + ).one_or_none() + + +def get_trace(run_id: int, expdb: Connection) -> Sequence[Row]: + """Fetch all trace iterations for a run, ordered as PHP does: repeat, fold, iteration.""" + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `repeat`, `fold`, `iteration`, setup_string, evaluation, selected + FROM trace + WHERE run_id = :run_id + ORDER BY `repeat` ASC, `fold` ASC, `iteration` ASC + """, + ), + parameters={"run_id": run_id}, + ).all(), + ) diff --git a/src/main.py b/src/main.py index 560b4c50..2fe219ae 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..61e853a0 --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,56 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import Connection + +import database.runs +from routers.dependencies import expdb_connection +from schemas.runs import RunTrace, RunTraceResponse, TraceIteration + +router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.get("/trace/{run_id}") +def get_run_trace( + run_id: int, + expdb: Annotated[Connection, Depends(expdb_connection)], +) -> RunTraceResponse: + """Get the optimization trace for a run. + + Returns all hyperparameter configurations tried during tuning, their + evaluations, and whether each was selected. Mirrors PHP API behavior. + """ + # 571: run does not exist at all + if not database.runs.get_run(run_id, expdb): + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "571", "message": "Run not found."}, + ) + + trace_rows = database.runs.get_trace(run_id, expdb) + + # 572: run exists but has no trace data + if not trace_rows: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "572", "message": "No trace found for run."}, + ) + + return RunTraceResponse( + trace=RunTrace( + # Cast to str: PHP returns run_id and all iteration fields as strings. + run_id=str(run_id), + trace_iteration=[ + TraceIteration( + repeat=str(row.repeat), + fold=str(row.fold), + iteration=str(row.iteration), + setup_string=row.setup_string, + evaluation=row.evaluation, + selected=row.selected, + ) + for row in trace_rows + ], + ), + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 00000000..30123b33 --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,22 @@ +from typing import Literal + +from pydantic import BaseModel + + +class TraceIteration(BaseModel): + repeat: str + fold: str + iteration: str + setup_string: str + evaluation: str + selected: Literal["true", "false"] + + +class RunTrace(BaseModel): + run_id: str + trace_iteration: list[TraceIteration] + + +# Wraps RunTrace in {"trace": {...}} to match PHP API response structure. +class RunTraceResponse(BaseModel): + trace: RunTrace diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 00000000..a1e6916b --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,71 @@ +from http import HTTPStatus + +import pytest +from starlette.testclient import TestClient + + +@pytest.mark.parametrize("run_id", [34]) +def test_get_run_trace(py_api: TestClient, run_id: int) -> None: + response = py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.OK + + body = response.json() + assert "trace" in body + + trace = body["trace"] + assert trace["run_id"] == str(run_id) + assert "trace_iteration" in trace + assert len(trace["trace_iteration"]) > 0 + + # Verify structure and types of each iteration — PHP returns all fields as strings + for iteration in trace["trace_iteration"]: + assert "repeat" in iteration + assert "fold" in iteration + assert "iteration" in iteration + assert "setup_string" in iteration + assert "evaluation" in iteration + assert "selected" in iteration + assert isinstance(iteration["repeat"], str) + assert isinstance(iteration["fold"], str) + assert isinstance(iteration["iteration"], str) + assert isinstance(iteration["setup_string"], str) + assert isinstance(iteration["evaluation"], str) + assert iteration["selected"] in ("true", "false") + + +def test_get_run_trace_ordering(py_api: TestClient) -> None: + """Trace iterations must be ordered by repeat, fold, iteration ASC — matches PHP.""" + response = py_api.get("/runs/trace/34") + assert response.status_code == HTTPStatus.OK + + iterations = response.json()["trace"]["trace_iteration"] + keys = [(int(i["repeat"]), int(i["fold"]), int(i["iteration"])) for i in iterations] + assert keys == sorted(keys) + + +def test_get_run_trace_run_not_found(py_api: TestClient) -> None: + """Run does not exist at all — expect error 571.""" + response = py_api.get("/runs/trace/999999") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "571" + + +def test_get_run_trace_negative_id(py_api: TestClient) -> None: + """Negative run_id can never exist — expect error 571.""" + response = py_api.get("/runs/trace/-1") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "571" + + +def test_get_run_trace_invalid_id(py_api: TestClient) -> None: + """Non-integer run_id — FastAPI should reject with 422 before hitting our handler.""" + response = py_api.get("/runs/trace/abc") + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +def test_get_run_trace_no_trace(py_api: TestClient) -> None: + """Run exists but has no trace data — expect error 572. + Run 24 exists in the test DB but has no trace rows.""" + response = py_api.get("/runs/trace/24") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "572"