diff --git a/ENV.md b/ENV.md index a2e84f24..22f84cb8 100644 --- a/ENV.md +++ b/ENV.md @@ -2,25 +2,26 @@ This page provides a full list, with description, of all the environment variabl Please ensure these are properly defined in a `.env` file in the root directory. -| Name | Description | Example | -|----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------| -| `GOOGLE_API_KEY` | The API key required for accessing the Google Custom Search API | `abc123` | -| `GOOGLE_CSE_ID` | The CSE ID required for accessing the Google Custom Search API | `abc123` | -|`POSTGRES_USER` | The username for the test database | `test_source_collector_user` | -|`POSTGRES_PASSWORD` | The password for the test database | `HanviliciousHamiltonHilltops` | -|`POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | -|`POSTGRES_HOST` | The host for the test database | `127.0.0.1` | -|`POSTGRES_PORT` | The port for the test database | `5432` | -|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | -|`DEV`| Set to any value to run the application in development mode. | `true` | -|`DEEPSEEK_API_KEY`| The API key required for accessing the DeepSeek API. | `abc123` | -|`OPENAI_API_KEY`| The API key required for accessing the OpenAI API. | `abc123` | -|`PDAP_EMAIL`| An email address for accessing the PDAP API.[^1] | `abc123@test.com` | -|`PDAP_PASSWORD`| A password for accessing the PDAP API.[^1] | `abc123` | -|`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | -|`PDAP_API_URL`| The URL for the PDAP API| `https://data-sources-v2.pdap.dev/api`| -|`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` | -|`HUGGINGFACE_INFERENCE_API_KEY` | The API key required for accessing the Huggingface Inference API. | `abc123` | +| Name | Description | Example | +|--------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------| +| `GOOGLE_API_KEY` | The API key required for accessing the Google Custom Search API | `abc123` | +| `GOOGLE_CSE_ID` | The CSE ID required for accessing the Google Custom Search API | `abc123` | +| `POSTGRES_USER` | The username for the test database | `test_source_collector_user` | +| `POSTGRES_PASSWORD` | The password for the test database | `HanviliciousHamiltonHilltops` | +| `POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | +| `POSTGRES_HOST` | The host for the test database | `127.0.0.1` | +| `POSTGRES_PORT` | The port for the test database | `5432` | +| `DS_APP_SECRET_KEY` | The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | +| `DEV` | Set to any value to run the application in development mode. | `true` | +| `DEEPSEEK_API_KEY` | The API key required for accessing the DeepSeek API. | `abc123` | +| `OPENAI_API_KEY` | The API key required for accessing the OpenAI API. | `abc123` | +| `PDAP_EMAIL` | An email address for accessing the PDAP API.[^1] | `abc123@test.com` | +| `PDAP_PASSWORD` | A password for accessing the PDAP API.[^1] | `abc123` | +| `PDAP_API_KEY` | An API key for accessing the PDAP API. | `abc123` | +| `PDAP_API_URL` | The URL for the PDAP API | `https://data-sources-v2.pdap.dev/api` | +| `DISCORD_WEBHOOK_URL` | The URL for the Discord webhook used for notifications | `abc123` | +| `HUGGINGFACE_INFERENCE_API_KEY` | The API key required for accessing the Hugging Face Inference API. | `abc123` | +| `HUGGINGFACE_HUB_TOKEN` | `abc123` | The API key required for uploading to the PDAP HuggingFace account via Hugging Face Hub API. | [^1:] The user account in question will require elevated permissions to access certain endpoints. At a minimum, the user will require the `source_collector` and `db_write` permissions. diff --git a/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py b/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py new file mode 100644 index 00000000..45cf66a0 --- /dev/null +++ b/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py @@ -0,0 +1,74 @@ +"""Setup for upload to huggingface task + +Revision ID: 637de6eaa3ab +Revises: 59d2af1bab33 +Create Date: 2025-07-26 08:30:37.940091 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '637de6eaa3ab' +down_revision: Union[str, None] = '59d2af1bab33' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +TABLE_NAME = "huggingface_upload_state" + + +def upgrade() -> None: + op.create_table( + TABLE_NAME, + id_column(), + sa.Column( + "last_upload_at", + sa.DateTime(), + nullable=False + ), + ) + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face' + ] + ) + + +def downgrade() -> None: + op.drop_table(TABLE_NAME) + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources' + ] + ) diff --git a/src/api/main.py b/src/api/main.py index 355fbedf..46ae4a3a 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -31,6 +31,7 @@ from src.db.client.async_ import AsyncDatabaseClient from src.db.client.sync import DatabaseClient from src.core.tasks.url.operators.url_html.scraper.root_url_cache.core import RootURLCache +from src.external.huggingface.hub.client import HuggingFaceHubClient from src.external.huggingface.inference.client import HuggingFaceInferenceClient from src.external.pdap.client import PDAPClient @@ -101,7 +102,10 @@ async def lifespan(app: FastAPI): handler=task_handler, loader=ScheduledTaskOperatorLoader( adb_client=adb_client, - pdap_client=pdap_client + pdap_client=pdap_client, + hf_client=HuggingFaceHubClient( + token=env_var_manager.hf_hub_token + ) ) ) await async_scheduled_task_manager.setup() diff --git a/src/core/env_var_manager.py b/src/core/env_var_manager.py index 8fce7ac3..98a78b69 100644 --- a/src/core/env_var_manager.py +++ b/src/core/env_var_manager.py @@ -30,6 +30,7 @@ def _load(self): self.openai_api_key = self.require_env("OPENAI_API_KEY") self.hf_inference_api_key = self.require_env("HUGGINGFACE_INFERENCE_API_KEY") + self.hf_hub_token = self.require_env("HUGGINGFACE_HUB_TOKEN") self.postgres_user = self.require_env("POSTGRES_USER") self.postgres_password = self.require_env("POSTGRES_PASSWORD") diff --git a/src/db/models/instantiations/sync_state/__init__.py b/src/core/tasks/scheduled/huggingface/__init__.py similarity index 100% rename from src/db/models/instantiations/sync_state/__init__.py rename to src/core/tasks/scheduled/huggingface/__init__.py diff --git a/src/core/tasks/scheduled/huggingface/operator.py b/src/core/tasks/scheduled/huggingface/operator.py new file mode 100644 index 00000000..45e35e17 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/operator.py @@ -0,0 +1,36 @@ + +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.external.huggingface.hub.client import HuggingFaceHubClient + + +class PushToHuggingFaceTaskOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.PUSH_TO_HUGGINGFACE + + def __init__( + self, + adb_client: AsyncDatabaseClient, + hf_client: HuggingFaceHubClient + ): + super().__init__(adb_client) + self.hf_client = hf_client + + async def inner_task_logic(self): + # Check if any valid urls have been updated + valid_urls_updated = await self.adb_client.check_valid_urls_updated() + print(f"Valid urls updated: {valid_urls_updated}") + if not valid_urls_updated: + print("No valid urls updated, skipping.") + return + + + # Otherwise, push to huggingface + run_dt = await self.adb_client.get_current_database_time() + outputs = await self.adb_client.get_data_sources_raw_for_huggingface() + self.hf_client.push_data_sources_raw_to_hub(outputs) + + await self.adb_client.set_hugging_face_upload_state(run_dt.replace(tzinfo=None)) diff --git a/src/core/tasks/scheduled/huggingface/queries/__init__.py b/src/core/tasks/scheduled/huggingface/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/huggingface/queries/check/__init__.py b/src/core/tasks/scheduled/huggingface/queries/check/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/huggingface/queries/check/core.py b/src/core/tasks/scheduled/huggingface/queries/check/core.py new file mode 100644 index 00000000..7b724a30 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/check/core.py @@ -0,0 +1,14 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.huggingface.queries.check.requester import CheckValidURLsUpdatedRequester +from src.db.queries.base.builder import QueryBuilderBase + + +class CheckValidURLsUpdatedQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + requester = CheckValidURLsUpdatedRequester(session=session) + latest_upload = await requester.latest_upload() + return await requester.has_valid_urls(latest_upload) + + diff --git a/src/core/tasks/scheduled/huggingface/queries/check/requester.py b/src/core/tasks/scheduled/huggingface/queries/check/requester.py new file mode 100644 index 00000000..6af94560 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/check/requester.py @@ -0,0 +1,53 @@ +from datetime import datetime + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.functions import count + +from src.collectors.enums import URLStatus +from src.db.helpers.session import session_helper as sh +from src.db.models.instantiations.state.huggingface import HuggingFaceUploadState +from src.db.models.instantiations.url.compressed_html import URLCompressedHTML +from src.db.models.instantiations.url.core.sqlalchemy import URL + + +class CheckValidURLsUpdatedRequester: + + def __init__(self, session: AsyncSession): + self.session = session + + async def latest_upload(self) -> datetime: + query = ( + select( + HuggingFaceUploadState.last_upload_at + ) + ) + return await sh.scalar( + session=self.session, + query=query + ) + + async def has_valid_urls(self, last_upload_at: datetime | None) -> bool: + query = ( + select(count(URL.id)) + .join( + URLCompressedHTML, + URL.id == URLCompressedHTML.url_id + ) + .where( + URL.outcome.in_( + [ + URLStatus.VALIDATED, + URLStatus.NOT_RELEVANT.value, + URLStatus.SUBMITTED.value, + ] + ), + ) + ) + if last_upload_at is not None: + query = query.where(URL.updated_at > last_upload_at) + url_count = await sh.scalar( + session=self.session, + query=query + ) + return url_count > 0 diff --git a/src/core/tasks/scheduled/huggingface/queries/get/__init__.py b/src/core/tasks/scheduled/huggingface/queries/get/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/huggingface/queries/get/convert.py b/src/core/tasks/scheduled/huggingface/queries/get/convert.py new file mode 100644 index 00000000..0f8e26a6 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/get/convert.py @@ -0,0 +1,16 @@ +from src.collectors.enums import URLStatus +from src.core.enums import RecordType +from src.core.tasks.scheduled.huggingface.queries.get.enums import RecordTypeCoarse +from src.core.tasks.scheduled.huggingface.queries.get.mappings import FINE_COARSE_RECORD_TYPE_MAPPING, \ + OUTCOME_RELEVANCY_MAPPING + + +def convert_fine_to_coarse_record_type( + fine_record_type: RecordType +) -> RecordTypeCoarse: + return FINE_COARSE_RECORD_TYPE_MAPPING[fine_record_type] + +def convert_url_status_to_relevant( + url_status: URLStatus +) -> bool: + return OUTCOME_RELEVANCY_MAPPING[url_status] \ No newline at end of file diff --git a/src/core/tasks/scheduled/huggingface/queries/get/core.py b/src/core/tasks/scheduled/huggingface/queries/get/core.py new file mode 100644 index 00000000..7deea322 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/get/core.py @@ -0,0 +1,65 @@ +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.enums import URLStatus +from src.core.tasks.scheduled.huggingface.queries.get.convert import convert_url_status_to_relevant, \ + convert_fine_to_coarse_record_type +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from src.db.models.instantiations.url.compressed_html import URLCompressedHTML +from src.db.models.instantiations.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.utils.compression import decompress_html +from src.db.helpers.session import session_helper as sh + +class GetForLoadingToHuggingFaceQueryBuilder(QueryBuilderBase): + + + async def run(self, session: AsyncSession) -> list[GetForLoadingToHuggingFaceOutput]: + label_url_id = 'url_id' + label_url = 'url' + label_url_status = 'url_status' + label_record_type_fine = 'record_type_fine' + label_html = 'html' + + + query = ( + select( + URL.id.label(label_url_id), + URL.url.label(label_url), + URL.outcome.label(label_url_status), + URL.record_type.label(label_record_type_fine), + URLCompressedHTML.compressed_html.label(label_html) + ) + .join( + URLCompressedHTML, + URL.id == URLCompressedHTML.url_id + ) + .where( + URL.outcome.in_([ + URLStatus.VALIDATED, + URLStatus.NOT_RELEVANT, + URLStatus.SUBMITTED + ]) + ) + ) + db_results = await sh.mappings( + session=session, + query=query + ) + final_results = [] + for result in db_results: + output = GetForLoadingToHuggingFaceOutput( + url_id=result[label_url_id], + url=result[label_url], + relevant=convert_url_status_to_relevant(result[label_url_status]), + record_type_fine=result[label_record_type_fine], + record_type_coarse=convert_fine_to_coarse_record_type( + result[label_record_type_fine] + ), + html=decompress_html(result[label_html]) + ) + final_results.append(output) + + return final_results diff --git a/src/core/tasks/scheduled/huggingface/queries/get/enums.py b/src/core/tasks/scheduled/huggingface/queries/get/enums.py new file mode 100644 index 00000000..86e1c511 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/get/enums.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class RecordTypeCoarse(Enum): + INFO_ABOUT_AGENCIES = "Info About Agencies" + INFO_ABOUT_OFFICERS = "Info About Officers" + AGENCY_PUBLISHED_RESOURCES = "Agency-Published Resources" + POLICE_AND_PUBLIC = "Police & Public Interactions" + POOR_DATA_SOURCE = "Poor Data Source" + NOT_RELEVANT = "Not Relevant" + JAILS_AND_COURTS = "Jails & Courts Specific" + OTHER = "Other" \ No newline at end of file diff --git a/src/core/tasks/scheduled/huggingface/queries/get/mappings.py b/src/core/tasks/scheduled/huggingface/queries/get/mappings.py new file mode 100644 index 00000000..2196a927 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/get/mappings.py @@ -0,0 +1,54 @@ +from src.collectors.enums import URLStatus +from src.core.enums import RecordType +from src.core.tasks.scheduled.huggingface.queries.get.enums import RecordTypeCoarse + +FINE_COARSE_RECORD_TYPE_MAPPING = { + # Police and Public + RecordType.ACCIDENT_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.ARREST_RECORDS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CALLS_FOR_SERVICE: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CAR_GPS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CITATIONS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.DISPATCH_LOGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.DISPATCH_RECORDINGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.FIELD_CONTACTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.INCIDENT_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.MISC_POLICE_ACTIVITY: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.OFFICER_INVOLVED_SHOOTINGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.STOPS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.SURVEYS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.USE_OF_FORCE_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.VEHICLE_PURSUITS: RecordTypeCoarse.POLICE_AND_PUBLIC, + # Info About Officers + RecordType.COMPLAINTS_AND_MISCONDUCT: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.DAILY_ACTIVITY_LOGS: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.TRAINING_AND_HIRING_INFO: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.PERSONNEL_RECORDS: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + # Info About Agencies + RecordType.ANNUAL_AND_MONTHLY_REPORTS: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.BUDGETS_AND_FINANCES: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.CONTACT_INFO_AND_AGENCY_META: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.GEOGRAPHIC: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.LIST_OF_DATA_SOURCES: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.POLICIES_AND_CONTRACTS: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + # Agency-Published Resources + RecordType.CRIME_MAPS_AND_REPORTS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.CRIME_STATISTICS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.MEDIA_BULLETINS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.RECORDS_REQUEST_INFO: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.RESOURCES: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.SEX_OFFENDER_REGISTRY: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.WANTED_PERSONS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + # Jails and Courts Specific + RecordType.BOOKING_REPORTS: RecordTypeCoarse.JAILS_AND_COURTS, + RecordType.COURT_CASES: RecordTypeCoarse.JAILS_AND_COURTS, + RecordType.INCARCERATION_RECORDS: RecordTypeCoarse.JAILS_AND_COURTS, + # Other + None: RecordTypeCoarse.NOT_RELEVANT +} + +OUTCOME_RELEVANCY_MAPPING = { + URLStatus.SUBMITTED: True, + URLStatus.VALIDATED: True, + URLStatus.NOT_RELEVANT: False +} \ No newline at end of file diff --git a/src/core/tasks/scheduled/huggingface/queries/get/model.py b/src/core/tasks/scheduled/huggingface/queries/get/model.py new file mode 100644 index 00000000..8aa52b16 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/get/model.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +from src.core.enums import RecordType +from src.core.tasks.scheduled.huggingface.queries.get.enums import RecordTypeCoarse + + +class GetForLoadingToHuggingFaceOutput(BaseModel): + url_id: int + url: str + relevant: bool + record_type_fine: RecordType | None + record_type_coarse: RecordTypeCoarse | None + html: str \ No newline at end of file diff --git a/src/core/tasks/scheduled/huggingface/queries/state.py b/src/core/tasks/scheduled/huggingface/queries/state.py new file mode 100644 index 00000000..5e04c809 --- /dev/null +++ b/src/core/tasks/scheduled/huggingface/queries/state.py @@ -0,0 +1,24 @@ +from datetime import datetime + +from sqlalchemy import delete, insert +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.instantiations.state.huggingface import HuggingFaceUploadState +from src.db.queries.base.builder import QueryBuilderBase + + +class SetHuggingFaceUploadStateQueryBuilder(QueryBuilderBase): + + def __init__(self, dt: datetime): + super().__init__() + self.dt = dt + + async def run(self, session: AsyncSession) -> None: + # Delete entry if any exists + await session.execute( + delete(HuggingFaceUploadState) + ) + # Insert entry + await session.execute( + insert(HuggingFaceUploadState).values(last_upload_at=self.dt) + ) diff --git a/src/core/tasks/scheduled/loader.py b/src/core/tasks/scheduled/loader.py index bd2e4b84..36f28db5 100644 --- a/src/core/tasks/scheduled/loader.py +++ b/src/core/tasks/scheduled/loader.py @@ -1,6 +1,8 @@ +from src.core.tasks.scheduled.huggingface.operator import PushToHuggingFaceTaskOperator from src.core.tasks.scheduled.sync.agency.operator import SyncAgenciesTaskOperator from src.core.tasks.scheduled.sync.data_sources.operator import SyncDataSourcesTaskOperator from src.db.client.async_ import AsyncDatabaseClient +from src.external.huggingface.hub.client import HuggingFaceHubClient from src.external.pdap.client import PDAPClient @@ -10,10 +12,12 @@ def __init__( self, adb_client: AsyncDatabaseClient, pdap_client: PDAPClient, + hf_client: HuggingFaceHubClient ): # Dependencies self.adb_client = adb_client self.pdap_client = pdap_client + self.hf_client = hf_client async def get_sync_agencies_task_operator(self) -> SyncAgenciesTaskOperator: @@ -27,3 +31,9 @@ async def get_sync_data_sources_task_operator(self) -> SyncDataSourcesTaskOperat adb_client=self.adb_client, pdap_client=self.pdap_client ) + + async def get_push_to_hugging_face_task_operator(self) -> PushToHuggingFaceTaskOperator: + return PushToHuggingFaceTaskOperator( + adb_client=self.adb_client, + hf_client=self.hf_client + ) diff --git a/src/core/tasks/scheduled/manager.py b/src/core/tasks/scheduled/manager.py index 66b50535..ac16eb31 100644 --- a/src/core/tasks/scheduled/manager.py +++ b/src/core/tasks/scheduled/manager.py @@ -31,6 +31,7 @@ def __init__( self.populate_backlog_snapshot_job = None self.sync_agencies_job = None self.sync_data_sources_job = None + self.push_to_hugging_face_job = None async def setup(self): self.scheduler.start() @@ -79,6 +80,17 @@ async def add_scheduled_tasks(self): "operator": await self.loader.get_sync_data_sources_task_operator() } ) + # TODO: enable once more URLs with HTML have been added to the database. + # self.push_to_hugging_face_job = self.scheduler.add_job( + # self.run_task, + # trigger=IntervalTrigger( + # days=1, + # start_date=datetime.now() + timedelta(minutes=4) + # ), + # kwargs={ + # "operator": await self.loader.get_push_to_hugging_face_task_operator() + # } + # ) def shutdown(self): if self.scheduler.running: diff --git a/src/core/tasks/scheduled/sync/agency/queries/get_sync_params.py b/src/core/tasks/scheduled/sync/agency/queries/get_sync_params.py index 8ff148e8..a502a156 100644 --- a/src/core/tasks/scheduled/sync/agency/queries/get_sync_params.py +++ b/src/core/tasks/scheduled/sync/agency/queries/get_sync_params.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.core.tasks.scheduled.sync.agency.dtos.parameters import AgencySyncParameters -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState from src.db.queries.base.builder import QueryBuilderBase diff --git a/src/core/tasks/scheduled/sync/agency/queries/mark_full_sync.py b/src/core/tasks/scheduled/sync/agency/queries/mark_full_sync.py index 50e7642c..f92a8798 100644 --- a/src/core/tasks/scheduled/sync/agency/queries/mark_full_sync.py +++ b/src/core/tasks/scheduled/sync/agency/queries/mark_full_sync.py @@ -1,6 +1,6 @@ from sqlalchemy import update, func, text, Update -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState def get_mark_full_agencies_sync_query() -> Update: diff --git a/src/core/tasks/scheduled/sync/agency/queries/update_sync_progress.py b/src/core/tasks/scheduled/sync/agency/queries/update_sync_progress.py index 2055bdc9..6cc88398 100644 --- a/src/core/tasks/scheduled/sync/agency/queries/update_sync_progress.py +++ b/src/core/tasks/scheduled/sync/agency/queries/update_sync_progress.py @@ -1,6 +1,6 @@ from sqlalchemy import Update, update -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState def get_update_agencies_sync_progress_query(page: int) -> Update: diff --git a/src/core/tasks/scheduled/sync/data_sources/queries/get_sync_params.py b/src/core/tasks/scheduled/sync/data_sources/queries/get_sync_params.py index 695813c6..5608dfe4 100644 --- a/src/core/tasks/scheduled/sync/data_sources/queries/get_sync_params.py +++ b/src/core/tasks/scheduled/sync/data_sources/queries/get_sync_params.py @@ -3,7 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.core.tasks.scheduled.sync.data_sources.params import DataSourcesSyncParameters -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState from src.db.queries.base.builder import QueryBuilderBase diff --git a/src/core/tasks/scheduled/sync/data_sources/queries/mark_full_sync.py b/src/core/tasks/scheduled/sync/data_sources/queries/mark_full_sync.py index d896f765..f2966c69 100644 --- a/src/core/tasks/scheduled/sync/data_sources/queries/mark_full_sync.py +++ b/src/core/tasks/scheduled/sync/data_sources/queries/mark_full_sync.py @@ -1,6 +1,6 @@ from sqlalchemy import Update, update, func, text -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState def get_mark_full_data_sources_sync_query() -> Update: diff --git a/src/core/tasks/scheduled/sync/data_sources/queries/update_sync_progress.py b/src/core/tasks/scheduled/sync/data_sources/queries/update_sync_progress.py index d6ba80e8..51962fff 100644 --- a/src/core/tasks/scheduled/sync/data_sources/queries/update_sync_progress.py +++ b/src/core/tasks/scheduled/sync/data_sources/queries/update_sync_progress.py @@ -1,6 +1,6 @@ from sqlalchemy import update, Update -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState def get_update_data_sources_sync_progress_query(page: int) -> Update: diff --git a/src/db/client/async_.py b/src/db/client/async_.py index bb444c0e..9f554f87 100644 --- a/src/db/client/async_.py +++ b/src/db/client/async_.py @@ -52,6 +52,10 @@ from src.collectors.enums import URLStatus, CollectorType from src.core.enums import BatchStatus, SuggestionType, RecordType, SuggestedStatus from src.core.env_var_manager import EnvVarManager +from src.core.tasks.scheduled.huggingface.queries.check.core import CheckValidURLsUpdatedQueryBuilder +from src.core.tasks.scheduled.huggingface.queries.get.core import GetForLoadingToHuggingFaceQueryBuilder +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from src.core.tasks.scheduled.huggingface.queries.state import SetHuggingFaceUploadStateQueryBuilder from src.core.tasks.scheduled.sync.agency.dtos.parameters import AgencySyncParameters from src.core.tasks.scheduled.sync.agency.queries.get_sync_params import GetAgenciesSyncParametersQueryBuilder from src.core.tasks.scheduled.sync.agency.queries.mark_full_sync import get_mark_full_agencies_sync_query @@ -141,11 +145,11 @@ from src.external.pdap.dtos.sync.agencies import AgenciesSyncResponseInnerInfo from src.external.pdap.dtos.sync.data_sources import DataSourcesSyncResponseInnerInfo - class AsyncDatabaseClient: def __init__(self, db_url: Optional[str] = None): if db_url is None: db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) + self.db_url = db_url echo = ConfigManager.get_sqlalchemy_echo() self.engine = create_async_engine( url=db_url, @@ -1490,16 +1494,8 @@ async def get_pending_urls_not_recently_probed_for_404(self, session: AsyncSessi urls = raw_result.scalars().all() return [URL404ProbeTDO(url=url.url, url_id=url.id) for url in urls] - @session_manager - async def get_urls_aggregated_pending_metrics( - self, - session: AsyncSession - ): - builder = GetMetricsURLSAggregatedPendingQueryBuilder() - result = await builder.run( - session=session - ) - return result + async def get_urls_aggregated_pending_metrics(self): + return await self.run_query_builder(GetMetricsURLSAggregatedPendingQueryBuilder()) async def get_agencies_sync_parameters(self) -> AgencySyncParameters: return await self.run_query_builder( @@ -1514,7 +1510,7 @@ async def get_data_sources_sync_parameters(self) -> DataSourcesSyncParameters: async def upsert_agencies( self, agencies: list[AgenciesSyncResponseInnerInfo] - ): + ) -> None: await self.bulk_upsert( models=convert_agencies_sync_response_to_agencies_upsert(agencies) ) @@ -1522,23 +1518,23 @@ async def upsert_agencies( async def upsert_urls_from_data_sources( self, data_sources: list[DataSourcesSyncResponseInnerInfo] - ): + ) -> None: await self.run_query_builder( UpsertURLsFromDataSourcesQueryBuilder( sync_infos=data_sources ) ) - async def update_agencies_sync_progress(self, page: int): + async def update_agencies_sync_progress(self, page: int) -> None: await self.execute(get_update_agencies_sync_progress_query(page)) - async def update_data_sources_sync_progress(self, page: int): + async def update_data_sources_sync_progress(self, page: int) -> None: await self.execute(get_update_data_sources_sync_progress_query(page)) - async def mark_full_data_sources_sync(self): + async def mark_full_data_sources_sync(self) -> None: await self.execute(get_mark_full_data_sources_sync_query()) - async def mark_full_agencies_sync(self): + async def mark_full_agencies_sync(self) -> None: await self.execute(get_mark_full_agencies_sync_query()) @session_manager @@ -1562,10 +1558,28 @@ async def add_raw_html( self, session: AsyncSession, info_list: list[RawHTMLInfo] - ): + ) -> None: for info in info_list: compressed_html = URLCompressedHTML( url_id=info.url_id, compressed_html=compress_html(info.html) ) session.add(compressed_html) + + async def get_data_sources_raw_for_huggingface(self) -> list[GetForLoadingToHuggingFaceOutput]: + return await self.run_query_builder( + GetForLoadingToHuggingFaceQueryBuilder() + ) + + async def set_hugging_face_upload_state(self, dt: datetime) -> None: + await self.run_query_builder( + SetHuggingFaceUploadStateQueryBuilder(dt=dt) + ) + + async def check_valid_urls_updated(self) -> bool: + return await self.run_query_builder( + CheckValidURLsUpdatedQueryBuilder() + ) + + async def get_current_database_time(self) -> datetime: + return await self.scalar(select(func.now())) diff --git a/src/db/enums.py b/src/db/enums.py index 7ea8de8c..6c1d1496 100644 --- a/src/db/enums.py +++ b/src/db/enums.py @@ -43,6 +43,7 @@ class TaskType(PyEnum): PROBE_404 = "404 Probe" SYNC_AGENCIES = "Sync Agencies" SYNC_DATA_SOURCES = "Sync Data Sources" + PUSH_TO_HUGGINGFACE = "Push to Hugging Face" class ChangeLogOperationType(PyEnum): INSERT = "INSERT" diff --git a/src/db/models/instantiations/batch/sqlalchemy.py b/src/db/models/instantiations/batch/sqlalchemy.py index c1bf14fb..b001dbac 100644 --- a/src/db/models/instantiations/batch/sqlalchemy.py +++ b/src/db/models/instantiations/batch/sqlalchemy.py @@ -49,7 +49,8 @@ class Batch(StandardBase): urls = relationship( "URL", secondary="link_batch_urls", - back_populates="batch" + back_populates="batch", + overlaps="url" ) # missings = relationship("Missing", back_populates="batch") # Not in active use logs = relationship("Log", back_populates="batch") diff --git a/src/db/models/instantiations/state/__init__.py b/src/db/models/instantiations/state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/instantiations/state/huggingface.py b/src/db/models/instantiations/state/huggingface.py new file mode 100644 index 00000000..58e54cdc --- /dev/null +++ b/src/db/models/instantiations/state/huggingface.py @@ -0,0 +1,10 @@ +from sqlalchemy import Column, Integer, DateTime + +from src.db.models.templates import Base + + +class HuggingFaceUploadState(Base): + __tablename__ = "huggingface_upload_state" + + id = Column(Integer, primary_key=True) + last_upload_at = Column(DateTime, nullable=False) \ No newline at end of file diff --git a/src/db/models/instantiations/state/sync/__init__.py b/src/db/models/instantiations/state/sync/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/instantiations/sync_state/agencies.py b/src/db/models/instantiations/state/sync/agencies.py similarity index 100% rename from src/db/models/instantiations/sync_state/agencies.py rename to src/db/models/instantiations/state/sync/agencies.py diff --git a/src/db/models/instantiations/sync_state/data_sources.py b/src/db/models/instantiations/state/sync/data_sources.py similarity index 100% rename from src/db/models/instantiations/sync_state/data_sources.py rename to src/db/models/instantiations/state/sync/data_sources.py diff --git a/src/db/models/instantiations/url/compressed_html.py b/src/db/models/instantiations/url/compressed_html.py index 206348ac..92e340a5 100644 --- a/src/db/models/instantiations/url/compressed_html.py +++ b/src/db/models/instantiations/url/compressed_html.py @@ -1,5 +1,5 @@ from sqlalchemy import Column, LargeBinary -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, Mapped from src.db.models.mixins import CreatedAtMixin, URLDependentMixin from src.db.models.templates import StandardBase @@ -12,7 +12,7 @@ class URLCompressedHTML( ): __tablename__ = 'url_compressed_html' - compressed_html = Column(LargeBinary, nullable=False) + compressed_html: Mapped[bytes] = Column(LargeBinary, nullable=False) url = relationship( "URL", diff --git a/src/db/models/instantiations/url/core/sqlalchemy.py b/src/db/models/instantiations/url/core/sqlalchemy.py index c20343b6..8a476071 100644 --- a/src/db/models/instantiations/url/core/sqlalchemy.py +++ b/src/db/models/instantiations/url/core/sqlalchemy.py @@ -20,7 +20,7 @@ class URL(UpdatedAtMixin, CreatedAtMixin, StandardBase): # The metadata from the collector collector_metadata = Column(JSON) # The outcome of the URL: submitted, human_labeling, rejected, duplicate, etc. - outcome = enum_column( + outcome: Column = enum_column( URLStatus, name='url_status', nullable=False diff --git a/src/external/huggingface/hub/__init__.py b/src/external/huggingface/hub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/huggingface/hub/client.py b/src/external/huggingface/hub/client.py new file mode 100644 index 00000000..9bb63391 --- /dev/null +++ b/src/external/huggingface/hub/client.py @@ -0,0 +1,20 @@ + +from datasets import Dataset + +from src.external.huggingface.hub.constants import DATA_SOURCES_RAW_REPO_ID +from src.external.huggingface.hub.format import format_as_huggingface_dataset +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput + + +class HuggingFaceHubClient: + + def __init__(self, token: str): + self.token = token + + def _push_dataset_to_hub(self, repo_id: str, dataset: Dataset): + dataset.push_to_hub(repo_id=repo_id, token=self.token) + + def push_data_sources_raw_to_hub(self, outputs: list[GetForLoadingToHuggingFaceOutput]): + dataset = format_as_huggingface_dataset(outputs) + print(dataset) + self._push_dataset_to_hub(repo_id=DATA_SOURCES_RAW_REPO_ID, dataset=dataset) \ No newline at end of file diff --git a/src/external/huggingface/hub/constants.py b/src/external/huggingface/hub/constants.py new file mode 100644 index 00000000..2cffa4f8 --- /dev/null +++ b/src/external/huggingface/hub/constants.py @@ -0,0 +1,3 @@ + + +DATA_SOURCES_RAW_REPO_ID = "PDAP/data_sources_raw" \ No newline at end of file diff --git a/src/external/huggingface/hub/format.py b/src/external/huggingface/hub/format.py new file mode 100644 index 00000000..b103d31d --- /dev/null +++ b/src/external/huggingface/hub/format.py @@ -0,0 +1,23 @@ +from datasets import Dataset + +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput + + +def format_as_huggingface_dataset(outputs: list[GetForLoadingToHuggingFaceOutput]) -> Dataset: + d = { + 'url_id': [], + 'url': [], + 'relevant': [], + 'record_type_fine': [], + 'record_type_coarse': [], + 'html': [] + } + for output in outputs: + d['url_id'].append(output.url_id) + d['url'].append(output.url) + d['relevant'].append(output.relevant) + d['record_type_fine'].append(output.record_type_fine) + d['record_type_coarse'].append(output.record_type_coarse) + d['html'].append(output.html) + return Dataset.from_dict(d) + diff --git a/tests/automated/integration/api/metrics/urls/aggregated/test_core.py b/tests/automated/integration/api/metrics/urls/aggregated/test_core.py index 15b48f1e..c8957952 100644 --- a/tests/automated/integration/api/metrics/urls/aggregated/test_core.py +++ b/tests/automated/integration/api/metrics/urls/aggregated/test_core.py @@ -22,7 +22,7 @@ async def test_get_urls_aggregated_metrics(api_test_helper): ] ) batch_0 = await ath.db_data_creator.batch_v2(batch_0_params) - oldest_url_id = batch_0.url_creation_infos[URLStatus.PENDING].url_mappings[0].url_id + oldest_url_id = batch_0.urls_by_status[URLStatus.PENDING].url_mappings[0].url_id batch_1_params = TestBatchCreationParameters( diff --git a/tests/automated/integration/api/test_annotate.py b/tests/automated/integration/api/test_annotate.py index b0039212..c4b1f33c 100644 --- a/tests/automated/integration/api/test_annotate.py +++ b/tests/automated/integration/api/test_annotate.py @@ -21,7 +21,7 @@ from tests.helpers.setup.annotate_agency.model import AnnotateAgencySetupInfo from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import BatchURLCreationInfo +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo from tests.automated.integration.api.conftest import MOCK_USER_ID def check_url_mappings_match( diff --git a/tests/automated/integration/core/async_/conclude_task/test_error.py b/tests/automated/integration/core/async_/conclude_task/test_error.py index 0f92fd26..2b8c1996 100644 --- a/tests/automated/integration/core/async_/conclude_task/test_error.py +++ b/tests/automated/integration/core/async_/conclude_task/test_error.py @@ -1,13 +1,11 @@ import pytest from src.core.enums import BatchStatus -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome -from src.db.enums import TaskType from tests.automated.integration.core.async_.conclude_task.helpers import setup_run_info from tests.automated.integration.core.async_.conclude_task.setup_info import TestAsyncCoreSetupInfo from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/core/async_/conclude_task/test_success.py b/tests/automated/integration/core/async_/conclude_task/test_success.py index 19bd0f4f..54de38f1 100644 --- a/tests/automated/integration/core/async_/conclude_task/test_success.py +++ b/tests/automated/integration/core/async_/conclude_task/test_success.py @@ -1,13 +1,11 @@ import pytest from src.core.enums import BatchStatus -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome -from src.db.enums import TaskType from tests.automated.integration.core.async_.conclude_task.helpers import setup_run_info from tests.automated.integration.core.async_.conclude_task.setup_info import TestAsyncCoreSetupInfo from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/core/async_/run_task/test_break_loop.py b/tests/automated/integration/core/async_/run_task/test_break_loop.py index e438c26d..303ee39d 100644 --- a/tests/automated/integration/core/async_/run_task/test_break_loop.py +++ b/tests/automated/integration/core/async_/run_task/test_break_loop.py @@ -7,7 +7,7 @@ from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/core/async_/run_task/test_prereq_met.py b/tests/automated/integration/core/async_/run_task/test_prereq_met.py index b171402d..00484e15 100644 --- a/tests/automated/integration/core/async_/run_task/test_prereq_met.py +++ b/tests/automated/integration/core/async_/run_task/test_prereq_met.py @@ -9,7 +9,7 @@ from src.db.enums import TaskType from src.db.models.instantiations.task.core import Task from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py b/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py index 37ed6462..0c261097 100644 --- a/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py +++ b/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py @@ -3,7 +3,7 @@ from src.db.constants import PLACEHOLDER_AGENCY_NAME from src.db.models.instantiations.agency.sqlalchemy import Agency from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py b/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py index ccf76dc8..1653da61 100644 --- a/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py +++ b/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py @@ -3,7 +3,7 @@ from src.core.enums import SuggestedStatus from src.db.dtos.url.mapping import URLMapping from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/approve_url/test_basic.py b/tests/automated/integration/db/client/approve_url/test_basic.py index df783e84..f438426f 100644 --- a/tests/automated/integration/db/client/approve_url/test_basic.py +++ b/tests/automated/integration/db/client/approve_url/test_basic.py @@ -8,7 +8,7 @@ from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata from src.db.models.instantiations.url.reviewing_user import ReviewingUserURL from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/approve_url/test_error.py b/tests/automated/integration/db/client/approve_url/test_error.py index 1e7b92d8..9523a16c 100644 --- a/tests/automated/integration/db/client/approve_url/test_error.py +++ b/tests/automated/integration/db/client/approve_url/test_error.py @@ -4,7 +4,7 @@ from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo from src.core.enums import RecordType from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py index adb48844..3f5c3182 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py @@ -2,7 +2,7 @@ from src.core.enums import SuggestedStatus, RecordType from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py index bce7d8e2..ad4fe3d6 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py @@ -1,7 +1,7 @@ import pytest from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py index 874dba18..38e0527c 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py @@ -2,7 +2,7 @@ from src.core.enums import SuggestionType from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py index 4b04d4d1..72430fec 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py @@ -5,7 +5,7 @@ from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py index b82ebee2..b278352c 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py @@ -1,6 +1,6 @@ import pytest -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py index 6c9a29c8..7e68ada4 100644 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py +++ b/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py @@ -1,7 +1,7 @@ import pytest from src.collectors.enums import URLStatus -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py index 57c6ae35..9c452f15 100644 --- a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py +++ b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py @@ -2,7 +2,7 @@ from src.core.enums import SuggestedStatus from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py index 3736c2b8..95e40847 100644 --- a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py +++ b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py @@ -2,7 +2,7 @@ from src.collectors.enums import URLStatus from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_add_url_error_info.py b/tests/automated/integration/db/client/test_add_url_error_info.py index 3bb25e58..55e84836 100644 --- a/tests/automated/integration/db/client/test_add_url_error_info.py +++ b/tests/automated/integration/db/client/test_add_url_error_info.py @@ -2,7 +2,7 @@ from src.db.client.async_ import AsyncDatabaseClient from src.db.models.instantiations.url.error_info.pydantic import URLErrorPydanticInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_delete_old_logs.py b/tests/automated/integration/db/client/test_delete_old_logs.py index 1a5b0cd7..61f94af0 100644 --- a/tests/automated/integration/db/client/test_delete_old_logs.py +++ b/tests/automated/integration/db/client/test_delete_old_logs.py @@ -3,7 +3,7 @@ import pytest from src.db.models.instantiations.log.pydantic.info import LogInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_delete_url_updated_at.py b/tests/automated/integration/db/client/test_delete_url_updated_at.py index 34bbc7b3..620e0318 100644 --- a/tests/automated/integration/db/client/test_delete_url_updated_at.py +++ b/tests/automated/integration/db/client/test_delete_url_updated_at.py @@ -1,5 +1,5 @@ from src.db.models.instantiations.url.core.pydantic import URLInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def test_delete_url_updated_at(db_data_creator: DBDataCreator): diff --git a/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py b/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py index 5a402727..a1df2164 100644 --- a/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py +++ b/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py @@ -2,7 +2,7 @@ from src.core.enums import SuggestionType from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py b/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py index 8f03286c..707399c9 100644 --- a/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py +++ b/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py @@ -1,7 +1,7 @@ import pytest from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py b/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py index 292ab33f..203cb710 100644 --- a/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py +++ b/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py @@ -2,7 +2,7 @@ from src.core.enums import RecordType from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_insert_logs.py b/tests/automated/integration/db/client/test_insert_logs.py index 6da198d8..dff43790 100644 --- a/tests/automated/integration/db/client/test_insert_logs.py +++ b/tests/automated/integration/db/client/test_insert_logs.py @@ -1,7 +1,7 @@ import pytest from src.db.models.instantiations.log.pydantic.info import LogInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/structure/test_html_content.py b/tests/automated/integration/db/structure/test_html_content.py index 8c9c3207..936a8a25 100644 --- a/tests/automated/integration/db/structure/test_html_content.py +++ b/tests/automated/integration/db/structure/test_html_content.py @@ -6,7 +6,7 @@ from src.util.helper_functions import get_enum_values from tests.automated.integration.db.structure.testers.models.column import ColumnTester from tests.automated.integration.db.structure.testers.table import TableTester -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def test_html_content(db_data_creator: DBDataCreator): diff --git a/tests/automated/integration/db/structure/test_root_url.py b/tests/automated/integration/db/structure/test_root_url.py index 7c3712df..8f8be80b 100644 --- a/tests/automated/integration/db/structure/test_root_url.py +++ b/tests/automated/integration/db/structure/test_root_url.py @@ -2,7 +2,7 @@ from tests.automated.integration.db.structure.testers.models.column import ColumnTester from tests.automated.integration.db.structure.testers.table import TableTester -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def test_root_url(db_data_creator: DBDataCreator): diff --git a/tests/automated/integration/db/structure/test_upsert_new_agencies.py b/tests/automated/integration/db/structure/test_upsert_new_agencies.py index 17a184f4..0993c7a7 100644 --- a/tests/automated/integration/db/structure/test_upsert_new_agencies.py +++ b/tests/automated/integration/db/structure/test_upsert_new_agencies.py @@ -3,7 +3,7 @@ from src.core.enums import SuggestionType from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo from src.db.models.instantiations.agency.sqlalchemy import Agency -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/structure/test_url.py b/tests/automated/integration/db/structure/test_url.py index c9c3cf79..1c14d519 100644 --- a/tests/automated/integration/db/structure/test_url.py +++ b/tests/automated/integration/db/structure/test_url.py @@ -5,7 +5,7 @@ from src.util.helper_functions import get_enum_values from tests.automated.integration.db.structure.testers.models.column import ColumnTester from tests.automated.integration.db.structure.testers.table import TableTester -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def test_url(db_data_creator: DBDataCreator): diff --git a/tests/automated/integration/tasks/asserts.py b/tests/automated/integration/tasks/asserts.py index 224e56a1..fa69d4a1 100644 --- a/tests/automated/integration/tasks/asserts.py +++ b/tests/automated/integration/tasks/asserts.py @@ -1,4 +1,5 @@ from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome @@ -10,7 +11,9 @@ async def assert_prereqs_met(operator): meets_prereqs = await operator.meets_task_prerequisites() assert meets_prereqs +def assert_task_ran_without_error(run_info: TaskOperatorRunInfo): + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message -def assert_task_has_expected_run_info(run_info: TaskOperatorRunInfo, url_ids: list[int]): +def assert_url_task_has_expected_run_info(run_info: URLTaskOperatorRunInfo, url_ids: list[int]): assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message assert run_info.linked_url_ids == url_ids diff --git a/tests/automated/integration/tasks/scheduled/huggingface/__init__.py b/tests/automated/integration/tasks/scheduled/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/scheduled/huggingface/conftest.py b/tests/automated/integration/tasks/scheduled/huggingface/conftest.py new file mode 100644 index 00000000..29d397b4 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/conftest.py @@ -0,0 +1,14 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.core.tasks.scheduled.huggingface.operator import PushToHuggingFaceTaskOperator +from src.external.huggingface.hub.client import HuggingFaceHubClient + + +@pytest.fixture +def operator(adb_client_test): + yield PushToHuggingFaceTaskOperator( + adb_client=adb_client_test, + hf_client=AsyncMock(spec=HuggingFaceHubClient) + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/__init__.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/data.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/data.py new file mode 100644 index 00000000..d28aa8f2 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/data.py @@ -0,0 +1,71 @@ +from src.collectors.enums import URLStatus +from src.core.enums import RecordType +from src.core.tasks.scheduled.huggingface.queries.get.enums import RecordTypeCoarse +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.entry \ + import TestPushToHuggingFaceURLSetupEntry as Entry +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.output import \ + TestPushToHuggingFaceURLSetupExpectedOutput as Output +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.input import \ + TestPushToHuggingFaceURLSetupEntryInput as Input + +ENTRIES = [ + # Because pending, should not be picked up + Entry( + input=Input( + outcome=URLStatus.PENDING, + has_html_content=True, + record_type=RecordType.INCARCERATION_RECORDS + ), + expected_output=Output( + picked_up=False, + ) + ), + # Because no html content, should not be picked up + Entry( + input=Input( + outcome=URLStatus.SUBMITTED, + has_html_content=False, + record_type=RecordType.RECORDS_REQUEST_INFO + ), + expected_output=Output( + picked_up=False, + ) + ), + # Remainder should be picked up + Entry( + input=Input( + outcome=URLStatus.VALIDATED, + has_html_content=True, + record_type=RecordType.RECORDS_REQUEST_INFO + ), + expected_output=Output( + picked_up=True, + coarse_record_type=RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + relevant=True + ) + ), + Entry( + input=Input( + outcome=URLStatus.SUBMITTED, + has_html_content=True, + record_type=RecordType.INCARCERATION_RECORDS + ), + expected_output=Output( + picked_up=True, + coarse_record_type=RecordTypeCoarse.JAILS_AND_COURTS, + relevant=True + ) + ), + Entry( + input=Input( + outcome=URLStatus.NOT_RELEVANT, + has_html_content=True, + record_type=None + ), + expected_output=Output( + picked_up=True, + coarse_record_type=RecordTypeCoarse.NOT_RELEVANT, + relevant=False + ) + ), +] diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/manager.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/manager.py new file mode 100644 index 00000000..9b6606d2 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/manager.py @@ -0,0 +1,46 @@ +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from src.db.client.async_ import AsyncDatabaseClient +from tests.automated.integration.tasks.scheduled.huggingface.setup.data import ENTRIES +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.output import \ + TestPushToHuggingFaceURLSetupExpectedOutput +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.record import \ + TestPushToHuggingFaceRecordSetupRecord as Record, TestPushToHuggingFaceRecordSetupRecord +from tests.automated.integration.tasks.scheduled.huggingface.setup.queries.setup import \ + SetupTestPushToHuggingFaceEntryQueryBuilder +from tests.helpers.data_creator.core import DBDataCreator + + +class PushToHuggingFaceTestSetupManager: + + def __init__(self, adb_client: AsyncDatabaseClient): + self.adb_client = adb_client + self.entries = ENTRIES + # Connects a URL ID to the expectation that it will be picked up + self._id_to_record: dict[int, TestPushToHuggingFaceRecordSetupRecord] = {} + + async def setup(self) -> None: + records: list[Record] = await self.adb_client.run_query_builder( + SetupTestPushToHuggingFaceEntryQueryBuilder(self.entries) + ) + for record in records: + if not record.expected_output.picked_up: + continue + self._id_to_record[record.url_id] = record + + def check_results(self, outputs: list[GetForLoadingToHuggingFaceOutput]) -> None: + # Check that both expected and actual results are same length + length_expected = len(self._id_to_record.keys()) + length_actual = len(outputs) + assert length_expected == length_actual, f"Expected {length_expected} results, got {length_actual}" + + # Check attributes of each URL match what is expected + for output in outputs: + url_id = output.url_id + record = self._id_to_record[url_id] + + expected_output = record.expected_output + assert output.relevant == expected_output.relevant + assert output.record_type_coarse == expected_output.coarse_record_type, \ + f"Expected {expected_output.coarse_record_type} but got {output.record_type_coarse}" + assert output.record_type_fine == record.record_type_fine + diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/models/__init__.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/models/entry.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/entry.py new file mode 100644 index 00000000..e072a1b6 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/entry.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.input import \ + TestPushToHuggingFaceURLSetupEntryInput +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.output import \ + TestPushToHuggingFaceURLSetupExpectedOutput + + +class TestPushToHuggingFaceURLSetupEntry(BaseModel): + input: TestPushToHuggingFaceURLSetupEntryInput + expected_output: TestPushToHuggingFaceURLSetupExpectedOutput + diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/models/input.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/input.py new file mode 100644 index 00000000..cd68782e --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/input.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.core.enums import RecordType + + +class TestPushToHuggingFaceURLSetupEntryInput(BaseModel): + outcome: URLStatus + record_type: RecordType | None + has_html_content: bool diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/models/output.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/output.py new file mode 100644 index 00000000..c1303543 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/output.py @@ -0,0 +1,22 @@ +from typing import Self + +from pydantic import BaseModel, model_validator + +from src.core.enums import RecordType +from src.core.tasks.scheduled.huggingface.queries.get.enums import RecordTypeCoarse + + +class TestPushToHuggingFaceURLSetupExpectedOutput(BaseModel): + picked_up: bool + relevant: bool | None = None + coarse_record_type: RecordTypeCoarse | None = None + + @model_validator(mode='after') + def validate_coarse_record_type_and_relevant(self) -> Self: + if not self.picked_up: + return self + if self.coarse_record_type is None: + raise ValueError('Coarse record type should be provided if picked up') + if self.relevant is None: + raise ValueError('Relevant should be provided if picked up') + return self diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/models/record.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/record.py new file mode 100644 index 00000000..becabc17 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/models/record.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.core.enums import RecordType +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.output import \ + TestPushToHuggingFaceURLSetupExpectedOutput + + +class TestPushToHuggingFaceRecordSetupRecord(BaseModel): + expected_output: TestPushToHuggingFaceURLSetupExpectedOutput + record_type_fine: RecordType | None + url_id: int \ No newline at end of file diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/queries/__init__.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/scheduled/huggingface/setup/queries/setup.py b/tests/automated/integration/tasks/scheduled/huggingface/setup/queries/setup.py new file mode 100644 index 00000000..dc0a3452 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/setup/queries/setup.py @@ -0,0 +1,55 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.instantiations.url.compressed_html import URLCompressedHTML +from src.db.models.instantiations.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.utils.compression import compress_html +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.entry import \ + TestPushToHuggingFaceURLSetupEntry as Entry +from tests.automated.integration.tasks.scheduled.huggingface.setup.models.record import \ + TestPushToHuggingFaceRecordSetupRecord as Record + + +class SetupTestPushToHuggingFaceEntryQueryBuilder(QueryBuilderBase): + + def __init__( + self, + entries: list[Entry] + ): + super().__init__() + self.entries = entries + + async def run(self, session: AsyncSession) -> list[Record]: + records = [] + for idx, entry in enumerate(self.entries): + if idx % 2 == 0: + name = "Test Push to Hugging Face URL Setup Entry" + description = "This is a test push to Hugging Face URL setup entry" + else: + name = None + description = None + inp = entry.input + url = URL( + url=f"www.testPushToHuggingFaceURLSetupEntry.com/{idx}", + outcome=inp.outcome, + name=name, + description=description, + record_type=inp.record_type, + ) + session.add(url) + await session.flush() + if entry.input.has_html_content: + compressed_html = URLCompressedHTML( + url_id=url.id, + compressed_html=compress_html(f"
Test Push to Hugging Face URL Setup Entry {idx}
"), + ) + session.add(compressed_html) + record = Record( + url_id=url.id, + expected_output=entry.expected_output, + record_type_fine=inp.record_type + ) + records.append(record) + + return records + diff --git a/tests/automated/integration/tasks/scheduled/huggingface/test_happy_path.py b/tests/automated/integration/tasks/scheduled/huggingface/test_happy_path.py new file mode 100644 index 00000000..d5eca4a7 --- /dev/null +++ b/tests/automated/integration/tasks/scheduled/huggingface/test_happy_path.py @@ -0,0 +1,42 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.core.tasks.scheduled.huggingface.operator import PushToHuggingFaceTaskOperator +from src.core.tasks.scheduled.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from tests.automated.integration.tasks.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.scheduled.huggingface.setup.manager import PushToHuggingFaceTestSetupManager +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_happy_path( + operator: PushToHuggingFaceTaskOperator, + db_data_creator: DBDataCreator +): + hf_client = operator.hf_client + push_function: AsyncMock = hf_client.push_data_sources_raw_to_hub + + # Check, prior to adding URLs, that task does not run + task_info = await operator.run_task(1) + assert_task_ran_without_error(task_info) + push_function.assert_not_called() + + # Add URLs + manager = PushToHuggingFaceTestSetupManager(adb_client=db_data_creator.adb_client) + await manager.setup() + + # Run task + task_info = await operator.run_task(2) + assert_task_ran_without_error(task_info) + push_function.assert_called_once() + + call_args: list[GetForLoadingToHuggingFaceOutput] = push_function.call_args.args[0] + + # Check for calls to HF Client + manager.check_results(call_args) + + # Test that after update, running again yields no results + task_info = await operator.run_task(3) + assert_task_ran_without_error(task_info) + push_function.assert_called_once() \ No newline at end of file diff --git a/tests/automated/integration/tasks/scheduled/sync/agency/helpers.py b/tests/automated/integration/tasks/scheduled/sync/agency/helpers.py index a60f0586..7c35a654 100644 --- a/tests/automated/integration/tasks/scheduled/sync/agency/helpers.py +++ b/tests/automated/integration/tasks/scheduled/sync/agency/helpers.py @@ -6,7 +6,7 @@ from src.db.client.async_ import AsyncDatabaseClient from src.db.models.instantiations.agency.sqlalchemy import Agency -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState from src.external.pdap.client import PDAPClient from tests.automated.integration.tasks.scheduled.sync.agency.data import PREEXISTING_AGENCIES diff --git a/tests/automated/integration/tasks/scheduled/sync/agency/test_interruption.py b/tests/automated/integration/tasks/scheduled/sync/agency/test_interruption.py index 41f4b86c..2f112175 100644 --- a/tests/automated/integration/tasks/scheduled/sync/agency/test_interruption.py +++ b/tests/automated/integration/tasks/scheduled/sync/agency/test_interruption.py @@ -4,7 +4,7 @@ from src.core.tasks.scheduled.sync.agency.operator import SyncAgenciesTaskOperator from src.core.tasks.url.enums import TaskOperatorOutcome from src.db.models.instantiations.agency.sqlalchemy import Agency -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState from tests.automated.integration.tasks.scheduled.sync.agency.data import FIRST_CALL_RESPONSE, \ THIRD_CALL_RESPONSE, SECOND_CALL_RESPONSE from tests.automated.integration.tasks.scheduled.sync.agency.existence_checker import AgencyChecker diff --git a/tests/automated/integration/tasks/scheduled/sync/agency/test_no_new_results.py b/tests/automated/integration/tasks/scheduled/sync/agency/test_no_new_results.py index 20a179bd..18fd263b 100644 --- a/tests/automated/integration/tasks/scheduled/sync/agency/test_no_new_results.py +++ b/tests/automated/integration/tasks/scheduled/sync/agency/test_no_new_results.py @@ -7,7 +7,7 @@ from src.core.tasks.scheduled.sync.agency.dtos.parameters import AgencySyncParameters from src.core.tasks.scheduled.sync.agency.operator import SyncAgenciesTaskOperator from src.db.models.instantiations.agency.sqlalchemy import Agency -from src.db.models.instantiations.sync_state.agencies import AgenciesSyncState +from src.db.models.instantiations.state.sync.agencies import AgenciesSyncState from tests.automated.integration.tasks.scheduled.sync.agency.data import THIRD_CALL_RESPONSE from tests.automated.integration.tasks.scheduled.sync.agency.existence_checker import AgencyChecker from tests.automated.integration.tasks.scheduled.sync.agency.helpers import patch_sync_agencies, check_sync_concluded diff --git a/tests/automated/integration/tasks/scheduled/sync/data_sources/check.py b/tests/automated/integration/tasks/scheduled/sync/data_sources/check.py index 5968831f..e5a3c4ba 100644 --- a/tests/automated/integration/tasks/scheduled/sync/data_sources/check.py +++ b/tests/automated/integration/tasks/scheduled/sync/data_sources/check.py @@ -3,7 +3,7 @@ from sqlalchemy import select, cast, func, TIMESTAMP from src.db.client.async_ import AsyncDatabaseClient -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState from src.db.models.instantiations.url.core.sqlalchemy import URL diff --git a/tests/automated/integration/tasks/scheduled/sync/data_sources/conftest.py b/tests/automated/integration/tasks/scheduled/sync/data_sources/conftest.py index 470504ab..017a9894 100644 --- a/tests/automated/integration/tasks/scheduled/sync/data_sources/conftest.py +++ b/tests/automated/integration/tasks/scheduled/sync/data_sources/conftest.py @@ -2,7 +2,7 @@ from src.core.tasks.scheduled.sync.data_sources.operator import SyncDataSourcesTaskOperator from src.external.pdap.client import PDAPClient -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest_asyncio.fixture diff --git a/tests/automated/integration/tasks/scheduled/sync/data_sources/test_interruption.py b/tests/automated/integration/tasks/scheduled/sync/data_sources/test_interruption.py index 955c33fb..81fb8806 100644 --- a/tests/automated/integration/tasks/scheduled/sync/data_sources/test_interruption.py +++ b/tests/automated/integration/tasks/scheduled/sync/data_sources/test_interruption.py @@ -3,7 +3,7 @@ from src.core.tasks.scheduled.sync.data_sources.operator import SyncDataSourcesTaskOperator from src.core.tasks.url.enums import TaskOperatorOutcome -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState from tests.automated.integration.tasks.scheduled.sync.data_sources.check import check_sync_concluded from tests.automated.integration.tasks.scheduled.sync.data_sources.setup.core import patch_sync_data_sources from tests.automated.integration.tasks.scheduled.sync.data_sources.setup.data import ENTRIES diff --git a/tests/automated/integration/tasks/scheduled/sync/data_sources/test_no_new_results.py b/tests/automated/integration/tasks/scheduled/sync/data_sources/test_no_new_results.py index f32a12ec..880c2ef3 100644 --- a/tests/automated/integration/tasks/scheduled/sync/data_sources/test_no_new_results.py +++ b/tests/automated/integration/tasks/scheduled/sync/data_sources/test_no_new_results.py @@ -5,7 +5,7 @@ from src.core.tasks.scheduled.sync.data_sources.operator import SyncDataSourcesTaskOperator from src.core.tasks.scheduled.sync.data_sources.params import DataSourcesSyncParameters -from src.db.models.instantiations.sync_state.data_sources import DataSourcesSyncState +from src.db.models.instantiations.state.sync.data_sources import DataSourcesSyncState from tests.automated.integration.tasks.scheduled.sync.data_sources.check import check_sync_concluded from tests.automated.integration.tasks.scheduled.sync.data_sources.setup.core import patch_sync_data_sources from tests.automated.integration.tasks.scheduled.sync.data_sources.setup.data import ENTRIES diff --git a/tests/automated/integration/tasks/url/auto_relevant/setup.py b/tests/automated/integration/tasks/url/auto_relevant/setup.py index fdd17e16..38c57409 100644 --- a/tests/automated/integration/tasks/url/auto_relevant/setup.py +++ b/tests/automated/integration/tasks/url/auto_relevant/setup.py @@ -5,7 +5,7 @@ from src.external.huggingface.inference.models.output import BasicOutput from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfoV2 +from tests.helpers.data_creator.core import DBDataCreator async def setup_operator(adb_client: AsyncDatabaseClient) -> URLAutoRelevantTaskOperator: diff --git a/tests/automated/integration/tasks/url/auto_relevant/test_task.py b/tests/automated/integration/tasks/url/auto_relevant/test_task.py index 886cec09..fab2edfe 100644 --- a/tests/automated/integration/tasks/url/auto_relevant/test_task.py +++ b/tests/automated/integration/tasks/url/auto_relevant/test_task.py @@ -7,7 +7,7 @@ from src.db.models.instantiations.url.core.sqlalchemy import URL from src.db.models.instantiations.url.error_info.sqlalchemy import URLErrorInfo from src.db.models.instantiations.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion -from tests.automated.integration.tasks.asserts import assert_prereqs_not_met, assert_task_has_expected_run_info, \ +from tests.automated.integration.tasks.asserts import assert_prereqs_not_met, assert_url_task_has_expected_run_info, \ assert_prereqs_met from tests.automated.integration.tasks.url.auto_relevant.setup import setup_operator, setup_urls @@ -25,7 +25,7 @@ async def test_url_auto_relevant_task(db_data_creator): run_info = await operator.run_task(task_id) - assert_task_has_expected_run_info(run_info, url_ids) + assert_url_task_has_expected_run_info(run_info, url_ids) adb_client = db_data_creator.adb_client # Get URLs, confirm one is marked as error diff --git a/tests/automated/integration/tasks/url/duplicate/test_url_duplicate_task.py b/tests/automated/integration/tasks/url/duplicate/test_url_duplicate_task.py index 816724b8..bd66e409 100644 --- a/tests/automated/integration/tasks/url/duplicate/test_url_duplicate_task.py +++ b/tests/automated/integration/tasks/url/duplicate/test_url_duplicate_task.py @@ -10,7 +10,7 @@ from src.collectors.enums import URLStatus from src.core.tasks.url.enums import TaskOperatorOutcome from tests.automated.integration.tasks.url.duplicate.constants import BATCH_CREATION_PARAMETERS -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from pdap_access_manager import ResponseInfo from src.external.pdap.client import PDAPClient @@ -32,7 +32,7 @@ async def test_url_duplicate_task( # Add three URLs to the database, one of which is in error, the other two pending creation_info = await db_data_creator.batch_v2(BATCH_CREATION_PARAMETERS) - pending_urls: list[URLMapping] = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings + pending_urls: list[URLMapping] = creation_info.urls_by_status[URLStatus.PENDING].url_mappings duplicate_url = pending_urls[0] non_duplicate_url = pending_urls[1] assert await operator.meets_task_prerequisites() diff --git a/tests/automated/integration/tasks/url/html/asserts.py b/tests/automated/integration/tasks/url/html/asserts.py index 5566aab6..9ca241cd 100644 --- a/tests/automated/integration/tasks/url/html/asserts.py +++ b/tests/automated/integration/tasks/url/html/asserts.py @@ -1,4 +1,6 @@ +from src.api.endpoints.task.by_id.dto import TaskInfo from src.collectors.enums import URLStatus +from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.db.client.async_ import AsyncDatabaseClient from src.db.enums import TaskType from tests.automated.integration.tasks.url.html.mocks.constants import MOCK_HTML_CONTENT @@ -46,5 +48,5 @@ def assert_task_type_is_html(task_info): assert task_info.task_type == TaskType.HTML -def assert_task_ran_without_error(task_info): +def assert_html_task_ran_without_error(task_info: TaskInfo): assert task_info.error_info is None diff --git a/tests/automated/integration/tasks/url/html/test_task.py b/tests/automated/integration/tasks/url/html/test_task.py index e39d7576..2592713f 100644 --- a/tests/automated/integration/tasks/url/html/test_task.py +++ b/tests/automated/integration/tasks/url/html/test_task.py @@ -2,10 +2,10 @@ from src.db.enums import TaskType from tests.automated.integration.tasks.url.html.asserts import assert_success_url_has_two_html_content_entries, assert_404_url_has_404_status, assert_task_has_one_url_error, \ - assert_task_type_is_html, assert_task_ran_without_error, assert_url_has_one_compressed_html_content_entry -from tests.automated.integration.tasks.asserts import assert_prereqs_not_met, assert_task_has_expected_run_info + assert_task_type_is_html, assert_html_task_ran_without_error, assert_url_has_one_compressed_html_content_entry +from tests.automated.integration.tasks.asserts import assert_prereqs_not_met, assert_url_task_has_expected_run_info from tests.automated.integration.tasks.url.html.setup import setup_urls, setup_operator -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -22,14 +22,14 @@ async def test_url_html_task(db_data_creator: DBDataCreator): task_id = await db_data_creator.adb_client.initiate_task(task_type=TaskType.HTML) run_info = await operator.run_task(task_id) - assert_task_has_expected_run_info(run_info, url_ids) + assert_url_task_has_expected_run_info(run_info, url_ids) task_info = await db_data_creator.adb_client.get_task_info( task_id=operator.task_id ) - assert_task_ran_without_error(task_info) + assert_html_task_ran_without_error(task_info) assert_task_type_is_html(task_info) assert_task_has_one_url_error(task_info) diff --git a/tests/automated/integration/tasks/url/submit_approved/setup.py b/tests/automated/integration/tasks/url/submit_approved/setup.py index cdf88d97..c1a1d4f4 100644 --- a/tests/automated/integration/tasks/url/submit_approved/setup.py +++ b/tests/automated/integration/tasks/url/submit_approved/setup.py @@ -1,6 +1,7 @@ from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo from src.core.enums import RecordType -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfo +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo async def setup_validated_urls(db_data_creator: DBDataCreator) -> list[str]: diff --git a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py b/tests/automated/integration/tasks/url/test_agency_preannotation_task.py index f7b75f51..d11a1def 100644 --- a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py +++ b/tests/automated/integration/tasks/url/test_agency_preannotation_task.py @@ -26,7 +26,8 @@ from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo from src.external.pdap.client import PDAPClient -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfoV2 +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 sample_agency_suggestions = [ URLAgencySuggestionInfo( @@ -127,7 +128,7 @@ async def mock_run_subtask( ] ) ) - d[strategy] = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings[0].url_id + d[strategy] = creation_info.urls_by_status[URLStatus.PENDING].url_mappings[0].url_id # Confirm meets prerequisites diff --git a/tests/automated/integration/tasks/url/test_example_task.py b/tests/automated/integration/tasks/url/test_example_task.py index 9a2a2fc9..06678658 100644 --- a/tests/automated/integration/tasks/url/test_example_task.py +++ b/tests/automated/integration/tasks/url/test_example_task.py @@ -5,7 +5,7 @@ from src.db.enums import TaskType from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.base import URLTaskOperatorBase -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator class ExampleTaskOperator(URLTaskOperatorBase): diff --git a/tests/automated/integration/tasks/url/test_url_404_probe.py b/tests/automated/integration/tasks/url/test_url_404_probe.py index 8966e416..54592640 100644 --- a/tests/automated/integration/tasks/url/test_url_404_probe.py +++ b/tests/automated/integration/tasks/url/test_url_404_probe.py @@ -12,7 +12,7 @@ from src.collectors.enums import URLStatus from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.url_html.scraper.request_interface.dtos.url_response import URLResponseInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters @@ -102,12 +102,12 @@ async def mock_make_simple_requests(self, urls: list[str]) -> list[URLResponseIn assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - pending_url_mappings = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings + pending_url_mappings = creation_info.urls_by_status[URLStatus.PENDING].url_mappings url_id_success = pending_url_mappings[0].url_id url_id_404 = pending_url_mappings[1].url_id url_id_error = pending_url_mappings[2].url_id - url_id_initial_error = creation_info.url_creation_infos[URLStatus.ERROR].url_mappings[0].url_id + url_id_initial_error = creation_info.urls_by_status[URLStatus.ERROR].url_mappings[0].url_id # Check that URLProbedFor404 has been appropriately populated probed_for_404_objects: list[URLProbedFor404] = await db_data_creator.adb_client.get_all(URLProbedFor404) diff --git a/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py b/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py index e9f55240..ed7f1336 100644 --- a/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py +++ b/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py @@ -7,7 +7,7 @@ from src.db.models.instantiations.url.core.sqlalchemy import URL from src.collectors.enums import CollectorType from src.core.tasks.url.enums import TaskOperatorOutcome -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def batch_and_url( diff --git a/tests/automated/integration/tasks/url/test_url_record_type_task.py b/tests/automated/integration/tasks/url/test_url_record_type_task.py index 514aa716..3ea95811 100644 --- a/tests/automated/integration/tasks/url/test_url_record_type_task.py +++ b/tests/automated/integration/tasks/url/test_url_record_type_task.py @@ -7,7 +7,7 @@ from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator from src.core.enums import RecordType -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from src.core.tasks.url.operators.record_type.llm_api.record_classifier.deepseek import DeepSeekRecordClassifier @pytest.mark.asyncio diff --git a/tests/conftest.py b/tests/conftest.py index 4e724563..e3789b45 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ from src.db.helpers.connect import get_postgres_connection_string from src.util.helper_functions import load_from_environment from tests.helpers.alembic_runner import AlembicRunner -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.populate import populate_database from tests.helpers.setup.wipe import wipe_database @@ -42,7 +42,8 @@ def setup_and_teardown(): "PDAP_API_URL", "DISCORD_WEBHOOK_URL", "OPENAI_API_KEY", - "HUGGINGFACE_INFERENCE_API_KEY" + "HUGGINGFACE_INFERENCE_API_KEY", + "HUGGINGFACE_HUB_TOKEN" ] all_env_vars = required_env_vars.copy() for env_var in test_env_vars: diff --git a/tests/helpers/api_test_helper.py b/tests/helpers/api_test_helper.py index 55a85345..2ff51f98 100644 --- a/tests/helpers/api_test_helper.py +++ b/tests/helpers/api_test_helper.py @@ -5,7 +5,7 @@ from src.core.core import AsyncCore from src.core.enums import BatchStatus from tests.automated.integration.api._helpers.RequestValidator import RequestValidator -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @dataclass diff --git a/tests/helpers/data_creator/__init__.py b/tests/helpers/data_creator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/db_data_creator.py b/tests/helpers/data_creator/core.py similarity index 94% rename from tests/helpers/db_data_creator.py rename to tests/helpers/data_creator/core.py index a8d8331a..696ca104 100644 --- a/tests/helpers/db_data_creator.py +++ b/tests/helpers/data_creator/core.py @@ -1,9 +1,8 @@ +from collections import defaultdict from datetime import datetime from random import randint from typing import List, Optional -from pydantic import BaseModel - from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo from src.api.endpoints.review.enums import RejectionReason @@ -16,7 +15,6 @@ from src.db.models.instantiations.url.error_info.pydantic import URLErrorPydanticInfo from src.db.dtos.url.html_content import URLHTMLContentInfo, HTMLContentType from src.db.models.instantiations.url.core.pydantic import URLInfo -from src.db.dtos.url.mapping import URLMapping from src.db.client.sync import DatabaseClient from src.db.dtos.url.raw_html import RawHTMLInfo from src.db.enums import TaskType @@ -26,35 +24,12 @@ from src.core.enums import BatchStatus, SuggestionType, RecordType, SuggestedStatus from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo from tests.helpers.simple_test_data_functions import generate_test_urls -class URLCreationInfo(BaseModel): - url_mappings: list[URLMapping] - outcome: URLStatus - annotation_info: Optional[AnnotationInfo] = None - - @property - def url_ids(self) -> list[int]: - return [url_mapping.url_id for url_mapping in self.url_mappings] - -class BatchURLCreationInfoV2(BaseModel): - batch_id: int - url_creation_infos: dict[URLStatus, URLCreationInfo] - - @property - def url_ids(self) -> list[int]: - url_creation_infos = self.url_creation_infos.values() - url_ids = [] - for url_creation_info in url_creation_infos: - url_ids.extend(url_creation_info.url_ids) - return url_ids - -class BatchURLCreationInfo(BaseModel): - batch_id: int - url_ids: list[int] - urls: list[str] - class DBDataCreator: """ Assists in the creation of test data @@ -92,18 +67,20 @@ async def batch_v2( self, parameters: TestBatchCreationParameters ) -> BatchURLCreationInfoV2: + # Create batch batch_id = self.batch( strategy=parameters.strategy, batch_status=parameters.outcome, created_at=parameters.created_at ) + # Return early if batch would not involve URL creation if parameters.outcome in (BatchStatus.ERROR, BatchStatus.ABORTED): return BatchURLCreationInfoV2( batch_id=batch_id, - url_creation_infos={} ) - d: dict[URLStatus, URLCreationInfo] = {} + urls_by_status: dict[URLStatus, URLCreationInfo] = {} + urls_by_order: list[URLCreationInfo] = [] # Create urls for url_parameters in parameters.urls: iui: InsertURLsInfo = self.urls( @@ -122,14 +99,17 @@ async def batch_v2( annotation_info=url_parameters.annotation_info ) - d[url_parameters.status] = URLCreationInfo( + creation_info = URLCreationInfo( url_mappings=iui.url_mappings, outcome=url_parameters.status, annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None ) + urls_by_order.append(creation_info) + urls_by_status[url_parameters.status] = creation_info + return BatchURLCreationInfoV2( batch_id=batch_id, - url_creation_infos=d + urls_by_status=urls_by_status, ) async def batch_and_urls( diff --git a/tests/helpers/data_creator/models/__init__.py b/tests/helpers/data_creator/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/creation_info/__init__.py b/tests/helpers/data_creator/models/creation_info/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/creation_info/batch/__init__.py b/tests/helpers/data_creator/models/creation_info/batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/creation_info/batch/v1.py b/tests/helpers/data_creator/models/creation_info/batch/v1.py new file mode 100644 index 00000000..d5451eca --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/batch/v1.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class BatchURLCreationInfo(BaseModel): + batch_id: int + url_ids: list[int] + urls: list[str] diff --git a/tests/helpers/data_creator/models/creation_info/batch/v2.py b/tests/helpers/data_creator/models/creation_info/batch/v2.py new file mode 100644 index 00000000..3e6ed74a --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/batch/v2.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class BatchURLCreationInfoV2(BaseModel): + batch_id: int + urls_by_status: dict[URLStatus, URLCreationInfo] = {} + + @property + def url_ids(self) -> list[int]: + url_creation_infos = self.urls_by_status.values() + url_ids = [] + for url_creation_info in url_creation_infos: + url_ids.extend(url_creation_info.url_ids) + return url_ids diff --git a/tests/helpers/data_creator/models/creation_info/url.py b/tests/helpers/data_creator/models/creation_info/url.py new file mode 100644 index 00000000..082769e7 --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/url.py @@ -0,0 +1,17 @@ +from typing import Optional + +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.db.dtos.url.mapping import URLMapping +from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo + + +class URLCreationInfo(BaseModel): + url_mappings: list[URLMapping] + outcome: URLStatus + annotation_info: Optional[AnnotationInfo] = None + + @property + def url_ids(self) -> list[int]: + return [url_mapping.url_id for url_mapping in self.url_mappings] diff --git a/tests/helpers/setup/annotate_agency/core.py b/tests/helpers/setup/annotate_agency/core.py index fbd7bc53..6827194d 100644 --- a/tests/helpers/setup/annotate_agency/core.py +++ b/tests/helpers/setup/annotate_agency/core.py @@ -1,5 +1,6 @@ from src.core.enums import SuggestionType -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfo +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo from tests.helpers.setup.annotate_agency.model import AnnotateAgencySetupInfo diff --git a/tests/helpers/setup/annotation/core.py b/tests/helpers/setup/annotation/core.py index d8d3bb0c..ff5105cd 100644 --- a/tests/helpers/setup/annotation/core.py +++ b/tests/helpers/setup/annotation/core.py @@ -1,5 +1,5 @@ from src.collectors.enums import URLStatus -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.annotation.model import AnnotationSetupInfo diff --git a/tests/helpers/setup/final_review/core.py b/tests/helpers/setup/final_review/core.py index 87c4da59..d9c3aa10 100644 --- a/tests/helpers/setup/final_review/core.py +++ b/tests/helpers/setup/final_review/core.py @@ -2,7 +2,7 @@ from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo from src.core.enums import RecordType -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.final_review.model import FinalReviewSetupInfo diff --git a/tests/manual/html_collector/test_html_tag_collector_integration.py b/tests/manual/html_collector/test_html_tag_collector_integration.py index bc48da9f..ef8f0df3 100644 --- a/tests/manual/html_collector/test_html_tag_collector_integration.py +++ b/tests/manual/html_collector/test_html_tag_collector_integration.py @@ -6,7 +6,7 @@ from src.core.tasks.url.operators.url_html.scraper.root_url_cache.core import RootURLCache from src.db.client.async_ import AsyncDatabaseClient from src.db.models.instantiations.url.core.pydantic import URLInfo -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator URLS = [ "https://pdap.io",