diff --git a/api/routes/annotate.py b/api/routes/annotate.py index 84ba00e4..95512a0b 100644 --- a/api/routes/annotate.py +++ b/api/routes/annotate.py @@ -4,10 +4,12 @@ from api.dependencies import get_async_core from core.AsyncCore import AsyncCore +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ URLAgencyAnnotationPostInfo +from core.DTOs.GetNextURLForAllAnnotationResponse import GetNextURLForAllAnnotationResponse from core.DTOs.RecordTypeAnnotationPostInfo import RecordTypeAnnotationPostInfo from core.DTOs.RelevanceAnnotationPostInfo import RelevanceAnnotationPostInfo from security_manager.SecurityManager import get_access_info, AccessInfo @@ -18,6 +20,11 @@ responses={404: {"description": "Not found"}}, ) +batch_query = Query( + description="The batch id of the next URL to get. " + "If not specified, defaults to first qualifying URL", + default=None +) @annotate_router.get("/relevance") async def get_next_url_for_relevance_annotation( @@ -40,10 +47,7 @@ async def annotate_url_for_relevance_and_get_next_url( url_id: int = Path(description="The URL id to annotate"), async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), + batch_id: Optional[int] = batch_query ) -> GetNextRelevanceAnnotationResponseOuterInfo: """ Post URL annotation and get next URL to annotate @@ -62,10 +66,7 @@ async def annotate_url_for_relevance_and_get_next_url( async def get_next_url_for_record_type_annotation( access_info: AccessInfo = Depends(get_access_info), async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), + batch_id: Optional[int] = batch_query ) -> GetNextRecordTypeAnnotationResponseOuterInfo: return await async_core.get_next_url_for_record_type_annotation( user_id=access_info.user_id, @@ -78,10 +79,7 @@ async def annotate_url_for_record_type_and_get_next_url( url_id: int = Path(description="The URL id to annotate"), async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), + batch_id: Optional[int] = batch_query ) -> GetNextRecordTypeAnnotationResponseOuterInfo: """ Post URL annotation and get next URL to annotate @@ -100,10 +98,7 @@ async def annotate_url_for_record_type_and_get_next_url( async def get_next_url_for_agency_annotation( access_info: AccessInfo = Depends(get_access_info), async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), + batch_id: Optional[int] = batch_query ) -> GetNextURLForAgencyAnnotationResponse: return await async_core.get_next_url_agency_for_annotation( user_id=access_info.user_id, @@ -116,10 +111,7 @@ async def annotate_url_for_agency_and_get_next_url( agency_annotation_post_info: URLAgencyAnnotationPostInfo, async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), + batch_id: Optional[int] = batch_query ) -> GetNextURLForAgencyAnnotationResponse: """ Post URL annotation and get next URL to annotate @@ -133,3 +125,33 @@ async def annotate_url_for_agency_and_get_next_url( user_id=access_info.user_id, batch_id=batch_id ) + +@annotate_router.get("/all") +async def get_next_url_for_all_annotations( + access_info: AccessInfo = Depends(get_access_info), + async_core: AsyncCore = Depends(get_async_core), + batch_id: Optional[int] = batch_query +) -> GetNextURLForAllAnnotationResponse: + return await async_core.get_next_url_for_all_annotations( + batch_id=batch_id + ) + +@annotate_router.post("/all/{url_id}") +async def annotate_url_for_all_annotations_and_get_next_url( + url_id: int, + all_annotation_post_info: AllAnnotationPostInfo, + async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info), + batch_id: Optional[int] = batch_query +) -> GetNextURLForAllAnnotationResponse: + """ + Post URL annotation and get next URL to annotate + """ + await async_core.submit_url_for_all_annotations( + user_id=access_info.user_id, + url_id=url_id, + post_info=all_annotation_post_info + ) + return await async_core.get_next_url_for_all_annotations( + batch_id=batch_id + ) \ No newline at end of file diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index eb68735c..46cd89db 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -3,11 +3,10 @@ from typing import Optional, Type, Any, List from fastapi import HTTPException -from sqlalchemy import select, exists, func, case, desc, Select, not_, and_, or_, update, Delete, Insert, asc, delete +from sqlalchemy import select, exists, func, case, desc, Select, not_, and_, update, asc, delete from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.orm import selectinload, joinedload, QueryableAttribute, aliased -from sqlalchemy.sql.functions import coalesce from starlette import status from collector_db.ConfigManager import ConfigManager @@ -23,18 +22,20 @@ from collector_db.DTOs.URLMapping import URLMapping from collector_db.StatementComposer import StatementComposer from collector_db.constants import PLACEHOLDER_AGENCY_NAME -from collector_db.enums import URLMetadataAttributeType, TaskType -from collector_db.helper_functions import get_postgres_connection_string +from collector_db.enums import TaskType from collector_db.models import URL, URLErrorInfo, URLHTMLContent, Base, \ RootURL, Task, TaskError, LinkTaskURL, Batch, Agency, AutomatedUrlAgencySuggestion, \ UserUrlAgencySuggestion, AutoRelevantSuggestion, AutoRecordTypeSuggestion, UserRelevantSuggestion, \ UserRecordTypeSuggestion, ReviewingUserURL, URLOptionalDataSourceMetadata, ConfirmedURLAgency, Duplicate, Log from collector_manager.enums import URLStatus, CollectorType +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ GetNextURLForAgencyAgencyInfo, GetNextURLForAgencyAnnotationInnerResponse +from core.DTOs.GetNextURLForAllAnnotationResponse import GetNextURLForAllAnnotationResponse, \ + GetNextURLForAllAnnotationInnerResponse from core.DTOs.GetNextURLForFinalReviewResponse import GetNextURLForFinalReviewResponse, FinalReviewAnnotationInfo, \ FinalReviewOptionalMetadata from core.DTOs.GetTasksResponse import GetTasksResponse, GetTasksResponseTaskInfo @@ -129,7 +130,6 @@ async def get_next_url_for_user_annotation( session: AsyncSession, user_suggestion_model_to_exclude: UserSuggestionModel, auto_suggestion_relationship: QueryableAttribute, - user_id: int, batch_id: Optional[int], check_if_annotated_not_relevant: bool = False ) -> URL: @@ -140,14 +140,7 @@ async def get_next_url_for_user_annotation( .where(URL.outcome == URLStatus.PENDING.value) # URL must not have user suggestion .where( - not_( - exists( - select(user_suggestion_model_to_exclude) - .where( - user_suggestion_model_to_exclude.url_id == URL.id, - ) - ) - ) + StatementComposer.user_suggestion_not_exists(user_suggestion_model_to_exclude) ) ) @@ -213,7 +206,6 @@ async def get_next_url_for_relevance_annotation( session, user_suggestion_model_to_exclude=UserRelevantSuggestion, auto_suggestion_relationship=URL.auto_relevant_suggestion, - user_id=user_id, batch_id=batch_id ) if url is None: @@ -254,7 +246,6 @@ async def get_next_url_for_record_type_annotation( session, user_suggestion_model_to_exclude=UserRecordTypeSuggestion, auto_suggestion_relationship=URL.auto_record_type_suggestion, - user_id=user_id, batch_id=batch_id, check_if_annotated_not_relevant=True ) @@ -823,9 +814,7 @@ async def get_next_url_agency_for_annotation( select(URL.id, URL.url) # Must not have confirmed agencies .where( - and_( - URL.outcome == URLStatus.PENDING.value - ) + URL.outcome == URLStatus.PENDING.value ) ) @@ -838,9 +827,7 @@ async def get_next_url_agency_for_annotation( .where( ~exists( select(UserUrlAgencySuggestion). - where( - UserUrlAgencySuggestion.url_id == URL.id - ). + where(UserUrlAgencySuggestion.url_id == URL.id). correlate(URL) ) ) @@ -885,37 +872,8 @@ async def get_next_url_agency_for_annotation( result = results[0] url_id = result[0] url = result[1] - # Get relevant autosuggestions and agency info, if an associated agency exists - statement = ( - select( - AutomatedUrlAgencySuggestion.agency_id, - AutomatedUrlAgencySuggestion.is_unknown, - Agency.name, - Agency.state, - Agency.county, - Agency.locality - ) - .join(Agency, isouter=True) - .where(AutomatedUrlAgencySuggestion.url_id == url_id) - ) - raw_autosuggestions = await session.execute(statement) - autosuggestions = raw_autosuggestions.all() - agency_suggestions = [] - for autosuggestion in autosuggestions: - agency_id = autosuggestion[0] - is_unknown = autosuggestion[1] - name = autosuggestion[2] - state = autosuggestion[3] - county = autosuggestion[4] - locality = autosuggestion[5] - agency_suggestions.append(GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.AUTO_SUGGESTION if not is_unknown else SuggestionType.UNKNOWN, - pdap_agency_id=agency_id, - agency_name=name, - state=state, - county=county, - locality=locality - )) + + agency_suggestions = await self.get_agency_suggestions(session, url_id=url_id) # Get HTML content info html_content_infos = await self.get_html_content_info(url_id) @@ -1626,5 +1584,141 @@ async def delete_old_logs(self, session): ) await session.execute(statement) + async def get_agency_suggestions(self, session, url_id: int) -> List[GetNextURLForAgencyAgencyInfo]: + # Get relevant autosuggestions and agency info, if an associated agency exists + + statement = ( + select( + AutomatedUrlAgencySuggestion.agency_id, + AutomatedUrlAgencySuggestion.is_unknown, + Agency.name, + Agency.state, + Agency.county, + Agency.locality + ) + .join(Agency, isouter=True) + .where(AutomatedUrlAgencySuggestion.url_id == url_id) + ) + raw_autosuggestions = await session.execute(statement) + autosuggestions = raw_autosuggestions.all() + agency_suggestions = [] + for autosuggestion in autosuggestions: + agency_id = autosuggestion[0] + is_unknown = autosuggestion[1] + name = autosuggestion[2] + state = autosuggestion[3] + county = autosuggestion[4] + locality = autosuggestion[5] + agency_suggestions.append(GetNextURLForAgencyAgencyInfo( + suggestion_type=SuggestionType.AUTO_SUGGESTION if not is_unknown else SuggestionType.UNKNOWN, + pdap_agency_id=agency_id, + agency_name=name, + state=state, + county=county, + locality=locality + )) + return agency_suggestions + + @session_manager + async def get_next_url_for_all_annotations(self, session, batch_id: Optional[int] = None) -> GetNextURLForAllAnnotationResponse: + query = ( + Select(URL) + .where( + and_( + URL.outcome == URLStatus.PENDING.value, + StatementComposer.user_suggestion_not_exists(UserUrlAgencySuggestion), + StatementComposer.user_suggestion_not_exists(UserRecordTypeSuggestion), + StatementComposer.user_suggestion_not_exists(UserRelevantSuggestion), + ) + ) + ) + if batch_id is not None: + query = query.where(URL.batch_id == batch_id) + + load_options = [ + URL.html_content, + URL.automated_agency_suggestions, + URL.auto_relevant_suggestion, + URL.auto_record_type_suggestion + ] + select_in_loads = [selectinload(load_option) for load_option in load_options] + + # Add load options + query = query.options( + *select_in_loads + ) + + query = query.order_by(URL.id.asc()).limit(1) + raw_results = await session.execute(query) + url = raw_results.scalars().one_or_none() + if url is None: + return GetNextURLForAllAnnotationResponse( + next_annotation=None + ) + + html_response_info = DTOConverter.html_content_list_to_html_response_info( + url.html_content + ) + + if url.auto_relevant_suggestion is not None: + auto_relevant = url.auto_relevant_suggestion.relevant + else: + auto_relevant = None + + if url.auto_record_type_suggestion is not None: + auto_record_type = url.auto_record_type_suggestion.record_type + else: + auto_record_type = None + + agency_suggestions = await self.get_agency_suggestions(session, url_id=url.id) + + return GetNextURLForAllAnnotationResponse( + next_annotation=GetNextURLForAllAnnotationInnerResponse( + url_id=url.id, + url=url.url, + html_info=html_response_info, + suggested_relevant=auto_relevant, + suggested_record_type=auto_record_type, + agency_suggestions=agency_suggestions + ) + ) + + @session_manager + async def add_all_annotations_to_url( + self, + session, + user_id: int, + url_id: int, + post_info: AllAnnotationPostInfo + ): + + # Add relevant annotation + relevant_suggestion = UserRelevantSuggestion( + url_id=url_id, + user_id=user_id, + relevant=post_info.is_relevant + ) + session.add(relevant_suggestion) + + # If not relevant, do nothing else + if not post_info.is_relevant: + return + + record_type_suggestion = UserRecordTypeSuggestion( + url_id=url_id, + user_id=user_id, + record_type=post_info.record_type.value + ) + session.add(record_type_suggestion) + + agency_suggestion = UserUrlAgencySuggestion( + url_id=url_id, + user_id=user_id, + agency_id=post_info.agency.suggested_agency, + is_new=post_info.agency.is_new + ) + session.add(agency_suggestion) + + diff --git a/collector_db/StatementComposer.py b/collector_db/StatementComposer.py index e25ba5d4..ca66f6ba 100644 --- a/collector_db/StatementComposer.py +++ b/collector_db/StatementComposer.py @@ -1,11 +1,11 @@ from typing import Any -from sqlalchemy import Select, select, exists, Table, func, Subquery, and_ +from sqlalchemy import Select, select, exists, Table, func, Subquery, and_, not_, ColumnElement from sqlalchemy.orm import aliased from collector_db.enums import URLMetadataAttributeType, ValidationStatus, TaskType from collector_db.models import URL, URLHTMLContent, AutomatedUrlAgencySuggestion, URLOptionalDataSourceMetadata, Batch, \ - ConfirmedURLAgency, LinkTaskURL, Task + ConfirmedURLAgency, LinkTaskURL, Task, UserUrlAgencySuggestion, UserRecordTypeSuggestion, UserRelevantSuggestion from collector_manager.enums import URLStatus, CollectorType from core.enums import BatchStatus @@ -94,4 +94,24 @@ def pending_urls_missing_miscellaneous_metadata_query() -> Select: Batch ) - return query \ No newline at end of file + return query + + + @staticmethod + def user_suggestion_not_exists( + model_to_exclude: UserUrlAgencySuggestion or + UserRecordTypeSuggestion or + UserRelevantSuggestion + ) -> ColumnElement[bool]: + # + + subquery = not_( + exists( + select(model_to_exclude) + .where( + model_to_exclude.url_id == URL.id, + ) + ) + ) + + return subquery \ No newline at end of file diff --git a/core/AsyncCore.py b/core/AsyncCore.py index d436d3c9..92f097db 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -8,6 +8,7 @@ from collector_db.enums import TaskType from collector_manager.AsyncCollectorManager import AsyncCollectorManager from collector_manager.enums import CollectorType +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.CollectorStartInfo import CollectorStartInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse @@ -17,6 +18,7 @@ from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ URLAgencyAnnotationPostInfo +from core.DTOs.GetNextURLForAllAnnotationResponse import GetNextURLForAllAnnotationResponse from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsByBatchResponse import GetURLsByBatchResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo @@ -227,6 +229,26 @@ async def get_next_source_for_review( batch_id=batch_id ) + async def get_next_url_for_all_annotations( + self, + batch_id: Optional[int] + ) -> GetNextURLForAllAnnotationResponse: + return await self.adb_client.get_next_url_for_all_annotations( + batch_id=batch_id + ) + + async def submit_url_for_all_annotations( + self, + user_id: int, + url_id: int, + post_info: AllAnnotationPostInfo + ): + await self.adb_client.add_all_annotations_to_url( + user_id=user_id, + url_id=url_id, + post_info=post_info + ) + async def approve_url( self, approval_info: FinalReviewApprovalInfo, diff --git a/core/DTOs/AllAnnotationPostInfo.py b/core/DTOs/AllAnnotationPostInfo.py new file mode 100644 index 00000000..a462b40b --- /dev/null +++ b/core/DTOs/AllAnnotationPostInfo.py @@ -0,0 +1,35 @@ +from http import HTTPStatus +from typing import Optional + +from fastapi import HTTPException +from pydantic import BaseModel, model_validator + +from core.DTOs.GetNextURLForAgencyAnnotationResponse import URLAgencyAnnotationPostInfo +from core.enums import RecordType +from core.exceptions import FailedValidationException + + +class AllAnnotationPostInfo(BaseModel): + is_relevant: bool + record_type: Optional[RecordType] = None + agency: Optional[URLAgencyAnnotationPostInfo] = None + + @model_validator(mode="before") + def allow_record_type_and_agency_only_if_relevant(cls, values): + is_relevant = values.get("is_relevant") + record_type = values.get("record_type") + agency = values.get("agency") + + if not is_relevant: + if record_type is not None: + raise FailedValidationException("record_type must be None if is_relevant is False") + + if agency is not None: + raise FailedValidationException("agency must be None if is_relevant is False") + return values + # Similarly, if relevant, record_type and agency must be provided + if record_type is None: + raise FailedValidationException("record_type must be provided if is_relevant is True") + if agency is None: + raise FailedValidationException("agency must be provided if is_relevant is True") + return values \ No newline at end of file diff --git a/core/DTOs/GetNextRecordTypeAnnotationResponseInfo.py b/core/DTOs/GetNextRecordTypeAnnotationResponseInfo.py index 783b5516..4280e00d 100644 --- a/core/DTOs/GetNextRecordTypeAnnotationResponseInfo.py +++ b/core/DTOs/GetNextRecordTypeAnnotationResponseInfo.py @@ -12,7 +12,7 @@ class GetNextRecordTypeAnnotationResponseInfo(BaseModel): title="Information about the URL" ) suggested_record_type: Optional[RecordType] = Field( - title="Whether the auto-labeler identified the URL as relevant or not" + title="What record type, if any, the auto-labeler identified the URL as" ) html_info: ResponseHTMLInfo = Field( title="HTML information about the URL" diff --git a/core/DTOs/GetNextURLForAllAnnotationResponse.py b/core/DTOs/GetNextURLForAllAnnotationResponse.py new file mode 100644 index 00000000..f4fa4bb8 --- /dev/null +++ b/core/DTOs/GetNextURLForAllAnnotationResponse.py @@ -0,0 +1,24 @@ +from typing import Optional + +from pydantic import Field, BaseModel + +from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAgencyInfo +from core.enums import RecordType +from html_tag_collector.DataClassTags import ResponseHTMLInfo + + +class GetNextURLForAllAnnotationInnerResponse(BaseModel): + url_id: int + url: str + html_info: ResponseHTMLInfo + agency_suggestions: Optional[list[GetNextURLForAgencyAgencyInfo]] + suggested_relevant: Optional[bool] = Field( + title="Whether the auto-labeler identified the URL as relevant or not" + ) + suggested_record_type: Optional[RecordType] = Field( + title="What record type, if any, the auto-labeler identified the URL as" + ) + + +class GetNextURLForAllAnnotationResponse(BaseModel): + next_annotation: Optional[GetNextURLForAllAnnotationInnerResponse] \ No newline at end of file diff --git a/core/DTOs/GetNextURLForAnnotationResponse.py b/core/DTOs/GetNextURLForAnnotationResponse.py deleted file mode 100644 index b4bc1087..00000000 --- a/core/DTOs/GetNextURLForAnnotationResponse.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.DTOs.AnnotationRequestInfo import AnnotationRequestInfo - - -class GetNextURLForAnnotationResponse(BaseModel): - next_annotation: Optional[AnnotationRequestInfo] = None diff --git a/core/exceptions.py b/core/exceptions.py index d9685245..e3e93e55 100644 --- a/core/exceptions.py +++ b/core/exceptions.py @@ -1,3 +1,8 @@ +from http import HTTPStatus + +from fastapi import HTTPException + + class InvalidPreprocessorError(Exception): pass @@ -8,3 +13,8 @@ class MuckrockAPIError(Exception): class MatchAgencyError(Exception): pass + + +class FailedValidationException(HTTPException): + def __init__(self, detail: str): + super().__init__(status_code=HTTPStatus.BAD_REQUEST, detail=detail) \ No newline at end of file diff --git a/tests/test_automated/integration/api/helpers/RequestValidator.py b/tests/test_automated/integration/api/helpers/RequestValidator.py index 4a12bb0e..28e4b4a3 100644 --- a/tests/test_automated/integration/api/helpers/RequestValidator.py +++ b/tests/test_automated/integration/api/helpers/RequestValidator.py @@ -10,6 +10,7 @@ from collector_db.enums import TaskType from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO from collector_manager.enums import CollectorType +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo, FinalReviewBaseInfo from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse @@ -18,6 +19,7 @@ from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ URLAgencyAnnotationPostInfo +from core.DTOs.GetNextURLForAllAnnotationResponse import GetNextURLForAllAnnotationResponse from core.DTOs.GetNextURLForFinalReviewResponse import GetNextURLForFinalReviewOuterResponse from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsByBatchResponse import GetURLsByBatchResponse @@ -294,4 +296,37 @@ async def get_current_task_status(self) -> GetTaskStatusResponseInfo: data = self.get( url=f"/task/status" ) - return GetTaskStatusResponseInfo(**data) \ No newline at end of file + return GetTaskStatusResponseInfo(**data) + + async def get_next_url_for_all_annotations( + self, + batch_id: Optional[int] = None + ) -> GetNextURLForAllAnnotationResponse: + params = {} + update_if_not_none( + target=params, + source={"batch_id": batch_id} + ) + data = self.get( + url=f"/annotate/all", + params=params + ) + return GetNextURLForAllAnnotationResponse(**data) + + async def post_all_annotations_and_get_next( + self, + url_id: int, + all_annotations_post_info: AllAnnotationPostInfo, + batch_id: Optional[int] = None, + ) -> GetNextURLForAllAnnotationResponse: + params = {} + update_if_not_none( + target=params, + source={"batch_id": batch_id} + ) + data = self.post( + url=f"/annotate/all/{url_id}", + params=params, + json=all_annotations_post_info.model_dump(mode='json') + ) + return GetNextURLForAllAnnotationResponse(**data) \ No newline at end of file diff --git a/tests/test_automated/integration/api/test_annotate.py b/tests/test_automated/integration/api/test_annotate.py index d5b6dade..a03540a1 100644 --- a/tests/test_automated/integration/api/test_annotate.py +++ b/tests/test_automated/integration/api/test_annotate.py @@ -1,16 +1,20 @@ +from http import HTTPStatus import pytest from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo from collector_db.DTOs.URLMapping import URLMapping from collector_db.models import UserUrlAgencySuggestion, UserRelevantSuggestion, UserRecordTypeSuggestion +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import URLAgencyAnnotationPostInfo from core.DTOs.RecordTypeAnnotationPostInfo import RecordTypeAnnotationPostInfo from core.DTOs.RelevanceAnnotationPostInfo import RelevanceAnnotationPostInfo from core.enums import RecordType, SuggestionType -from tests.helpers.complex_test_data_functions import AnnotateAgencySetupInfo, setup_for_annotate_agency +from core.exceptions import FailedValidationException +from tests.helpers.complex_test_data_functions import AnnotateAgencySetupInfo, setup_for_annotate_agency, \ + setup_for_get_next_url_for_final_review from html_tag_collector.DataClassTags import ResponseHTMLInfo from tests.helpers.DBDataCreator import BatchURLCreationInfo from tests.test_automated.integration.api.conftest import MOCK_USER_ID @@ -514,3 +518,135 @@ async def test_annotate_agency_submit_new(api_test_helper): assert len(all_manual_suggestions) == 1 assert all_manual_suggestions[0].is_new +@pytest.mark.asyncio +async def test_annotate_all(api_test_helper): + """ + Test the happy path workflow for the all-annotations endpoint + The user should be able to get a valid URL (filtering on batch id if needed), + submit a full annotation, and receive another URL + """ + ath = api_test_helper + adb_client = ath.adb_client() + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + setup_info_2 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_2 = setup_info_2.url_mapping + + # First, get a valid URL to annotate + get_response_1 = await ath.request_validator.get_next_url_for_all_annotations() + + # Apply the second batch id as a filter and see that a different URL is returned + get_response_2 = await ath.request_validator.get_next_url_for_all_annotations( + batch_id=setup_info_2.batch_id + ) + + assert get_response_1.next_annotation.url_id != get_response_2.next_annotation.url_id + + # Annotate the first and submit + agency_id = await ath.db_data_creator.agency() + post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + is_relevant=True, + record_type=RecordType.ACCIDENT_REPORTS, + agency=URLAgencyAnnotationPostInfo( + is_new=False, + suggested_agency=agency_id + ) + ) + ) + assert post_response_1.next_annotation is not None + + # Confirm the second is received + assert post_response_1.next_annotation.url_id == url_mapping_2.url_id + + # Upon submitting the second, confirm that no more URLs are returned through either POST or GET + post_response_2 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_2.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + is_relevant=False, + ) + ) + assert post_response_2.next_annotation is None + + get_response_3 = await ath.request_validator.get_next_url_for_all_annotations() + assert get_response_3.next_annotation is None + + + # Check that all annotations are present in the database + + # Should be two relevance annotations, one True and one False + all_relevance_suggestions = await adb_client.get_all(UserRelevantSuggestion) + assert len(all_relevance_suggestions) == 2 + assert all_relevance_suggestions[0].relevant == True + assert all_relevance_suggestions[1].relevant == False + + # Should be one agency + all_agency_suggestions = await adb_client.get_all(UserUrlAgencySuggestion) + assert len(all_agency_suggestions) == 1 + assert all_agency_suggestions[0].is_new == False + assert all_agency_suggestions[0].agency_id == agency_id + + # Should be one record type + all_record_type_suggestions = await adb_client.get_all(UserRecordTypeSuggestion) + assert len(all_record_type_suggestions) == 1 + assert all_record_type_suggestions[0].record_type == RecordType.ACCIDENT_REPORTS.value + +@pytest.mark.asyncio +async def test_annotate_all_post_batch_filtering(api_test_helper): + """ + Batch filtering should also work when posting annotations + """ + ath = api_test_helper + adb_client = ath.adb_client() + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + setup_info_2 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + setup_info_3 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_3 = setup_info_3.url_mapping + + # Submit the first annotation, using the third batch id, and receive the third URL + post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + batch_id=setup_info_3.batch_id, + all_annotations_post_info=AllAnnotationPostInfo( + is_relevant=True, + record_type=RecordType.ACCIDENT_REPORTS, + agency=URLAgencyAnnotationPostInfo( + is_new=True + ) + ) + ) + + assert post_response_1.next_annotation.url_id == url_mapping_3.url_id + + +@pytest.mark.asyncio +async def test_annotate_all_validation_error(api_test_helper): + """ + Validation errors in the PostInfo DTO should result in a 400 BAD REQUEST response + """ + ath = api_test_helper + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + + with pytest.raises(FailedValidationException) as e: + response = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + is_relevant=False, + record_type=RecordType.ACCIDENT_REPORTS + ) + ) diff --git a/tests/test_automated/unit/dto/__init__.py b/tests/test_automated/unit/dto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_automated/unit/dto/test_all_annotation_post_info.py b/tests/test_automated/unit/dto/test_all_annotation_post_info.py new file mode 100644 index 00000000..3e5cbab4 --- /dev/null +++ b/tests/test_automated/unit/dto/test_all_annotation_post_info.py @@ -0,0 +1,37 @@ +import pytest +from pydantic import ValidationError + +from core.DTOs.AllAnnotationPostInfo import AllAnnotationPostInfo +from core.enums import RecordType +from core.exceptions import FailedValidationException + +# Mock values to pass +mock_record_type = RecordType.ARREST_RECORDS.value # replace with valid RecordType if Enum +mock_agency = {"is_new": False, "suggested_agency": 1} # replace with a valid dict for the URLAgencyAnnotationPostInfo model + +@pytest.mark.parametrize( + "is_relevant, record_type, agency, should_raise", + [ + (True, mock_record_type, mock_agency, False), # valid + (True, None, mock_agency, True), # missing record_type + (True, mock_record_type, None, True), # missing agency + (True, None, None, True), # missing both + (False, None, None, False), # valid + (False, mock_record_type, None, True), # record_type present + (False, None, mock_agency, True), # agency present + (False, mock_record_type, mock_agency, True), # both present + ] +) +def test_all_annotation_post_info_validation(is_relevant, record_type, agency, should_raise): + data = { + "is_relevant": is_relevant, + "record_type": record_type, + "agency": agency + } + + if should_raise: + with pytest.raises(FailedValidationException): + AllAnnotationPostInfo(**data) + else: + model = AllAnnotationPostInfo(**data) + assert model.is_relevant == is_relevant