From 74247a70b71040188021ceea4e7fabff964e19f2 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Sat, 15 Mar 2025 17:46:23 -0400 Subject: [PATCH 1/5] DRAFT --- ENV.md | 8 +++-- collector_db/enums.py | 1 + core/AsyncCore.py | 30 ++++++++++++++----- .../task_data_objects/SubmitApprovedURLTDO.py | 11 +++++++ core/README.md | 3 +- core/classes/SubmitApprovedURLTaskOperator.py | 28 +++++++++++++++++ pdap_api_client/AccessManager.py | 5 ++-- 7 files changed, 72 insertions(+), 14 deletions(-) create mode 100644 core/DTOs/task_data_objects/SubmitApprovedURLTDO.py create mode 100644 core/classes/SubmitApprovedURLTaskOperator.py diff --git a/ENV.md b/ENV.md index 68359348..92b7de31 100644 --- a/ENV.md +++ b/ENV.md @@ -14,11 +14,13 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | |`POSTGRES_HOST` | The host for the test database | `127.0.0.1` | |`POSTGRES_PORT` | The port for the test database | `5432` | -|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | +|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | |`DEV`| Set to any value to run the application in development mode. | `true` | |`DEEPSEEK_API_KEY`| The API key required for accessing the DeepSeek API. | `abc123` | |`OPENAI_API_KEY`| The API key required for accessing the OpenAI API. | `abc123` | -|`PDAP_EMAIL`| An email address for accessing the PDAP API. | `abc123@test.com` | -|`PDAP_PASSWORD`| A password for accessing the PDAP API. | `abc123` | +|`PDAP_EMAIL`| An email address for accessing the PDAP API.[^1] | `abc123@test.com` | +|`PDAP_PASSWORD`| A password for accessing the PDAP API.[^1] | `abc123` | |`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | +|`PDAP_API_URL`| The URL for the PDAP API| `https://data-sources-v2.pdap.dev/api`| +[^1:] The user account in question will require elevated permissions to access certain endpoints. At a minimum, the user will require the `source_collector` and `db_write` permissions. \ No newline at end of file diff --git a/collector_db/enums.py b/collector_db/enums.py index 2d82e87b..60b3df13 100644 --- a/collector_db/enums.py +++ b/collector_db/enums.py @@ -37,6 +37,7 @@ class TaskType(PyEnum): RELEVANCY = "Relevancy" RECORD_TYPE = "Record Type" AGENCY_IDENTIFICATION = "Agency Identification" + SUBMIT_APPROVED = "Submit Approved URLs" class PGEnum(TypeDecorator): impl = postgresql.ENUM diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 4854926e..08480a61 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -18,6 +18,7 @@ from core.DTOs.AnnotationRequestInfo import AnnotationRequestInfo from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator +from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator from core.classes.TaskOperatorBase import TaskOperatorBase from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator @@ -51,6 +52,15 @@ def __init__( self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) + async def get_pdap_client(self): + return PDAPClient( + access_manager=AccessManager( + email=get_from_env("PDAP_EMAIL"), + password=get_from_env("PDAP_PASSWORD"), + api_key=get_from_env("PDAP_API_KEY"), + ), + ) + async def get_url_html_task_operator(self): self.logger.info("Running URL HTML Task") operator = URLHTMLTaskOperator( @@ -76,13 +86,7 @@ async def get_url_record_type_task_operator(self): return operator async def get_agency_identification_task_operator(self): - pdap_client = PDAPClient( - access_manager=AccessManager( - email=get_from_env("PDAP_EMAIL"), - password=get_from_env("PDAP_PASSWORD"), - api_key=get_from_env("PDAP_API_KEY"), - ), - ) + pdap_client = await self.get_pdap_client() muckrock_api_interface = MuckrockAPIInterface() operator = AgencyIdentificationTaskOperator( adb_client=self.adb_client, @@ -91,12 +95,22 @@ async def get_agency_identification_task_operator(self): ) return operator + async def get_submit_approved_url_task_operator(self): + pdap_client = await self.get_pdap_client() + operator = SubmitApprovedURLTaskOperator( + adb_client=self.adb_client, + pdap_client=pdap_client + ) + return operator + + async def get_task_operators(self) -> list[TaskOperatorBase]: return [ await self.get_url_html_task_operator(), await self.get_url_relevance_huggingface_task_operator(), await self.get_url_record_type_task_operator(), - await self.get_agency_identification_task_operator() + await self.get_agency_identification_task_operator(), + await self.get_submit_approved_url_task_operator(), ] async def run_tasks(self): diff --git a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py new file mode 100644 index 00000000..ee1b8dc6 --- /dev/null +++ b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py @@ -0,0 +1,11 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.enums import RecordType + + +class SubmitApprovedURLTDO(BaseModel): + url: str + record_type: RecordType + agency_id: Optional[int] \ No newline at end of file diff --git a/core/README.md b/core/README.md index 25b1cde3..9546f613 100644 --- a/core/README.md +++ b/core/README.md @@ -11,4 +11,5 @@ The Source Collector Core is a directory which integrates: - **Cycle**: Refers to the overall lifecycle for Each URL -- from initial retrieval in a Batch to either disposal or incorporation into the Data Sources App Database - **Task**: A semi-independent operation performed on a set of URLs. These include: Collection, retrieving HTML data, getting metadata via Machine Learning, and so on. - **Task Set**: Refers to a group of URLs that are operated on together as part of a single task. These URLs in a set are not necessarily all from the same batch. URLs in a task set should only be operated on in that task once. -- **Task Operator**: A class which performs a single task on a set of URLs. \ No newline at end of file +- **Task Operator**: A class which performs a single task on a set of URLs. +- **Subtask**: A subcomponent of a Task Operator which performs a single operation on a single URL. Often distinguished by the Collector Strategy used for that URL. \ No newline at end of file diff --git a/core/classes/SubmitApprovedURLTaskOperator.py b/core/classes/SubmitApprovedURLTaskOperator.py new file mode 100644 index 00000000..633f8c1e --- /dev/null +++ b/core/classes/SubmitApprovedURLTaskOperator.py @@ -0,0 +1,28 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.enums import TaskType +from core.classes.TaskOperatorBase import TaskOperatorBase +from pdap_api_client.PDAPClient import PDAPClient + + +class SubmitApprovedURLTaskOperator(TaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + pdap_client: PDAPClient + ): + super().__init__(adb_client) + self.pdap_client = pdap_client + + @property + def task_type(self): + return TaskType.SUBMIT_APPROVED + + async def meets_task_prerequisites(self): + return await self.adb_client.has_validated_urls() + + async def inner_task_logic(self): + raise NotImplementedError + + async def update_errors_in_database(self, error_tdos: list[UrlHtmlTDO]): + raise NotImplementedError \ No newline at end of file diff --git a/pdap_api_client/AccessManager.py b/pdap_api_client/AccessManager.py index c39ba1e8..1020f365 100644 --- a/pdap_api_client/AccessManager.py +++ b/pdap_api_client/AccessManager.py @@ -5,8 +5,8 @@ from aiohttp import ClientSession from pdap_api_client.DTOs import RequestType, Namespaces, RequestInfo, ResponseInfo +from util.helper_functions import get_from_env -API_URL = "https://data-sources-v2.pdap.dev/api" request_methods = { RequestType.POST: ClientSession.post, RequestType.PUT: ClientSession.put, @@ -23,7 +23,8 @@ def build_url( namespace: Namespaces, subdomains: Optional[list[str]] = None ): - url = f"{API_URL}/{namespace.value}" + api_url = get_from_env('PDAP_API_URL') + url = f"{api_url}/{namespace.value}" if subdomains is not None: url = f"{url}/{'/'.join(subdomains)}" return url From a3dedcd0574746a63ce1241e3527189078a52270 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Mon, 31 Mar 2025 19:52:24 -0400 Subject: [PATCH 2/5] DRAFT --- ..._add_data_source_id_column_to_url_table.py | 31 +++++++++++++++++++ collector_db/models.py | 1 + .../task_data_objects/SubmitApprovedURLTDO.py | 8 ++++- core/classes/SubmitApprovedURLTaskOperator.py | 1 + 4 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py diff --git a/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py new file mode 100644 index 00000000..8e15dbf2 --- /dev/null +++ b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py @@ -0,0 +1,31 @@ +"""Add data source ID column to URL table + +Revision ID: 33a546c93441 +Revises: 5ea47dacd0ef +Create Date: 2025-03-29 17:16:11.863064 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '33a546c93441' +down_revision: Union[str, None] = '5ea47dacd0ef' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + 'url', + sa.Column('data_source_id', sa.Integer(), nullable=True) + ) + # Add unique constraint to data_source_id column + op.create_unique_constraint('uq_data_source_id', 'url', ['data_source_id']) + + +def downgrade() -> None: + op.drop_column('url', 'data_source_id') diff --git a/collector_db/models.py b/collector_db/models.py index 55b75af2..4a82e68c 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -105,6 +105,7 @@ class URL(Base): relevant = Column(Boolean, nullable=True) created_at = get_created_at_column() updated_at = get_updated_at_column() + data_source_id = Column(Integer, nullable=True) # Relationships batch = relationship("Batch", back_populates="urls") diff --git a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py index ee1b8dc6..fc6e789b 100644 --- a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py +++ b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py @@ -8,4 +8,10 @@ class SubmitApprovedURLTDO(BaseModel): url: str record_type: RecordType - agency_id: Optional[int] \ No newline at end of file + agency_id: Optional[int] + name: str + description: str + record_formats: Optional[list[str]] = None + data_portal_type: Optional[str] = None + supplying_entity: Optional[str] = None + data_source_id: Optional[int] = None \ No newline at end of file diff --git a/core/classes/SubmitApprovedURLTaskOperator.py b/core/classes/SubmitApprovedURLTaskOperator.py index 633f8c1e..06b28a18 100644 --- a/core/classes/SubmitApprovedURLTaskOperator.py +++ b/core/classes/SubmitApprovedURLTaskOperator.py @@ -1,5 +1,6 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.enums import TaskType +from core.DTOs.task_data_objects.UrlHtmlTDO import UrlHtmlTDO from core.classes.TaskOperatorBase import TaskOperatorBase from pdap_api_client.PDAPClient import PDAPClient From 77c7dff9ed972b01ca679602f2f84ca97e2371c9 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Tue, 8 Apr 2025 20:34:53 -0400 Subject: [PATCH 3/5] DRAFT --- ..._add_data_source_id_column_to_url_table.py | 10 +- collector_db/AsyncDatabaseClient.py | 64 ++++++- .../task_data_objects/SubmitApprovedURLTDO.py | 3 +- core/classes/SubmitApprovedURLTaskOperator.py | 32 +++- pdap_api_client/DTOs.py | 1 + pdap_api_client/PDAPClient.py | 30 ++- tests/helpers/DBDataCreator.py | 5 +- .../tasks/test_submit_approved_url_task.py | 171 ++++++++++++++++++ 8 files changed, 302 insertions(+), 14 deletions(-) create mode 100644 tests/test_automated/integration/tasks/test_submit_approved_url_task.py diff --git a/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py index 8e15dbf2..b92fe1ef 100644 --- a/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py +++ b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py @@ -1,7 +1,7 @@ """Add data source ID column to URL table Revision ID: 33a546c93441 -Revises: 5ea47dacd0ef +Revises: 45271f8fe75d Create Date: 2025-03-29 17:16:11.863064 """ @@ -13,19 +13,19 @@ # revision identifiers, used by Alembic. revision: str = '33a546c93441' -down_revision: Union[str, None] = '5ea47dacd0ef' +down_revision: Union[str, None] = '45271f8fe75d' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: op.add_column( - 'url', + 'urls', sa.Column('data_source_id', sa.Integer(), nullable=True) ) # Add unique constraint to data_source_id column - op.create_unique_constraint('uq_data_source_id', 'url', ['data_source_id']) + op.create_unique_constraint('uq_data_source_id', 'urls', ['data_source_id']) def downgrade() -> None: - op.drop_column('url', 'data_source_id') + op.drop_column('urls', 'data_source_id') diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index 34ebe7f7..e74a28ec 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -15,7 +15,6 @@ from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo, HTMLContentType from collector_db.DTOs.URLInfo import URLInfo from collector_db.DTOs.URLMapping import URLMapping -from collector_db.DTOs.URLWithHTML import URLWithHTML from collector_db.StatementComposer import StatementComposer from collector_db.constants import PLACEHOLDER_AGENCY_NAME from collector_db.enums import URLMetadataAttributeType, TaskType @@ -37,6 +36,7 @@ GetURLsResponseInnerInfo from core.DTOs.URLAgencySuggestionInfo import URLAgencySuggestionInfo from core.DTOs.task_data_objects.AgencyIdentificationTDO import AgencyIdentificationTDO +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO from core.DTOs.task_data_objects.URLMiscellaneousMetadataTDO import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo from core.enums import BatchStatus, SuggestionType, RecordType from html_tag_collector.DataClassTags import convert_to_response_html_info @@ -1337,4 +1337,64 @@ async def reject_url( url_id=url_id ) - session.add(rejecting_user_url) \ No newline at end of file + session.add(rejecting_user_url) + + @session_manager + async def has_validated_urls(self, session: AsyncSession) -> bool: + query = ( + select(URL) + .where(URL.outcome == URLStatus.VALIDATED.value) + ) + urls = await session.execute(query) + urls = urls.scalars().all() + return len(urls) > 0 + + @session_manager + async def get_validated_urls( + self, + session: AsyncSession + ) -> list[SubmitApprovedURLTDO]: + query = ( + select(URL) + .where(URL.outcome == URLStatus.VALIDATED.value) + .options( + selectinload(URL.optional_data_source_metadata), + selectinload(URL.confirmed_agencies) + ) + ) + urls = await session.execute(query) + urls = urls.scalars().all() + results: list[SubmitApprovedURLTDO] = [] + for url in urls: + agency_ids = [] + for agency in url.confirmed_agencies: + agency_ids.append(agency.agency_id) + tdo = SubmitApprovedURLTDO( + url_id=url.id, + url=url.url, + name=url.name, + agency_ids=agency_ids, + description=url.description, + record_type=url.record_type, + record_formats=url.optional_data_source_metadata.record_formats, + data_portal_type=url.optional_data_source_metadata.data_portal_type, + supplying_entity=url.optional_data_source_metadata.supplying_entity, + ) + results.append(tdo) + return results + + @session_manager + async def mark_urls_as_submitted(self, session: AsyncSession, tdos: list[SubmitApprovedURLTDO]): + for tdo in tdos: + url_id = tdo.url_id + data_source_id = tdo.data_source_id + query = ( + update(URL) + .where(URL.id == url_id) + .values( + data_source_id=data_source_id, + outcome=URLStatus.SUBMITTED.value + ) + ) + await session.execute(query) + diff --git a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py index fc6e789b..45fa7daf 100644 --- a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py +++ b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py @@ -6,9 +6,10 @@ class SubmitApprovedURLTDO(BaseModel): + url_id: int url: str record_type: RecordType - agency_id: Optional[int] + agency_ids: list[int] name: str description: str record_formats: Optional[list[str]] = None diff --git a/core/classes/SubmitApprovedURLTaskOperator.py b/core/classes/SubmitApprovedURLTaskOperator.py index 06b28a18..2a308e7c 100644 --- a/core/classes/SubmitApprovedURLTaskOperator.py +++ b/core/classes/SubmitApprovedURLTaskOperator.py @@ -1,6 +1,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo from collector_db.enums import TaskType -from core.DTOs.task_data_objects.UrlHtmlTDO import UrlHtmlTDO +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO from core.classes.TaskOperatorBase import TaskOperatorBase from pdap_api_client.PDAPClient import PDAPClient @@ -23,7 +24,30 @@ async def meets_task_prerequisites(self): return await self.adb_client.has_validated_urls() async def inner_task_logic(self): - raise NotImplementedError + # Retrieve all URLs that are validated and not submitted + tdos: list[SubmitApprovedURLTDO] = await self.adb_client.get_validated_urls() - async def update_errors_in_database(self, error_tdos: list[UrlHtmlTDO]): - raise NotImplementedError \ No newline at end of file + # Link URLs to this task + await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) + + # Submit each URL, recording errors if they exist + error_infos: list[URLErrorPydanticInfo] = [] + success_tdos: list[SubmitApprovedURLTDO] = [] + for tdo in tdos: + try: + data_source_id = await self.pdap_client.submit_url(tdo) + tdo.data_source_id = data_source_id + success_tdos.append(tdo) + except Exception as e: + error_info = URLErrorPydanticInfo( + task_id=self.task_id, + url_id=tdo.url_id, + error=str(e), + ) + error_infos.append(error_info) + + # Update the database for successful submissions + await self.adb_client.mark_urls_as_submitted(tdos=success_tdos) + + # Update the database for failed submissions + await self.adb_client.add_url_error_infos(error_infos) diff --git a/pdap_api_client/DTOs.py b/pdap_api_client/DTOs.py index 19255a35..37d7e857 100644 --- a/pdap_api_client/DTOs.py +++ b/pdap_api_client/DTOs.py @@ -36,6 +36,7 @@ class Namespaces(Enum): AUTH = "auth" MATCH = "match" CHECK = "check" + DATA_SOURCES = "data-sources" class RequestType(Enum): diff --git a/pdap_api_client/PDAPClient.py b/pdap_api_client/PDAPClient.py index b2b89564..8b1c5e82 100644 --- a/pdap_api_client/PDAPClient.py +++ b/pdap_api_client/PDAPClient.py @@ -1,5 +1,6 @@ from typing import Optional +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO from pdap_api_client.AccessManager import build_url, AccessManager from pdap_api_client.DTOs import MatchAgencyInfo, UniqueURLDuplicateInfo, UniqueURLResponseInfo, Namespaces, \ RequestType, RequestInfo, MatchAgencyResponse @@ -21,7 +22,6 @@ async def match_agency( county: Optional[str] = None, locality: Optional[str] = None ) -> MatchAgencyResponse: - # TODO: Change to async """ Returns agencies, if any, that match or partially match the search criteria """ @@ -84,3 +84,31 @@ async def is_url_unique( is_unique=is_unique, duplicates=duplicates ) + + async def submit_url( + self, + tdo: SubmitApprovedURLTDO + ) -> int: + url = build_url( + namespace=Namespaces.DATA_SOURCES, + ) + headers = await self.access_manager.jwt_header() + request_info = RequestInfo( + type_=RequestType.POST, + url=url, + headers=headers, + json={ + "entry_data": { + "name": tdo.name, + "description": tdo.description, + "source_url": tdo.url, + "record_type_name": tdo.record_type.value, + "record_formats": tdo.record_formats, + "data_portal_type": tdo.data_portal_type, + "supplying_entity": tdo.supplying_entity + }, + "linked_agency_ids": tdo.agency_ids + } + ) + response_info = await self.access_manager.make_request(request_info) + return response_info.data["id"] diff --git a/tests/helpers/DBDataCreator.py b/tests/helpers/DBDataCreator.py index 9f9719a7..dbf7072a 100644 --- a/tests/helpers/DBDataCreator.py +++ b/tests/helpers/DBDataCreator.py @@ -61,7 +61,10 @@ async def batch_and_urls( if with_html_content: await self.html_data(url_ids) - return BatchURLCreationInfo(batch_id=batch_id, url_ids=url_ids) + return BatchURLCreationInfo( + batch_id=batch_id, + url_ids=url_ids + ) async def agency(self) -> int: agency_id = randint(1, 99999999) diff --git a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py new file mode 100644 index 00000000..75630af8 --- /dev/null +++ b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py @@ -0,0 +1,171 @@ +from http import HTTPStatus +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from collector_db.models import URL +from collector_manager.enums import URLStatus +from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo +from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome +from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator +from core.enums import RecordType +from helpers.DBDataCreator import BatchURLCreationInfo, DBDataCreator +from pdap_api_client.AccessManager import AccessManager +from pdap_api_client.DTOs import RequestInfo, RequestType, ResponseInfo +from pdap_api_client.PDAPClient import PDAPClient + + +@pytest.fixture +def mock_pdap_client(): + mock_access_manager = MagicMock( + spec=AccessManager + ) + mock_access_manager.make_request = AsyncMock( + side_effect=[ + ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "id": 21 + } + ), + ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "id": 34 + } + ) + ] + ) + mock_access_manager.jwt_header = AsyncMock( + return_value={"Authorization": "Bearer token"} + ) + pdap_client = PDAPClient( + access_manager=mock_access_manager + ) + return pdap_client + +async def setup_validated_urls(db_data_creator: DBDataCreator): + creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( + url_count=2, + with_html_content=True + ) + url_1 = creation_info.url_ids[0] + url_2 = creation_info.url_ids[1] + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_1, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[1, 2], + name="URL 1 Name", + description="URL 1 Description", + record_formats=["Record Format 1", "Record Format 2"], + data_portal_type="Data Portal Type 1", + supplying_entity="Supplying Entity 1" + ), + user_id=1 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_2, + record_type=RecordType.INCARCERATION_RECORDS, + agency_ids=[3, 4], + name="URL 2 Name", + description="URL 2 Description", + ), + user_id=1 + ) + +@pytest.mark.asyncio +async def test_submit_approved_url_task( + db_data_creator, + mock_pdap_client, + monkeypatch +): + monkeypatch.setenv("PDAP_API_URL", "http://localhost:8000") + + # Get Task Operator + operator = SubmitApprovedURLTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + # Check Task Operator does not yet meet pre-requisites + assert not await operator.meets_task_prerequisites() + + # Create URLs with status 'validated' in database and all requisite URL values + # Ensure they have optional metadata as well + await setup_validated_urls(db_data_creator) + + # Check Task Operator does meet pre-requisites + assert await operator.meets_task_prerequisites() + + # Run Task + run_info = await operator.run_task(task_id=1) + + # Check Task has been marked as completed + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message + + # Get URLs + urls = await db_data_creator.adb_client.get_all(URL, order_by_attribute="id") + url_1 = urls[0] + url_2 = urls[1] + + # Check URLs have been marked as 'submitted' + assert url_1.outcome == URLStatus.SUBMITTED.value + assert url_2.outcome == URLStatus.SUBMITTED.value + + # Check URLs now have data source ids + assert url_1.data_source_id == 21 + assert url_2.data_source_id == 34 + + # Check mock method was called twice with expected parameters + access_manager = mock_pdap_client.access_manager + assert access_manager.make_request.call_count == 2 + # Check first call + + + call_1 = access_manager.make_request.call_args_list[0][0][0] + expected_call_1 = RequestInfo( + type_=RequestType.POST, + url="http://localhost:8000/data-sources", + headers=access_manager.jwt_header.return_value, + json={ + "entry_data": { + "name": "URL 1 Name", + "source_url": url_1.url, + "record_type_name": "Accident Reports", + "description": "URL 1 Description", + "record_formats": ["Record Format 1", "Record Format 2"], + "data_portal_type": "Data Portal Type 1", + "supplying_entity": "Supplying Entity 1" + }, + "linked_agency_ids": [1, 2] + } + ) + assert call_1.type_ == expected_call_1.type_ + assert call_1.url == expected_call_1.url + assert call_1.headers == expected_call_1.headers + assert call_1.json == expected_call_1.json + # Check second call + call_2 = access_manager.make_request.call_args_list[1][0][0] + expected_call_2 = RequestInfo( + type_=RequestType.POST, + url="http://localhost:8000/data-sources", + headers=access_manager.jwt_header.return_value, + json={ + "entry_data": { + "name": "URL 2 Name", + "source_url": url_2.url, + "record_type_name": "Incarceration Records", + "description": "URL 2 Description", + "data_portal_type": None, + "supplying_entity": None, + "record_formats": None + }, + "linked_agency_ids": [3, 4] + } + ) + assert call_2.type_ == expected_call_2.type_ + assert call_2.url == expected_call_2.url + assert call_2.headers == expected_call_2.headers + assert call_2.json == expected_call_2.json \ No newline at end of file From 82e65b9f13e008821c0de9ee429cefe9c0511cfa Mon Sep 17 00:00:00 2001 From: Max Chis Date: Tue, 15 Apr 2025 16:47:03 -0400 Subject: [PATCH 4/5] feat(app): add submit approved URL task --- ...3794fa4e9_add_submit_url_task_type_enum.py | 48 +++++ ...33d2e_revert_to_pending_validated_urls_.py | 42 +++++ api/main.py | 32 +++- api/routes/batch.py | 18 +- collector_db/AsyncDatabaseClient.py | 111 +++++++++-- collector_db/DTOs/URLInfo.py | 1 + collector_db/DatabaseClient.py | 96 +--------- collector_db/helper_functions.py | 13 +- collector_db/models.py | 6 +- core/AsyncCore.py | 29 ++- .../task_data_objects/SubmitApprovedURLTDO.py | 9 +- core/EnvVarManager.py | 76 ++++++++ core/SourceCollectorCore.py | 27 +-- core/TaskManager.py | 29 +-- core/classes/SubmitApprovedURLTaskOperator.py | 42 +++-- core/enums.py | 18 +- html_tag_collector/RootURLCache.py | 4 +- llm_api_logic/OpenAIRecordClassifier.py | 5 +- pdap_api_client/AccessManager.py | 4 +- pdap_api_client/DTOs.py | 1 + pdap_api_client/PDAPClient.py | 65 +++++-- start_mirrored_local_app.py | 62 ++++--- tests/conftest.py | 31 +++- tests/helpers/DBDataCreator.py | 12 +- .../integration/api/conftest.py | 12 +- .../integration/api/test_annotate.py | 3 - .../integration/api/test_duplicates.py | 3 +- .../collector_db/test_database_structure.py | 9 +- .../collector_db/test_db_client.py | 17 +- .../integration/core/test_async_core.py | 1 + .../tasks/test_submit_approved_url_task.py | 174 +++++++++++------- util/DiscordNotifier.py | 8 +- util/helper_functions.py | 13 ++ 33 files changed, 687 insertions(+), 334 deletions(-) create mode 100644 alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py create mode 100644 alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py create mode 100644 core/EnvVarManager.py diff --git a/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py b/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py new file mode 100644 index 00000000..e1d5b725 --- /dev/null +++ b/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py @@ -0,0 +1,48 @@ +"""Add Submit URL Task Type Enum + +Revision ID: b363794fa4e9 +Revises: 33a546c93441 +Create Date: 2025-04-15 13:38:58.293627 + +""" +from typing import Sequence, Union + + +from util.alembic_helpers import switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = 'b363794fa4e9' +down_revision: Union[str, None] = '33a546c93441' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + "Submit Approved URLs" + ] + ) + + +def downgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + ] + ) \ No newline at end of file diff --git a/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py b/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py new file mode 100644 index 00000000..82ce97a4 --- /dev/null +++ b/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py @@ -0,0 +1,42 @@ +"""Revert to pending validated URLs without name and add constraint + +Revision ID: ed06a5633d2e +Revises: b363794fa4e9 +Create Date: 2025-04-15 15:32:26.465488 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'ed06a5633d2e' +down_revision: Union[str, None] = 'b363794fa4e9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + + op.execute( + """ + UPDATE public.urls + SET OUTCOME = 'pending' + WHERE OUTCOME = 'validated' AND NAME IS NULL + """ + ) + + op.create_check_constraint( + 'url_name_not_null_when_validated', + 'urls', + "NAME IS NOT NULL OR OUTCOME != 'validated'" + ) + + +def downgrade() -> None: + op.drop_constraint( + 'url_name_not_null_when_validated', + 'urls', + type_='check' + ) diff --git a/api/main.py b/api/main.py index 19f8de8d..40970e4f 100644 --- a/api/main.py +++ b/api/main.py @@ -1,5 +1,6 @@ from contextlib import asynccontextmanager +import aiohttp import uvicorn from fastapi import FastAPI @@ -15,6 +16,7 @@ from collector_manager.AsyncCollectorManager import AsyncCollectorManager from core.AsyncCore import AsyncCore from core.AsyncCoreLogger import AsyncCoreLogger +from core.EnvVarManager import EnvVarManager from core.ScheduledTaskManager import AsyncScheduledTaskManager from core.SourceCollectorCore import SourceCollectorCore from core.TaskManager import TaskManager @@ -22,18 +24,27 @@ from html_tag_collector.RootURLCache import RootURLCache from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface +from pdap_api_client.AccessManager import AccessManager +from pdap_api_client.PDAPClient import PDAPClient from util.DiscordNotifier import DiscordPoster -from util.helper_functions import get_from_env + @asynccontextmanager async def lifespan(app: FastAPI): + env_var_manager = EnvVarManager.get() + # Initialize shared dependencies - db_client = DatabaseClient() - adb_client = AsyncDatabaseClient() + db_client = DatabaseClient( + db_url=env_var_manager.get_postgres_connection_string() + ) + adb_client = AsyncDatabaseClient( + db_url=env_var_manager.get_postgres_connection_string(is_async=True) + ) await setup_database(db_client) core_logger = AsyncCoreLogger(adb_client=adb_client) + session = aiohttp.ClientSession() source_collector_core = SourceCollectorCore( db_client=DatabaseClient(), @@ -46,7 +57,15 @@ async def lifespan(app: FastAPI): root_url_cache=RootURLCache() ), discord_poster=DiscordPoster( - webhook_url=get_from_env("DISCORD_WEBHOOK_URL") + webhook_url=env_var_manager.discord_webhook_url + ), + pdap_client=PDAPClient( + access_manager=AccessManager( + email=env_var_manager.pdap_email, + password=env_var_manager.pdap_password, + api_key=env_var_manager.pdap_api_key, + session=session + ) ) ) async_collector_manager = AsyncCollectorManager( @@ -72,17 +91,17 @@ async def lifespan(app: FastAPI): yield # Code here runs before shutdown # Shutdown logic (if needed) + # Clean up resources, close connections, etc. await core_logger.shutdown() await async_core.shutdown() source_collector_core.shutdown() - # Clean up resources, close connections, etc. + await session.close() pass async def setup_database(db_client): # Initialize database if dev environment, otherwise apply migrations try: - get_from_env("DEV") db_client.init_db() except Exception as e: return @@ -95,6 +114,7 @@ async def setup_database(db_client): lifespan=lifespan ) + routers = [ root_router, collector_router, diff --git a/api/routes/batch.py b/api/routes/batch.py index 23df2394..9d4b62cc 100644 --- a/api/routes/batch.py +++ b/api/routes/batch.py @@ -25,7 +25,7 @@ @batch_router.get("") -def get_batch_status( +async def get_batch_status( collector_type: Optional[CollectorType] = Query( description="Filter by collector type", default=None @@ -38,13 +38,13 @@ def get_batch_status( description="The page number", default=1 ), - core: SourceCollectorCore = Depends(get_core), + core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetBatchStatusResponse: """ Get the status of recent batches """ - return core.get_batch_statuses(collector_type=collector_type, status=status, page=page) + return await core.get_batch_statuses(collector_type=collector_type, status=status, page=page) @batch_router.get("/{batch_id}") @@ -69,28 +69,28 @@ async def get_urls_by_batch( return await core.get_urls_by_batch(batch_id, page=page) @batch_router.get("/{batch_id}/duplicates") -def get_duplicates_by_batch( +async def get_duplicates_by_batch( batch_id: int = Path(description="The batch id"), page: int = Query( description="The page number", default=1 ), - core: SourceCollectorCore = Depends(get_core), + core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetDuplicatesByBatchResponse: - return core.get_duplicate_urls_by_batch(batch_id, page=page) + return await core.get_duplicate_urls_by_batch(batch_id, page=page) @batch_router.get("/{batch_id}/logs") -def get_batch_logs( +async def get_batch_logs( batch_id: int = Path(description="The batch id"), - core: SourceCollectorCore = Depends(get_core), + async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetBatchLogsResponse: """ Retrieve the logs for a recent batch. Note that for later batches, the logs may not be available. """ - return core.get_batch_logs(batch_id) + return await async_core.get_batch_logs(batch_id) @batch_router.post("/{batch_id}/abort") async def abort_batch( diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index c44468a4..98410b6f 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -5,16 +5,16 @@ from sqlalchemy import select, exists, func, case, desc, Select, not_, and_, or_, update, Delete, Insert, asc from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.orm import selectinload, joinedload, QueryableAttribute +from sqlalchemy.orm import selectinload, joinedload, QueryableAttribute, aliased from sqlalchemy.sql.functions import coalesce from starlette import status from collector_db.ConfigManager import ConfigManager from collector_db.DTOConverter import DTOConverter from collector_db.DTOs.BatchInfo import BatchInfo -from collector_db.DTOs.DuplicateInfo import DuplicateInsertInfo +from collector_db.DTOs.DuplicateInfo import DuplicateInsertInfo, DuplicateInfo from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo -from collector_db.DTOs.LogInfo import LogInfo +from collector_db.DTOs.LogInfo import LogInfo, LogOutputInfo from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo, HTMLContentType @@ -41,8 +41,9 @@ GetURLsResponseInnerInfo from core.DTOs.URLAgencySuggestionInfo import URLAgencySuggestionInfo from core.DTOs.task_data_objects.AgencyIdentificationTDO import AgencyIdentificationTDO -from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO, SubmittedURLInfo from core.DTOs.task_data_objects.URLMiscellaneousMetadataTDO import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo +from core.EnvVarManager import EnvVarManager from core.enums import BatchStatus, SuggestionType, RecordType from html_tag_collector.DataClassTags import convert_to_response_html_info @@ -58,7 +59,9 @@ def add_standard_limit_and_offset(statement, page, limit=100): class AsyncDatabaseClient: - def __init__(self, db_url: str = get_postgres_connection_string(is_async=True)): + def __init__(self, db_url: Optional[str] = None): + if db_url is None: + db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) self.engine = create_async_engine( url=db_url, echo=ConfigManager.get_sqlalchemy_echo(), @@ -1487,8 +1490,9 @@ async def get_validated_urls( .where(URL.outcome == URLStatus.VALIDATED.value) .options( selectinload(URL.optional_data_source_metadata), - selectinload(URL.confirmed_agencies) - ) + selectinload(URL.confirmed_agencies), + selectinload(URL.reviewing_user) + ).limit(100) ) urls = await session.execute(query) urls = urls.scalars().all() @@ -1497,6 +1501,17 @@ async def get_validated_urls( agency_ids = [] for agency in url.confirmed_agencies: agency_ids.append(agency.agency_id) + optional_metadata = url.optional_data_source_metadata + + if optional_metadata is None: + record_formats = None + data_portal_type = None + supplying_entity = None + else: + record_formats = optional_metadata.record_formats + data_portal_type = optional_metadata.data_portal_type + supplying_entity = optional_metadata.supplying_entity + tdo = SubmitApprovedURLTDO( url_id=url.id, url=url.url, @@ -1504,18 +1519,19 @@ async def get_validated_urls( agency_ids=agency_ids, description=url.description, record_type=url.record_type, - record_formats=url.optional_data_source_metadata.record_formats, - data_portal_type=url.optional_data_source_metadata.data_portal_type, - supplying_entity=url.optional_data_source_metadata.supplying_entity, + record_formats=record_formats, + data_portal_type=data_portal_type, + supplying_entity=supplying_entity, + approving_user_id=url.reviewing_user.user_id ) results.append(tdo) return results @session_manager - async def mark_urls_as_submitted(self, session: AsyncSession, tdos: list[SubmitApprovedURLTDO]): - for tdo in tdos: - url_id = tdo.url_id - data_source_id = tdo.data_source_id + async def mark_urls_as_submitted(self, session: AsyncSession, infos: list[SubmittedURLInfo]): + for info in infos: + url_id = info.url_id + data_source_id = info.data_source_id query = ( update(URL) .where(URL.id == url_id) @@ -1526,3 +1542,70 @@ async def mark_urls_as_submitted(self, session: AsyncSession, tdos: list[SubmitA ) await session.execute(query) + @session_manager + async def get_duplicates_by_batch_id(self, session, batch_id: int, page: int) -> List[DuplicateInfo]: + original_batch = aliased(Batch) + duplicate_batch = aliased(Batch) + + query = ( + Select( + URL.url.label("source_url"), + URL.id.label("original_url_id"), + duplicate_batch.id.label("duplicate_batch_id"), + duplicate_batch.parameters.label("duplicate_batch_parameters"), + original_batch.id.label("original_batch_id"), + original_batch.parameters.label("original_batch_parameters"), + ) + .select_from(Duplicate) + .join(URL, Duplicate.original_url_id == URL.id) + .join(duplicate_batch, Duplicate.batch_id == duplicate_batch.id) + .join(original_batch, URL.batch_id == original_batch.id) + .filter(duplicate_batch.id == batch_id) + .limit(100) + .offset((page - 1) * 100) + ) + raw_results = await session.execute(query) + results = raw_results.all() + final_results = [] + for result in results: + final_results.append( + DuplicateInfo( + source_url=result.source_url, + duplicate_batch_id=result.duplicate_batch_id, + duplicate_metadata=result.duplicate_batch_parameters, + original_batch_id=result.original_batch_id, + original_metadata=result.original_batch_parameters, + original_url_id=result.original_url_id + ) + ) + return final_results + + @session_manager + async def get_recent_batch_status_info( + self, + session, + page: int, + collector_type: Optional[CollectorType] = None, + status: Optional[BatchStatus] = None, + ) -> List[BatchInfo]: + # Get only the batch_id, collector_type, status, and created_at + limit = 100 + query = (Select(Batch) + .order_by(Batch.date_generated.desc())) + if collector_type: + query = query.filter(Batch.strategy == collector_type.value) + if status: + query = query.filter(Batch.status == status.value) + query = (query.limit(limit) + .offset((page - 1) * limit)) + raw_results = await session.execute(query) + batches = raw_results.scalars().all() + return [BatchInfo(**batch.__dict__) for batch in batches] + + @session_manager + async def get_logs_by_batch_id(self, session, batch_id: int) -> List[LogOutputInfo]: + query = Select(Log).filter_by(batch_id=batch_id).order_by(Log.created_at.asc()) + raw_results = await session.execute(query) + logs = raw_results.scalars().all() + return ([LogOutputInfo(**log.__dict__) for log in logs]) + diff --git a/collector_db/DTOs/URLInfo.py b/collector_db/DTOs/URLInfo.py index afe6c2f2..c47d2830 100644 --- a/collector_db/DTOs/URLInfo.py +++ b/collector_db/DTOs/URLInfo.py @@ -13,3 +13,4 @@ class URLInfo(BaseModel): collector_metadata: Optional[dict] = None outcome: URLStatus = URLStatus.PENDING updated_at: Optional[datetime.datetime] = None + name: Optional[str] = None diff --git a/collector_db/DatabaseClient.py b/collector_db/DatabaseClient.py index 8d72ef0d..b8547f1d 100644 --- a/collector_db/DatabaseClient.py +++ b/collector_db/DatabaseClient.py @@ -16,13 +16,17 @@ from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base, Batch, URL, Log, Duplicate from collector_manager.enums import CollectorType +from core.EnvVarManager import EnvVarManager from core.enums import BatchStatus # Database Client class DatabaseClient: - def __init__(self, db_url: str = get_postgres_connection_string()): + def __init__(self, db_url: Optional[str] = None): """Initialize the DatabaseClient.""" + if db_url is None: + db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) + self.engine = create_engine( url=db_url, echo=ConfigManager.get_sqlalchemy_echo(), @@ -49,10 +53,6 @@ def wrapper(self, *args, **kwargs): return wrapper - def row_to_dict(self, row: Row) -> dict: - return dict(row._mapping) - - @session_manager def insert_batch(self, session, batch_info: BatchInfo) -> int: """Insert a new batch into the database and return its ID.""" @@ -105,24 +105,14 @@ def insert_url(self, session, url_info: URLInfo) -> int: batch_id=url_info.batch_id, url=url_info.url, collector_metadata=url_info.collector_metadata, - outcome=url_info.outcome.value + outcome=url_info.outcome.value, + name=url_info.name ) session.add(url_entry) session.commit() session.refresh(url_entry) return url_entry.id - @session_manager - def add_duplicate_info(self, session, duplicate_infos: list[DuplicateInfo]): - # TODO: Add test for this method when testing CollectorDatabaseProcessor - for duplicate_info in duplicate_infos: - duplicate = Duplicate( - batch_id=duplicate_info.original_batch_id, - original_url_id=duplicate_info.original_url_id, - ) - session.add(duplicate) - - def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo: url_mappings = [] duplicates = [] @@ -163,83 +153,11 @@ def insert_logs(self, session, log_infos: List[LogInfo]): log.created_at = log_info.created_at session.add(log) - @session_manager - def get_logs_by_batch_id(self, session, batch_id: int) -> List[LogOutputInfo]: - logs = session.query(Log).filter_by(batch_id=batch_id).order_by(Log.created_at.asc()).all() - return ([LogOutputInfo(**log.__dict__) for log in logs]) - - @session_manager - def get_all_logs(self, session) -> List[LogInfo]: - logs = session.query(Log).all() - return ([LogInfo(**log.__dict__) for log in logs]) - @session_manager def get_batch_status(self, session, batch_id: int) -> BatchStatus: batch = session.query(Batch).filter_by(id=batch_id).first() return BatchStatus(batch.status) - @session_manager - def get_recent_batch_status_info( - self, - session, - page: int, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - ) -> List[BatchInfo]: - # Get only the batch_id, collector_type, status, and created_at - limit = 100 - query = (session.query(Batch) - .order_by(Batch.date_generated.desc())) - if collector_type: - query = query.filter(Batch.strategy == collector_type.value) - if status: - query = query.filter(Batch.status == status.value) - query = (query.limit(limit) - .offset((page - 1) * limit)) - batches = query.all() - return [BatchInfo(**batch.__dict__) for batch in batches] - - @session_manager - def get_duplicates_by_batch_id(self, session, batch_id: int, page: int) -> List[DuplicateInfo]: - original_batch = aliased(Batch) - duplicate_batch = aliased(Batch) - - query = ( - session.query( - URL.url.label("source_url"), - URL.id.label("original_url_id"), - duplicate_batch.id.label("duplicate_batch_id"), - duplicate_batch.parameters.label("duplicate_batch_parameters"), - original_batch.id.label("original_batch_id"), - original_batch.parameters.label("original_batch_parameters"), - ) - .select_from(Duplicate) - .join(URL, Duplicate.original_url_id == URL.id) - .join(duplicate_batch, Duplicate.batch_id == duplicate_batch.id) - .join(original_batch, URL.batch_id == original_batch.id) - .filter(duplicate_batch.id == batch_id) - .limit(100) - .offset((page - 1) * 100) - ) - results = query.all() - final_results = [] - for result in results: - final_results.append( - DuplicateInfo( - source_url=result.source_url, - duplicate_batch_id=result.duplicate_batch_id, - duplicate_metadata=result.duplicate_batch_parameters, - original_batch_id=result.original_batch_id, - original_metadata=result.original_batch_parameters, - original_url_id=result.original_url_id - ) - ) - return final_results - - @session_manager - def delete_all_logs(self, session): - session.query(Log).delete() - @session_manager def delete_old_logs(self, session): """ diff --git a/collector_db/helper_functions.py b/collector_db/helper_functions.py index dcb161b9..4f99556a 100644 --- a/collector_db/helper_functions.py +++ b/collector_db/helper_functions.py @@ -2,15 +2,8 @@ import dotenv +from core.EnvVarManager import EnvVarManager + def get_postgres_connection_string(is_async = False): - dotenv.load_dotenv() - username = os.getenv("POSTGRES_USER") - password = os.getenv("POSTGRES_PASSWORD") - host = os.getenv("POSTGRES_HOST") - port = os.getenv("POSTGRES_PORT") - database = os.getenv("POSTGRES_DB") - driver = "postgresql" - if is_async: - driver += "+asyncpg" - return f"{driver}://{username}:{password}@{host}:{port}/{database}" \ No newline at end of file + return EnvVarManager.get().get_postgres_connection_string(is_async) diff --git a/collector_db/models.py b/collector_db/models.py index e0fd1b88..e98ef437 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -129,8 +129,8 @@ class URL(Base): "AutoRelevantSuggestion", uselist=False, back_populates="url") user_relevant_suggestions = relationship( "UserRelevantSuggestion", back_populates="url") - reviewing_users = relationship( - "ReviewingUserURL", back_populates="url") + reviewing_user = relationship( + "ReviewingUserURL", uselist=False, back_populates="url") optional_data_source_metadata = relationship( "URLOptionalDataSourceMetadata", uselist=False, back_populates="url") confirmed_agencies = relationship( @@ -164,7 +164,7 @@ class ReviewingUserURL(Base): created_at = get_created_at_column() # Relationships - url = relationship("URL", back_populates="reviewing_users") + url = relationship("URL", uselist=False, back_populates="reviewing_user") class RootURL(Base): __tablename__ = 'root_url_cache' diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 85762c85..cb9a80bc 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -10,6 +10,9 @@ from collector_manager.enums import CollectorType from core.DTOs.CollectorStartInfo import CollectorStartInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo +from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse +from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse +from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ @@ -17,11 +20,10 @@ from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsByBatchResponse import GetURLsByBatchResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo +from core.DTOs.MessageResponse import MessageResponse from core.TaskManager import TaskManager -from core.enums import BatchStatus, SuggestionType, RecordType +from core.enums import BatchStatus, RecordType -from pdap_api_client.AccessManager import AccessManager -from pdap_api_client.PDAPClient import PDAPClient from security_manager.SecurityManager import AccessInfo @@ -57,6 +59,27 @@ async def abort_batch(self, batch_id: int) -> MessageResponse: await self.collector_manager.abort_collector_async(cid=batch_id) return MessageResponse(message=f"Batch aborted.") + async def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> GetDuplicatesByBatchResponse: + dup_infos = await self.adb_client.get_duplicates_by_batch_id(batch_id, page=page) + return GetDuplicatesByBatchResponse(duplicates=dup_infos) + + async def get_batch_statuses( + self, + collector_type: Optional[CollectorType], + status: Optional[BatchStatus], + page: int + ) -> GetBatchStatusResponse: + results = await self.adb_client.get_recent_batch_status_info( + collector_type=collector_type, + status=status, + page=page + ) + return GetBatchStatusResponse(results=results) + + async def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: + logs = await self.adb_client.get_logs_by_batch_id(batch_id) + return GetBatchLogsResponse(logs=logs) + #endregion # region Collector diff --git a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py index 45fa7daf..c5b002d0 100644 --- a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py +++ b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py @@ -12,7 +12,14 @@ class SubmitApprovedURLTDO(BaseModel): agency_ids: list[int] name: str description: str + approving_user_id: int record_formats: Optional[list[str]] = None data_portal_type: Optional[str] = None supplying_entity: Optional[str] = None - data_source_id: Optional[int] = None \ No newline at end of file + data_source_id: Optional[int] = None + request_error: Optional[str] = None + +class SubmittedURLInfo(BaseModel): + url_id: int + data_source_id: Optional[int] + request_error: Optional[str] \ No newline at end of file diff --git a/core/EnvVarManager.py b/core/EnvVarManager.py new file mode 100644 index 00000000..39e4ce83 --- /dev/null +++ b/core/EnvVarManager.py @@ -0,0 +1,76 @@ +import os + +class EnvVarManager: + _instance = None + _allow_direct_init = False # internal flag + + """ + A class for unified management of environment variables + """ + def __new__(cls, *args, **kwargs): + if not cls._allow_direct_init: + raise RuntimeError("Use `EnvVarManager.get()` or `EnvVarManager.override()` instead.") + return super().__new__(cls) + + def __init__(self, env: dict = os.environ): + self.env = env + self._load() + + def _load(self): + + self.google_api_key = self.require_env("GOOGLE_API_KEY") + self.google_cse_id = self.require_env("GOOGLE_CSE_ID") + + self.pdap_email = self.require_env("PDAP_EMAIL") + self.pdap_password = self.require_env("PDAP_PASSWORD") + self.pdap_api_key = self.require_env("PDAP_API_KEY") + self.pdap_api_url = self.require_env("PDAP_API_URL") + + self.discord_webhook_url = self.require_env("DISCORD_WEBHOOK_URL") + + self.openai_api_key = self.require_env("OPENAI_API_KEY") + + self.postgres_user = self.require_env("POSTGRES_USER") + self.postgres_password = self.require_env("POSTGRES_PASSWORD") + self.postgres_host = self.require_env("POSTGRES_HOST") + self.postgres_port = self.require_env("POSTGRES_PORT") + self.postgres_db = self.require_env("POSTGRES_DB") + + @classmethod + def get(cls): + """ + Get the singleton instance, loading from environment if not yet + instantiated + """ + if cls._instance is None: + cls._allow_direct_init = True + cls._instance = cls(os.environ) + cls._allow_direct_init = False + return cls._instance + + @classmethod + def override(cls, env: dict): + """ + Create singleton instance that + overrides the environment variables with injected values + """ + cls._allow_direct_init = True + cls._instance = cls(env) + cls._allow_direct_init = False + + @classmethod + def reset(cls): + cls._instance = None + + def get_postgres_connection_string(self, is_async = False): + driver = "postgresql" + if is_async: + driver += "+asyncpg" + return (f"{driver}://{self.postgres_user}:{self.postgres_password}" + f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}") + + def require_env(self, key: str, allow_none: bool = False): + val = self.env.get(key) + if val is None and not allow_none: + raise ValueError(f"Environment variable {key} is not set") + return val \ No newline at end of file diff --git a/core/SourceCollectorCore.py b/core/SourceCollectorCore.py index 8002717c..4516ceb5 100644 --- a/core/SourceCollectorCore.py +++ b/core/SourceCollectorCore.py @@ -2,10 +2,6 @@ from collector_db.DatabaseClient import DatabaseClient -from collector_manager.enums import CollectorType -from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse -from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse -from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.ScheduledTaskManager import ScheduledTaskManager from core.enums import BatchStatus @@ -15,38 +11,21 @@ def __init__( self, core_logger: Optional[Any] = None, # Deprecated collector_manager: Optional[Any] = None, # Deprecated - db_client: DatabaseClient = DatabaseClient(), + db_client: Optional[DatabaseClient] = None, dev_mode: bool = False ): + if db_client is None: + db_client = DatabaseClient() self.db_client = db_client if not dev_mode: self.scheduled_task_manager = ScheduledTaskManager(db_client=db_client) else: self.scheduled_task_manager = None - def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> GetDuplicatesByBatchResponse: - dup_infos = self.db_client.get_duplicates_by_batch_id(batch_id, page=page) - return GetDuplicatesByBatchResponse(duplicates=dup_infos) - - def get_batch_statuses( - self, - collector_type: Optional[CollectorType], - status: Optional[BatchStatus], - page: int - ) -> GetBatchStatusResponse: - results = self.db_client.get_recent_batch_status_info( - collector_type=collector_type, - status=status, - page=page - ) - return GetBatchStatusResponse(results=results) def get_status(self, batch_id: int) -> BatchStatus: return self.db_client.get_batch_status(batch_id) - def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: - logs = self.db_client.get_logs_by_batch_id(batch_id) - return GetBatchLogsResponse(logs=logs) def shutdown(self): if self.scheduled_task_manager is not None: diff --git a/core/TaskManager.py b/core/TaskManager.py index 624cb906..7796e80e 100644 --- a/core/TaskManager.py +++ b/core/TaskManager.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient @@ -9,6 +8,7 @@ from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome from core.FunctionTrigger import FunctionTrigger from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator +from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator from core.classes.TaskOperatorBase import TaskOperatorBase from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator from core.classes.URLMiscellaneousMetadataTaskOperator import URLMiscellaneousMetadataTaskOperator @@ -19,10 +19,8 @@ from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier -from pdap_api_client.AccessManager import AccessManager from pdap_api_client.PDAPClient import PDAPClient from util.DiscordNotifier import DiscordPoster -from util.helper_functions import get_from_env TASK_REPEAT_THRESHOLD = 20 @@ -35,12 +33,16 @@ def __init__( url_request_interface: URLRequestInterface, html_parser: HTMLResponseParser, discord_poster: DiscordPoster, + pdap_client: PDAPClient ): + # Dependencies self.adb_client = adb_client + self.pdap_client = pdap_client self.huggingface_interface = huggingface_interface self.url_request_interface = url_request_interface self.html_parser = html_parser self.discord_poster = discord_poster + self.logger = logging.getLogger(__name__) self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) @@ -73,21 +75,21 @@ async def get_url_record_type_task_operator(self): return operator async def get_agency_identification_task_operator(self): - pdap_client = PDAPClient( - access_manager=AccessManager( - email=get_from_env("PDAP_EMAIL"), - password=get_from_env("PDAP_PASSWORD"), - api_key=get_from_env("PDAP_API_KEY"), - ), - ) muckrock_api_interface = MuckrockAPIInterface() operator = AgencyIdentificationTaskOperator( adb_client=self.adb_client, - pdap_client=pdap_client, + pdap_client=self.pdap_client, muckrock_api_interface=muckrock_api_interface ) return operator + async def get_submit_approved_url_task_operator(self): + operator = SubmitApprovedURLTaskOperator( + adb_client=self.adb_client, + pdap_client=self.pdap_client + ) + return operator + async def get_url_miscellaneous_metadata_task_operator(self): operator = URLMiscellaneousMetadataTaskOperator( adb_client=self.adb_client @@ -96,11 +98,12 @@ async def get_url_miscellaneous_metadata_task_operator(self): async def get_task_operators(self) -> list[TaskOperatorBase]: return [ - # await self.get_url_html_task_operator(), + await self.get_url_html_task_operator(), await self.get_url_relevance_huggingface_task_operator(), await self.get_url_record_type_task_operator(), await self.get_agency_identification_task_operator(), - await self.get_url_miscellaneous_metadata_task_operator() + await self.get_url_miscellaneous_metadata_task_operator(), + await self.get_submit_approved_url_task_operator() ] #endregion diff --git a/core/classes/SubmitApprovedURLTaskOperator.py b/core/classes/SubmitApprovedURLTaskOperator.py index 2a308e7c..81f0b242 100644 --- a/core/classes/SubmitApprovedURLTaskOperator.py +++ b/core/classes/SubmitApprovedURLTaskOperator.py @@ -31,23 +31,35 @@ async def inner_task_logic(self): await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) # Submit each URL, recording errors if they exist - error_infos: list[URLErrorPydanticInfo] = [] - success_tdos: list[SubmitApprovedURLTDO] = [] - for tdo in tdos: - try: - data_source_id = await self.pdap_client.submit_url(tdo) - tdo.data_source_id = data_source_id - success_tdos.append(tdo) - except Exception as e: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, - url_id=tdo.url_id, - error=str(e), - ) - error_infos.append(error_info) + submitted_url_infos = await self.pdap_client.submit_urls(tdos) + + error_infos = await self.get_error_infos(submitted_url_infos) + success_infos = await self.get_success_infos(submitted_url_infos) # Update the database for successful submissions - await self.adb_client.mark_urls_as_submitted(tdos=success_tdos) + await self.adb_client.mark_urls_as_submitted(infos=success_infos) # Update the database for failed submissions await self.adb_client.add_url_error_infos(error_infos) + + async def get_success_infos(self, submitted_url_infos): + success_infos = [ + response_object for response_object in submitted_url_infos + if response_object.data_source_id is not None + ] + return success_infos + + async def get_error_infos(self, submitted_url_infos): + error_infos: list[URLErrorPydanticInfo] = [] + error_response_objects = [ + response_object for response_object in submitted_url_infos + if response_object.request_error is not None + ] + for error_response_object in error_response_objects: + error_info = URLErrorPydanticInfo( + task_id=self.task_id, + url_id=error_response_object.url_id, + error=error_response_object.request_error, + ) + error_infos.append(error_info) + return error_infos diff --git a/core/enums.py b/core/enums.py index 213db47c..cfccbb92 100644 --- a/core/enums.py +++ b/core/enums.py @@ -7,11 +7,10 @@ class BatchStatus(Enum): ERROR = "error" ABORTED = "aborted" -class LabelStudioTaskStatus(Enum): - PENDING = "pending" - COMPLETED = "completed" - class RecordType(Enum): + """ + All available URL record types + """ ACCIDENT_REPORTS = "Accident Reports" ARREST_RECORDS = "Arrest Records" CALLS_FOR_SERVICE = "Calls for Service" @@ -51,8 +50,19 @@ class RecordType(Enum): class SuggestionType(Enum): + """ + Identifies the specific kind of suggestion made for a URL + """ AUTO_SUGGESTION = "Auto Suggestion" MANUAL_SUGGESTION = "Manual Suggestion" UNKNOWN = "Unknown" NEW_AGENCY = "New Agency" CONFIRMED = "Confirmed" + +class SubmitResponseStatus(Enum): + """ + Response statuses from the /source-collector/data-sources endpoint + """ + SUCCESS = "success" + FAILURE = "FAILURE" + ALREADY_EXISTS = "already_exists" \ No newline at end of file diff --git a/html_tag_collector/RootURLCache.py b/html_tag_collector/RootURLCache.py index e306b6e1..165be89d 100644 --- a/html_tag_collector/RootURLCache.py +++ b/html_tag_collector/RootURLCache.py @@ -16,7 +16,9 @@ class RootURLCacheResponseInfo: exception: Optional[Exception] = None class RootURLCache: - def __init__(self, adb_client: AsyncDatabaseClient = AsyncDatabaseClient()): + def __init__(self, adb_client: Optional[AsyncDatabaseClient] = None): + if adb_client is None: + adb_client = AsyncDatabaseClient() self.adb_client = adb_client self.cache = None diff --git a/llm_api_logic/OpenAIRecordClassifier.py b/llm_api_logic/OpenAIRecordClassifier.py index fc20a0e2..cc0829b5 100644 --- a/llm_api_logic/OpenAIRecordClassifier.py +++ b/llm_api_logic/OpenAIRecordClassifier.py @@ -1,17 +1,16 @@ -from typing import Any from openai.types.chat import ParsedChatCompletion +from core.EnvVarManager import EnvVarManager from llm_api_logic.LLMRecordClassifierBase import RecordClassifierBase from llm_api_logic.RecordTypeStructuredOutput import RecordTypeStructuredOutput -from util.helper_functions import get_from_env class OpenAIRecordClassifier(RecordClassifierBase): @property def api_key(self): - return get_from_env("OPENAI_API_KEY") + return EnvVarManager.get().openai_api_key @property def model_name(self): diff --git a/pdap_api_client/AccessManager.py b/pdap_api_client/AccessManager.py index 1020f365..aadd8451 100644 --- a/pdap_api_client/AccessManager.py +++ b/pdap_api_client/AccessManager.py @@ -4,8 +4,8 @@ import requests from aiohttp import ClientSession +from core.EnvVarManager import EnvVarManager from pdap_api_client.DTOs import RequestType, Namespaces, RequestInfo, ResponseInfo -from util.helper_functions import get_from_env request_methods = { RequestType.POST: ClientSession.post, @@ -23,7 +23,7 @@ def build_url( namespace: Namespaces, subdomains: Optional[list[str]] = None ): - api_url = get_from_env('PDAP_API_URL') + api_url = EnvVarManager.get().pdap_api_url url = f"{api_url}/{namespace.value}" if subdomains is not None: url = f"{url}/{'/'.join(subdomains)}" diff --git a/pdap_api_client/DTOs.py b/pdap_api_client/DTOs.py index 37d7e857..93f67839 100644 --- a/pdap_api_client/DTOs.py +++ b/pdap_api_client/DTOs.py @@ -37,6 +37,7 @@ class Namespaces(Enum): MATCH = "match" CHECK = "check" DATA_SOURCES = "data-sources" + SOURCE_COLLECTOR = "source-collector" class RequestType(Enum): diff --git a/pdap_api_client/PDAPClient.py b/pdap_api_client/PDAPClient.py index 8b1c5e82..24b9d98c 100644 --- a/pdap_api_client/PDAPClient.py +++ b/pdap_api_client/PDAPClient.py @@ -1,6 +1,6 @@ from typing import Optional -from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO, SubmittedURLInfo from pdap_api_client.AccessManager import build_url, AccessManager from pdap_api_client.DTOs import MatchAgencyInfo, UniqueURLDuplicateInfo, UniqueURLResponseInfo, Namespaces, \ RequestType, RequestInfo, MatchAgencyResponse @@ -85,30 +85,59 @@ async def is_url_unique( duplicates=duplicates ) - async def submit_url( + async def submit_urls( self, - tdo: SubmitApprovedURLTDO - ) -> int: - url = build_url( - namespace=Namespaces.DATA_SOURCES, + tdos: list[SubmitApprovedURLTDO] + ) -> list[SubmittedURLInfo]: + """ + Submits URLs to Data Sources App, + modifying tdos in-place with data source id or error + """ + request_url = build_url( + namespace=Namespaces.SOURCE_COLLECTOR, + subdomains=["data-sources"] ) + + # Build url-id dictionary + url_id_dict = {} + for tdo in tdos: + url_id_dict[tdo.url] = tdo.url_id + + data_sources_json = [] + for tdo in tdos: + data_sources_json.append({ + "name": tdo.name, + "description": tdo.description, + "source_url": tdo.url, + "record_type": tdo.record_type.value, + "record_formats": tdo.record_formats, + "data_portal_type": tdo.data_portal_type, + "last_approval_editor": tdo.approving_user_id, + "supplying_entity": tdo.supplying_entity, + "agency_ids": tdo.agency_ids + }) + + headers = await self.access_manager.jwt_header() request_info = RequestInfo( type_=RequestType.POST, - url=url, + url=request_url, headers=headers, json={ - "entry_data": { - "name": tdo.name, - "description": tdo.description, - "source_url": tdo.url, - "record_type_name": tdo.record_type.value, - "record_formats": tdo.record_formats, - "data_portal_type": tdo.data_portal_type, - "supplying_entity": tdo.supplying_entity - }, - "linked_agency_ids": tdo.agency_ids + "data_sources": data_sources_json } ) response_info = await self.access_manager.make_request(request_info) - return response_info.data["id"] + data_sources_response_json = response_info.data["data_sources"] + + results = [] + for data_source in data_sources_response_json: + url = data_source["url"] + response_object = SubmittedURLInfo( + url_id=url_id_dict[url], + data_source_id=data_source["data_source_id"], + request_error=data_source["error"] + ) + results.append(response_object) + + return results diff --git a/start_mirrored_local_app.py b/start_mirrored_local_app.py index 48859adc..940c372e 100644 --- a/start_mirrored_local_app.py +++ b/start_mirrored_local_app.py @@ -19,7 +19,7 @@ import docker from docker.errors import APIError, NotFound from docker.models.containers import Container -from pydantic import BaseModel, model_validator, AfterValidator +from pydantic import BaseModel, AfterValidator from apply_migrations import apply_migrations from util.helper_functions import get_from_env @@ -193,7 +193,7 @@ def get_image(self, dockerfile_info: DockerfileInfo): def run_container( self, docker_info: DockerInfo, - ): + ) -> Container: print(f"Running container {docker_info.name}") try: container = self.client.containers.get(docker_info.name) @@ -255,24 +255,11 @@ def set_last_run_time(self): with open("local_state/last_run.txt", "w") as f: f.write(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - - -def main(): - docker_manager = DockerManager() - # Ensure docker is running, and start if not - if not is_docker_running(): - start_docker_engine() - - - # Ensure Dockerfile for database is running, and if not, start it - database_docker_info = DockerInfo( +def get_database_docker_info() -> DockerInfo: + return DockerInfo( dockerfile_info=DockerfileInfo( image_tag="postgres:15", ), - # volume_info=VolumeInfo( - # host_path="dbscripts", - # container_path="/var/lib/postgresql/data" - # ), name="data_source_identification_db", ports={ "5432/tcp": 5432 @@ -290,12 +277,9 @@ def main(): start_period=2 ) ) - container = docker_manager.run_container(database_docker_info) - wait_for_pg_to_be_ready(container) - - # Start dockerfile for Datadumper - data_dumper_docker_info = DockerInfo( +def get_data_dumper_docker_info() -> DockerInfo: + return DockerInfo( dockerfile_info=DockerfileInfo( image_tag="datadumper", dockerfile_directory="local_database/DataDumper" @@ -320,6 +304,21 @@ def main(): command="bash" ) +def main(): + docker_manager = DockerManager() + # Ensure docker is running, and start if not + if not is_docker_running(): + start_docker_engine() + + # Ensure Dockerfile for database is running, and if not, start it + database_docker_info = get_database_docker_info() + container = docker_manager.run_container(database_docker_info) + wait_for_pg_to_be_ready(container) + + + # Start dockerfile for Datadumper + data_dumper_docker_info = get_data_dumper_docker_info() + # If not last run within 24 hours, run dump operation in Datadumper # Check cache if exists and checker = TimestampChecker() @@ -343,11 +342,20 @@ def main(): apply_migrations() # Run `fastapi dev main.py` - uvicorn.run( - "api.main:app", - host="0.0.0.0", - port=8000 - ) + try: + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=8000 + ) + finally: + # Add feature to stop all running containers + print("Stopping containers...") + for container in docker_manager.client.containers.list(): + container.stop() + + print("Containers stopped.") + diff --git a/tests/conftest.py b/tests/conftest.py index 8aeb6dc6..d7b1bce7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import pytest -from alembic import command from alembic.config import Config from sqlalchemy import create_engine, inspect, MetaData from sqlalchemy.orm import scoped_session, sessionmaker @@ -8,12 +7,42 @@ from collector_db.DatabaseClient import DatabaseClient from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base +from core.EnvVarManager import EnvVarManager from tests.helpers.AlembicRunner import AlembicRunner from tests.helpers.DBDataCreator import DBDataCreator +from util.helper_functions import load_from_environment @pytest.fixture(autouse=True, scope="session") def setup_and_teardown(): + # Set up environment variables that must be defined + # outside of tests + required_env_vars: dict = load_from_environment( + keys=[ + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_HOST", + "POSTGRES_PORT", + "POSTGRES_DB", + ] + ) + # Add test environment variables + test_env_vars = [ + "GOOGLE_API_KEY", + "GOOGLE_CSE_ID", + "PDAP_EMAIL", + "PDAP_PASSWORD", + "PDAP_API_KEY", + "PDAP_API_URL", + "DISCORD_WEBHOOK_URL", + "OPENAI_API_KEY", + ] + all_env_vars = required_env_vars.copy() + for env_var in test_env_vars: + all_env_vars[env_var] = "TEST" + + EnvVarManager.override(all_env_vars) + conn = get_postgres_connection_string() engine = create_engine(conn) alembic_cfg = Config("alembic.ini") diff --git a/tests/helpers/DBDataCreator.py b/tests/helpers/DBDataCreator.py index 60b873c4..613bfe4d 100644 --- a/tests/helpers/DBDataCreator.py +++ b/tests/helpers/DBDataCreator.py @@ -23,13 +23,17 @@ class BatchURLCreationInfo(BaseModel): batch_id: int url_ids: list[int] + urls: list[str] class DBDataCreator: """ Assists in the creation of test data """ - def __init__(self, db_client: DatabaseClient = DatabaseClient()): - self.db_client = db_client + def __init__(self, db_client: Optional[DatabaseClient] = None): + if db_client is not None: + self.db_client = db_client + else: + self.db_client = DatabaseClient() self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() def batch(self, strategy: CollectorType = CollectorType.EXAMPLE) -> int: @@ -63,7 +67,8 @@ async def batch_and_urls( return BatchURLCreationInfo( batch_id=batch_id, - url_ids=url_ids + url_ids=url_ids, + urls=[iui.url for iui in iuis.url_mappings] ) async def agency(self) -> int: @@ -189,6 +194,7 @@ def urls( URLInfo( url=url, outcome=outcome, + name="Test Name" if outcome == URLStatus.VALIDATED else None, collector_metadata=collector_metadata ) ) diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index b466bfbb..ae34b28e 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -1,9 +1,6 @@ -import asyncio -import logging -import os from dataclasses import dataclass from typing import Generator -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, patch import pytest import pytest_asyncio @@ -11,7 +8,6 @@ from api.main import app from core.AsyncCore import AsyncCore -from core.AsyncCoreLogger import AsyncCoreLogger from core.SourceCollectorCore import SourceCollectorCore from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions from tests.helpers.DBDataCreator import DBDataCreator @@ -48,9 +44,7 @@ def override_access_info() -> AccessInfo: @pytest.fixture(scope="session") def client() -> Generator[TestClient, None, None]: - # Mock envioronment - _original_env = dict(os.environ) - os.environ["DISCORD_WEBHOOK_URL"] = "https://discord.com" + # Mock environment with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info async_core: AsyncCore = c.app.state.async_core @@ -67,8 +61,6 @@ def client() -> Generator[TestClient, None, None]: yield c # Reset environment variables back to original state - os.environ.clear() - os.environ.update(_original_env) @pytest_asyncio.fixture diff --git a/tests/test_automated/integration/api/test_annotate.py b/tests/test_automated/integration/api/test_annotate.py index 0e462ba5..0501ac1f 100644 --- a/tests/test_automated/integration/api/test_annotate.py +++ b/tests/test_automated/integration/api/test_annotate.py @@ -1,15 +1,12 @@ -from typing import Any import pytest from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo from collector_db.DTOs.URLMapping import URLMapping -from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource from collector_db.models import UserUrlAgencySuggestion, UserRelevantSuggestion, UserRecordTypeSuggestion from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import URLAgencyAnnotationPostInfo -from core.DTOs.GetNextURLForAnnotationResponse import GetNextURLForAnnotationResponse from core.DTOs.RecordTypeAnnotationPostInfo import RecordTypeAnnotationPostInfo from core.DTOs.RelevanceAnnotationPostInfo import RelevanceAnnotationPostInfo from core.enums import RecordType, SuggestionType diff --git a/tests/test_automated/integration/api/test_duplicates.py b/tests/test_automated/integration/api/test_duplicates.py index c42b894d..a5c77b29 100644 --- a/tests/test_automated/integration/api/test_duplicates.py +++ b/tests/test_automated/integration/api/test_duplicates.py @@ -1,5 +1,4 @@ import time -from unittest.mock import AsyncMock from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO @@ -30,7 +29,7 @@ def test_duplicates(api_test_helper): assert batch_id_2 is not None - time.sleep(2) + time.sleep(1.5) bi_1: BatchInfo = ath.request_validator.get_batch_info(batch_id_1) bi_2: BatchInfo = ath.request_validator.get_batch_info(batch_id_2) diff --git a/tests/test_automated/integration/collector_db/test_database_structure.py b/tests/test_automated/integration/collector_db/test_database_structure.py index 2b2fcbca..6d82631c 100644 --- a/tests/test_automated/integration/collector_db/test_database_structure.py +++ b/tests/test_automated/integration/collector_db/test_database_structure.py @@ -52,9 +52,11 @@ def __init__( self, columns: list[ColumnTester], table_name: str, - engine: sa.Engine = create_engine(get_postgres_connection_string()), + engine: Optional[sa.Engine] = None, constraints: Optional[list[ConstraintTester]] = None, ): + if engine is None: + engine = create_engine(get_postgres_connection_string(is_async=True)) self.columns = columns self.table_name = table_name self.constraints = constraints @@ -228,6 +230,11 @@ def test_url(db_data_creator: DBDataCreator): column_name="outcome", type_=postgresql.ENUM, allowed_values=get_enum_values(URLStatus) + ), + ColumnTester( + column_name="name", + type_=sa.String, + allowed_values=['test'], ) ], engine=db_data_creator.db_client.engine diff --git a/tests/test_automated/integration/collector_db/test_db_client.py b/tests/test_automated/integration/collector_db/test_db_client.py index c78bf57e..7b98728f 100644 --- a/tests/test_automated/integration/collector_db/test_db_client.py +++ b/tests/test_automated/integration/collector_db/test_db_client.py @@ -60,11 +60,12 @@ async def test_insert_urls( assert insert_urls_info.original_count == 2 assert insert_urls_info.duplicate_count == 1 - -def test_insert_logs(db_data_creator: DBDataCreator): +@pytest.mark.asyncio +async def test_insert_logs(db_data_creator: DBDataCreator): batch_id_1 = db_data_creator.batch() batch_id_2 = db_data_creator.batch() + adb_client = db_data_creator.adb_client db_client = db_data_creator.db_client db_client.insert_logs( log_infos=[ @@ -74,26 +75,28 @@ def test_insert_logs(db_data_creator: DBDataCreator): ] ) - logs = db_client.get_logs_by_batch_id(batch_id_1) + logs = await adb_client.get_logs_by_batch_id(batch_id_1) assert len(logs) == 2 - logs = db_client.get_logs_by_batch_id(batch_id_2) + logs = await adb_client.get_logs_by_batch_id(batch_id_2) assert len(logs) == 1 -def test_delete_old_logs(db_data_creator: DBDataCreator): +@pytest.mark.asyncio +async def test_delete_old_logs(db_data_creator: DBDataCreator): batch_id = db_data_creator.batch() old_datetime = datetime.now() - timedelta(days=1) db_client = db_data_creator.db_client + adb_client = db_data_creator.adb_client log_infos = [] for i in range(3): log_infos.append(LogInfo(log="test log", batch_id=batch_id, created_at=old_datetime)) db_client.insert_logs(log_infos=log_infos) - logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 3 db_client.delete_old_logs() - logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 0 def test_delete_url_updated_at(db_data_creator: DBDataCreator): diff --git a/tests/test_automated/integration/core/test_async_core.py b/tests/test_automated/integration/core/test_async_core.py index b4b8e740..ed314dfd 100644 --- a/tests/test_automated/integration/core/test_async_core.py +++ b/tests/test_automated/integration/core/test_async_core.py @@ -21,6 +21,7 @@ def setup_async_core(adb_client: AsyncDatabaseClient): url_request_interface=AsyncMock(), html_parser=AsyncMock(), discord_poster=AsyncMock(), + pdap_client=AsyncMock() ), collector_manager=AsyncMock() ) diff --git a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py index 75630af8..04256de9 100644 --- a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py +++ b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py @@ -3,39 +3,54 @@ import pytest -from collector_db.models import URL +from collector_db.enums import TaskType +from collector_db.models import URL, URLErrorInfo from collector_manager.enums import URLStatus from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator -from core.enums import RecordType +from core.enums import RecordType, SubmitResponseStatus from helpers.DBDataCreator import BatchURLCreationInfo, DBDataCreator from pdap_api_client.AccessManager import AccessManager from pdap_api_client.DTOs import RequestInfo, RequestType, ResponseInfo from pdap_api_client.PDAPClient import PDAPClient +def mock_make_request(pdap_client: PDAPClient, urls: list[str]): + assert len(urls) == 3, "Expected 3 urls" + pdap_client.access_manager.make_request = AsyncMock( + return_value=ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "data_sources": [ + { + "url": urls[0], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 21, + }, + { + "url": urls[1], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 34, + }, + { + "url": urls[2], + "status": SubmitResponseStatus.FAILURE, + "error": "Test Error", + "data_source_id": None + } + ] + } + ) + ) + @pytest.fixture -def mock_pdap_client(): +def mock_pdap_client() -> PDAPClient: mock_access_manager = MagicMock( spec=AccessManager ) - mock_access_manager.make_request = AsyncMock( - side_effect=[ - ResponseInfo( - status_code=HTTPStatus.OK, - data={ - "id": 21 - } - ), - ResponseInfo( - status_code=HTTPStatus.OK, - data={ - "id": 34 - } - ) - ] - ) mock_access_manager.jwt_header = AsyncMock( return_value={"Authorization": "Bearer token"} ) @@ -44,13 +59,15 @@ def mock_pdap_client(): ) return pdap_client -async def setup_validated_urls(db_data_creator: DBDataCreator): +async def setup_validated_urls(db_data_creator: DBDataCreator) -> list[str]: creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( - url_count=2, + url_count=3, with_html_content=True ) + url_1 = creation_info.url_ids[0] url_2 = creation_info.url_ids[1] + url_3 = creation_info.url_ids[2] await db_data_creator.adb_client.approve_url( approval_info=FinalReviewApprovalInfo( url_id=url_1, @@ -72,16 +89,31 @@ async def setup_validated_urls(db_data_creator: DBDataCreator): name="URL 2 Name", description="URL 2 Description", ), - user_id=1 + user_id=2 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_3, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[5, 6], + name="URL 3 Name", + description="URL 3 Description", + ), + user_id=3 ) + return creation_info.urls @pytest.mark.asyncio async def test_submit_approved_url_task( db_data_creator, - mock_pdap_client, + mock_pdap_client: PDAPClient, monkeypatch ): - monkeypatch.setenv("PDAP_API_URL", "http://localhost:8000") + """ + The submit_approved_url_task should submit + all validated URLs to the PDAP Data Sources App + """ + # Get Task Operator operator = SubmitApprovedURLTaskOperator( @@ -94,13 +126,17 @@ async def test_submit_approved_url_task( # Create URLs with status 'validated' in database and all requisite URL values # Ensure they have optional metadata as well - await setup_validated_urls(db_data_creator) + urls = await setup_validated_urls(db_data_creator) + mock_make_request(mock_pdap_client, urls) # Check Task Operator does meet pre-requisites assert await operator.meets_task_prerequisites() # Run Task - run_info = await operator.run_task(task_id=1) + task_id = await db_data_creator.adb_client.initiate_task( + task_type=TaskType.SUBMIT_APPROVED + ) + run_info = await operator.run_task(task_id=task_id) # Check Task has been marked as completed assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message @@ -109,63 +145,73 @@ async def test_submit_approved_url_task( urls = await db_data_creator.adb_client.get_all(URL, order_by_attribute="id") url_1 = urls[0] url_2 = urls[1] + url_3 = urls[2] # Check URLs have been marked as 'submitted' assert url_1.outcome == URLStatus.SUBMITTED.value assert url_2.outcome == URLStatus.SUBMITTED.value + assert url_3.outcome == URLStatus.ERROR.value # Check URLs now have data source ids assert url_1.data_source_id == 21 assert url_2.data_source_id == 34 + assert url_3.data_source_id is None - # Check mock method was called twice with expected parameters - access_manager = mock_pdap_client.access_manager - assert access_manager.make_request.call_count == 2 - # Check first call + # Check that errored URL has entry in url_error_info + url_errors = await db_data_creator.adb_client.get_all(URLErrorInfo) + assert len(url_errors) == 1 + url_error = url_errors[0] + assert url_error.url_id == url_3.id + assert url_error.error == "Test Error" + # Check mock method was called expected parameters + access_manager = mock_pdap_client.access_manager + access_manager.make_request.assert_called_once() call_1 = access_manager.make_request.call_args_list[0][0][0] expected_call_1 = RequestInfo( type_=RequestType.POST, - url="http://localhost:8000/data-sources", + url="TEST/source-collector/data-sources", headers=access_manager.jwt_header.return_value, json={ - "entry_data": { - "name": "URL 1 Name", - "source_url": url_1.url, - "record_type_name": "Accident Reports", - "description": "URL 1 Description", - "record_formats": ["Record Format 1", "Record Format 2"], - "data_portal_type": "Data Portal Type 1", - "supplying_entity": "Supplying Entity 1" - }, - "linked_agency_ids": [1, 2] + "data_sources": [ + { + "name": "URL 1 Name", + "source_url": url_1.url, + "record_type": "Accident Reports", + "description": "URL 1 Description", + "record_formats": ["Record Format 1", "Record Format 2"], + "data_portal_type": "Data Portal Type 1", + "last_approval_editor": 1, + "supplying_entity": "Supplying Entity 1", + "agency_ids": [1, 2] + }, + { + "name": "URL 2 Name", + "source_url": url_2.url, + "record_type": "Incarceration Records", + "description": "URL 2 Description", + "last_approval_editor": 2, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [3, 4] + }, + { + "name": "URL 3 Name", + "source_url": url_3.url, + "record_type": "Accident Reports", + "description": "URL 3 Description", + "last_approval_editor": 3, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [5, 6] + } + ] } ) assert call_1.type_ == expected_call_1.type_ assert call_1.url == expected_call_1.url assert call_1.headers == expected_call_1.headers assert call_1.json == expected_call_1.json - # Check second call - call_2 = access_manager.make_request.call_args_list[1][0][0] - expected_call_2 = RequestInfo( - type_=RequestType.POST, - url="http://localhost:8000/data-sources", - headers=access_manager.jwt_header.return_value, - json={ - "entry_data": { - "name": "URL 2 Name", - "source_url": url_2.url, - "record_type_name": "Incarceration Records", - "description": "URL 2 Description", - "data_portal_type": None, - "supplying_entity": None, - "record_formats": None - }, - "linked_agency_ids": [3, 4] - } - ) - assert call_2.type_ == expected_call_2.type_ - assert call_2.url == expected_call_2.url - assert call_2.headers == expected_call_2.headers - assert call_2.json == expected_call_2.json \ No newline at end of file diff --git a/util/DiscordNotifier.py b/util/DiscordNotifier.py index 15e74020..6df1aa90 100644 --- a/util/DiscordNotifier.py +++ b/util/DiscordNotifier.py @@ -10,4 +10,10 @@ def __init__(self, webhook_url: str): raise ValueError("WEBHOOK_URL environment variable not set") self.webhook_url = webhook_url def post_to_discord(self, message): - requests.post(self.webhook_url, json={"content": message}) + try: + requests.post(self.webhook_url, json={"content": message}) + except Exception as e: + logging.error( + f"Error posting message to Discord: {e}." + f"\n\nMessage: {message}" + ) diff --git a/util/helper_functions.py b/util/helper_functions.py index bf72d39b..7d6c7f8d 100644 --- a/util/helper_functions.py +++ b/util/helper_functions.py @@ -16,6 +16,19 @@ def get_from_env(key: str, allow_none: bool = False): raise ValueError(f"Environment variable {key} is not set") return val +def load_from_environment(keys: list[str]) -> dict[str, str]: + """ + Load selected keys from environment, returning a dictionary + """ + original_environment = os.environ.copy() + try: + load_dotenv() + return {key: os.getenv(key) for key in keys} + finally: + # Restore the original environment + os.environ.clear() + os.environ.update(original_environment) + def base_model_list_dump(model_list: list[BaseModel]) -> list[dict]: return [model.model_dump() for model in model_list] From beae1b9885b0e2e94dfefad703bf4c02fb9fce3b Mon Sep 17 00:00:00 2001 From: Max Chis Date: Tue, 15 Apr 2025 16:53:48 -0400 Subject: [PATCH 5/5] fix(tests): fix import bug --- .../integration/tasks/test_submit_approved_url_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py index 04256de9..b15ff9d5 100644 --- a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py +++ b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py @@ -10,7 +10,7 @@ from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator from core.enums import RecordType, SubmitResponseStatus -from helpers.DBDataCreator import BatchURLCreationInfo, DBDataCreator +from tests.helpers.DBDataCreator import BatchURLCreationInfo, DBDataCreator from pdap_api_client.AccessManager import AccessManager from pdap_api_client.DTOs import RequestInfo, RequestType, ResponseInfo from pdap_api_client.PDAPClient import PDAPClient