Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/core/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,9 @@ class DatasetError(IntEnum):
NOT_FOUND = 111
NO_ACCESS = 112
NO_DATA_FILE = 113


class UserError(IntEnum):
NOT_FOUND = 120
NO_ACCESS = 121
HAS_RESOURCES = 122
81 changes: 81 additions & 0 deletions src/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,84 @@ def groups(self) -> list[UserGroup]:
groups = get_user_groups_for(user_id=self.user_id, connection=self._database)
self._groups = [UserGroup(group_id) for group_id in groups]
return self._groups


def get_user_resource_count(*, user_id: int, expdb: Connection) -> int:
"""Return the total number of datasets, flows, and runs owned by the user."""
dataset_count = (
expdb.execute(
text("SELECT COUNT(*) FROM dataset WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)
flow_count = (
expdb.execute(
text("SELECT COUNT(*) FROM implementation WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)
run_count = (
expdb.execute(
text("SELECT COUNT(*) FROM run WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)

study_count = (
expdb.execute(
text("SELECT COUNT(*) FROM study WHERE creator = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)
task_study_count = (
expdb.execute(
text("SELECT COUNT(*) FROM task_study WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)
run_study_count = (
expdb.execute(
text("SELECT COUNT(*) FROM run_study WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)
dataset_tag_count = (
expdb.execute(
text("SELECT COUNT(*) FROM dataset_tag WHERE uploader = :user_id"),
parameters={"user_id": user_id},
).scalar()
or 0
)

return int(
dataset_count
+ flow_count
+ run_count
+ study_count
+ task_study_count
+ run_study_count
+ dataset_tag_count,
)


def delete_user(*, user_id: int, connection: Connection) -> None:
"""Remove the user and their group memberships from the user database."""
with connection.begin_nested() as transaction:
try:
connection.execute(
text("DELETE FROM users_groups WHERE user_id = :user_id"),
parameters={"user_id": user_id},
)
connection.execute(
text("DELETE FROM users WHERE id = :user_id"),
parameters={"user_id": user_id},
)
except Exception:
transaction.rollback()
raise
2 changes: 2 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from routers.openml.study import router as study_router
from routers.openml.tasks import router as task_router
from routers.openml.tasktype import router as ttype_router
from routers.openml.users import router as users_router


def _parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -55,6 +56,7 @@ def create_api() -> FastAPI:
app.include_router(task_router)
app.include_router(flows_router)
app.include_router(study_router)
app.include_router(users_router)
return app


Expand Down
23 changes: 18 additions & 5 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
import database.flows
from core.conversions import _str_to_num
from routers.dependencies import expdb_connection
from schemas.flows import Flow, Parameter, Subflow
from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow

router = APIRouter(prefix="/flows", tags=["flows"])


@router.get("/exists/{name}/{external_version}")
@router.post("/exists")
def flow_exists(
name: str,
external_version: str,
body: FlowExistsBody,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Check if a Flow with the name and version exists, if so, return the flow id."""
flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb)
flow = database.flows.get_by_name(
name=body.name,
external_version=body.external_version,
expdb=expdb,
)
if flow is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
Expand All @@ -28,6 +31,16 @@ def flow_exists(
return {"flow_id": flow.id}


@router.get("/exists/{name}/{external_version}", deprecated=True)
def flow_exists_get(
name: str,
external_version: str,
expdb: Annotated[Connection, Depends(expdb_connection)],
) -> dict[Literal["flow_id"], int]:
"""Deprecated: use POST /flows/exists instead."""
return flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb)


@router.get("/{flow_id}")
def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow:
flow = database.flows.get(flow_id, expdb)
Expand Down
94 changes: 94 additions & 0 deletions src/routers/openml/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from http import HTTPStatus
from typing import Annotated, Any

from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import Connection

from core.errors import UserError
from database.users import User, UserGroup, delete_user, get_user_resource_count
from routers.dependencies import expdb_connection, fetch_user, userdb_connection

router = APIRouter(prefix="/users", tags=["users"])


@router.delete(
"/{user_id}",
summary="Delete a user account",
description=(
"Deletes the account of the specified user. "
"Only the account owner or an admin may perform this action. "
"Deletion is blocked if the user has uploaded any owned resources."
),
)
def delete_account(
user_id: int,
caller: Annotated[User | None, Depends(fetch_user)] = None,
user_db: Annotated[Connection, Depends(userdb_connection)] = None,
expdb: Annotated[Connection, Depends(expdb_connection)] = None,
) -> dict[str, Any]:
if caller is None:
raise HTTPException(
status_code=HTTPStatus.UNAUTHORIZED,
detail={"code": str(int(UserError.NO_ACCESS)), "message": "Authentication required"},
)

is_admin = UserGroup.ADMIN in caller.groups
is_self = caller.user_id == user_id

if not is_admin and not is_self:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail={"code": str(int(UserError.NO_ACCESS)), "message": "No access granted"},
)

import uuid

from sqlalchemy import text # noqa: PLC0415

original = user_db.execute(
text("SELECT session_hash FROM users WHERE id = :id FOR UPDATE"),
parameters={"id": user_id},
).fetchone()

if original is None:
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail={"code": str(int(UserError.NOT_FOUND)), "message": "User not found"},
)

# Invalidate session immediately to prevent concurrent resource creation
original_session_hash = original[0]
temp_lock_hash = uuid.uuid4().hex
user_db.execute(
text("UPDATE users SET session_hash = :lock_hash WHERE id = :id"),
parameters={"lock_hash": temp_lock_hash, "id": user_id},
)
user_db.commit()
Comment on lines +59 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

DELETION_PENDING becomes a working credential here.

src/database/users.py:29-40 still authenticates with WHERE session_hash = :api_key, so after Line 61 any request that sends api_key=DELETION_PENDING can impersonate the target user until this flow restores or deletes the row. Use an unguessable tombstone value here, or better, a dedicated pending flag that the auth path rejects.

Possible minimal fix
 from http import HTTPStatus
+import secrets
 from typing import Annotated, Any
-    user_db.execute(
-        text("UPDATE users SET session_hash = 'DELETION_PENDING' WHERE id = :id"),
-        parameters={"id": user_id},
-    )
+    pending_session_hash = secrets.token_hex(32)
+    user_db.execute(
+        text("UPDATE users SET session_hash = :session_hash WHERE id = :id"),
+        parameters={"session_hash": pending_session_hash, "id": user_id},
+    )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/routers/openml/users.py` around lines 57 - 64, The current update sets
session_hash = 'DELETION_PENDING' which can be used as a valid api_key; change
the invalidation to use an unguessable tombstone or a dedicated flag that the
auth path will reject: in the block that writes session_hash (the code using
original_session_hash and user_db.execute(... "UPDATE users SET session_hash =
'DELETION_PENDING' ...")), replace the literal with a securely generated unique
value (e.g. a UUID or cryptographic random token) or instead set a new
deletion_pending boolean/enum column and update that field; also update the
authentication check in the auth function in users.py (the code that queries
WHERE session_hash = :api_key) to explicitly reject tombstone tokens or to check
deletion_pending = false so the temporary marker cannot be used to authenticate.

Comment on lines +48 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Concurrent delete attempts can restore the wrong session_hash.

Lines 48-66 snapshot whatever value is currently in users.session_hash. If a second request enters after another delete attempt has already written its UUID lock, original_session_hash becomes that temporary value. When both requests unwind through Lines 87-94, the last restore wins and can leave the account stuck with a random hash even though deletion was blocked. This needs an explicit “deletion in progress” state instead of restoring whatever happened to be in session_hash.

Also applies to: 87-94

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/routers/openml/users.py` around lines 48 - 66, The current logic
snapshots and later restores users.session_hash, which can clobber another
request's temporary lock; instead add and use an explicit deletion-in-progress
flag on the row (e.g., a boolean column deletion_in_progress) and update the
flow in the SELECT FOR UPDATE block and the restore block: after selecting the
row in the code around the SELECT ... FOR UPDATE (where original_session_hash
and temp_lock_hash are set) set deletion_in_progress = true in the same
transaction, and when cleaning up (the section currently restoring session_hash
around lines 87-94) only clear deletion_in_progress = false (and avoid blindly
restoring session_hash), or restore the original session_hash only if the
current session_hash still equals the temp_lock_hash; update any
rollback/cleanup to clear deletion_in_progress to false. Locate changes to the
SELECT ... FOR UPDATE block and the corresponding restore/cleanup block in
users.py to implement this.


deletion_successful = False
try:
resource_count = get_user_resource_count(user_id=user_id, expdb=expdb)
if resource_count > 0:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail={
"code": str(int(UserError.HAS_RESOURCES)),
"message": (
f"User has {resource_count} resource(s). "
"Remove or transfer resources before deleting the account."
),
},
)

delete_user(user_id=user_id, connection=user_db)
user_db.commit()
deletion_successful = True
return {"user_id": user_id, "deleted": True}
finally:
if not deletion_successful:
# Restore session hash if deletion did not complete successfully
user_db.execute(
text("UPDATE users SET session_hash = :hash WHERE id = :id"),
parameters={"hash": original_session_hash, "id": user_id},
)
user_db.commit()
5 changes: 5 additions & 0 deletions src/schemas/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
from pydantic import BaseModel, ConfigDict, Field


class FlowExistsBody(BaseModel):
name: str
external_version: str


class Parameter(BaseModel):
name: str
default_value: Any
Expand Down
27 changes: 23 additions & 4 deletions tests/routers/openml/flows_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from starlette.testclient import TestClient

from routers.openml.flows import flow_exists
from schemas.flows import FlowExistsBody
from tests.conftest import Flow


Expand All @@ -25,7 +26,7 @@ def test_flow_exists_calls_db_correctly(
mocker: MockerFixture,
) -> None:
mocked_db = mocker.patch("database.flows.get_by_name")
flow_exists(name, external_version, expdb_test)
flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test)
mocked_db.assert_called_once_with(
name=name,
external_version=external_version,
Expand All @@ -47,25 +48,43 @@ def test_flow_exists_processes_found(
"database.flows.get_by_name",
return_value=fake_flow,
)
response = flow_exists("name", "external_version", expdb_test)
response = flow_exists(
FlowExistsBody(name="name", external_version="external_version"), expdb_test
)
assert response == {"flow_id": fake_flow.id}


def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None:
mocker.patch("database.flows.get_by_name", return_value=None)
with pytest.raises(HTTPException) as error:
flow_exists("foo", "bar", expdb_test)
flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test)
assert error.value.status_code == HTTPStatus.NOT_FOUND
assert error.value.detail == "Flow not found."


def test_flow_exists(flow: Flow, py_api: TestClient) -> None:
response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
response = py_api.post(
"/flows/exists", json={"name": flow.name, "external_version": flow.external_version}
)
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


def test_flow_exists_not_exists(py_api: TestClient) -> None:
response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"})
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == "Flow not found."


def test_flow_exists_get_alias(flow: Flow, py_api: TestClient) -> None:
"""Test the deprecated GET wrapper for backward compatibility."""
response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}")
assert response.status_code == HTTPStatus.OK
assert response.json() == {"flow_id": flow.id}


def test_flow_exists_get_alias_not_exists(py_api: TestClient) -> None:
"""Test the deprecated GET wrapper returns 404 for non-existent flows."""
response = py_api.get("/flows/exists/foo/bar")
assert response.status_code == HTTPStatus.NOT_FOUND
assert response.json()["detail"] == "Flow not found."
Expand Down
15 changes: 9 additions & 6 deletions tests/routers/openml/migration/flows_migration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ def test_flow_exists_not(
py_api: TestClient,
php_api: TestClient,
) -> None:
path = "exists/foo/bar"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")
py_response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"})
php_response = php_api.get("/flow/exists/foo/bar")

assert py_response.status_code == HTTPStatus.NOT_FOUND
assert php_response.status_code == HTTPStatus.OK
Expand All @@ -36,9 +35,13 @@ def test_flow_exists(
py_api: TestClient,
php_api: TestClient,
) -> None:
path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}"
py_response = py_api.get(f"/flows/{path}")
php_response = php_api.get(f"/flow/{path}")
py_response = py_api.post(
"/flows/exists",
json={"name": persisted_flow.name, "external_version": persisted_flow.external_version},
)
php_response = php_api.get(
f"/flow/exists/{persisted_flow.name}/{persisted_flow.external_version}"
)

assert py_response.status_code == php_response.status_code, php_response.content

Expand Down
Loading