diff --git a/src/api/endpoints/agencies/routes.py b/src/api/endpoints/agencies/routes.py index b0a756aa..bfbf456f 100644 --- a/src/api/endpoints/agencies/routes.py +++ b/src/api/endpoints/agencies/routes.py @@ -16,6 +16,8 @@ from src.api.endpoints.agencies.root.post.response import AgencyPostResponse from src.api.shared.models.message_response import MessageResponse from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_admin_access_info agencies_router = APIRouter(prefix="/agencies", tags=["Agencies"]) @@ -34,7 +36,9 @@ async def get_agencies( @agencies_router.post("") async def create_agency( request: AgencyPostRequest, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), + ) -> AgencyPostResponse: return await async_core.adb_client.run_query_builder( AddAgencyQueryBuilder(request=request) @@ -45,6 +49,7 @@ async def delete_agency( agency_id: int = Path( description="Agency ID to delete" ), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await async_core.adb_client.run_query_builder( @@ -58,6 +63,7 @@ async def update_agency( agency_id: int = Path( description="Agency ID to update" ), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await async_core.adb_client.run_query_builder( @@ -84,6 +90,7 @@ async def add_location_to_agency( location_id: int = Path( description="Location ID to add" ), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await async_core.adb_client.run_query_builder( @@ -99,6 +106,7 @@ async def remove_location_from_agency( location_id: int = Path( description="Location ID to remove" ), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await async_core.adb_client.run_query_builder( diff --git a/src/api/endpoints/annotate/routes.py b/src/api/endpoints/annotate/routes.py index 945de945..0af2afcb 100644 --- a/src/api/endpoints/annotate/routes.py +++ b/src/api/endpoints/annotate/routes.py @@ -17,7 +17,7 @@ from src.core.core import AsyncCore from src.db.queries.implementations.anonymous_session import MakeAnonymousSessionQueryBuilder from src.security.dtos.access_info import AccessInfo -from src.security.manager import get_access_info, get_standard_user_access_info +from src.security.manager import get_admin_access_info, get_standard_user_access_info annotate_router = APIRouter( prefix="/annotate", @@ -136,7 +136,7 @@ async def migrate_annotations_to_user( async def get_agency_suggestions( url_id: int, async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), location_id: int | None = Query(default=None) ) -> AgencyAnnotationResponseOuterInfo: return await async_core.adb_client.run_query_builder( diff --git a/src/api/endpoints/batch/routes.py b/src/api/endpoints/batch/routes.py index 4dfbbbfc..81abb7bc 100644 --- a/src/api/endpoints/batch/routes.py +++ b/src/api/endpoints/batch/routes.py @@ -12,7 +12,7 @@ from src.core.core import AsyncCore from src.db.models.materialized_views.batch_url_status.enums import BatchURLStatusViewEnum from src.security.dtos.access_info import AccessInfo -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info batch_router = APIRouter( prefix="/batch", @@ -36,7 +36,7 @@ async def get_batch_status( default=1 ), core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> GetBatchSummariesResponse: """ Get the status of recent batches @@ -52,7 +52,7 @@ async def get_batch_status( async def get_batch_info( batch_id: int = Path(description="The batch id"), core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> BatchSummary: return await core.get_batch_info(batch_id) @@ -64,7 +64,7 @@ async def get_urls_by_batch( default=1 ), core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> GetURLsByBatchResponse: return await core.get_urls_by_batch(batch_id, page=page) @@ -76,7 +76,7 @@ async def get_duplicates_by_batch( default=1 ), core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> GetDuplicatesByBatchResponse: return await core.get_duplicate_urls_by_batch(batch_id, page=page) @@ -84,7 +84,7 @@ async def get_duplicates_by_batch( async def get_batch_logs( batch_id: int = Path(description="The batch id"), async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> GetBatchLogsResponse: """ Retrieve the logs for a recent batch. @@ -96,6 +96,6 @@ async def get_batch_logs( async def abort_batch( batch_id: int = Path(description="The batch id"), async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> MessageResponse: return await async_core.abort_batch(batch_id) \ No newline at end of file diff --git a/src/api/endpoints/collector/routes.py b/src/api/endpoints/collector/routes.py index 4818dc63..0ab89261 100644 --- a/src/api/endpoints/collector/routes.py +++ b/src/api/endpoints/collector/routes.py @@ -10,7 +10,7 @@ from src.collectors.impl.example.dtos.input import ExampleInputDTO from src.collectors.enums import CollectorType from src.core.core import AsyncCore -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo from src.collectors.impl.ckan.dtos.input import CKANInputDTO from src.collectors.impl.muckrock.collectors.all_foia.dto import MuckrockAllFOIARequestsCollectorInputDTO @@ -27,7 +27,7 @@ async def start_example_collector( dto: ExampleInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the example collector @@ -42,7 +42,7 @@ async def start_example_collector( async def start_ckan_collector( dto: CKANInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the ckan collector @@ -57,7 +57,7 @@ async def start_ckan_collector( async def start_common_crawler_collector( dto: CommonCrawlerInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the common crawler collector @@ -72,7 +72,7 @@ async def start_common_crawler_collector( async def start_auto_googler_collector( dto: AutoGooglerInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the auto googler collector @@ -87,7 +87,7 @@ async def start_auto_googler_collector( async def start_muckrock_collector( dto: MuckrockSimpleSearchCollectorInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the muckrock collector @@ -102,7 +102,7 @@ async def start_muckrock_collector( async def start_muckrock_county_collector( dto: MuckrockCountySearchCollectorInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the muckrock county level collector @@ -117,7 +117,7 @@ async def start_muckrock_county_collector( async def start_muckrock_all_foia_collector( dto: MuckrockAllFOIARequestsCollectorInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> CollectorStartInfo: """ Start the muckrock collector for all FOIA requests @@ -132,7 +132,7 @@ async def start_muckrock_all_foia_collector( async def upload_manual_collector( dto: ManualBatchInputDTO, core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> ManualBatchResponseDTO: """ Uploads a manual "collector" with existing data diff --git a/src/api/endpoints/data_source/routes.py b/src/api/endpoints/data_source/routes.py index 25787b85..a657ac18 100644 --- a/src/api/endpoints/data_source/routes.py +++ b/src/api/endpoints/data_source/routes.py @@ -13,6 +13,8 @@ from src.api.endpoints.data_source.by_id.put.request import DataSourcePutRequest from src.api.shared.models.message_response import MessageResponse from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_admin_access_info data_sources_router = APIRouter( prefix="/data-sources", @@ -45,6 +47,7 @@ async def get_data_source_by_id( async def update_data_source( url_id: int , request: DataSourcePutRequest, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await check_is_data_source_url(url_id=url_id, adb_client=async_core.adb_client) @@ -70,6 +73,7 @@ async def get_data_source_agencies( async def add_agency_to_data_source( url_id: int, agency_id: int, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await add_data_source_agency_link( @@ -83,6 +87,7 @@ async def add_agency_to_data_source( async def remove_agency_from_data_source( url_id: int, agency_id: int, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await delete_data_source_agency_link( diff --git a/src/api/endpoints/locations/routes.py b/src/api/endpoints/locations/routes.py index 4a0ef096..c86f66b5 100644 --- a/src/api/endpoints/locations/routes.py +++ b/src/api/endpoints/locations/routes.py @@ -5,6 +5,8 @@ from src.api.endpoints.locations.post.request import AddLocationRequestModel from src.api.endpoints.locations.post.response import AddLocationResponseModel from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_admin_access_info location_url_router = APIRouter( prefix="/locations", @@ -15,6 +17,7 @@ @location_url_router.post("") async def create_location( request: AddLocationRequestModel, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> AddLocationResponseModel: return await async_core.adb_client.run_query_builder( diff --git a/src/api/endpoints/meta_url/routes.py b/src/api/endpoints/meta_url/routes.py index 82a36756..790fd519 100644 --- a/src/api/endpoints/meta_url/routes.py +++ b/src/api/endpoints/meta_url/routes.py @@ -12,6 +12,8 @@ from src.api.endpoints.meta_url.by_id.put.request import UpdateMetaURLRequest from src.api.shared.models.message_response import MessageResponse from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_admin_access_info meta_urls_router = APIRouter( prefix="/meta-urls", @@ -35,6 +37,7 @@ async def get_meta_urls( async def update_meta_url( url_id: int, request: UpdateMetaURLRequest, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await check_is_meta_url(url_id=url_id, adb_client=async_core.adb_client) @@ -61,6 +64,7 @@ async def get_meta_url_agencies( async def add_agency_to_meta_url( url_id: int, agency_id: int, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await add_meta_url_agency_link( @@ -74,6 +78,7 @@ async def add_agency_to_meta_url( async def remove_agency_from_meta_url( url_id: int, agency_id: int, + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> MessageResponse: await delete_meta_url_agency_link( diff --git a/src/api/endpoints/metrics/routes.py b/src/api/endpoints/metrics/routes.py index 59fa5906..06c09de3 100644 --- a/src/api/endpoints/metrics/routes.py +++ b/src/api/endpoints/metrics/routes.py @@ -10,7 +10,7 @@ from src.api.endpoints.metrics.dtos.get.urls.breakdown.pending import GetMetricsURLsBreakdownPendingResponseDTO from src.api.endpoints.metrics.dtos.get.urls.breakdown.submitted import GetMetricsURLsBreakdownSubmittedResponseDTO from src.core.core import AsyncCore -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo metrics_router = APIRouter( @@ -22,14 +22,14 @@ @metrics_router.get("/batches/aggregated") async def get_batches_aggregated_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsBatchesAggregatedResponseDTO: return await core.get_batches_aggregated_metrics() @metrics_router.get("/batches/breakdown") async def get_batches_breakdown_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), page: int = Query( description="The page number", default=1 @@ -40,34 +40,34 @@ async def get_batches_breakdown_metrics( @metrics_router.get("/urls/aggregate") async def get_urls_aggregated_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsURLsAggregatedResponseDTO: return await core.get_urls_aggregated_metrics() @metrics_router.get("/urls/aggregate/pending") async def get_urls_aggregated_pending_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsURLsAggregatedPendingResponseDTO: return await core.get_urls_aggregated_pending_metrics() @metrics_router.get("/urls/breakdown/submitted") async def get_urls_breakdown_submitted_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsURLsBreakdownSubmittedResponseDTO: return await core.get_urls_breakdown_submitted_metrics() @metrics_router.get("/urls/breakdown/pending") async def get_urls_breakdown_pending_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsURLsBreakdownPendingResponseDTO: return await core.get_urls_breakdown_pending_metrics() @metrics_router.get("/backlog") async def get_backlog_metrics( core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetMetricsBacklogResponseDTO: return await core.get_backlog_metrics() \ No newline at end of file diff --git a/src/api/endpoints/proposals/routes.py b/src/api/endpoints/proposals/routes.py index 147e0501..9259a341 100644 --- a/src/api/endpoints/proposals/routes.py +++ b/src/api/endpoints/proposals/routes.py @@ -18,14 +18,14 @@ from src.api.shared.models.message_response import MessageResponse from src.core.core import AsyncCore from src.security.dtos.access_info import AccessInfo -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info proposal_router = APIRouter(prefix="/proposal", tags=["Pending"]) @proposal_router.get("/agencies") async def get_pending_agencies( async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> ProposalAgencyGetOuterResponse: return await async_core.adb_client.run_query_builder( ProposalAgencyGetQueryBuilder(), @@ -37,7 +37,7 @@ async def approve_proposed_agency( proposed_agency_id: int = Path( description="Proposed agency ID to approve" ), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> ProposalAgencyApproveResponse: return await async_core.adb_client.run_query_builder( ProposalAgencyApproveQueryBuilder( @@ -53,7 +53,7 @@ async def reject_proposed_agency( proposed_agency_id: int = Path( description="Proposed agency ID to reject" ), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> ProposalAgencyRejectResponse: return await async_core.adb_client.run_query_builder( ProposalAgencyRejectQueryBuilder( diff --git a/src/api/endpoints/root.py b/src/api/endpoints/root.py index 03b05ed4..044c0a5f 100644 --- a/src/api/endpoints/root.py +++ b/src/api/endpoints/root.py @@ -1,6 +1,6 @@ from fastapi import APIRouter, Query, Depends -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo root_router = APIRouter(prefix="", tags=["Root"]) @@ -8,7 +8,7 @@ @root_router.get("/") async def root( test: str = Query(description="A test parameter"), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> dict[str, str]: """ A simple root endpoint for testing and pinging diff --git a/src/api/endpoints/search/routes.py b/src/api/endpoints/search/routes.py index 58b661e8..aa3c730b 100644 --- a/src/api/endpoints/search/routes.py +++ b/src/api/endpoints/search/routes.py @@ -8,7 +8,7 @@ from src.api.endpoints.search.dtos.response import SearchURLResponse from src.core.core import AsyncCore from src.db.models.impl.agency.enums import JurisdictionType -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo search_router = APIRouter(prefix="/search", tags=["Search"]) @@ -17,7 +17,7 @@ @search_router.get("/url") async def search_url( url: str = Query(description="The URL to search for"), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> SearchURLResponse: """ @@ -44,7 +44,7 @@ async def search_agency( description="The page to search for", default=1 ), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> list[AgencySearchResponse]: if query is None and location_id is None and jurisdiction_type is None: diff --git a/src/api/endpoints/submit/routes.py b/src/api/endpoints/submit/routes.py index dec7e2aa..b7e2344c 100644 --- a/src/api/endpoints/submit/routes.py +++ b/src/api/endpoints/submit/routes.py @@ -14,7 +14,7 @@ from src.api.endpoints.submit.url.queries.core import SubmitURLQueryBuilder from src.core.core import AsyncCore from src.security.dtos.access_info import AccessInfo -from src.security.manager import get_access_info, get_standard_user_access_info +from src.security.manager import get_admin_access_info, get_standard_user_access_info submit_router = APIRouter(prefix="/submit", tags=["Submit"]) @@ -23,7 +23,7 @@ ) async def submit_url( request: URLSubmissionRequest, - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_standard_user_access_info), async_core: AsyncCore = Depends(get_async_core), ) -> URLSubmissionResponse: return await async_core.adb_client.run_query_builder( diff --git a/src/api/endpoints/task/routes.py b/src/api/endpoints/task/routes.py index 23f52999..3bb039b7 100644 --- a/src/api/endpoints/task/routes.py +++ b/src/api/endpoints/task/routes.py @@ -9,7 +9,7 @@ from src.db.enums import TaskType from src.core.core import AsyncCore from src.core.enums import BatchStatus -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo task_router = APIRouter( @@ -34,7 +34,7 @@ async def get_tasks( default=None ), async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetTasksResponse: return await async_core.get_tasks( page=page, @@ -45,7 +45,7 @@ async def get_tasks( @task_router.get("/status") async def get_task_status( async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> GetTaskStatusResponseInfo: return await async_core.get_current_task_status() @@ -53,7 +53,7 @@ async def get_task_status( async def get_task_info( task_id: int = Path(description="The task id"), async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info) + access_info: AccessInfo = Depends(get_admin_access_info) ) -> TaskInfo: return await async_core.get_task_info(task_id) diff --git a/src/api/endpoints/url/routes.py b/src/api/endpoints/url/routes.py index 7d184e6e..77a0a749 100644 --- a/src/api/endpoints/url/routes.py +++ b/src/api/endpoints/url/routes.py @@ -6,7 +6,7 @@ from src.api.endpoints.url.get.dto import GetURLsResponseInfo from src.api.shared.models.message_response import MessageResponse from src.core.core import AsyncCore -from src.security.manager import get_access_info +from src.security.manager import get_admin_access_info from src.security.dtos.access_info import AccessInfo url_router = APIRouter( @@ -26,7 +26,7 @@ async def get_urls( default=False ), async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> GetURLsResponseInfo: result = await async_core.get_urls(page=page, errors=errors) return result @@ -50,6 +50,7 @@ async def get_url_screenshot( async def delete_url( url_id: int, async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_admin_access_info), ) -> MessageResponse: await async_core.adb_client.run_query_builder( DeleteURLQueryBuilder(url_id=url_id) diff --git a/src/security/manager.py b/src/security/manager.py index abeade07..8ec7996a 100644 --- a/src/security/manager.py +++ b/src/security/manager.py @@ -64,7 +64,7 @@ def check_access( oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") -def get_access_info( +def get_admin_access_info( token: Annotated[str, Depends(oauth2_scheme)] ) -> AccessInfo: return SecurityManager().check_access(token, Permissions.SOURCE_COLLECTOR) diff --git a/tests/automated/integration/conftest.py b/tests/automated/integration/conftest.py index 8a9a8569..c15ba98c 100644 --- a/tests/automated/integration/conftest.py +++ b/tests/automated/integration/conftest.py @@ -18,7 +18,7 @@ from src.db.models.impl.url.core.sqlalchemy import URL from src.security.dtos.access_info import AccessInfo from src.security.enums import Permissions -from src.security.manager import get_access_info, get_standard_user_access_info +from src.security.manager import get_admin_access_info, get_standard_user_access_info from tests.automated.integration.api._helpers.RequestValidator import RequestValidator from tests.helpers.api_test_helper import APITestHelper from tests.helpers.data_creator.core import DBDataCreator @@ -133,7 +133,7 @@ def override_access_info() -> AccessInfo: @pytest.fixture(scope="session") def client(disable_task_flags) -> Generator[TestClient, None, None]: with TestClient(app) as c: - app.dependency_overrides[get_access_info] = override_access_info + app.dependency_overrides[get_admin_access_info] = override_access_info app.dependency_overrides[get_standard_user_access_info] = override_access_info async_core: AsyncCore = c.app.state.async_core diff --git a/tests/automated/unit/security_manager/test_security_manager.py b/tests/automated/unit/security_manager/test_security_manager.py index ae58ed6e..42ae8e4d 100644 --- a/tests/automated/unit/security_manager/test_security_manager.py +++ b/tests/automated/unit/security_manager/test_security_manager.py @@ -6,7 +6,7 @@ from src.security.dtos.access_info import AccessInfo from src.security.enums import Permissions -from src.security.manager import SecurityManager, get_access_info +from src.security.manager import SecurityManager, get_admin_access_info SECRET_KEY = "test_secret_key" VALID_TOKEN = "valid_token" @@ -64,6 +64,6 @@ def test_check_access_failure(mock_get_secret_key, mock_jwt_decode): def test_get_access_info(mock_get_secret_key, mock_jwt_decode): - access_info = get_access_info(token=VALID_TOKEN) + access_info = get_admin_access_info(token=VALID_TOKEN) assert access_info.user_id == 1 assert Permissions.SOURCE_COLLECTOR in access_info.permissions