diff --git a/src/core/tasks/base/operator.py b/src/core/tasks/base/operator.py index ba7a3d3a..ce0ee860 100644 --- a/src/core/tasks/base/operator.py +++ b/src/core/tasks/base/operator.py @@ -45,7 +45,7 @@ async def run_info(self, outcome: TaskOperatorOutcome, message: str) -> TaskOper @abstractmethod - async def inner_task_logic(self): + async def inner_task_logic(self) -> None: raise NotImplementedError async def handle_task_error(self, e): diff --git a/src/core/tasks/url/loader.py b/src/core/tasks/url/loader.py index 24986a85..50ff8920 100644 --- a/src/core/tasks/url/loader.py +++ b/src/core/tasks/url/loader.py @@ -4,6 +4,7 @@ from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.loader import AgencyIdentificationSubtaskLoader from src.core.tasks.url.operators.auto_relevant.core import URLAutoRelevantTaskOperator from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator @@ -59,8 +60,10 @@ async def get_url_record_type_task_operator(self): async def get_agency_identification_task_operator(self): operator = AgencyIdentificationTaskOperator( adb_client=self.adb_client, - pdap_client=self.pdap_client, - muckrock_api_interface=self.muckrock_api_interface + loader=AgencyIdentificationSubtaskLoader( + pdap_client=self.pdap_client, + muckrock_api_interface=self.muckrock_api_interface + ) ) return operator @@ -104,7 +107,7 @@ async def get_task_operators(self) -> list[URLTaskOperatorBase]: await self.get_url_duplicate_task_operator(), await self.get_url_404_probe_task_operator(), await self.get_url_record_type_task_operator(), - # await self.get_agency_identification_task_operator(), + await self.get_agency_identification_task_operator(), await self.get_url_miscellaneous_metadata_task_operator(), await self.get_submit_approved_url_task_operator(), await self.get_url_auto_relevance_task_operator() diff --git a/src/core/tasks/url/operators/agency_identification/core.py b/src/core/tasks/url/operators/agency_identification/core.py index 993807fd..759cfe81 100644 --- a/src/core/tasks/url/operators/agency_identification/core.py +++ b/src/core/tasks/url/operators/agency_identification/core.py @@ -1,75 +1,87 @@ -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface +from src.collectors.enums import CollectorType +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.output import GetAgencySuggestionsOutput from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo from src.core.tasks.url.operators.agency_identification.dtos.tdo import AgencyIdentificationTDO +from src.core.tasks.url.operators.agency_identification.subtasks.impl.base import AgencyIdentificationSubtaskBase +from src.core.tasks.url.operators.agency_identification.subtasks.loader import AgencyIdentificationSubtaskLoader +from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.db.client.async_ import AsyncDatabaseClient -from src.db.models.instantiations.url.error_info.pydantic import URLErrorPydanticInfo from src.db.enums import TaskType -from src.collectors.enums import CollectorType -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.core.tasks.url.subtasks.agency_identification.auto_googler import AutoGooglerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.ckan import CKANAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.common_crawler import CommonCrawlerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.muckrock import MuckrockAgencyIdentificationSubtask -from src.core.enums import SuggestionType -from src.external.pdap.client import PDAPClient - +from src.db.models.instantiations.url.error_info.pydantic import URLErrorPydanticInfo -# TODO: Validate with Manual Tests class AgencyIdentificationTaskOperator(URLTaskOperatorBase): def __init__( self, adb_client: AsyncDatabaseClient, - pdap_client: PDAPClient, - muckrock_api_interface: MuckrockAPIInterface, + loader: AgencyIdentificationSubtaskLoader, ): super().__init__(adb_client) - self.pdap_client = pdap_client - self.muckrock_api_interface = muckrock_api_interface + self.loader = loader @property - def task_type(self): + def task_type(self) -> TaskType: return TaskType.AGENCY_IDENTIFICATION - async def meets_task_prerequisites(self): + async def meets_task_prerequisites(self) -> bool: has_urls_without_agency_suggestions = await self.adb_client.has_urls_without_agency_suggestions() return has_urls_without_agency_suggestions - async def get_pending_urls_without_agency_identification(self): + async def get_pending_urls_without_agency_identification(self) -> list[AgencyIdentificationTDO]: return await self.adb_client.get_urls_without_agency_suggestions() - async def get_muckrock_subtask(self): - return MuckrockAgencyIdentificationSubtask( - muckrock_api_interface=self.muckrock_api_interface, - pdap_client=self.pdap_client - ) - - async def get_subtask(self, collector_type: CollectorType): - match collector_type: - case CollectorType.MUCKROCK_SIMPLE_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.MUCKROCK_COUNTY_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.MUCKROCK_ALL_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.AUTO_GOOGLER: - return AutoGooglerAgencyIdentificationSubtask() - case CollectorType.COMMON_CRAWLER: - return CommonCrawlerAgencyIdentificationSubtask() - case CollectorType.CKAN: - return CKANAgencyIdentificationSubtask( - pdap_client=self.pdap_client - ) - return None + async def get_subtask( + self, + collector_type: CollectorType + ) -> AgencyIdentificationSubtaskBase: + """Get subtask based on collector type.""" + return await self.loader.load_subtask(collector_type) @staticmethod - async def run_subtask(subtask, url_id, collector_metadata) -> list[URLAgencySuggestionInfo]: - return await subtask.run(url_id=url_id, collector_metadata=collector_metadata) + async def run_subtask( + subtask: AgencyIdentificationSubtaskBase, + url_id: int, + collector_metadata: dict | None + ) -> list[URLAgencySuggestionInfo]: + return await subtask.run( + url_id=url_id, + collector_metadata=collector_metadata + ) - async def inner_task_logic(self): + async def inner_task_logic(self) -> None: tdos: list[AgencyIdentificationTDO] = await self.get_pending_urls_without_agency_identification() await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) + output = await self._get_agency_suggestions(tdos) + + await self._process_agency_suggestions(output.agency_suggestions) + await self.adb_client.add_url_error_infos(output.error_infos) + + async def _process_agency_suggestions( + self, + suggestions: list[URLAgencySuggestionInfo] + ) -> None: + non_unknown_agency_suggestions = [ + suggestion for suggestion in suggestions + if suggestion.suggestion_type != SuggestionType.UNKNOWN + ] + await self.adb_client.upsert_new_agencies(non_unknown_agency_suggestions) + confirmed_suggestions = [ + suggestion for suggestion in suggestions + if suggestion.suggestion_type == SuggestionType.CONFIRMED + ] + await self.adb_client.add_confirmed_agency_url_links(confirmed_suggestions) + non_confirmed_suggestions = [ + suggestion for suggestion in suggestions + if suggestion.suggestion_type != SuggestionType.CONFIRMED + ] + await self.adb_client.add_agency_auto_suggestions(non_confirmed_suggestions) + + async def _get_agency_suggestions( + self, + tdos: list[AgencyIdentificationTDO] + ) -> GetAgencySuggestionsOutput: error_infos = [] all_agency_suggestions = [] for tdo in tdos: @@ -88,13 +100,10 @@ async def inner_task_logic(self): error=str(e), ) error_infos.append(error_info) - - non_unknown_agency_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type != SuggestionType.UNKNOWN] - await self.adb_client.upsert_new_agencies(non_unknown_agency_suggestions) - confirmed_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type == SuggestionType.CONFIRMED] - await self.adb_client.add_confirmed_agency_url_links(confirmed_suggestions) - non_confirmed_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type != SuggestionType.CONFIRMED] - await self.adb_client.add_agency_auto_suggestions(non_confirmed_suggestions) - await self.adb_client.add_url_error_infos(error_infos) + output = GetAgencySuggestionsOutput( + agency_suggestions=all_agency_suggestions, + error_infos=error_infos + ) + return output diff --git a/src/core/tasks/url/operators/agency_identification/dtos/output.py b/src/core/tasks/url/operators/agency_identification/dtos/output.py new file mode 100644 index 00000000..46f3aa97 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/dtos/output.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.db.models.instantiations.url.error_info.pydantic import URLErrorPydanticInfo + + +class GetAgencySuggestionsOutput(BaseModel): + error_infos: list[URLErrorPydanticInfo] + agency_suggestions: list[URLAgencySuggestionInfo] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py b/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py index c0ea08f4..f42ecfc2 100644 --- a/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py +++ b/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py @@ -7,7 +7,7 @@ class URLAgencySuggestionInfo(BaseModel): url_id: int - suggestion_type: SuggestionType + suggestion_type: SuggestionType = SuggestionType.UNKNOWN pdap_agency_id: Optional[int] = None agency_name: Optional[str] = None state: Optional[str] = None diff --git a/src/core/tasks/url/operators/agency_identification/dtos/tdo.py b/src/core/tasks/url/operators/agency_identification/dtos/tdo.py index 70ff1ae5..35f22844 100644 --- a/src/core/tasks/url/operators/agency_identification/dtos/tdo.py +++ b/src/core/tasks/url/operators/agency_identification/dtos/tdo.py @@ -8,4 +8,4 @@ class AgencyIdentificationTDO(BaseModel): url_id: int collector_metadata: Optional[dict] = None - collector_type: CollectorType + collector_type: CollectorType | None diff --git a/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py b/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py index 0c814cb2..63ade865 100644 --- a/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py +++ b/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py @@ -15,11 +15,15 @@ class GetPendingURLsWithoutAgencySuggestionsQueryBuilder(QueryBuilderBase): async def run(self, session: AsyncSession) -> list[AgencyIdentificationTDO]: statement = ( - select(URL.id, URL.collector_metadata, Batch.strategy) + select( + URL.id, + URL.collector_metadata, + Batch.strategy + ) .select_from(URL) .where(URL.outcome == URLStatus.PENDING.value) - .join(LinkBatchURL) - .join(Batch) + .outerjoin(LinkBatchURL) + .outerjoin(Batch) ) statement = StatementComposer.exclude_urls_with_agency_suggestions(statement) statement = statement.limit(100) @@ -28,7 +32,7 @@ async def run(self, session: AsyncSession) -> list[AgencyIdentificationTDO]: AgencyIdentificationTDO( url_id=raw_result[0], collector_metadata=raw_result[1], - collector_type=CollectorType(raw_result[2]) + collector_type=CollectorType(raw_result[2]) if raw_result[2] is not None else None ) for raw_result in raw_results ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/queries/has_urls_without_agency_suggestions.py b/src/core/tasks/url/operators/agency_identification/queries/has_urls_without_agency_suggestions.py new file mode 100644 index 00000000..88e3c828 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/queries/has_urls_without_agency_suggestions.py @@ -0,0 +1,27 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.enums import URLStatus +from src.db.models.instantiations.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.statement_composer import StatementComposer + + +class HasURLsWithoutAgencySuggestionsQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> bool: + statement = ( + select( + URL.id + ).where( + URL.outcome == URLStatus.PENDING.value + ) + ) + + statement = StatementComposer.exclude_urls_with_agency_suggestions(statement) + raw_result = await session.execute(statement) + result = raw_result.all() + return len(result) != 0 \ No newline at end of file diff --git a/src/core/tasks/url/subtasks/agency_identification/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/__init__.py similarity index 100% rename from src/core/tasks/url/subtasks/agency_identification/__init__.py rename to src/core/tasks/url/operators/agency_identification/subtasks/__init__.py diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/subtasks/agency_identification/base.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/base.py similarity index 87% rename from src/core/tasks/url/subtasks/agency_identification/base.py rename to src/core/tasks/url/operators/agency_identification/subtasks/impl/base.py index 5727fcc8..96f98f30 100644 --- a/src/core/tasks/url/subtasks/agency_identification/base.py +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/base.py @@ -11,6 +11,6 @@ class AgencyIdentificationSubtaskBase(ABC): async def run( self, url_id: int, - collector_metadata: Optional[dict] = None + collector_metadata: dict | None = None ) -> list[URLAgencySuggestionInfo]: raise NotImplementedError diff --git a/src/core/tasks/url/subtasks/agency_identification/ckan.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan.py similarity index 72% rename from src/core/tasks/url/subtasks/agency_identification/ckan.py rename to src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan.py index 6092aed4..15dddf6f 100644 --- a/src/core/tasks/url/subtasks/agency_identification/ckan.py +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan.py @@ -1,12 +1,15 @@ -from typing import Optional +from typing import final + +from typing_extensions import override from src.core.helpers import process_match_agency_response_to_suggestions from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.core.tasks.url.operators.agency_identification.subtasks.impl.base import AgencyIdentificationSubtaskBase from src.external.pdap.client import PDAPClient from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse - -class CKANAgencyIdentificationSubtask: +@final +class CKANAgencyIdentificationSubtask(AgencyIdentificationSubtaskBase): def __init__( self, @@ -14,10 +17,11 @@ def __init__( ): self.pdap_client = pdap_client + @override async def run( self, url_id: int, - collector_metadata: Optional[dict] + collector_metadata: dict | None = None ) -> list[URLAgencySuggestionInfo]: agency_name = collector_metadata["agency_name"] match_agency_response: MatchAgencyResponse = await self.pdap_client.match_agency( diff --git a/src/core/tasks/url/subtasks/agency_identification/muckrock.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock.py similarity index 84% rename from src/core/tasks/url/subtasks/agency_identification/muckrock.py rename to src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock.py index df61e281..fd3b9ec2 100644 --- a/src/core/tasks/url/subtasks/agency_identification/muckrock.py +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock.py @@ -1,4 +1,6 @@ -from typing import Optional +from typing import final + +from typing_extensions import override from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse @@ -6,11 +8,12 @@ from src.core.exceptions import MuckrockAPIError from src.core.helpers import process_match_agency_response_to_suggestions from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.core.tasks.url.operators.agency_identification.subtasks.impl.base import AgencyIdentificationSubtaskBase from src.external.pdap.client import PDAPClient from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse - -class MuckrockAgencyIdentificationSubtask: +@final +class MuckrockAgencyIdentificationSubtask(AgencyIdentificationSubtaskBase): def __init__( self, @@ -20,10 +23,11 @@ def __init__( self.muckrock_api_interface = muckrock_api_interface self.pdap_client = pdap_client + @override async def run( self, url_id: int, - collector_metadata: Optional[dict] + collector_metadata: dict | None = None ) -> list[URLAgencySuggestionInfo]: muckrock_agency_id = collector_metadata["agency"] agency_lookup_response: AgencyLookupResponse = await self.muckrock_api_interface.lookup_agency( diff --git a/src/core/tasks/url/subtasks/agency_identification/auto_googler.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/unknown.py similarity index 56% rename from src/core/tasks/url/subtasks/agency_identification/auto_googler.py rename to src/core/tasks/url/operators/agency_identification/subtasks/impl/unknown.py index 6f19ee7b..7ffd57bc 100644 --- a/src/core/tasks/url/subtasks/agency_identification/auto_googler.py +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/unknown.py @@ -1,16 +1,21 @@ -from typing import Optional +from typing_extensions import override, final from src.core.enums import SuggestionType from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.core.tasks.url.subtasks.agency_identification.base import AgencyIdentificationSubtaskBase +from src.core.tasks.url.operators.agency_identification.subtasks.impl.base import AgencyIdentificationSubtaskBase +@final +class UnknownAgencyIdentificationSubtask(AgencyIdentificationSubtaskBase): + """A subtask that returns an unknown suggestion. -class AutoGooglerAgencyIdentificationSubtask(AgencyIdentificationSubtaskBase): + Used in cases where the agency cannot be reliably inferred from the source. + """ + @override async def run( self, url_id: int, - collector_metadata: Optional[dict] = None + collector_metadata: dict | None = None ) -> list[URLAgencySuggestionInfo]: return [ URLAgencySuggestionInfo( diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/loader.py b/src/core/tasks/url/operators/agency_identification/subtasks/loader.py new file mode 100644 index 00000000..71f53568 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/loader.py @@ -0,0 +1,48 @@ +from src.collectors.enums import CollectorType +from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface +from src.core.tasks.url.operators.agency_identification.subtasks.impl.base import AgencyIdentificationSubtaskBase +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan import CKANAgencyIdentificationSubtask +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock import \ + MuckrockAgencyIdentificationSubtask +from src.core.tasks.url.operators.agency_identification.subtasks.impl.unknown import UnknownAgencyIdentificationSubtask +from src.external.pdap.client import PDAPClient + + +class AgencyIdentificationSubtaskLoader: + """Loads subtasks and associated dependencies.""" + + def __init__( + self, + pdap_client: PDAPClient, + muckrock_api_interface: MuckrockAPIInterface + ): + self.pdap_client = pdap_client + self.muckrock_api_interface = muckrock_api_interface + + async def _load_muckrock_subtask(self) -> MuckrockAgencyIdentificationSubtask: + return MuckrockAgencyIdentificationSubtask( + muckrock_api_interface=self.muckrock_api_interface, + pdap_client=self.pdap_client + ) + + async def _load_ckan_subtask(self) -> CKANAgencyIdentificationSubtask: + return CKANAgencyIdentificationSubtask( + pdap_client=self.pdap_client + ) + + async def load_subtask(self, collector_type: CollectorType) -> AgencyIdentificationSubtaskBase: + """Get subtask based on collector type.""" + match collector_type: + case CollectorType.MUCKROCK_SIMPLE_SEARCH: + return await self._load_muckrock_subtask() + case CollectorType.MUCKROCK_COUNTY_SEARCH: + return await self._load_muckrock_subtask() + case CollectorType.MUCKROCK_ALL_SEARCH: + return await self._load_muckrock_subtask() + case CollectorType.AUTO_GOOGLER: + return UnknownAgencyIdentificationSubtask() + case CollectorType.COMMON_CRAWLER: + return UnknownAgencyIdentificationSubtask() + case CollectorType.CKAN: + return await self._load_ckan_subtask() + return UnknownAgencyIdentificationSubtask() \ No newline at end of file diff --git a/src/core/tasks/url/operators/base.py b/src/core/tasks/url/operators/base.py index 59c41c6a..d4d1667e 100644 --- a/src/core/tasks/url/operators/base.py +++ b/src/core/tasks/url/operators/base.py @@ -17,7 +17,7 @@ def __init__(self, adb_client: AsyncDatabaseClient): self.linked_url_ids = [] @abstractmethod - async def meets_task_prerequisites(self): + async def meets_task_prerequisites(self) -> bool: """ A task should not be initiated unless certain conditions are met diff --git a/src/core/tasks/url/subtasks/agency_identification/common_crawler.py b/src/core/tasks/url/subtasks/agency_identification/common_crawler.py deleted file mode 100644 index fae8faaf..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/common_crawler.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional - -from src.core.enums import SuggestionType -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo - - -class CommonCrawlerAgencyIdentificationSubtask: - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] - ) -> list[URLAgencySuggestionInfo]: - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=None, - agency_name=None, - state=None, - county=None, - locality=None - ) - ] diff --git a/src/db/client/async_.py b/src/db/client/async_.py index 9f554f87..b4311733 100644 --- a/src/db/client/async_.py +++ b/src/db/client/async_.py @@ -73,6 +73,8 @@ from src.core.tasks.url.operators.agency_identification.dtos.tdo import AgencyIdentificationTDO from src.core.tasks.url.operators.agency_identification.queries.get_pending_urls_without_agency_suggestions import \ GetPendingURLsWithoutAgencySuggestionsQueryBuilder +from src.core.tasks.url.operators.agency_identification.queries.has_urls_without_agency_suggestions import \ + HasURLsWithoutAgencySuggestionsQueryBuilder from src.core.tasks.url.operators.auto_relevant.models.tdo import URLRelevantTDO from src.core.tasks.url.operators.auto_relevant.queries.get_tdos import GetAutoRelevantTDOsQueryBuilder from src.core.tasks.url.operators.submit_approved_url.queries.get import GetValidatedURLsQueryBuilder @@ -721,23 +723,8 @@ async def get_tasks( tasks=final_results ) - @session_manager - async def has_urls_without_agency_suggestions( - self, - session: AsyncSession - ) -> bool: - statement = ( - select( - URL.id - ).where( - URL.outcome == URLStatus.PENDING.value - ) - ) - - statement = self.statement_composer.exclude_urls_with_agency_suggestions(statement) - raw_result = await session.execute(statement) - result = raw_result.all() - return len(result) != 0 + async def has_urls_without_agency_suggestions(self) -> bool: + return await self.run_query_builder(HasURLsWithoutAgencySuggestionsQueryBuilder()) async def get_urls_without_agency_suggestions( self diff --git a/src/db/client/sync.py b/src/db/client/sync.py index 361cb25a..866feb25 100644 --- a/src/db/client/sync.py +++ b/src/db/client/sync.py @@ -127,11 +127,12 @@ def insert_url(self, session, url_info: URLInfo) -> int: session.add(url_entry) session.commit() session.refresh(url_entry) - link = LinkBatchURL( - batch_id=url_info.batch_id, - url_id=url_entry.id - ) - session.add(link) + if url_info.batch_id is not None: + link = LinkBatchURL( + batch_id=url_info.batch_id, + url_id=url_entry.id + ) + session.add(link) return url_entry.id def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo: diff --git a/tests/automated/integration/api/test_annotate.py b/tests/automated/integration/api/test_annotate.py index c4b1f33c..690b83e4 100644 --- a/tests/automated/integration/api/test_annotate.py +++ b/tests/automated/integration/api/test_annotate.py @@ -187,7 +187,7 @@ async def test_annotate_relevancy_already_annotated_by_different_user( await ath.db_data_creator.user_relevant_suggestion( url_id=creation_info.url_ids[0], user_id=2, - relevant=True + suggested_status=SuggestedStatus.RELEVANT ) # Annotate with different user (default is 1) and get conflict error diff --git a/tests/automated/integration/tasks/url/agency_identification/__init__.py b/tests/automated/integration/tasks/url/agency_identification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/__init__.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/asserts.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/asserts.py new file mode 100644 index 00000000..c96aa4db --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/happy_path/asserts.py @@ -0,0 +1,19 @@ +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.instantiations.agency.sqlalchemy import Agency +from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion + + +async def assert_expected_confirmed_and_auto_suggestions(adb_client: AsyncDatabaseClient): + confirmed_suggestions = await adb_client.get_urls_with_confirmed_agencies() + + # The number of confirmed suggestions is dependent on how often + # the subtask iterated through the sample agency suggestions defined in `data.py` + assert len(confirmed_suggestions) == 3 + agencies = await adb_client.get_all(Agency) + assert len(agencies) == 2 + auto_suggestions = await adb_client.get_all(AutomatedUrlAgencySuggestion) + assert len(auto_suggestions) == 4 + # Of the auto suggestions, 2 should be unknown + assert len([s for s in auto_suggestions if s.is_unknown]) == 2 + # Of the auto suggestions, 2 should not be unknown + assert len([s for s in auto_suggestions if not s.is_unknown]) == 2 diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/conftest.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/conftest.py new file mode 100644 index 00000000..d3a95856 --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/happy_path/conftest.py @@ -0,0 +1,29 @@ +from unittest.mock import create_autospec, AsyncMock + +import pytest + +from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface +from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.loader import AgencyIdentificationSubtaskLoader +from src.db.client.async_ import AsyncDatabaseClient +from src.external.pdap.client import PDAPClient +from tests.automated.integration.tasks.url.agency_identification.happy_path.mock import mock_run_subtask + + +@pytest.fixture +def operator( + adb_client_test: AsyncDatabaseClient +): + + operator = AgencyIdentificationTaskOperator( + adb_client=adb_client_test, + loader=AgencyIdentificationSubtaskLoader( + pdap_client=create_autospec(PDAPClient), + muckrock_api_interface=create_autospec(MuckrockAPIInterface) + ) + ) + operator.run_subtask = AsyncMock( + side_effect=mock_run_subtask + ) + + return operator diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/data.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/data.py new file mode 100644 index 00000000..ea224c37 --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/happy_path/data.py @@ -0,0 +1,34 @@ + + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo + +SAMPLE_AGENCY_SUGGESTIONS = [ + URLAgencySuggestionInfo( + url_id=-1, # This will be overwritten + suggestion_type=SuggestionType.UNKNOWN, + pdap_agency_id=None, + agency_name=None, + state=None, + county=None, + locality=None + ), + URLAgencySuggestionInfo( + url_id=-1, # This will be overwritten + suggestion_type=SuggestionType.CONFIRMED, + pdap_agency_id=-1, + agency_name="Test Agency", + state="Test State", + county="Test County", + locality="Test Locality" + ), + URLAgencySuggestionInfo( + url_id=-1, # This will be overwritten + suggestion_type=SuggestionType.AUTO_SUGGESTION, + pdap_agency_id=-1, + agency_name="Test Agency 2", + state="Test State 2", + county="Test County 2", + locality="Test Locality 2" + ) +] diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/mock.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/mock.py new file mode 100644 index 00000000..cec98d3c --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/happy_path/mock.py @@ -0,0 +1,19 @@ +from copy import deepcopy +from typing import Optional + +from src.core.enums import SuggestionType +from tests.automated.integration.tasks.url.agency_identification.happy_path.data import SAMPLE_AGENCY_SUGGESTIONS + + +async def mock_run_subtask( + subtask, + url_id: int, + collector_metadata: Optional[dict] +): + """A mocked version of run_subtask that returns a single suggestion for each url_id.""" + + # Deepcopy to prevent using the same instance in memory + suggestion = deepcopy(SAMPLE_AGENCY_SUGGESTIONS[url_id % 3]) + suggestion.url_id = url_id + suggestion.pdap_agency_id = (url_id % 3) if suggestion.suggestion_type != SuggestionType.UNKNOWN else None + return [suggestion] diff --git a/tests/automated/integration/tasks/url/agency_identification/happy_path/test_happy_path.py b/tests/automated/integration/tasks/url/agency_identification/happy_path/test_happy_path.py new file mode 100644 index 00000000..5cae5a26 --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/happy_path/test_happy_path.py @@ -0,0 +1,129 @@ +from unittest.mock import AsyncMock + +import pytest +from aiohttp import ClientSession + +from src.collectors.enums import CollectorType, URLStatus +from src.core.tasks.url.enums import TaskOperatorOutcome +from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan import CKANAgencyIdentificationSubtask +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock import \ + MuckrockAgencyIdentificationSubtask +from src.core.tasks.url.operators.agency_identification.subtasks.impl.unknown import UnknownAgencyIdentificationSubtask +from tests.automated.integration.tasks.url.agency_identification.happy_path.asserts import \ + assert_expected_confirmed_and_auto_suggestions +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 + + +@pytest.mark.asyncio +async def test_agency_identification_task( + db_data_creator: DBDataCreator, + test_client_session: ClientSession, + operator: AgencyIdentificationTaskOperator +): + """Test full flow of AgencyIdentificationTaskOperator""" + + + # Confirm does not yet meet prerequisites + assert not await operator.meets_task_prerequisites() + + collector_type_to_url_id: dict[CollectorType | None, int] = {} + + # Create six urls, one from each strategy + for strategy in [ + CollectorType.COMMON_CRAWLER, + CollectorType.AUTO_GOOGLER, + CollectorType.MUCKROCK_COUNTY_SEARCH, + CollectorType.MUCKROCK_SIMPLE_SEARCH, + CollectorType.MUCKROCK_ALL_SEARCH, + CollectorType.CKAN, + ]: + # Create two URLs for each, one pending and one errored + creation_info: BatchURLCreationInfoV2 = await db_data_creator.batch_v2( + parameters=TestBatchCreationParameters( + strategy=strategy, + urls=[ + TestURLCreationParameters( + count=1, + status=URLStatus.PENDING, + with_html_content=True + ), + TestURLCreationParameters( + count=1, + status=URLStatus.ERROR, + with_html_content=True + ) + ] + ) + ) + collector_type_to_url_id[strategy] = creation_info.urls_by_status[URLStatus.PENDING].url_mappings[0].url_id + + # Create an additional two urls with no collector. + response = await db_data_creator.url_v2( + parameters=[ + TestURLCreationParameters( + count=1, + status=URLStatus.PENDING, + with_html_content=True + ), + TestURLCreationParameters( + count=1, + status=URLStatus.ERROR, + with_html_content=True + ) + ] + ) + collector_type_to_url_id[None] = response.urls_by_status[URLStatus.PENDING].url_mappings[0].url_id + + + # Confirm meets prerequisites + assert await operator.meets_task_prerequisites() + # Run task + run_info = await operator.run_task(1) + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message + + # Confirm tasks are piped into the correct subtasks + # * common_crawler into common_crawler_subtask + # * auto_googler into auto_googler_subtask + # * muckrock_county_search into muckrock_subtask + # * muckrock_simple_search into muckrock_subtask + # * muckrock_all_search into muckrock_subtask + # * ckan into ckan_subtask + + + mock_run_subtask: AsyncMock = operator.run_subtask + + # Check correct number of calls to run_subtask + assert mock_run_subtask.call_count == 7 + + # Confirm subtask classes are correct for the given urls + d2 = {} + for call_arg in mock_run_subtask.call_args_list: + subtask_class = call_arg[0][0].__class__ + url_id = call_arg[0][1] + d2[url_id] = subtask_class + + + subtask_class_collector_type = [ + (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_ALL_SEARCH), + (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_COUNTY_SEARCH), + (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_SIMPLE_SEARCH), + (CKANAgencyIdentificationSubtask, CollectorType.CKAN), + (UnknownAgencyIdentificationSubtask, CollectorType.COMMON_CRAWLER), + (UnknownAgencyIdentificationSubtask, CollectorType.AUTO_GOOGLER), + (UnknownAgencyIdentificationSubtask, None) + ] + + for subtask_class, collector_type in subtask_class_collector_type: + url_id = collector_type_to_url_id[collector_type] + assert d2[url_id] == subtask_class + + + # Confirm task again does not meet prerequisites + assert not await operator.meets_task_prerequisites() + # # Check confirmed and auto suggestions + adb_client = db_data_creator.adb_client + await assert_expected_confirmed_and_auto_suggestions(adb_client) diff --git a/tests/automated/integration/tasks/url/agency_identification/subtasks/__init__.py b/tests/automated/integration/tasks/url/agency_identification/subtasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/agency_identification/subtasks/test_ckan.py b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_ckan.py new file mode 100644 index 00000000..6a2e4fed --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_ckan.py @@ -0,0 +1,58 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.external.pdap.enums import MatchAgencyResponseStatus +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan import CKANAgencyIdentificationSubtask +from src.core.enums import SuggestionType +from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse +from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_ckan_subtask(db_data_creator: DBDataCreator): + # Test that ckan subtask correctly sends agency id to + # CKANAPIInterface, sends resultant agency name to + # PDAPClient and adds received suggestions to + # url_agency_suggestions + + pdap_client = AsyncMock() + pdap_client.match_agency.return_value = MatchAgencyResponse( + status=MatchAgencyResponseStatus.PARTIAL_MATCH, + matches=[ + MatchAgencyInfo( + id=1, + submitted_name="Mock Agency Name", + ), + MatchAgencyInfo( + id=2, + submitted_name="Another Mock Agency Name", + ) + ] + ) # Assuming MatchAgencyResponse is a class + + # Create an instance of CKANAgencyIdentificationSubtask + task = CKANAgencyIdentificationSubtask(pdap_client) + + # Call the run method with static values + collector_metadata = {"agency_name": "Test Agency"} + url_id = 1 + + # Call the run method + result = await task.run(url_id, collector_metadata) + + # Check the result + assert len(result) == 2 + assert result[0].url_id == 1 + assert result[0].suggestion_type == SuggestionType.AUTO_SUGGESTION + assert result[0].pdap_agency_id == 1 + assert result[0].agency_name == "Mock Agency Name" + assert result[1].url_id == 1 + assert result[1].suggestion_type == SuggestionType.AUTO_SUGGESTION + assert result[1].pdap_agency_id == 2 + assert result[1].agency_name == "Another Mock Agency Name" + + # Assert methods called as expected + pdap_client.match_agency.assert_called_once_with(name="Test Agency") + diff --git a/tests/automated/integration/tasks/url/agency_identification/subtasks/test_muckrock.py b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_muckrock.py new file mode 100644 index 00000000..87bc6614 --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_muckrock.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock + +import pytest + +from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface +from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse +from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock import MuckrockAgencyIdentificationSubtask +from src.external.pdap.client import PDAPClient +from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo +from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse +from src.external.pdap.enums import MatchAgencyResponseStatus +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_muckrock_subtask(db_data_creator: DBDataCreator): + # Test that muckrock subtask correctly sends agency name to + # MatchAgenciesInterface and adds received suggestions to + # url_agency_suggestions + + # Create mock instances for dependency injections + muckrock_api_interface_mock = MagicMock(spec=MuckrockAPIInterface) + pdap_client_mock = MagicMock(spec=PDAPClient) + + # Set up mock return values for method calls + muckrock_api_interface_mock.lookup_agency.return_value = AgencyLookupResponse( + type=AgencyLookupResponseType.FOUND, + name="Mock Agency Name", + error=None + ) + + pdap_client_mock.match_agency.return_value = MatchAgencyResponse( + status=MatchAgencyResponseStatus.PARTIAL_MATCH, + matches=[ + MatchAgencyInfo( + id=1, + submitted_name="Mock Agency Name", + ), + MatchAgencyInfo( + id=2, + submitted_name="Another Mock Agency Name", + ) + ] + ) + + # Create an instance of MuckrockAgencyIdentificationSubtask with mock dependencies + muckrock_agency_identification_subtask = MuckrockAgencyIdentificationSubtask( + muckrock_api_interface=muckrock_api_interface_mock, + pdap_client=pdap_client_mock + ) + + # Run the subtask + results: list[URLAgencySuggestionInfo] = await muckrock_agency_identification_subtask.run( + url_id=1, + collector_metadata={ + "agency": 123 + } + ) + + # Verify the results + assert len(results) == 2 + assert results[0].url_id == 1 + assert results[0].suggestion_type == SuggestionType.AUTO_SUGGESTION + assert results[0].pdap_agency_id == 1 + assert results[0].agency_name == "Mock Agency Name" + assert results[1].url_id == 1 + assert results[1].suggestion_type == SuggestionType.AUTO_SUGGESTION + assert results[1].pdap_agency_id == 2 + assert results[1].agency_name == "Another Mock Agency Name" + + # Assert methods called as expected + muckrock_api_interface_mock.lookup_agency.assert_called_once_with( + muckrock_agency_id=123 + ) + pdap_client_mock.match_agency.assert_called_once_with( + name="Mock Agency Name" + ) diff --git a/tests/automated/integration/tasks/url/agency_identification/subtasks/test_unknown.py b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_unknown.py new file mode 100644 index 00000000..aab59dca --- /dev/null +++ b/tests/automated/integration/tasks/url/agency_identification/subtasks/test_unknown.py @@ -0,0 +1,16 @@ +import pytest + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.core.tasks.url.operators.agency_identification.subtasks.impl.unknown import UnknownAgencyIdentificationSubtask + + +@pytest.mark.asyncio +async def test_unknown_agency_identification_subtask(): + # Test that no_collector subtask correctly adds URL to + # url_agency_suggestions with label 'Unknown' + subtask = UnknownAgencyIdentificationSubtask() + results: list[URLAgencySuggestionInfo] = await subtask.run(url_id=1, collector_metadata={}) + assert len(results) == 1 + assert results[0].url_id == 1 + assert results[0].suggestion_type == SuggestionType.UNKNOWN \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py b/tests/automated/integration/tasks/url/test_agency_preannotation_task.py deleted file mode 100644 index d11a1def..00000000 --- a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py +++ /dev/null @@ -1,327 +0,0 @@ -from copy import deepcopy -from typing import Optional -from unittest.mock import MagicMock, AsyncMock, patch - -import pytest -from aiohttp import ClientSession - -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface -from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse -from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType -from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.external.pdap.enums import MatchAgencyResponseStatus -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from src.db.models.instantiations.agency.sqlalchemy import Agency -from src.collectors.enums import CollectorType, URLStatus -from src.core.tasks.url.enums import TaskOperatorOutcome -from src.core.tasks.url.subtasks.agency_identification.auto_googler import AutoGooglerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.ckan import CKANAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.common_crawler import CommonCrawlerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.muckrock import MuckrockAgencyIdentificationSubtask -from src.core.enums import SuggestionType -from pdap_access_manager import AccessManager -from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse -from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo -from src.external.pdap.client import PDAPClient -from tests.helpers.data_creator.core import DBDataCreator -from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 - -sample_agency_suggestions = [ - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=None, - agency_name=None, - state=None, - county=None, - locality=None - ), - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=-1, - agency_name="Test Agency", - state="Test State", - county="Test County", - locality="Test Locality" - ), - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=-1, - agency_name="Test Agency 2", - state="Test State 2", - county="Test County 2", - locality="Test Locality 2" - ) -] - -@pytest.mark.asyncio -async def test_agency_preannotation_task(db_data_creator: DBDataCreator): - async def mock_run_subtask( - subtask, - url_id: int, - collector_metadata: Optional[dict] - ): - # Deepcopy to prevent using the same instance in memory - suggestion = deepcopy(sample_agency_suggestions[url_id % 3]) - suggestion.url_id = url_id - suggestion.pdap_agency_id = (url_id % 3) if suggestion.suggestion_type != SuggestionType.UNKNOWN else None - return [suggestion] - - async with ClientSession() as session: - mock = MagicMock() - access_manager = AccessManager( - email=mock.email, - password=mock.password, - api_key=mock.api_key, - session=session - ) - pdap_client = PDAPClient( - access_manager=access_manager - ) - muckrock_api_interface = MuckrockAPIInterface(session=session) - with patch.object( - AgencyIdentificationTaskOperator, - "run_subtask", - side_effect=mock_run_subtask, - ) as mock: - operator = AgencyIdentificationTaskOperator( - adb_client=db_data_creator.adb_client, - pdap_client=pdap_client, - muckrock_api_interface=muckrock_api_interface - ) - - # Confirm does not yet meet prerequisites - assert not await operator.meets_task_prerequisites() - - - d = {} - - # Create six urls, one from each strategy - for strategy in [ - CollectorType.COMMON_CRAWLER, - CollectorType.AUTO_GOOGLER, - CollectorType.MUCKROCK_COUNTY_SEARCH, - CollectorType.MUCKROCK_SIMPLE_SEARCH, - CollectorType.MUCKROCK_ALL_SEARCH, - CollectorType.CKAN - ]: - # Create two URLs for each, one pending and one errored - creation_info: BatchURLCreationInfoV2 = await db_data_creator.batch_v2( - parameters=TestBatchCreationParameters( - strategy=strategy, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING, - with_html_content=True - ), - TestURLCreationParameters( - count=1, - status=URLStatus.ERROR, - with_html_content=True - ) - ] - ) - ) - d[strategy] = creation_info.urls_by_status[URLStatus.PENDING].url_mappings[0].url_id - - - # Confirm meets prerequisites - assert await operator.meets_task_prerequisites() - # Run task - run_info = await operator.run_task(1) - assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - - # Confirm tasks are piped into the correct subtasks - # * common_crawler into common_crawler_subtask - # * auto_googler into auto_googler_subtask - # * muckrock_county_search into muckrock_subtask - # * muckrock_simple_search into muckrock_subtask - # * muckrock_all_search into muckrock_subtask - # * ckan into ckan_subtask - - assert mock.call_count == 6 - - - # Confirm subtask classes are correct for the given urls - d2 = {} - for call_arg in mock.call_args_list: - subtask_class = call_arg[0][0].__class__ - url_id = call_arg[0][1] - d2[url_id] = subtask_class - - - subtask_class_collector_type = [ - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_ALL_SEARCH), - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_COUNTY_SEARCH), - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_SIMPLE_SEARCH), - (CKANAgencyIdentificationSubtask, CollectorType.CKAN), - (CommonCrawlerAgencyIdentificationSubtask, CollectorType.COMMON_CRAWLER), - (AutoGooglerAgencyIdentificationSubtask, CollectorType.AUTO_GOOGLER) - ] - - for subtask_class, collector_type in subtask_class_collector_type: - url_id = d[collector_type] - assert d2[url_id] == subtask_class - - - # Confirm task again does not meet prerequisites - assert not await operator.meets_task_prerequisites() - - - - - # Check confirmed and auto suggestions - adb_client = db_data_creator.adb_client - confirmed_suggestions = await adb_client.get_urls_with_confirmed_agencies() - assert len(confirmed_suggestions) == 2 - - agencies = await adb_client.get_all(Agency) - assert len(agencies) == 2 - - auto_suggestions = await adb_client.get_all(AutomatedUrlAgencySuggestion) - assert len(auto_suggestions) == 4 - - # Of the auto suggestions, 2 should be unknown - assert len([s for s in auto_suggestions if s.is_unknown]) == 2 - - # Of the auto suggestions, 2 should not be unknown - assert len([s for s in auto_suggestions if not s.is_unknown]) == 2 - -@pytest.mark.asyncio -async def test_common_crawler_subtask(db_data_creator: DBDataCreator): - # Test that common_crawler subtask correctly adds URL to - # url_agency_suggestions with label 'Unknown' - subtask = CommonCrawlerAgencyIdentificationSubtask() - results: list[URLAgencySuggestionInfo] = await subtask.run(url_id=1, collector_metadata={}) - assert len(results) == 1 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.UNKNOWN - - -@pytest.mark.asyncio -async def test_auto_googler_subtask(db_data_creator: DBDataCreator): - # Test that auto_googler subtask correctly adds URL to - # url_agency_suggestions with label 'Unknown' - subtask = AutoGooglerAgencyIdentificationSubtask() - results: list[URLAgencySuggestionInfo] = await subtask.run(url_id=1, collector_metadata={}) - assert len(results) == 1 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.UNKNOWN - -@pytest.mark.asyncio -async def test_muckrock_subtask(db_data_creator: DBDataCreator): - # Test that muckrock subtask correctly sends agency name to - # MatchAgenciesInterface and adds received suggestions to - # url_agency_suggestions - - # Create mock instances for dependency injections - muckrock_api_interface_mock = MagicMock(spec=MuckrockAPIInterface) - pdap_client_mock = MagicMock(spec=PDAPClient) - - # Set up mock return values for method calls - muckrock_api_interface_mock.lookup_agency.return_value = AgencyLookupResponse( - type=AgencyLookupResponseType.FOUND, - name="Mock Agency Name", - error=None - ) - - pdap_client_mock.match_agency.return_value = MatchAgencyResponse( - status=MatchAgencyResponseStatus.PARTIAL_MATCH, - matches=[ - MatchAgencyInfo( - id=1, - submitted_name="Mock Agency Name", - ), - MatchAgencyInfo( - id=2, - submitted_name="Another Mock Agency Name", - ) - ] - ) - - # Create an instance of MuckrockAgencyIdentificationSubtask with mock dependencies - muckrock_agency_identification_subtask = MuckrockAgencyIdentificationSubtask( - muckrock_api_interface=muckrock_api_interface_mock, - pdap_client=pdap_client_mock - ) - - # Run the subtask - results: list[URLAgencySuggestionInfo] = await muckrock_agency_identification_subtask.run( - url_id=1, - collector_metadata={ - "agency": 123 - } - ) - - # Verify the results - assert len(results) == 2 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert results[0].pdap_agency_id == 1 - assert results[0].agency_name == "Mock Agency Name" - assert results[1].url_id == 1 - assert results[1].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert results[1].pdap_agency_id == 2 - assert results[1].agency_name == "Another Mock Agency Name" - - # Assert methods called as expected - muckrock_api_interface_mock.lookup_agency.assert_called_once_with( - muckrock_agency_id=123 - ) - pdap_client_mock.match_agency.assert_called_once_with( - name="Mock Agency Name" - ) - - -@pytest.mark.asyncio -async def test_ckan_subtask(db_data_creator: DBDataCreator): - # Test that ckan subtask correctly sends agency id to - # CKANAPIInterface, sends resultant agency name to - # PDAPClient and adds received suggestions to - # url_agency_suggestions - - pdap_client = AsyncMock() - pdap_client.match_agency.return_value = MatchAgencyResponse( - status=MatchAgencyResponseStatus.PARTIAL_MATCH, - matches=[ - MatchAgencyInfo( - id=1, - submitted_name="Mock Agency Name", - ), - MatchAgencyInfo( - id=2, - submitted_name="Another Mock Agency Name", - ) - ] - ) # Assuming MatchAgencyResponse is a class - - # Create an instance of CKANAgencyIdentificationSubtask - task = CKANAgencyIdentificationSubtask(pdap_client) - - # Call the run method with static values - collector_metadata = {"agency_name": "Test Agency"} - url_id = 1 - - # Call the run method - result = await task.run(url_id, collector_metadata) - - # Check the result - assert len(result) == 2 - assert result[0].url_id == 1 - assert result[0].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert result[0].pdap_agency_id == 1 - assert result[0].agency_name == "Mock Agency Name" - assert result[1].url_id == 1 - assert result[1].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert result[1].pdap_agency_id == 2 - assert result[1].agency_name == "Another Mock Agency Name" - - # Assert methods called as expected - pdap_client.match_agency.assert_called_once_with(name="Test Agency") - diff --git a/tests/conftest.py b/tests/conftest.py index e3789b45..f26249cd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ import logging from typing import Any, Generator, AsyncGenerator +from unittest.mock import AsyncMock import pytest import pytest_asyncio +from aiohttp import ClientSession from alembic.config import Config +from pdap_access_manager import AccessManager from sqlalchemy import create_engine, inspect, MetaData from sqlalchemy.orm import scoped_session, sessionmaker @@ -123,3 +126,8 @@ def db_data_creator( ): db_data_creator = DBDataCreator(db_client=db_client_test) yield db_data_creator + +@pytest.fixture +async def test_client_session() -> AsyncGenerator[ClientSession, Any]: + async with ClientSession() as session: + yield session diff --git a/tests/helpers/data_creator/commands/__init__.py b/tests/helpers/data_creator/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/base.py b/tests/helpers/data_creator/commands/base.py new file mode 100644 index 00000000..84e77621 --- /dev/null +++ b/tests/helpers/data_creator/commands/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.client.sync import DatabaseClient +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer + + +class DBDataCreatorCommandBase(ABC): + + def __init__(self,): + self._clients: DBDataCreatorClientContainer | None = None + + def load_clients(self, clients: DBDataCreatorClientContainer): + self._clients = clients + + @property + def clients(self) -> DBDataCreatorClientContainer: + if self._clients is None: + raise Exception("Clients not loaded") + return self._clients + + @property + def db_client(self) -> DatabaseClient: + return self.clients.db + + @property + def adb_client(self) -> AsyncDatabaseClient: + return self.clients.adb + + def run_command_sync(self, command: "DBDataCreatorCommandBase"): + command.load_clients(self._clients) + return command.run_sync() + + async def run_command(self, command: "DBDataCreatorCommandBase"): + command.load_clients(self._clients) + return await command.run() + + @abstractmethod + async def run(self): + raise NotImplementedError + + async def run_sync(self): + raise NotImplementedError \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/__init__.py b/tests/helpers/data_creator/commands/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/agency.py b/tests/helpers/data_creator/commands/impl/agency.py new file mode 100644 index 00000000..97b27a1a --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/agency.py @@ -0,0 +1,29 @@ +from random import randint +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + +@final +class AgencyCommand(DBDataCreatorCommandBase): + + @override + async def run(self) -> int: + agency_id = randint(1, 99999999) + await self.adb_client.upsert_new_agencies( + suggestions=[ + URLAgencySuggestionInfo( + url_id=-1, + suggestion_type=SuggestionType.UNKNOWN, + pdap_agency_id=agency_id, + agency_name=f"Test Agency {agency_id}", + state=f"Test State {agency_id}", + county=f"Test County {agency_id}", + locality=f"Test Locality {agency_id}" + ) + ] + ) + return agency_id diff --git a/tests/helpers/data_creator/commands/impl/annotate.py b/tests/helpers/data_creator/commands/impl/annotate.py new file mode 100644 index 00000000..5f341326 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/annotate.py @@ -0,0 +1,102 @@ +from typing import final + +from typing_extensions import override + +from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo +from src.api.endpoints.review.enums import RejectionReason +from src.core.enums import SuggestionType +from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.suggestion.auto.agency import AgencyAutoSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.record_type import AutoRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.relevant import AutoRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.agency import AgencyUserSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.record_type import UserRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.relevant import UserRelevantSuggestionCommand + + +@final +class AnnotateCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + annotation_info: AnnotationInfo + ): + super().__init__() + self.url_id = url_id + self.annotation_info = annotation_info + + @override + async def run(self) -> None: + info = self.annotation_info + if info.user_relevant is not None: + await self.run_command( + UserRelevantSuggestionCommand( + url_id=self.url_id, + suggested_status=info.user_relevant + ) + ) + if info.auto_relevant is not None: + await self.run_command( + AutoRelevantSuggestionCommand( + url_id=self.url_id, + relevant=info.auto_relevant + ) + ) + if info.user_record_type is not None: + await self.run_command( + UserRecordTypeSuggestionCommand( + url_id=self.url_id, + record_type=info.user_record_type, + ) + ) + if info.auto_record_type is not None: + await self.run_command( + AutoRecordTypeSuggestionCommand( + url_id=self.url_id, + record_type=info.auto_record_type + ) + ) + if info.user_agency is not None: + await self.run_command( + AgencyUserSuggestionsCommand( + url_id=self.url_id, + agency_annotation_info=info.user_agency + ) + ) + if info.auto_agency is not None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=self.url_id, + count=1, + suggestion_type=SuggestionType.AUTO_SUGGESTION + ) + ) + if info.confirmed_agency is not None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=self.url_id, + count=1, + suggestion_type=SuggestionType.CONFIRMED + ) + ) + if info.final_review_approved is not None: + if info.final_review_approved: + final_review_approval_info = FinalReviewApprovalInfo( + url_id=self.url_id, + record_type=self.annotation_info.user_record_type, + agency_ids=[self.annotation_info.user_agency.suggested_agency] + if self.annotation_info.user_agency is not None else None, + description="Test Description", + ) + await self.adb_client.approve_url( + approval_info=final_review_approval_info, + user_id=1 + ) + else: + await self.adb_client.reject_url( + url_id=self.url_id, + user_id=1, + rejection_reason=RejectionReason.NOT_RELEVANT + ) diff --git a/tests/helpers/data_creator/commands/impl/batch.py b/tests/helpers/data_creator/commands/impl/batch.py new file mode 100644 index 00000000..09cdbe61 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/batch.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Optional + +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.models.instantiations.batch.pydantic import BatchInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class DBDataCreatorBatchCommand(DBDataCreatorCommandBase): + + def __init__( + self, + strategy: CollectorType = CollectorType.EXAMPLE, + batch_status: BatchStatus = BatchStatus.IN_PROCESS, + created_at: Optional[datetime] = None + ): + super().__init__() + self.strategy = strategy + self.batch_status = batch_status + self.created_at = created_at + + async def run(self) -> int: + raise NotImplementedError + + def run_sync(self) -> int: + return self.db_client.insert_batch( + BatchInfo( + strategy=self.strategy.value, + status=self.batch_status, + parameters={"test_key": "test_value"}, + user_id=1, + date_generated=self.created_at + ) + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/batch_v2.py b/tests/helpers/data_creator/commands/impl/batch_v2.py new file mode 100644 index 00000000..524416da --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/batch_v2.py @@ -0,0 +1,43 @@ +from src.core.enums import BatchStatus +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.batch import DBDataCreatorBatchCommand +from tests.helpers.data_creator.commands.impl.urls_v2.core import URLsV2Command +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 + + +class BatchV2Command(DBDataCreatorCommandBase): + + def __init__( + self, + parameters: TestBatchCreationParameters + ): + super().__init__() + self.parameters = parameters + + async def run(self) -> BatchURLCreationInfoV2: + # Create batch + command = DBDataCreatorBatchCommand( + strategy=self.parameters.strategy, + batch_status=self.parameters.outcome, + created_at=self.parameters.created_at + ) + batch_id = self.run_command_sync(command) + # Return early if batch would not involve URL creation + if self.parameters.outcome in (BatchStatus.ERROR, BatchStatus.ABORTED): + return BatchURLCreationInfoV2( + batch_id=batch_id, + ) + + response = await self.run_command( + URLsV2Command( + parameters=self.parameters.urls, + batch_id=batch_id, + created_at=self.parameters.created_at + ) + ) + + return BatchURLCreationInfoV2( + batch_id=batch_id, + urls_by_status=response.urls_by_status, + ) diff --git a/tests/helpers/data_creator/commands/impl/html_data.py b/tests/helpers/data_creator/commands/impl/html_data.py new file mode 100644 index 00000000..6c9e95e3 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/html_data.py @@ -0,0 +1,42 @@ +from src.db.dtos.url.html_content import URLHTMLContentInfo, HTMLContentType +from src.db.dtos.url.raw_html import RawHTMLInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer + + +class HTMLDataCreatorCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_ids: list[int] + ): + super().__init__() + self.url_ids = url_ids + + async def run(self) -> None: + html_content_infos = [] + raw_html_info_list = [] + for url_id in self.url_ids: + html_content_infos.append( + URLHTMLContentInfo( + url_id=url_id, + content_type=HTMLContentType.TITLE, + content="test html content" + ) + ) + html_content_infos.append( + URLHTMLContentInfo( + url_id=url_id, + content_type=HTMLContentType.DESCRIPTION, + content="test description" + ) + ) + raw_html_info = RawHTMLInfo( + url_id=url_id, + html="" + ) + raw_html_info_list.append(raw_html_info) + + await self.adb_client.add_raw_html(raw_html_info_list) + await self.adb_client.add_html_content_infos(html_content_infos) + diff --git a/tests/helpers/data_creator/commands/impl/suggestion/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py b/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py new file mode 100644 index 00000000..e096d15e --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py @@ -0,0 +1,29 @@ +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + +@final +class AgencyConfirmedSuggestionCommand(DBDataCreatorCommandBase): + + def __init__(self, url_id: int): + super().__init__() + self.url_id = url_id + + @override + async def run(self) -> int: + agency_id = await self.run_command(AgencyCommand()) + await self.adb_client.add_confirmed_agency_url_links( + suggestions=[ + URLAgencySuggestionInfo( + url_id=self.url_id, + suggestion_type=SuggestionType.CONFIRMED, + pdap_agency_id=agency_id + ) + ] + ) + return agency_id \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/agency.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/agency.py new file mode 100644 index 00000000..96743df8 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/agency.py @@ -0,0 +1,46 @@ +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + +@final +class AgencyAutoSuggestionsCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + count: int, + suggestion_type: SuggestionType = SuggestionType.AUTO_SUGGESTION + ): + super().__init__() + if suggestion_type == SuggestionType.UNKNOWN: + count = 1 # Can only be one auto suggestion if unknown + self.url_id = url_id + self.count = count + self.suggestion_type = suggestion_type + + @override + async def run(self) -> None: + suggestions = [] + for _ in range(self.count): + if self.suggestion_type == SuggestionType.UNKNOWN: + pdap_agency_id = None + else: + pdap_agency_id = await self.run_command(AgencyCommand()) + suggestion = URLAgencySuggestionInfo( + url_id=self.url_id, + suggestion_type=self.suggestion_type, + pdap_agency_id=pdap_agency_id, + state="Test State", + county="Test County", + locality="Test Locality" + ) + suggestions.append(suggestion) + + await self.adb_client.add_agency_auto_suggestions( + suggestions=suggestions + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py new file mode 100644 index 00000000..25ad6e53 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py @@ -0,0 +1,20 @@ +from src.core.enums import RecordType +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class AutoRecordTypeSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + record_type: RecordType + ): + super().__init__() + self.url_id = url_id + self.record_type = record_type + + async def run(self) -> None: + await self.adb_client.add_auto_record_type_suggestion( + url_id=self.url_id, + record_type=self.record_type + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py new file mode 100644 index 00000000..58dfc8fb --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py @@ -0,0 +1,24 @@ +from src.db.models.instantiations.url.suggestion.relevant.auto.pydantic.input import AutoRelevancyAnnotationInput +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class AutoRelevantSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + relevant: bool = True + ): + super().__init__() + self.url_id = url_id + self.relevant = relevant + + async def run(self) -> None: + await self.adb_client.add_auto_relevant_suggestion( + input_=AutoRelevancyAnnotationInput( + url_id=self.url_id, + is_relevant=self.relevant, + confidence=0.5, + model_name="test_model" + ) + ) diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/user/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py b/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py new file mode 100644 index 00000000..35418679 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py @@ -0,0 +1,37 @@ +from random import randint +from typing import final + +from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + + +@final +class AgencyUserSuggestionsCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + user_id: int | None = None, + agency_annotation_info: URLAgencyAnnotationPostInfo | None = None + ): + super().__init__() + if user_id is None: + user_id = randint(1, 99999999) + self.url_id = url_id + self.user_id = user_id + self.agency_annotation_info = agency_annotation_info + + async def run(self) -> None: + if self.agency_annotation_info is None: + agency_annotation_info = URLAgencyAnnotationPostInfo( + suggested_agency=await self.run_command(AgencyCommand()) + ) + else: + agency_annotation_info = self.agency_annotation_info + await self.adb_client.add_agency_manual_suggestion( + agency_id=agency_annotation_info.suggested_agency, + url_id=self.url_id, + user_id=self.user_id, + is_new=agency_annotation_info.is_new + ) diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py b/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py new file mode 100644 index 00000000..03c7ab0b --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py @@ -0,0 +1,25 @@ +from random import randint + +from src.core.enums import RecordType +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class UserRecordTypeSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + record_type: RecordType, + user_id: int | None = None, + ): + super().__init__() + self.url_id = url_id + self.user_id = user_id if user_id is not None else randint(1, 99999999) + self.record_type = record_type + + async def run(self) -> None: + await self.adb_client.add_user_record_type_suggestion( + url_id=self.url_id, + user_id=self.user_id, + record_type=self.record_type + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py b/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py new file mode 100644 index 00000000..9d4df2c3 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py @@ -0,0 +1,29 @@ +from random import randint +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestedStatus +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + +@final +class UserRelevantSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + user_id: int | None = None, + suggested_status: SuggestedStatus = SuggestedStatus.RELEVANT + ): + super().__init__() + self.url_id = url_id + self.user_id = user_id if user_id is not None else randint(1, 99999999) + self.suggested_status = suggested_status + + @override + async def run(self) -> None: + await self.adb_client.add_user_relevant_suggestion( + url_id=self.url_id, + user_id=self.user_id, + suggested_status=self.suggested_status + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/urls.py b/tests/helpers/data_creator/commands/impl/urls.py new file mode 100644 index 00000000..daec2445 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls.py @@ -0,0 +1,64 @@ +from datetime import datetime + +from src.collectors.enums import URLStatus +from src.core.tasks.url.operators.submit_approved_url.tdo import SubmittedURLInfo +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.models.instantiations.url.core.pydantic import URLInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.simple_test_data_functions import generate_test_urls + + +class URLsDBDataCreatorCommand(DBDataCreatorCommandBase): + + def __init__( + self, + batch_id: int | None, + url_count: int, + collector_metadata: dict | None = None, + outcome: URLStatus = URLStatus.PENDING, + created_at: datetime | None = None + ): + super().__init__() + self.batch_id = batch_id + self.url_count = url_count + self.collector_metadata = collector_metadata + self.outcome = outcome + self.created_at = created_at + + async def run(self) -> InsertURLsInfo: + raise NotImplementedError + + def run_sync(self) -> InsertURLsInfo: + raw_urls = generate_test_urls(self.url_count) + url_infos: list[URLInfo] = [] + for url in raw_urls: + url_infos.append( + URLInfo( + url=url, + outcome=self.outcome, + name="Test Name" if self.outcome == URLStatus.VALIDATED else None, + collector_metadata=self.collector_metadata, + created_at=self.created_at + ) + ) + + url_insert_info = self.db_client.insert_urls( + url_infos=url_infos, + batch_id=self.batch_id, + ) + + # If outcome is submitted, also add entry to DataSourceURL + if self.outcome == URLStatus.SUBMITTED: + submitted_url_infos = [] + for url_id in url_insert_info.url_ids: + submitted_url_info = SubmittedURLInfo( + url_id=url_id, + data_source_id=url_id, # Use same ID for convenience, + request_error=None, + submitted_at=self.created_at + ) + submitted_url_infos.append(submitted_url_info) + self.db_client.mark_urls_as_submitted(submitted_url_infos) + + + return url_insert_info \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/__init__.py b/tests/helpers/data_creator/commands/impl/urls_v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/core.py b/tests/helpers/data_creator/commands/impl/urls_v2/core.py new file mode 100644 index 00000000..29d260d6 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_v2/core.py @@ -0,0 +1,66 @@ +from datetime import datetime + +from src.collectors.enums import URLStatus +from src.db.dtos.url.insert import InsertURLsInfo +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.annotate import AnnotateCommand +from tests.helpers.data_creator.commands.impl.html_data import HTMLDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls import URLsDBDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls_v2.response import URLsV2Response +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class URLsV2Command(DBDataCreatorCommandBase): + + def __init__( + self, + parameters: list[TestURLCreationParameters], + batch_id: int | None = None, + created_at: datetime | None = None + ): + super().__init__() + self.parameters = parameters + self.batch_id = batch_id + self.created_at = created_at + + async def run(self) -> URLsV2Response: + urls_by_status: dict[URLStatus, URLCreationInfo] = {} + urls_by_order: list[URLCreationInfo] = [] + # Create urls + for url_parameters in self.parameters: + command = URLsDBDataCreatorCommand( + batch_id=self.batch_id, + url_count=url_parameters.count, + outcome=url_parameters.status, + created_at=self.created_at + ) + iui: InsertURLsInfo = self.run_command_sync(command) + url_ids = [iui.url_id for iui in iui.url_mappings] + if url_parameters.with_html_content: + command = HTMLDataCreatorCommand( + url_ids=url_ids + ) + await self.run_command(command) + if url_parameters.annotation_info.has_annotations(): + for url_id in url_ids: + await self.run_command( + AnnotateCommand( + url_id=url_id, + annotation_info=url_parameters.annotation_info + ) + ) + + creation_info = URLCreationInfo( + url_mappings=iui.url_mappings, + outcome=url_parameters.status, + annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None + ) + urls_by_order.append(creation_info) + urls_by_status[url_parameters.status] = creation_info + + return URLsV2Response( + urls_by_status=urls_by_status, + urls_by_order=urls_by_order + ) diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/response.py b/tests/helpers/data_creator/commands/impl/urls_v2/response.py new file mode 100644 index 00000000..db19328e --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_v2/response.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class URLsV2Response(BaseModel): + urls_by_status: dict[URLStatus, URLCreationInfo] = {} + urls_by_order: list[URLCreationInfo] = [] \ No newline at end of file diff --git a/tests/helpers/data_creator/core.py b/tests/helpers/data_creator/core.py index 696ca104..f86e9a25 100644 --- a/tests/helpers/data_creator/core.py +++ b/tests/helpers/data_creator/core.py @@ -1,7 +1,7 @@ from collections import defaultdict from datetime import datetime from random import randint -from typing import List, Optional +from typing import List, Optional, Any from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo @@ -24,6 +24,24 @@ from src.core.enums import BatchStatus, SuggestionType, RecordType, SuggestedStatus from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand +from tests.helpers.data_creator.commands.impl.annotate import AnnotateCommand +from tests.helpers.data_creator.commands.impl.batch import DBDataCreatorBatchCommand +from tests.helpers.data_creator.commands.impl.batch_v2 import BatchV2Command +from tests.helpers.data_creator.commands.impl.html_data import HTMLDataCreatorCommand +from tests.helpers.data_creator.commands.impl.suggestion.agency_confirmed import AgencyConfirmedSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.agency import AgencyAutoSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.record_type import AutoRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.relevant import AutoRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.agency import AgencyUserSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.record_type import UserRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.relevant import UserRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.urls import URLsDBDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls_v2.core import URLsV2Command +from tests.helpers.data_creator.commands.impl.urls_v2.response import URLsV2Response +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo @@ -40,6 +58,18 @@ def __init__(self, db_client: Optional[DatabaseClient] = None): else: self.db_client = DatabaseClient() self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() + self.clients = DBDataCreatorClientContainer( + adb=self.adb_client, + db=self.db_client + ) + + def run_command_sync(self, command: DBDataCreatorCommandBase) -> Any: + command.load_clients(self.clients) + return command.run_sync() + + async def run_command(self, command: DBDataCreatorCommandBase) -> Any: + command.load_clients(self.clients) + return await command.run() def batch( self, @@ -47,15 +77,12 @@ def batch( batch_status: BatchStatus = BatchStatus.IN_PROCESS, created_at: Optional[datetime] = None ) -> int: - return self.db_client.insert_batch( - BatchInfo( - strategy=strategy.value, - status=batch_status, - parameters={"test_key": "test_value"}, - user_id=1, - date_generated=created_at - ) + command = DBDataCreatorBatchCommand( + strategy=strategy, + batch_status=batch_status, + created_at=created_at ) + return self.run_command_sync(command) async def task(self, url_ids: Optional[list[int]] = None) -> int: task_id = await self.adb_client.initiate_task(task_type=TaskType.HTML) @@ -67,51 +94,23 @@ async def batch_v2( self, parameters: TestBatchCreationParameters ) -> BatchURLCreationInfoV2: - # Create batch - batch_id = self.batch( - strategy=parameters.strategy, - batch_status=parameters.outcome, - created_at=parameters.created_at - ) - # Return early if batch would not involve URL creation - if parameters.outcome in (BatchStatus.ERROR, BatchStatus.ABORTED): - return BatchURLCreationInfoV2( - batch_id=batch_id, - ) + return await self.run_command(BatchV2Command(parameters)) - urls_by_status: dict[URLStatus, URLCreationInfo] = {} - urls_by_order: list[URLCreationInfo] = [] - # Create urls - for url_parameters in parameters.urls: - iui: InsertURLsInfo = self.urls( + async def url_v2( + self, + parameters: list[TestURLCreationParameters], + batch_id: int | None = None, + created_at: datetime | None = None + ) -> URLsV2Response: + return await self.run_command( + URLsV2Command( + parameters=parameters, batch_id=batch_id, - url_count=url_parameters.count, - outcome=url_parameters.status, - created_at=parameters.created_at + created_at=created_at ) - url_ids = [iui.url_id for iui in iui.url_mappings] - if url_parameters.with_html_content: - await self.html_data(url_ids) - if url_parameters.annotation_info.has_annotations(): - for url_id in url_ids: - await self.annotate( - url_id=url_id, - annotation_info=url_parameters.annotation_info - ) - - creation_info = URLCreationInfo( - url_mappings=iui.url_mappings, - outcome=url_parameters.status, - annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None - ) - urls_by_order.append(creation_info) - urls_by_status[url_parameters.status] = creation_info - - return BatchURLCreationInfoV2( - batch_id=batch_id, - urls_by_status=urls_by_status, ) + async def batch_and_urls( self, strategy: CollectorType = CollectorType.EXAMPLE, @@ -146,97 +145,28 @@ async def batch_and_urls( ) async def agency(self) -> int: - agency_id = randint(1, 99999999) - await self.adb_client.upsert_new_agencies( - suggestions=[ - URLAgencySuggestionInfo( - url_id=-1, - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=agency_id, - agency_name=f"Test Agency {agency_id}", - state=f"Test State {agency_id}", - county=f"Test County {agency_id}", - locality=f"Test Locality {agency_id}" - ) - ] - ) - return agency_id + return await self.run_command(AgencyCommand()) async def auto_relevant_suggestions(self, url_id: int, relevant: bool = True): - await self.adb_client.add_auto_relevant_suggestion( - input_=AutoRelevancyAnnotationInput( + await self.run_command( + AutoRelevantSuggestionCommand( url_id=url_id, - is_relevant=relevant, - confidence=0.5, - model_name="test_model" + relevant=relevant ) ) - async def annotate( - self, - url_id: int, - annotation_info: AnnotationInfo - ): - info = annotation_info - if info.user_relevant is not None: - await self.user_relevant_suggestion_v2(url_id=url_id, suggested_status=info.user_relevant) - if info.auto_relevant is not None: - await self.auto_relevant_suggestions(url_id=url_id, relevant=info.auto_relevant) - if info.user_record_type is not None: - await self.user_record_type_suggestion(url_id=url_id, record_type=info.user_record_type) - if info.auto_record_type is not None: - await self.auto_record_type_suggestions(url_id=url_id, record_type=info.auto_record_type) - if info.user_agency is not None: - await self.agency_user_suggestions(url_id=url_id, agency_annotation_info=info.user_agency) - if info.auto_agency is not None: - await self.agency_auto_suggestions(url_id=url_id, count=1, suggestion_type=SuggestionType.AUTO_SUGGESTION) - if info.confirmed_agency is not None: - await self.agency_auto_suggestions(url_id=url_id, count=1, suggestion_type=SuggestionType.CONFIRMED) - if info.final_review_approved is not None: - if info.final_review_approved: - final_review_approval_info = FinalReviewApprovalInfo( - url_id=url_id, - record_type=annotation_info.user_record_type, - agency_ids=[annotation_info.user_agency.suggested_agency] - if annotation_info.user_agency is not None else None, - description="Test Description", - ) - await self.adb_client.approve_url( - approval_info=final_review_approval_info, - user_id=1 - ) - else: - await self.adb_client.reject_url( - url_id=url_id, - user_id=1, - rejection_reason=RejectionReason.NOT_RELEVANT - ) - - async def user_relevant_suggestion( self, url_id: int, - user_id: Optional[int] = None, - relevant: bool = True - ): - await self.user_relevant_suggestion_v2( - url_id=url_id, - user_id=user_id, - suggested_status=SuggestedStatus.RELEVANT if relevant else SuggestedStatus.NOT_RELEVANT - ) - - async def user_relevant_suggestion_v2( - self, - url_id: int, - user_id: Optional[int] = None, + user_id: int | None = None, suggested_status: SuggestedStatus = SuggestedStatus.RELEVANT - ): - if user_id is None: - user_id = randint(1, 99999999) - await self.adb_client.add_user_relevant_suggestion( - url_id=url_id, - user_id=user_id, - suggested_status=suggested_status + ) -> None: + await self.run_command( + UserRelevantSuggestionCommand( + url_id=url_id, + user_id=user_id, + suggested_status=suggested_status + ) ) async def user_record_type_suggestion( @@ -244,22 +174,27 @@ async def user_record_type_suggestion( url_id: int, record_type: RecordType, user_id: Optional[int] = None, - ): - if user_id is None: - user_id = randint(1, 99999999) - await self.adb_client.add_user_record_type_suggestion( - url_id=url_id, - user_id=user_id, - record_type=record_type + ) -> None: + await self.run_command( + UserRecordTypeSuggestionCommand( + url_id=url_id, + record_type=record_type, + user_id=user_id + ) ) - async def auto_record_type_suggestions(self, url_id: int, record_type: RecordType): - await self.adb_client.add_auto_record_type_suggestion( - url_id=url_id, - record_type=record_type + async def auto_record_type_suggestions( + self, + url_id: int, + record_type: RecordType + ): + await self.run_command( + AutoRecordTypeSuggestionCommand( + url_id=url_id, + record_type=record_type + ) ) - async def auto_suggestions( self, url_ids: list[int], @@ -315,43 +250,18 @@ def urls( self, batch_id: int, url_count: int, - collector_metadata: Optional[dict] = None, + collector_metadata: dict | None = None, outcome: URLStatus = URLStatus.PENDING, - created_at: Optional[datetime] = None + created_at: datetime | None = None ) -> InsertURLsInfo: - raw_urls = generate_test_urls(url_count) - url_infos: List[URLInfo] = [] - for url in raw_urls: - url_infos.append( - URLInfo( - url=url, - outcome=outcome, - name="Test Name" if outcome == URLStatus.VALIDATED else None, - collector_metadata=collector_metadata, - created_at=created_at - ) - ) - - url_insert_info = self.db_client.insert_urls( - url_infos=url_infos, + command = URLsDBDataCreatorCommand( batch_id=batch_id, + url_count=url_count, + collector_metadata=collector_metadata, + outcome=outcome, + created_at=created_at ) - - # If outcome is submitted, also add entry to DataSourceURL - if outcome == URLStatus.SUBMITTED: - submitted_url_infos = [] - for url_id in url_insert_info.url_ids: - submitted_url_info = SubmittedURLInfo( - url_id=url_id, - data_source_id=url_id, # Use same ID for convenience, - request_error=None, - submitted_at=created_at - ) - submitted_url_infos.append(submitted_url_info) - self.db_client.mark_urls_as_submitted(submitted_url_infos) - - - return url_insert_info + return self.run_command_sync(command) async def url_miscellaneous_metadata( self, @@ -394,32 +304,11 @@ def duplicate_urls(self, duplicate_batch_id: int, url_ids: list[int]): self.db_client.insert_duplicates(duplicate_infos) - async def html_data(self, url_ids: list[int]): - html_content_infos = [] - raw_html_info_list = [] - for url_id in url_ids: - html_content_infos.append( - URLHTMLContentInfo( - url_id=url_id, - content_type=HTMLContentType.TITLE, - content="test html content" - ) - ) - html_content_infos.append( - URLHTMLContentInfo( - url_id=url_id, - content_type=HTMLContentType.DESCRIPTION, - content="test description" - ) - ) - raw_html_info = RawHTMLInfo( - url_id=url_id, - html="" - ) - raw_html_info_list.append(raw_html_info) - - await self.adb_client.add_raw_html(raw_html_info_list) - await self.adb_client.add_html_content_infos(html_content_infos) + async def html_data(self, url_ids: list[int]) -> None: + command = HTMLDataCreatorCommand( + url_ids=url_ids + ) + await self.run_command(command) async def error_info( self, @@ -444,28 +333,13 @@ async def agency_auto_suggestions( url_id: int, count: int, suggestion_type: SuggestionType = SuggestionType.AUTO_SUGGESTION - ): - if suggestion_type == SuggestionType.UNKNOWN: - count = 1 # Can only be one auto suggestion if unknown - - suggestions = [] - for _ in range(count): - if suggestion_type == SuggestionType.UNKNOWN: - pdap_agency_id = None - else: - pdap_agency_id = await self.agency() - suggestion = URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=suggestion_type, - pdap_agency_id=pdap_agency_id, - state="Test State", - county="Test County", - locality="Test Locality" + ) -> None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=url_id, + count=count, + suggestion_type=suggestion_type ) - suggestions.append(suggestion) - - await self.adb_client.add_agency_auto_suggestions( - suggestions=suggestions ) async def agency_confirmed_suggestion( @@ -473,37 +347,22 @@ async def agency_confirmed_suggestion( url_id: int ) -> int: """ - Creates a confirmed agency suggestion - and returns the auto-generated pdap_agency_id + Create a confirmed agency suggestion and return the auto-generated pdap_agency_id. """ - agency_id = await self.agency() - await self.adb_client.add_confirmed_agency_url_links( - suggestions=[ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=agency_id - ) - ] + return await self.run_command( + AgencyConfirmedSuggestionCommand(url_id) ) - return agency_id async def agency_user_suggestions( self, url_id: int, - user_id: Optional[int] = None, - agency_annotation_info: Optional[URLAgencyAnnotationPostInfo] = None - ): - if user_id is None: - user_id = randint(1, 99999999) - - if agency_annotation_info is None: - agency_annotation_info = URLAgencyAnnotationPostInfo( - suggested_agency=await self.agency() + user_id: int | None = None, + agency_annotation_info: URLAgencyAnnotationPostInfo | None = None + ) -> None: + await self.run_command( + AgencyUserSuggestionsCommand( + url_id=url_id, + user_id=user_id, + agency_annotation_info=agency_annotation_info ) - await self.adb_client.add_agency_manual_suggestion( - agency_id=agency_annotation_info.suggested_agency, - url_id=url_id, - user_id=user_id, - is_new=agency_annotation_info.is_new ) diff --git a/tests/helpers/data_creator/models/clients.py b/tests/helpers/data_creator/models/clients.py new file mode 100644 index 00000000..a8256dfc --- /dev/null +++ b/tests/helpers/data_creator/models/clients.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.client.sync import DatabaseClient + + +class DBDataCreatorClientContainer(BaseModel): + db: DatabaseClient + adb: AsyncDatabaseClient + + class Config: + arbitrary_types_allowed = True diff --git a/tests/helpers/setup/final_review/core.py b/tests/helpers/setup/final_review/core.py index d9c3aa10..6c4a3498 100644 --- a/tests/helpers/setup/final_review/core.py +++ b/tests/helpers/setup/final_review/core.py @@ -1,7 +1,7 @@ from typing import Optional from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.core.enums import RecordType +from src.core.enums import RecordType, SuggestedStatus from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.final_review.model import FinalReviewSetupInfo @@ -46,7 +46,7 @@ async def add_record_type_suggestion(record_type: RecordType): async def add_relevant_suggestion(relevant: bool): await db_data_creator.user_relevant_suggestion( url_id=url_mapping.url_id, - relevant=relevant + suggested_status=SuggestedStatus.RELEVANT if relevant else SuggestedStatus.NOT_RELEVANT ) await db_data_creator.auto_relevant_suggestions( diff --git a/tests/manual/source_collectors/test_autogoogler_collector.py b/tests/manual/source_collectors/test_autogoogler_collector.py index c5ebda01..320434e1 100644 --- a/tests/manual/source_collectors/test_autogoogler_collector.py +++ b/tests/manual/source_collectors/test_autogoogler_collector.py @@ -20,13 +20,9 @@ async def test_autogoogler_collector(monkeypatch): collector = AutoGooglerCollector( batch_id=1, dto=AutoGooglerInputDTO( - urls_per_result=5, + urls_per_result=20, queries=[ - "brooklyn new york city police data", - "queens new york city police data", - "staten island new york city police data", - "manhattan new york city police data", - "bronx new york city police data" + "pennsylvania police officer roster" ], ), logger = AsyncMock(spec=AsyncCoreLogger),