Skip to content
Open
27 changes: 26 additions & 1 deletion src/database/tasks.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from collections.abc import Sequence
from typing import cast

from sqlalchemy import Row, text
from sqlalchemy import Row, RowMapping, text
from sqlalchemy.ext.asyncio import AsyncConnection

ALLOWED_LOOKUP_TABLES = ["estimation_procedure", "evaluation_measure", "task_type", "dataset"]
PK_MAPPING = {
"task_type": "ttid",
"dataset": "did",
}


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
Expand Down Expand Up @@ -115,3 +121,22 @@ async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def get_lookup_data(table: str, id_: int, expdb: AsyncConnection) -> RowMapping | None:
if table not in ALLOWED_LOOKUP_TABLES:
msg = f"Table {table} is not allowed for lookup."
raise ValueError(msg)

pk = PK_MAPPING.get(table, "id")
result = await expdb.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `{pk}` = :id_
""", # noqa: S608
),
parameters={"id_": id_},
)
return result.mappings().one_or_none()
43 changes: 25 additions & 18 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from typing import Annotated, cast

import xmltodict
from fastapi import APIRouter, Depends
from sqlalchemy import RowMapping, text
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import RowMapping
from sqlalchemy.ext.asyncio import AsyncConnection

import config
Expand All @@ -17,6 +17,7 @@
router = APIRouter(prefix="/tasks", tags=["tasks"])

type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None
ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"}


def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]:
Expand Down Expand Up @@ -95,7 +96,7 @@ async def fill_template(
)


async def _fill_json_template( # noqa: C901
async def _fill_json_template( # noqa: C901, PLR0912
template: JSON,
task: RowMapping,
task_inputs: dict[str, str | int],
Expand Down Expand Up @@ -128,23 +129,29 @@ async def _fill_json_template( # noqa: C901
(field,) = match.groups()
if field not in fetched_data:
table, _ = field.split(".")
result = await connection.execute(
text(
f"""
SELECT *
FROM {table}
WHERE `id` = :id_
""", # noqa: S608
),
# Not sure how parametrize table names, as the parametrization adds
# quotes which is not legal.
parameters={"id_": int(task_inputs[table])},
)
rows = result.mappings()
row_data = next(rows, None)
# List of tables allowed for [LOOKUP:table.column] directive.
# This is a security measure to prevent SQL injection via table names.
if table not in task_inputs or not task_inputs[table]:
msg = f"Missing or empty input for lookup table: {table}"
raise HTTPException(status_code=400, detail=msg)

try:
id_val = int(task_inputs[table])
except ValueError:
msg = f"Invalid integer id for table {table}: {task_inputs[table]}"
raise HTTPException(status_code=400, detail=msg) from None

try:
row_data = await database.tasks.get_lookup_data(
table=table,
id_=id_val,
expdb=connection,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
if row_data is None:
msg = f"No data found for table {table} with id {task_inputs[table]}"
raise ValueError(msg)
raise HTTPException(status_code=400, detail=msg)
for column, value in row_data.items():
fetched_data[f"{table}.{column}"] = value
if match.string == template:
Expand Down