diff --git a/src/core/errors.py b/src/core/errors.py index 840cd75f..6626a639 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -5,3 +5,9 @@ class DatasetError(IntEnum): NOT_FOUND = 111 NO_ACCESS = 112 NO_DATA_FILE = 113 + + +class UserError(IntEnum): + NOT_FOUND = 120 + NO_ACCESS = 121 + HAS_RESOURCES = 122 diff --git a/src/database/users.py b/src/database/users.py index b439be7e..cb75920d 100644 --- a/src/database/users.py +++ b/src/database/users.py @@ -72,3 +72,84 @@ def groups(self) -> list[UserGroup]: groups = get_user_groups_for(user_id=self.user_id, connection=self._database) self._groups = [UserGroup(group_id) for group_id in groups] return self._groups + + +def get_user_resource_count(*, user_id: int, expdb: Connection) -> int: + """Return the total number of datasets, flows, and runs owned by the user.""" + dataset_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM dataset WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + flow_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM implementation WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + run_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM run WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + + study_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM study WHERE creator = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + task_study_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM task_study WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + run_study_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM run_study WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + dataset_tag_count = ( + expdb.execute( + text("SELECT COUNT(*) FROM dataset_tag WHERE uploader = :user_id"), + parameters={"user_id": user_id}, + ).scalar() + or 0 + ) + + return int( + dataset_count + + flow_count + + run_count + + study_count + + task_study_count + + run_study_count + + dataset_tag_count, + ) + + +def delete_user(*, user_id: int, connection: Connection) -> None: + """Remove the user and their group memberships from the user database.""" + with connection.begin_nested() as transaction: + try: + connection.execute( + text("DELETE FROM users_groups WHERE user_id = :user_id"), + parameters={"user_id": user_id}, + ) + connection.execute( + text("DELETE FROM users WHERE id = :user_id"), + parameters={"user_id": user_id}, + ) + except Exception: + transaction.rollback() + raise diff --git a/src/main.py b/src/main.py index 560b4c50..528ad32b 100644 --- a/src/main.py +++ b/src/main.py @@ -14,6 +14,7 @@ from routers.openml.study import router as study_router from routers.openml.tasks import router as task_router from routers.openml.tasktype import router as ttype_router +from routers.openml.users import router as users_router def _parse_args() -> argparse.Namespace: @@ -55,6 +56,7 @@ def create_api() -> FastAPI: app.include_router(task_router) app.include_router(flows_router) app.include_router(study_router) + app.include_router(users_router) return app diff --git a/src/routers/openml/flows.py b/src/routers/openml/flows.py index cb6df5d9..31562a08 100644 --- a/src/routers/openml/flows.py +++ b/src/routers/openml/flows.py @@ -7,19 +7,22 @@ import database.flows from core.conversions import _str_to_num from routers.dependencies import expdb_connection -from schemas.flows import Flow, Parameter, Subflow +from schemas.flows import Flow, FlowExistsBody, Parameter, Subflow router = APIRouter(prefix="/flows", tags=["flows"]) -@router.get("/exists/{name}/{external_version}") +@router.post("/exists") def flow_exists( - name: str, - external_version: str, + body: FlowExistsBody, expdb: Annotated[Connection, Depends(expdb_connection)], ) -> dict[Literal["flow_id"], int]: """Check if a Flow with the name and version exists, if so, return the flow id.""" - flow = database.flows.get_by_name(name=name, external_version=external_version, expdb=expdb) + flow = database.flows.get_by_name( + name=body.name, + external_version=body.external_version, + expdb=expdb, + ) if flow is None: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, @@ -28,6 +31,16 @@ def flow_exists( return {"flow_id": flow.id} +@router.get("/exists/{name}/{external_version}", deprecated=True) +def flow_exists_get( + name: str, + external_version: str, + expdb: Annotated[Connection, Depends(expdb_connection)], +) -> dict[Literal["flow_id"], int]: + """Deprecated: use POST /flows/exists instead.""" + return flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb) + + @router.get("/{flow_id}") def get_flow(flow_id: int, expdb: Annotated[Connection, Depends(expdb_connection)] = None) -> Flow: flow = database.flows.get(flow_id, expdb) diff --git a/src/routers/openml/users.py b/src/routers/openml/users.py new file mode 100644 index 00000000..f43a2b50 --- /dev/null +++ b/src/routers/openml/users.py @@ -0,0 +1,94 @@ +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import Connection + +from core.errors import UserError +from database.users import User, UserGroup, delete_user, get_user_resource_count +from routers.dependencies import expdb_connection, fetch_user, userdb_connection + +router = APIRouter(prefix="/users", tags=["users"]) + + +@router.delete( + "/{user_id}", + summary="Delete a user account", + description=( + "Deletes the account of the specified user. " + "Only the account owner or an admin may perform this action. " + "Deletion is blocked if the user has uploaded any owned resources." + ), +) +def delete_account( + user_id: int, + caller: Annotated[User | None, Depends(fetch_user)] = None, + user_db: Annotated[Connection, Depends(userdb_connection)] = None, + expdb: Annotated[Connection, Depends(expdb_connection)] = None, +) -> dict[str, Any]: + if caller is None: + raise HTTPException( + status_code=HTTPStatus.UNAUTHORIZED, + detail={"code": str(int(UserError.NO_ACCESS)), "message": "Authentication required"}, + ) + + is_admin = UserGroup.ADMIN in caller.groups + is_self = caller.user_id == user_id + + if not is_admin and not is_self: + raise HTTPException( + status_code=HTTPStatus.FORBIDDEN, + detail={"code": str(int(UserError.NO_ACCESS)), "message": "No access granted"}, + ) + + import uuid + + from sqlalchemy import text # noqa: PLC0415 + + original = user_db.execute( + text("SELECT session_hash FROM users WHERE id = :id FOR UPDATE"), + parameters={"id": user_id}, + ).fetchone() + + if original is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={"code": str(int(UserError.NOT_FOUND)), "message": "User not found"}, + ) + + # Invalidate session immediately to prevent concurrent resource creation + original_session_hash = original[0] + temp_lock_hash = uuid.uuid4().hex + user_db.execute( + text("UPDATE users SET session_hash = :lock_hash WHERE id = :id"), + parameters={"lock_hash": temp_lock_hash, "id": user_id}, + ) + user_db.commit() + + deletion_successful = False + try: + resource_count = get_user_resource_count(user_id=user_id, expdb=expdb) + if resource_count > 0: + raise HTTPException( + status_code=HTTPStatus.CONFLICT, + detail={ + "code": str(int(UserError.HAS_RESOURCES)), + "message": ( + f"User has {resource_count} resource(s). " + "Remove or transfer resources before deleting the account." + ), + }, + ) + + delete_user(user_id=user_id, connection=user_db) + user_db.commit() + deletion_successful = True + return {"user_id": user_id, "deleted": True} + finally: + if not deletion_successful: + # Restore session hash if deletion did not complete successfully + user_db.execute( + text("UPDATE users SET session_hash = :hash WHERE id = :id"), + parameters={"hash": original_session_hash, "id": user_id}, + ) + user_db.commit() diff --git a/src/schemas/flows.py b/src/schemas/flows.py index a6cd479c..50e2491c 100644 --- a/src/schemas/flows.py +++ b/src/schemas/flows.py @@ -6,6 +6,11 @@ from pydantic import BaseModel, ConfigDict, Field +class FlowExistsBody(BaseModel): + name: str + external_version: str + + class Parameter(BaseModel): name: str default_value: Any diff --git a/tests/routers/openml/flows_test.py b/tests/routers/openml/flows_test.py index d5188d0e..60b1baed 100644 --- a/tests/routers/openml/flows_test.py +++ b/tests/routers/openml/flows_test.py @@ -8,6 +8,7 @@ from starlette.testclient import TestClient from routers.openml.flows import flow_exists +from schemas.flows import FlowExistsBody from tests.conftest import Flow @@ -25,7 +26,7 @@ def test_flow_exists_calls_db_correctly( mocker: MockerFixture, ) -> None: mocked_db = mocker.patch("database.flows.get_by_name") - flow_exists(name, external_version, expdb_test) + flow_exists(FlowExistsBody(name=name, external_version=external_version), expdb_test) mocked_db.assert_called_once_with( name=name, external_version=external_version, @@ -47,25 +48,43 @@ def test_flow_exists_processes_found( "database.flows.get_by_name", return_value=fake_flow, ) - response = flow_exists("name", "external_version", expdb_test) + response = flow_exists( + FlowExistsBody(name="name", external_version="external_version"), expdb_test + ) assert response == {"flow_id": fake_flow.id} def test_flow_exists_handles_flow_not_found(mocker: MockerFixture, expdb_test: Connection) -> None: mocker.patch("database.flows.get_by_name", return_value=None) with pytest.raises(HTTPException) as error: - flow_exists("foo", "bar", expdb_test) + flow_exists(FlowExistsBody(name="foo", external_version="bar"), expdb_test) assert error.value.status_code == HTTPStatus.NOT_FOUND assert error.value.detail == "Flow not found." def test_flow_exists(flow: Flow, py_api: TestClient) -> None: - response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + response = py_api.post( + "/flows/exists", json={"name": flow.name, "external_version": flow.external_version} + ) assert response.status_code == HTTPStatus.OK assert response.json() == {"flow_id": flow.id} def test_flow_exists_not_exists(py_api: TestClient) -> None: + response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"}) + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"] == "Flow not found." + + +def test_flow_exists_get_alias(flow: Flow, py_api: TestClient) -> None: + """Test the deprecated GET wrapper for backward compatibility.""" + response = py_api.get(f"/flows/exists/{flow.name}/{flow.external_version}") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"flow_id": flow.id} + + +def test_flow_exists_get_alias_not_exists(py_api: TestClient) -> None: + """Test the deprecated GET wrapper returns 404 for non-existent flows.""" response = py_api.get("/flows/exists/foo/bar") assert response.status_code == HTTPStatus.NOT_FOUND assert response.json()["detail"] == "Flow not found." diff --git a/tests/routers/openml/migration/flows_migration_test.py b/tests/routers/openml/migration/flows_migration_test.py index 674bc439..5ee8e592 100644 --- a/tests/routers/openml/migration/flows_migration_test.py +++ b/tests/routers/openml/migration/flows_migration_test.py @@ -18,9 +18,8 @@ def test_flow_exists_not( py_api: TestClient, php_api: TestClient, ) -> None: - path = "exists/foo/bar" - py_response = py_api.get(f"/flows/{path}") - php_response = php_api.get(f"/flow/{path}") + py_response = py_api.post("/flows/exists", json={"name": "foo", "external_version": "bar"}) + php_response = php_api.get("/flow/exists/foo/bar") assert py_response.status_code == HTTPStatus.NOT_FOUND assert php_response.status_code == HTTPStatus.OK @@ -36,9 +35,13 @@ def test_flow_exists( py_api: TestClient, php_api: TestClient, ) -> None: - path = f"exists/{persisted_flow.name}/{persisted_flow.external_version}" - py_response = py_api.get(f"/flows/{path}") - php_response = php_api.get(f"/flow/{path}") + py_response = py_api.post( + "/flows/exists", + json={"name": persisted_flow.name, "external_version": persisted_flow.external_version}, + ) + php_response = php_api.get( + f"/flow/exists/{persisted_flow.name}/{persisted_flow.external_version}" + ) assert py_response.status_code == php_response.status_code, php_response.content diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py index 45b330ae..9b108f2f 100644 --- a/tests/routers/openml/users_test.py +++ b/tests/routers/openml/users_test.py @@ -1,27 +1,168 @@ +from http import HTTPStatus + import pytest -from sqlalchemy import Connection +from fastapi.testclient import TestClient +from sqlalchemy import Connection, text + +from tests.users import ApiKey + + +@pytest.mark.mut +def test_delete_user_self(py_api: TestClient, user_test: Connection) -> None: + """A user without resources can delete their own account.""" + user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('aaaabbbbccccddddaaaabbbbccccdddd', 'del@test.com', 'Del', 'User', 'x')", + ), + ) + (new_id,) = user_test.execute(text("SELECT LAST_INSERT_ID()")).one() + + user_test.execute( + text("INSERT INTO users_groups (user_id, group_id) VALUES (:id, 2)"), + parameters={"id": new_id}, + ) + + response = py_api.delete(f"/users/{new_id}?api_key=aaaabbbbccccddddaaaabbbbccccdddd") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"user_id": new_id, "deleted": True} + + user_count = user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": new_id}, + ).scalar() + group_count = user_test.execute( + text("SELECT COUNT(*) FROM users_groups WHERE user_id = :id"), + parameters={"id": new_id}, + ).scalar() + assert user_count == 0 + assert group_count == 0 + + +@pytest.mark.mut +def test_delete_user_as_admin(py_api: TestClient, user_test: Connection) -> None: + """An admin can delete any user without resources.""" + user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('eeeeffffaaaabbbbeeeeffffaaaabbbb', 'del2@test.com', 'Del2', 'User', 'x')", + ), + ) + (new_id,) = user_test.execute(text("SELECT LAST_INSERT_ID()")).one() + + user_test.execute( + text("INSERT INTO users_groups (user_id, group_id) VALUES (:id, 2)"), + parameters={"id": new_id}, + ) + + response = py_api.delete(f"/users/{new_id}?api_key={ApiKey.ADMIN}") + assert response.status_code == HTTPStatus.OK + assert response.json() == {"user_id": new_id, "deleted": True} -from database.users import User -from routers.dependencies import fetch_user -from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey + user_count = user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": new_id}, + ).scalar() + assert user_count == 0 +def test_delete_user_no_auth(py_api: TestClient) -> None: + """No API key → 401.""" + response = py_api.delete("/users/2") + assert response.status_code == HTTPStatus.UNAUTHORIZED + + +def test_delete_user_not_owner(py_api: TestClient) -> None: + """A non-owner non-admin user cannot delete someone else's account → 403.""" + response = py_api.delete(f"/users/3229?api_key={ApiKey.SOME_USER}") + assert response.status_code == HTTPStatus.FORBIDDEN + + +def test_delete_user_not_found(py_api: TestClient) -> None: + """Deleting a non-existent user → 404.""" + response = py_api.delete(f"/users/99999999?api_key={ApiKey.ADMIN}") + assert response.status_code == HTTPStatus.NOT_FOUND + assert response.json()["detail"]["code"] == "120" + + +def test_delete_user_has_resources(py_api: TestClient, user_test: Connection) -> None: + """A user with resources (datasets, flows, runs) gets a 409 Conflict.""" + target_id = 16 + response = py_api.delete(f"/users/{target_id}?api_key={ApiKey.DATASET_130_OWNER}") + + assert response.status_code == HTTPStatus.CONFLICT + assert response.json()["detail"]["code"] == "122" + assert "resource(s)" in response.json()["detail"]["message"] + + user_count = user_test.execute( + text("SELECT COUNT(*) FROM users WHERE id = :id"), + parameters={"id": target_id}, + ).scalar() + assert user_count == 1 + + +@pytest.mark.mut @pytest.mark.parametrize( - ("api_key", "user"), + ("table_name", "column_name", "insert_sql"), [ - (ApiKey.ADMIN, ADMIN_USER), - (ApiKey.OWNER_USER, OWNER_USER), - (ApiKey.SOME_USER, SOME_USER), + ( + "dataset", + "uploader", + "INSERT INTO dataset (uploader, name, format) VALUES (:id, 'x', 'ARFF')", + ), + ( + "implementation", + "uploader", + "INSERT INTO implementation (uploader, fullname, name, version, " + "external_version, uploadDate) VALUES (:id, 'x', 'x', 1, '1', '2024-01-01')", + ), + ("run", "uploader", "INSERT INTO run (uploader, task_id, setup) VALUES (:id, 1, 1)"), + ( + "study", + "creator", + "INSERT INTO study (creator, name, main_entity_type) VALUES (:id, 'x', 'run')", + ), + ( + "task_study", + "uploader", + "INSERT INTO task_study (uploader, study_id, task_id) VALUES (:id, 14, 1)", + ), + ( + "run_study", + "uploader", + "INSERT INTO run_study (uploader, study_id, run_id) VALUES (:id, 14, 1)", + ), + ( + "dataset_tag", + "uploader", + "INSERT INTO dataset_tag (uploader, id, tag) VALUES (:id, 1, 'x')", + ), ], ) -def test_fetch_user(api_key: str, user: User, user_test: Connection) -> None: - db_user = fetch_user(api_key, user_data=user_test) - assert db_user is not None - assert user.user_id == db_user.user_id - assert set(user.groups) == set(db_user.groups) +def test_delete_user_has_resources_parametrized( # noqa: PLR0913 + py_api: TestClient, + user_test: Connection, + expdb_test: Connection, + table_name: str, # noqa: ARG001 + column_name: str, # noqa: ARG001 + insert_sql: str, +) -> None: + """Verify that possessing any tracked resource blocks deletion.""" + user_test.execute( + text( + "INSERT INTO users (session_hash, email, first_name, last_name, password)" + " VALUES ('eeeeffffccccddddaaaabbbbccccdddd', 'res@test.com', 'Del', 'User', 'x')", + ), + ) + (new_id,) = user_test.execute(text("SELECT LAST_INSERT_ID()")).one() + + # Disable constraints temporarily to inject simple orphaned rows for testing 409 + expdb_test.execute(text("SET FOREIGN_KEY_CHECKS=0")) + expdb_test.execute(text(insert_sql), parameters={"id": new_id}) + expdb_test.execute(text("SET FOREIGN_KEY_CHECKS=1")) + expdb_test.commit() + response = py_api.delete(f"/users/{new_id}?api_key=eeeeffffccccddddaaaabbbbccccdddd") -def test_fetch_user_invalid_key_returns_none(user_test: Connection) -> None: - assert fetch_user(api_key=None, user_data=user_test) is None - invalid_key = "f" * 32 - assert fetch_user(api_key=invalid_key, user_data=user_test) is None + assert response.status_code == HTTPStatus.CONFLICT + assert response.json()["detail"]["code"] == "122"