From 64ee6c20a72660699f7c854bc346d647036e6f4f Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sun, 8 Mar 2026 12:03:41 +0530 Subject: [PATCH 1/4] Add GET /runs/trace/{run_id} endpoint --- src/database/runs.py | 36 ++++++++++++++++ src/main.py | 2 + src/routers/openml/runs.py | 51 ++++++++++++++++++++++ src/schemas/runs.py | 20 +++++++++ tests/routers/openml/runs_test.py | 71 +++++++++++++++++++++++++++++++ 5 files changed, 180 insertions(+) create mode 100644 src/database/runs.py create mode 100644 src/routers/openml/runs.py create mode 100644 src/schemas/runs.py create mode 100644 tests/routers/openml/runs_test.py diff --git a/src/database/runs.py b/src/database/runs.py new file mode 100644 index 00000000..6a4989cc --- /dev/null +++ b/src/database/runs.py @@ -0,0 +1,36 @@ +from collections.abc import Sequence +from typing import cast + +from sqlalchemy import Connection, Row, text + + +def get_run(run_id: int, expdb: Connection) -> Row | None: + """Check if a run exists. Used to distinguish 571 (run not found) from 572 (no trace).""" + return expdb.execute( + text( + """ + SELECT rid + FROM run + WHERE rid = :run_id + """, + ), + parameters={"run_id": run_id}, + ).one_or_none() + + +def get_trace(run_id: int, expdb: Connection) -> Sequence[Row]: + """Fetch all trace iterations for a run, ordered as PHP does: repeat, fold, iteration.""" + return cast( + "Sequence[Row]", + expdb.execute( + text( + """ + SELECT `repeat`, `fold`, `iteration`, setup_string, evaluation, selected + FROM trace + WHERE run_id = :run_id + ORDER BY `repeat` ASC, `fold` ASC, `iteration` ASC + """, + ), + parameters={"run_id": run_id}, + ).all(), + ) diff --git a/src/main.py b/src/main.py index 560b4c50..2fe219ae 100644 --- a/src/main.py +++ b/src/main.py @@ -11,6 +11,7 @@ from routers.openml.evaluations import router as evaluationmeasures_router from routers.openml.flows import router as flows_router from routers.openml.qualities import router as qualities_router +from routers.openml.runs import router as runs_router from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(runs_router) return app diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py new file mode 100644 index 00000000..74edac6e --- /dev/null +++ b/src/routers/openml/runs.py @@ -0,0 +1,51 @@ +from http import HTTPStatus +from typing import Annotated + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import Connection + +import database.runs +from routers.dependencies import expdb_connection +from schemas.runs import RunTrace, RunTraceResponse, TraceIteration + +router = APIRouter(prefix="/runs", tags=["runs"]) + + +@router.get("/trace/{run_id}") +def get_run_trace( + run_id: int, + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> RunTraceResponse: + # 571: run does not exist at all + if not database.runs.get_run(run_id, expdb): + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "571", "message": "Run not found."}, + ) + + trace_rows = database.runs.get_trace(run_id, expdb) + + # 572: run exists but has no trace data + if not trace_rows: + raise HTTPException( + status_code=HTTPStatus.PRECONDITION_FAILED, + detail={"code": "572", "message": "No trace found for run."}, + ) + + return RunTraceResponse( + trace=RunTrace( + # Cast to str: PHP returns run_id and all iteration fields as strings. + run_id=str(run_id), + trace_iteration=[ + TraceIteration( + repeat=str(row.repeat), + fold=str(row.fold), + iteration=str(row.iteration), + setup_string=row.setup_string, + evaluation=row.evaluation, + selected=row.selected, + ) + for row in trace_rows + ], + ), + ) diff --git a/src/schemas/runs.py b/src/schemas/runs.py new file mode 100644 index 00000000..b22a5b7c --- /dev/null +++ b/src/schemas/runs.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + + +class TraceIteration(BaseModel): + repeat: str + fold: str + iteration: str + setup_string: str + evaluation: str + selected: str + + +class RunTrace(BaseModel): + run_id: str + trace_iteration: list[TraceIteration] + + +# Wraps RunTrace in {"trace": {...}} to match PHP API response structure. +class RunTraceResponse(BaseModel): + trace: RunTrace diff --git a/tests/routers/openml/runs_test.py b/tests/routers/openml/runs_test.py new file mode 100644 index 00000000..a1e6916b --- /dev/null +++ b/tests/routers/openml/runs_test.py @@ -0,0 +1,71 @@ +from http import HTTPStatus + +import pytest +from starlette.testclient import TestClient + + +@pytest.mark.parametrize("run_id", [34]) +def test_get_run_trace(py_api: TestClient, run_id: int) -> None: + response = py_api.get(f"/runs/trace/{run_id}") + assert response.status_code == HTTPStatus.OK + + body = response.json() + assert "trace" in body + + trace = body["trace"] + assert trace["run_id"] == str(run_id) + assert "trace_iteration" in trace + assert len(trace["trace_iteration"]) > 0 + + # Verify structure and types of each iteration — PHP returns all fields as strings + for iteration in trace["trace_iteration"]: + assert "repeat" in iteration + assert "fold" in iteration + assert "iteration" in iteration + assert "setup_string" in iteration + assert "evaluation" in iteration + assert "selected" in iteration + assert isinstance(iteration["repeat"], str) + assert isinstance(iteration["fold"], str) + assert isinstance(iteration["iteration"], str) + assert isinstance(iteration["setup_string"], str) + assert isinstance(iteration["evaluation"], str) + assert iteration["selected"] in ("true", "false") + + +def test_get_run_trace_ordering(py_api: TestClient) -> None: + """Trace iterations must be ordered by repeat, fold, iteration ASC — matches PHP.""" + response = py_api.get("/runs/trace/34") + assert response.status_code == HTTPStatus.OK + + iterations = response.json()["trace"]["trace_iteration"] + keys = [(int(i["repeat"]), int(i["fold"]), int(i["iteration"])) for i in iterations] + assert keys == sorted(keys) + + +def test_get_run_trace_run_not_found(py_api: TestClient) -> None: + """Run does not exist at all — expect error 571.""" + response = py_api.get("/runs/trace/999999") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "571" + + +def test_get_run_trace_negative_id(py_api: TestClient) -> None: + """Negative run_id can never exist — expect error 571.""" + response = py_api.get("/runs/trace/-1") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "571" + + +def test_get_run_trace_invalid_id(py_api: TestClient) -> None: + """Non-integer run_id — FastAPI should reject with 422 before hitting our handler.""" + response = py_api.get("/runs/trace/abc") + assert response.status_code == HTTPStatus.UNPROCESSABLE_ENTITY + + +def test_get_run_trace_no_trace(py_api: TestClient) -> None: + """Run exists but has no trace data — expect error 572. + Run 24 exists in the test DB but has no trace rows.""" + response = py_api.get("/runs/trace/24") + assert response.status_code == HTTPStatus.PRECONDITION_FAILED + assert response.json()["detail"]["code"] == "572" From 257c444a5e6393ebedd5fa862b30fa64899ca3ec Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sun, 8 Mar 2026 12:20:32 +0530 Subject: [PATCH 2/4] Address review feedback: remove None default, add Literal type for selected --- src/routers/openml/runs.py | 2 +- src/schemas/runs.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 74edac6e..0791de92 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -14,7 +14,7 @@ @router.get("/trace/{run_id}") def get_run_trace( run_id: int, - expdb: Annotated[Connection, Depends(expdb_connection)] = None, + expdb: Annotated[Connection, Depends(expdb_connection)], ) -> RunTraceResponse: # 571: run does not exist at all if not database.runs.get_run(run_id, expdb): diff --git a/src/schemas/runs.py b/src/schemas/runs.py index b22a5b7c..30123b33 100644 --- a/src/schemas/runs.py +++ b/src/schemas/runs.py @@ -1,3 +1,5 @@ +from typing import Literal + from pydantic import BaseModel @@ -7,7 +9,7 @@ class TraceIteration(BaseModel): iteration: str setup_string: str evaluation: str - selected: str + selected: Literal["true", "false"] class RunTrace(BaseModel): From ae1f4560c2aa0ccc38f7f9192616942966f94de0 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sun, 8 Mar 2026 12:39:46 +0530 Subject: [PATCH 3/4] Add docstring to get_run_trace, cast selected to str --- src/routers/openml/runs.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index 0791de92..d551bb32 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -16,6 +16,11 @@ def get_run_trace( run_id: int, expdb: Annotated[Connection, Depends(expdb_connection)], ) -> RunTraceResponse: + """Get the optimization trace for a run. + + Returns all hyperparameter configurations tried during tuning, their + evaluations, and whether each was selected. Mirrors PHP API behavior. + """ # 571: run does not exist at all if not database.runs.get_run(run_id, expdb): raise HTTPException( @@ -43,7 +48,7 @@ def get_run_trace( iteration=str(row.iteration), setup_string=row.setup_string, evaluation=row.evaluation, - selected=row.selected, + selected="true" if row.selected else "false", ) for row in trace_rows ], From 6b876f36e2de92745d063c07dc775887c342f0e0 Mon Sep 17 00:00:00 2001 From: saathviksheerla Date: Sun, 8 Mar 2026 13:07:28 +0530 Subject: [PATCH 4/4] Use row.selected directly: enum returns plain string from DB --- src/routers/openml/runs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/routers/openml/runs.py b/src/routers/openml/runs.py index d551bb32..61e853a0 100644 --- a/src/routers/openml/runs.py +++ b/src/routers/openml/runs.py @@ -48,7 +48,7 @@ def get_run_trace( iteration=str(row.iteration), setup_string=row.setup_string, evaluation=row.evaluation, - selected="true" if row.selected else "false", + selected=row.selected, ) for row in trace_rows ],