Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/core/tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from collections.abc import Awaitable, Callable
from typing import Any

from sqlalchemy import Row
from sqlalchemy.ext.asyncio import AsyncConnection

from core.errors import TagAlreadyExistsError, TagNotFoundError, TagNotOwnedError
from database.users import User, UserGroup


async def tag_entity(
entity_id: int,
tag: str,
user: User,
expdb: AsyncConnection,
*,
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
tag_fn: Callable[..., Awaitable[None]],
response_key: str,
) -> dict[str, dict[str, Any]]:
tags = await get_tags_fn(entity_id, expdb)
if tag.casefold() in (t.casefold() for t in tags):
msg = f"Entity {entity_id} already tagged with {tag!r}."
raise TagAlreadyExistsError(msg)
await tag_fn(entity_id, tag, user_id=user.user_id, expdb=expdb)
tags = await get_tags_fn(entity_id, expdb)
return {response_key: {"id": str(entity_id), "tag": tags}}
Comment on lines +26 to +27
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to just return [*tags, tag] and forgo making a database request to fetch all tags again. It will be more efficient. In either case the chance that there are concurrent requests that result in returning an outdated set of tags exists, but that's ok.



async def untag_entity(
entity_id: int,
tag: str,
user: User,
expdb: AsyncConnection,
*,
get_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[Row | None]],
delete_tag_fn: Callable[[int, str, AsyncConnection], Awaitable[None]],
get_tags_fn: Callable[[int, AsyncConnection], Awaitable[list[str]]],
response_key: str,
) -> dict[str, dict[str, Any]]:
existing = await get_tag_fn(entity_id, tag, expdb)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better for now to forgo the get_tag_fn and just retrieve all tags:

  • It's not obvious getting the single tag first is more efficient. In the happy path a second database trip is needed to fetch the remainder of the tags.
  • Not requiring the function simplifies the function signature.
  • The PHP API currently actually compares in a case-insensitive way. I'm not sure from the top of my head if the current database schema facilitates this in which case the behavior would be preserved implicitly, but it's in any case good if this behavior doesn't suddenly change if we change the column type in the database (with some other collation).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this would also result in the get_tag (singular) functions becoming dead code. If correct, then those can be removed from the PR. We can easily re-introduce them if we do have a need for them, but I suspect there won't be.

if existing is None:
msg = f"Tag {tag!r} not found on entity {entity_id}."
raise TagNotFoundError(msg)
groups = await user.get_groups()
if existing.uploader != user.user_id and UserGroup.ADMIN not in groups:
msg = f"Tag {tag!r} on entity {entity_id} is not owned by you."
raise TagNotOwnedError(msg)
await delete_tag_fn(entity_id, tag, expdb)
tags = await get_tags_fn(entity_id, expdb)
return {response_key: {"id": str(entity_id), "tag": tags}}
37 changes: 25 additions & 12 deletions src/database/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "implementation_tag"
_ID_COLUMN = "id"
Comment on lines +9 to +10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_TABLE = "implementation_tag"
_ID_COLUMN = "id"
_TAG_TABLE = "implementation_tag"
_TAG_TABLE_ID_COLUMN = "id"

Would be better because it's clear what kind of a table or column they refer to.
Same remark of course for other files of other entities, and their usages need to be updated accordingly.



