diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d2..8d4f37c 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,9 +1,15 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Row, text +from sqlalchemy import Row, RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection +ALLOWED_LOOKUP_TABLES = ["estimation_procedure", "evaluation_measure", "task_type", "dataset"] +PK_MAPPING = { + "task_type": "ttid", + "dataset": "did", +} + async def get(id_: int, expdb: AsyncConnection) -> Row | None: row = await expdb.execute( @@ -115,3 +121,22 @@ async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: ) tag_rows = rows.all() return [row.tag for row in tag_rows] + + +async def get_lookup_data(table: str, id_: int, expdb: AsyncConnection) -> RowMapping | None: + if table not in ALLOWED_LOOKUP_TABLES: + msg = f"Table {table} is not allowed for lookup." + raise ValueError(msg) + + pk = PK_MAPPING.get(table, "id") + result = await expdb.execute( + text( + f""" + SELECT * + FROM {table} + WHERE `{pk}` = :id_ + """, # noqa: S608 + ), + parameters={"id_": id_}, + ) + return result.mappings().one_or_none() diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 788cd80..c02c8da 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -3,8 +3,8 @@ from typing import Annotated, cast import xmltodict -from fastapi import APIRouter, Depends -from sqlalchemy import RowMapping, text +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import RowMapping from sqlalchemy.ext.asyncio import AsyncConnection import config @@ -17,6 +17,7 @@ router = APIRouter(prefix="/tasks", tags=["tasks"]) type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None +ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"} def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: @@ -95,7 +96,7 @@ async def fill_template( ) -async def _fill_json_template( # noqa: C901 +async def _fill_json_template( # noqa: C901, PLR0912 template: JSON, task: RowMapping, task_inputs: dict[str, str | int], @@ -128,23 +129,29 @@ async def _fill_json_template( # noqa: C901 (field,) = match.groups() if field not in fetched_data: table, _ = field.split(".") - result = await connection.execute( - text( - f""" - SELECT * - FROM {table} - WHERE `id` = :id_ - """, # noqa: S608 - ), - # Not sure how parametrize table names, as the parametrization adds - # quotes which is not legal. - parameters={"id_": int(task_inputs[table])}, - ) - rows = result.mappings() - row_data = next(rows, None) + # List of tables allowed for [LOOKUP:table.column] directive. + # This is a security measure to prevent SQL injection via table names. + if table not in task_inputs or not task_inputs[table]: + msg = f"Missing or empty input for lookup table: {table}" + raise HTTPException(status_code=400, detail=msg) + + try: + id_val = int(task_inputs[table]) + except ValueError: + msg = f"Invalid integer id for table {table}: {task_inputs[table]}" + raise HTTPException(status_code=400, detail=msg) from None + + try: + row_data = await database.tasks.get_lookup_data( + table=table, + id_=id_val, + expdb=connection, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e if row_data is None: msg = f"No data found for table {table} with id {task_inputs[table]}" - raise ValueError(msg) + raise HTTPException(status_code=400, detail=msg) for column, value in row_data.items(): fetched_data[f"{table}.{column}"] = value if match.string == template: