From 443e76795d462113e095d9db56085313d2575275 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Fri, 4 Apr 2025 08:39:37 -0400 Subject: [PATCH] feat(api): require final review permission for review endpoints BREAKING CHANGE: All `/review/`endpoints now require the `source_collector_final_review` permission --- api/routes/review.py | 10 ++++++---- security_manager/SecurityManager.py | 16 +++++++++++++--- tests/test_automated/integration/api/conftest.py | 12 ++++++++++-- .../security_manager/test_security_manager.py | 4 ++-- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/api/routes/review.py b/api/routes/review.py index 649e0b39..62bf5de6 100644 --- a/api/routes/review.py +++ b/api/routes/review.py @@ -7,7 +7,7 @@ from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo, FinalReviewBaseInfo from core.DTOs.GetNextURLForFinalReviewResponse import GetNextURLForFinalReviewResponse, \ GetNextURLForFinalReviewOuterResponse -from security_manager.SecurityManager import AccessInfo, get_access_info +from security_manager.SecurityManager import AccessInfo, get_access_info, require_permission, Permissions review_router = APIRouter( prefix="/review", @@ -15,10 +15,12 @@ responses={404: {"description": "Not found"}}, ) +requires_final_review_permission = require_permission(Permissions.SOURCE_COLLECTOR_FINAL_REVIEW) + @review_router.get("/next-source") async def get_next_source( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(requires_final_review_permission), batch_id: Optional[int] = Query( description="The batch id of the next URL to get. " "If not specified, defaults to first qualifying URL", @@ -30,7 +32,7 @@ async def get_next_source( @review_router.post("/approve-source") async def approve_source( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(requires_final_review_permission), approval_info: FinalReviewApprovalInfo = FinalReviewApprovalInfo, batch_id: Optional[int] = Query( description="The batch id of the next URL to get. " @@ -47,7 +49,7 @@ async def approve_source( @review_router.post("/reject-source") async def reject_source( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(requires_final_review_permission), review_info: FinalReviewBaseInfo = FinalReviewBaseInfo, batch_id: Optional[int] = Query( description="The batch id of the next URL to get. " diff --git a/security_manager/SecurityManager.py b/security_manager/SecurityManager.py index 8d80f46c..92da2975 100644 --- a/security_manager/SecurityManager.py +++ b/security_manager/SecurityManager.py @@ -20,6 +20,7 @@ def get_secret_key(): class Permissions(Enum): SOURCE_COLLECTOR = "source_collector" + SOURCE_COLLECTOR_FINAL_REVIEW = "source_collector_final_review" class AccessInfo(BaseModel): user_id: int @@ -65,9 +66,13 @@ def get_relevant_permissions(raw_permissions: list[str]) -> list[Permissions]: continue return relevant_permissions - def check_access(self, token: str) -> AccessInfo: + def check_access( + self, + token: str, + permission: Permissions + ) -> AccessInfo: access_info = self.validate_token(token) - if not access_info.has_permission(Permissions.SOURCE_COLLECTOR): + if not access_info.has_permission(permission): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access forbidden", @@ -80,4 +85,9 @@ def check_access(self, token: str) -> AccessInfo: def get_access_info( token: Annotated[str, Depends(oauth2_scheme)] ) -> AccessInfo: - return SecurityManager().check_access(token) \ No newline at end of file + return SecurityManager().check_access(token, Permissions.SOURCE_COLLECTOR) + +def require_permission(permission: Permissions): + def dependency(token: Annotated[str, Depends(oauth2_scheme)]) -> AccessInfo: + return SecurityManager().check_access(token, permission=permission) + return dependency \ No newline at end of file diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index d9a504a7..a0a46abf 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -6,8 +6,9 @@ from starlette.testclient import TestClient from api.main import app +from api.routes.review import requires_final_review_permission from core.SourceCollectorCore import SourceCollectorCore -from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions +from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions, require_permission from tests.helpers.DBDataCreator import DBDataCreator from tests.test_automated.integration.api.helpers.RequestValidator import RequestValidator @@ -27,12 +28,19 @@ def adb_client(self): def override_access_info() -> AccessInfo: - return AccessInfo(user_id=MOCK_USER_ID, permissions=[Permissions.SOURCE_COLLECTOR]) + return AccessInfo( + user_id=MOCK_USER_ID, + permissions=[ + Permissions.SOURCE_COLLECTOR, + Permissions.SOURCE_COLLECTOR_FINAL_REVIEW + ] + ) @pytest.fixture def client(db_client_test) -> Generator[TestClient, None, None]: with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info + app.dependency_overrides[requires_final_review_permission] = override_access_info core: SourceCollectorCore = c.app.state.core # core.shutdown() yield c diff --git a/tests/test_automated/unit/security_manager/test_security_manager.py b/tests/test_automated/unit/security_manager/test_security_manager.py index f827cc1b..fd03fee5 100644 --- a/tests/test_automated/unit/security_manager/test_security_manager.py +++ b/tests/test_automated/unit/security_manager/test_security_manager.py @@ -49,7 +49,7 @@ def test_validate_token_failure(mock_get_secret_key, mock_jwt_decode): def test_check_access_success(mock_get_secret_key, mock_jwt_decode): sm = SecurityManager() - sm.check_access(VALID_TOKEN) # Should not raise any exceptions. + sm.check_access(VALID_TOKEN, Permissions.SOURCE_COLLECTOR) # Should not raise any exceptions. def test_check_access_failure(mock_get_secret_key, mock_jwt_decode): @@ -57,7 +57,7 @@ def test_check_access_failure(mock_get_secret_key, mock_jwt_decode): with patch(get_patch_path("SecurityManager.validate_token"), return_value=AccessInfo(user_id=1, permissions=[])): sm = SecurityManager() with pytest.raises(HTTPException) as exc_info: - sm.check_access(VALID_TOKEN) + sm.check_access(VALID_TOKEN, Permissions.SOURCE_COLLECTOR) assert exc_info.value.status_code == 403