From 8dc0b563d2f3e7a0e942563ec5280158c29873b9 Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Fri, 20 Mar 2026 17:34:35 +0530 Subject: [PATCH 1/7] security: implement SQL parameterization and table whitelisting in datasets and tasks routers --- src/routers/openml/datasets.py | 97 +++++++++++++++++++++------------- src/routers/openml/tasks.py | 9 +++- 2 files changed, 67 insertions(+), 39 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 164efd7d..8d9193ad 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends -from sqlalchemy import text +from sqlalchemy import bindparam, text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -75,7 +75,7 @@ class DatasetStatusFilter(StrEnum): @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") @router.get(path="/list") -async def list_datasets( # noqa: PLR0913 +async def list_datasets( # noqa: C901, PLR0913, PLR0915 pagination: Annotated[Pagination, Body(default_factory=Pagination)], data_name: Annotated[str | None, CasualString128] = None, tag: Annotated[str | None, SystemString64] = None, @@ -114,65 +114,96 @@ async def list_datasets( # noqa: PLR0913 ) """, ) - + statuses: list[str] if status == DatasetStatusFilter.ALL: statuses = [ - DatasetStatusFilter.ACTIVE, - DatasetStatusFilter.DEACTIVATED, - DatasetStatusFilter.IN_PREPARATION, + DatasetStatusFilter.ACTIVE.value, + DatasetStatusFilter.DEACTIVATED.value, + DatasetStatusFilter.IN_PREPARATION.value, ] else: - statuses = [status] + statuses = [status.value] + + params: dict[str, Any] = { + "tag": tag, + "data_name": data_name, + "data_version": data_version, + "uploader": uploader, + "limit": pagination.limit, + "offset": pagination.offset, + "statuses": statuses, + } - where_status = ",".join(f"'{status}'" for status in statuses) if user is None: visible_to_user = "`visibility`='public'" elif UserGroup.ADMIN in await user.get_groups(): visible_to_user = "TRUE" else: - visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" + visible_to_user = "(`visibility`='public' OR `uploader`=:user_id)" + params["user_id"] = user.user_id where_name = "" if data_name is None else "AND `name`=:data_name" where_version = "" if data_version is None else "AND `version`=:data_version" where_uploader = "" if uploader is None else "AND `uploader`=:uploader" - data_id_str = ",".join(str(did) for did in data_id) if data_id else "" - where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" - # requires some benchmarking on whether e.g., IN () is more efficient. + where_data_id = "" + if data_id: + where_data_id = "AND d.`did` IN :data_ids" + params["data_ids"] = data_id + matching_tag = ( - text( - """ + """ AND d.`did` IN ( SELECT `id` FROM dataset_tag as dt WHERE dt.`tag`=:tag ) - """, - ) + """ if tag else "" ) - def quality_clause(quality: str, range_: str | None) -> str: + def quality_clause(range_: str | None, param_name: str) -> str: if not range_: return "" if not (match := re.match(integer_range_regex, range_)): msg = f"`range_` not a valid range: {range_}" raise ValueError(msg) start, end = match.groups() - value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" + if end: + params[f"{param_name}_start"] = int(start) + params[f"{param_name}_end"] = int(end[2:]) + value_clause = f"`value` BETWEEN :{param_name}_start AND :{param_name}_end" + else: + params[f"{param_name}_val"] = int(start) + value_clause = f"`value` = :{param_name}_val" + return f""" AND d.`did` IN ( SELECT `data` FROM data_quality - WHERE `quality`='{quality}' AND {value} + WHERE `quality`=:quality_{param_name} AND {value_clause} ) - """ # noqa: S608 - `quality` is not user provided, value is filtered with regex + """ # noqa: S608 + + params["quality_instances"] = "NumberOfInstances" + params["quality_classes"] = "NumberOfClasses" + params["quality_features"] = "NumberOfFeatures" + params["quality_missing"] = "NumberOfMissingValues" + + number_instances_filter = quality_clause(number_instances, "instances") + number_classes_filter = quality_clause(number_classes, "classes") + number_features_filter = quality_clause(number_features, "features") + number_missing_values_filter = quality_clause( + number_missing_values, + "missing", + ) + + # Use bindparam with expanding=True for list parameters in IN clauses + bind_params = [bindparam("statuses", expanding=True)] + if "data_ids" in params: + bind_params.append(bindparam("data_ids", expanding=True)) - number_instances_filter = quality_clause("NumberOfInstances", number_instances) - number_classes_filter = quality_clause("NumberOfClasses", number_classes) - number_features_filter = quality_clause("NumberOfFeatures", number_features) - number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values) matching_filter = text( f""" SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`, @@ -182,23 +213,15 @@ def quality_clause(quality: str, range_: str | None) -> str: WHERE {visible_to_user} {where_name} {where_version} {where_uploader} {where_data_id} {matching_tag} {number_instances_filter} {number_features_filter} {number_classes_filter} {number_missing_values_filter} - AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status}) - LIMIT {pagination.limit} OFFSET {pagination.offset} + AND IFNULL(cs.`status`, 'in_preparation') IN :statuses + LIMIT :limit OFFSET :offset """, # noqa: S608 - # I am not sure how to do this correctly without an error from Bandit here. - # However, the `status` input is already checked by FastAPI to be from a set - # of given options, so no injection is possible (I think). The `current_status` - # subquery also has no user input. So I think this should be safe. - ) + ).bindparams(*bind_params) + columns = ["did", "name", "version", "format", "file_id", "status"] result = await expdb_db.execute( matching_filter, - parameters={ - "tag": tag, - "data_name": data_name, - "data_version": data_version, - "uploader": uploader, - }, + parameters=params, ) rows = result.all() datasets: dict[int, dict[str, Any]] = { diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 788cd804..9789ec1e 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -128,6 +128,13 @@ async def _fill_json_template( # noqa: C901 (field,) = match.groups() if field not in fetched_data: table, _ = field.split(".") + # List of tables allowed for [LOOKUP:table.column] directive. + # This is a security measure to prevent SQL injection via table names. + allowed_tables = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"} + if table not in allowed_tables: + msg = f"Table {table} is not allowed for lookup." + raise ValueError(msg) + result = await connection.execute( text( f""" @@ -136,8 +143,6 @@ async def _fill_json_template( # noqa: C901 WHERE `id` = :id_ """, # noqa: S608 ), - # Not sure how parametrize table names, as the parametrization adds - # quotes which is not legal. parameters={"id_": int(task_inputs[table])}, ) rows = result.mappings() From 83abbf5be0ee55c6efb2ed0b29bec802607eb2ff Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Fri, 20 Mar 2026 17:44:50 +0530 Subject: [PATCH 2/7] Architectural refactoring: Move list_datasets logic to database layer and optimize feature fetching --- src/database/datasets.py | 179 +++++++++++++++++++++++++++++++-- src/routers/openml/datasets.py | 127 +++++------------------ 2 files changed, 200 insertions(+), 106 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index 26eb33d8..fea08d45 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,12 +1,13 @@ -"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707.""" - import datetime -from collections import defaultdict +import re +from collections.abc import Sequence +from typing import Any, cast from sqlalchemy import text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection +from routers.types import integer_range_regex from schemas.datasets.openml import Feature @@ -149,9 +150,13 @@ async def get_feature_ontologies( ), parameters={"dataset_id": dataset_id}, ) - ontologies: dict[int, list[str]] = defaultdict(list) - for row in rows.mappings(): - ontologies[row["index"]].append(row["value"]) + ontologies: dict[int, list[str]] = {} + for mapping in rows.mappings(): + index = int(mapping["index"]) + value = str(mapping["value"]) + if index not in ontologies: + ontologies[index] = [] + ontologies[index].append(value) return ontologies @@ -175,6 +180,30 @@ async def get_feature_values( return [row.value for row in rows] +async def get_feature_values_bulk( + dataset_id: int, + connection: AsyncConnection, +) -> dict[int, list[str]]: + rows = await connection.execute( + text( + """ + SELECT `index`, `value` + FROM data_feature_value + WHERE `did` = :dataset_id + """, + ), + parameters={"dataset_id": dataset_id}, + ) + values: dict[int, list[str]] = {} + for mapping in rows.mappings(): + index = int(mapping["index"]) + value = str(mapping["value"]) + if index not in values: + values[index] = [] + values[index].append(value) + return values + + async def update_status( dataset_id: int, status: str, @@ -208,3 +237,141 @@ async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection ), parameters={"data": dataset_id}, ) + + +async def list_datasets( # noqa: C901, PLR0913 + *, + limit: int, + offset: int, + data_name: str | None = None, + data_version: str | None = None, + tag: str | None = None, + data_ids: list[int] | None = None, + uploader: int | None = None, + number_instances: str | None = None, + number_features: str | None = None, + number_classes: str | None = None, + number_missing_values: str | None = None, + statuses: list[str], + user_id: int | None = None, + is_admin: bool = False, + connection: AsyncConnection, +) -> Sequence[Row]: + current_status = """ + SELECT ds1.`did`, ds1.`status` + FROM dataset_status AS ds1 + WHERE ds1.`status_date`=( + SELECT MAX(ds2.`status_date`) + FROM dataset_status as ds2 + WHERE ds1.`did`=ds2.`did` + ) + """ + + if is_admin: + visible_to_user = "TRUE" + elif user_id: + visible_to_user = f"(`visibility`='public' OR `uploader`={user_id})" + else: + visible_to_user = "`visibility`='public'" + + where_name = "AND `name`=:data_name" if data_name else "" + where_version = "AND `version`=:data_version" if data_version else "" + where_uploader = "AND `uploader`=:uploader" if uploader else "" + where_data_id = "AND d.`did` IN :data_ids" if data_ids else "" + + matching_tag = ( + """ + AND d.`did` IN ( + SELECT `id` + FROM dataset_tag as dt + WHERE dt.`tag`=:tag + ) + """ + if tag + else "" + ) + + def quality_clause(quality: str, range_str: str | None, param_name: str) -> str: + if not range_str: + return "" + if not (match := re.match(integer_range_regex, range_str)): + msg = f"`range_str` not a valid range: {range_str}" + raise ValueError(msg) + _start, end = match.groups() + if end: + # end is e.g. "..150" + value = f"`value` BETWEEN :{param_name}_start AND :{param_name}_end" + else: + value = f"`value` = :{param_name}_start" + + return f""" AND + d.`did` IN ( + SELECT `data` + FROM data_quality + WHERE `quality`='{quality}' AND {value} + ) + """ # noqa: S608 + + q_params = {} + + def get_range_params(range_str: str | None, param_prefix: str) -> dict[str, Any]: + if not range_str: + return {} + if not (match := re.match(integer_range_regex, range_str)): + return {} + _start, end = match.groups() + params: dict[str, Any] = {f"{param_prefix}_start": _start} + if end: + # end is e.g. "..150" + end_val = str(end) + params[f"{param_prefix}_end"] = end_val[2:] + return params + + instances_filter = quality_clause("NumberOfInstances", number_instances, "instances") + q_params.update(get_range_params(number_instances, "instances")) + + features_filter = quality_clause("NumberOfFeatures", number_features, "features") + q_params.update(get_range_params(number_features, "features")) + + classes_filter = quality_clause("NumberOfClasses", number_classes, "classes") + q_params.update(get_range_params(number_classes, "classes")) + + missing_values_filter = quality_clause( + "NumberOfMissingValues", + number_missing_values, + "missing_values", + ) + q_params.update(get_range_params(number_missing_values, "missing_values")) + + sql = text( + f""" + SELECT d.`did`, d.`name`, d.`version`, d.`format`, d.`file_id`, + IFNULL(cs.`status`, 'in_preparation') AS status + FROM dataset AS d + LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did` + WHERE {visible_to_user} {where_name} {where_version} {where_uploader} + {where_data_id} {matching_tag} {instances_filter} {features_filter} + {classes_filter} {missing_values_filter} + AND IFNULL(cs.`status`, 'in_preparation') IN :statuses + LIMIT :limit OFFSET :offset + """, # noqa: S608 + ) + + parameters = { + "data_name": data_name, + "data_version": data_version, + "uploader": uploader, + "tag": tag, + "statuses": statuses, + "limit": limit, + "offset": offset, + **q_params, + } + if data_ids: + parameters["data_ids"] = data_ids + + result = await connection.execute( + sql.bindparams(statuses=statuses, data_ids=data_ids) if data_ids else sql, + parameters=parameters, + ) + return cast("Sequence[Row]", result.all()) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 164efd7d..e047c87b 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -1,10 +1,8 @@ -import re from datetime import datetime from enum import StrEnum from typing import Annotated, Any, Literal, NamedTuple from fastapi import APIRouter, Body, Depends -from sqlalchemy import text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -39,7 +37,7 @@ fetch_user_or_raise, userdb_connection, ) -from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex +from routers.types import CasualString128, IntegerRange, SystemString64 from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -103,104 +101,35 @@ async def list_datasets( # noqa: PLR0913 expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: assert expdb_db is not None # noqa: S101 - current_status = text( - """ - SELECT ds1.`did`, ds1.`status` - FROM dataset_status as ds1 - WHERE ds1.`status_date`=( - SELECT MAX(ds2.`status_date`) - FROM dataset_status as ds2 - WHERE ds1.`did`=ds2.`did` - ) - """, - ) - + statuses: list[str] if status == DatasetStatusFilter.ALL: statuses = [ - DatasetStatusFilter.ACTIVE, - DatasetStatusFilter.DEACTIVATED, - DatasetStatusFilter.IN_PREPARATION, + DatasetStatus.ACTIVE.value, + DatasetStatus.DEACTIVATED.value, + DatasetStatus.IN_PREPARATION.value, ] else: - statuses = [status] - - where_status = ",".join(f"'{status}'" for status in statuses) - if user is None: - visible_to_user = "`visibility`='public'" - elif UserGroup.ADMIN in await user.get_groups(): - visible_to_user = "TRUE" - else: - visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" - - where_name = "" if data_name is None else "AND `name`=:data_name" - where_version = "" if data_version is None else "AND `version`=:data_version" - where_uploader = "" if uploader is None else "AND `uploader`=:uploader" - data_id_str = ",".join(str(did) for did in data_id) if data_id else "" - where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" - - # requires some benchmarking on whether e.g., IN () is more efficient. - matching_tag = ( - text( - """ - AND d.`did` IN ( - SELECT `id` - FROM dataset_tag as dt - WHERE dt.`tag`=:tag - ) - """, - ) - if tag - else "" + statuses = [str(status.value)] + + rows = await database.datasets.list_datasets( + limit=pagination.limit, + offset=pagination.offset, + data_name=data_name, + data_version=str(data_version) if data_version else None, + tag=tag, + data_ids=data_id, + uploader=uploader, + number_instances=number_instances, + number_features=number_features, + number_classes=number_classes, + number_missing_values=number_missing_values, + statuses=statuses, + user_id=user.user_id if user else None, + is_admin=UserGroup.ADMIN in await user.get_groups() if user else False, + connection=expdb_db, ) - def quality_clause(quality: str, range_: str | None) -> str: - if not range_: - return "" - if not (match := re.match(integer_range_regex, range_)): - msg = f"`range_` not a valid range: {range_}" - raise ValueError(msg) - start, end = match.groups() - value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" - return f""" AND - d.`did` IN ( - SELECT `data` - FROM data_quality - WHERE `quality`='{quality}' AND {value} - ) - """ # noqa: S608 - `quality` is not user provided, value is filtered with regex - - number_instances_filter = quality_clause("NumberOfInstances", number_instances) - number_classes_filter = quality_clause("NumberOfClasses", number_classes) - number_features_filter = quality_clause("NumberOfFeatures", number_features) - number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values) - matching_filter = text( - f""" - SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`, - IFNULL(cs.`status`, 'in_preparation') - FROM dataset AS d - LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did` - WHERE {visible_to_user} {where_name} {where_version} {where_uploader} - {where_data_id} {matching_tag} {number_instances_filter} {number_features_filter} - {number_classes_filter} {number_missing_values_filter} - AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status}) - LIMIT {pagination.limit} OFFSET {pagination.offset} - """, # noqa: S608 - # I am not sure how to do this correctly without an error from Bandit here. - # However, the `status` input is already checked by FastAPI to be from a set - # of given options, so no injection is possible (I think). The `current_status` - # subquery also has no user input. So I think this should be safe. - ) columns = ["did", "name", "version", "format", "file_id", "status"] - result = await expdb_db.execute( - matching_filter, - parameters={ - "tag": tag, - "data_name": data_name, - "data_version": data_version, - "uploader": uploader, - }, - ) - rows = result.all() datasets: dict[int, dict[str, Any]] = { row.did: dict(zip(columns, row, strict=True)) for row in rows } @@ -298,12 +227,10 @@ async def get_dataset_features( for feature in features: feature.ontology = ontologies.get(feature.index) - for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: - feature.nominal_values = await database.datasets.get_feature_values( - dataset_id, - feature_index=feature.index, - connection=expdb, - ) + nominal_values = await database.datasets.get_feature_values_bulk(dataset_id, expdb) + for feature in features: + if feature.data_type == FeatureType.NOMINAL: + feature.nominal_values = nominal_values.get(feature.index, []) if not features: processing_state = await database.datasets.get_latest_processing_update(dataset_id, expdb) From 193a0dd3435fe7ca582a81d0230a5e2aaa077987 Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Fri, 20 Mar 2026 17:51:50 +0530 Subject: [PATCH 3/7] Address review feedback: Hoist allowed_tables, refactor quality_clause to pure function, and simplify bindparams --- src/routers/openml/datasets.py | 68 ++++++++++++++++++++++------------ src/routers/openml/tasks.py | 8 ++-- 2 files changed, 48 insertions(+), 28 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 8d9193ad..9e8aab89 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -163,46 +163,63 @@ async def list_datasets( # noqa: C901, PLR0913, PLR0915 else "" ) - def quality_clause(range_: str | None, param_name: str) -> str: + def quality_clause( + quality_name: str, + range_: str | None, + *, + prefix: str, + ) -> tuple[str, dict[str, Any]]: if not range_: - return "" + return "", {} if not (match := re.match(integer_range_regex, range_)): msg = f"`range_` not a valid range: {range_}" raise ValueError(msg) start, end = match.groups() + clause_params: dict[str, Any] = {f"quality_name_{prefix}": quality_name} if end: - params[f"{param_name}_start"] = int(start) - params[f"{param_name}_end"] = int(end[2:]) - value_clause = f"`value` BETWEEN :{param_name}_start AND :{param_name}_end" + clause_params[f"{prefix}_start"] = int(start) + clause_params[f"{prefix}_end"] = int(end[2:]) + value_clause = f"`value` BETWEEN :{prefix}_start AND :{prefix}_end" else: - params[f"{param_name}_val"] = int(start) - value_clause = f"`value` = :{param_name}_val" + clause_params[f"{prefix}_val"] = int(start) + value_clause = f"`value` = :{prefix}_val" - return f""" AND + sql = f""" AND d.`did` IN ( SELECT `data` FROM data_quality - WHERE `quality`=:quality_{param_name} AND {value_clause} + WHERE `quality`=:quality_name_{prefix} AND {value_clause} ) """ # noqa: S608 + return sql, clause_params - params["quality_instances"] = "NumberOfInstances" - params["quality_classes"] = "NumberOfClasses" - params["quality_features"] = "NumberOfFeatures" - params["quality_missing"] = "NumberOfMissingValues" + number_instances_filter, instances_params = quality_clause( + "NumberOfInstances", + number_instances, + prefix="instances", + ) + params.update(instances_params) - number_instances_filter = quality_clause(number_instances, "instances") - number_classes_filter = quality_clause(number_classes, "classes") - number_features_filter = quality_clause(number_features, "features") - number_missing_values_filter = quality_clause( - number_missing_values, - "missing", + number_classes_filter, classes_params = quality_clause( + "NumberOfClasses", + number_classes, + prefix="classes", ) + params.update(classes_params) - # Use bindparam with expanding=True for list parameters in IN clauses - bind_params = [bindparam("statuses", expanding=True)] - if "data_ids" in params: - bind_params.append(bindparam("data_ids", expanding=True)) + number_features_filter, features_params = quality_clause( + "NumberOfFeatures", + number_features, + prefix="features", + ) + params.update(features_params) + + number_missing_values_filter, missing_params = quality_clause( + "NumberOfMissingValues", + number_missing_values, + prefix="missing", + ) + params.update(missing_params) matching_filter = text( f""" @@ -216,7 +233,10 @@ def quality_clause(range_: str | None, param_name: str) -> str: AND IFNULL(cs.`status`, 'in_preparation') IN :statuses LIMIT :limit OFFSET :offset """, # noqa: S608 - ).bindparams(*bind_params) + ).bindparams( + bindparam("statuses", expanding=True), + bindparam("data_ids", expanding=True) if "data_ids" in params else bindparam("data_ids"), + ) columns = ["did", "name", "version", "format", "file_id", "status"] result = await expdb_db.execute( diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 9789ec1e..1157aec9 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -3,7 +3,7 @@ from typing import Annotated, cast import xmltodict -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection @@ -17,6 +17,7 @@ router = APIRouter(prefix="/tasks", tags=["tasks"]) type JSON = dict[str, "JSON"] | list["JSON"] | str | int | float | bool | None +ALLOWED_LOOKUP_TABLES = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"} def convert_template_xml_to_json(xml_template: str) -> dict[str, JSON]: @@ -130,10 +131,9 @@ async def _fill_json_template( # noqa: C901 table, _ = field.split(".") # List of tables allowed for [LOOKUP:table.column] directive. # This is a security measure to prevent SQL injection via table names. - allowed_tables = {"estimation_procedure", "evaluation_measure", "task_type", "dataset"} - if table not in allowed_tables: + if table not in ALLOWED_LOOKUP_TABLES: msg = f"Table {table} is not allowed for lookup." - raise ValueError(msg) + raise HTTPException(status_code=400, detail=msg) result = await connection.execute( text( From 567f15c95d9be3e38b941a8a2607e7f3828e482d Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Fri, 20 Mar 2026 18:11:58 +0530 Subject: [PATCH 4/7] Fix SQLAlchemy ArgumentError: conditionally bind data_ids --- src/routers/openml/datasets.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 9e8aab89..1764ae23 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -221,6 +221,11 @@ def quality_clause( ) params.update(missing_params) + # Use bindparam with expanding=True for list parameters in IN clauses + bind_params = [bindparam("statuses", expanding=True)] + if "data_ids" in params: + bind_params.append(bindparam("data_ids", expanding=True)) + matching_filter = text( f""" SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`, @@ -233,10 +238,7 @@ def quality_clause( AND IFNULL(cs.`status`, 'in_preparation') IN :statuses LIMIT :limit OFFSET :offset """, # noqa: S608 - ).bindparams( - bindparam("statuses", expanding=True), - bindparam("data_ids", expanding=True) if "data_ids" in params else bindparam("data_ids"), - ) + ).bindparams(*bind_params) columns = ["did", "name", "version", "format", "file_id", "status"] result = await expdb_db.execute( From 7d8d391ae7b8e74807dce1097c30b4cb70a59535 Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Sat, 21 Mar 2026 17:07:08 +0530 Subject: [PATCH 5/7] Architectural refactoring: Move task template lookup logic to database layer --- src/database/tasks.py | 17 ++++++++++++++++- src/routers/openml/tasks.py | 17 +++++------------ 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/database/tasks.py b/src/database/tasks.py index e9670d26..009d358f 100644 --- a/src/database/tasks.py +++ b/src/database/tasks.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import cast -from sqlalchemy import Row, text +from sqlalchemy import Row, RowMapping, text from sqlalchemy.ext.asyncio import AsyncConnection @@ -115,3 +115,18 @@ async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]: ) tag_rows = rows.all() return [row.tag for row in tag_rows] + + +async def get_lookup_data(table: str, id_: int, expdb: AsyncConnection) -> RowMapping | None: + # table is already whitelisted in the router + result = await expdb.execute( + text( + f""" + SELECT * + FROM {table} + WHERE `id` = :id_ + """, # noqa: S608 + ), + parameters={"id_": id_}, + ) + return result.mappings().one_or_none() diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index 1157aec9..b7c702da 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -4,7 +4,7 @@ import xmltodict from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import RowMapping, text +from sqlalchemy import RowMapping from sqlalchemy.ext.asyncio import AsyncConnection import config @@ -135,18 +135,11 @@ async def _fill_json_template( # noqa: C901 msg = f"Table {table} is not allowed for lookup." raise HTTPException(status_code=400, detail=msg) - result = await connection.execute( - text( - f""" - SELECT * - FROM {table} - WHERE `id` = :id_ - """, # noqa: S608 - ), - parameters={"id_": int(task_inputs[table])}, + row_data = await database.tasks.get_lookup_data( + table=table, + id_=int(task_inputs[table]), + expdb=connection, ) - rows = result.mappings() - row_data = next(rows, None) if row_data is None: msg = f"No data found for table {table} with id {task_inputs[table]}" raise ValueError(msg) From db9b7fbf4f514b82ef52cdf2a089dff01119e0cf Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Sun, 22 Mar 2026 19:45:59 +0530 Subject: [PATCH 6/7] Address final CodeRabbit review: fix list_datasets signature, expansion, and error handling --- src/database/datasets.py | 8 ++++++-- src/routers/openml/datasets.py | 24 +++++++++++++++++++----- src/routers/openml/tasks.py | 2 +- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index 7a666ec0..c495b711 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Any, cast -from sqlalchemy import text +from sqlalchemy import bindparam, text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -359,6 +359,10 @@ async def list_datasets( # noqa: PLR0913 """, # noqa: S608 ) + sql = sql.bindparams(bindparam("statuses", expanding=True)) + if data_ids: + sql = sql.bindparams(bindparam("data_ids", expanding=True)) + parameters = { "data_name": data_name, "data_version": data_version, @@ -373,7 +377,7 @@ async def list_datasets( # noqa: PLR0913 parameters["data_ids"] = data_ids result = await connection.execute( - sql.bindparams(statuses=statuses, data_ids=data_ids) if data_ids else sql, + sql, parameters=parameters, ) return cast("Sequence[Row]", result.all()) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 292a7612..169cd4fb 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -101,20 +101,34 @@ async def list_datasets( # noqa: PLR0913 expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: assert expdb_db is not None # noqa: S101 + if status == DatasetStatusFilter.ALL: + statuses = [ + DatasetStatusFilter.ACTIVE, + DatasetStatusFilter.DEACTIVATED, + DatasetStatusFilter.IN_PREPARATION, + ] + else: + statuses = [status] + + user_id = user.user_id if user else None + is_admin = UserGroup.ADMIN in await user.get_groups() if user else False + try: rows = await database.datasets.list_datasets( - pagination=pagination, + limit=pagination.limit, + offset=pagination.offset, data_name=data_name, tag=tag, - data_version=data_version, + data_version=str(data_version) if data_version is not None else None, uploader=uploader, - data_id=data_id, + data_ids=data_id, number_instances=number_instances, number_features=number_features, number_classes=number_classes, number_missing_values=number_missing_values, - status=status, - user=user, + statuses=[s.value for s in statuses], + user_id=user_id, + is_admin=is_admin, connection=expdb_db, ) except ValueError as e: diff --git a/src/routers/openml/tasks.py b/src/routers/openml/tasks.py index e577bda3..c02c8da6 100644 --- a/src/routers/openml/tasks.py +++ b/src/routers/openml/tasks.py @@ -151,7 +151,7 @@ async def _fill_json_template( # noqa: C901, PLR0912 raise HTTPException(status_code=400, detail=str(e)) from e if row_data is None: msg = f"No data found for table {table} with id {task_inputs[table]}" - raise ValueError(msg) + raise HTTPException(status_code=400, detail=msg) for column, value in row_data.items(): fetched_data[f"{table}.{column}"] = value if match.string == template: From 778faa8698da2ce576c1ce8c3b093a406e32d26b Mon Sep 17 00:00:00 2001 From: MOHITKOURAV01 Date: Tue, 24 Mar 2026 01:12:41 +0530 Subject: [PATCH 7/7] Revert dataset list sanitation and keep lookup table checks as requested by maintainer --- src/database/datasets.py | 187 ++------------------------------- src/routers/openml/datasets.py | 126 +++++++++++++++++----- 2 files changed, 104 insertions(+), 209 deletions(-) diff --git a/src/database/datasets.py b/src/database/datasets.py index c495b711..26eb33d8 100644 --- a/src/database/datasets.py +++ b/src/database/datasets.py @@ -1,13 +1,12 @@ +"""Translation from https://github.com/openml/OpenML/blob/c19c9b99568c0fabb001e639ff6724b9a754bbc9/openml_OS/models/api/v1/Api_data.php#L707.""" + import datetime -import re -from collections.abc import Sequence -from typing import Any, cast +from collections import defaultdict -from sqlalchemy import bindparam, text +from sqlalchemy import text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection -from routers.types import integer_range_regex from schemas.datasets.openml import Feature @@ -150,13 +149,9 @@ async def get_feature_ontologies( ), parameters={"dataset_id": dataset_id}, ) - ontologies: dict[int, list[str]] = {} - for mapping in rows.mappings(): - index = int(mapping["index"]) - value = str(mapping["value"]) - if index not in ontologies: - ontologies[index] = [] - ontologies[index].append(value) + ontologies: dict[int, list[str]] = defaultdict(list) + for row in rows.mappings(): + ontologies[row["index"]].append(row["value"]) return ontologies @@ -180,30 +175,6 @@ async def get_feature_values( return [row.value for row in rows] -async def get_feature_values_bulk( - dataset_id: int, - connection: AsyncConnection, -) -> dict[int, list[str]]: - rows = await connection.execute( - text( - """ - SELECT `index`, `value` - FROM data_feature_value - WHERE `did` = :dataset_id - """, - ), - parameters={"dataset_id": dataset_id}, - ) - values: dict[int, list[str]] = {} - for mapping in rows.mappings(): - index = int(mapping["index"]) - value = str(mapping["value"]) - if index not in values: - values[index] = [] - values[index].append(value) - return values - - async def update_status( dataset_id: int, status: str, @@ -237,147 +208,3 @@ async def remove_deactivated_status(dataset_id: int, connection: AsyncConnection ), parameters={"data": dataset_id}, ) - - -def _get_quality_filter(quality: str, range_str: str | None, param_name: str) -> str: - if not range_str: - return "" - if not (match := re.match(integer_range_regex, range_str)): - msg = f"Invalid range format for {quality}: {range_str}" - raise ValueError(msg) - _start, end = match.groups() - if end: - value = f"`value` BETWEEN :{param_name}_start AND :{param_name}_end" - else: - value = f"`value` = :{param_name}_start" - - return f""" AND - d.`did` IN ( - SELECT `data` - FROM data_quality - WHERE `quality` = :quality_name_{param_name} AND {value} - ) - """ # noqa: S608 - - -def _get_range_params( - quality_name: str, - range_str: str | None, - param_prefix: str, -) -> dict[str, Any]: - if not range_str: - return {} - if not (match := re.match(integer_range_regex, range_str)): - return {} - start, end = match.groups() - params: dict[str, Any] = { - f"quality_name_{param_prefix}": quality_name, - f"{param_prefix}_start": int(start), - } - if end: - params[f"{param_prefix}_end"] = int(end[2:]) - return params - - -async def list_datasets( # noqa: PLR0913 - *, - limit: int, - offset: int, - data_name: str | None = None, - data_version: str | None = None, - tag: str | None = None, - data_ids: list[int] | None = None, - uploader: int | None = None, - number_instances: str | None = None, - number_features: str | None = None, - number_classes: str | None = None, - number_missing_values: str | None = None, - statuses: list[str], - user_id: int | None = None, - is_admin: bool = False, - connection: AsyncConnection, -) -> Sequence[Row]: - current_status = """ - SELECT ds1.`did`, ds1.`status` - FROM dataset_status AS ds1 - WHERE ds1.`status_date`=( - SELECT MAX(ds2.`status_date`) - FROM dataset_status as ds2 - WHERE ds1.`did`=ds2.`did` - ) - """ - - if is_admin: - visible_to_user = "TRUE" - elif user_id: - visible_to_user = f"(`visibility`='public' OR `uploader`={user_id})" - else: - visible_to_user = "`visibility`='public'" - - where_name = "AND `name`=:data_name" if data_name else "" - where_version = "AND `version`=:data_version" if data_version else "" - where_uploader = "AND `uploader`=:uploader" if uploader else "" - where_data_id = "AND d.`did` IN :data_ids" if data_ids else "" - - matching_tag = ( - """ - AND d.`did` IN ( - SELECT `id` - FROM dataset_tag as dt - WHERE dt.`tag`=:tag - ) - """ - if tag - else "" - ) - - q_params: dict[str, Any] = {} - q_filters: list[str] = [] - - for quality, range_str, prefix in [ - ("NumberOfInstances", number_instances, "instances"), - ("NumberOfFeatures", number_features, "features"), - ("NumberOfClasses", number_classes, "classes"), - ("NumberOfMissingValues", number_missing_values, "missing_vals"), - ]: - q_filters.append(_get_quality_filter(quality, range_str, prefix)) - q_params.update(_get_range_params(quality, range_str, prefix)) - - instances_filter, features_filter, classes_filter, missing_values_filter = q_filters - - sql = text( - f""" - SELECT d.`did`, d.`name`, d.`version`, d.`format`, d.`file_id`, - IFNULL(cs.`status`, 'in_preparation') AS status - FROM dataset AS d - LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did` - WHERE {visible_to_user} {where_name} {where_version} {where_uploader} - {where_data_id} {matching_tag} {instances_filter} {features_filter} - {classes_filter} {missing_values_filter} - AND IFNULL(cs.`status`, 'in_preparation') IN :statuses - LIMIT :limit OFFSET :offset - """, # noqa: S608 - ) - - sql = sql.bindparams(bindparam("statuses", expanding=True)) - if data_ids: - sql = sql.bindparams(bindparam("data_ids", expanding=True)) - - parameters = { - "data_name": data_name, - "data_version": data_version, - "uploader": uploader, - "tag": tag, - "statuses": statuses, - "limit": limit, - "offset": offset, - **q_params, - } - if data_ids: - parameters["data_ids"] = data_ids - - result = await connection.execute( - sql, - parameters=parameters, - ) - return cast("Sequence[Row]", result.all()) diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index 169cd4fb..164efd7d 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -1,8 +1,10 @@ +import re from datetime import datetime from enum import StrEnum from typing import Annotated, Any, Literal, NamedTuple -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Body, Depends +from sqlalchemy import text from sqlalchemy.engine import Row from sqlalchemy.ext.asyncio import AsyncConnection @@ -37,7 +39,7 @@ fetch_user_or_raise, userdb_connection, ) -from routers.types import CasualString128, IntegerRange, SystemString64 +from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex from schemas.datasets.openml import DatasetMetadata, DatasetStatus, Feature, FeatureType router = APIRouter(prefix="/datasets", tags=["datasets"]) @@ -101,6 +103,18 @@ async def list_datasets( # noqa: PLR0913 expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> list[dict[str, Any]]: assert expdb_db is not None # noqa: S101 + current_status = text( + """ + SELECT ds1.`did`, ds1.`status` + FROM dataset_status as ds1 + WHERE ds1.`status_date`=( + SELECT MAX(ds2.`status_date`) + FROM dataset_status as ds2 + WHERE ds1.`did`=ds2.`did` + ) + """, + ) + if status == DatasetStatusFilter.ALL: statuses = [ DatasetStatusFilter.ACTIVE, @@ -110,33 +124,85 @@ async def list_datasets( # noqa: PLR0913 else: statuses = [status] - user_id = user.user_id if user else None - is_admin = UserGroup.ADMIN in await user.get_groups() if user else False - - try: - rows = await database.datasets.list_datasets( - limit=pagination.limit, - offset=pagination.offset, - data_name=data_name, - tag=tag, - data_version=str(data_version) if data_version is not None else None, - uploader=uploader, - data_ids=data_id, - number_instances=number_instances, - number_features=number_features, - number_classes=number_classes, - number_missing_values=number_missing_values, - statuses=[s.value for s in statuses], - user_id=user_id, - is_admin=is_admin, - connection=expdb_db, + where_status = ",".join(f"'{status}'" for status in statuses) + if user is None: + visible_to_user = "`visibility`='public'" + elif UserGroup.ADMIN in await user.get_groups(): + visible_to_user = "TRUE" + else: + visible_to_user = f"(`visibility`='public' OR `uploader`={user.user_id})" + + where_name = "" if data_name is None else "AND `name`=:data_name" + where_version = "" if data_version is None else "AND `version`=:data_version" + where_uploader = "" if uploader is None else "AND `uploader`=:uploader" + data_id_str = ",".join(str(did) for did in data_id) if data_id else "" + where_data_id = "" if not data_id else f"AND d.`did` IN ({data_id_str})" + + # requires some benchmarking on whether e.g., IN () is more efficient. + matching_tag = ( + text( + """ + AND d.`did` IN ( + SELECT `id` + FROM dataset_tag as dt + WHERE dt.`tag`=:tag + ) + """, ) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) from e + if tag + else "" + ) + def quality_clause(quality: str, range_: str | None) -> str: + if not range_: + return "" + if not (match := re.match(integer_range_regex, range_)): + msg = f"`range_` not a valid range: {range_}" + raise ValueError(msg) + start, end = match.groups() + value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" + return f""" AND + d.`did` IN ( + SELECT `data` + FROM data_quality + WHERE `quality`='{quality}' AND {value} + ) + """ # noqa: S608 - `quality` is not user provided, value is filtered with regex + + number_instances_filter = quality_clause("NumberOfInstances", number_instances) + number_classes_filter = quality_clause("NumberOfClasses", number_classes) + number_features_filter = quality_clause("NumberOfFeatures", number_features) + number_missing_values_filter = quality_clause("NumberOfMissingValues", number_missing_values) + matching_filter = text( + f""" + SELECT d.`did`,d.`name`,d.`version`,d.`format`,d.`file_id`, + IFNULL(cs.`status`, 'in_preparation') + FROM dataset AS d + LEFT JOIN ({current_status}) AS cs ON d.`did`=cs.`did` + WHERE {visible_to_user} {where_name} {where_version} {where_uploader} + {where_data_id} {matching_tag} {number_instances_filter} {number_features_filter} + {number_classes_filter} {number_missing_values_filter} + AND IFNULL(cs.`status`, 'in_preparation') IN ({where_status}) + LIMIT {pagination.limit} OFFSET {pagination.offset} + """, # noqa: S608 + # I am not sure how to do this correctly without an error from Bandit here. + # However, the `status` input is already checked by FastAPI to be from a set + # of given options, so no injection is possible (I think). The `current_status` + # subquery also has no user input. So I think this should be safe. + ) columns = ["did", "name", "version", "format", "file_id", "status"] + result = await expdb_db.execute( + matching_filter, + parameters={ + "tag": tag, + "data_name": data_name, + "data_version": data_version, + "uploader": uploader, + }, + ) + rows = result.all() datasets: dict[int, dict[str, Any]] = { - row.did: dict(zip(columns, row, strict=False)) for row in rows + row.did: dict(zip(columns, row, strict=True)) for row in rows } if not datasets: msg = "No datasets match the search criteria." @@ -232,10 +298,12 @@ async def get_dataset_features( for feature in features: feature.ontology = ontologies.get(feature.index) - nominal_values = await database.datasets.get_feature_values_bulk(dataset_id, expdb) - for feature in features: - if feature.data_type == FeatureType.NOMINAL: - feature.nominal_values = nominal_values.get(feature.index, []) + for feature in [f for f in features if f.data_type == FeatureType.NOMINAL]: + feature.nominal_values = await database.datasets.get_feature_values( + dataset_id, + feature_index=feature.index, + connection=expdb, + ) if not features: processing_state = await database.datasets.get_latest_processing_update(dataset_id, expdb)