async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:
rows = await expdb.execute(
Expand All @@ -23,18 +28,7 @@ async def get_subflows(for_flow: int, expdb: AsyncConnection) -> Sequence[Row]:


async def get_tags(flow_id: int, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
SELECT tag
FROM implementation_tag
WHERE id = :flow_id
""",
),
parameters={"flow_id": flow_id},
)
tag_rows = rows.all()
return [tag.tag for tag in tag_rows]
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=flow_id, expdb=expdb)


async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
Expand All @@ -54,6 +48,25 @@ async def get_parameters(flow_id: int, expdb: AsyncConnection) -> Sequence[Row]:
)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def get_by_name(name: str, external_version: str, expdb: AsyncConnection) -> Row | None:
"""Get flow by name and external version."""
row = await expdb.execute(
Expand Down
44 changes: 44 additions & 0 deletions src/database/runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "run_tag"
_ID_COLUMN = "id"


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
text(
"""
SELECT *
FROM run
WHERE `id` = :run_id
""",
),
parameters={"run_id": id_},
)
return row.one_or_none()


async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
82 changes: 82 additions & 0 deletions src/database/tagging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection


async def insert_tag(
*,
table: str,
id_column: str,
Comment on lines +7 to +8
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Constrain table and id_column before interpolating them into shared SQL.

These helpers are now the reusable tagging boundary, but they accept arbitrary identifier strings and splice them straight into every query. The current callers pass constants, yet one future non-constant call turns this module into an injection sink. Please validate against a closed set of supported identifier combinations here, instead of relying on every caller to stay disciplined.

Also applies to: 16-19, 35-39, 56-59, 74-78

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/database/tagging.py` around lines 7 - 8, The helpers in this module
accept the identifier parameters table and id_column and interpolate them
directly into shared SQL; validate these inputs against a closed allowlist
before any string interpolation by adding an allowlist mapping (e.g.,
ALLOWED_TABLES = {"tags": {"id"}, "items": {"item_id"}}) and checking that table
is a key and id_column is one of its allowed columns, raising ValueError on
mismatch; perform this check at the start of the public helper functions that
accept table/id_column so no query is constructed for invalid values, and when
you must embed identifiers into SQL use a safe identifier mechanism (e.g.,
psycopg2.sql.Identifier or the DB driver's proper identifier-quoting helper)
rather than naive string concatenation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a good comment. id_column should always be id or study_id, tables should be setup_tag,dataset_tag,task_tag,implementation_tag, study_tag. You can validate this in a common function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
cat src/database/tagging.py

Length of output: 1700


@PGijsbers, thanks for confirming the allowed values. Here is a proposed patch that adds a validation helper and calls it in all four functions. Note that run_tag (used by the new src/database/runs.py) is not in your list — I've included it below but please confirm whether it belongs:

from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

# Allowlist of valid (table → allowed id_columns) combinations.
_ALLOWED_IDENTIFIERS: dict[str, frozenset[str]] = {
    "dataset_tag": frozenset({"id"}),
    "implementation_tag": frozenset({"id"}),
    "run_tag": frozenset({"id"}),       # used by database/runs.py — confirm if needed
    "setup_tag": frozenset({"id"}),
    "study_tag": frozenset({"study_id"}),
    "task_tag": frozenset({"id"}),
}


def _validate_identifiers(table: str, id_column: str) -> None:
    allowed_columns = _ALLOWED_IDENTIFIERS.get(table)
    if allowed_columns is None:
        raise ValueError(f"Unknown tag table: {table!r}")
    if id_column not in allowed_columns:
        raise ValueError(
            f"Unknown id_column {id_column!r} for table {table!r}. "
            f"Allowed: {sorted(allowed_columns)}"
        )


async def insert_tag(
    *,
    table: str,
    id_column: str,
    id_: int,
    tag_: str,
    user_id: int,
    expdb: AsyncConnection,
) -> None:
    _validate_identifiers(table, id_column)
    await expdb.execute(
        text(
            f"""
            INSERT INTO {table}(`{id_column}`, `tag`, `uploader`)
            VALUES (:id, :tag, :user_id)
            """,
        ),
        parameters={"id": id_, "tag": tag_, "user_id": user_id},
    )


async def select_tag(
    *,
    table: str,
    id_column: str,
    id_: int,
    tag_: str,
    expdb: AsyncConnection,
) -> Row | None:
    _validate_identifiers(table, id_column)
    result = await expdb.execute(
        text(
            f"""
            SELECT `{id_column}` as id, `tag`, `uploader`
            FROM {table}
            WHERE `{id_column}` = :id AND `tag` = :tag
            """,
        ),
        parameters={"id": id_, "tag": tag_},
    )
    return result.one_or_none()


async def remove_tag(
    *,
    table: str,
    id_column: str,
    id_: int,
    tag_: str,
    expdb: AsyncConnection,
) -> None:
    _validate_identifiers(table, id_column)
    await expdb.execute(
        text(
            f"""
            DELETE FROM {table}
            WHERE `{id_column}` = :id AND `tag` = :tag
            """,
        ),
        parameters={"id": id_, "tag": tag_},
    )


async def select_tags(
    *,
    table: str,
    id_column: str,
    id_: int,
    expdb: AsyncConnection,
) -> list[str]:
    _validate_identifiers(table, id_column)
    result = await expdb.execute(
        text(
            f"""
            SELECT `tag`
            FROM {table}
            WHERE `{id_column}` = :id
            """,
        ),
        parameters={"id": id_},
    )
    return [row.tag for row in result.all()]

Key design decisions:

  • _ALLOWED_IDENTIFIERS maps each table to the set of columns it accepts, so the two dimensions are validated together (e.g. you can't accidentally use study_id against task_tag).
  • _validate_identifiers is called once at the top of every public function, before any SQL is constructed, so the f-string interpolation is always gated.
  • The ValueError messages identify the bad value by name, which makes debugging straightforward without leaking sensitive data.

✏️ Learnings added
Learnt from: PGijsbers
Repo: openml/server-api PR: 250
File: src/database/tagging.py:7-8
Timestamp: 2026-03-23T14:36:51.518Z
Learning: In `src/database/tagging.py` (openml/server-api), the allowed tag table names are: "setup_tag", "dataset_tag", "task_tag", "implementation_tag", "study_tag", "run_tag". The allowed id_column values per table are: most tables use "id"; "study_tag" uses "study_id". Identifier inputs to the shared SQL helpers (insert_tag, select_tag, remove_tag, select_tags) must be validated against this allowlist before any SQL interpolation.

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.


🧠 Learnings used
Learnt from: saathviksheerla
Repo: openml/server-api PR: 267
File: src/routers/openml/runs.py:45-51
Timestamp: 2026-03-08T07:59:51.515Z
Learning: In Python code using SQLAlchemy, VARCHAR/text columns are returned as Python strings. Avoid wrapping such values with str(...) when constructing objects (e.g., TraceIteration). Treat string-type columns as strings and pass them through directly unless a transformation is actually required. This guidance applies to similar string columns across Python files in the repository (and can be considered for enum-like values that are exposed as strings).

id_: int,
tag_: str,
user_id: int,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
f"""
INSERT INTO {table}(`{id_column}`, `tag`, `uploader`)
VALUES (:id, :tag, :user_id)
""",
),
parameters={"id": id_, "tag": tag_, "user_id": user_id},
)


async def select_tag(
*,
table: str,
id_column: str,
id_: int,
tag_: str,
expdb: AsyncConnection,
) -> Row | None:
result = await expdb.execute(
text(
f"""
SELECT `{id_column}` as id, `tag`, `uploader`
FROM {table}
WHERE `{id_column}` = :id AND `tag` = :tag
""",
),
parameters={"id": id_, "tag": tag_},
)
return result.one_or_none()


async def remove_tag(
*,
table: str,
id_column: str,
id_: int,
tag_: str,
expdb: AsyncConnection,
) -> None:
await expdb.execute(
text(
f"""
DELETE FROM {table}
WHERE `{id_column}` = :id AND `tag` = :tag
""",
),
parameters={"id": id_, "tag": tag_},
)
Comment on lines +46 to +62
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bind the DELETE to the checked owner.

The new /untag flow does the ownership check before it calls delete_tag, but remove_tag deletes by (id, tag) only. If that row is removed and recreated by someone else between the check and this DELETE, the original requester can still delete the new owner's tag. Make the mutation conditional on the expected uploader (or row PK) so authorization and deletion stay atomic.

🔒 Suggested direction
 async def remove_tag(
     *,
     table: str,
     id_column: str,
     id_: int,
     tag_: str,
+    uploader: int | None = None,
     expdb: AsyncConnection,
 ) -> None:
+    uploader_clause = " AND `uploader` = :uploader" if uploader is not None else ""
     await expdb.execute(
         text(
             f"""
             DELETE FROM {table}
-            WHERE `{id_column}` = :id AND `tag` = :tag
+            WHERE `{id_column}` = :id AND `tag` = :tag{uploader_clause}
             """,
         ),
-        parameters={"id": id_, "tag": tag_},
+        parameters={
+            "id": id_,
+            "tag": tag_,
+            **({"uploader": uploader} if uploader is not None else {}),
+        },
     )

Then thread the expected uploader through core.tagging.untag_entity for non-admin deletes.

🧰 Tools
🪛 Ruff (0.15.6)

[error] 56-59: Possible SQL injection vector through string-based query construction

(S608)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/database/tagging.py` around lines 46 - 62, remove_tag currently deletes
rows by (id, tag) only which allows TOCTOU deletion by another uploader; change
remove_tag to accept an expected_uploader (or expected_row_pk) parameter and
include it in the DELETE WHERE clause (e.g., AND `uploader` = :expected_uploader
or AND primary_key = :expected_pk) so the DELETE is conditional/atomic; update
callers (notably core.tagging.untag_entity and any pathway used by non-admins)
to pass the checked uploader value when calling remove_tag so authorization is
enforced in the same SQL statement.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore this one, because we would also need to account for requests by administrators to remove the tag. This isn't a very impactful race condition, but I'll make a note of it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PGijsbers, understood — acknowledged that the admin bypass makes the conditional-uploader approach non-trivial, and that the window for this race condition is narrow in practice. Would you like me to open a follow-up GitHub issue to track it so the note doesn't get lost?


🧠 Learnings used
Learnt from: saathviksheerla
Repo: openml/server-api PR: 267
File: src/routers/openml/runs.py:45-51
Timestamp: 2026-03-08T07:59:51.515Z
Learning: In Python code using SQLAlchemy, VARCHAR/text columns are returned as Python strings. Avoid wrapping such values with str(...) when constructing objects (e.g., TraceIteration). Treat string-type columns as strings and pass them through directly unless a transformation is actually required. This guidance applies to similar string columns across Python files in the repository (and can be considered for enum-like values that are exposed as strings).



async def select_tags(
*,
table: str,
id_column: str,
id_: int,
expdb: AsyncConnection,
) -> list[str]:
result = await expdb.execute(
text(
f"""
SELECT `tag`
FROM {table}
WHERE `{id_column}` = :id
""",
),
parameters={"id": id_},
)
return [row.tag for row in result.all()]
35 changes: 24 additions & 11 deletions src/database/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
from sqlalchemy import Row, text
from sqlalchemy.ext.asyncio import AsyncConnection

from database.tagging import insert_tag, remove_tag, select_tag, select_tags

_TABLE = "task_tag"
_ID_COLUMN = "id"


async def get(id_: int, expdb: AsyncConnection) -> Row | None:
row = await expdb.execute(
Expand Down Expand Up @@ -103,15 +108,23 @@ async def get_task_type_inout_with_template(


async def get_tags(id_: int, expdb: AsyncConnection) -> list[str]:
rows = await expdb.execute(
text(
"""
SELECT `tag`
FROM task_tag
WHERE `id` = :task_id
""",
),
parameters={"task_id": id_},
return await select_tags(table=_TABLE, id_column=_ID_COLUMN, id_=id_, expdb=expdb)


async def tag(id_: int, tag_: str, *, user_id: int, expdb: AsyncConnection) -> None:
await insert_tag(
table=_TABLE,
id_column=_ID_COLUMN,
id_=id_,
tag_=tag_,
user_id=user_id,
expdb=expdb,
)
tag_rows = rows.all()
return [row.tag for row in tag_rows]


async def get_tag(id_: int, tag_: str, expdb: AsyncConnection) -> Row | None:
return await select_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)


async def delete_tag(id_: int, tag_: str, expdb: AsyncConnection) -> None:
await remove_tag(table=_TABLE, id_column=_ID_COLUMN, id_=id_, tag_=tag_, expdb=expdb)
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from routers.openml.evaluations import router as evaluationmeasures_router
from routers.openml.flows import router as flows_router
from routers.openml.qualities import router as qualities_router
from routers.openml.runs import router as runs_router
from routers.openml.setups import router as setup_router
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
Expand Down Expand Up @@ -68,6 +69,7 @@ def create_api() -> FastAPI:
app.include_router(estimationprocedure_router)
app.include_router(task_router)
app.include_router(flows_router)
app.include_router(runs_router)
app.include_router(study_router)
app.include_router(setup_router)
return app
Expand Down
46 changes: 43 additions & 3 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,57 @@
from typing import Annotated, Literal
from typing import Annotated, Any, Literal

from fastapi import APIRouter, Depends
from fastapi import APIRouter, Body, Depends
from sqlalchemy.ext.asyncio import AsyncConnection

import database.flows
from core.conversions import _str_to_num
from core.errors import FlowNotFoundError
from routers.dependencies import expdb_connection
from core.tagging import tag_entity, untag_entity
from database.users import User
from routers.dependencies import expdb_connection, fetch_user_or_raise
from routers.types import SystemString64
from schemas.flows import Flow, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.post(path="/tag")
async def tag_flow(
flow_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
return await tag_entity(
flow_id,
tag,
user,
expdb,
get_tags_fn=database.flows.get_tags,
tag_fn=database.flows.tag,
response_key="flow_tag",
)


@router.post(path="/untag")
async def untag_flow(
flow_id: Annotated[int, Body()],
tag: Annotated[str, SystemString64],
user: Annotated[User, Depends(fetch_user_or_raise)],
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, Any]]:
return await untag_entity(
flow_id,
tag,
user,
expdb,
get_tag_fn=database.flows.get_tag,
delete_tag_fn=database.flows.delete_tag,
get_tags_fn=database.flows.get_tags,
response_key="flow_tag",
)


@router.get("/exists/{name}/{external_version}")
async def flow_exists(
name: str,
Expand Down
Loading