diff --git a/Dockerfile b/Dockerfile index 85931528..e96272b0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,6 +14,8 @@ RUN uv sync --locked --no-dev # Must call from the root directory because uv does not add playwright to path RUN playwright install-deps chromium RUN playwright install chromium +# Download Spacy Model +RUN python -m spacy download en_core_web_sm # Copy project files COPY src ./src diff --git a/ENV.md b/ENV.md index a2e84f24..b957bc11 100644 --- a/ENV.md +++ b/ENV.md @@ -2,28 +2,119 @@ This page provides a full list, with description, of all the environment variabl Please ensure these are properly defined in a `.env` file in the root directory. -| Name | Description | Example | -|----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------------------------------| -| `GOOGLE_API_KEY` | The API key required for accessing the Google Custom Search API | `abc123` | -| `GOOGLE_CSE_ID` | The CSE ID required for accessing the Google Custom Search API | `abc123` | -|`POSTGRES_USER` | The username for the test database | `test_source_collector_user` | -|`POSTGRES_PASSWORD` | The password for the test database | `HanviliciousHamiltonHilltops` | -|`POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | -|`POSTGRES_HOST` | The host for the test database | `127.0.0.1` | -|`POSTGRES_PORT` | The port for the test database | `5432` | -|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | -|`DEV`| Set to any value to run the application in development mode. | `true` | -|`DEEPSEEK_API_KEY`| The API key required for accessing the DeepSeek API. | `abc123` | -|`OPENAI_API_KEY`| The API key required for accessing the OpenAI API. | `abc123` | -|`PDAP_EMAIL`| An email address for accessing the PDAP API.[^1] | `abc123@test.com` | -|`PDAP_PASSWORD`| A password for accessing the PDAP API.[^1] | `abc123` | -|`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | -|`PDAP_API_URL`| The URL for the PDAP API| `https://data-sources-v2.pdap.dev/api`| -|`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` | -|`HUGGINGFACE_INFERENCE_API_KEY` | The API key required for accessing the Huggingface Inference API. | `abc123` | +| Name | Description | Example | +|---------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------| +| `GOOGLE_API_KEY` | The API key required for accessing the Google Custom Search API | `abc123` | +| `GOOGLE_CSE_ID` | The CSE ID required for accessing the Google Custom Search API | `abc123` | +| `POSTGRES_USER` | The username for the test database | `test_source_collector_user` | +| `POSTGRES_PASSWORD` | The password for the test database | `HanviliciousHamiltonHilltops` | +| `POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | +| `POSTGRES_HOST` | The host for the test database | `127.0.0.1` | +| `POSTGRES_PORT` | The port for the test database | `5432` | +| `DS_APP_SECRET_KEY` | The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | +| `DEV` | Set to any value to run the application in development mode. | `true` | +| `DEEPSEEK_API_KEY` | The API key required for accessing the DeepSeek API. | `abc123` | +| `OPENAI_API_KEY` | The API key required for accessing the OpenAI API. | `abc123` | +| `PDAP_EMAIL` | An email address for accessing the PDAP API.[^1] | `abc123@test.com` | +| `PDAP_PASSWORD` | A password for accessing the PDAP API.[^1] | `abc123` | +| `PDAP_API_KEY` | An API key for accessing the PDAP API. | `abc123` | +| `PDAP_API_URL` | The URL for the PDAP API | `https://data-sources-v2.pdap.dev/api` | +| `DISCORD_WEBHOOK_URL` | The URL for the Discord webhook used for notifications | `abc123` | +| `HUGGINGFACE_INFERENCE_API_KEY` | The API key required for accessing the Hugging Face Inference API. | `abc123` | +| `HUGGINGFACE_HUB_TOKEN` | The API key required for uploading to the PDAP HuggingFace account via Hugging Face Hub API. | `abc123` | +| `INTERNET_ARCHIVE_S3_KEYS` | Keys used for saving a URL to the Internet Archives. | 'abc123:gpb0dk` | + + [^1:] The user account in question will require elevated permissions to access certain endpoints. At a minimum, the user will require the `source_collector` and `db_write` permissions. +# Variables With Defaults + +The following environment variables have default values that will be used if not otherwise defined. + +| Variable | Description | Default | +|-------------------------------|------------------------------------------------------------------|---------| +| `URL_TASKS_FREQUENCY_MINUTES` | The frequency for the `RUN_URL_TASKS` Scheduled Task, in minutes | `60` | + +# Flags + +Flags are used to enable/disable certain features. They are set to `1` to enable the feature and `0` to disable the feature. By default, all flags are enabled. + +## Configuration Flags + +Configuration flags are used to enable/disable certain configurations. + +| Flag | Description | +|--------------|--------------------------------------| +| `POST_TO_DISCORD_FLAG` | Enables posting errors to discord. | +| `PROGRESS_BAR_FLAG` | Enables progress bars on some tasks. | + + +## Task Flags +Task flags are used to enable/disable certain tasks. + +Note that some tasks/subtasks are themselves enabled by other tasks. + +### Scheduled Task Flags + +| Flag | Description | +|-------------------------------------|-------------------------------------------------------------------------------| +| `SCHEDULED_TASKS_FLAG` | All scheduled tasks. Disabling disables all other scheduled tasks. | +| `PUSH_TO_HUGGING_FACE_TASK_FLAG` | Pushes data to HuggingFace. | +| `POPULATE_BACKLOG_SNAPSHOT_TASK_FLAG` | Populates the backlog snapshot. | +| `DELETE_OLD_LOGS_TASK_FLAG` | Deletes old logs. | +| `RUN_URL_TASKS_TASK_FLAG` | Runs URL tasks. | +| `IA_PROBE_TASK_FLAG` | Extracts and links Internet Archives metadata to URLs. | +| `IA_SAVE_TASK_FLAG` | Saves URLs to Internet Archives. | +| `MARK_TASK_NEVER_COMPLETED_TASK_FLAG` | Marks tasks that were started but never completed (usually due to a restart). | +| `DELETE_STALE_SCREENSHOTS_TASK_FLAG` | Deletes stale screenshots for URLs already validated. | +| `TASK_CLEANUP_TASK_FLAG` | Cleans up tasks that are no longer needed. | +| `REFRESH_MATERIALIZED_VIEWS_TASK_FLAG` | Refreshes materialized views. | + +### URL Task Flags + +URL Task Flags are collectively controlled by the `RUN_URL_TASKS_TASK_FLAG` flag. + + +| Flag | Description | +|-------------------------------------|-------------------------------------------------------| +| `URL_HTML_TASK_FLAG` | URL HTML scraping task. | +| `URL_RECORD_TYPE_TASK_FLAG` | Automatically assigns Record Types to URLs. | +| `URL_AGENCY_IDENTIFICATION_TASK_FLAG` | Automatically assigns and suggests Agencies for URLs. | +| `URL_SUBMIT_APPROVED_TASK_FLAG` | Submits approved URLs to the Data Sources App. | +| `URL_MISC_METADATA_TASK_FLAG` | Adds misc metadata to URLs. | +| `URL_AUTO_RELEVANCE_TASK_FLAG` | Automatically assigns Relevances to URLs. | +| `URL_PROBE_TASK_FLAG` | Probes URLs for web metadata. | +| `URL_ROOT_URL_TASK_FLAG` | Extracts and links Root URLs to URLs. | +| `URL_SCREENSHOT_TASK_FLAG` | Takes screenshots of URLs. | +| `URL_AUTO_VALIDATE_TASK_FLAG` | Automatically validates URLs. | +| `URL_AUTO_NAME_TASK_FLAG` | Automatically names URLs. | +| `URL_SUSPEND_TASK_FLAG` | Suspends URLs meeting suspension criteria. | +| `URL_SUBMIT_META_URLS_TASK_FLAG` | Submits meta URLs to the Data Sources App. | + +### Agency ID Subtasks + +Agency ID Subtasks are collectively disabled by the `URL_AGENCY_IDENTIFICATION_TASK_FLAG` flag. + +| Flag | Description | +|-------------------------------------|-------------------------------------------------------------------| +| `AGENCY_ID_HOMEPAGE_MATCH_FLAG` | Enables the homepage match subtask for agency identification. | +| `AGENCY_ID_NLP_LOCATION_MATCH_FLAG` | Enables the NLP location match subtask for agency identification. | +| `AGENCY_ID_CKAN_FLAG` | Enables the CKAN subtask for agency identification. | +| `AGENCY_ID_MUCKROCK_FLAG` | Enables the MuckRock subtask for agency identification. | +| `AGENCY_ID_BATCH_LINK_FLAG` | Enables the Batch Link subtask for agency identification. | + + +### Location ID Subtasks + +Location ID Subtasks are collectively disabled by the `URL_LOCATION_IDENTIFICATION_TASK_FLAG` flag + +| Flag | Description | +|---------------------------------------|---------------------------------------------------------------------| +| `LOCATION_ID_NLP_LOCATION_MATCH_FLAG` | Enables the NLP location match subtask for location identification. | +| `LOCATION_ID_BATCH_LINK_FLAG` | Enables the Batch Link subtask for location identification. | + + ## Foreign Data Wrapper (FDW) ``` FDW_DATA_SOURCES_HOST=127.0.0.1 # The host of the Data Sources Database, used for FDW setup diff --git a/alembic/env.py b/alembic/env.py index 3d305e32..ff14698b 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,4 +1,3 @@ -import logging from datetime import datetime from logging.config import fileConfig @@ -6,8 +5,8 @@ from sqlalchemy import engine_from_config from sqlalchemy import pool -from src.db.helpers import get_postgres_connection_string -from src.db.models.templates import Base +from src.db.helpers.connect import get_postgres_connection_string +from src.db.models.templates_.base import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/2025_07_21_0637-59d2af1bab33_setup_for_sync_data_sources_task.py b/alembic/versions/2025_07_21_0637-59d2af1bab33_setup_for_sync_data_sources_task.py new file mode 100644 index 00000000..9e990bc1 --- /dev/null +++ b/alembic/versions/2025_07_21_0637-59d2af1bab33_setup_for_sync_data_sources_task.py @@ -0,0 +1,285 @@ +"""Setup for sync data sources task + +Revision ID: 59d2af1bab33 +Revises: 9552d354ccf4 +Create Date: 2025-07-21 06:37:51.043504 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +from src.util.alembic_helpers import switch_enum_type, id_column + +# revision identifiers, used by Alembic. +revision: str = '59d2af1bab33' +down_revision: Union[str, None] = '9552d354ccf4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +SYNC_STATE_TABLE_NAME = "data_sources_sync_state" +URL_DATA_SOURCES_METADATA_TABLE_NAME = "url_data_sources_metadata" + +CONFIRMED_AGENCY_TABLE_NAME = "confirmed_url_agency" +LINK_URLS_AGENCIES_TABLE_NAME = "link_urls_agencies" +CHANGE_LOG_TABLE_NAME = "change_log" + +AGENCIES_TABLE_NAME = "agencies" + +TABLES_TO_LOG = [ + LINK_URLS_AGENCIES_TABLE_NAME, + "urls", + "url_data_sources", + "agencies", +] + +OperationTypeEnum = sa.Enum("UPDATE", "DELETE", "INSERT", name="operation_type") + + +def upgrade() -> None: + _create_data_sources_sync_state_table() + _create_data_sources_sync_task() + + _rename_confirmed_url_agency_to_link_urls_agencies() + _create_change_log_table() + _add_jsonb_diff_val_function() + _create_log_table_changes_trigger() + + + _add_table_change_log_triggers() + _add_agency_id_column() + + + +def downgrade() -> None: + _drop_data_sources_sync_task() + _drop_data_sources_sync_state_table() + _drop_change_log_table() + _drop_table_change_log_triggers() + _drop_jsonb_diff_val_function() + _drop_log_table_changes_trigger() + + _rename_link_urls_agencies_to_confirmed_url_agency() + + OperationTypeEnum.drop(op.get_bind()) + _drop_agency_id_column() + + + +def _add_jsonb_diff_val_function() -> None: + op.execute( + """ + CREATE OR REPLACE FUNCTION jsonb_diff_val(val1 JSONB, val2 JSONB) + RETURNS JSONB AS + $$ + DECLARE + result JSONB; + v RECORD; + BEGIN + result = val1; + FOR v IN SELECT * FROM jsonb_each(val2) + LOOP + IF result @> jsonb_build_object(v.key, v.value) + THEN + result = result - v.key; + ELSIF result ? v.key THEN + CONTINUE; + ELSE + result = result || jsonb_build_object(v.key, 'null'); + END IF; + END LOOP; + RETURN result; + END; + $$ LANGUAGE plpgsql; + """ + ) + +def _drop_jsonb_diff_val_function() -> None: + op.execute("DROP FUNCTION IF EXISTS jsonb_diff_val(val1 JSONB, val2 JSONB)") + +def _create_log_table_changes_trigger() -> None: + op.execute( + f""" + CREATE OR REPLACE FUNCTION public.log_table_changes() + RETURNS trigger + LANGUAGE 'plpgsql' + COST 100 + VOLATILE NOT LEAKPROOF + AS $BODY$ + DECLARE + old_values JSONB; + new_values JSONB; + old_to_new JSONB; + new_to_old JSONB; + BEGIN + -- Handle DELETE operations (store entire OLD row since all data is lost) + IF (TG_OP = 'DELETE') THEN + old_values = row_to_json(OLD)::jsonb; + + INSERT INTO {CHANGE_LOG_TABLE_NAME} (operation_type, table_name, affected_id, old_data) + VALUES ('DELETE', TG_TABLE_NAME, OLD.id, old_values); + + RETURN OLD; + + -- Handle UPDATE operations (only log the changed columns) + ELSIF (TG_OP = 'UPDATE') THEN + old_values = row_to_json(OLD)::jsonb; + new_values = row_to_json(NEW)::jsonb; + new_to_old = jsonb_diff_val(old_values, new_values); + old_to_new = jsonb_diff_val(new_values, old_values); + + -- Skip logging if both old_to_new and new_to_old are NULL or empty JSON objects + IF (new_to_old IS NOT NULL AND new_to_old <> '{{}}') OR + (old_to_new IS NOT NULL AND old_to_new <> '{{}}') THEN + INSERT INTO {CHANGE_LOG_TABLE_NAME} (operation_type, table_name, affected_id, old_data, new_data) + VALUES ('UPDATE', TG_TABLE_NAME, OLD.id, new_to_old, old_to_new); + END IF; + + RETURN NEW; + + -- Handle INSERT operations + ELSIF (TG_OP = 'INSERT') THEN + new_values = row_to_json(NEW)::jsonb; + + -- Skip logging if new_values is NULL or an empty JSON object + IF new_values IS NOT NULL AND new_values <> '{{}}' THEN + INSERT INTO {CHANGE_LOG_TABLE_NAME} (operation_type, table_name, affected_id, new_data) + VALUES ('INSERT', TG_TABLE_NAME, NEW.id, new_values); + END IF; + + RETURN NEW; + END IF; + END; + $BODY$; + """ + ) + +def _drop_log_table_changes_trigger() -> None: + op.execute(f"DROP TRIGGER IF EXISTS log_table_changes ON {URL_DATA_SOURCES_METADATA_TABLE_NAME}") + +def _create_data_sources_sync_state_table() -> None: + table = op.create_table( + SYNC_STATE_TABLE_NAME, + id_column(), + sa.Column('last_full_sync_at', sa.DateTime(), nullable=True), + sa.Column('current_cutoff_date', sa.Date(), nullable=True), + sa.Column('current_page', sa.Integer(), nullable=True), + ) + # Add row to `data_sources_sync_state` table + op.bulk_insert( + table, + [ + { + "last_full_sync_at": None, + "current_cutoff_date": None, + "current_page": None + } + ] + ) + +def _drop_data_sources_sync_state_table() -> None: + op.drop_table(SYNC_STATE_TABLE_NAME) + +def _create_data_sources_sync_task() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources' + ] + ) + +def _drop_data_sources_sync_task() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + ] + ) + +def _create_change_log_table() -> None: + # Create change_log table + op.create_table( + CHANGE_LOG_TABLE_NAME, + id_column(), + sa.Column("operation_type", OperationTypeEnum, nullable=False), + sa.Column("table_name", sa.String(), nullable=False), + sa.Column("affected_id", sa.Integer(), nullable=False), + sa.Column("old_data", JSONB, nullable=True), + sa.Column("new_data", JSONB, nullable=True), + sa.Column( + "created_at", sa.DateTime(), server_default=sa.func.now(), nullable=False + ), + ) + +def _drop_change_log_table() -> None: + op.drop_table(CHANGE_LOG_TABLE_NAME) + +def _rename_confirmed_url_agency_to_link_urls_agencies() -> None: + op.rename_table(CONFIRMED_AGENCY_TABLE_NAME, LINK_URLS_AGENCIES_TABLE_NAME) + +def _rename_link_urls_agencies_to_confirmed_url_agency() -> None: + op.rename_table(LINK_URLS_AGENCIES_TABLE_NAME, CONFIRMED_AGENCY_TABLE_NAME) + +def _add_table_change_log_triggers() -> None: + # Create trigger for tables: + def create_table_trigger(table_name: str) -> None: + op.execute( + """ + CREATE OR REPLACE TRIGGER log_{table_name}_changes + BEFORE INSERT OR DELETE OR UPDATE + ON public.{table_name} + FOR EACH ROW + EXECUTE FUNCTION public.log_table_changes(); + """.format(table_name=table_name) + ) + + for table_name in TABLES_TO_LOG: + create_table_trigger(table_name) + +def _drop_table_change_log_triggers() -> None: + def drop_table_trigger(table_name: str) -> None: + op.execute( + f""" + DROP TRIGGER log_{table_name}_changes + ON public.{table_name} + """ + ) + + for table_name in TABLES_TO_LOG: + drop_table_trigger(table_name) + +def _add_agency_id_column(): + op.add_column( + AGENCIES_TABLE_NAME, + id_column(), + ) + + +def _drop_agency_id_column(): + op.drop_column( + AGENCIES_TABLE_NAME, + 'id', + ) diff --git a/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py b/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py new file mode 100644 index 00000000..45cf66a0 --- /dev/null +++ b/alembic/versions/2025_07_26_0830-637de6eaa3ab_setup_for_upload_to_huggingface_task.py @@ -0,0 +1,74 @@ +"""Setup for upload to huggingface task + +Revision ID: 637de6eaa3ab +Revises: 59d2af1bab33 +Create Date: 2025-07-26 08:30:37.940091 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '637de6eaa3ab' +down_revision: Union[str, None] = '59d2af1bab33' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +TABLE_NAME = "huggingface_upload_state" + + +def upgrade() -> None: + op.create_table( + TABLE_NAME, + id_column(), + sa.Column( + "last_upload_at", + sa.DateTime(), + nullable=False + ), + ) + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face' + ] + ) + + +def downgrade() -> None: + op.drop_table(TABLE_NAME) + + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources' + ] + ) diff --git a/alembic/versions/2025_07_31_1536-99eceed6e614_add_web_status_info_table.py b/alembic/versions/2025_07_31_1536-99eceed6e614_add_web_status_info_table.py new file mode 100644 index 00000000..891bef3a --- /dev/null +++ b/alembic/versions/2025_07_31_1536-99eceed6e614_add_web_status_info_table.py @@ -0,0 +1,156 @@ +"""Add HTML Status Info table + +Revision ID: 99eceed6e614 +Revises: 637de6eaa3ab +Create Date: 2025-07-31 15:36:40.966605 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, created_at_column, updated_at_column, url_id_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '99eceed6e614' +down_revision: Union[str, None] = '637de6eaa3ab' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +WEB_STATUS_ENUM = sa.Enum( + "not_attempted", + "success", + "error", + "404_not_found", + name="web_status" +) +SCRAPE_STATUS_ENUM = sa.Enum( + "success", + "error", + name="scrape_status", +) + +URL_WEB_METADATA_TABLE_NAME = 'url_web_metadata' +URL_SCRAPE_INFO = 'url_scrape_info' + + + + + +def upgrade() -> None: + _create_url_html_info_table() + _add_url_probe_task_type_enum() + _set_up_scrape_info_table() + _use_existing_html_data_to_add_scrape_info() + +def _use_existing_html_data_to_add_scrape_info(): + op.execute( + f""" + INSERT INTO {URL_SCRAPE_INFO} (url_id, status) + SELECT url_id, 'success'::scrape_status + FROM url_compressed_html + """ + ) + op.execute( + f""" + INSERT INTO {URL_SCRAPE_INFO} (url_id, status) + SELECT distinct(url_id), 'success'::scrape_status + FROM url_html_content + LEFT JOIN URL_COMPRESSED_HTML USING (url_id) + WHERE URL_COMPRESSED_HTML.url_id IS NULL + """ + ) + +def downgrade() -> None: + _drop_scrape_info_table() + # Drop Enums + WEB_STATUS_ENUM.drop(op.get_bind(), checkfirst=True) + _drop_url_probe_task_type_enum() + _tear_down_scrape_info_table() + + +def _set_up_scrape_info_table(): + op.create_table( + URL_SCRAPE_INFO, + id_column(), + url_id_column(), + sa.Column( + 'status', + SCRAPE_STATUS_ENUM, + nullable=False, + comment='The status of the most recent scrape attempt.' + ), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', name='uq_url_scrape_info_url_id') + ) + + + + +def _tear_down_scrape_info_table(): + op.drop_table(URL_SCRAPE_INFO) + # Drop enum + SCRAPE_STATUS_ENUM.drop(op.get_bind(), checkfirst=True) + + +def _add_url_probe_task_type_enum() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe' + ] + ) + +def _drop_url_probe_task_type_enum() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face' + ] + ) + +def _create_url_html_info_table() -> None: + op.create_table( + URL_WEB_METADATA_TABLE_NAME, + id_column(), + url_id_column(), + sa.Column('accessed', sa.Boolean(), nullable=False), + sa.Column('status_code', sa.Integer(), nullable=True), + sa.Column('content_type', sa.Text(), nullable=True), + sa.Column('error_message', sa.Text(), nullable=True), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', name='uq_url_web_status_info_url_id'), + sa.CheckConstraint('status_code >= 100', name='ck_url_web_status_info_status_code_min'), + sa.CheckConstraint('status_code <= 999', name='ck_url_web_status_info_status_code_max'), + ) + +def _drop_scrape_info_table() -> None: + op.drop_table(URL_WEB_METADATA_TABLE_NAME) diff --git a/alembic/versions/2025_08_03_1800-571ada5b81b9_add_link_urls_redirect_url_table.py b/alembic/versions/2025_08_03_1800-571ada5b81b9_add_link_urls_redirect_url_table.py new file mode 100644 index 00000000..33c2a8c6 --- /dev/null +++ b/alembic/versions/2025_08_03_1800-571ada5b81b9_add_link_urls_redirect_url_table.py @@ -0,0 +1,110 @@ +"""Add link_urls_redirect_url table + +Revision ID: 571ada5b81b9 +Revises: 99eceed6e614 +Create Date: 2025-08-03 18:00:06.345733 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, created_at_column, updated_at_column + +# revision identifiers, used by Alembic. +revision: str = '571ada5b81b9' +down_revision: Union[str, None] = '99eceed6e614' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +URLS_TABLE = 'urls' +LINK_URLS_REDIRECT_URL_TABLE = 'link_urls_redirect_url' + +SOURCE_ENUM = sa.Enum( + 'collector', + 'data_sources_app', + 'redirect', + 'root_url', + 'manual', + name='url_source' +) + +def upgrade() -> None: + _create_link_urls_redirect_url_table() + _add_source_column_to_urls_table() + + + +def downgrade() -> None: + _drop_link_urls_redirect_url_table() + _drop_source_column_from_urls_table() + + +def _create_link_urls_redirect_url_table(): + op.create_table( + LINK_URLS_REDIRECT_URL_TABLE, + id_column(), + sa.Column('source_url_id', sa.Integer(), nullable=False), + sa.Column('destination_url_id', sa.Integer(), nullable=False), + created_at_column(), + updated_at_column(), + sa.ForeignKeyConstraint(['source_url_id'], [URLS_TABLE + '.id'], ), + sa.ForeignKeyConstraint(['destination_url_id'], [URLS_TABLE + '.id'], ), + sa.UniqueConstraint( + 'source_url_id', + 'destination_url_id', + name='link_urls_redirect_url_uq_source_url_id_destination_url_id' + ), + ) + + +def _add_source_column_to_urls_table(): + # Create enum + SOURCE_ENUM.create(op.get_bind(), checkfirst=True) + op.add_column( + URLS_TABLE, + sa.Column( + 'source', + SOURCE_ENUM, + nullable=True, + comment='The source of the URL.' + ) + ) + # Add sources to existing URLs + op.execute( + f"""UPDATE {URLS_TABLE} + SET source = 'collector'::url_source + """ + ) + op.execute( + f"""UPDATE {URLS_TABLE} + SET source = 'data_sources_app'::url_source + FROM url_data_sources WHERE url_data_sources.url_id = {URLS_TABLE}.id + AND url_data_sources.data_source_id IS NOT NULL; + """ + ) + op.execute( + f"""UPDATE {URLS_TABLE} + SET source = 'collector'::url_source + FROM link_batch_urls WHERE link_batch_urls.url_id = {URLS_TABLE}.id + AND link_batch_urls.batch_id IS NOT NULL; + """ + ) + + # Make source required + op.alter_column( + URLS_TABLE, + 'source', + nullable=False + ) + + +def _drop_link_urls_redirect_url_table(): + op.drop_table(LINK_URLS_REDIRECT_URL_TABLE) + + +def _drop_source_column_from_urls_table(): + op.drop_column(URLS_TABLE, 'source') + # Drop enum + SOURCE_ENUM.drop(op.get_bind(), checkfirst=True) diff --git a/alembic/versions/2025_08_09_2031-8cd5aa7670ff_remove_functional_duplicates.py b/alembic/versions/2025_08_09_2031-8cd5aa7670ff_remove_functional_duplicates.py new file mode 100644 index 00000000..201d2448 --- /dev/null +++ b/alembic/versions/2025_08_09_2031-8cd5aa7670ff_remove_functional_duplicates.py @@ -0,0 +1,124 @@ +"""Remove functional duplicates and setup constraints on fragments and nbsp + +Revision ID: 8cd5aa7670ff +Revises: 571ada5b81b9 +Create Date: 2025-08-09 20:31:58.865231 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '8cd5aa7670ff' +down_revision: Union[str, None] = '571ada5b81b9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +COMPRESSED_HTML_FOREIGN_KEY_NAME = 'fk_url_compressed_html_url_id' +COMPRESSED_HTML_TABLE_NAME = 'url_compressed_html' + +URL_HTML_CONTENT_FOREIGN_KEY_NAME = 'url_html_content_url_id_fkey' +URL_HTML_CONTENT_TABLE_NAME = 'url_html_content' + +URL_ERROR_INFO_TABLE_NAME = 'url_error_info' +URL_ERROR_INFO_FOREIGN_KEY_NAME = 'url_error_info_url_id_fkey' + +URLS_NBSP_CHECK_CONSTRAINT_NAME = 'urls_nbsp_check' +URLS_FRAGMENTS_CHECK_CONSTRAINT_NAME = 'urls_fragments_check' + +AUTOMATED_URL_AGENCY_SUGGESTION_TABLE_NAME = 'automated_url_agency_suggestions' +AUTOMATED_URL_AGENCY_SUGGESTION_FOREIGN_KEY_NAME = 'automated_url_agency_suggestions_url_id_fkey' + + +def upgrade() -> None: + _add_cascade_foreign_key(URL_HTML_CONTENT_TABLE_NAME, foreign_key_name=URL_HTML_CONTENT_FOREIGN_KEY_NAME) + _add_cascade_foreign_key(COMPRESSED_HTML_TABLE_NAME, foreign_key_name=COMPRESSED_HTML_FOREIGN_KEY_NAME) + _add_cascade_foreign_key(URL_ERROR_INFO_TABLE_NAME, foreign_key_name=URL_ERROR_INFO_FOREIGN_KEY_NAME) + _add_cascade_foreign_key(AUTOMATED_URL_AGENCY_SUGGESTION_TABLE_NAME, foreign_key_name=AUTOMATED_URL_AGENCY_SUGGESTION_FOREIGN_KEY_NAME) + _remove_data_source_urls() + _reset_data_sources_sync_state() + _add_constraint_forbidding_nbsp() + _delete_duplicate_urls() + _remove_fragments_from_urls() + _add_constraint_forbidding_fragments() + + +def downgrade() -> None: + _remove_constraint_forbidding_fragments() + _remove_constraint_forbidding_nbsp() + _remove_cascade_foreign_key(URL_ERROR_INFO_TABLE_NAME, foreign_key_name=URL_ERROR_INFO_FOREIGN_KEY_NAME) + _remove_cascade_foreign_key(COMPRESSED_HTML_TABLE_NAME, foreign_key_name=COMPRESSED_HTML_FOREIGN_KEY_NAME) + _remove_cascade_foreign_key(URL_HTML_CONTENT_TABLE_NAME, foreign_key_name=URL_HTML_CONTENT_FOREIGN_KEY_NAME) + # _remove_cascade_foreign_key(AUTOMATED_URL_AGENCY_SUGGESTION_TABLE_NAME, foreign_key_name=AUTOMATED_URL_AGENCY_SUGGESTION_FOREIGN_KEY_NAME) + +def _delete_duplicate_urls() -> None: + op.execute('delete from urls where id in (2341,2343,2344,2347,2348,2349,2354,2359,2361,2501,2504,2505,2506,2507)') + +def _create_url_foreign_key_with_cascade(table_name: str, foreign_key_name: str) -> None: + op.create_foreign_key( + foreign_key_name, + table_name, + referent_table='urls', + local_cols=['url_id'], remote_cols=['id'], + ondelete='CASCADE' + ) + +def _create_url_foreign_key_without_cascade(table_name: str, foreign_key_name: str) -> None: + op.create_foreign_key( + foreign_key_name, + table_name, + referent_table='urls', + local_cols=['url_id'], remote_cols=['id'] + ) + +def _remove_cascade_foreign_key(table_name: str, foreign_key_name: str) -> None: + op.drop_constraint(foreign_key_name, table_name=table_name, type_='foreignkey') + _create_url_foreign_key_without_cascade(table_name, foreign_key_name=foreign_key_name) + +def _add_cascade_foreign_key(table_name: str, foreign_key_name: str) -> None: + op.drop_constraint(foreign_key_name, table_name=table_name, type_='foreignkey') + _create_url_foreign_key_with_cascade(table_name, foreign_key_name=foreign_key_name) + +def _remove_data_source_urls() -> None: + op.execute(""" + delete from urls + where source = 'data_sources_app' + """ + ) + +def _reset_data_sources_sync_state() -> None: + op.execute(""" + delete from data_sources_sync_state + """ + ) + +def _add_constraint_forbidding_nbsp() -> None: + op.create_check_constraint( + constraint_name=URLS_NBSP_CHECK_CONSTRAINT_NAME, + table_name='urls', + condition="url not like '% %'" + ) + +def _add_constraint_forbidding_fragments() -> None: + op.create_check_constraint( + constraint_name=URLS_FRAGMENTS_CHECK_CONSTRAINT_NAME, + table_name='urls', + condition="url not like '%#%'" + ) + +def _remove_constraint_forbidding_nbsp() -> None: + op.drop_constraint(URLS_NBSP_CHECK_CONSTRAINT_NAME, table_name='urls', type_='check') + +def _remove_constraint_forbidding_fragments() -> None: + op.drop_constraint(URLS_FRAGMENTS_CHECK_CONSTRAINT_NAME, table_name='urls', type_='check') + +def _remove_fragments_from_urls() -> None: + # Remove fragments and everything after them + op.execute(""" + update urls + set url = substring(url from 1 for position('#' in url) - 1) + where url like '%#%' + """) \ No newline at end of file diff --git a/alembic/versions/2025_08_10_1032-11ece61d7ac2_add_scheduled_tasks.py b/alembic/versions/2025_08_10_1032-11ece61d7ac2_add_scheduled_tasks.py new file mode 100644 index 00000000..97fbd655 --- /dev/null +++ b/alembic/versions/2025_08_10_1032-11ece61d7ac2_add_scheduled_tasks.py @@ -0,0 +1,63 @@ +"""Add scheduled tasks + +Revision ID: 11ece61d7ac2 +Revises: 8cd5aa7670ff +Create Date: 2025-08-10 10:32:11.400714 + +""" +from typing import Sequence, Union + +from src.util.alembic_helpers import switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '11ece61d7ac2' +down_revision: Union[str, None] = '8cd5aa7670ff' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles' + ] + ) + + +def downgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources' + 'Push to Hugging Face', + 'URL Probe' + ] + ) diff --git a/alembic/versions/2025_08_10_2046-5930e70660c5_change_url_outcome_to_url_status.py b/alembic/versions/2025_08_10_2046-5930e70660c5_change_url_outcome_to_url_status.py new file mode 100644 index 00000000..c24d5ac8 --- /dev/null +++ b/alembic/versions/2025_08_10_2046-5930e70660c5_change_url_outcome_to_url_status.py @@ -0,0 +1,26 @@ +"""Change URL outcome to URL status + +Revision ID: 5930e70660c5 +Revises: 11ece61d7ac2 +Create Date: 2025-08-10 20:46:58.576623 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '5930e70660c5' +down_revision: Union[str, None] = '11ece61d7ac2' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.alter_column('urls', 'outcome', new_column_name='status') + + +def downgrade() -> None: + op.alter_column('urls', 'status', new_column_name='outcome') diff --git a/alembic/versions/2025_08_11_0914-c14d669d7c0d_change_link_table_nomenclature.py b/alembic/versions/2025_08_11_0914-c14d669d7c0d_change_link_table_nomenclature.py new file mode 100644 index 00000000..834f81fb --- /dev/null +++ b/alembic/versions/2025_08_11_0914-c14d669d7c0d_change_link_table_nomenclature.py @@ -0,0 +1,28 @@ +"""Change Link table nomenclature + +Revision ID: c14d669d7c0d +Revises: 5930e70660c5 +Create Date: 2025-08-11 09:14:08.034093 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c14d669d7c0d' +down_revision: Union[str, None] = '5930e70660c5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +OLD_URL_DATA_SOURCE_NAME = "url_data_sources" +NEW_URL_DATA_SOURCE_NAME = "url_data_source" + +def upgrade() -> None: + op.rename_table(OLD_URL_DATA_SOURCE_NAME, NEW_URL_DATA_SOURCE_NAME) + + +def downgrade() -> None: + op.rename_table(NEW_URL_DATA_SOURCE_NAME, OLD_URL_DATA_SOURCE_NAME) diff --git a/alembic/versions/2025_08_11_0931-9a56916ea7d8_remove_agencies_ds_last_updated_at.py b/alembic/versions/2025_08_11_0931-9a56916ea7d8_remove_agencies_ds_last_updated_at.py new file mode 100644 index 00000000..a14cf32b --- /dev/null +++ b/alembic/versions/2025_08_11_0931-9a56916ea7d8_remove_agencies_ds_last_updated_at.py @@ -0,0 +1,31 @@ +"""Remove agencies.ds_last_updated_at + +Revision ID: 9a56916ea7d8 +Revises: c14d669d7c0d +Create Date: 2025-08-11 09:31:18.268319 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '9a56916ea7d8' +down_revision: Union[str, None] = 'c14d669d7c0d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +COLUMN_NAME = "ds_last_updated_at" +TABLE_NAME = "agencies" + +def upgrade() -> None: + op.drop_column(TABLE_NAME, COLUMN_NAME) + + +def downgrade() -> None: + op.add_column( + table_name=TABLE_NAME, + column=sa.Column(COLUMN_NAME, sa.DateTime(), nullable=False), + ) diff --git a/alembic/versions/2025_08_12_0819-49fd9f295b8d_refine_root_table_logic.py b/alembic/versions/2025_08_12_0819-49fd9f295b8d_refine_root_table_logic.py new file mode 100644 index 00000000..28b1f049 --- /dev/null +++ b/alembic/versions/2025_08_12_0819-49fd9f295b8d_refine_root_table_logic.py @@ -0,0 +1,147 @@ +"""Refine root table logic + +Revision ID: 49fd9f295b8d +Revises: 9a56916ea7d8 +Create Date: 2025-08-12 08:19:08.170835 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, updated_at_column, url_id_column, created_at_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '49fd9f295b8d' +down_revision: Union[str, None] = '9a56916ea7d8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +ROOT_URLS_TABLE_NAME = "root_urls" +ROOT_URL_CACHE_TABLE_NAME = "root_url_cache" + +LINK_URLS_ROOT_URL_TABLE_NAME = "link_urls_root_url" +FLAG_ROOT_URL_TABLE_NAME = "flag_root_url" + + + + +def upgrade() -> None: + _drop_root_url_cache() + _drop_root_urls() + _create_flag_root_url() + _create_link_urls_root_url() + _add_root_url_task_enum() + + +def downgrade() -> None: + _create_root_url_cache() + _create_root_urls() + _drop_link_urls_root_url() + _drop_flag_root_url() + _remove_root_url_task_enum() + +def _add_root_url_task_enum(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL' + ] + ) + + +def _remove_root_url_task_enum(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles' + ] + ) + + +def _drop_root_url_cache(): + op.drop_table(ROOT_URL_CACHE_TABLE_NAME) + +def _drop_root_urls(): + op.drop_table(ROOT_URLS_TABLE_NAME) + +def _create_root_url_cache(): + op.create_table( + ROOT_URL_CACHE_TABLE_NAME, + id_column(), + sa.Column('url', sa.String(), nullable=False), + sa.Column('page_title', sa.String(), nullable=False), + sa.Column('page_description', sa.String(), nullable=True), + updated_at_column(), + sa.UniqueConstraint('url', name='root_url_cache_uq_url') + ) + +def _create_root_urls(): + op.create_table( + ROOT_URLS_TABLE_NAME, + id_column(), + sa.Column('url', sa.String(), nullable=False), + sa.Column('page_title', sa.String(), nullable=False), + sa.Column('page_description', sa.String(), nullable=True), + updated_at_column(), + sa.UniqueConstraint('url', name='uq_root_url_url') + ) + +def _create_link_urls_root_url(): + op.create_table( + LINK_URLS_ROOT_URL_TABLE_NAME, + id_column(), + url_id_column(), + url_id_column('root_url_id'), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', 'root_url_id') + ) + +def _drop_link_urls_root_url(): + op.drop_table(LINK_URLS_ROOT_URL_TABLE_NAME) + +def _create_flag_root_url(): + op.create_table( + FLAG_ROOT_URL_TABLE_NAME, + url_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint('url_id') + ) + +def _drop_flag_root_url(): + op.drop_table(FLAG_ROOT_URL_TABLE_NAME) \ No newline at end of file diff --git a/alembic/versions/2025_08_14_0722-2a7192657354_add_internet_archive_tables.py b/alembic/versions/2025_08_14_0722-2a7192657354_add_internet_archive_tables.py new file mode 100644 index 00000000..afdaecbe --- /dev/null +++ b/alembic/versions/2025_08_14_0722-2a7192657354_add_internet_archive_tables.py @@ -0,0 +1,108 @@ +"""Add Internet Archive Tables + +Revision ID: 2a7192657354 +Revises: 49fd9f295b8d +Create Date: 2025-08-14 07:22:15.308210 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import url_id_column, created_at_column, id_column, updated_at_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = '2a7192657354' +down_revision: Union[str, None] = '49fd9f295b8d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +IA_METADATA_TABLE_NAME = "urls_internet_archive_metadata" +IA_FLAGS_TABLE_NAME = "flag_url_checked_for_internet_archive" + +def upgrade() -> None: + _create_metadata_table() + _create_flags_table() + _add_internet_archives_task_enum() + +def downgrade() -> None: + op.drop_table(IA_METADATA_TABLE_NAME) + op.drop_table(IA_FLAGS_TABLE_NAME) + _remove_internet_archives_task_enum() + + +def _create_metadata_table(): + op.create_table( + IA_METADATA_TABLE_NAME, + id_column(), + url_id_column(), + sa.Column('archive_url', sa.String(), nullable=False), + sa.Column('digest', sa.String(), nullable=False), + sa.Column('length', sa.Integer(), nullable=False), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', name='uq_url_id_internet_archive_metadata') + ) + +def _add_internet_archives_task_enum(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive' + ] + ) + +def _remove_internet_archives_task_enum(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + ] + ) + +def _create_flags_table(): + op.create_table( + IA_FLAGS_TABLE_NAME, + url_id_column(), + sa.Column('success', sa.Boolean(), nullable=False), + created_at_column(), + sa.PrimaryKeyConstraint('url_id') + ) + diff --git a/alembic/versions/2025_08_17_1830-8a70ee509a74_add_internet_archives_upload_task.py b/alembic/versions/2025_08_17_1830-8a70ee509a74_add_internet_archives_upload_task.py new file mode 100644 index 00000000..4523e8c2 --- /dev/null +++ b/alembic/versions/2025_08_17_1830-8a70ee509a74_add_internet_archives_upload_task.py @@ -0,0 +1,43 @@ +"""Add internet archives upload task + +Revision ID: 8a70ee509a74 +Revises: 2a7192657354 +Create Date: 2025-08-17 18:30:18.353605 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, url_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = '8a70ee509a74' +down_revision: Union[str, None] = '2a7192657354' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +IA_PROBE_METADATA_TABLE_NAME_OLD = "urls_internet_archive_metadata" +IA_PROBE_METADATA_TABLE_NAME_NEW = "url_internet_archives_probe_metadata" + +IA_UPLOAD_METADATA_TABLE_NAME = "url_internet_archives_save_metadata" + +def upgrade() -> None: + _create_internet_archive_save_metadata_table() + op.rename_table(IA_PROBE_METADATA_TABLE_NAME_OLD, IA_PROBE_METADATA_TABLE_NAME_NEW) + + + +def downgrade() -> None: + op.drop_table(IA_UPLOAD_METADATA_TABLE_NAME) + op.rename_table(IA_PROBE_METADATA_TABLE_NAME_NEW, IA_PROBE_METADATA_TABLE_NAME_OLD) + +def _create_internet_archive_save_metadata_table() -> None: + op.create_table( + IA_UPLOAD_METADATA_TABLE_NAME, + id_column(), + url_id_column(), + created_at_column(), + sa.Column('last_uploaded_at', sa.DateTime(), nullable=False, server_default=sa.text('now()')), + ) \ No newline at end of file diff --git a/alembic/versions/2025_08_19_0803-b741b65a1431_augment_auto_agency_suggestions.py b/alembic/versions/2025_08_19_0803-b741b65a1431_augment_auto_agency_suggestions.py new file mode 100644 index 00000000..de3069e2 --- /dev/null +++ b/alembic/versions/2025_08_19_0803-b741b65a1431_augment_auto_agency_suggestions.py @@ -0,0 +1,254 @@ +"""Augment auto_agency_suggestions + +Revision ID: b741b65a1431 +Revises: 8a70ee509a74 +Create Date: 2025-08-19 08:03:12.106575 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import created_at_column, updated_at_column, id_column, url_id_column, switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = 'b741b65a1431' +down_revision: Union[str, None] = '8a70ee509a74' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +OLD_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME = "automated_url_agency_suggestions" +NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME = "url_auto_agency_suggestions" + +OLD_LINK_URLS_AGENCY_TABLE_NAME = "link_urls_agencies" +NEW_LINK_URLS_AGENCY_TABLE_NAME = "link_urls_agency" + +AGENCY_AUTO_SUGGESTION_METHOD_ENUM = sa.Enum( + "homepage_match", + "nlp_location_match", + "muckrock_match", + "ckan_match", + name="agency_auto_suggestion_method", +) + +FLAG_URL_VALIDATED_TABLE_NAME = "flag_url_validated" + +VALIDATED_URL_TYPE_ENUM = sa.Enum( + "data source", + "meta url", + "not relevant", + "individual record", + name="validated_url_type" +) + + + + +def upgrade() -> None: + op.rename_table(OLD_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME) + op.rename_table(OLD_LINK_URLS_AGENCY_TABLE_NAME, NEW_LINK_URLS_AGENCY_TABLE_NAME) + _alter_auto_agency_suggestions_table() + _create_flag_url_validated_table() + _add_urls_to_flag_url_validated_table() + _remove_validated_and_submitted_url_statuses() + _reset_agencies_sync_state() + + +def downgrade() -> None: + op.rename_table(NEW_LINK_URLS_AGENCY_TABLE_NAME, OLD_LINK_URLS_AGENCY_TABLE_NAME) + _revert_auto_agency_suggestions_table() + op.rename_table(NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, OLD_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME) + _revert_url_statuses() + _update_validated_and_submitted_url_statuses() + op.drop_table(FLAG_URL_VALIDATED_TABLE_NAME) + _drop_validated_url_type_enum() + +def _reset_agencies_sync_state(): + op.execute( + """ + UPDATE agencies_sync_state + set + last_full_sync_at = null, + current_cutoff_date = null, + current_page = null + """ + ) + +def _remove_validated_and_submitted_url_statuses(): + switch_enum_type( + table_name="urls", + column_name="status", + enum_name="url_status", + new_enum_values=[ + 'ok', + 'duplicate', + 'error', + '404 not found', + ], + check_constraints_to_drop=['url_name_not_null_when_validated'], + conversion_mappings={ + 'validated': 'ok', + 'submitted': 'ok', + 'pending': 'ok', + 'not relevant': 'ok', + 'individual record': 'ok' + } + ) + +def _add_urls_to_flag_url_validated_table(): + op.execute(""" + INSERT INTO flag_url_validated (url_id, type) + SELECT + urls.id, + CASE urls.status::text + WHEN 'validated' THEN 'data source' + WHEN 'submitted' THEN 'data source' + ELSE urls.status::text + END::validated_url_type + FROM urls + WHERE urls.status in ('validated', 'submitted', 'individual record', 'not relevant')""") + +def _revert_url_statuses(): + switch_enum_type( + table_name="urls", + column_name="status", + enum_name="url_status", + new_enum_values=[ + 'pending', + 'validated', + 'submitted', + 'duplicate', + 'not relevant', + 'error', + '404 not found', + 'individual record' + ], + conversion_mappings={ + 'ok': 'pending', + } + ) + op.create_check_constraint( + "url_name_not_null_when_validated", + "urls", + "(name IS NOT NULL) OR (status <> 'validated'::url_status)" + ) + +def _update_validated_and_submitted_url_statuses(): + op.execute(""" + UPDATE urls + SET status = 'not relevant' + FROM flag_url_validated + WHERE urls.id = flag_url_validated.id + AND flag_url_validated.type = 'not relevant' + """) + + op.execute(""" + UPDATE urls + SET status = 'individual record' + FROM flag_url_validated + WHERE urls.id = flag_url_validated.id + AND flag_url_validated.type = 'individual record' + """) + + op.execute(""" + UPDATE urls + SET status = 'validated' + FROM flag_url_validated + left join url_data_source on flag_url_validated.url_id = url_data_source.url_id + WHERE urls.id = flag_url_validated.id + AND flag_url_validated.type = 'data source' + AND url_data_source.url_id is NULL + """) + + op.execute(""" + UPDATE urls + SET status = 'validated' + FROM flag_url_validated + left join url_data_source on flag_url_validated.url_id = url_data_source.url_id + WHERE urls.id = flag_url_validated.id + AND flag_url_validated.type = 'data source' + AND url_data_source.url_id is not NULL + """) + + +def _create_flag_url_validated_table(): + op.create_table( + FLAG_URL_VALIDATED_TABLE_NAME, + id_column(), + url_id_column(), + sa.Column( + 'type', + VALIDATED_URL_TYPE_ENUM, + nullable=False, + ), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', name='uq_flag_url_validated_url_id') + ) + +def _drop_validated_url_type_enum(): + VALIDATED_URL_TYPE_ENUM.drop(op.get_bind()) + +def _alter_auto_agency_suggestions_table(): + AGENCY_AUTO_SUGGESTION_METHOD_ENUM.create(op.get_bind()) + # Created At + op.add_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + created_at_column() + ) + # Updated At + op.add_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + updated_at_column() + ) + # Method + op.add_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + sa.Column( + 'method', + AGENCY_AUTO_SUGGESTION_METHOD_ENUM, + nullable=True + ) + ) + # Confidence + op.add_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + sa.Column( + 'confidence', + sa.Float(), + server_default=sa.text('0.0'), + nullable=False + ) + ) + # Check constraint that confidence is between 0 and 1 + op.create_check_constraint( + "auto_url_agency_suggestions_check_confidence_between_0_and_1", + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + "confidence BETWEEN 0 AND 1" + ) + + +def _revert_auto_agency_suggestions_table(): + # Created At + op.drop_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + 'created_at' + ) + # Updated At + op.drop_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + 'updated_at' + ) + # Method + op.drop_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + 'method' + ) + # Confidence + op.drop_column( + NEW_AUTO_URL_AGENCY_SUGGESTIONS_TABLE_NAME, + 'confidence' + ) + AGENCY_AUTO_SUGGESTION_METHOD_ENUM.drop(op.get_bind()) + diff --git a/alembic/versions/2025_08_31_1930-70baaee0dd79_overhaul_agency_identification.py b/alembic/versions/2025_08_31_1930-70baaee0dd79_overhaul_agency_identification.py new file mode 100644 index 00000000..39703fde --- /dev/null +++ b/alembic/versions/2025_08_31_1930-70baaee0dd79_overhaul_agency_identification.py @@ -0,0 +1,267 @@ +"""Overhaul agency identification + +Revision ID: 70baaee0dd79 +Revises: b741b65a1431 +Create Date: 2025-08-31 19:30:20.690369 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, url_id_column, created_at_column, agency_id_column, updated_at_column, \ + task_id_column + +# revision identifiers, used by Alembic. +revision: str = '70baaee0dd79' +down_revision: Union[str, None] = 'b741b65a1431' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +URL_HAS_AGENCY_SUGGESTIONS_VIEW_NAME: str = "url_has_agency_auto_suggestions_view" +URL_UNKNOWN_AGENCIES_VIEW_NAME: str = "url_unknown_agencies_view" + +URL_AUTO_AGENCY_SUBTASK_TABLE_NAME: str = "url_auto_agency_id_subtasks" +LINK_AGENCY_ID_SUBTASK_AGENCIES_TABLE_NAME: str = "agency_id_subtask_suggestions" + +META_URL_VIEW_NAME: str = "meta_url_view" +UNVALIDATED_URL_VIEW_NAME: str = "unvalidated_url_view" + +URL_AUTO_AGENCY_SUGGESTIONS_TABLE_NAME: str = "url_auto_agency_suggestions" + +AGENCY_AUTO_SUGGESTION_METHOD_ENUM = sa.dialects.postgresql.ENUM( + name="agency_auto_suggestion_method", + create_type=False +) + +SUBTASK_DETAIL_CODE_ENUM = sa.Enum( + 'no details', + 'retrieval error', + 'homepage-single agency', + 'homepage-multi agency', + name="agency_id_subtask_detail_code", +) + + + + + +def upgrade() -> None: + _create_url_auto_agency_subtask_table() + _create_url_unknown_agencies_view() + _create_meta_url_view() + _create_link_agency_id_subtask_agencies_table() + _drop_url_annotation_flags_view() + _create_new_url_annotation_flags_view() + _drop_url_auto_agency_suggestions_table() + _create_unvalidated_urls_view() + + +def downgrade() -> None: + _drop_url_unknown_agencies_view() + _create_url_auto_agency_suggestions_table() + _drop_url_annotation_flags_view() + _create_old_url_annotation_flags_view() + _drop_link_agency_id_subtask_agencies_table() + _drop_url_auto_agency_subtask_table() + _drop_meta_url_view() + SUBTASK_DETAIL_CODE_ENUM.drop(op.get_bind()) + _drop_unvalidated_urls_view() + +def _create_unvalidated_urls_view(): + op.execute(f""" + CREATE OR REPLACE VIEW {UNVALIDATED_URL_VIEW_NAME} as + select + u.id as url_id + from + urls u + left join flag_url_validated fuv + on fuv.url_id = u.id + where + fuv.type is null + """) + +def _drop_unvalidated_urls_view(): + op.execute(f"DROP VIEW IF EXISTS {UNVALIDATED_URL_VIEW_NAME}") + + +def _drop_url_annotation_flags_view(): + op.execute(f"DROP VIEW IF EXISTS url_annotation_flags") + + +def _drop_meta_url_view(): + op.execute(f"DROP VIEW IF EXISTS {META_URL_VIEW_NAME}") + + +def _create_meta_url_view(): + op.execute(f""" + CREATE OR REPLACE VIEW {META_URL_VIEW_NAME} AS + SELECT + urls.id as url_id + FROM urls + INNER JOIN flag_url_validated fuv on fuv.url_id = urls.id + where fuv.type = 'meta url' + """) + +def _drop_url_auto_agency_suggestions_table(): + op.drop_table(URL_AUTO_AGENCY_SUGGESTIONS_TABLE_NAME) + + +def _create_new_url_annotation_flags_view(): + + op.execute( + f""" + CREATE OR REPLACE VIEW url_annotation_flags AS + ( + SELECT u.id as url_id, + EXISTS (SELECT 1 FROM public.auto_record_type_suggestions a WHERE a.url_id = u.id) AS has_auto_record_type_suggestion, + EXISTS (SELECT 1 FROM public.auto_relevant_suggestions a WHERE a.url_id = u.id) AS has_auto_relevant_suggestion, + EXISTS (SELECT 1 FROM public.{URL_AUTO_AGENCY_SUBTASK_TABLE_NAME} a WHERE a.url_id = u.id) AS has_auto_agency_suggestion, + EXISTS (SELECT 1 FROM public.user_record_type_suggestions a WHERE a.url_id = u.id) AS has_user_record_type_suggestion, + EXISTS (SELECT 1 FROM public.user_relevant_suggestions a WHERE a.url_id = u.id) AS has_user_relevant_suggestion, + EXISTS (SELECT 1 FROM public.user_url_agency_suggestions a WHERE a.url_id = u.id) AS has_user_agency_suggestion, + EXISTS (SELECT 1 FROM public.link_urls_agency a WHERE a.url_id = u.id) AS has_confirmed_agency, + EXISTS (SELECT 1 FROM public.reviewing_user_url a WHERE a.url_id = u.id) AS was_reviewed + FROM urls u + ) + """ + ) + + +def _create_url_unknown_agencies_view(): + op.execute( + f""" + CREATE OR REPLACE VIEW {URL_UNKNOWN_AGENCIES_VIEW_NAME} AS + SELECT + u.id + FROM urls u + LEFT JOIN {URL_AUTO_AGENCY_SUBTASK_TABLE_NAME} uas ON u.id = uas.url_id + GROUP BY u.id + HAVING bool_or(uas.agencies_found) = false + """ + ) + + +def _create_url_auto_agency_subtask_table(): + op.create_table( + URL_AUTO_AGENCY_SUBTASK_TABLE_NAME, + id_column(), + task_id_column(), + url_id_column(), + sa.Column( + "type", + AGENCY_AUTO_SUGGESTION_METHOD_ENUM, + nullable=False + ), + sa.Column( + "agencies_found", + sa.Boolean(), + nullable=False + ), + sa.Column( + "detail", + SUBTASK_DETAIL_CODE_ENUM, + server_default=sa.text("'no details'"), + nullable=False + ), + created_at_column() + ) + + +def _create_link_agency_id_subtask_agencies_table(): + op.create_table( + LINK_AGENCY_ID_SUBTASK_AGENCIES_TABLE_NAME, + id_column(), + sa.Column( + "subtask_id", + sa.Integer(), + sa.ForeignKey( + f'{URL_AUTO_AGENCY_SUBTASK_TABLE_NAME}.id', + ondelete='CASCADE' + ), + nullable=False, + comment='A foreign key to the `url_auto_agency_subtask` table.' + ), + sa.Column( + "confidence", + sa.Integer, + sa.CheckConstraint( + "confidence BETWEEN 0 and 100" + ), + nullable=False, + ), + agency_id_column(), + created_at_column() + ) + + +def _drop_link_agency_id_subtask_agencies_table(): + op.drop_table(LINK_AGENCY_ID_SUBTASK_AGENCIES_TABLE_NAME) + + +def _drop_url_auto_agency_subtask_table(): + op.drop_table(URL_AUTO_AGENCY_SUBTASK_TABLE_NAME) + + +def _create_url_auto_agency_suggestions_table(): + op.create_table( + URL_AUTO_AGENCY_SUGGESTIONS_TABLE_NAME, + id_column(), + agency_id_column(), + url_id_column(), + sa.Column( + "is_unknown", + sa.Boolean(), + nullable=False + ), + created_at_column(), + updated_at_column(), + sa.Column( + 'method', + AGENCY_AUTO_SUGGESTION_METHOD_ENUM, + nullable=True + ), + sa.Column( + 'confidence', + sa.Float(), + server_default=sa.text('0.0'), + nullable=False + ), + sa.UniqueConstraint("agency_id", "url_id") + ) + + +def _drop_url_unknown_agencies_view(): + op.execute(f"DROP VIEW IF EXISTS {URL_UNKNOWN_AGENCIES_VIEW_NAME}") + +def _drop_url_annotation_flags_view(): + op.execute("DROP VIEW url_annotation_flags;") + + +def _create_old_url_annotation_flags_view(): + op.execute( + f""" + CREATE OR REPLACE VIEW url_annotation_flags AS + ( + SELECT u.id, + CASE WHEN arts.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_record_type_suggestion, + CASE WHEN ars.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_relevant_suggestion, + CASE WHEN auas.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_agency_suggestion, + CASE WHEN urts.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_record_type_suggestion, + CASE WHEN urs.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_relevant_suggestion, + CASE WHEN uuas.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_agency_suggestion, + CASE WHEN cua.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_confirmed_agency, + CASE WHEN ruu.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS was_reviewed + FROM urls u + LEFT JOIN public.auto_record_type_suggestions arts ON u.id = arts.url_id + LEFT JOIN public.auto_relevant_suggestions ars ON u.id = ars.url_id + LEFT JOIN public.{URL_AUTO_AGENCY_SUGGESTIONS_TABLE_NAME} auas ON u.id = auas.url_id + LEFT JOIN public.user_record_type_suggestions urts ON u.id = urts.url_id + LEFT JOIN public.user_relevant_suggestions urs ON u.id = urs.url_id + LEFT JOIN public.user_url_agency_suggestions uuas ON u.id = uuas.url_id + LEFT JOIN public.reviewing_user_url ruu ON u.id = ruu.url_id + LEFT JOIN public.link_urls_agency cua on u.id = cua.url_id + ) + """ + ) diff --git a/alembic/versions/2025_09_12_2040-e7189dc92a83_create_url_screenshot_task.py b/alembic/versions/2025_09_12_2040-e7189dc92a83_create_url_screenshot_task.py new file mode 100644 index 00000000..0348c6c3 --- /dev/null +++ b/alembic/versions/2025_09_12_2040-e7189dc92a83_create_url_screenshot_task.py @@ -0,0 +1,122 @@ +"""Create url screenshot task + +Revision ID: e7189dc92a83 +Revises: 70baaee0dd79 +Create Date: 2025-09-12 20:40:45.950204 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import switch_enum_type, id_column, url_id_column, created_at_column, updated_at_column + +# revision identifiers, used by Alembic. +revision: str = 'e7189dc92a83' +down_revision: Union[str, None] = '70baaee0dd79' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +URL_SCREENSHOT_TABLE_NAME = "url_screenshot" +SCREENSHOT_ERROR_TABLE_NAME = "error_url_screenshot" + + + +def upgrade() -> None: + _add_url_screenshot_task() + _add_url_screenshot_table() + _add_screenshot_error_table() + + + +def downgrade() -> None: + _remove_url_screenshot_task() + _remove_url_screenshot_table() + _remove_screenshot_error_table() + + +def _add_screenshot_error_table(): + op.create_table( + SCREENSHOT_ERROR_TABLE_NAME, + url_id_column(), + sa.Column('error', sa.String(), nullable=False), + created_at_column(), + sa.PrimaryKeyConstraint('url_id') + ) + + +def _add_url_screenshot_table(): + op.create_table( + URL_SCREENSHOT_TABLE_NAME, + url_id_column(), + sa.Column('content', sa.LargeBinary(), nullable=False), + sa.Column('file_size', sa.Integer(), nullable=False), + created_at_column(), + updated_at_column(), + sa.UniqueConstraint('url_id', name='uq_url_id_url_screenshot') + ) + + +def _remove_url_screenshot_table(): + op.drop_table(URL_SCREENSHOT_TABLE_NAME) + + +def _remove_screenshot_error_table(): + op.drop_table(SCREENSHOT_ERROR_TABLE_NAME) + + +def _add_url_screenshot_task(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive', + 'Screenshot' + ] + ) + +def _remove_url_screenshot_task(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive' + ] + ) \ No newline at end of file diff --git a/alembic/versions/2025_09_15_1137-d5f92e6fedf4_add_location_tables.py b/alembic/versions/2025_09_15_1137-d5f92e6fedf4_add_location_tables.py new file mode 100644 index 00000000..be2c22e9 --- /dev/null +++ b/alembic/versions/2025_09_15_1137-d5f92e6fedf4_add_location_tables.py @@ -0,0 +1,161 @@ +"""Add Location tables + +Revision ID: d5f92e6fedf4 +Revises: e7189dc92a83 +Create Date: 2025-09-15 11:37:58.183674 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'd5f92e6fedf4' +down_revision: Union[str, None] = 'e7189dc92a83' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +US_STATES_TABLE_NAME = 'us_states' +COUNTIES_TABLE_NAME = 'counties' +LOCALITIES_TABLE_NAME = 'localities' +LOCATIONS_TABLE_NAME = 'locations' +LINK_AGENCIES_LOCATIONS_TABLE_NAME = 'link_agencies_locations' + +def upgrade() -> None: + _create_location_type() + _create_us_states_table() + _create_counties_table() + _create_localities_table() + _create_locations_table() + _create_link_agencies_locations_table() + +def downgrade() -> None: + _remove_link_agencies_locations_table() + _remove_locations_table() + _remove_localities_table() + _remove_counties_table() + _remove_us_states_table() + _remove_location_type() + +def _create_location_type(): + op.execute(""" + create type location_type as enum ('National', 'State', 'County', 'Locality') + """) + +def _remove_location_type(): + op.execute(""" + drop type location_type + """) + +def _create_us_states_table(): + op.execute(""" + create table if not exists public.us_states + ( + state_iso text not null + constraint unique_state_iso + unique, + state_name text, + id bigint generated always as identity + primary key + ) + """) + +def _create_counties_table(): + op.execute(""" + create table if not exists public.counties + ( + fips varchar not null + constraint unique_fips + unique, + name text, + lat double precision, + lng double precision, + population bigint, + agencies text, + id bigint generated always as identity + primary key, + state_id integer + references public.us_states, + unique (fips, state_id), + constraint unique_county_name_and_state + unique (name, state_id) + ) + """) + +def _create_localities_table(): + op.execute(""" + create table if not exists public.localities + ( + id bigint generated always as identity + primary key, + name varchar(255) not null + constraint localities_name_check + check ((name)::text !~~ '%,%'::text), + county_id integer not null + references public.counties, + unique (name, county_id) + ) + + """) + +def _create_locations_table(): + op.execute(""" + create table if not exists public.locations + ( + id bigint generated always as identity + primary key, + type location_type not null, + state_id bigint + references public.us_states + on delete cascade, + county_id bigint + references public.counties + on delete cascade, + locality_id bigint + references public.localities + on delete cascade, + lat double precision, + lng double precision, + unique (id, type, state_id, county_id, locality_id), + constraint locations_check + check (((type = 'National'::location_type) AND (state_id IS NULL) AND (county_id IS NULL) AND + (locality_id IS NULL)) OR + ((type = 'State'::location_type) AND (county_id IS NULL) AND (locality_id IS NULL)) OR + ((type = 'County'::location_type) AND (county_id IS NOT NULL) AND (locality_id IS NULL)) OR + ((type = 'Locality'::location_type) AND (county_id IS NOT NULL) AND (locality_id IS NOT NULL))) + ) + """) + +def _create_link_agencies_locations_table(): + op.execute(""" + create table if not exists public.link_agencies_locations + ( + id serial + primary key, + agency_id integer not null + references public.agencies + on delete cascade, + location_id integer not null + references public.locations + on delete cascade, + constraint unique_agency_location + unique (agency_id, location_id) + ) + """) + +def _remove_link_agencies_locations_table(): + op.drop_table(LINK_AGENCIES_LOCATIONS_TABLE_NAME) + +def _remove_locations_table(): + op.drop_table(LOCATIONS_TABLE_NAME) + +def _remove_localities_table(): + op.drop_table(LOCALITIES_TABLE_NAME) + +def _remove_counties_table(): + op.drop_table(COUNTIES_TABLE_NAME) + +def _remove_us_states_table(): + op.drop_table(US_STATES_TABLE_NAME) diff --git a/alembic/versions/2025_09_15_1905-93cbaa3b8e9b_add_location_annotation_logic.py b/alembic/versions/2025_09_15_1905-93cbaa3b8e9b_add_location_annotation_logic.py new file mode 100644 index 00000000..55bb5ea5 --- /dev/null +++ b/alembic/versions/2025_09_15_1905-93cbaa3b8e9b_add_location_annotation_logic.py @@ -0,0 +1,426 @@ +"""Add location annotation logic + +Revision ID: 93cbaa3b8e9b +Revises: d5f92e6fedf4 +Create Date: 2025-09-15 19:05:27.872875 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +from src.util.alembic_helpers import switch_enum_type, url_id_column, location_id_column, created_at_column, id_column, \ + task_id_column, user_id_column + +# revision identifiers, used by Alembic. +revision: str = '93cbaa3b8e9b' +down_revision: Union[str, None] = 'd5f92e6fedf4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +USER_LOCATION_SUGGESTIONS_TABLE_NAME = 'user_location_suggestions' +AUTO_LOCATION_ID_SUBTASK_TABLE_NAME = 'auto_location_id_subtasks' +LOCATION_ID_SUBTASK_SUGGESTIONS_TABLE_NAME = 'location_id_subtask_suggestions' +LOCATION_ID_TASK_TYPE = 'Location ID' +LOCATION_ID_SUBTASK_TYPE_NAME = 'location_id_subtask_type' + + + + +def upgrade() -> None: + _add_location_id_task_type() + _create_user_location_suggestions_table() + _create_auto_location_id_subtask_table() + _create_location_id_subtask_suggestions_table() + _create_new_url_annotation_flags_view() + _create_locations_expanded_view() + _create_state_location_trigger() + _create_county_location_trigger() + _create_locality_location_trigger() + _add_pg_trgm_extension() + +def downgrade() -> None: + _drop_locations_expanded_view() + _create_old_url_annotation_flags_view() + _drop_location_id_subtask_suggestions_table() + _drop_auto_location_id_subtask_table() + _drop_user_location_suggestions_table() + _drop_location_id_task_type() + _drop_location_id_subtask_type() + _drop_state_location_trigger() + _drop_county_location_trigger() + _drop_locality_location_trigger() + _drop_pg_trgm_extension() + +def _drop_pg_trgm_extension(): + op.execute(""" + drop extension if exists pg_trgm; + """) + +def _add_pg_trgm_extension(): + op.execute(""" + create extension if not exists pg_trgm; + """) + + +def _create_state_location_trigger(): + # Function + op.execute(""" + create function insert_state_location() returns trigger + language plpgsql + as + $$ + BEGIN + -- Insert a new location of type 'State' when a new state is added + INSERT INTO locations (type, state_id) + VALUES ('State', NEW.id); + RETURN NEW; + END; + $$; + """) + + # Trigger + op.execute(""" + create trigger after_state_insert + after insert + on us_states + for each row + execute procedure insert_state_location(); + """) + + +def _create_county_location_trigger(): + # Function + op.execute(""" + create function insert_county_location() returns trigger + language plpgsql + as + $$ + BEGIN + -- Insert a new location of type 'County' when a new county is added + INSERT INTO locations (type, state_id, county_id) + VALUES ('County', NEW.state_id, NEW.id); + RETURN NEW; + END; + $$; + """) + + # Trigger + op.execute(""" + create trigger after_county_insert + after insert + on counties + for each row + execute procedure insert_county_location(); + """) + + +def _create_locality_location_trigger(): + # Function + op.execute(""" + create function insert_locality_location() returns trigger + language plpgsql + as + $$ + DECLARE + v_state_id BIGINT; + BEGIN + -- Get the state_id from the associated county + SELECT c.state_id INTO v_state_id + FROM counties c + WHERE c.id = NEW.county_id; + + -- Insert a new location of type 'Locality' when a new locality is added + INSERT INTO locations (type, state_id, county_id, locality_id) + VALUES ('Locality', v_state_id, NEW.county_id, NEW.id); + + RETURN NEW; + END; + $$; + """) + + # Trigger + op.execute(""" + create trigger after_locality_insert + after insert + on localities + for each row + execute procedure insert_locality_location(); + + """) + + +def _drop_state_location_trigger(): + # Trigger + op.execute(""" + drop trigger if exists after_state_insert on us_states; + """) + + # Function + op.execute(""" + drop function if exists insert_state_location; + """) + + + + +def _drop_locality_location_trigger(): + # Trigger + op.execute(""" + drop trigger if exists after_locality_insert on localities; + """) + + # Function + op.execute(""" + drop function if exists insert_locality_location; + """) + + + +def _drop_county_location_trigger(): + # Trigger + op.execute(""" + drop trigger if exists after_county_insert on counties; + """) + + # Function + op.execute(""" + drop function if exists insert_county_location; + """) + + + +def _create_new_url_annotation_flags_view(): + op.execute("""DROP VIEW IF EXISTS url_annotation_flags;""") + op.execute( + f""" + CREATE OR REPLACE VIEW url_annotation_flags AS + ( + SELECT u.id as url_id, + EXISTS (SELECT 1 FROM public.auto_record_type_suggestions a WHERE a.url_id = u.id) AS has_auto_record_type_suggestion, + EXISTS (SELECT 1 FROM public.auto_relevant_suggestions a WHERE a.url_id = u.id) AS has_auto_relevant_suggestion, + EXISTS (SELECT 1 FROM public.url_auto_agency_id_subtasks a WHERE a.url_id = u.id) AS has_auto_agency_suggestion, + EXISTS (SELECT 1 FROM public.auto_location_id_subtasks a WHERE a.url_id = u.id) AS has_auto_location_suggestion, + EXISTS (SELECT 1 FROM public.user_record_type_suggestions a WHERE a.url_id = u.id) AS has_user_record_type_suggestion, + EXISTS (SELECT 1 FROM public.user_relevant_suggestions a WHERE a.url_id = u.id) AS has_user_relevant_suggestion, + EXISTS (SELECT 1 FROM public.user_url_agency_suggestions a WHERE a.url_id = u.id) AS has_user_agency_suggestion, + EXISTS (SELECT 1 FROM public.user_location_suggestions a WHERE a.url_id = u.id) AS has_user_location_suggestion, + EXISTS (SELECT 1 FROM public.link_urls_agency a WHERE a.url_id = u.id) AS has_confirmed_agency, + EXISTS (SELECT 1 FROM public.reviewing_user_url a WHERE a.url_id = u.id) AS was_reviewed + FROM urls u + ) + """ + ) + +def _create_old_url_annotation_flags_view(): + op.execute("""DROP VIEW IF EXISTS url_annotation_flags;""") + op.execute( + f""" + CREATE OR REPLACE VIEW url_annotation_flags AS + ( + SELECT u.id as url_id, + EXISTS (SELECT 1 FROM public.auto_record_type_suggestions a WHERE a.url_id = u.id) AS has_auto_record_type_suggestion, + EXISTS (SELECT 1 FROM public.auto_relevant_suggestions a WHERE a.url_id = u.id) AS has_auto_relevant_suggestion, + EXISTS (SELECT 1 FROM public.url_auto_agency_id_subtasks a WHERE a.url_id = u.id) AS has_auto_agency_suggestion, + EXISTS (SELECT 1 FROM public.user_record_type_suggestions a WHERE a.url_id = u.id) AS has_user_record_type_suggestion, + EXISTS (SELECT 1 FROM public.user_relevant_suggestions a WHERE a.url_id = u.id) AS has_user_relevant_suggestion, + EXISTS (SELECT 1 FROM public.user_url_agency_suggestions a WHERE a.url_id = u.id) AS has_user_agency_suggestion, + EXISTS (SELECT 1 FROM public.link_urls_agency a WHERE a.url_id = u.id) AS has_confirmed_agency, + EXISTS (SELECT 1 FROM public.reviewing_user_url a WHERE a.url_id = u.id) AS was_reviewed + FROM urls u + ) + """ + ) + + +def _drop_locations_expanded_view(): + op.execute(""" + drop view if exists public.locations_expanded; + """) + +def _create_locations_expanded_view(): + op.execute(""" + create or replace view public.locations_expanded + (id, type, state_name, state_iso, county_name, county_fips, locality_name, locality_id, state_id, county_id, + display_name, full_display_name) + as + SELECT + locations.id, + locations.type, + us_states.state_name, + us_states.state_iso, + counties.name AS county_name, + counties.fips AS county_fips, + localities.name AS locality_name, + localities.id AS locality_id, + us_states.id AS state_id, + counties.id AS county_id, + CASE + WHEN locations.type = 'Locality'::location_type THEN localities.name + WHEN locations.type = 'County'::location_type THEN counties.name::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + ELSE NULL::character varying + END AS display_name, + CASE + WHEN locations.type = 'Locality'::location_type THEN concat(localities.name, ', ', counties.name, ', ', + us_states.state_name)::character varying + WHEN locations.type = 'County'::location_type + THEN concat(counties.name, ', ', us_states.state_name)::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + ELSE NULL::character varying + END AS full_display_name + FROM + locations + LEFT JOIN us_states ON locations.state_id = us_states.id + LEFT JOIN counties ON locations.county_id = counties.id + LEFT JOIN localities ON locations.locality_id = localities.id; + + """) + +def _add_location_id_task_type(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive', + 'Screenshot', + LOCATION_ID_TASK_TYPE + ] + ) + + +def _create_user_location_suggestions_table(): + op.create_table( + USER_LOCATION_SUGGESTIONS_TABLE_NAME, + url_id_column(), + user_id_column(), + location_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint( + 'url_id', + 'user_id', + 'location_id', + name='user_location_suggestions_pk' + ) + ) + + +def _create_auto_location_id_subtask_table(): + op.create_table( + AUTO_LOCATION_ID_SUBTASK_TABLE_NAME, + id_column(), + task_id_column(), + url_id_column(), + sa.Column( + 'locations_found', + sa.Boolean(), + nullable=False + ), + sa.Column( + 'type', + sa.Enum( + 'nlp_location_frequency', + name='auto_location_id_subtask_type' + ), + nullable=False + ), + created_at_column(), + sa.UniqueConstraint( + 'url_id', + 'type', + name='auto_location_id_subtask_url_id_type_unique' + ) + ) + + +def _create_location_id_subtask_suggestions_table(): + op.create_table( + LOCATION_ID_SUBTASK_SUGGESTIONS_TABLE_NAME, + sa.Column( + 'subtask_id', + sa.Integer(), + sa.ForeignKey( + f'{AUTO_LOCATION_ID_SUBTASK_TABLE_NAME}.id', + ondelete='CASCADE' + ), + ), + location_id_column(), + sa.Column( + 'confidence', + sa.Float(), + nullable=False + ), + created_at_column(), + sa.PrimaryKeyConstraint( + 'subtask_id', + 'location_id', + name='location_id_subtask_suggestions_pk' + ) + ) + + + +def _drop_location_id_task_type(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive', + 'Screenshot', + ] + ) + + +def _drop_auto_location_id_subtask_table(): + op.drop_table(AUTO_LOCATION_ID_SUBTASK_TABLE_NAME) + + +def _drop_user_location_suggestions_table(): + op.drop_table(USER_LOCATION_SUGGESTIONS_TABLE_NAME) + + +def _drop_location_id_subtask_suggestions_table(): + op.drop_table(LOCATION_ID_SUBTASK_SUGGESTIONS_TABLE_NAME) + +def _drop_location_id_subtask_type(): + op.execute(""" + DROP TYPE IF EXISTS auto_location_id_subtask_type; + """) + diff --git a/alembic/versions/2025_09_21_0940-8d7208843b76_update_for_human_agreement_logic.py b/alembic/versions/2025_09_21_0940-8d7208843b76_update_for_human_agreement_logic.py new file mode 100644 index 00000000..08378218 --- /dev/null +++ b/alembic/versions/2025_09_21_0940-8d7208843b76_update_for_human_agreement_logic.py @@ -0,0 +1,406 @@ +"""Update for human agreement logic + +Revision ID: 8d7208843b76 +Revises: 93cbaa3b8e9b +Create Date: 2025-09-21 09:40:36.506827 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import switch_enum_type, url_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = '8d7208843b76' +down_revision: Union[str, None] = '93cbaa3b8e9b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +AUTO_VALIDATION_TASK_TYPE: str = 'Auto Validate' +URL_TYPE_NAME: str = 'url_type' +VALIDATED_URL_TYPE_NAME: str = 'validated_url_type' +FLAG_URL_VALIDATED_TABLE_NAME: str = 'flag_url_validated' + +USER_RELEVANT_SUGGESTIONS_TABLE_NAME: str = 'user_relevant_suggestions' +USER_URL_TYPE_SUGGESTIONS_TABLE_NAME: str = 'user_url_type_suggestions' + +FLAG_URL_AUTO_VALIDATED_TABLE_NAME: str = 'flag_url_auto_validated' + + +def _create_anno_count_view(): + op.execute(""" + CREATE OR REPLACE VIEW url_annotation_count_view AS + with auto_location_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_location_id_subtasks anno on u.id = anno.url_id + group by u.id +) +, auto_agency_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.url_auto_agency_id_subtasks anno on u.id = anno.url_id + group by u.id +) +, auto_url_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_relevant_suggestions anno on u.id = anno.url_id + group by u.id +) +, auto_record_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_record_type_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_location_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_location_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_agency_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_url_agency_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_url_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_url_type_suggestions anno on u.id = anno.url_id + group by u.id + ) +, user_record_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_record_type_suggestions anno on u.id = anno.url_id + group by u.id +) +select + u.id as url_id, + coalesce(auto_ag.cnt, 0) as auto_agency_count, + coalesce(auto_loc.cnt, 0) as auto_location_count, + coalesce(auto_rec.cnt, 0) as auto_record_type_count, + coalesce(auto_typ.cnt, 0) as auto_url_type_count, + coalesce(user_ag.cnt, 0) as user_agency_count, + coalesce(user_loc.cnt, 0) as user_location_count, + coalesce(user_rec.cnt, 0) as user_record_type_count, + coalesce(user_typ.cnt, 0) as user_url_type_count, + ( + coalesce(auto_ag.cnt, 0) + + coalesce(auto_loc.cnt, 0) + + coalesce(auto_rec.cnt, 0) + + coalesce(auto_typ.cnt, 0) + + coalesce(user_ag.cnt, 0) + + coalesce(user_loc.cnt, 0) + + coalesce(user_rec.cnt, 0) + + coalesce(user_typ.cnt, 0) + ) as total_anno_count + + from urls u + left join auto_agency_count auto_ag on auto_ag.id = u.id + left join auto_location_count auto_loc on auto_loc.id = u.id + left join auto_record_type_count auto_rec on auto_rec.id = u.id + left join auto_url_type_count auto_typ on auto_typ.id = u.id + left join user_agency_count user_ag on user_ag.id = u.id + left join user_location_count user_loc on user_loc.id = u.id + left join user_record_type_count user_rec on user_rec.id = u.id + left join user_url_type_count user_typ on user_typ.id = u.id + + + """) + + +def upgrade() -> None: + _drop_meta_url_view() + _drop_unvalidated_url_view() + + # URL Type + _rename_validated_url_type_to_url_type() + _add_not_found_url_type() + + # suggested Status + _rename_user_relevant_suggestions_to_user_url_type_suggestions() + _rename_suggested_status_column_to_type() + _switch_suggested_status_with_url_type() + _remove_suggested_status_enum() + + _add_flag_url_auto_validated_table() + _add_auto_validate_task() + + _create_anno_count_view() + + + _add_meta_url_view() + _add_unvalidated_url_view() + + +def _remove_suggested_status_enum(): + op.execute(f"DROP TYPE suggested_status") + + +def _add_suggested_status_enum(): + op.execute( + "create type suggested_status as enum " + + "('relevant', 'not relevant', 'individual record', 'broken page/404 not found');" + ) + + +def _drop_anno_count_view(): + op.execute(""" + DROP VIEW IF EXISTS url_annotation_count_view + """) + + +def downgrade() -> None: + _drop_meta_url_view() + _drop_unvalidated_url_view() + _drop_anno_count_view() + + # Suggested Status + _add_suggested_status_enum() + _replace_url_type_with_suggested_status() + _rename_type_column_to_suggested_status() + _rename_user_url_type_suggestions_to_user_relevant_suggestions() + + # URL Type + _remove_not_found_url_type() + _rename_url_type_to_validated_url_type() + + _remove_auto_validate_task() + _remove_flag_url_auto_validated_table() + + + _add_meta_url_view() + _add_unvalidated_url_view() + +def _rename_suggested_status_column_to_type(): + op.alter_column( + table_name=USER_URL_TYPE_SUGGESTIONS_TABLE_NAME, + column_name='suggested_status', + new_column_name='type' + ) + + +def _rename_type_column_to_suggested_status(): + op.alter_column( + table_name=USER_URL_TYPE_SUGGESTIONS_TABLE_NAME, + column_name='type', + new_column_name='suggested_status' + ) + + + + +def _drop_unvalidated_url_view(): + op.execute("DROP VIEW IF EXISTS unvalidated_url_view") + + +def _add_unvalidated_url_view(): + op.execute(""" + CREATE OR REPLACE VIEW unvalidated_url_view AS + select + u.id as url_id + from + urls u + left join flag_url_validated fuv + on fuv.url_id = u.id + where + fuv.type is null + """) + + +def _add_meta_url_view(): + op.execute(""" + CREATE OR REPLACE VIEW meta_url_view AS + SELECT + urls.id as url_id + FROM urls + INNER JOIN flag_url_validated fuv on fuv.url_id = urls.id + where fuv.type = 'meta url' + """) + +def _drop_meta_url_view(): + op.execute("DROP VIEW IF EXISTS meta_url_view") + +def _rename_validated_url_type_to_url_type(): + op.execute(f""" + ALTER TYPE {VALIDATED_URL_TYPE_NAME} RENAME TO {URL_TYPE_NAME} + """) + +def _rename_url_type_to_validated_url_type(): + op.execute(f""" + ALTER TYPE {URL_TYPE_NAME} RENAME TO {VALIDATED_URL_TYPE_NAME} + """) + +def _add_not_found_url_type(): + switch_enum_type( + table_name=FLAG_URL_VALIDATED_TABLE_NAME, + column_name='type', + enum_name=URL_TYPE_NAME, + new_enum_values=[ + 'data source', + 'meta url', + 'not relevant', + 'individual record', + 'not found' + ] + ) + +def _remove_not_found_url_type(): + switch_enum_type( + table_name=FLAG_URL_VALIDATED_TABLE_NAME, + column_name='type', + enum_name=URL_TYPE_NAME, + new_enum_values=[ + 'data source', + 'meta url', + 'not relevant', + 'individual record' + ] + ) + + +def _switch_suggested_status_with_url_type(): + op.execute(f""" + ALTER TABLE {USER_URL_TYPE_SUGGESTIONS_TABLE_NAME} + ALTER COLUMN type type {URL_TYPE_NAME} + USING ( + CASE type::text + WHEN 'relevant' THEN 'data source' + WHEN 'broken page/404 not found' THEN 'not found' + ELSE type::text + END + )::{URL_TYPE_NAME} + """) + + + +def _replace_url_type_with_suggested_status(): + op.execute(f""" + ALTER TABLE {USER_URL_TYPE_SUGGESTIONS_TABLE_NAME} + ALTER COLUMN type type suggested_status + USING ( + CASE type::text + WHEN 'data source' THEN 'relevant' + WHEN 'not found' THEN 'broken page/404 not found' + ELSE type::text + END + )::suggested_status + + """) + + + + +def _add_flag_url_auto_validated_table(): + op.create_table( + FLAG_URL_AUTO_VALIDATED_TABLE_NAME, + url_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint('url_id') + ) + + + +def _remove_flag_url_auto_validated_table(): + op.drop_table(FLAG_URL_AUTO_VALIDATED_TABLE_NAME) + + + +def _add_auto_validate_task(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive', + 'Screenshot', + 'Location ID', + AUTO_VALIDATION_TASK_TYPE, + ] + ) + + +def _rename_user_relevant_suggestions_to_user_url_type_suggestions(): + op.rename_table( + old_table_name=USER_RELEVANT_SUGGESTIONS_TABLE_NAME, + new_table_name=USER_URL_TYPE_SUGGESTIONS_TABLE_NAME + ) + + + +def _rename_user_url_type_suggestions_to_user_relevant_suggestions(): + op.rename_table( + old_table_name=USER_URL_TYPE_SUGGESTIONS_TABLE_NAME, + new_table_name=USER_RELEVANT_SUGGESTIONS_TABLE_NAME + ) + + +def _remove_auto_validate_task(): + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + 'HTML', + 'Relevancy', + 'Record Type', + 'Agency Identification', + 'Misc Metadata', + 'Submit Approved URLs', + 'Duplicate Detection', + '404 Probe', + 'Sync Agencies', + 'Sync Data Sources', + 'Push to Hugging Face', + 'URL Probe', + 'Populate Backlog Snapshot', + 'Delete Old Logs', + 'Run URL Task Cycles', + 'Root URL', + 'Internet Archives Probe', + 'Internet Archives Archive', + 'Screenshot', + 'Location ID' + ] + ) + + diff --git a/alembic/versions/2025_09_22_1309-6b3db0c19f9b_update_suggestion_constraints.py b/alembic/versions/2025_09_22_1309-6b3db0c19f9b_update_suggestion_constraints.py new file mode 100644 index 00000000..afd688aa --- /dev/null +++ b/alembic/versions/2025_09_22_1309-6b3db0c19f9b_update_suggestion_constraints.py @@ -0,0 +1,51 @@ +"""Update suggestion constraints + +Revision ID: 6b3db0c19f9b +Revises: 8d7208843b76 +Create Date: 2025-09-22 13:09:42.830264 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '6b3db0c19f9b' +down_revision: Union[str, None] = '8d7208843b76' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.drop_constraint( + table_name="user_url_type_suggestions", + constraint_name='uq_user_relevant_suggestions_url_id' + ) + op.drop_constraint( + table_name="user_url_agency_suggestions", + constraint_name='uq_user_agency_suggestions_url_id' + ) + op.drop_constraint( + table_name="user_record_type_suggestions", + constraint_name='uq_user_record_type_suggestions_url_id' + ) + + +def downgrade() -> None: + op.create_unique_constraint( + constraint_name='uq_user_relevant_suggestions_url_id', + table_name="user_url_type_suggestions", + columns=["url_id"], + ) + op.create_unique_constraint( + constraint_name='uq_user_agency_suggestions_url_id', + table_name="user_url_agency_suggestions", + columns=["url_id"], + ) + op.create_unique_constraint( + constraint_name='uq_user_record_type_suggestions_url_id', + table_name="user_record_type_suggestions", + columns=["url_id"], + ) diff --git a/alembic/versions/2025_09_22_1916-e6a1a1b3bad4_add_url_record_type.py b/alembic/versions/2025_09_22_1916-e6a1a1b3bad4_add_url_record_type.py new file mode 100644 index 00000000..cf69e8b0 --- /dev/null +++ b/alembic/versions/2025_09_22_1916-e6a1a1b3bad4_add_url_record_type.py @@ -0,0 +1,127 @@ +"""Add URL record type + +Revision ID: e6a1a1b3bad4 +Revises: 6b3db0c19f9b +Create Date: 2025-09-22 19:16:01.744304 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from src.util.alembic_helpers import url_id_column, created_at_column, id_column + +# revision identifiers, used by Alembic. +revision: str = 'e6a1a1b3bad4' +down_revision: Union[str, None] = '6b3db0c19f9b' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +URL_RECORD_TYPE_TABLE_NAME = "url_record_type" + + + + +def upgrade() -> None: + _create_url_record_type_table() + _migrate_url_record_types_to_url_record_type_table() + _drop_record_type_column() + _drop_agencies_sync_state() + _drop_data_sources_sync_state() + +def _drop_agencies_sync_state(): + op.drop_table("agencies_sync_state") + + +def _drop_data_sources_sync_state(): + op.drop_table("data_sources_sync_state") + + +def _create_data_sources_sync_state(): + table = op.create_table( + "data_sources_sync_state", + id_column(), + sa.Column('last_full_sync_at', sa.DateTime(), nullable=True), + sa.Column('current_cutoff_date', sa.Date(), nullable=True), + sa.Column('current_page', sa.Integer(), nullable=True), + ) + # Add row to `data_sources_sync_state` table + op.bulk_insert( + table, + [ + { + "last_full_sync_at": None, + "current_cutoff_date": None, + "current_page": None + } + ] + ) + + +def _create_agencies_sync_state(): + table = op.create_table( + 'agencies_sync_state', + id_column(), + sa.Column('last_full_sync_at', sa.DateTime(), nullable=True), + sa.Column('current_cutoff_date', sa.Date(), nullable=True), + sa.Column('current_page', sa.Integer(), nullable=True), + ) + + # Add row to `agencies_sync_state` table + op.bulk_insert( + table, + [ + { + "last_full_sync_at": None, + "current_cutoff_date": None, + "current_page": None + } + ] + ) + + +def downgrade() -> None: + _add_record_type_column() + _migrate_url_record_types_from_url_record_type_table() + _drop_url_record_type_table() + _create_agencies_sync_state() + _create_data_sources_sync_state() + +def _drop_record_type_column(): + op.drop_column("urls", "record_type") + +def _add_record_type_column(): + op.add_column("urls", sa.Column("record_type", postgresql.ENUM(name="record_type", create_type=False), nullable=True)) + + +def _create_url_record_type_table(): + op.create_table( + URL_RECORD_TYPE_TABLE_NAME, + url_id_column(primary_key=True), + sa.Column("record_type", postgresql.ENUM(name="record_type", create_type=False), nullable=False), + created_at_column() + ) + + +def _drop_url_record_type_table(): + op.drop_table(URL_RECORD_TYPE_TABLE_NAME) + + +def _migrate_url_record_types_from_url_record_type_table(): + op.execute(""" + UPDATE urls + SET record_type = url_record_type.record_type + FROM url_record_type + WHERE urls.id = url_record_type.url_id + """) + + +def _migrate_url_record_types_to_url_record_type_table(): + op.execute(""" + INSERT INTO url_record_type (url_id, record_type) + SELECT id, record_type + FROM urls + WHERE record_type IS NOT NULL + """) diff --git a/alembic/versions/2025_09_24_1739-3687026267fc_add_url_naming_logic.py b/alembic/versions/2025_09_24_1739-3687026267fc_add_url_naming_logic.py new file mode 100644 index 00000000..9e6a3821 --- /dev/null +++ b/alembic/versions/2025_09_24_1739-3687026267fc_add_url_naming_logic.py @@ -0,0 +1,69 @@ +"""Add URL naming logic + +Revision ID: 3687026267fc +Revises: e6a1a1b3bad4 +Create Date: 2025-09-24 17:39:55.353947 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import id_column, url_id_column, created_at_column, user_id_column + +# revision identifiers, used by Alembic. +revision: str = '3687026267fc' +down_revision: Union[str, None] = 'e6a1a1b3bad4' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + + +def upgrade() -> None: + _add_auto_name_task() + _create_url_name_suggestion_table() + _create_link_user_name_suggestion_table() + +def _add_auto_name_task(): + op.execute("""ALTER TYPE task_type ADD VALUE 'Auto Name';""") + + +def _create_url_name_suggestion_table(): + op.create_table( + 'url_name_suggestions', + id_column(), + url_id_column(), + sa.Column('suggestion', sa.String( + length=100 + ), nullable=False), + sa.Column( + 'source', sa.Enum( + "HTML Metadata Title", + "User", + name="suggestion_source_enum" + ) + ), + created_at_column(), + sa.UniqueConstraint( + 'url_id', 'suggestion', name='url_name_suggestions_url_id_source_unique' + ) + ) + + +def _create_link_user_name_suggestion_table(): + op.create_table( + 'link_user_name_suggestions', + user_id_column(), + sa.Column( + "suggestion_id", + sa.Integer(), + sa.ForeignKey("url_name_suggestions.id"), + nullable=False, + ), + created_at_column(), + sa.PrimaryKeyConstraint( + "user_id", + "suggestion_id" + ) + ) \ No newline at end of file diff --git a/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py b/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py new file mode 100644 index 00000000..e27633fe --- /dev/null +++ b/alembic/versions/2025_09_26_0718-7b955c783e27_add_dependent_locations.py @@ -0,0 +1,56 @@ +"""Add dependent locations + +Revision ID: 7b955c783e27 +Revises: 3687026267fc +Create Date: 2025-09-26 07:18:37.916841 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '7b955c783e27' +down_revision: Union[str, None] = '3687026267fc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + create view dependent_locations(parent_location_id, dependent_location_id) as + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'County'::location_type AND lp.type = 'State'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.county_id = lp.county_id AND ld.type = 'Locality'::location_type AND lp.type = 'County'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'Locality'::location_type AND lp.type = 'State'::location_type + UNION ALL + SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id + FROM + locations lp + JOIN locations ld ON lp.type = 'National'::location_type AND (ld.type = ANY + (ARRAY ['State'::location_type, 'County'::location_type, 'Locality'::location_type])) + """) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_26_1357-b9317c6836e7_add_agency_and_jurisdiction_type.py b/alembic/versions/2025_09_26_1357-b9317c6836e7_add_agency_and_jurisdiction_type.py new file mode 100644 index 00000000..7d917fbf --- /dev/null +++ b/alembic/versions/2025_09_26_1357-b9317c6836e7_add_agency_and_jurisdiction_type.py @@ -0,0 +1,67 @@ +"""Add agency and jurisdiction type + +Revision ID: b9317c6836e7 +Revises: 7b955c783e27 +Create Date: 2025-09-26 13:57:42.357788 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b9317c6836e7' +down_revision: Union[str, None] = '7b955c783e27' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _add_agency_type_column(): + agency_type_enum = sa.Enum( + "unknown", + "incarceration", + "law enforcement", + "court", + "aggregated", + name="agency_type_enum", + create_type=True, + ) + agency_type_enum.create(op.get_bind()) + + op.add_column( + table_name="agencies", + column=sa.Column( + "agency_type", + agency_type_enum, + server_default="unknown", + nullable=False, + ) + ) + + +def _add_jurisdiction_type_column(): + jurisdiction_type_enum = sa.Enum( + 'school', 'county', 'local', 'port', 'tribal', 'transit', 'state', 'federal', + name="jurisdiction_type_enum", + ) + jurisdiction_type_enum.create(op.get_bind()) + + op.add_column( + table_name="agencies", + column=sa.Column( + "jurisdiction_type", + jurisdiction_type_enum, + nullable=True, + ) + ) + + +def upgrade() -> None: + _add_agency_type_column() + _add_jurisdiction_type_column() + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_26_1751-d4c63e23d3f0_update_locations_expanded_view.py b/alembic/versions/2025_09_26_1751-d4c63e23d3f0_update_locations_expanded_view.py new file mode 100644 index 00000000..871e54b9 --- /dev/null +++ b/alembic/versions/2025_09_26_1751-d4c63e23d3f0_update_locations_expanded_view.py @@ -0,0 +1,85 @@ +"""Update locations expanded view + +Revision ID: d4c63e23d3f0 +Revises: b9317c6836e7 +Create Date: 2025-09-26 17:51:41.214287 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ENUM + +from src.util.alembic_helpers import id_column, location_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = 'd4c63e23d3f0' +down_revision: Union[str, None] = 'b9317c6836e7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def _update_locations_expanded_view(): + op.execute( + """ + CREATE OR REPLACE VIEW locations_expanded as + SELECT locations.id, + locations.type, + us_states.state_name, + us_states.state_iso, + counties.name AS county_name, + counties.fips AS county_fips, + localities.name AS locality_name, + localities.id AS locality_id, + us_states.id AS state_id, + counties.id AS county_id, + CASE + WHEN locations.type = 'Locality'::location_type THEN localities.name + WHEN locations.type = 'County'::location_type THEN counties.name::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + WHEN locations.type = 'National'::location_type THEN 'United States' + ELSE NULL::character varying + END AS display_name, + CASE + WHEN locations.type = 'Locality'::location_type THEN concat(localities.name, ', ', counties.name, + ', ', + us_states.state_name)::character varying + WHEN locations.type = 'County'::location_type + THEN concat(counties.name, ', ', us_states.state_name)::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + WHEN locations.type = 'National'::location_type THEN 'United States' + ELSE NULL::character varying + END AS full_display_name + FROM locations + LEFT JOIN us_states + ON locations.state_id = us_states.id + LEFT JOIN counties + ON locations.county_id = counties.id + LEFT JOIN localities + ON locations.locality_id = localities.id + """ + ) + + +def _create_new_agency_suggestion_table(): + op.create_table( + 'new_agency_suggestions', + id_column(), + location_id_column(), + sa.Column('name', sa.String()), + sa.Column('jurisdiction_type', ENUM(name='jurisdiction_type_enum', create_type=False), nullable=True), + sa.Column('agency_type', ENUM(name='agency_type_enum', create_type=False), nullable=True), + created_at_column() + ) + + +def upgrade() -> None: + _update_locations_expanded_view() + _create_new_agency_suggestion_table() + + + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_26_2002-50a710e413f8_add_suggestion_url_link_table.py b/alembic/versions/2025_09_26_2002-50a710e413f8_add_suggestion_url_link_table.py new file mode 100644 index 00000000..0c55aad5 --- /dev/null +++ b/alembic/versions/2025_09_26_2002-50a710e413f8_add_suggestion_url_link_table.py @@ -0,0 +1,39 @@ +"""Add new agency suggestion url link table + +Revision ID: 50a710e413f8 +Revises: d4c63e23d3f0 +Create Date: 2025-09-26 20:02:10.867728 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import url_id_column, agency_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = '50a710e413f8' +down_revision: Union[str, None] = 'd4c63e23d3f0' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'link_url_new_agency_suggestion', + url_id_column(), + sa.Column( + 'suggestion_id', + sa.Integer, + sa.ForeignKey('new_agency_suggestions.id'), nullable=False + ), + created_at_column(), + sa.PrimaryKeyConstraint( + 'url_id', 'suggestion_id' + ) + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_29_1246-5be534715a01_add_agency_location_not_found_logic.py b/alembic/versions/2025_09_29_1246-5be534715a01_add_agency_location_not_found_logic.py new file mode 100644 index 00000000..171adcbe --- /dev/null +++ b/alembic/versions/2025_09_29_1246-5be534715a01_add_agency_location_not_found_logic.py @@ -0,0 +1,74 @@ +"""Add Agency/Location Not Found Logic + +Revision ID: 5be534715a01 +Revises: 50a710e413f8 +Create Date: 2025-09-29 12:46:27.140173 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import created_at_column, url_id_column, user_id_column + +# revision identifiers, used by Alembic. +revision: str = '5be534715a01' +down_revision: Union[str, None] = '50a710e413f8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +def upgrade() -> None: + add_link_user_suggestion_agency_not_found_table() + add_link_user_suggestion_location_not_found_table() + add_flag_url_suspended_table() + add_url_suspend_task_type() + remove_link_url_new_agency_suggestion_table() + remove_new_agency_suggestions_table() + +def add_url_suspend_task_type(): + op.execute( + """ + ALTER TYPE task_type ADD VALUE 'Suspend URLs'; + """ + ) + +def add_link_user_suggestion_agency_not_found_table(): + op.create_table( + "link_user_suggestion_agency_not_found", + user_id_column(), + url_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint("user_id", "url_id"), + ) + + +def add_link_user_suggestion_location_not_found_table(): + op.create_table( + "link_user_suggestion_location_not_found", + user_id_column(), + url_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint("user_id", "url_id"), + ) + + +def add_flag_url_suspended_table(): + op.create_table( + "flag_url_suspended", + url_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint("url_id"), + ) + + +def remove_link_url_new_agency_suggestion_table(): + op.drop_table("link_url_new_agency_suggestion") + + +def remove_new_agency_suggestions_table(): + op.drop_table("new_agency_suggestions") + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_30_1046-84a3de626ad8_add_link_user_submitted_url_table.py b/alembic/versions/2025_09_30_1046-84a3de626ad8_add_link_user_submitted_url_table.py new file mode 100644 index 00000000..fe7d9309 --- /dev/null +++ b/alembic/versions/2025_09_30_1046-84a3de626ad8_add_link_user_submitted_url_table.py @@ -0,0 +1,34 @@ +"""Add link user submitted URL table + +Revision ID: 84a3de626ad8 +Revises: 5be534715a01 +Create Date: 2025-09-30 10:46:16.552174 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +from src.util.alembic_helpers import url_id_column, user_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = '84a3de626ad8' +down_revision: Union[str, None] = '5be534715a01' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "link_user_submitted_urls", + url_id_column(), + user_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint("url_id", "user_id"), + sa.UniqueConstraint("url_id") + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_09_30_1613-241fd3925f5d_add_logic_for_meta_url_submissions.py b/alembic/versions/2025_09_30_1613-241fd3925f5d_add_logic_for_meta_url_submissions.py new file mode 100644 index 00000000..fb30fba2 --- /dev/null +++ b/alembic/versions/2025_09_30_1613-241fd3925f5d_add_logic_for_meta_url_submissions.py @@ -0,0 +1,63 @@ +"""Add logic for meta URL submissions + +Revision ID: 241fd3925f5d +Revises: 84a3de626ad8 +Create Date: 2025-09-30 16:13:03.980113 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +from src.util.alembic_helpers import url_id_column, created_at_column, agency_id_column + +# revision identifiers, used by Alembic. +revision: str = '241fd3925f5d' +down_revision: Union[str, None] = '84a3de626ad8' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute("""ALTER TYPE task_type ADD VALUE 'Submit Meta URLs'""") + op.create_table( + "url_ds_meta_url", + url_id_column(), + agency_id_column(), + sa.Column("ds_meta_url_id", sa.Integer(), nullable=False), + created_at_column(), + sa.PrimaryKeyConstraint( + "url_id", + "agency_id" + ), + sa.UniqueConstraint( + "ds_meta_url_id" + ) + ) + op.execute("""ALTER TYPE task_type ADD VALUE 'Delete Stale Screenshots'""") + op.execute("""ALTER TYPE task_type ADD VALUE 'Mark Task Never Completed'""") + op.execute(""" + CREATE TYPE task_status_enum as ENUM( + 'complete', + 'in-process', + 'error', + 'aborted', + 'never-completed' + ) + """) + op.execute(""" + ALTER TABLE tasks + ALTER COLUMN task_status DROP DEFAULT, + ALTER COLUMN task_status TYPE task_status_enum + USING ( + CASE task_status::text -- old enum -> text + WHEN 'ready to label' THEN 'complete'::task_status_enum + ELSE task_status::text::task_status_enum + END + ); + """) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_03_1546-c5c20af87511_add_task_cleanup_task.py b/alembic/versions/2025_10_03_1546-c5c20af87511_add_task_cleanup_task.py new file mode 100644 index 00000000..39a1004f --- /dev/null +++ b/alembic/versions/2025_10_03_1546-c5c20af87511_add_task_cleanup_task.py @@ -0,0 +1,28 @@ +"""Add task cleanup task + +Revision ID: c5c20af87511 +Revises: 241fd3925f5d +Create Date: 2025-10-03 15:46:00.212674 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c5c20af87511' +down_revision: Union[str, None] = '241fd3925f5d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + ALTER TYPE task_type ADD VALUE 'Task Cleanup' + """) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_03_1831-dc6ab5157c49_add_url_task_error_table_and_remove_url_.py b/alembic/versions/2025_10_03_1831-dc6ab5157c49_add_url_task_error_table_and_remove_url_.py new file mode 100644 index 00000000..e6a4e93d --- /dev/null +++ b/alembic/versions/2025_10_03_1831-dc6ab5157c49_add_url_task_error_table_and_remove_url_.py @@ -0,0 +1,54 @@ +"""Add url_task_error table and remove url_error_info + +Revision ID: dc6ab5157c49 +Revises: c5c20af87511 +Create Date: 2025-10-03 18:31:54.887740 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ENUM + +from src.util.alembic_helpers import url_id_column, task_id_column, created_at_column + +# revision identifiers, used by Alembic. +revision: str = 'dc6ab5157c49' +down_revision: Union[str, None] = 'c5c20af87511' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + + + + +def upgrade() -> None: + _remove_url_error_info() + _remove_url_screenshot_error() + _add_url_task_error() + +def _remove_url_error_info(): + op.drop_table("url_error_info") + +def _remove_url_screenshot_error(): + op.drop_table("error_url_screenshot") + +def _add_url_task_error(): + op.create_table( + "url_task_error", + url_id_column(), + task_id_column(), + sa.Column( + "task_type", + ENUM(name="task_type", create_type=False) + ), + sa.Column("error", sa.String(), nullable=False), + created_at_column(), + sa.PrimaryKeyConstraint("url_id", "task_type") + ) + + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_04_1541-445d8858b23a_remove_agency_location_columns.py b/alembic/versions/2025_10_04_1541-445d8858b23a_remove_agency_location_columns.py new file mode 100644 index 00000000..c7d98156 --- /dev/null +++ b/alembic/versions/2025_10_04_1541-445d8858b23a_remove_agency_location_columns.py @@ -0,0 +1,29 @@ +"""Remove agency location columns + +Revision ID: 445d8858b23a +Revises: dc6ab5157c49 +Create Date: 2025-10-04 15:41:52.384222 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '445d8858b23a' +down_revision: Union[str, None] = 'dc6ab5157c49' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +TABLE_NAME = 'agencies' + +def upgrade() -> None: + op.drop_column(TABLE_NAME, 'locality') + op.drop_column(TABLE_NAME, 'state') + op.drop_column(TABLE_NAME, 'county') + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_04_1640-f708c6a8ae5d_remove_unused_batches_columns.py b/alembic/versions/2025_10_04_1640-f708c6a8ae5d_remove_unused_batches_columns.py new file mode 100644 index 00000000..83d8c441 --- /dev/null +++ b/alembic/versions/2025_10_04_1640-f708c6a8ae5d_remove_unused_batches_columns.py @@ -0,0 +1,31 @@ +"""Remove unused batches columns + +Revision ID: f708c6a8ae5d +Revises: 445d8858b23a +Create Date: 2025-10-04 16:40:11.064794 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'f708c6a8ae5d' +down_revision: Union[str, None] = '445d8858b23a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +TABLE_NAME = "batches" + +def upgrade() -> None: + op.drop_column(TABLE_NAME, "strategy_success_rate") + op.drop_column(TABLE_NAME, "metadata_success_rate") + op.drop_column(TABLE_NAME, "agency_match_rate") + op.drop_column(TABLE_NAME, "record_type_match_rate") + op.drop_column(TABLE_NAME, "record_category_match_rate") + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_05_0757-dff1085d1c3d_add_url_task_count_views.py b/alembic/versions/2025_10_05_0757-dff1085d1c3d_add_url_task_count_views.py new file mode 100644 index 00000000..0c60096c --- /dev/null +++ b/alembic/versions/2025_10_05_0757-dff1085d1c3d_add_url_task_count_views.py @@ -0,0 +1,60 @@ +"""Add URL Task Count Views + +Revision ID: dff1085d1c3d +Revises: f708c6a8ae5d +Create Date: 2025-10-05 07:57:09.333844 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'dff1085d1c3d' +down_revision: Union[str, None] = 'f708c6a8ae5d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + CREATE VIEW URL_TASK_COUNT_1_WEEK AS + ( + select + t.task_type, + count(ltu.url_id) + from + tasks t + join link_task_urls ltu + on ltu.task_id = t.id + where + t.updated_at > (now() - INTERVAL '1 week') + group by + t.task_type + ) + + """) + + op.execute(""" + CREATE VIEW URL_TASK_COUNT_1_DAY AS + ( + select + t.task_type, + count(ltu.url_id) + from + tasks t + join link_task_urls ltu + on ltu.task_id = t.id + where + t.updated_at > (now() - INTERVAL '1 day') + group by + t.task_type + ) + + """) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_09_2046-7c4049508bfc_add_link_tables_for_location_batch_and_.py b/alembic/versions/2025_10_09_2046-7c4049508bfc_add_link_tables_for_location_batch_and_.py new file mode 100644 index 00000000..8972c0d0 --- /dev/null +++ b/alembic/versions/2025_10_09_2046-7c4049508bfc_add_link_tables_for_location_batch_and_.py @@ -0,0 +1,58 @@ +"""Add link tables for location_batch and agency_batch + +Revision ID: 7c4049508bfc +Revises: dff1085d1c3d +Create Date: 2025-10-09 20:46:30.013715 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import batch_id_column, location_id_column, created_at_column, agency_id_column + +# revision identifiers, used by Alembic. +revision: str = '7c4049508bfc' +down_revision: Union[str, None] = 'dff1085d1c3d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + + + + +def upgrade() -> None: + _create_link_location_batches_table() + _create_link_agency_batches_table() + +def _create_link_location_batches_table(): + op.create_table( + "link_location_batches", + batch_id_column(), + location_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint( + 'batch_id', + 'location_id', + name='link_location_batches_pk' + ) + ) + + +def _create_link_agency_batches_table(): + op.create_table( + "link_agency_batches", + batch_id_column(), + agency_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint( + 'batch_id', + 'agency_id', + name='link_agency_batches_pk' + ) + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_11_1438-8b2adc95c5d7_add_batch_link_subtasks.py b/alembic/versions/2025_10_11_1438-8b2adc95c5d7_add_batch_link_subtasks.py new file mode 100644 index 00000000..49fd2354 --- /dev/null +++ b/alembic/versions/2025_10_11_1438-8b2adc95c5d7_add_batch_link_subtasks.py @@ -0,0 +1,34 @@ +"""Add batch link subtasks + +Revision ID: 8b2adc95c5d7 +Revises: 7c4049508bfc +Create Date: 2025-10-11 14:38:01.874040 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import add_enum_value + +# revision identifiers, used by Alembic. +revision: str = '8b2adc95c5d7' +down_revision: Union[str, None] = '7c4049508bfc' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + add_enum_value( + enum_name="agency_auto_suggestion_method", + enum_value="batch_link" + ) + add_enum_value( + enum_name="auto_location_id_subtask_type", + enum_value="batch_link" + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_11_1913-25b3fc777c31_add_url_status_view.py b/alembic/versions/2025_10_11_1913-25b3fc777c31_add_url_status_view.py new file mode 100644 index 00000000..e620828a --- /dev/null +++ b/alembic/versions/2025_10_11_1913-25b3fc777c31_add_url_status_view.py @@ -0,0 +1,88 @@ +"""Add URL status view + +Revision ID: 25b3fc777c31 +Revises: 8b2adc95c5d7 +Create Date: 2025-10-11 19:13:03.309461 + +""" +from typing import Sequence, Union + +from alembic import op + +from src.util.alembic_helpers import add_enum_value + +# revision identifiers, used by Alembic. +revision: str = '25b3fc777c31' +down_revision: Union[str, None] = '8b2adc95c5d7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + CREATE MATERIALIZED VIEW url_status_mat_view AS + with + urls_with_relevant_errors as ( + select + ute.url_id + from + url_task_error ute + where + ute.task_type in ( + 'Screenshot', + 'HTML', + 'URL Probe' + ) + ) + select + u.id as url_id, + case + when ( + -- Validated as not relevant, individual record, or not found + fuv.type in ('not relevant', 'individual record', 'not found') + -- Has Meta URL in data sources app + OR udmu.url_id is not null + -- Has data source in data sources app + OR uds.url_id is not null + ) Then 'Submitted/Pipeline Complete' + when fuv.type is not null THEN 'Accepted' + when ( + -- Has compressed HTML + uch.url_id is not null + AND + -- Has web metadata + uwm.url_id is not null + AND + -- Has screenshot + us.url_id is not null + ) THEN 'Community Labeling' + when uwre.url_id is not null then 'Error' + ELSE 'Intake' + END as status + + from + urls u + left join urls_with_relevant_errors uwre + on u.id = uwre.url_id + left join url_screenshot us + on u.id = us.url_id + left join url_compressed_html uch + on u.id = uch.url_id + left join url_web_metadata uwm + on u.id = uwm.url_id + left join flag_url_validated fuv + on u.id = fuv.url_id + left join url_ds_meta_url udmu + on u.id = udmu.url_id + left join url_data_source uds + on u.id = uds.url_id + """) + + add_enum_value( + enum_name="task_type", + enum_value="Refresh Materialized Views" + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_12_1549-d55ec2987702_remove_404_probe_task.py b/alembic/versions/2025_10_12_1549-d55ec2987702_remove_404_probe_task.py new file mode 100644 index 00000000..26fb9d0e --- /dev/null +++ b/alembic/versions/2025_10_12_1549-d55ec2987702_remove_404_probe_task.py @@ -0,0 +1,157 @@ +"""Remove 404 Probe Task + +Revision ID: d55ec2987702 +Revises: 25b3fc777c31 +Create Date: 2025-10-12 15:49:01.945412 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import remove_enum_value, add_enum_value + +# revision identifiers, used by Alembic. +revision: str = 'd55ec2987702' +down_revision: Union[str, None] = '25b3fc777c31' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + _drop_views() + add_enum_value( + enum_name="url_type", + enum_value="broken page" + ) + + op.execute( + """DELETE FROM TASKS WHERE task_type = '404 Probe'""" + ) + op.execute( + """DELETE FROM url_task_error WHERE task_type = '404 Probe'""" + ) + remove_enum_value( + enum_name="task_type", + value_to_remove="404 Probe", + targets=[ + ("tasks", "task_type"), + ("url_task_error", "task_type") + ] + ) + op.execute( + """UPDATE URLS SET status = 'ok' WHERE status = '404 not found'""" + ) + remove_enum_value( + enum_name="url_status", + value_to_remove="404 not found", + targets=[ + ("urls", "status") + ] + ) + + op.drop_table("url_probed_for_404") + + _recreate_views() + +def _drop_views(): + op.execute("drop view url_task_count_1_day") + op.execute("drop view url_task_count_1_week") + op.execute("drop materialized view url_status_mat_view") + +def _recreate_views(): + op.execute(""" + create view url_task_count_1_day(task_type, count) as + SELECT + t.task_type, + count(ltu.url_id) AS count + FROM + tasks t + JOIN link_task_urls ltu + ON ltu.task_id = t.id + WHERE + t.updated_at > (now() - '1 day'::interval) + GROUP BY + t.task_type; + """) + + op.execute(""" + create view url_task_count_1_week(task_type, count) as + SELECT + t.task_type, + count(ltu.url_id) AS count + FROM + tasks t + JOIN link_task_urls ltu + ON ltu.task_id = t.id + WHERE + t.updated_at > (now() - '7 days'::interval) + GROUP BY + t.task_type; + """) + + op.execute( + """ + CREATE MATERIALIZED VIEW url_status_mat_view AS + with + urls_with_relevant_errors as ( + select + ute.url_id + from + url_task_error ute + where + ute.task_type in ( + 'Screenshot', + 'HTML', + 'URL Probe' + ) + ) + select + u.id as url_id, + case + when ( + -- Validated as not relevant, individual record, or not found + fuv.type in ('not relevant', 'individual record', 'not found') + -- Has Meta URL in data sources app + OR udmu.url_id is not null + -- Has data source in data sources app + OR uds.url_id is not null + ) Then 'Submitted/Pipeline Complete' + when fuv.type is not null THEN 'Accepted' + when ( + -- Has compressed HTML + uch.url_id is not null + AND + -- Has web metadata + uwm.url_id is not null + AND + -- Has screenshot + us.url_id is not null + ) THEN 'Community Labeling' + when uwre.url_id is not null then 'Error' + ELSE 'Intake' + END as status + + from + urls u + left join urls_with_relevant_errors uwre + on u.id = uwre.url_id + left join url_screenshot us + on u.id = us.url_id + left join url_compressed_html uch + on u.id = uch.url_id + left join url_web_metadata uwm + on u.id = uwm.url_id + left join flag_url_validated fuv + on u.id = fuv.url_id + left join url_ds_meta_url udmu + on u.id = udmu.url_id + left join url_data_source uds + on u.id = uds.url_id + """ + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_12_1828-51bde16e22f7_add_batch_url_status_materialized_view.py b/alembic/versions/2025_10_12_1828-51bde16e22f7_add_batch_url_status_materialized_view.py new file mode 100644 index 00000000..8a3524e8 --- /dev/null +++ b/alembic/versions/2025_10_12_1828-51bde16e22f7_add_batch_url_status_materialized_view.py @@ -0,0 +1,87 @@ +"""Add Batch URL Status materialized view + +Revision ID: 51bde16e22f7 +Revises: d55ec2987702 +Create Date: 2025-10-12 18:28:28.602086 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '51bde16e22f7' +down_revision: Union[str, None] = 'd55ec2987702' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute(""" + CREATE MATERIALIZED VIEW batch_url_status_mat_view as ( + with + batches_with_urls as ( + select + b.id + from + batches b + where + exists( + select + 1 + from + link_batch_urls lbu + where + lbu.batch_id = b.id + ) + ) + , batches_with_only_validated_urls as ( + select + b.id + from + batches b + where + exists( + select + 1 + from + link_batch_urls lbu + left join flag_url_validated fuv on fuv.url_id = lbu.url_id + where + lbu.batch_id = b.id + and fuv.id is not null + ) + and not exists( + select + 1 + from + link_batch_urls lbu + left join flag_url_validated fuv on fuv.url_id = lbu.url_id + where + lbu.batch_id = b.id + and fuv.id is null + ) + ) + + select + b.id as batch_id, + case + when b.status = 'error' THEN 'Error' + when (bwu.id is null) THEN 'No URLs' + when (bwovu.id is not null) THEN 'Labeling Complete' + else 'Has Unlabeled URLs' + end as batch_url_status + from + batches b + left join batches_with_urls bwu + on bwu.id = b.id + left join batches_with_only_validated_urls bwovu + on bwovu.id = b.id + ) + """) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_12_2036-43077d7e08c5_eliminate_contact_info_and_agency_meta_.py b/alembic/versions/2025_10_12_2036-43077d7e08c5_eliminate_contact_info_and_agency_meta_.py new file mode 100644 index 00000000..e5a2513f --- /dev/null +++ b/alembic/versions/2025_10_12_2036-43077d7e08c5_eliminate_contact_info_and_agency_meta_.py @@ -0,0 +1,45 @@ +"""Eliminate Contact Info and Agency Meta Record Type + +Revision ID: 43077d7e08c5 +Revises: 51bde16e22f7 +Create Date: 2025-10-12 20:36:17.965218 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import remove_enum_value + +# revision identifiers, used by Alembic. +revision: str = '43077d7e08c5' +down_revision: Union[str, None] = '51bde16e22f7' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.execute( + """DELETE FROM URL_RECORD_TYPE WHERE RECORD_TYPE = 'Contact Info & Agency Meta'""" + ) + op.execute( + """DELETE FROM auto_record_type_suggestions WHERE record_type = 'Contact Info & Agency Meta'""" + ) + op.execute( + """DELETE FROM user_record_type_suggestions WHERE record_type = 'Contact Info & Agency Meta'""" + ) + + remove_enum_value( + enum_name="record_type", + value_to_remove="Contact Info & Agency Meta", + targets=[ + ("url_record_type", "record_type"), + ("auto_record_type_suggestions", "record_type"), + ("user_record_type_suggestions", "record_type") + ] + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/2025_10_13_2007-7aace6587d1a_add_anonymous_annotation_tables.py b/alembic/versions/2025_10_13_2007-7aace6587d1a_add_anonymous_annotation_tables.py new file mode 100644 index 00000000..18cf4230 --- /dev/null +++ b/alembic/versions/2025_10_13_2007-7aace6587d1a_add_anonymous_annotation_tables.py @@ -0,0 +1,60 @@ +"""Add anonymous annotation tables + +Revision ID: 7aace6587d1a +Revises: 43077d7e08c5 +Create Date: 2025-10-13 20:07:18.388899 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + +from src.util.alembic_helpers import url_id_column, agency_id_column, created_at_column, location_id_column, enum_column + +# revision identifiers, used by Alembic. +revision: str = '7aace6587d1a' +down_revision: Union[str, None] = '43077d7e08c5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "anonymous_annotation_agency", + url_id_column(), + agency_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint('url_id', 'agency_id') + ) + op.create_table( + "anonymous_annotation_location", + url_id_column(), + location_id_column(), + created_at_column(), + sa.PrimaryKeyConstraint('url_id', 'location_id') + ) + op.create_table( + "anonymous_annotation_record_type", + url_id_column(), + enum_column( + column_name="record_type", + enum_name="record_type" + ), + created_at_column(), + sa.PrimaryKeyConstraint('url_id', 'record_type') + ) + op.create_table( + "anonymous_annotation_url_type", + url_id_column(), + enum_column( + column_name="url_type", + enum_name="url_type" + ), + created_at_column(), + sa.PrimaryKeyConstraint('url_id', 'url_type') + ) + + +def downgrade() -> None: + pass diff --git a/alembic/versions/d7eb670edaf0_revise_agency_identification_logic.py b/alembic/versions/d7eb670edaf0_revise_agency_identification_logic.py index cd68a4b5..6ba6f7c9 100644 --- a/alembic/versions/d7eb670edaf0_revise_agency_identification_logic.py +++ b/alembic/versions/d7eb670edaf0_revise_agency_identification_logic.py @@ -118,7 +118,7 @@ def upgrade(): def downgrade(): # Drop constraints first op.drop_constraint("uq_confirmed_url_agency", "confirmed_url_agency", type_="unique") - op.drop_constraint("uq_automated_url_agency_suggestions", "automated_url_agency_suggestions", type_="unique") + # op.drop_constraint("uq_automated_url_agency_suggestions", "automated_url_agency_suggestions", type_="unique") op.drop_constraint("uq_user_url_agency_suggestions", "user_url_agency_suggestions", type_="unique") # Drop tables diff --git a/apply_migrations.py b/apply_migrations.py index 6b3188f3..2b217c8b 100644 --- a/apply_migrations.py +++ b/apply_migrations.py @@ -1,7 +1,8 @@ from alembic import command from alembic.config import Config -from src.db.helpers import get_postgres_connection_string +from src.db.helpers.connect import get_postgres_connection_string + def apply_migrations(): print("Applying migrations...") diff --git a/local_database/DataDumper/dump.sh b/local_database/DataDumper/dump.sh index 482a3ca1..6d7fa669 100644 --- a/local_database/DataDumper/dump.sh +++ b/local_database/DataDumper/dump.sh @@ -23,6 +23,7 @@ else fi # Run pg_dump -pg_dump -h $DB_HOST -p $DB_PORT -U $DB_USER -d $DB_NAME $PG_DUMP_FLAGS -f $DUMP_FILE +echo "(Excluding url_screenshot table data)" +pg_dump -h $DB_HOST -p $DB_PORT -U $DB_USER -d $DB_NAME $PG_DUMP_FLAGS -f $DUMP_FILE --exclude-table-data=url_screenshot echo "Dump completed. File saved to $DUMP_FILE." diff --git a/local_database/DockerInfos.py b/local_database/DockerInfos.py index 654b59bc..4d1d2a8f 100644 --- a/local_database/DockerInfos.py +++ b/local_database/DockerInfos.py @@ -28,7 +28,7 @@ def get_database_docker_info() -> DockerInfo: def get_source_collector_data_dumper_info() -> DockerInfo: return DockerInfo( dockerfile_info=DockerfileInfo( - image_tag="datadumper", + image_tag="datadumper_sc", dockerfile_directory=str(project_path( "local_database", "DataDumper" @@ -42,7 +42,7 @@ def get_source_collector_data_dumper_info() -> DockerInfo: )), container_path="/dump" ), - name="datadumper", + name="datadumper_sc", environment={ "DUMP_HOST": get_from_env("DUMP_HOST"), "DUMP_USER": get_from_env("DUMP_USER"), diff --git a/local_database/classes/DockerClient.py b/local_database/classes/DockerClient.py index ca9d535b..5c33e7d9 100644 --- a/local_database/classes/DockerClient.py +++ b/local_database/classes/DockerClient.py @@ -1,5 +1,7 @@ import docker from docker.errors import NotFound, APIError +from docker.models.containers import Container +from docker.models.networks import Network from local_database.DTOs import DockerfileInfo, DockerInfo @@ -9,7 +11,7 @@ class DockerClient: def __init__(self): self.client = docker.from_env() - def run_command(self, command: str, container_id: str): + def run_command(self, command: str, container_id: str) -> None: exec_id = self.client.api.exec_create( container_id, cmd=command, @@ -20,7 +22,7 @@ def run_command(self, command: str, container_id: str): for line in output_stream: print(line.decode().rstrip()) - def start_network(self, network_name): + def start_network(self, network_name) -> Network: try: self.client.networks.create(network_name, driver="bridge") except APIError as e: @@ -30,14 +32,14 @@ def start_network(self, network_name): print("Network already exists") return self.client.networks.get(network_name) - def stop_network(self, network_name): + def stop_network(self, network_name) -> None: self.client.networks.get(network_name).remove() def get_image( self, dockerfile_info: DockerfileInfo, force_rebuild: bool = False - ): + ) -> None: if dockerfile_info.dockerfile_directory: # Build image from Dockerfile self.client.images.build( @@ -58,7 +60,7 @@ def get_image( except NotFound: self.client.images.pull(dockerfile_info.image_tag) - def get_existing_container(self, docker_info_name: str): + def get_existing_container(self, docker_info_name: str) -> Container | None: try: return self.client.containers.get(docker_info_name) except NotFound: diff --git a/local_database/classes/DockerContainer.py b/local_database/classes/DockerContainer.py index 33b71ce0..0a86e601 100644 --- a/local_database/classes/DockerContainer.py +++ b/local_database/classes/DockerContainer.py @@ -11,19 +11,19 @@ def __init__(self, dc: DockerClient, container: Container): self.dc = dc self.container = container - def run_command(self, command: str): + def run_command(self, command: str) -> None: self.dc.run_command(command, self.container.id) - def stop(self): + def stop(self) -> None: self.container.stop() - def log_to_file(self): + def log_to_file(self) -> None: logs = self.container.logs(stdout=True, stderr=True) container_name = self.container.name with open(f"{container_name}.log", "wb") as f: f.write(logs) - def wait_for_pg_to_be_ready(self): + def wait_for_pg_to_be_ready(self) -> None: for i in range(30): exit_code, output = self.container.exec_run("pg_isready") print(output) diff --git a/local_database/classes/DockerManager.py b/local_database/classes/DockerManager.py index ac294dc1..fc32c3bc 100644 --- a/local_database/classes/DockerManager.py +++ b/local_database/classes/DockerManager.py @@ -4,6 +4,8 @@ import docker from docker.errors import APIError +from docker.models.containers import Container +from docker.models.networks import Network from local_database.DTOs import DockerfileInfo, DockerInfo from local_database.classes.DockerClient import DockerClient @@ -20,7 +22,7 @@ def __init__(self): self.network = self.start_network() @staticmethod - def start_docker_engine(): + def start_docker_engine() -> None: system = platform.system() match system: @@ -41,7 +43,7 @@ def start_docker_engine(): sys.exit(1) @staticmethod - def is_docker_running(): + def is_docker_running() -> bool: try: client = docker.from_env() client.ping() @@ -50,16 +52,23 @@ def is_docker_running(): print(f"Docker is not running: {e}") return False - def run_command(self, command: str, container_id: str): + def run_command( + self, + command: str, + container_id: str + ) -> None: self.client.run_command(command, container_id) - def start_network(self): + def start_network(self) -> Network: return self.client.start_network(self.network_name) - def stop_network(self): + def stop_network(self) -> None: self.client.stop_network(self.network_name) - def get_image(self, dockerfile_info: DockerfileInfo): + def get_image( + self, + dockerfile_info: DockerfileInfo + ) -> None: self.client.get_image(dockerfile_info) def run_container( @@ -74,5 +83,5 @@ def run_container( ) return DockerContainer(self.client, raw_container) - def get_containers(self): + def get_containers(self) -> list[Container]: return self.client.client.containers.list() \ No newline at end of file diff --git a/local_database/classes/TimestampChecker.py b/local_database/classes/TimestampChecker.py index 56779fd4..fc2c25a0 100644 --- a/local_database/classes/TimestampChecker.py +++ b/local_database/classes/TimestampChecker.py @@ -1,27 +1,26 @@ -import datetime import os -from typing import Optional +from datetime import datetime, timedelta class TimestampChecker: def __init__(self): - self.last_run_time: Optional[datetime.datetime] = self.load_last_run_time() + self.last_run_time: datetime | None = self.load_last_run_time() - def load_last_run_time(self) -> Optional[datetime.datetime]: + def load_last_run_time(self) -> datetime | None: # Check if file `last_run.txt` exists # If it does, load the last run time if os.path.exists("local_state/last_run.txt"): with open("local_state/last_run.txt", "r") as f: - return datetime.datetime.strptime( + return datetime.strptime( f.read(), "%Y-%m-%d %H:%M:%S" ) return None - def last_run_within_24_hours(self): + def last_run_within_24_hours(self) -> bool: if self.last_run_time is None: return False - return datetime.datetime.now() - self.last_run_time < datetime.timedelta(days=1) + return datetime.now() - self.last_run_time < timedelta(days=1) def set_last_run_time(self): # If directory `local_state` doesn't exist, create it @@ -29,4 +28,4 @@ def set_last_run_time(self): os.makedirs("local_state") with open("local_state/last_run.txt", "w") as f: - f.write(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) + f.write(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) diff --git a/local_database/create_database.py b/local_database/create_database.py index 67eae70b..e18cbd2a 100644 --- a/local_database/create_database.py +++ b/local_database/create_database.py @@ -15,7 +15,7 @@ # Connect to the default 'postgres' database to create other databases -def connect(database="postgres", autocommit=True): +def connect(database="postgres", autocommit=True) -> psycopg2.extensions.connection: conn = psycopg2.connect( dbname=database, user=POSTGRES_USER, @@ -27,7 +27,7 @@ def connect(database="postgres", autocommit=True): conn.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) return conn -def create_database(db_name): +def create_database(db_name: str) -> None: conn = connect("postgres") with conn.cursor() as cur: cur.execute(sql.SQL(""" @@ -48,7 +48,7 @@ def create_database(db_name): except Exception as e: print(f"❌ Failed to create {db_name}: {e}") -def main(): +def main() -> None: print("Creating databases...") create_database(LOCAL_SOURCE_COLLECTOR_DB_NAME) diff --git a/local_database/setup.py b/local_database/setup.py index 99ff1da9..64f5af48 100644 --- a/local_database/setup.py +++ b/local_database/setup.py @@ -7,14 +7,19 @@ MAX_RETRIES = 20 SLEEP_SECONDS = 1 -def run_command(cmd, check=True, capture_output=False, **kwargs): +def run_command( + cmd: str, + check: bool = True, + capture_output: bool = False, + **kwargs: dict +) -> subprocess.CompletedProcess: try: return subprocess.run(cmd, shell=True, check=check, capture_output=capture_output, text=True, **kwargs) except subprocess.CalledProcessError as e: print(f"Command '{cmd}' failed: {e}") sys.exit(1) -def get_postgres_container_id(): +def get_postgres_container_id() -> str: result = run_command(f"docker-compose ps -q {POSTGRES_SERVICE_NAME}", capture_output=True) container_id = result.stdout.strip() if not container_id: @@ -22,7 +27,7 @@ def get_postgres_container_id(): sys.exit(1) return container_id -def wait_for_postgres(container_id): +def wait_for_postgres(container_id: str) -> None: print("Waiting for Postgres to be ready...") for i in range(MAX_RETRIES): try: @@ -36,7 +41,7 @@ def wait_for_postgres(container_id): print("Postgres did not become ready in time.") sys.exit(1) -def main(): +def main() -> None: print("Stopping Docker Compose...") run_command("docker-compose down") diff --git a/pyproject.toml b/pyproject.toml index 15e3c8ea..70f54673 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,7 @@ version = "0.1.0" requires-python = ">=3.11" dependencies = [ "aiohttp~=3.11.11", + "aiolimiter>=1.2.1", "alembic~=1.14.0", "apscheduler~=3.11.0", "asyncpg~=0.30.0", @@ -23,6 +24,8 @@ dependencies = [ "marshmallow~=3.23.2", "openai~=1.60.1", "pdap-access-manager==0.3.6", + "pillow>=11.3.0", + "pip>=25.2", "playwright~=1.49.1", "psycopg2-binary~=2.9.6", "psycopg[binary]~=3.1.20", @@ -30,6 +33,8 @@ dependencies = [ "pyjwt~=2.10.1", "python-dotenv~=1.0.1", "requests~=2.32.3", + "side-effects>=1.6.dev0", + "spacy>=3.8.7", "sqlalchemy~=2.0.36", "starlette~=0.45.3", "tqdm>=4.64.1", @@ -46,6 +51,7 @@ dev = [ "pytest-asyncio~=0.25.2", "pytest-mock==3.12.0", "pytest-timeout~=2.3.1", + "vulture>=2.14", ] diff --git a/src/api/endpoints/annotate/_shared/extract.py b/src/api/endpoints/annotate/_shared/extract.py new file mode 100644 index 00000000..390579d9 --- /dev/null +++ b/src/api/endpoints/annotate/_shared/extract.py @@ -0,0 +1,64 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate._shared.queries.get_annotation_batch_info import GetAnnotationBatchInfoQueryBuilder +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.name import NameAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.record_type import RecordTypeAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse, \ + GetNextURLForAllAnnotationInnerResponse +from src.api.endpoints.annotate.all.get.models.url_type import URLTypeAnnotationSuggestion +from src.api.endpoints.annotate.all.get.queries.agency.core import GetAgencySuggestionsQueryBuilder +from src.api.endpoints.annotate.all.get.queries.convert import \ + convert_user_url_type_suggestion_to_url_type_annotation_suggestion, \ + convert_user_record_type_suggestion_to_record_type_annotation_suggestion +from src.api.endpoints.annotate.all.get.queries.location_.core import GetLocationSuggestionsQueryBuilder +from src.api.endpoints.annotate.all.get.queries.name.core import GetNameSuggestionsQueryBuilder +from src.db.dto_converter import DTOConverter +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion + + +async def extract_and_format_get_annotation_result( + session: AsyncSession, + url: URL, + batch_id: int | None = None +): + html_response_info = DTOConverter.html_content_list_to_html_response_info( + url.html_content + ) + url_type_suggestions: list[URLTypeAnnotationSuggestion] = \ + convert_user_url_type_suggestion_to_url_type_annotation_suggestion( + url.user_relevant_suggestions + ) + record_type_suggestions: list[RecordTypeAnnotationSuggestion] = \ + convert_user_record_type_suggestion_to_record_type_annotation_suggestion( + url.user_record_type_suggestions + ) + agency_suggestions: AgencyAnnotationResponseOuterInfo = \ + await GetAgencySuggestionsQueryBuilder(url_id=url.id).run(session) + location_suggestions: LocationAnnotationResponseOuterInfo = \ + await GetLocationSuggestionsQueryBuilder(url_id=url.id).run(session) + name_suggestions: list[NameAnnotationSuggestion] = \ + await GetNameSuggestionsQueryBuilder(url_id=url.id).run(session) + return GetNextURLForAllAnnotationResponse( + next_annotation=GetNextURLForAllAnnotationInnerResponse( + url_info=URLMapping( + url_id=url.id, + url=url.url + ), + html_info=html_response_info, + url_type_suggestions=url_type_suggestions, + record_type_suggestions=record_type_suggestions, + agency_suggestions=agency_suggestions, + batch_info=await GetAnnotationBatchInfoQueryBuilder( + batch_id=batch_id, + models=[ + UserUrlAgencySuggestion, + ] + ).run(session), + location_suggestions=location_suggestions, + name_suggestions=name_suggestions + ) + ) diff --git a/src/api/endpoints/annotate/_shared/queries/get_annotation_batch_info.py b/src/api/endpoints/annotate/_shared/queries/get_annotation_batch_info.py index 15f5b631..5a56cf32 100644 --- a/src/api/endpoints/annotate/_shared/queries/get_annotation_batch_info.py +++ b/src/api/endpoints/annotate/_shared/queries/get_annotation_batch_info.py @@ -5,8 +5,8 @@ from src.api.endpoints.annotate.dtos.shared.batch import AnnotationBatchInfo from src.collectors.enums import URLStatus -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL from src.db.queries.base.builder import QueryBuilderBase from src.db.statement_composer import StatementComposer from src.db.types import UserSuggestionType @@ -42,7 +42,7 @@ async def run( ) common_where_clause = [ - URL.outcome == URLStatus.PENDING.value, + URL.status == URLStatus.OK.value, LinkBatchURL.batch_id == self.batch_id, ] diff --git a/src/api/endpoints/annotate/_shared/queries/get_next_url_for_user_annotation.py b/src/api/endpoints/annotate/_shared/queries/get_next_url_for_user_annotation.py deleted file mode 100644 index 3bda8ff3..00000000 --- a/src/api/endpoints/annotate/_shared/queries/get_next_url_for_user_annotation.py +++ /dev/null @@ -1,75 +0,0 @@ -from sqlalchemy import select, not_, exists -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import QueryableAttribute, joinedload - -from src.collectors.enums import URLStatus -from src.core.enums import SuggestedStatus -from src.db.client.types import UserSuggestionModel -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer - - -class GetNextURLForUserAnnotationQueryBuilder(QueryBuilderBase): - - def __init__( - self, - user_suggestion_model_to_exclude: UserSuggestionModel, - auto_suggestion_relationship: QueryableAttribute, - batch_id: int | None, - check_if_annotated_not_relevant: bool = False - ): - super().__init__() - self.check_if_annotated_not_relevant = check_if_annotated_not_relevant - self.batch_id = batch_id - self.user_suggestion_model_to_exclude = user_suggestion_model_to_exclude - self.auto_suggestion_relationship = auto_suggestion_relationship - - async def run(self, session: AsyncSession): - query = ( - select( - URL, - ) - ) - - if self.batch_id is not None: - query = ( - query - .join(LinkBatchURL) - .where(LinkBatchURL.batch_id == self.batch_id) - ) - - query = ( - query - .where(URL.outcome == URLStatus.PENDING.value) - # URL must not have user suggestion - .where( - StatementComposer.user_suggestion_not_exists(self.user_suggestion_model_to_exclude) - ) - ) - - if self.check_if_annotated_not_relevant: - query = query.where( - not_( - exists( - select(UserRelevantSuggestion) - .where( - UserRelevantSuggestion.url_id == URL.id, - UserRelevantSuggestion.suggested_status != SuggestedStatus.RELEVANT.value - ) - ) - ) - ) - - - - query = query.options( - joinedload(self.auto_suggestion_relationship), - joinedload(URL.html_content) - ).limit(1) - - raw_result = await session.execute(query) - - return raw_result.unique().scalars().one_or_none() \ No newline at end of file diff --git a/src/api/endpoints/annotate/agency/get/dto.py b/src/api/endpoints/annotate/agency/get/dto.py index f2dda0f5..a0c06622 100644 --- a/src/api/endpoints/annotate/agency/get/dto.py +++ b/src/api/endpoints/annotate/agency/get/dto.py @@ -7,17 +7,12 @@ class GetNextURLForAgencyAgencyInfo(BaseModel): suggestion_type: SuggestionType - pdap_agency_id: Optional[int] = None - agency_name: Optional[str] = None - state: Optional[str] = None - county: Optional[str] = None - locality: Optional[str] = None - -class GetNextURLForAgencyAnnotationInnerResponse(AnnotationInnerResponseInfoBase): - agency_suggestions: list[ - GetNextURLForAgencyAgencyInfo - ] - -class GetNextURLForAgencyAnnotationResponse(BaseModel): - next_annotation: Optional[GetNextURLForAgencyAnnotationInnerResponse] + pdap_agency_id: int | None = None + agency_name: str | None = None + state: str | None = None + county: str | None = None + locality: str | None = None +class AgencySuggestionAndUserCount(BaseModel): + suggestion: GetNextURLForAgencyAgencyInfo + user_count: int \ No newline at end of file diff --git a/src/api/endpoints/annotate/agency/get/queries/agency_suggestion.py b/src/api/endpoints/annotate/agency/get/queries/agency_suggestion.py deleted file mode 100644 index f1ab8b67..00000000 --- a/src/api/endpoints/annotate/agency/get/queries/agency_suggestion.py +++ /dev/null @@ -1,55 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo -from src.core.enums import SuggestionType -from src.db.models.instantiations.agency import Agency -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.queries.base.builder import QueryBuilderBase - - -class GetAgencySuggestionsQueryBuilder(QueryBuilderBase): - - def __init__( - self, - url_id: int - ): - super().__init__() - self.url_id = url_id - - async def run(self, session: AsyncSession) -> list[GetNextURLForAgencyAgencyInfo]: - # Get relevant autosuggestions and agency info, if an associated agency exists - - statement = ( - select( - AutomatedUrlAgencySuggestion.agency_id, - AutomatedUrlAgencySuggestion.is_unknown, - Agency.name, - Agency.state, - Agency.county, - Agency.locality - ) - .join(Agency, isouter=True) - .where(AutomatedUrlAgencySuggestion.url_id == self.url_id) - ) - raw_autosuggestions = await session.execute(statement) - autosuggestions = raw_autosuggestions.all() - agency_suggestions = [] - for autosuggestion in autosuggestions: - agency_id = autosuggestion[0] - is_unknown = autosuggestion[1] - name = autosuggestion[2] - state = autosuggestion[3] - county = autosuggestion[4] - locality = autosuggestion[5] - agency_suggestions.append( - GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.AUTO_SUGGESTION if not is_unknown else SuggestionType.UNKNOWN, - pdap_agency_id=agency_id, - agency_name=name, - state=state, - county=county, - locality=locality - ) - ) - return agency_suggestions \ No newline at end of file diff --git a/src/api/endpoints/annotate/agency/get/queries/next_for_annotation.py b/src/api/endpoints/annotate/agency/get/queries/next_for_annotation.py deleted file mode 100644 index 5bfd6e8a..00000000 --- a/src/api/endpoints/annotate/agency/get/queries/next_for_annotation.py +++ /dev/null @@ -1,128 +0,0 @@ -from sqlalchemy import select, exists -from sqlalchemy.ext.asyncio import AsyncSession - -from src.api.endpoints.annotate._shared.queries.get_annotation_batch_info import GetAnnotationBatchInfoQueryBuilder -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAnnotationResponse, \ - GetNextURLForAgencyAnnotationInnerResponse -from src.api.endpoints.annotate.agency.get.queries.agency_suggestion import GetAgencySuggestionsQueryBuilder -from src.collectors.enums import URLStatus -from src.core.enums import SuggestedStatus -from src.core.tasks.url.operators.url_html.scraper.parser.util import convert_to_response_html_info -from src.db.dtos.url.mapping import URLMapping -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from src.db.queries.base.builder import QueryBuilderBase -from src.db.queries.implementations.core.get.html_content_info import GetHTMLContentInfoQueryBuilder - - -class GetNextURLAgencyForAnnotationQueryBuilder(QueryBuilderBase): - - def __init__( - self, - batch_id: int | None, - user_id: int - ): - super().__init__() - self.batch_id = batch_id - self.user_id = user_id - - async def run( - self, - session: AsyncSession - ) -> GetNextURLForAgencyAnnotationResponse: - """ - Retrieve URL for annotation - The URL must - not be a confirmed URL - not have been annotated by this user - have extant autosuggestions - """ - # Select statement - query = select(URL.id, URL.url) - if self.batch_id is not None: - query = query.join(LinkBatchURL).where(LinkBatchURL.batch_id == self.batch_id) - - # Must not have confirmed agencies - query = query.where( - URL.outcome == URLStatus.PENDING.value - ) - - - # Must not have been annotated by a user - query = ( - query.join(UserUrlAgencySuggestion, isouter=True) - .where( - ~exists( - select(UserUrlAgencySuggestion). - where(UserUrlAgencySuggestion.url_id == URL.id). - correlate(URL) - ) - ) - # Must have extant autosuggestions - .join(AutomatedUrlAgencySuggestion, isouter=True) - .where( - exists( - select(AutomatedUrlAgencySuggestion). - where(AutomatedUrlAgencySuggestion.url_id == URL.id). - correlate(URL) - ) - ) - # Must not have confirmed agencies - .join(ConfirmedURLAgency, isouter=True) - .where( - ~exists( - select(ConfirmedURLAgency). - where(ConfirmedURLAgency.url_id == URL.id). - correlate(URL) - ) - ) - # Must not have been marked as "Not Relevant" by this user - .join(UserRelevantSuggestion, isouter=True) - .where( - ~exists( - select(UserRelevantSuggestion). - where( - (UserRelevantSuggestion.user_id == self.user_id) & - (UserRelevantSuggestion.url_id == URL.id) & - (UserRelevantSuggestion.suggested_status != SuggestedStatus.RELEVANT.value) - ).correlate(URL) - ) - ) - ).limit(1) - raw_result = await session.execute(query) - results = raw_result.all() - if len(results) == 0: - return GetNextURLForAgencyAnnotationResponse( - next_annotation=None - ) - - result = results[0] - url_id = result[0] - url = result[1] - - agency_suggestions = await GetAgencySuggestionsQueryBuilder(url_id=url_id).run(session) - - # Get HTML content info - html_content_infos = await GetHTMLContentInfoQueryBuilder(url_id).run(session) - response_html_info = convert_to_response_html_info(html_content_infos) - - return GetNextURLForAgencyAnnotationResponse( - next_annotation=GetNextURLForAgencyAnnotationInnerResponse( - url_info=URLMapping( - url=url, - url_id=url_id - ), - html_info=response_html_info, - agency_suggestions=agency_suggestions, - batch_info=await GetAnnotationBatchInfoQueryBuilder( - batch_id=self.batch_id, - models=[ - UserUrlAgencySuggestion, - ] - ).run(session) - ) - ) \ No newline at end of file diff --git a/src/api/endpoints/annotate/agency/post/dto.py b/src/api/endpoints/annotate/agency/post/dto.py index 1d0ade02..dc41720a 100644 --- a/src/api/endpoints/annotate/agency/post/dto.py +++ b/src/api/endpoints/annotate/agency/post/dto.py @@ -5,4 +5,4 @@ class URLAgencyAnnotationPostInfo(BaseModel): is_new: bool = False - suggested_agency: Optional[int] = None + suggested_agency: int | None = None diff --git a/src/api/endpoints/annotate/all/get/dto.py b/src/api/endpoints/annotate/all/get/dto.py deleted file mode 100644 index 63d46ce6..00000000 --- a/src/api/endpoints/annotate/all/get/dto.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Optional - -from pydantic import Field, BaseModel - -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo -from src.api.endpoints.annotate.dtos.shared.base.response import AnnotationInnerResponseInfoBase -from src.api.endpoints.annotate.relevance.get.dto import RelevanceAnnotationResponseInfo -from src.core.enums import RecordType - - -class GetNextURLForAllAnnotationInnerResponse(AnnotationInnerResponseInfoBase): - agency_suggestions: list[GetNextURLForAgencyAgencyInfo] | None = Field( - title="The auto-labeler's suggestions for agencies" - ) - suggested_relevant: RelevanceAnnotationResponseInfo | None = Field( - title="Whether the auto-labeler identified the URL as relevant or not" - ) - suggested_record_type: RecordType | None = Field( - title="What record type, if any, the auto-labeler identified the URL as" - ) - - -class GetNextURLForAllAnnotationResponse(BaseModel): - next_annotation: Optional[GetNextURLForAllAnnotationInnerResponse] \ No newline at end of file diff --git a/src/api/endpoints/annotate/agency/get/queries/__init__.py b/src/api/endpoints/annotate/all/get/models/__init__.py similarity index 100% rename from src/api/endpoints/annotate/agency/get/queries/__init__.py rename to src/api/endpoints/annotate/all/get/models/__init__.py diff --git a/src/api/endpoints/annotate/all/get/models/agency.py b/src/api/endpoints/annotate/all/get/models/agency.py new file mode 100644 index 00000000..45806d98 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/agency.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel, Field + + +class AgencyAnnotationAutoSuggestion(BaseModel): + agency_id: int + agency_name: str + confidence: int = Field( + title="The confidence of the location", + ge=0, + le=100, + ) + +class AgencyAnnotationUserSuggestion(BaseModel): + agency_id: int + agency_name: str + user_count: int + +class AgencyAnnotationUserSuggestionOuterInfo(BaseModel): + suggestions: list[AgencyAnnotationUserSuggestion] + not_found_count: int = Field( + title="How many users listed the agency as not found.", + ge=0, + ) + +class AgencyAnnotationResponseOuterInfo(BaseModel): + user: AgencyAnnotationUserSuggestionOuterInfo + auto: list[AgencyAnnotationAutoSuggestion] \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/models/location.py b/src/api/endpoints/annotate/all/get/models/location.py new file mode 100644 index 00000000..fb467004 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/location.py @@ -0,0 +1,35 @@ +from pydantic import BaseModel, Field + + +class LocationAnnotationAutoSuggestion(BaseModel): + location_id: int + location_name: str = Field( + title="The full name of the location" + ) + confidence: int = Field( + title="The confidence of the location", + ge=0, + le=100, + ) + + +class LocationAnnotationUserSuggestion(BaseModel): + location_id: int + location_name: str = Field( + title="The full name of the location" + ) + user_count: int = Field( + title="The number of users who suggested this location", + ge=1, + ) + +class LocationAnnotationUserSuggestionOuterInfo(BaseModel): + suggestions: list[LocationAnnotationUserSuggestion] + not_found_count: int = Field( + title="How many users listed the location as not found.", + ge=0, + ) + +class LocationAnnotationResponseOuterInfo(BaseModel): + user: LocationAnnotationUserSuggestionOuterInfo + auto: list[LocationAnnotationAutoSuggestion] \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/models/name.py b/src/api/endpoints/annotate/all/get/models/name.py new file mode 100644 index 00000000..80857305 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/name.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class NameAnnotationSuggestion(BaseModel): + name: str + suggestion_id: int + endorsement_count: int \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/models/record_type.py b/src/api/endpoints/annotate/all/get/models/record_type.py new file mode 100644 index 00000000..a1c24911 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/record_type.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.core.enums import RecordType + + + +class RecordTypeAnnotationSuggestion(BaseModel): + record_type: RecordType + endorsement_count: int + + diff --git a/src/api/endpoints/annotate/all/get/models/response.py b/src/api/endpoints/annotate/all/get/models/response.py new file mode 100644 index 00000000..989dbf8d --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/response.py @@ -0,0 +1,35 @@ +from typing import Optional + +from pydantic import Field, BaseModel + +from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.name import NameAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.record_type import RecordTypeAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.url_type import URLTypeAnnotationSuggestion +from src.api.endpoints.annotate.dtos.shared.base.response import AnnotationInnerResponseInfoBase +from src.api.endpoints.annotate.relevance.get.dto import RelevanceAnnotationResponseInfo +from src.core.enums import RecordType + + +class GetNextURLForAllAnnotationInnerResponse(AnnotationInnerResponseInfoBase): + agency_suggestions: AgencyAnnotationResponseOuterInfo | None = Field( + title="The auto-labeler's suggestions for agencies" + ) + location_suggestions: LocationAnnotationResponseOuterInfo | None = Field( + title="User and Auto-Suggestions for locations" + ) + url_type_suggestions: list[URLTypeAnnotationSuggestion] = Field( + title="Whether the auto-labeler identified the URL as relevant or not" + ) + record_type_suggestions: list[RecordTypeAnnotationSuggestion] = Field( + title="What record type, if any, user and the auto-labeler identified the URL as" + ) + name_suggestions: list[NameAnnotationSuggestion] | None = Field( + title="User and Auto-Suggestions for names" + ) + + +class GetNextURLForAllAnnotationResponse(BaseModel): + next_annotation: GetNextURLForAllAnnotationInnerResponse | None \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/models/url_type.py b/src/api/endpoints/annotate/all/get/models/url_type.py new file mode 100644 index 00000000..cbc947e6 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/models/url_type.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.db.models.impl.flag.url_validated.enums import URLType + + +class URLTypeAnnotationSuggestion(BaseModel): + url_type: URLType + endorsement_count: int diff --git a/src/api/endpoints/annotate/dtos/record_type/__init__.py b/src/api/endpoints/annotate/all/get/queries/__init__.py similarity index 100% rename from src/api/endpoints/annotate/dtos/record_type/__init__.py rename to src/api/endpoints/annotate/all/get/queries/__init__.py diff --git a/src/api/endpoints/annotate/relevance/post/__init__.py b/src/api/endpoints/annotate/all/get/queries/agency/__init__.py similarity index 100% rename from src/api/endpoints/annotate/relevance/post/__init__.py rename to src/api/endpoints/annotate/all/get/queries/agency/__init__.py diff --git a/src/api/endpoints/annotate/all/get/queries/agency/core.py b/src/api/endpoints/annotate/all/get/queries/agency/core.py new file mode 100644 index 00000000..28cfbd2d --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/agency/core.py @@ -0,0 +1,47 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationResponseOuterInfo, \ + AgencyAnnotationUserSuggestionOuterInfo, AgencyAnnotationUserSuggestion, AgencyAnnotationAutoSuggestion +from src.api.endpoints.annotate.all.get.queries.agency.requester import GetAgencySuggestionsRequester +from src.db.queries.base.builder import QueryBuilderBase +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationResponseOuterInfo, \ + AgencyAnnotationUserSuggestionOuterInfo, AgencyAnnotationUserSuggestion, AgencyAnnotationAutoSuggestion +from src.api.endpoints.annotate.all.get.queries.agency.requester import GetAgencySuggestionsRequester +from src.db.queries.base.builder import QueryBuilderBase + + +class GetAgencySuggestionsQueryBuilder(QueryBuilderBase): + + def __init__( + self, + url_id: int, + location_id: int | None = None + ): + super().__init__() + self.url_id = url_id + self.location_id = location_id + + async def run(self, session: AsyncSession) -> AgencyAnnotationResponseOuterInfo: + requester = GetAgencySuggestionsRequester( + session, + url_id=self.url_id, + location_id=self.location_id + ) + + user_suggestions: list[AgencyAnnotationUserSuggestion] = \ + await requester.get_user_agency_suggestions() + auto_suggestions: list[AgencyAnnotationAutoSuggestion] = \ + await requester.get_auto_agency_suggestions() + not_found_count: int = \ + await requester.get_not_found_count() + return AgencyAnnotationResponseOuterInfo( + user=AgencyAnnotationUserSuggestionOuterInfo( + suggestions=user_suggestions, + not_found_count=not_found_count + ), + auto=auto_suggestions, + ) + + diff --git a/src/api/endpoints/annotate/all/get/queries/agency/requester.py b/src/api/endpoints/annotate/all/get/queries/agency/requester.py new file mode 100644 index 00000000..fc309e50 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/agency/requester.py @@ -0,0 +1,137 @@ +from typing import Sequence + +from sqlalchemy import func, select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationAutoSuggestion, \ + AgencyAnnotationUserSuggestion +from src.api.endpoints.annotate.all.get.queries.agency.suggestions_with_highest_confidence import \ + SuggestionsWithHighestConfidenceCTE +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.impl.link.user_suggestion_not_found.agency.sqlalchemy import LinkUserSuggestionAgencyNotFound +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.templates.requester import RequesterBase + + +class GetAgencySuggestionsRequester(RequesterBase): + + def __init__( + self, + session: AsyncSession, + url_id: int, + location_id: int + ): + super().__init__(session) + self.url_id = url_id + self.location_id = location_id + + async def get_user_agency_suggestions(self) -> list[AgencyAnnotationUserSuggestion]: + query = ( + select( + UserUrlAgencySuggestion.agency_id, + func.count(UserUrlAgencySuggestion.user_id).label("count"), + Agency.name.label("agency_name"), + ) + .join( + Agency, + Agency.agency_id == UserUrlAgencySuggestion.agency_id + ) + + ) + + if self.location_id is not None: + query = ( + query.join( + LinkAgencyLocation, + LinkAgencyLocation.agency_id == UserUrlAgencySuggestion.agency_id + ) + .where( + LinkAgencyLocation.location_id == self.location_id + ) + ) + + query = ( + query.where( + UserUrlAgencySuggestion.url_id == self.url_id + ) + .group_by( + UserUrlAgencySuggestion.agency_id, + Agency.name + ) + .order_by( + func.count(UserUrlAgencySuggestion.user_id).desc() + ) + .limit(3) + ) + + results: Sequence[RowMapping] = await sh.mappings(self.session, query=query) + + return [ + AgencyAnnotationUserSuggestion( + agency_id=autosuggestion["agency_id"], + user_count=autosuggestion["count"], + agency_name=autosuggestion["agency_name"], + ) + for autosuggestion in results + ] + + + async def get_auto_agency_suggestions(self) -> list[AgencyAnnotationAutoSuggestion]: + cte = SuggestionsWithHighestConfidenceCTE() + query = ( + select( + cte.agency_id, + cte.confidence, + Agency.name.label("agency_name"), + ) + .join( + Agency, + Agency.agency_id == cte.agency_id + ) + ) + + if self.location_id is not None: + query = ( + query.join( + LinkAgencyLocation, + LinkAgencyLocation.agency_id == cte.agency_id + ) + .where( + LinkAgencyLocation.location_id == self.location_id + ) + ) + + query = ( + query.where( + cte.url_id == self.url_id + ) + .order_by( + cte.confidence.desc() + ) + .limit(3) + ) + + results: Sequence[RowMapping] = await sh.mappings(self.session, query=query) + + return [ + AgencyAnnotationAutoSuggestion( + agency_id=autosuggestion["agency_id"], + confidence=autosuggestion["confidence"], + agency_name=autosuggestion["agency_name"], + ) + for autosuggestion in results + ] + + async def get_not_found_count(self) -> int: + query = ( + select( + func.count(LinkUserSuggestionAgencyNotFound.user_id) + ) + .where( + LinkUserSuggestionAgencyNotFound.url_id == self.url_id + ) + ) + + return await sh.scalar(self.session, query=query) \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/queries/agency/suggestions_with_highest_confidence.py b/src/api/endpoints/annotate/all/get/queries/agency/suggestions_with_highest_confidence.py new file mode 100644 index 00000000..6d389b11 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/agency/suggestions_with_highest_confidence.py @@ -0,0 +1,62 @@ +from sqlalchemy import CTE, select, func, Column + +from src.db.models.impl.url.suggestion.agency.subtask.sqlalchemy import URLAutoAgencyIDSubtask +from src.db.models.impl.url.suggestion.agency.suggestion.sqlalchemy import AgencyIDSubtaskSuggestion + +SUGGESTIONS_WITH_HIGHEST_CONFIDENCE_CTE: CTE = ( + select( + URLAutoAgencyIDSubtask.url_id, + AgencyIDSubtaskSuggestion.agency_id, + func.max(AgencyIDSubtaskSuggestion.confidence) + ) + .select_from(URLAutoAgencyIDSubtask) + .join( + AgencyIDSubtaskSuggestion, + URLAutoAgencyIDSubtask.id == AgencyIDSubtaskSuggestion.subtask_id + ) + .group_by( + URLAutoAgencyIDSubtask.url_id, + AgencyIDSubtaskSuggestion.agency_id + ) + .cte("suggestions_with_highest_confidence") +) + +class SuggestionsWithHighestConfidenceCTE: + + def __init__(self): + self._cte = ( + select( + URLAutoAgencyIDSubtask.url_id, + AgencyIDSubtaskSuggestion.agency_id, + func.max(AgencyIDSubtaskSuggestion.confidence).label("confidence") + ) + .select_from(URLAutoAgencyIDSubtask) + .join( + AgencyIDSubtaskSuggestion, + URLAutoAgencyIDSubtask.id == AgencyIDSubtaskSuggestion.subtask_id + ) + .where( + AgencyIDSubtaskSuggestion.agency_id.isnot(None) + ) + .group_by( + URLAutoAgencyIDSubtask.url_id, + AgencyIDSubtaskSuggestion.agency_id + ) + .cte("suggestions_with_highest_confidence") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.columns.url_id + + @property + def agency_id(self) -> Column[int]: + return self._cte.columns.agency_id + + @property + def confidence(self) -> Column[float]: + return self._cte.columns.confidence \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/queries/convert.py b/src/api/endpoints/annotate/all/get/queries/convert.py new file mode 100644 index 00000000..535a7d15 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/convert.py @@ -0,0 +1,43 @@ +from collections import Counter + +from src.api.endpoints.annotate.all.get.models.record_type import RecordTypeAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.url_type import URLTypeAnnotationSuggestion +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion + + +def convert_user_url_type_suggestion_to_url_type_annotation_suggestion( + db_suggestions: list[UserURLTypeSuggestion] +) -> list[URLTypeAnnotationSuggestion]: + counter: Counter[URLType] = Counter() + for suggestion in db_suggestions: + counter[suggestion.type] += 1 + anno_suggestions: list[URLTypeAnnotationSuggestion] = [] + for url_type, endorsement_count in counter.most_common(3): + anno_suggestions.append( + URLTypeAnnotationSuggestion( + url_type=url_type, + endorsement_count=endorsement_count, + ) + ) + return anno_suggestions + +def convert_user_record_type_suggestion_to_record_type_annotation_suggestion( + db_suggestions: list[UserRecordTypeSuggestion] +) -> list[RecordTypeAnnotationSuggestion]: + counter: Counter[RecordType] = Counter() + for suggestion in db_suggestions: + counter[suggestion.record_type] += 1 + + anno_suggestions: list[RecordTypeAnnotationSuggestion] = [] + for record_type, endorsement_count in counter.most_common(3): + anno_suggestions.append( + RecordTypeAnnotationSuggestion( + record_type=record_type, + endorsement_count=endorsement_count, + ) + ) + + return anno_suggestions \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/get/queries/core.py b/src/api/endpoints/annotate/all/get/queries/core.py new file mode 100644 index 00000000..e37f2396 --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/core.py @@ -0,0 +1,125 @@ +from sqlalchemy import Select, exists, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from src.api.endpoints.annotate._shared.extract import extract_and_format_get_annotation_result +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.collectors.enums import URLStatus +from src.db.models.impl.flag.url_suspended.sqlalchemy import FlagURLSuspended +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL +from src.db.models.views.url_anno_count import URLAnnotationCount +from src.db.models.views.url_annotations_flags import URLAnnotationFlagsView +from src.db.queries.base.builder import QueryBuilderBase + + +class GetNextURLForAllAnnotationQueryBuilder(QueryBuilderBase): + + def __init__( + self, + batch_id: int | None, + user_id: int, + url_id: int | None = None + ): + super().__init__() + self.batch_id = batch_id + self.url_id = url_id + self.user_id = user_id + + async def run( + self, + session: AsyncSession + ) -> GetNextURLForAllAnnotationResponse: + query = ( + Select(URL) + # URL Must be unvalidated + .join( + UnvalidatedURL, + UnvalidatedURL.url_id == URL.id + ) + .join( + URLAnnotationFlagsView, + URLAnnotationFlagsView.url_id == URL.id + ) + .join( + URLAnnotationCount, + URLAnnotationCount.url_id == URL.id + ) + ) + if self.batch_id is not None: + query = query.join(LinkBatchURL).where(LinkBatchURL.batch_id == self.batch_id) + if self.url_id is not None: + query = query.where(URL.id == self.url_id) + query = ( + query + .where( + URL.status == URLStatus.OK.value, + # Must not have been previously annotated by user + ~exists( + select(UserURLTypeSuggestion.id) + .where( + UserURLTypeSuggestion.url_id == URL.id, + UserURLTypeSuggestion.user_id == self.user_id, + ) + ), + ~exists( + select(UserUrlAgencySuggestion.id) + .where( + UserUrlAgencySuggestion.url_id == URL.id, + UserUrlAgencySuggestion.user_id == self.user_id, + ) + ), + ~exists( + select( + UserLocationSuggestion.url_id + ) + .where( + UserLocationSuggestion.url_id == URL.id, + UserLocationSuggestion.user_id == self.user_id, + ) + ), + ~exists( + select( + UserRecordTypeSuggestion.url_id + ) + .where( + UserRecordTypeSuggestion.url_id == URL.id, + UserRecordTypeSuggestion.user_id == self.user_id, + ) + ), + ~exists( + select( + FlagURLSuspended.url_id + ) + .where( + FlagURLSuspended.url_id == URL.id, + ) + ) + ) + ) + # Add load options + query = query.options( + joinedload(URL.html_content), + joinedload(URL.user_relevant_suggestions), + joinedload(URL.user_record_type_suggestions), + joinedload(URL.name_suggestions), + ) + + query = query.order_by( + URLAnnotationCount.total_anno_count.desc(), + URL.id.asc() + ).limit(1) + raw_results = (await session.execute(query)).unique() + url: URL | None = raw_results.scalars().one_or_none() + if url is None: + return GetNextURLForAllAnnotationResponse( + next_annotation=None + ) + + return await extract_and_format_get_annotation_result(session, url=url, batch_id=self.batch_id) + diff --git a/src/collectors/source_collectors/__init__.py b/src/api/endpoints/annotate/all/get/queries/location_/__init__.py similarity index 100% rename from src/collectors/source_collectors/__init__.py rename to src/api/endpoints/annotate/all/get/queries/location_/__init__.py diff --git a/src/api/endpoints/annotate/all/get/queries/location_/core.py b/src/api/endpoints/annotate/all/get/queries/location_/core.py new file mode 100644 index 00000000..85db523c --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/location_/core.py @@ -0,0 +1,41 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationResponseOuterInfo, \ + LocationAnnotationUserSuggestion, LocationAnnotationAutoSuggestion, LocationAnnotationUserSuggestionOuterInfo +from src.api.endpoints.annotate.all.get.queries.location_.requester import GetLocationSuggestionsRequester +from src.db.queries.base.builder import QueryBuilderBase +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationResponseOuterInfo, \ + LocationAnnotationUserSuggestion, LocationAnnotationAutoSuggestion +from src.api.endpoints.annotate.all.get.queries.location_.requester import GetLocationSuggestionsRequester +from src.db.queries.base.builder import QueryBuilderBase + + +class GetLocationSuggestionsQueryBuilder(QueryBuilderBase): + + def __init__( + self, + url_id: int + ): + super().__init__() + self.url_id = url_id + + + async def run(self, session: AsyncSession) -> LocationAnnotationResponseOuterInfo: + requester = GetLocationSuggestionsRequester(session) + user_suggestions: list[LocationAnnotationUserSuggestion] = \ + await requester.get_user_location_suggestions(self.url_id) + auto_suggestions: list[LocationAnnotationAutoSuggestion] = \ + await requester.get_auto_location_suggestions(self.url_id) + not_found_count: int = \ + await requester.get_not_found_count(self.url_id) + + return LocationAnnotationResponseOuterInfo( + user=LocationAnnotationUserSuggestionOuterInfo( + suggestions=user_suggestions, + not_found_count=not_found_count + ), + auto=auto_suggestions + ) + diff --git a/src/api/endpoints/annotate/all/get/queries/location_/requester.py b/src/api/endpoints/annotate/all/get/queries/location_/requester.py new file mode 100644 index 00000000..c60c8efe --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/location_/requester.py @@ -0,0 +1,94 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping + +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationUserSuggestion, \ + LocationAnnotationAutoSuggestion +from src.db.models.impl.link.user_suggestion_not_found.location.sqlalchemy import LinkUserSuggestionLocationNotFound +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.views.location_expanded import LocationExpandedView +from src.db.templates.requester import RequesterBase + +from src.db.helpers.session import session_helper as sh + +class GetLocationSuggestionsRequester(RequesterBase): + + + async def get_user_location_suggestions(self, url_id: int) -> list[LocationAnnotationUserSuggestion]: + query = ( + select( + UserLocationSuggestion.location_id, + LocationExpandedView.display_name.label("location_name"), + func.count(UserLocationSuggestion.user_id).label('user_count') + ) + .join( + LocationExpandedView, + LocationExpandedView.id == UserLocationSuggestion.location_id + ) + .where( + UserLocationSuggestion.url_id == url_id + ) + .group_by( + UserLocationSuggestion.location_id, + LocationExpandedView.display_name + ) + .order_by( + func.count(UserLocationSuggestion.user_id).desc() + ) + ) + raw_results: Sequence[RowMapping] = await sh.mappings(self.session, query) + return [ + LocationAnnotationUserSuggestion( + **raw_result + ) + for raw_result in raw_results + ] + + + + async def get_auto_location_suggestions( + self, + url_id: int + ) -> list[LocationAnnotationAutoSuggestion]: + query = ( + select( + LocationExpandedView.full_display_name.label("location_name"), + LocationIDSubtaskSuggestion.location_id, + LocationIDSubtaskSuggestion.confidence, + ) + .join( + LocationExpandedView, + LocationExpandedView.id == LocationIDSubtaskSuggestion.location_id + ) + .join( + AutoLocationIDSubtask, + AutoLocationIDSubtask.id == LocationIDSubtaskSuggestion.subtask_id + ) + .where( + AutoLocationIDSubtask.url_id == url_id + ) + .order_by( + LocationIDSubtaskSuggestion.confidence.desc() + ) + ) + raw_results: Sequence[RowMapping] = await sh.mappings(self.session, query) + return [ + LocationAnnotationAutoSuggestion( + **raw_result + ) + for raw_result in raw_results + ] + + async def get_not_found_count(self, url_id: int) -> int: + query = ( + select( + func.count(LinkUserSuggestionLocationNotFound.user_id) + ) + .where( + LinkUserSuggestionLocationNotFound.url_id == url_id + ) + ) + + return await sh.scalar(self.session, query=query) \ No newline at end of file diff --git a/src/collectors/source_collectors/auto_googler/__init__.py b/src/api/endpoints/annotate/all/get/queries/name/__init__.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/__init__.py rename to src/api/endpoints/annotate/all/get/queries/name/__init__.py diff --git a/src/api/endpoints/annotate/all/get/queries/name/core.py b/src/api/endpoints/annotate/all/get/queries/name/core.py new file mode 100644 index 00000000..b048cb2c --- /dev/null +++ b/src/api/endpoints/annotate/all/get/queries/name/core.py @@ -0,0 +1,58 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.get.models.name import NameAnnotationSuggestion +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.queries.base.builder import QueryBuilderBase + + +class GetNameSuggestionsQueryBuilder(QueryBuilderBase): + + def __init__( + self, + url_id: int + ): + super().__init__() + self.url_id = url_id + + async def run(self, session: AsyncSession) -> list[NameAnnotationSuggestion]: + query = ( + select( + URLNameSuggestion.id.label('suggestion_id'), + URLNameSuggestion.suggestion.label('name'), + func.count( + LinkUserNameSuggestion.user_id + ).label('endorsement_count'), + ) + .outerjoin( + LinkUserNameSuggestion, + LinkUserNameSuggestion.suggestion_id == URLNameSuggestion.id, + ) + .where( + URLNameSuggestion.url_id == self.url_id, + ) + .group_by( + URLNameSuggestion.id, + URLNameSuggestion.suggestion, + ) + .order_by( + func.count(LinkUserNameSuggestion.user_id).desc(), + URLNameSuggestion.id.asc(), + ) + .limit(3) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + return [ + NameAnnotationSuggestion( + **mapping + ) + for mapping in mappings + ] + + + diff --git a/src/api/endpoints/annotate/all/get/query.py b/src/api/endpoints/annotate/all/get/query.py deleted file mode 100644 index 1191e8d6..00000000 --- a/src/api/endpoints/annotate/all/get/query.py +++ /dev/null @@ -1,112 +0,0 @@ -from sqlalchemy import Select, and_ -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from src.api.endpoints.annotate._shared.queries.get_annotation_batch_info import GetAnnotationBatchInfoQueryBuilder -from src.api.endpoints.annotate.agency.get.queries.agency_suggestion import GetAgencySuggestionsQueryBuilder -from src.api.endpoints.annotate.agency.get.queries.next_for_annotation import GetNextURLAgencyForAnnotationQueryBuilder -from src.api.endpoints.annotate.all.get.dto import GetNextURLForAllAnnotationResponse, \ - GetNextURLForAllAnnotationInnerResponse -from src.api.endpoints.annotate.relevance.get.dto import RelevanceAnnotationResponseInfo -from src.collectors.enums import URLStatus -from src.db.dto_converter import DTOConverter -from src.db.dtos.url.mapping import URLMapping -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer - - -class GetNextURLForAllAnnotationQueryBuilder(QueryBuilderBase): - - def __init__( - self, - batch_id: int | None - ): - super().__init__() - self.batch_id = batch_id - - async def run( - self, - session: AsyncSession - ) -> GetNextURLForAllAnnotationResponse: - query = Select(URL) - if self.batch_id is not None: - query = query.join(LinkBatchURL).where(LinkBatchURL.batch_id == self.batch_id) - query = ( - query - .where( - and_( - URL.outcome == URLStatus.PENDING.value, - StatementComposer.user_suggestion_not_exists(UserUrlAgencySuggestion), - StatementComposer.user_suggestion_not_exists(UserRecordTypeSuggestion), - StatementComposer.user_suggestion_not_exists(UserRelevantSuggestion), - ) - ) - ) - - - load_options = [ - URL.html_content, - URL.automated_agency_suggestions, - URL.auto_relevant_suggestion, - URL.auto_record_type_suggestion - ] - select_in_loads = [ - selectinload(load_option) for load_option in load_options - ] - - # Add load options - query = query.options( - *select_in_loads - ) - - query = query.order_by(URL.id.asc()).limit(1) - raw_results = await session.execute(query) - url = raw_results.scalars().one_or_none() - if url is None: - return GetNextURLForAllAnnotationResponse( - next_annotation=None - ) - - html_response_info = DTOConverter.html_content_list_to_html_response_info( - url.html_content - ) - - if url.auto_relevant_suggestion is not None: - auto_relevant = url.auto_relevant_suggestion - else: - auto_relevant = None - - if url.auto_record_type_suggestion is not None: - auto_record_type = url.auto_record_type_suggestion.record_type - else: - auto_record_type = None - - agency_suggestions = await GetAgencySuggestionsQueryBuilder(url_id=url.id).run(session) - - return GetNextURLForAllAnnotationResponse( - next_annotation=GetNextURLForAllAnnotationInnerResponse( - url_info=URLMapping( - url_id=url.id, - url=url.url - ), - html_info=html_response_info, - suggested_relevant=RelevanceAnnotationResponseInfo( - is_relevant=auto_relevant.relevant, - confidence=auto_relevant.confidence, - model_name=auto_relevant.model_name - ) if auto_relevant is not None else None, - suggested_record_type=auto_record_type, - agency_suggestions=agency_suggestions, - batch_info=await GetAnnotationBatchInfoQueryBuilder( - batch_id=self.batch_id, - models=[ - UserUrlAgencySuggestion, - ] - ).run(session) - ) - ) \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/post/dto.py b/src/api/endpoints/annotate/all/post/dto.py deleted file mode 100644 index 293dcd7a..00000000 --- a/src/api/endpoints/annotate/all/post/dto.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel, model_validator - -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.core.enums import RecordType, SuggestedStatus -from src.core.exceptions import FailedValidationException - - -class AllAnnotationPostInfo(BaseModel): - suggested_status: SuggestedStatus - record_type: Optional[RecordType] = None - agency: Optional[URLAgencyAnnotationPostInfo] = None - - @model_validator(mode="after") - def allow_record_type_and_agency_only_if_relevant(self): - suggested_status = self.suggested_status - record_type = self.record_type - agency = self.agency - - if suggested_status != SuggestedStatus.RELEVANT: - if record_type is not None: - raise FailedValidationException("record_type must be None if suggested_status is not relevant") - - if agency is not None: - raise FailedValidationException("agency must be None if suggested_status is not relevant") - return self - # Similarly, if relevant, record_type and agency must be provided - if record_type is None: - raise FailedValidationException("record_type must be provided if suggested_status is relevant") - if agency is None: - raise FailedValidationException("agency must be provided if suggested_status is relevant") - return self \ No newline at end of file diff --git a/src/collectors/source_collectors/auto_googler/dtos/__init__.py b/src/api/endpoints/annotate/all/post/models/__init__.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/dtos/__init__.py rename to src/api/endpoints/annotate/all/post/models/__init__.py diff --git a/src/api/endpoints/annotate/all/post/models/agency.py b/src/api/endpoints/annotate/all/post/models/agency.py new file mode 100644 index 00000000..97574e86 --- /dev/null +++ b/src/api/endpoints/annotate/all/post/models/agency.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, model_validator + + +class AnnotationPostAgencyInfo(BaseModel): + not_found: bool = False + agency_ids: list[int] = [] + + @property + def empty(self) -> bool: + return len(self.agency_ids) == 0 + + @model_validator(mode="after") + def forbid_not_found_if_agency_ids(self): + if self.not_found and len(self.agency_ids) > 0: + raise ValueError("not_found must be False if agency_ids is not empty") + return self diff --git a/src/api/endpoints/annotate/all/post/models/location.py b/src/api/endpoints/annotate/all/post/models/location.py new file mode 100644 index 00000000..1eb7947d --- /dev/null +++ b/src/api/endpoints/annotate/all/post/models/location.py @@ -0,0 +1,16 @@ +from pydantic import BaseModel, model_validator + + +class AnnotationPostLocationInfo(BaseModel): + not_found: bool = False + location_ids: list[int] = [] + + @property + def empty(self) -> bool: + return len(self.location_ids) == 0 + + @model_validator(mode="after") + def forbid_not_found_if_location_ids(self): + if self.not_found and len(self.location_ids) > 0: + raise ValueError("not_found must be False if location_ids is not empty") + return self \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/post/models/name.py b/src/api/endpoints/annotate/all/post/models/name.py new file mode 100644 index 00000000..4cc63682 --- /dev/null +++ b/src/api/endpoints/annotate/all/post/models/name.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, ConfigDict + + +class AnnotationPostNameInfo(BaseModel): + model_config = ConfigDict(extra="forbid") + new_name: str | None = None + existing_name_id: int | None = None + + @property + def empty(self) -> bool: + return self.new_name is None and self.existing_name_id is None \ No newline at end of file diff --git a/src/api/endpoints/annotate/all/post/models/request.py b/src/api/endpoints/annotate/all/post/models/request.py new file mode 100644 index 00000000..8de222de --- /dev/null +++ b/src/api/endpoints/annotate/all/post/models/request.py @@ -0,0 +1,42 @@ +from pydantic import BaseModel, model_validator, ConfigDict + +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.name import AnnotationPostNameInfo +from src.core.enums import RecordType +from src.core.exceptions import FailedValidationException +from src.db.models.impl.flag.url_validated.enums import URLType + + +class AllAnnotationPostInfo(BaseModel): + model_config = ConfigDict(extra='forbid') + + suggested_status: URLType + record_type: RecordType | None = None + agency_info: AnnotationPostAgencyInfo = AnnotationPostAgencyInfo() + location_info: AnnotationPostLocationInfo = AnnotationPostLocationInfo() + name_info: AnnotationPostNameInfo = AnnotationPostNameInfo() + + @model_validator(mode="after") + def forbid_record_type_if_meta_url_or_individual_record(self): + if self.suggested_status not in [ + URLType.META_URL, + URLType.INDIVIDUAL_RECORD, + ]: + return self + if self.record_type is not None: + raise FailedValidationException("record_type must be None if suggested_status is META_URL") + return self + + @model_validator(mode="after") + def forbid_all_else_if_not_relevant(self): + if self.suggested_status != URLType.NOT_RELEVANT: + return self + if self.record_type is not None: + raise FailedValidationException("record_type must be None if suggested_status is NOT RELEVANT") + if not self.agency_info.empty: + raise FailedValidationException("agency_info must be empty if suggested_status is NOT RELEVANT") + if not self.location_info.empty: + raise FailedValidationException("location_ids must be empty if suggested_status is NOT RELEVANT") + return self + diff --git a/src/api/endpoints/annotate/all/post/query.py b/src/api/endpoints/annotate/all/post/query.py new file mode 100644 index 00000000..4056de8e --- /dev/null +++ b/src/api/endpoints/annotate/all/post/query.py @@ -0,0 +1,51 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.api.endpoints.annotate.all.post.requester import AddAllAnnotationsToURLRequester +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.queries.base.builder import QueryBuilderBase + + +class AddAllAnnotationsToURLQueryBuilder(QueryBuilderBase): + + def __init__( + self, + user_id: int, + url_id: int, + post_info: AllAnnotationPostInfo + ): + super().__init__() + self.user_id = user_id + self.url_id = url_id + self.post_info = post_info + + + async def run(self, session: AsyncSession) -> None: + requester = AddAllAnnotationsToURLRequester( + session=session, + url_id=self.url_id, + user_id=self.user_id + ) + + # Add relevant annotation + requester.add_relevant_annotation(self.post_info.suggested_status) + + await requester.optionally_add_name_suggestion(self.post_info.name_info) + + + # If not relevant, do nothing else + if self.post_info.suggested_status == URLType.NOT_RELEVANT: + return + + requester.add_location_ids(self.post_info.location_info.location_ids) + + # TODO (TEST): Add test for submitting Meta URL validation + requester.optionally_add_record_type(self.post_info.record_type) + + requester.add_agency_ids(self.post_info.agency_info.agency_ids) + + if self.post_info.location_info.not_found: + requester.add_not_found_location() + + if self.post_info.agency_info.not_found: + requester.add_not_found_agency() diff --git a/src/api/endpoints/annotate/all/post/requester.py b/src/api/endpoints/annotate/all/post/requester.py new file mode 100644 index 00000000..14064e8a --- /dev/null +++ b/src/api/endpoints/annotate/all/post/requester.py @@ -0,0 +1,111 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.post.models.name import AnnotationPostNameInfo +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.link.user_suggestion_not_found.agency.sqlalchemy import LinkUserSuggestionAgencyNotFound +from src.db.models.impl.link.user_suggestion_not_found.location.sqlalchemy import LinkUserSuggestionLocationNotFound +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from src.db.templates.requester import RequesterBase + + +class AddAllAnnotationsToURLRequester(RequesterBase): + + def __init__( + self, + session: AsyncSession, + url_id: int, + user_id: int, + ): + super().__init__(session=session) + self.url_id = url_id + self.user_id = user_id + + def optionally_add_record_type( + self, + rt: RecordType | None, + ) -> None: + if rt is None: + return + record_type_suggestion = UserRecordTypeSuggestion( + url_id=self.url_id, + user_id=self.user_id, + record_type=rt.value + ) + self.session.add(record_type_suggestion) + + def add_relevant_annotation( + self, + url_type: URLType, + ) -> None: + relevant_suggestion = UserURLTypeSuggestion( + url_id=self.url_id, + user_id=self.user_id, + type=url_type + ) + self.session.add(relevant_suggestion) + + def add_agency_ids(self, agency_ids: list[int]) -> None: + for agency_id in agency_ids: + agency_suggestion = UserUrlAgencySuggestion( + url_id=self.url_id, + user_id=self.user_id, + agency_id=agency_id, + ) + self.session.add(agency_suggestion) + + def add_location_ids(self, location_ids: list[int]) -> None: + locations: list[UserLocationSuggestion] = [] + for location_id in location_ids: + locations.append(UserLocationSuggestion( + url_id=self.url_id, + user_id=self.user_id, + location_id=location_id + )) + self.session.add_all(locations) + + async def optionally_add_name_suggestion( + self, + name_info: AnnotationPostNameInfo + ) -> None: + if name_info.empty: + return + if name_info.existing_name_id is not None: + link = LinkUserNameSuggestion( + user_id=self.user_id, + suggestion_id=name_info.existing_name_id, + ) + self.session.add(link) + return + name_suggestion = URLNameSuggestion( + url_id=self.url_id, + suggestion=name_info.new_name, + source=NameSuggestionSource.USER + ) + self.session.add(name_suggestion) + await self.session.flush() + link = LinkUserNameSuggestion( + user_id=self.user_id, + suggestion_id=name_suggestion.id, + ) + self.session.add(link) + + def add_not_found_agency(self) -> None: + not_found_agency = LinkUserSuggestionAgencyNotFound( + user_id=self.user_id, + url_id=self.url_id, + ) + self.session.add(not_found_agency) + + def add_not_found_location(self) -> None: + not_found_location = LinkUserSuggestionLocationNotFound( + user_id=self.user_id, + url_id=self.url_id, + ) + self.session.add(not_found_location) diff --git a/src/collectors/source_collectors/ckan/__init__.py b/src/api/endpoints/annotate/anonymous/__init__.py similarity index 100% rename from src/collectors/source_collectors/ckan/__init__.py rename to src/api/endpoints/annotate/anonymous/__init__.py diff --git a/src/collectors/source_collectors/ckan/dtos/__init__.py b/src/api/endpoints/annotate/anonymous/get/__init__.py similarity index 100% rename from src/collectors/source_collectors/ckan/dtos/__init__.py rename to src/api/endpoints/annotate/anonymous/get/__init__.py diff --git a/src/api/endpoints/annotate/anonymous/get/query.py b/src/api/endpoints/annotate/anonymous/get/query.py new file mode 100644 index 00000000..7e5f2e53 --- /dev/null +++ b/src/api/endpoints/annotate/anonymous/get/query.py @@ -0,0 +1,61 @@ +from typing import Any + +from sqlalchemy import Select, func +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from src.api.endpoints.annotate._shared.extract import extract_and_format_get_annotation_result +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.collectors.enums import URLStatus +from src.db.helpers.query import not_exists_url +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.anonymous.url_type.sqlalchemy import AnonymousAnnotationURLType +from src.db.models.views.unvalidated_url import UnvalidatedURL +from src.db.models.views.url_anno_count import URLAnnotationCount +from src.db.models.views.url_annotations_flags import URLAnnotationFlagsView +from src.db.queries.base.builder import QueryBuilderBase + + +class GetNextURLForAnonymousAnnotationQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> GetNextURLForAllAnnotationResponse: + + query = ( + Select(URL) + # URL Must be unvalidated + .join( + UnvalidatedURL, + UnvalidatedURL.url_id == URL.id + ) + .join( + URLAnnotationFlagsView, + URLAnnotationFlagsView.url_id == URL.id + ) + .join( + URLAnnotationCount, + URLAnnotationCount.url_id == URL.id + ) + .where( + URL.status == URLStatus.OK.value, + not_exists_url(AnonymousAnnotationURLType) + ) + .options( + joinedload(URL.html_content), + joinedload(URL.user_relevant_suggestions), + joinedload(URL.user_record_type_suggestions), + joinedload(URL.name_suggestions), + ) + .order_by( + func.random() + ) + .limit(1) + ) + + raw_results = (await session.execute(query)).unique() + url: URL | None = raw_results.scalars().one_or_none() + if url is None: + return GetNextURLForAllAnnotationResponse( + next_annotation=None + ) + + return await extract_and_format_get_annotation_result(session, url=url) diff --git a/src/collectors/source_collectors/ckan/dtos/search/__init__.py b/src/api/endpoints/annotate/anonymous/post/__init__.py similarity index 100% rename from src/collectors/source_collectors/ckan/dtos/search/__init__.py rename to src/api/endpoints/annotate/anonymous/post/__init__.py diff --git a/src/api/endpoints/annotate/anonymous/post/query.py b/src/api/endpoints/annotate/anonymous/post/query.py new file mode 100644 index 00000000..faa7aa1d --- /dev/null +++ b/src/api/endpoints/annotate/anonymous/post/query.py @@ -0,0 +1,56 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.db.models.impl.url.suggestion.anonymous.agency.sqlalchemy import AnonymousAnnotationAgency +from src.db.models.impl.url.suggestion.anonymous.location.sqlalchemy import AnonymousAnnotationLocation +from src.db.models.impl.url.suggestion.anonymous.record_type.sqlalchemy import AnonymousAnnotationRecordType +from src.db.models.impl.url.suggestion.anonymous.url_type.sqlalchemy import AnonymousAnnotationURLType +from src.db.queries.base.builder import QueryBuilderBase + + +class AddAnonymousAnnotationsToURLQueryBuilder(QueryBuilderBase): + def __init__( + self, + url_id: int, + post_info: AllAnnotationPostInfo + ): + super().__init__() + self.url_id = url_id + self.post_info = post_info + + async def run(self, session: AsyncSession) -> None: + + url_type_suggestion = AnonymousAnnotationURLType( + url_id=self.url_id, + url_type=self.post_info.suggested_status + ) + session.add(url_type_suggestion) + + if self.post_info.record_type is not None: + record_type_suggestion = AnonymousAnnotationRecordType( + url_id=self.url_id, + record_type=self.post_info.record_type + ) + session.add(record_type_suggestion) + + if len(self.post_info.location_info.location_ids) != 0: + location_suggestions = [ + AnonymousAnnotationLocation( + url_id=self.url_id, + location_id=location_id + ) + for location_id in self.post_info.location_info.location_ids + ] + session.add_all(location_suggestions) + + if len(self.post_info.agency_info.agency_ids) != 0: + agency_suggestions = [ + AnonymousAnnotationAgency( + url_id=self.url_id, + agency_id=agency_id + ) + for agency_id in self.post_info.agency_info.agency_ids + ] + session.add_all(agency_suggestions) + + # Ignore Name suggestions \ No newline at end of file diff --git a/src/api/endpoints/annotate/dtos/record_type/post.py b/src/api/endpoints/annotate/dtos/record_type/post.py deleted file mode 100644 index a3c7a653..00000000 --- a/src/api/endpoints/annotate/dtos/record_type/post.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - -from src.core.enums import RecordType - - -class RecordTypeAnnotationPostInfo(BaseModel): - record_type: RecordType \ No newline at end of file diff --git a/src/api/endpoints/annotate/dtos/record_type/response.py b/src/api/endpoints/annotate/dtos/record_type/response.py deleted file mode 100644 index d46c8e12..00000000 --- a/src/api/endpoints/annotate/dtos/record_type/response.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import Optional - -from pydantic import Field, BaseModel - -from src.api.endpoints.annotate.dtos.shared.base.response import AnnotationInnerResponseInfoBase -from src.core.enums import RecordType - - -class GetNextRecordTypeAnnotationResponseInfo( - AnnotationInnerResponseInfoBase -): - suggested_record_type: Optional[RecordType] = Field( - title="What record type, if any, the auto-labeler identified the URL as" - ) - -class GetNextRecordTypeAnnotationResponseOuterInfo( - BaseModel -): - next_annotation: Optional[GetNextRecordTypeAnnotationResponseInfo] diff --git a/src/api/endpoints/annotate/dtos/shared/base/response.py b/src/api/endpoints/annotate/dtos/shared/base/response.py index a7e30385..edcc80e1 100644 --- a/src/api/endpoints/annotate/dtos/shared/base/response.py +++ b/src/api/endpoints/annotate/dtos/shared/base/response.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from src.api.endpoints.annotate.dtos.shared.batch import AnnotationBatchInfo -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo from src.db.dtos.url.mapping import URLMapping @@ -14,6 +14,6 @@ class AnnotationInnerResponseInfoBase(BaseModel): html_info: ResponseHTMLInfo = Field( title="HTML information about the URL" ) - batch_info: Optional[AnnotationBatchInfo] = Field( + batch_info: AnnotationBatchInfo | None = Field( title="Information about the annotation batch" ) \ No newline at end of file diff --git a/src/api/endpoints/annotate/relevance/get/dto.py b/src/api/endpoints/annotate/relevance/get/dto.py index b4467365..8855fdf3 100644 --- a/src/api/endpoints/annotate/relevance/get/dto.py +++ b/src/api/endpoints/annotate/relevance/get/dto.py @@ -15,11 +15,3 @@ class RelevanceAnnotationResponseInfo(BaseModel): model_name: str | None = Field( title="The name of the model that made the annotation" ) - -class GetNextRelevanceAnnotationResponseInfo(AnnotationInnerResponseInfoBase): - annotation: RelevanceAnnotationResponseInfo | None = Field( - title="The auto-labeler's annotation for relevance" - ) - -class GetNextRelevanceAnnotationResponseOuterInfo(BaseModel): - next_annotation: Optional[GetNextRelevanceAnnotationResponseInfo] diff --git a/src/api/endpoints/annotate/relevance/get/query.py b/src/api/endpoints/annotate/relevance/get/query.py deleted file mode 100644 index ffd37d2c..00000000 --- a/src/api/endpoints/annotate/relevance/get/query.py +++ /dev/null @@ -1,65 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncSession - -from src.api.endpoints.annotate._shared.queries.get_annotation_batch_info import GetAnnotationBatchInfoQueryBuilder -from src.api.endpoints.annotate._shared.queries.get_next_url_for_user_annotation import \ - GetNextURLForUserAnnotationQueryBuilder -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseInfo, \ - RelevanceAnnotationResponseInfo -from src.core.tasks.url.operators.auto_relevant.models.annotation import RelevanceAnnotationInfo -from src.db.dto_converter import DTOConverter -from src.db.dtos.url.mapping import URLMapping -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from src.db.queries.base.builder import QueryBuilderBase - - -class GetNextUrlForRelevanceAnnotationQueryBuilder(QueryBuilderBase): - - def __init__( - self, - batch_id: int | None - ): - super().__init__() - self.batch_id = batch_id - - async def run( - self, - session: AsyncSession - ) -> GetNextRelevanceAnnotationResponseInfo | None: - url = await GetNextURLForUserAnnotationQueryBuilder( - user_suggestion_model_to_exclude=UserRelevantSuggestion, - auto_suggestion_relationship=URL.auto_relevant_suggestion, - batch_id=self.batch_id - ).run(session) - if url is None: - return None - - # Next, get all HTML content for the URL - html_response_info = DTOConverter.html_content_list_to_html_response_info( - url.html_content - ) - - if url.auto_relevant_suggestion is not None: - suggestion = url.auto_relevant_suggestion - else: - suggestion = None - - return GetNextRelevanceAnnotationResponseInfo( - url_info=URLMapping( - url=url.url, - url_id=url.id - ), - annotation=RelevanceAnnotationResponseInfo( - is_relevant=suggestion.relevant, - confidence=suggestion.confidence, - model_name=suggestion.model_name - ) if suggestion else None, - html_info=html_response_info, - batch_info=await GetAnnotationBatchInfoQueryBuilder( - batch_id=self.batch_id, - models=[ - UserUrlAgencySuggestion, - ] - ).run(session) - ) diff --git a/src/api/endpoints/annotate/relevance/post/dto.py b/src/api/endpoints/annotate/relevance/post/dto.py deleted file mode 100644 index a29a5327..00000000 --- a/src/api/endpoints/annotate/relevance/post/dto.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - -from src.core.enums import SuggestedStatus - - -class RelevanceAnnotationPostInfo(BaseModel): - suggested_status: SuggestedStatus \ No newline at end of file diff --git a/src/api/endpoints/annotate/routes.py b/src/api/endpoints/annotate/routes.py index fb5b117e..a09ee1ec 100644 --- a/src/api/endpoints/annotate/routes.py +++ b/src/api/endpoints/annotate/routes.py @@ -1,19 +1,16 @@ -from typing import Optional - -from fastapi import APIRouter, Depends, Path, Query +from fastapi import APIRouter, Depends, Query from src.api.dependencies import get_async_core -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAnnotationResponse -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.api.endpoints.annotate.all.get.dto import GetNextURLForAllAnnotationResponse -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.post import RecordTypeAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.post.dto import RelevanceAnnotationPostInfo +from src.api.endpoints.annotate.all.get.models.agency import AgencyAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.get.queries.agency.core import GetAgencySuggestionsQueryBuilder +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.api.endpoints.annotate.all.post.query import AddAllAnnotationsToURLQueryBuilder +from src.api.endpoints.annotate.anonymous.get.query import GetNextURLForAnonymousAnnotationQueryBuilder +from src.api.endpoints.annotate.anonymous.post.query import AddAnonymousAnnotationsToURLQueryBuilder from src.core.core import AsyncCore -from src.security.manager import get_access_info from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_access_info annotate_router = APIRouter( prefix="/annotate", @@ -26,115 +23,51 @@ "If not specified, defaults to first qualifying URL", default=None ) - -@annotate_router.get("/relevance") -async def get_next_url_for_relevance_annotation( - access_info: AccessInfo = Depends(get_access_info), - async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None), -) -> GetNextRelevanceAnnotationResponseOuterInfo: - return await async_core.get_next_url_for_relevance_annotation( - user_id=access_info.user_id, - batch_id=batch_id - ) +url_id_query = Query( + description="The URL id to annotate. " + + "If not specified, defaults to first qualifying URL", + default=None +) -@annotate_router.post("/relevance/{url_id}") -async def annotate_url_for_relevance_and_get_next_url( - relevance_annotation_post_info: RelevanceAnnotationPostInfo, - url_id: int = Path(description="The URL id to annotate"), - async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = batch_query -) -> GetNextRelevanceAnnotationResponseOuterInfo: - """ - Post URL annotation and get next URL to annotate - """ - await async_core.submit_url_relevance_annotation( - user_id=access_info.user_id, - url_id=url_id, - suggested_status=relevance_annotation_post_info.suggested_status - ) - return await async_core.get_next_url_for_relevance_annotation( - user_id=access_info.user_id, - batch_id=batch_id +@annotate_router.get("/anonymous") +async def get_next_url_for_all_annotations_anonymous( + async_core: AsyncCore = Depends(get_async_core), +) -> GetNextURLForAllAnnotationResponse: + return await async_core.adb_client.run_query_builder( + GetNextURLForAnonymousAnnotationQueryBuilder() ) -@annotate_router.get("/record-type") -async def get_next_url_for_record_type_annotation( - access_info: AccessInfo = Depends(get_access_info), - async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = batch_query -) -> GetNextRecordTypeAnnotationResponseOuterInfo: - return await async_core.get_next_url_for_record_type_annotation( - user_id=access_info.user_id, - batch_id=batch_id +@annotate_router.post("/anonymous/{url_id}") +async def annotate_url_for_all_annotations_and_get_next_url_anonymous( + url_id: int, + all_annotation_post_info: AllAnnotationPostInfo, + async_core: AsyncCore = Depends(get_async_core), +) -> GetNextURLForAllAnnotationResponse: + await async_core.adb_client.run_query_builder( + AddAnonymousAnnotationsToURLQueryBuilder( + url_id=url_id, + post_info=all_annotation_post_info + ) ) -@annotate_router.post("/record-type/{url_id}") -async def annotate_url_for_record_type_and_get_next_url( - record_type_annotation_post_info: RecordTypeAnnotationPostInfo, - url_id: int = Path(description="The URL id to annotate"), - async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = batch_query -) -> GetNextRecordTypeAnnotationResponseOuterInfo: - """ - Post URL annotation and get next URL to annotate - """ - await async_core.submit_url_record_type_annotation( - user_id=access_info.user_id, - url_id=url_id, - record_type=record_type_annotation_post_info.record_type, - ) - return await async_core.get_next_url_for_record_type_annotation( - user_id=access_info.user_id, - batch_id=batch_id + return await async_core.adb_client.run_query_builder( + GetNextURLForAnonymousAnnotationQueryBuilder() ) -@annotate_router.get("/agency") -async def get_next_url_for_agency_annotation( - access_info: AccessInfo = Depends(get_access_info), - async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = batch_query -) -> GetNextURLForAgencyAnnotationResponse: - return await async_core.get_next_url_agency_for_annotation( - user_id=access_info.user_id, - batch_id=batch_id - ) -@annotate_router.post("/agency/{url_id}") -async def annotate_url_for_agency_and_get_next_url( - url_id: int, - agency_annotation_post_info: URLAgencyAnnotationPostInfo, - async_core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = batch_query -) -> GetNextURLForAgencyAnnotationResponse: - """ - Post URL annotation and get next URL to annotate - """ - await async_core.submit_url_agency_annotation( - user_id=access_info.user_id, - url_id=url_id, - agency_post_info=agency_annotation_post_info - ) - return await async_core.get_next_url_agency_for_annotation( - user_id=access_info.user_id, - batch_id=batch_id - ) @annotate_router.get("/all") async def get_next_url_for_all_annotations( access_info: AccessInfo = Depends(get_access_info), async_core: AsyncCore = Depends(get_async_core), - batch_id: Optional[int] = batch_query + batch_id: int | None = batch_query, + anno_url_id: int | None = url_id_query ) -> GetNextURLForAllAnnotationResponse: - return await async_core.get_next_url_for_all_annotations( - batch_id=batch_id + return await async_core.adb_client.get_next_url_for_all_annotations( + batch_id=batch_id, + user_id=access_info.user_id, + url_id=anno_url_id ) @annotate_router.post("/all/{url_id}") @@ -143,16 +76,36 @@ async def annotate_url_for_all_annotations_and_get_next_url( all_annotation_post_info: AllAnnotationPostInfo, async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), - batch_id: Optional[int] = batch_query + batch_id: int | None = batch_query, + anno_url_id: int | None = url_id_query ) -> GetNextURLForAllAnnotationResponse: """ Post URL annotation and get next URL to annotate """ - await async_core.submit_url_for_all_annotations( + await async_core.adb_client.run_query_builder( + AddAllAnnotationsToURLQueryBuilder( + user_id=access_info.user_id, + url_id=url_id, + post_info=all_annotation_post_info + ) + ) + + return await async_core.adb_client.get_next_url_for_all_annotations( + batch_id=batch_id, user_id=access_info.user_id, - url_id=url_id, - post_info=all_annotation_post_info + url_id=anno_url_id ) - return await async_core.get_next_url_for_all_annotations( - batch_id=batch_id + +@annotate_router.get("/suggestions/agencies/{url_id}") +async def get_agency_suggestions( + url_id: int, + async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info), + location_id: int | None = Query(default=None) +) -> AgencyAnnotationResponseOuterInfo: + return await async_core.adb_client.run_query_builder( + GetAgencySuggestionsQueryBuilder( + url_id=url_id, + location_id=location_id + ) ) \ No newline at end of file diff --git a/src/api/endpoints/batch/dtos/get/logs.py b/src/api/endpoints/batch/dtos/get/logs.py index a350caa1..09ac7bba 100644 --- a/src/api/endpoints/batch/dtos/get/logs.py +++ b/src/api/endpoints/batch/dtos/get/logs.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from src.db.dtos.log import LogOutputInfo +from src.db.models.impl.log.pydantic.output import LogOutputInfo class GetBatchLogsResponse(BaseModel): diff --git a/src/api/endpoints/batch/dtos/get/summaries/summary.py b/src/api/endpoints/batch/dtos/get/summaries/summary.py index f00a42a5..4ca06768 100644 --- a/src/api/endpoints/batch/dtos/get/summaries/summary.py +++ b/src/api/endpoints/batch/dtos/get/summaries/summary.py @@ -13,6 +13,6 @@ class BatchSummary(BaseModel): status: BatchStatus parameters: dict user_id: int - compute_time: Optional[float] + compute_time: float | None date_generated: datetime.datetime url_counts: BatchSummaryURLCounts diff --git a/src/api/endpoints/batch/duplicates/dto.py b/src/api/endpoints/batch/duplicates/dto.py index 3838be77..dce8ae02 100644 --- a/src/api/endpoints/batch/duplicates/dto.py +++ b/src/api/endpoints/batch/duplicates/dto.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from src.db.dtos.duplicate import DuplicateInfo +from src.db.models.impl.duplicate.pydantic.info import DuplicateInfo class GetDuplicatesByBatchResponse(BaseModel): diff --git a/src/api/endpoints/batch/duplicates/query.py b/src/api/endpoints/batch/duplicates/query.py index a4c3aa31..b09b6e5d 100644 --- a/src/api/endpoints/batch/duplicates/query.py +++ b/src/api/endpoints/batch/duplicates/query.py @@ -2,11 +2,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import aliased -from src.db.dtos.duplicate import DuplicateInfo -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.duplicate import Duplicate -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.duplicate.pydantic.info import DuplicateInfo +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.duplicate.sqlalchemy import Duplicate +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL from src.db.queries.base.builder import QueryBuilderBase @@ -50,7 +50,7 @@ async def run(self, session: AsyncSession) -> list[DuplicateInfo]: final_results.append( DuplicateInfo( source_url=result.source_url, - duplicate_batch_id=result.duplicate_batch_id, + batch_id=result.duplicate_batch_id, duplicate_metadata=result.duplicate_batch_parameters, original_batch_id=result.original_batch_id, original_metadata=result.original_batch_parameters, diff --git a/src/api/endpoints/batch/routes.py b/src/api/endpoints/batch/routes.py index 879c643d..bd7bbf61 100644 --- a/src/api/endpoints/batch/routes.py +++ b/src/api/endpoints/batch/routes.py @@ -13,6 +13,7 @@ from src.collectors.enums import CollectorType from src.core.core import AsyncCore from src.core.enums import BatchStatus +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.security.dtos.access_info import AccessInfo from src.security.manager import get_access_info @@ -25,18 +26,14 @@ @batch_router.get("") async def get_batch_status( - collector_type: Optional[CollectorType] = Query( + collector_type: CollectorType | None = Query( description="Filter by collector type", default=None ), - status: Optional[BatchStatus] = Query( + status: BatchURLStatusEnum | None = Query( description="Filter by status", default=None ), - has_pending_urls: Optional[bool] = Query( - description="Filter by whether the batch has pending URLs", - default=None - ), page: int = Query( description="The page number", default=1 @@ -50,7 +47,6 @@ async def get_batch_status( return await core.get_batch_statuses( collector_type=collector_type, status=status, - has_pending_urls=has_pending_urls, page=page ) diff --git a/src/api/endpoints/batch/urls/dto.py b/src/api/endpoints/batch/urls/dto.py index 40b1e753..5e671e4b 100644 --- a/src/api/endpoints/batch/urls/dto.py +++ b/src/api/endpoints/batch/urls/dto.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from src.db.dtos.url.core import URLInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo class GetURLsByBatchResponse(BaseModel): diff --git a/src/api/endpoints/batch/urls/query.py b/src/api/endpoints/batch/urls/query.py index fcfba3ee..391a265f 100644 --- a/src/api/endpoints/batch/urls/query.py +++ b/src/api/endpoints/batch/urls/query.py @@ -1,9 +1,9 @@ from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession -from src.db.dtos.url.core import URLInfo -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL from src.db.queries.base.builder import QueryBuilderBase diff --git a/src/api/endpoints/collector/dtos/manual_batch/post.py b/src/api/endpoints/collector/dtos/manual_batch/post.py index f7de1ecf..6ec62579 100644 --- a/src/api/endpoints/collector/dtos/manual_batch/post.py +++ b/src/api/endpoints/collector/dtos/manual_batch/post.py @@ -7,13 +7,13 @@ class ManualBatchInnerInputDTO(BaseModel): url: str - name: Optional[str] = None - description: Optional[str] = None - collector_metadata: Optional[dict] = None - record_type: Optional[RecordType] = None - record_formats: Optional[list[str]] = None - data_portal_type: Optional[str] = None - supplying_entity: Optional[str] = None + name: str | None = None + description: str | None = None + collector_metadata: dict | None = None + record_type: RecordType | None = None + record_formats: list[str] | None = None + data_portal_type: str | None = None + supplying_entity: str | None = None class ManualBatchInputDTO(BaseModel): diff --git a/src/api/endpoints/collector/manual/query.py b/src/api/endpoints/collector/manual/query.py index 2f29a357..4f8956dc 100644 --- a/src/api/endpoints/collector/manual/query.py +++ b/src/api/endpoints/collector/manual/query.py @@ -5,10 +5,12 @@ from src.api.endpoints.collector.dtos.manual_batch.response import ManualBatchResponseDTO from src.collectors.enums import CollectorType, URLStatus from src.core.enums import BatchStatus -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType from src.db.queries.base.builder import QueryBuilderBase @@ -36,9 +38,9 @@ async def run(self, session: AsyncSession) -> ManualBatchResponseDTO: session.add(batch) await session.flush() - batch_id = batch.id - url_ids = [] - duplicate_urls = [] + batch_id: int = batch.id + url_ids: list[int] = [] + duplicate_urls: list[str] = [] for entry in self.dto.entries: url = URL( @@ -46,10 +48,11 @@ async def run(self, session: AsyncSession) -> ManualBatchResponseDTO: name=entry.name, description=entry.description, collector_metadata=entry.collector_metadata, - outcome=URLStatus.PENDING.value, - record_type=entry.record_type.value if entry.record_type is not None else None, + status=URLStatus.OK.value, + source=URLSource.MANUAL ) + async with session.begin_nested(): try: session.add(url) @@ -58,6 +61,15 @@ async def run(self, session: AsyncSession) -> ManualBatchResponseDTO: duplicate_urls.append(entry.url) continue await session.flush() + + if entry.record_type is not None: + record_type = URLRecordType( + url_id=url.id, + record_type=entry.record_type, + ) + session.add(record_type) + + link = LinkBatchURL( batch_id=batch_id, url_id=url.id diff --git a/src/api/endpoints/collector/routes.py b/src/api/endpoints/collector/routes.py index 6f39d27f..4818dc63 100644 --- a/src/api/endpoints/collector/routes.py +++ b/src/api/endpoints/collector/routes.py @@ -5,17 +5,17 @@ from src.api.endpoints.collector.dtos.collector_start import CollectorStartInfo from src.api.endpoints.collector.dtos.manual_batch.post import ManualBatchInputDTO from src.api.endpoints.collector.dtos.manual_batch.response import ManualBatchResponseDTO -from src.collectors.source_collectors.auto_googler.dtos.input import AutoGooglerInputDTO -from src.collectors.source_collectors.common_crawler.input import CommonCrawlerInputDTO -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO +from src.collectors.impl.auto_googler.dtos.input import AutoGooglerInputDTO +from src.collectors.impl.common_crawler.input import CommonCrawlerInputDTO +from src.collectors.impl.example.dtos.input import ExampleInputDTO from src.collectors.enums import CollectorType from src.core.core import AsyncCore from src.security.manager import get_access_info from src.security.dtos.access_info import AccessInfo -from src.collectors.source_collectors.ckan.dtos.input import CKANInputDTO -from src.collectors.source_collectors.muckrock.collectors.all_foia.dto import MuckrockAllFOIARequestsCollectorInputDTO -from src.collectors.source_collectors.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO -from src.collectors.source_collectors.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO +from src.collectors.impl.ckan.dtos.input import CKANInputDTO +from src.collectors.impl.muckrock.collectors.all_foia.dto import MuckrockAllFOIARequestsCollectorInputDTO +from src.collectors.impl.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO +from src.collectors.impl.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO collector_router = APIRouter( prefix="/collector", diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/__init__.py b/src/api/endpoints/contributions/__init__.py similarity index 100% rename from src/collectors/source_collectors/ckan/scraper_toolkit/__init__.py rename to src/api/endpoints/contributions/__init__.py diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/__init__.py b/src/api/endpoints/contributions/leaderboard/__init__.py similarity index 100% rename from src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/__init__.py rename to src/api/endpoints/contributions/leaderboard/__init__.py diff --git a/src/api/endpoints/contributions/leaderboard/query.py b/src/api/endpoints/contributions/leaderboard/query.py new file mode 100644 index 00000000..4075585f --- /dev/null +++ b/src/api/endpoints/contributions/leaderboard/query.py @@ -0,0 +1,39 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.contributions.leaderboard.response import ContributionsLeaderboardResponse, \ + ContributionsLeaderboardInnerResponse +from src.api.endpoints.contributions.shared.contributions import ContributionsCTEContainer +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetContributionsLeaderboardQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> ContributionsLeaderboardResponse: + cte = ContributionsCTEContainer() + + query = ( + select( + cte.user_id, + cte.count, + ) + .order_by( + cte.count.desc() + ) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + inner_responses = [ + ContributionsLeaderboardInnerResponse( + user_id=mapping["user_id"], + count=mapping["count"] + ) + for mapping in mappings + ] + + return ContributionsLeaderboardResponse( + leaderboard=inner_responses + ) \ No newline at end of file diff --git a/src/api/endpoints/contributions/leaderboard/response.py b/src/api/endpoints/contributions/leaderboard/response.py new file mode 100644 index 00000000..a92c177b --- /dev/null +++ b/src/api/endpoints/contributions/leaderboard/response.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + + +class ContributionsLeaderboardInnerResponse(BaseModel): + user_id: int + count: int + +class ContributionsLeaderboardResponse(BaseModel): + leaderboard: list[ContributionsLeaderboardInnerResponse] \ No newline at end of file diff --git a/src/api/endpoints/contributions/routes.py b/src/api/endpoints/contributions/routes.py new file mode 100644 index 00000000..c6fdc739 --- /dev/null +++ b/src/api/endpoints/contributions/routes.py @@ -0,0 +1,46 @@ +from fastapi import APIRouter, Depends + +from src.api.dependencies import get_async_core +from src.api.endpoints.contributions.leaderboard.query import GetContributionsLeaderboardQueryBuilder +from src.api.endpoints.contributions.leaderboard.response import ContributionsLeaderboardResponse +from src.api.endpoints.contributions.user.queries.core import GetUserContributionsQueryBuilder +from src.api.endpoints.contributions.user.response import ContributionsUserResponse +from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_access_info + +contributions_router = APIRouter( + prefix="/contributions", + tags=["Contributions"], +) + +@contributions_router.get("/leaderboard") +async def get_leaderboard( + core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info) +) -> ContributionsLeaderboardResponse: + """Returns the leaderboard of user contributions.""" + return await core.adb_client.run_query_builder( + GetContributionsLeaderboardQueryBuilder() + ) + +@contributions_router.get("/user") +async def get_user_contributions( + core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info) +) -> ContributionsUserResponse: + """Get contributions for the user and how often their annotations agreed with the final validation of URLs. + + Agreement for each is based the number of the user's correct annotations for that URL attribute + divided by their total number of annotations for that URL attribute. + + "Correct" in this case means the user's annotation value for that URL attribute + aligned with the final validated value for that attribute. + + In the case of attributes with multiple validated values, such as agency ID, + agreement is determined if the user's suggested value aligns with any of the final validated values. + """ + + return await core.adb_client.run_query_builder( + GetUserContributionsQueryBuilder(access_info.user_id) + ) \ No newline at end of file diff --git a/src/collectors/source_collectors/common_crawler/__init__.py b/src/api/endpoints/contributions/shared/__init__.py similarity index 100% rename from src/collectors/source_collectors/common_crawler/__init__.py rename to src/api/endpoints/contributions/shared/__init__.py diff --git a/src/api/endpoints/contributions/shared/contributions.py b/src/api/endpoints/contributions/shared/contributions.py new file mode 100644 index 00000000..477f0365 --- /dev/null +++ b/src/api/endpoints/contributions/shared/contributions.py @@ -0,0 +1,31 @@ +from sqlalchemy import select, func, CTE, Column + +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion + + +class ContributionsCTEContainer: + + def __init__(self): + self._cte = ( + select( + UserURLTypeSuggestion.user_id, + func.count().label("count") + ) + .group_by( + UserURLTypeSuggestion.user_id + ) + .cte("contributions") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def count(self) -> Column[int]: + return self.cte.c.count + + @property + def user_id(self) -> Column[int]: + return self.cte.c.user_id + diff --git a/src/collectors/source_collectors/example/__init__.py b/src/api/endpoints/contributions/user/__init__.py similarity index 100% rename from src/collectors/source_collectors/example/__init__.py rename to src/api/endpoints/contributions/user/__init__.py diff --git a/src/collectors/source_collectors/example/dtos/__init__.py b/src/api/endpoints/contributions/user/queries/__init__.py similarity index 100% rename from src/collectors/source_collectors/example/dtos/__init__.py rename to src/api/endpoints/contributions/user/queries/__init__.py diff --git a/src/collectors/source_collectors/muckrock/__init__.py b/src/api/endpoints/contributions/user/queries/agreement/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/__init__.py rename to src/api/endpoints/contributions/user/queries/agreement/__init__.py diff --git a/src/api/endpoints/contributions/user/queries/agreement/agency.py b/src/api/endpoints/contributions/user/queries/agreement/agency.py new file mode 100644 index 00000000..96011e06 --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/agreement/agency.py @@ -0,0 +1,60 @@ +from sqlalchemy import select, func, exists, and_ + +from src.api.endpoints.contributions.user.queries.annotated_and_validated import AnnotatedAndValidatedCTEContainer +from src.api.endpoints.contributions.user.queries.templates.agreement import AgreementCTEContainer +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion + + +def get_agency_agreement_cte_container( + inner_cte: AnnotatedAndValidatedCTEContainer +) -> AgreementCTEContainer: + + count_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserUrlAgencySuggestion, + and_( + inner_cte.user_id == UserUrlAgencySuggestion.user_id, + inner_cte.url_id == UserUrlAgencySuggestion.url_id + ) + ) + .group_by( + inner_cte.user_id + ) + .cte("agency_count_total") + ) + + agreed_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserUrlAgencySuggestion, + and_( + inner_cte.user_id == UserUrlAgencySuggestion.user_id, + inner_cte.url_id == UserUrlAgencySuggestion.url_id + ) + ) + .where( + exists() + .where( + LinkURLAgency.url_id == UserUrlAgencySuggestion.url_id, + LinkURLAgency.agency_id == UserUrlAgencySuggestion.agency_id + ) + ) + .group_by( + inner_cte.user_id + ) + .cte("agency_count_agreed") + ) + + return AgreementCTEContainer( + count_cte=count_cte, + agreed_cte=agreed_cte, + name="agency" + ) diff --git a/src/api/endpoints/contributions/user/queries/agreement/record_type.py b/src/api/endpoints/contributions/user/queries/agreement/record_type.py new file mode 100644 index 00000000..2cde5ab5 --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/agreement/record_type.py @@ -0,0 +1,54 @@ +from sqlalchemy import select, func, and_ + +from src.api.endpoints.contributions.user.queries.annotated_and_validated import AnnotatedAndValidatedCTEContainer +from src.api.endpoints.contributions.user.queries.templates.agreement import AgreementCTEContainer +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion + + +def get_record_type_agreement_cte_container( + inner_cte: AnnotatedAndValidatedCTEContainer +) -> AgreementCTEContainer: + + count_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserRecordTypeSuggestion, + UserRecordTypeSuggestion.url_id == inner_cte.url_id + ) + .group_by( + inner_cte.user_id + ) + .cte("record_type_count_total") + ) + + agreed_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserRecordTypeSuggestion, + UserRecordTypeSuggestion.url_id == inner_cte.url_id + ) + .join( + URLRecordType, + and_( + URLRecordType.url_id == inner_cte.url_id, + URLRecordType.record_type == UserRecordTypeSuggestion.record_type + ) + ) + .group_by( + inner_cte.user_id + ) + .cte("record_type_count_agreed") + ) + + return AgreementCTEContainer( + count_cte=count_cte, + agreed_cte=agreed_cte, + name="record_type" + ) \ No newline at end of file diff --git a/src/api/endpoints/contributions/user/queries/agreement/url_type.py b/src/api/endpoints/contributions/user/queries/agreement/url_type.py new file mode 100644 index 00000000..cf028bf1 --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/agreement/url_type.py @@ -0,0 +1,61 @@ +from sqlalchemy import select, func, and_ + +from src.api.endpoints.contributions.user.queries.annotated_and_validated import AnnotatedAndValidatedCTEContainer +from src.api.endpoints.contributions.user.queries.templates.agreement import AgreementCTEContainer +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion + + +def get_url_type_agreement_cte_container( + inner_cte: AnnotatedAndValidatedCTEContainer +) -> AgreementCTEContainer: + + # Count CTE is number of User URL Type Suggestions + count_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserURLTypeSuggestion, + UserURLTypeSuggestion.url_id == inner_cte.url_id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == inner_cte.url_id + ) + .group_by( + inner_cte.user_id + ) + .cte("url_type_count_total") + ) + + agreed_cte = ( + select( + inner_cte.user_id, + func.count() + ) + .join( + UserURLTypeSuggestion, + UserURLTypeSuggestion.url_id == inner_cte.url_id + ) + .join( + FlagURLValidated, + and_( + FlagURLValidated.url_id == inner_cte.url_id, + UserURLTypeSuggestion.type == FlagURLValidated.type + + ) + ) + .group_by( + inner_cte.user_id + ) + .cte("url_type_count_agreed") + ) + + return AgreementCTEContainer( + count_cte=count_cte, + agreed_cte=agreed_cte, + name="url_type" + ) + diff --git a/src/api/endpoints/contributions/user/queries/annotated_and_validated.py b/src/api/endpoints/contributions/user/queries/annotated_and_validated.py new file mode 100644 index 00000000..a9740328 --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/annotated_and_validated.py @@ -0,0 +1,34 @@ +from sqlalchemy import select, Column, CTE + +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion + + +class AnnotatedAndValidatedCTEContainer: + + def __init__(self, user_id: int | None): + self._cte = ( + select( + UserURLTypeSuggestion.user_id, + UserURLTypeSuggestion.url_id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == UserURLTypeSuggestion.url_id + ) + ) + if user_id is not None: + self._cte = self._cte.where(UserURLTypeSuggestion.user_id == user_id) + self._cte = self._cte.cte("annotated_and_validated") + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self.cte.c.url_id + + @property + def user_id(self) -> Column[int]: + return self.cte.c.user_id \ No newline at end of file diff --git a/src/api/endpoints/contributions/user/queries/core.py b/src/api/endpoints/contributions/user/queries/core.py new file mode 100644 index 00000000..57727215 --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/core.py @@ -0,0 +1,59 @@ +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.contributions.shared.contributions import ContributionsCTEContainer +from src.api.endpoints.contributions.user.queries.agreement.agency import get_agency_agreement_cte_container +from src.api.endpoints.contributions.user.queries.agreement.record_type import get_record_type_agreement_cte_container +from src.api.endpoints.contributions.user.queries.agreement.url_type import get_url_type_agreement_cte_container +from src.api.endpoints.contributions.user.queries.annotated_and_validated import AnnotatedAndValidatedCTEContainer +from src.api.endpoints.contributions.user.queries.templates.agreement import AgreementCTEContainer +from src.api.endpoints.contributions.user.response import ContributionsUserResponse, ContributionsUserAgreement +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetUserContributionsQueryBuilder(QueryBuilderBase): + + def __init__(self, user_id: int): + super().__init__() + self.user_id = user_id + + async def run(self, session: AsyncSession) -> ContributionsUserResponse: + inner_cte = AnnotatedAndValidatedCTEContainer(self.user_id) + + contributions_cte = ContributionsCTEContainer() + record_type_agree: AgreementCTEContainer = get_record_type_agreement_cte_container(inner_cte) + agency_agree: AgreementCTEContainer = get_agency_agreement_cte_container(inner_cte) + url_type_agree: AgreementCTEContainer = get_url_type_agreement_cte_container(inner_cte) + + query = ( + select( + contributions_cte.count, + record_type_agree.agreement.label("record_type"), + agency_agree.agreement.label("agency"), + url_type_agree.agreement.label("url_type") + ) + .join( + record_type_agree.cte, + contributions_cte.user_id == record_type_agree.user_id + ) + .join( + agency_agree.cte, + contributions_cte.user_id == agency_agree.user_id + ) + .join( + url_type_agree.cte, + contributions_cte.user_id == url_type_agree.user_id + ) + ) + + mapping: RowMapping = await sh.mapping(session, query=query) + + return ContributionsUserResponse( + count_validated=mapping.count, + agreement=ContributionsUserAgreement( + record_type=mapping.record_type, + agency=mapping.agency, + url_type=mapping.url_type + ) + ) \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/api_interface/__init__.py b/src/api/endpoints/contributions/user/queries/templates/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/api_interface/__init__.py rename to src/api/endpoints/contributions/user/queries/templates/__init__.py diff --git a/src/api/endpoints/contributions/user/queries/templates/agreement.py b/src/api/endpoints/contributions/user/queries/templates/agreement.py new file mode 100644 index 00000000..8479f90c --- /dev/null +++ b/src/api/endpoints/contributions/user/queries/templates/agreement.py @@ -0,0 +1,35 @@ +from sqlalchemy import CTE, select, Column + + +class AgreementCTEContainer: + + def __init__( + self, + count_cte: CTE, + agreed_cte: CTE, + name: str + ): + self._cte = ( + select( + count_cte.c.user_id, + (agreed_cte.c.count / count_cte.c.count).label("agreement") + ) + .join( + agreed_cte, + count_cte.c.user_id == agreed_cte.c.user_id + ) + .cte(f"{name}_agreement") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def user_id(self) -> Column[int]: + return self.cte.c.user_id + + @property + def agreement(self) -> Column[float]: + return self.cte.c.agreement + diff --git a/src/api/endpoints/contributions/user/response.py b/src/api/endpoints/contributions/user/response.py new file mode 100644 index 00000000..8151c493 --- /dev/null +++ b/src/api/endpoints/contributions/user/response.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + +class ContributionsUserAgreement(BaseModel): + record_type: float = Field(ge=0, le=1) + agency: float = Field(ge=0, le=1) + url_type: float = Field(ge=0, le=1) + +class ContributionsUserResponse(BaseModel): + count_validated: int + agreement: ContributionsUserAgreement \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/collectors/__init__.py b/src/api/endpoints/metrics/backlog/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/__init__.py rename to src/api/endpoints/metrics/backlog/__init__.py diff --git a/src/api/endpoints/metrics/backlog/query.py b/src/api/endpoints/metrics/backlog/query.py new file mode 100644 index 00000000..788ef424 --- /dev/null +++ b/src/api/endpoints/metrics/backlog/query.py @@ -0,0 +1,53 @@ +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.dtos.get.backlog import GetMetricsBacklogResponseDTO, GetMetricsBacklogResponseInnerDTO +from src.db.models.impl.backlog_snapshot import BacklogSnapshot +from src.db.queries.base.builder import QueryBuilderBase + + +class GetBacklogMetricsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> GetMetricsBacklogResponseDTO: + month = func.date_trunc('month', BacklogSnapshot.created_at) + + # 1. Create a subquery that assigns row_number() partitioned by month + monthly_snapshot_subq = ( + select( + BacklogSnapshot.id, + BacklogSnapshot.created_at, + BacklogSnapshot.count_pending_total, + month.label("month_start"), + func.row_number() + .over( + partition_by=month, + order_by=BacklogSnapshot.created_at.desc() + ) + .label("row_number") + ) + .subquery() + ) + + # 2. Filter for the top (most recent) row in each month + stmt = ( + select( + monthly_snapshot_subq.c.month_start, + monthly_snapshot_subq.c.created_at, + monthly_snapshot_subq.c.count_pending_total + ) + .where(monthly_snapshot_subq.c.row_number == 1) + .order_by(monthly_snapshot_subq.c.month_start) + ) + + raw_result = await session.execute(stmt) + results = raw_result.all() + final_results = [] + for result in results: + final_results.append( + GetMetricsBacklogResponseInnerDTO( + month=result.month_start.strftime("%B %Y"), + count_pending_total=result.count_pending_total, + ) + ) + + return GetMetricsBacklogResponseDTO(entries=final_results) \ No newline at end of file diff --git a/src/api/endpoints/metrics/batches/aggregated/query.py b/src/api/endpoints/metrics/batches/aggregated/query.py deleted file mode 100644 index 12616a22..00000000 --- a/src/api/endpoints/metrics/batches/aggregated/query.py +++ /dev/null @@ -1,117 +0,0 @@ -from sqlalchemy import case, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql.functions import coalesce - -from src.api.endpoints.metrics.batches.aggregated.dto import GetMetricsBatchesAggregatedResponseDTO, \ - GetMetricsBatchesAggregatedInnerResponseDTO -from src.collectors.enums import URLStatus, CollectorType -from src.core.enums import BatchStatus -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer - - -class GetBatchesAggregatedMetricsQueryBuilder(QueryBuilderBase): - - async def run( - self, - session: AsyncSession - ) -> GetMetricsBatchesAggregatedResponseDTO: - sc = StatementComposer - - # First, get all batches broken down by collector type and status - def batch_column(status: BatchStatus, label): - return sc.count_distinct( - case( - ( - Batch.status == status.value, - Batch.id - ) - ), - label=label - ) - - batch_count_subquery = select( - batch_column(BatchStatus.READY_TO_LABEL, label="done_count"), - batch_column(BatchStatus.ERROR, label="error_count"), - Batch.strategy, - ).group_by(Batch.strategy).subquery("batch_count") - - def url_column(status: URLStatus, label): - return sc.count_distinct( - case( - ( - URL.outcome == status.value, - URL.id - ) - ), - label=label - ) - - # Next, count urls - url_count_subquery = select( - Batch.strategy, - url_column(URLStatus.PENDING, label="pending_count"), - url_column(URLStatus.ERROR, label="error_count"), - url_column(URLStatus.VALIDATED, label="validated_count"), - url_column(URLStatus.SUBMITTED, label="submitted_count"), - url_column(URLStatus.NOT_RELEVANT, label="rejected_count"), - - ).join( - LinkBatchURL, - LinkBatchURL.url_id == URL.id - ).outerjoin( - Batch, Batch.id == LinkBatchURL.batch_id - ).group_by( - Batch.strategy - ).subquery("url_count") - - # Combine - query = select( - Batch.strategy, - batch_count_subquery.c.done_count.label("batch_done_count"), - batch_count_subquery.c.error_count.label("batch_error_count"), - coalesce(url_count_subquery.c.pending_count, 0).label("pending_count"), - coalesce(url_count_subquery.c.error_count, 0).label("error_count"), - coalesce(url_count_subquery.c.submitted_count, 0).label("submitted_count"), - coalesce(url_count_subquery.c.rejected_count, 0).label("rejected_count"), - coalesce(url_count_subquery.c.validated_count, 0).label("validated_count") - ).join( - batch_count_subquery, - Batch.strategy == batch_count_subquery.c.strategy - ).outerjoin( - url_count_subquery, - Batch.strategy == url_count_subquery.c.strategy - ) - raw_results = await session.execute(query) - results = raw_results.all() - d: dict[CollectorType, GetMetricsBatchesAggregatedInnerResponseDTO] = {} - for result in results: - d[CollectorType(result.strategy)] = GetMetricsBatchesAggregatedInnerResponseDTO( - count_successful_batches=result.batch_done_count, - count_failed_batches=result.batch_error_count, - count_urls=result.pending_count + result.submitted_count + - result.rejected_count + result.error_count + - result.validated_count, - count_urls_pending=result.pending_count, - count_urls_validated=result.validated_count, - count_urls_submitted=result.submitted_count, - count_urls_rejected=result.rejected_count, - count_urls_errors=result.error_count - ) - - total_batch_query = await session.execute( - select( - sc.count_distinct(Batch.id, label="count") - ) - ) - total_batch_count = total_batch_query.scalars().one_or_none() - if total_batch_count is None: - total_batch_count = 0 - - return GetMetricsBatchesAggregatedResponseDTO( - total_batches=total_batch_count, - by_strategy=d - ) \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/collectors/all_foia/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/all_foia/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/__init__.py diff --git a/src/collectors/source_collectors/muckrock/collectors/county/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/all_urls/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/county/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/all_urls/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/all_urls/query.py b/src/api/endpoints/metrics/batches/aggregated/query/all_urls/query.py new file mode 100644 index 00000000..7eed215a --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/all_urls/query.py @@ -0,0 +1,28 @@ +from typing import Sequence + +from sqlalchemy import func, select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.base.builder import QueryBuilderBase + + +class CountAllURLsByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[CountByBatchStrategyResponse]: + + query = ( + select( + Batch.strategy, + func.count(LinkBatchURL.url_id).label("count") + ) + .join(LinkBatchURL) + .group_by(Batch.strategy) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results = [CountByBatchStrategyResponse(**mapping) for mapping in mappings] + return results \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/collectors/simple/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/simple/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/batch_status_/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/query.py b/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/query.py new file mode 100644 index 00000000..f8587b68 --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/query.py @@ -0,0 +1,37 @@ +from typing import Sequence + +from sqlalchemy import CTE, select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.batch_status_.response import \ + BatchStatusCountByBatchStrategyResponseDTO +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class BatchStatusByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[BatchStatusCountByBatchStrategyResponseDTO]: + query = ( + select( + Batch.strategy, + Batch.status, + func.count(Batch.id).label("count") + ) + .group_by(Batch.strategy, Batch.status) + ) + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + results: list[BatchStatusCountByBatchStrategyResponseDTO] = [] + for mapping in mappings: + results.append( + BatchStatusCountByBatchStrategyResponseDTO( + strategy=CollectorType(mapping["strategy"]), + status=BatchStatus(mapping["status"]), + count=mapping["count"] + ) + ) + return results \ No newline at end of file diff --git a/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/response.py b/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/response.py new file mode 100644 index 00000000..79c1b2dd --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/batch_status_/response.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus + + +class BatchStatusCountByBatchStrategyResponseDTO(BaseModel): + strategy: CollectorType + status: BatchStatus + count: int \ No newline at end of file diff --git a/src/api/endpoints/metrics/batches/aggregated/query/core.py b/src/api/endpoints/metrics/batches/aggregated/query/core.py new file mode 100644 index 00000000..c17f0f6d --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/core.py @@ -0,0 +1,79 @@ +from sqlalchemy import case, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.functions import coalesce, func + +from src.api.endpoints.metrics.batches.aggregated.dto import GetMetricsBatchesAggregatedResponseDTO, \ + GetMetricsBatchesAggregatedInnerResponseDTO +from src.api.endpoints.metrics.batches.aggregated.query.all_urls.query import CountAllURLsByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.batch_status_.query import \ + BatchStatusByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.requester_.requester import \ + GetBatchesAggregatedMetricsQueryRequester +from src.api.endpoints.metrics.batches.aggregated.query.submitted_.query import \ + CountSubmittedByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.url_error.query import URLErrorByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.validated_.query import \ + ValidatedURLCountByBatchStrategyQueryBuilder +from src.collectors.enums import URLStatus, CollectorType +from src.core.enums import BatchStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.queries.base.builder import QueryBuilderBase +from src.db.statement_composer import StatementComposer + + +class GetBatchesAggregatedMetricsQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> GetMetricsBatchesAggregatedResponseDTO: + + requester = GetBatchesAggregatedMetricsQueryRequester(session=session) + + url_error_count_dict: dict[CollectorType, int] = await requester.url_error_by_collector_strategy() + url_pending_count_dict: dict[CollectorType, int] = await requester.pending_url_count_by_collector_strategy() + url_submitted_count_dict: dict[CollectorType, int] = await requester.submitted_url_count_by_collector_strategy() + url_validated_count_dict: dict[CollectorType, int] = await requester.validated_url_count_by_collector_strategy() + url_rejected_count_dict: dict[CollectorType, int] = await requester.rejected_url_count_by_collector_strategy() + url_total_count_dict: dict[CollectorType, int] = await requester.url_count_by_collector_strategy() + batch_status_count_dict: dict[ + CollectorType, + dict[BatchStatus, int] + ] = await requester.batch_status_by_collector_strategy() + + + + + + d: dict[CollectorType, GetMetricsBatchesAggregatedInnerResponseDTO] = {} + for collector_type in CollectorType: + inner_response = GetMetricsBatchesAggregatedInnerResponseDTO( + count_successful_batches=batch_status_count_dict[collector_type][BatchStatus.READY_TO_LABEL], + count_failed_batches=batch_status_count_dict[collector_type][BatchStatus.ERROR], + count_urls=url_total_count_dict[collector_type], + count_urls_pending=url_pending_count_dict[collector_type], + count_urls_validated=url_validated_count_dict[collector_type], + count_urls_submitted=url_submitted_count_dict[collector_type], + count_urls_rejected=url_rejected_count_dict[collector_type], + count_urls_errors=url_error_count_dict[collector_type], + ) + d[collector_type] = inner_response + + total_batch_query = await session.execute( + select( + func.count(Batch.id, label="count") + ) + ) + total_batch_count = total_batch_query.scalars().one_or_none() + if total_batch_count is None: + total_batch_count = 0 + + return GetMetricsBatchesAggregatedResponseDTO( + total_batches=total_batch_count, + by_strategy=d + ) \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/models/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetch_requests/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/models/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/models/strategy_count.py b/src/api/endpoints/metrics/batches/aggregated/query/models/strategy_count.py new file mode 100644 index 00000000..9ceb7781 --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/models/strategy_count.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.collectors.enums import CollectorType + + +class CountByBatchStrategyResponse(BaseModel): + strategy: CollectorType + count: int \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/fetchers/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/pending/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetchers/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/pending/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/pending/query.py b/src/api/endpoints/metrics/batches/aggregated/query/pending/query.py new file mode 100644 index 00000000..224d3bad --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/pending/query.py @@ -0,0 +1,37 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class PendingURLCountByBatchStrategyQueryBuilder(QueryBuilderBase): + async def run( + self, session: AsyncSession + ) -> list[CountByBatchStrategyResponse]: + + query = ( + select( + Batch.strategy, + func.count(LinkBatchURL.url_id).label("count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .outerjoin( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .where(FlagURLValidated.url_id.is_(None)) + .group_by(Batch.strategy) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results = [CountByBatchStrategyResponse(**mapping) for mapping in mappings] + return results diff --git a/src/collectors/source_collectors/muckrock/fetchers/foia/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/rejected/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetchers/foia/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/rejected/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/rejected/query.py b/src/api/endpoints/metrics/batches/aggregated/query/rejected/query.py new file mode 100644 index 00000000..7b94f2ba --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/rejected/query.py @@ -0,0 +1,39 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class RejectedURLCountByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run( + self, session: AsyncSession + ) -> list[CountByBatchStrategyResponse]: + + query = ( + select( + Batch.strategy, + func.count(FlagURLValidated.url_id).label("count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .where(FlagURLValidated.type == URLType.NOT_RELEVANT) + .group_by(Batch.strategy) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results = [CountByBatchStrategyResponse(**mapping) for mapping in mappings] + return results diff --git a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/requester_/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetchers/jurisdiction/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/requester_/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/requester_/convert.py b/src/api/endpoints/metrics/batches/aggregated/query/requester_/convert.py new file mode 100644 index 00000000..4a129dfb --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/requester_/convert.py @@ -0,0 +1,11 @@ +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.collectors.enums import CollectorType + + +def convert_strategy_counts_to_strategy_count_dict( + responses: list[CountByBatchStrategyResponse] +) -> dict[CollectorType, int]: + result: dict[CollectorType, int] = {collector_type: 0 for collector_type in CollectorType} + for response in responses: + result[response.strategy] = response.count + return result \ No newline at end of file diff --git a/src/api/endpoints/metrics/batches/aggregated/query/requester_/requester.py b/src/api/endpoints/metrics/batches/aggregated/query/requester_/requester.py new file mode 100644 index 00000000..ac4c6dfa --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/requester_/requester.py @@ -0,0 +1,75 @@ + +from src.api.endpoints.metrics.batches.aggregated.query.all_urls.query import CountAllURLsByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.batch_status_.query import \ + BatchStatusByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.batch_status_.response import \ + BatchStatusCountByBatchStrategyResponseDTO +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.api.endpoints.metrics.batches.aggregated.query.pending.query import PendingURLCountByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.rejected.query import \ + RejectedURLCountByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.requester_.convert import \ + convert_strategy_counts_to_strategy_count_dict +from src.api.endpoints.metrics.batches.aggregated.query.submitted_.query import \ + CountSubmittedByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.url_error.query import URLErrorByBatchStrategyQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.validated_.query import \ + ValidatedURLCountByBatchStrategyQueryBuilder +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.queries.base.builder import QueryBuilderBase +from src.db.templates.requester import RequesterBase + + +class GetBatchesAggregatedMetricsQueryRequester(RequesterBase): + + async def _run_strategy_count_query_builder( + self, query_builder: type[QueryBuilderBase]) -> dict[CollectorType, int]: + responses: list[CountByBatchStrategyResponse] = \ + await query_builder().run(self.session) + + return convert_strategy_counts_to_strategy_count_dict(responses) + + async def url_error_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(URLErrorByBatchStrategyQueryBuilder) + + async def url_count_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(CountAllURLsByBatchStrategyQueryBuilder) + + async def submitted_url_count_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(CountSubmittedByBatchStrategyQueryBuilder) + + async def validated_url_count_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(ValidatedURLCountByBatchStrategyQueryBuilder) + + async def rejected_url_count_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(RejectedURLCountByBatchStrategyQueryBuilder) + + async def pending_url_count_by_collector_strategy(self) -> dict[CollectorType, int]: + return await self._run_strategy_count_query_builder(PendingURLCountByBatchStrategyQueryBuilder) + + async def batch_status_by_collector_strategy(self) -> dict[ + CollectorType, + dict[BatchStatus, int] + ]: + + responses: list[BatchStatusCountByBatchStrategyResponseDTO] = \ + await BatchStatusByBatchStrategyQueryBuilder().run(self.session) + + result: dict[CollectorType, dict[BatchStatus, int]] = { + collector_type: { + BatchStatus.ERROR: 0, + BatchStatus.READY_TO_LABEL: 0, + } + for collector_type in CollectorType + } + for response in responses: + if response.status not in ( + BatchStatus.ERROR, + BatchStatus.READY_TO_LABEL + ): + continue + result[response.strategy][response.status] = response.count + + return result + diff --git a/src/collectors/source_collectors/muckrock/fetchers/templates/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/submitted_/__init__.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetchers/templates/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/submitted_/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/submitted_/query.py b/src/api/endpoints/metrics/batches/aggregated/query/submitted_/query.py new file mode 100644 index 00000000..ee8f8065 --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/submitted_/query.py @@ -0,0 +1,45 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.collectors.enums import CollectorType +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.queries.base.builder import QueryBuilderBase + + +class CountSubmittedByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[ + CountByBatchStrategyResponse + ]: + query = ( + select( + Batch.strategy, + func.count(URLDataSource.id).label("count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + URLDataSource, + URLDataSource.url_id == LinkBatchURL.url_id + ) + .group_by(Batch.strategy) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results: list[CountByBatchStrategyResponse] = [] + for mapping in mappings: + results.append( + CountByBatchStrategyResponse( + strategy=CollectorType(mapping["strategy"]), + count=mapping["count"] + ) + ) + return results diff --git a/src/core/tasks/dtos/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/url_error/__init__.py similarity index 100% rename from src/core/tasks/dtos/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/url_error/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/url_error/query.py b/src/api/endpoints/metrics/batches/aggregated/query/url_error/query.py new file mode 100644 index 00000000..9bcc3a57 --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/url_error/query.py @@ -0,0 +1,34 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.collectors.enums import URLStatus +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class URLErrorByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[CountByBatchStrategyResponse]: + query = ( + select( + Batch.strategy, + func.count(URL.id).label("count") + ) + .select_from(Batch) + .join(LinkBatchURL) + .join(URL) + .where(URL.status == URLStatus.ERROR) + .group_by(Batch.strategy, URL.status) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results = [CountByBatchStrategyResponse(**mapping) for mapping in mappings] + return results + + diff --git a/src/core/tasks/scheduled/operators/__init__.py b/src/api/endpoints/metrics/batches/aggregated/query/validated_/__init__.py similarity index 100% rename from src/core/tasks/scheduled/operators/__init__.py rename to src/api/endpoints/metrics/batches/aggregated/query/validated_/__init__.py diff --git a/src/api/endpoints/metrics/batches/aggregated/query/validated_/query.py b/src/api/endpoints/metrics/batches/aggregated/query/validated_/query.py new file mode 100644 index 00000000..155cbcb0 --- /dev/null +++ b/src/api/endpoints/metrics/batches/aggregated/query/validated_/query.py @@ -0,0 +1,38 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.batches.aggregated.query.models.strategy_count import CountByBatchStrategyResponse +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.base.builder import QueryBuilderBase + + +class ValidatedURLCountByBatchStrategyQueryBuilder(QueryBuilderBase): + + async def run( + self, session: AsyncSession + ) -> list[CountByBatchStrategyResponse]: + + query = ( + select( + Batch.strategy, + func.count(FlagURLValidated.url_id).label("count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .group_by(Batch.strategy) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + results = [CountByBatchStrategyResponse(**mapping) for mapping in mappings] + return results diff --git a/src/core/tasks/scheduled/operators/agency_sync/__init__.py b/src/api/endpoints/metrics/batches/breakdown/error/__init__.py similarity index 100% rename from src/core/tasks/scheduled/operators/agency_sync/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/error/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/error/cte_.py b/src/api/endpoints/metrics/batches/breakdown/error/cte_.py new file mode 100644 index 00000000..ed2ff44f --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/error/cte_.py @@ -0,0 +1,25 @@ +from sqlalchemy import select, func, CTE, Column + +from src.collectors.enums import URLStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.url.core.sqlalchemy import URL + +URL_ERROR_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(LinkBatchURL.url_id).label("count_error") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + URL, + URL.id == LinkBatchURL.url_id + ) + .where(URL.status == URLStatus.ERROR) + .group_by(Batch.id) + .cte("error") +) diff --git a/src/core/tasks/scheduled/operators/agency_sync/dtos/__init__.py b/src/api/endpoints/metrics/batches/breakdown/not_relevant/__init__.py similarity index 100% rename from src/core/tasks/scheduled/operators/agency_sync/dtos/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/not_relevant/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/not_relevant/cte_.py b/src/api/endpoints/metrics/batches/breakdown/not_relevant/cte_.py new file mode 100644 index 00000000..6342018b --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/not_relevant/cte_.py @@ -0,0 +1,27 @@ +from sqlalchemy import select, func + +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL + +NOT_RELEVANT_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(FlagURLValidated.url_id).label("count_rejected") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .where( + FlagURLValidated.type == URLType.NOT_RELEVANT + ) + .group_by(Batch.id) + .cte("not_relevant") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/queries/__init__.py b/src/api/endpoints/metrics/batches/breakdown/pending/__init__.py similarity index 100% rename from src/core/tasks/url/operators/agency_identification/queries/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/pending/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/pending/cte_.py b/src/api/endpoints/metrics/batches/breakdown/pending/cte_.py new file mode 100644 index 00000000..bf09f345 --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/pending/cte_.py @@ -0,0 +1,26 @@ +from sqlalchemy import select, func + +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL + +PENDING_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(LinkBatchURL.url_id).label("count_pending") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .outerjoin( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .where( + FlagURLValidated.url_id.is_(None) + ) + .group_by(Batch.id) + .cte("pending") +) \ No newline at end of file diff --git a/src/api/endpoints/metrics/batches/breakdown/query.py b/src/api/endpoints/metrics/batches/breakdown/query.py index 771543ac..5847e309 100644 --- a/src/api/endpoints/metrics/batches/breakdown/query.py +++ b/src/api/endpoints/metrics/batches/breakdown/query.py @@ -1,14 +1,21 @@ -from sqlalchemy import select, case +from sqlalchemy import select, case, Column from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql.functions import coalesce from src.api.endpoints.metrics.batches.breakdown.dto import GetMetricsBatchesBreakdownResponseDTO, \ GetMetricsBatchesBreakdownInnerResponseDTO +from src.api.endpoints.metrics.batches.breakdown.error.cte_ import URL_ERROR_CTE +from src.api.endpoints.metrics.batches.breakdown.not_relevant.cte_ import NOT_RELEVANT_CTE +from src.api.endpoints.metrics.batches.breakdown.pending.cte_ import PENDING_CTE +from src.api.endpoints.metrics.batches.breakdown.submitted.cte_ import SUBMITTED_CTE +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.api.endpoints.metrics.batches.breakdown.total.cte_ import TOTAL_CTE +from src.api.endpoints.metrics.batches.breakdown.validated.cte_ import VALIDATED_CTE from src.collectors.enums import URLStatus, CollectorType from src.core.enums import BatchStatus -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL from src.db.queries.base.builder import QueryBuilderBase from src.db.statement_composer import StatementComposer @@ -32,28 +39,32 @@ async def run(self, session: AsyncSession) -> GetMetricsBatchesBreakdownResponse Batch.date_generated.label("created_at"), ) - def url_column(status: URLStatus, label): - return sc.count_distinct( - case( - ( - URL.outcome == status.value, - URL.id - ) - ), - label=label - ) + all_ctes: list[BatchesBreakdownURLCTE] = [ + URL_ERROR_CTE, + NOT_RELEVANT_CTE, + PENDING_CTE, + SUBMITTED_CTE, + TOTAL_CTE, + VALIDATED_CTE + ] + + count_columns: list[Column] = [ + cte.count for cte in all_ctes + ] + count_query = select( - LinkBatchURL.batch_id, - sc.count_distinct(URL.id, label="count_total"), - url_column(URLStatus.PENDING, label="count_pending"), - url_column(URLStatus.SUBMITTED, label="count_submitted"), - url_column(URLStatus.NOT_RELEVANT, label="count_rejected"), - url_column(URLStatus.ERROR, label="count_error"), - url_column(URLStatus.VALIDATED, label="count_validated"), - ).join(URL, LinkBatchURL.url_id == URL.id).group_by( - LinkBatchURL.batch_id - ).subquery("url_count") + Batch.id.label("batch_id"), + *count_columns + ) + for cte in all_ctes: + count_query = count_query.outerjoin( + cte.query, + Batch.id == cte.batch_id + ) + + count_query = count_query.cte("url_count") + query = (select( main_query.c.strategy, diff --git a/src/core/tasks/url/operators/submit_approved_url/__init__.py b/src/api/endpoints/metrics/batches/breakdown/submitted/__init__.py similarity index 100% rename from src/core/tasks/url/operators/submit_approved_url/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/submitted/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/submitted/cte_.py b/src/api/endpoints/metrics/batches/breakdown/submitted/cte_.py new file mode 100644 index 00000000..face1891 --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/submitted/cte_.py @@ -0,0 +1,23 @@ +from sqlalchemy import select, func + +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource + +SUBMITTED_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(URLDataSource.id).label("count_submitted") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + URLDataSource, + URLDataSource.url_id == LinkBatchURL.url_id + ) + .group_by(Batch.id) + .cte("submitted") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_404_probe/__init__.py b/src/api/endpoints/metrics/batches/breakdown/templates/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_404_probe/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/templates/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/templates/cte_.py b/src/api/endpoints/metrics/batches/breakdown/templates/cte_.py new file mode 100644 index 00000000..3fdd7521 --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/templates/cte_.py @@ -0,0 +1,20 @@ +from psycopg import Column +from sqlalchemy import CTE + + +class BatchesBreakdownURLCTE: + + def __init__(self, query: CTE): + self._query = query + + @property + def query(self) -> CTE: + return self._query + + @property + def batch_id(self) -> Column: + return self._query.columns[0] + + @property + def count(self) -> Column: + return self._query.columns[1] \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_duplicate/__init__.py b/src/api/endpoints/metrics/batches/breakdown/total/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_duplicate/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/total/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/total/cte_.py b/src/api/endpoints/metrics/batches/breakdown/total/cte_.py new file mode 100644 index 00000000..33cf0c84 --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/total/cte_.py @@ -0,0 +1,15 @@ +from sqlalchemy import select, func + +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL + +TOTAL_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(LinkBatchURL.url_id).label("count_total") + ) + .join(LinkBatchURL) + .group_by(Batch.id) + .cte("total") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_html/__init__.py b/src/api/endpoints/metrics/batches/breakdown/validated/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/__init__.py rename to src/api/endpoints/metrics/batches/breakdown/validated/__init__.py diff --git a/src/api/endpoints/metrics/batches/breakdown/validated/cte_.py b/src/api/endpoints/metrics/batches/breakdown/validated/cte_.py new file mode 100644 index 00000000..b6ff5ef1 --- /dev/null +++ b/src/api/endpoints/metrics/batches/breakdown/validated/cte_.py @@ -0,0 +1,23 @@ +from sqlalchemy import select, func + +from src.api.endpoints.metrics.batches.breakdown.templates.cte_ import BatchesBreakdownURLCTE +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL + +VALIDATED_CTE = BatchesBreakdownURLCTE( + select( + Batch.id, + func.count(FlagURLValidated.url_id).label("count_validated") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == LinkBatchURL.url_id + ) + .group_by(Batch.id) + .cte("validated") +) \ No newline at end of file diff --git a/src/api/endpoints/metrics/dtos/get/urls/aggregated/core.py b/src/api/endpoints/metrics/dtos/get/urls/aggregated/core.py index 66009223..7dbbc48a 100644 --- a/src/api/endpoints/metrics/dtos/get/urls/aggregated/core.py +++ b/src/api/endpoints/metrics/dtos/get/urls/aggregated/core.py @@ -2,13 +2,17 @@ from pydantic import BaseModel +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.views.url_status.enums import URLStatusViewEnum + +class GetMetricsURLValidatedOldestPendingURL(BaseModel): + url_id: int + created_at: datetime.datetime class GetMetricsURLsAggregatedResponseDTO(BaseModel): count_urls_total: int - count_urls_pending: int - count_urls_submitted: int - count_urls_rejected: int - count_urls_validated: int - count_urls_errors: int - oldest_pending_url_created_at: datetime.datetime - oldest_pending_url_id: int \ No newline at end of file + count_urls_status: dict[URLStatusViewEnum, int] + count_urls_type: dict[URLType, int] + count_urls_record_type: dict[RecordType, int] + oldest_pending_url: GetMetricsURLValidatedOldestPendingURL | None diff --git a/src/core/tasks/url/operators/url_html/queries/__init__.py b/src/api/endpoints/metrics/urls/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/queries/__init__.py rename to src/api/endpoints/metrics/urls/__init__.py diff --git a/src/core/tasks/url/operators/url_html/scraper/__init__.py b/src/api/endpoints/metrics/urls/aggregated/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/__init__.py rename to src/api/endpoints/metrics/urls/aggregated/__init__.py diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/__init__.py b/src/api/endpoints/metrics/urls/aggregated/query/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/parser/__init__.py rename to src/api/endpoints/metrics/urls/aggregated/query/__init__.py diff --git a/src/api/endpoints/metrics/urls/aggregated/query/core.py b/src/api/endpoints/metrics/urls/aggregated/query/core.py new file mode 100644 index 00000000..c6dbc29f --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/core.py @@ -0,0 +1,40 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.dtos.get.urls.aggregated.core import GetMetricsURLsAggregatedResponseDTO, \ + GetMetricsURLValidatedOldestPendingURL +from src.api.endpoints.metrics.urls.aggregated.query.subqueries.all import ALL_SUBQUERY +from src.api.endpoints.metrics.urls.aggregated.query.subqueries.oldest_pending_url import \ + GetOldestPendingURLQueryBuilder +from src.api.endpoints.metrics.urls.aggregated.query.subqueries.record_type import GetURLRecordTypeCountQueryBuilder +from src.api.endpoints.metrics.urls.aggregated.query.subqueries.status import GetURLStatusCountQueryBuilder +from src.api.endpoints.metrics.urls.aggregated.query.subqueries.url_type import GetURLTypeCountQueryBuilder +from src.core.enums import RecordType +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.views.url_status.enums import URLStatusViewEnum +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsAggregatedMetricsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> GetMetricsURLsAggregatedResponseDTO: + + oldest_pending_url: GetMetricsURLValidatedOldestPendingURL | None = \ + await GetOldestPendingURLQueryBuilder().run(session=session) + + status_counts: dict[URLStatusViewEnum, int] = \ + await GetURLStatusCountQueryBuilder().run(session=session) + + validated_counts: dict[URLType, int] = \ + await GetURLTypeCountQueryBuilder().run(session=session) + + record_type_counts: dict[RecordType, int] = \ + await GetURLRecordTypeCountQueryBuilder().run(session=session) + + return GetMetricsURLsAggregatedResponseDTO( + count_urls_total=await sh.scalar(session, query=ALL_SUBQUERY), + oldest_pending_url=oldest_pending_url, + count_urls_status=status_counts, + count_urls_type=validated_counts, + count_urls_record_type=record_type_counts, + ) diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/dtos/__init__.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/parser/dtos/__init__.py rename to src/api/endpoints/metrics/urls/aggregated/query/subqueries/__init__.py diff --git a/src/api/endpoints/metrics/urls/aggregated/query/subqueries/all.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/all.py new file mode 100644 index 00000000..a2d09217 --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/all.py @@ -0,0 +1,9 @@ +from sqlalchemy import select, func + +from src.db.models.impl.url.core.sqlalchemy import URL + +ALL_SUBQUERY = ( + select( + func.count(URL.id).label("count") + ) +) \ No newline at end of file diff --git a/src/api/endpoints/metrics/urls/aggregated/query/subqueries/oldest_pending_url.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/oldest_pending_url.py new file mode 100644 index 00000000..2a951b4a --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/oldest_pending_url.py @@ -0,0 +1,47 @@ +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.dtos.get.urls.aggregated.core import GetMetricsURLValidatedOldestPendingURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.views.url_status.core import URLStatusMatView +from src.db.models.views.url_status.enums import URLStatusViewEnum +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class GetOldestPendingURLQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> GetMetricsURLValidatedOldestPendingURL | None: + + query = ( + select( + URLStatusMatView.url_id, + URL.created_at + ) + .join( + URL, + URLStatusMatView.url_id == URL.id + ).where( + URLStatusMatView.status.not_in( + [ + URLStatusViewEnum.SUBMITTED_PIPELINE_COMPLETE.value, + URLStatusViewEnum.ACCEPTED.value, + ] + ) + ).order_by( + URL.created_at.asc() + ).limit(1) + ) + + mapping: RowMapping | None = (await session.execute(query)).mappings().one_or_none() + if mapping is None: + return None + + return GetMetricsURLValidatedOldestPendingURL( + url_id=mapping["url_id"], + created_at=mapping["created_at"], + ) + diff --git a/src/api/endpoints/metrics/urls/aggregated/query/subqueries/record_type.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/record_type.py new file mode 100644 index 00000000..a4923af6 --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/record_type.py @@ -0,0 +1,33 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.enums import RecordType +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLRecordTypeCountQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> dict[RecordType, int]: + query = ( + select( + URLRecordType.record_type, + func.count(URLRecordType.url_id).label("count") + ) + .group_by( + URLRecordType.record_type + ) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + return { + mapping["record_type"]: mapping["count"] + for mapping in mappings + } \ No newline at end of file diff --git a/src/api/endpoints/metrics/urls/aggregated/query/subqueries/status.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/status.py new file mode 100644 index 00000000..05813ce0 --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/status.py @@ -0,0 +1,36 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.helpers.session import session_helper as sh +from src.db.models.views.url_status.core import URLStatusMatView +from src.db.models.views.url_status.enums import URLStatusViewEnum +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLStatusCountQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> dict[URLStatusViewEnum, int]: + + query = ( + select( + URLStatusMatView.status, + func.count( + URLStatusMatView.url_id + ).label("count") + ) + .group_by( + URLStatusMatView.status + ) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + return { + URLStatusViewEnum(mapping["status"]): mapping["count"] + for mapping in mappings + } diff --git a/src/api/endpoints/metrics/urls/aggregated/query/subqueries/url_type.py b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/url_type.py new file mode 100644 index 00000000..6561850e --- /dev/null +++ b/src/api/endpoints/metrics/urls/aggregated/query/subqueries/url_type.py @@ -0,0 +1,33 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLTypeCountQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> dict[URLType, int]: + query = ( + select( + FlagURLValidated.type, + func.count(FlagURLValidated.url_id).label("count") + ) + .group_by( + FlagURLValidated.type + ) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + return { + mapping["type"]: mapping["count"] + for mapping in mappings + } \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/__init__.py b/src/api/endpoints/metrics/urls/breakdown/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/request_interface/__init__.py rename to src/api/endpoints/metrics/urls/breakdown/__init__.py diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/__init__.py b/src/api/endpoints/metrics/urls/breakdown/query/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/__init__.py rename to src/api/endpoints/metrics/urls/breakdown/query/__init__.py diff --git a/src/api/endpoints/metrics/urls/breakdown/query/core.py b/src/api/endpoints/metrics/urls/breakdown/query/core.py new file mode 100644 index 00000000..e585554c --- /dev/null +++ b/src/api/endpoints/metrics/urls/breakdown/query/core.py @@ -0,0 +1,91 @@ +from typing import Any + +from sqlalchemy import select, case, literal, func +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.metrics.dtos.get.urls.breakdown.pending import GetMetricsURLsBreakdownPendingResponseInnerDTO, \ + GetMetricsURLsBreakdownPendingResponseDTO +from src.collectors.enums import URLStatus +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsBreakdownPendingMetricsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> GetMetricsURLsBreakdownPendingResponseDTO: + + flags = ( + select( + URL.id.label("url_id"), + case((UserRecordTypeSuggestion.url_id != None, literal(True)), else_=literal(False)).label( + "has_user_record_type_annotation" + ), + case((UserURLTypeSuggestion.url_id != None, literal(True)), else_=literal(False)).label( + "has_user_relevant_annotation" + ), + case((UserUrlAgencySuggestion.url_id != None, literal(True)), else_=literal(False)).label( + "has_user_agency_annotation" + ), + ) + .outerjoin(UserRecordTypeSuggestion, URL.id == UserRecordTypeSuggestion.url_id) + .outerjoin(UserURLTypeSuggestion, URL.id == UserURLTypeSuggestion.url_id) + .outerjoin(UserUrlAgencySuggestion, URL.id == UserUrlAgencySuggestion.url_id) + ).cte("flags") + + month = func.date_trunc('month', URL.created_at) + + # Build the query + query = ( + select( + month.label('month'), + func.count(URL.id).label('count_total'), + func.count( + case( + (flags.c.has_user_record_type_annotation == True, 1) + ) + ).label('user_record_type_count'), + func.count( + case( + (flags.c.has_user_relevant_annotation == True, 1) + ) + ).label('user_relevant_count'), + func.count( + case( + (flags.c.has_user_agency_annotation == True, 1) + ) + ).label('user_agency_count'), + ) + .outerjoin(flags, flags.c.url_id == URL.id) + .outerjoin( + FlagURLValidated, + FlagURLValidated.url_id == URL.id + ) + .where( + FlagURLValidated.url_id.is_(None), + URL.status == URLStatus.OK + ) + .group_by(month) + .order_by(month.asc()) + ) + + # Execute the query and return the results + results = await session.execute(query) + all_results = results.all() + final_results: list[GetMetricsURLsBreakdownPendingResponseInnerDTO] = [] + + for result in all_results: + dto = GetMetricsURLsBreakdownPendingResponseInnerDTO( + month=result.month.strftime("%B %Y"), + count_pending_total=result.count_total, + count_pending_relevant_user=result.user_relevant_count, + count_pending_record_type_user=result.user_record_type_count, + count_pending_agency_user=result.user_agency_count, + ) + final_results.append(dto) + return GetMetricsURLsBreakdownPendingResponseDTO( + entries=final_results, + ) \ No newline at end of file diff --git a/src/api/endpoints/review/approve/dto.py b/src/api/endpoints/review/approve/dto.py index 0d9628f7..639868ca 100644 --- a/src/api/endpoints/review/approve/dto.py +++ b/src/api/endpoints/review/approve/dto.py @@ -7,37 +7,37 @@ class FinalReviewApprovalInfo(FinalReviewBaseInfo): - record_type: Optional[RecordType] = Field( + record_type: RecordType | None = Field( title="The final record type of the URL." "If none, defers to the existing value from the auto-labeler only if it exists.", default=None ) - agency_ids: Optional[list[int]] = Field( + agency_ids: list[int] | None = Field( title="The final confirmed agencies for the URL. " "If none, defers to an existing confirmed agency only if that exists.", default=None ) - name: Optional[str] = Field( + name: str | None = Field( title="The name of the source. " "If none, defers to an existing name only if that exists.", default=None ) - description: Optional[str] = Field( + description: str | None = Field( title="The description of the source. " "If none, defers to an existing description only if that exists.", default=None ) - record_formats: Optional[list[str]] = Field( + record_formats: list[str] | None = Field( title="The record formats of the source. " "If none, defers to an existing record formats only if that exists.", default=None ) - data_portal_type: Optional[str] = Field( + data_portal_type: str | None = Field( title="The data portal type of the source. " "If none, defers to an existing data portal type only if that exists.", default=None ) - supplying_entity: Optional[str] = Field( + supplying_entity: str | None = Field( title="The supplying entity of the source. " "If none, defers to an existing supplying entity only if that exists.", default=None diff --git a/src/api/endpoints/review/approve/query.py b/src/api/endpoints/review/approve/query.py deleted file mode 100644 index bff32bf3..00000000 --- a/src/api/endpoints/review/approve/query.py +++ /dev/null @@ -1,150 +0,0 @@ -from typing import Any - -from sqlalchemy import Select, select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload -from starlette.exceptions import HTTPException -from starlette.status import HTTP_400_BAD_REQUEST - -from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.collectors.enums import URLStatus -from src.db.constants import PLACEHOLDER_AGENCY_NAME -from src.db.models.instantiations.agency import Agency -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.reviewing_user import ReviewingUserURL -from src.db.queries.base.builder import QueryBuilderBase - - -class ApproveURLQueryBuilder(QueryBuilderBase): - - def __init__( - self, - user_id: int, - approval_info: FinalReviewApprovalInfo - ): - super().__init__() - self.user_id = user_id - self.approval_info = approval_info - - async def run(self, session: AsyncSession) -> None: - # Get URL - def update_if_not_none( - model, - field, - value: Any, - required: bool = False - ): - if value is not None: - setattr(model, field, value) - return - if not required: - return - model_value = getattr(model, field, None) - if model_value is None: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail=f"Must specify {field} if it does not already exist" - ) - - query = ( - Select(URL) - .where(URL.id == self.approval_info.url_id) - .options( - joinedload(URL.optional_data_source_metadata), - joinedload(URL.confirmed_agencies), - ) - ) - - url = await session.execute(query) - url = url.scalars().first() - - update_if_not_none( - url, - "record_type", - self.approval_info.record_type.value - if self.approval_info.record_type is not None else None, - required=True - ) - - # Get existing agency ids - existing_agencies = url.confirmed_agencies or [] - existing_agency_ids = [agency.agency_id for agency in existing_agencies] - new_agency_ids = self.approval_info.agency_ids or [] - if len(existing_agency_ids) == 0 and len(new_agency_ids) == 0: - raise HTTPException( - status_code=HTTP_400_BAD_REQUEST, - detail="Must specify agency_id if URL does not already have a confirmed agency" - ) - - # Get any existing agency ids that are not in the new agency ids - # If new agency ids are specified, overwrite existing - if len(new_agency_ids) != 0: - for existing_agency in existing_agencies: - if existing_agency.id not in new_agency_ids: - # If the existing agency id is not in the new agency ids, delete it - await session.delete(existing_agency) - # Add any new agency ids that are not in the existing agency ids - for new_agency_id in new_agency_ids: - if new_agency_id not in existing_agency_ids: - # Check if the new agency exists in the database - query = ( - select(Agency) - .where(Agency.agency_id == new_agency_id) - ) - existing_agency = await session.execute(query) - existing_agency = existing_agency.scalars().first() - if existing_agency is None: - # If not, create it - agency = Agency( - agency_id=new_agency_id, - name=PLACEHOLDER_AGENCY_NAME, - ) - session.add(agency) - - # If the new agency id is not in the existing agency ids, add it - confirmed_url_agency = ConfirmedURLAgency( - url_id=self.approval_info.url_id, - agency_id=new_agency_id - ) - session.add(confirmed_url_agency) - - # If it does, do nothing - - url.outcome = URLStatus.VALIDATED.value - - update_if_not_none(url, "name", self.approval_info.name, required=True) - update_if_not_none(url, "description", self.approval_info.description, required=True) - - optional_metadata = url.optional_data_source_metadata - if optional_metadata is None: - url.optional_data_source_metadata = URLOptionalDataSourceMetadata( - record_formats=self.approval_info.record_formats, - data_portal_type=self.approval_info.data_portal_type, - supplying_entity=self.approval_info.supplying_entity - ) - else: - update_if_not_none( - optional_metadata, - "record_formats", - self.approval_info.record_formats - ) - update_if_not_none( - optional_metadata, - "data_portal_type", - self.approval_info.data_portal_type - ) - update_if_not_none( - optional_metadata, - "supplying_entity", - self.approval_info.supplying_entity - ) - - # Add approving user - approving_user_url = ReviewingUserURL( - user_id=self.user_id, - url_id=self.approval_info.url_id - ) - - session.add(approving_user_url) \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/__init__.py b/src/api/endpoints/review/approve/query_/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/root_url_cache/__init__.py rename to src/api/endpoints/review/approve/query_/__init__.py diff --git a/src/api/endpoints/review/approve/query_/core.py b/src/api/endpoints/review/approve/query_/core.py new file mode 100644 index 00000000..15641764 --- /dev/null +++ b/src/api/endpoints/review/approve/query_/core.py @@ -0,0 +1,174 @@ +from sqlalchemy import Select, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload +from starlette.exceptions import HTTPException +from starlette.status import HTTP_400_BAD_REQUEST + +from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo +from src.api.endpoints.review.approve.query_.util import update_if_not_none +from src.collectors.enums import URLStatus +from src.db.constants import PLACEHOLDER_AGENCY_NAME +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.models.impl.url.reviewing_user import ReviewingUserURL +from src.db.queries.base.builder import QueryBuilderBase + + +class ApproveURLQueryBuilder(QueryBuilderBase): + + def __init__( + self, + user_id: int, + approval_info: FinalReviewApprovalInfo + ): + super().__init__() + self.user_id = user_id + self.approval_info = approval_info + + async def run(self, session: AsyncSession) -> None: + # Get URL + + url = await self._get_url(session) + + await self._optionally_update_record_type(session) + + # Get existing agency ids + existing_agencies = url.confirmed_agencies or [] + existing_agency_ids = [agency.agency_id for agency in existing_agencies] + new_agency_ids = self.approval_info.agency_ids or [] + await self._check_for_unspecified_agency_ids(existing_agency_ids, new_agency_ids) + + await self._overwrite_existing_agencies(existing_agencies, new_agency_ids, session) + # Add any new agency ids that are not in the existing agency ids + await self._add_new_agencies(existing_agency_ids, new_agency_ids, session) + + await self._add_validated_flag(session, url=url) + + await self._optionally_update_required_metadata(url) + await self._optionally_update_optional_metdata(url) + await self._add_approving_user(session) + + async def _optionally_update_required_metadata(self, url: URL) -> None: + update_if_not_none(url, "name", self.approval_info.name, required=True) + update_if_not_none(url, "description", self.approval_info.description, required=False) + + async def _add_approving_user(self, session: AsyncSession) -> None: + approving_user_url = ReviewingUserURL( + user_id=self.user_id, + url_id=self.approval_info.url_id + ) + session.add(approving_user_url) + + async def _optionally_update_optional_metdata(self, url: URL) -> None: + optional_metadata = url.optional_data_source_metadata + if optional_metadata is None: + url.optional_data_source_metadata = URLOptionalDataSourceMetadata( + record_formats=self.approval_info.record_formats, + data_portal_type=self.approval_info.data_portal_type, + supplying_entity=self.approval_info.supplying_entity + ) + else: + update_if_not_none( + optional_metadata, + "record_formats", + self.approval_info.record_formats + ) + update_if_not_none( + optional_metadata, + "data_portal_type", + self.approval_info.data_portal_type + ) + update_if_not_none( + optional_metadata, + "supplying_entity", + self.approval_info.supplying_entity + ) + + async def _optionally_update_record_type(self, session: AsyncSession) -> None: + if self.approval_info.record_type is None: + return + + record_type = URLRecordType( + url_id=self.approval_info.url_id, + record_type=self.approval_info.record_type.value + ) + session.add(record_type) + + async def _get_url(self, session: AsyncSession) -> URL: + query = ( + Select(URL) + .where(URL.id == self.approval_info.url_id) + .options( + joinedload(URL.optional_data_source_metadata), + joinedload(URL.confirmed_agencies), + ) + ) + url = await session.execute(query) + url = url.scalars().first() + return url + + async def _check_for_unspecified_agency_ids( + self, + existing_agency_ids: list[int], + new_agency_ids: list[int] + ) -> None: + """ + raises: + HTTPException: If no agency ids are specified and no existing agency ids are found + """ + if len(existing_agency_ids) == 0 and len(new_agency_ids) == 0: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Must specify agency_id if URL does not already have a confirmed agency" + ) + + async def _overwrite_existing_agencies(self, existing_agencies, new_agency_ids, session): + # Get any existing agency ids that are not in the new agency ids + # If new agency ids are specified, overwrite existing + if len(new_agency_ids) != 0: + for existing_agency in existing_agencies: + if existing_agency.id not in new_agency_ids: + # If the existing agency id is not in the new agency ids, delete it + await session.delete(existing_agency) + + async def _add_new_agencies(self, existing_agency_ids, new_agency_ids, session): + for new_agency_id in new_agency_ids: + if new_agency_id in existing_agency_ids: + continue + # Check if the new agency exists in the database + query = ( + select(Agency) + .where(Agency.agency_id == new_agency_id) + ) + existing_agency = await session.execute(query) + existing_agency = existing_agency.scalars().first() + if existing_agency is None: + # If not, raise an error + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail="Agency not found" + ) + + + # If the new agency id is not in the existing agency ids, add it + confirmed_url_agency = LinkURLAgency( + url_id=self.approval_info.url_id, + agency_id=new_agency_id + ) + session.add(confirmed_url_agency) + + async def _add_validated_flag( + self, + session: AsyncSession, + url: URL + ) -> None: + flag = FlagURLValidated( + url_id=url.id, + type=URLType.DATA_SOURCE + ) + session.add(flag) diff --git a/src/api/endpoints/review/approve/query_/util.py b/src/api/endpoints/review/approve/query_/util.py new file mode 100644 index 00000000..219a1f86 --- /dev/null +++ b/src/api/endpoints/review/approve/query_/util.py @@ -0,0 +1,23 @@ +from typing import Any + +from starlette.exceptions import HTTPException +from starlette.status import HTTP_400_BAD_REQUEST + + +def update_if_not_none( + model, + field, + value: Any, + required: bool = False +): + if value is not None: + setattr(model, field, value) + return + if not required: + return + model_value = getattr(model, field, None) + if model_value is None: + raise HTTPException( + status_code=HTTP_400_BAD_REQUEST, + detail=f"Must specify {field} if it does not already exist" + ) diff --git a/src/api/endpoints/review/next/dto.py b/src/api/endpoints/review/next/dto.py index 7fc53b17..13a68239 100644 --- a/src/api/endpoints/review/next/dto.py +++ b/src/api/endpoints/review/next/dto.py @@ -1,43 +1,42 @@ -from typing import Optional - from pydantic import BaseModel, Field -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo +from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo, AgencySuggestionAndUserCount from src.api.endpoints.annotate.relevance.get.dto import RelevanceAnnotationResponseInfo -from src.core.enums import RecordType, SuggestedStatus -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.core.enums import RecordType +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.db.models.impl.flag.url_validated.enums import URLType class FinalReviewAnnotationRelevantInfo(BaseModel): - auto: Optional[RelevanceAnnotationResponseInfo] = Field(title="Whether the auto-labeler has marked the URL as relevant") - user: Optional[SuggestedStatus] = Field( - title="The status marked by a user, if any", + auto: RelevanceAnnotationResponseInfo | None = Field(title="Whether the auto-labeler has marked the URL as relevant") + user: dict[URLType, int] = Field( + title="How users have labeled the URLType" ) class FinalReviewAnnotationRecordTypeInfo(BaseModel): - auto: Optional[RecordType] = Field( + auto: RecordType | None = Field( title="The record type suggested by the auto-labeler" ) - user: Optional[RecordType] = Field( - title="The record type suggested by a user", + user: dict[RecordType, int] = Field( + title="The record types suggested by other users", ) # region Agency class FinalReviewAnnotationAgencyAutoInfo(BaseModel): unknown: bool = Field(title="Whether the auto-labeler suggested the URL as unknown") - suggestions: Optional[list[GetNextURLForAgencyAgencyInfo]] = Field( + suggestions: list[GetNextURLForAgencyAgencyInfo] | None = Field( title="A list of agencies, if any, suggested by the auto-labeler", ) class FinalReviewAnnotationAgencyInfo(BaseModel): - confirmed: Optional[list[GetNextURLForAgencyAgencyInfo]] = Field( + confirmed: list[GetNextURLForAgencyAgencyInfo] | None = Field( title="The confirmed agency for the URL", ) - auto: Optional[FinalReviewAnnotationAgencyAutoInfo] = Field( + auto: FinalReviewAnnotationAgencyAutoInfo | None = Field( title="A single agency or a list of agencies suggested by the auto-labeler",) - user: Optional[GetNextURLForAgencyAgencyInfo] = Field( - title="A single agency suggested by a user", + user: list[AgencySuggestionAndUserCount] = Field( + title="Agencies suggested by users", ) # endregion @@ -53,15 +52,15 @@ class FinalReviewAnnotationInfo(BaseModel): ) class FinalReviewOptionalMetadata(BaseModel): - record_formats: Optional[list[str]] = Field( + record_formats: list[str] | None = Field( title="The record formats of the source", default=None ) - data_portal_type: Optional[str] = Field( + data_portal_type: str | None = Field( title="The data portal type of the source", default=None ) - supplying_entity: Optional[str] = Field( + supplying_entity: str | None = Field( title="The supplying entity of the source", default=None ) @@ -77,8 +76,8 @@ class FinalReviewBatchInfo(BaseModel): class GetNextURLForFinalReviewResponse(BaseModel): id: int = Field(title="The id of the URL") url: str = Field(title="The URL") - name: Optional[str] = Field(title="The name of the source") - description: Optional[str] = Field(title="The description of the source") + name: str | None = Field(title="The name of the source") + description: str | None = Field(title="The description of the source") html_info: ResponseHTMLInfo = Field(title="The HTML content of the URL") annotations: FinalReviewAnnotationInfo = Field( title="The annotations for the URL, from both users and the auto-labeler", @@ -86,12 +85,12 @@ class GetNextURLForFinalReviewResponse(BaseModel): optional_metadata: FinalReviewOptionalMetadata = Field( title="Optional metadata for the source", ) - batch_info: Optional[FinalReviewBatchInfo] = Field( + batch_info: FinalReviewBatchInfo | None = Field( title="Information about the batch", ) class GetNextURLForFinalReviewOuterResponse(BaseModel): - next_source: Optional[GetNextURLForFinalReviewResponse] = Field( + next_source: GetNextURLForFinalReviewResponse | None = Field( title="The next source to be reviewed", ) remaining: int = Field( diff --git a/src/api/endpoints/review/next/query.py b/src/api/endpoints/review/next/query.py deleted file mode 100644 index 8f7d5e35..00000000 --- a/src/api/endpoints/review/next/query.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Optional, Type - -from sqlalchemy import FromClause, select, and_, Select, desc, asc, func, join -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload - -from src.api.endpoints.review.next.dto import FinalReviewOptionalMetadata, FinalReviewBatchInfo, \ - GetNextURLForFinalReviewOuterResponse, GetNextURLForFinalReviewResponse, FinalReviewAnnotationInfo -from src.collectors.enums import URLStatus -from src.core.tasks.url.operators.url_html.scraper.parser.util import convert_to_response_html_info -from src.db.constants import USER_ANNOTATION_MODELS, ALL_ANNOTATION_MODELS -from src.db.dto_converter import DTOConverter -from src.db.dtos.url.html_content import URLHTMLContentInfo -from src.db.exceptions import FailedQueryException -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.mixins import URLDependentMixin -from src.db.queries.base.builder import QueryBuilderBase -from src.db.queries.implementations.core.common.annotation_exists import AnnotationExistsCTEQueryBuilder - -TOTAL_DISTINCT_ANNOTATION_COUNT_LABEL = "total_distinct_annotation_count" - - -class GetNextURLForFinalReviewQueryBuilder(QueryBuilderBase): - - def __init__(self, batch_id: Optional[int] = None): - super().__init__() - self.batch_id = batch_id - self.anno_exists_builder = AnnotationExistsCTEQueryBuilder() - # The below relationships are joined directly to the URL - self.single_join_relationships = [ - URL.html_content, - URL.auto_record_type_suggestion, - URL.auto_relevant_suggestion, - URL.user_relevant_suggestion, - URL.user_record_type_suggestion, - URL.optional_data_source_metadata, - ] - # The below relationships are joined to entities that are joined to the URL - self.double_join_relationships = [ - (URL.automated_agency_suggestions, AutomatedUrlAgencySuggestion.agency), - (URL.user_agency_suggestion, UserUrlAgencySuggestion.agency), - (URL.confirmed_agencies, ConfirmedURLAgency.agency) - ] - - self.count_label = "count" - - def _get_where_exist_clauses( - self, - query: FromClause, - ): - where_clauses = [] - for model in USER_ANNOTATION_MODELS: - label = self.anno_exists_builder.get_exists_label(model) - where_clause = getattr(query.c, label) == 1 - where_clauses.append(where_clause) - return where_clauses - - def _build_base_query( - self, - anno_exists_query: FromClause, - ) -> Select: - builder = self.anno_exists_builder - where_exist_clauses = self._get_where_exist_clauses( - builder.query - ) - - query = ( - select( - URL, - self._sum_exists_query(anno_exists_query, USER_ANNOTATION_MODELS) - ) - .select_from(anno_exists_query) - .join( - URL, - URL.id == builder.url_id - ) - ) - if self.batch_id is not None: - query = ( - query.join( - LinkBatchURL - ) - .where( - LinkBatchURL.batch_id == self.batch_id - ) - ) - - query = ( - query.where( - and_( - URL.outcome == URLStatus.PENDING.value, - *where_exist_clauses - ) - ) - ) - return query - - - def _sum_exists_query(self, query, models: list[Type[URLDependentMixin]]): - return sum( - [getattr(query.c, self.anno_exists_builder.get_exists_label(model)) for model in models] - ).label(TOTAL_DISTINCT_ANNOTATION_COUNT_LABEL) - - - async def _apply_batch_id_filter(self, url_query: Select, batch_id: Optional[int]): - if batch_id is None: - return url_query - return url_query.where(URL.batch_id == batch_id) - - async def _apply_options( - self, - url_query: Select - ): - return url_query.options( - *[ - joinedload(relationship) - for relationship in self.single_join_relationships - ], - *[ - joinedload(primary).joinedload(secondary) - for primary, secondary in self.double_join_relationships - ] - ) - - async def _apply_order_clause(self, url_query: Select): - return url_query.order_by( - desc(TOTAL_DISTINCT_ANNOTATION_COUNT_LABEL), - asc(URL.id) - ) - - async def _extract_html_content_infos(self, url: URL) -> list[URLHTMLContentInfo]: - html_content = url.html_content - html_content_infos = [ - URLHTMLContentInfo(**html_info.__dict__) - for html_info in html_content - ] - return html_content_infos - - async def _extract_optional_metadata(self, url: URL) -> FinalReviewOptionalMetadata: - if url.optional_data_source_metadata is None: - return FinalReviewOptionalMetadata() - return FinalReviewOptionalMetadata( - record_formats=url.optional_data_source_metadata.record_formats, - data_portal_type=url.optional_data_source_metadata.data_portal_type, - supplying_entity=url.optional_data_source_metadata.supplying_entity - ) - - async def get_batch_info(self, session: AsyncSession) -> Optional[FinalReviewBatchInfo]: - if self.batch_id is None: - return None - - count_reviewed_query = await self.get_count_reviewed_query() - - count_ready_query = await self.get_count_ready_query() - - full_query = ( - select( - func.coalesce(count_reviewed_query.c[self.count_label], 0).label("count_reviewed"), - func.coalesce(count_ready_query.c[self.count_label], 0).label("count_ready_for_review") - ) - .select_from( - count_ready_query.outerjoin( - count_reviewed_query, - count_reviewed_query.c.batch_id == count_ready_query.c.batch_id - ) - ) - ) - - raw_result = await session.execute(full_query) - return FinalReviewBatchInfo(**raw_result.mappings().one()) - - async def get_count_ready_query(self): - builder = self.anno_exists_builder - count_ready_query = ( - select( - LinkBatchURL.batch_id, - func.count(URL.id).label(self.count_label) - ) - .select_from(LinkBatchURL) - .join(URL) - .join( - builder.query, - builder.url_id == URL.id - ) - .where( - LinkBatchURL.batch_id == self.batch_id, - URL.outcome == URLStatus.PENDING.value, - *self._get_where_exist_clauses( - builder.query - ) - ) - .group_by(LinkBatchURL.batch_id) - .subquery("count_ready") - ) - return count_ready_query - - async def get_count_reviewed_query(self): - count_reviewed_query = ( - select( - Batch.id.label("batch_id"), - func.count(URL.id).label(self.count_label) - ) - .select_from(Batch) - .join(LinkBatchURL) - .outerjoin(URL, URL.id == LinkBatchURL.url_id) - .where( - URL.outcome.in_( - [ - URLStatus.VALIDATED.value, - URLStatus.NOT_RELEVANT.value, - URLStatus.SUBMITTED.value, - URLStatus.INDIVIDUAL_RECORD.value - ] - ), - LinkBatchURL.batch_id == self.batch_id - ) - .group_by(Batch.id) - .subquery("count_reviewed") - ) - return count_reviewed_query - - async def run( - self, - session: AsyncSession - ) -> GetNextURLForFinalReviewOuterResponse: - await self.anno_exists_builder.build() - - url_query = await self.build_url_query() - - raw_result = await session.execute(url_query.limit(1)) - row = raw_result.unique().first() - - if row is None: - return GetNextURLForFinalReviewOuterResponse( - next_source=None, - remaining=0 - ) - - count_query = ( - select( - func.count() - ).select_from(url_query.subquery("count")) - ) - remaining_result = (await session.execute(count_query)).scalar() - - - result: URL = row[0] - - html_content_infos = await self._extract_html_content_infos(result) - optional_metadata = await self._extract_optional_metadata(result) - - batch_info = await self.get_batch_info(session) - try: - - next_source = GetNextURLForFinalReviewResponse( - id=result.id, - url=result.url, - html_info=convert_to_response_html_info(html_content_infos), - name=result.name, - description=result.description, - annotations=FinalReviewAnnotationInfo( - relevant=DTOConverter.final_review_annotation_relevant_info( - user_suggestion=result.user_relevant_suggestion, - auto_suggestion=result.auto_relevant_suggestion - ), - record_type=DTOConverter.final_review_annotation_record_type_info( - user_suggestion=result.user_record_type_suggestion, - auto_suggestion=result.auto_record_type_suggestion - ), - agency=DTOConverter.final_review_annotation_agency_info( - automated_agency_suggestions=result.automated_agency_suggestions, - user_agency_suggestion=result.user_agency_suggestion, - confirmed_agencies=result.confirmed_agencies - ) - ), - optional_metadata=optional_metadata, - batch_info=batch_info - ) - return GetNextURLForFinalReviewOuterResponse( - next_source=next_source, - remaining=remaining_result - ) - except Exception as e: - raise FailedQueryException(f"Failed to convert result for url id {result.id} to response") from e - - async def build_url_query(self): - anno_exists_query = self.anno_exists_builder.query - url_query = self._build_base_query(anno_exists_query) - url_query = await self._apply_options(url_query) - url_query = await self._apply_order_clause(url_query) - - return url_query diff --git a/src/api/endpoints/review/reject/query.py b/src/api/endpoints/review/reject/query.py index 50bee0bc..1f9dfe91 100644 --- a/src/api/endpoints/review/reject/query.py +++ b/src/api/endpoints/review/reject/query.py @@ -5,8 +5,10 @@ from src.api.endpoints.review.enums import RejectionReason from src.collectors.enums import URLStatus -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.reviewing_user import ReviewingUserURL +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.reviewing_user import ReviewingUserURL from src.db.queries.base.builder import QueryBuilderBase @@ -33,19 +35,26 @@ async def run(self, session) -> None: url = await session.execute(query) url = url.scalars().first() + validation_type: URLType match self.rejection_reason: case RejectionReason.INDIVIDUAL_RECORD: - url.outcome = URLStatus.INDIVIDUAL_RECORD.value + validation_type = URLType.INDIVIDUAL_RECORD case RejectionReason.BROKEN_PAGE_404: - url.outcome = URLStatus.NOT_FOUND.value + validation_type = URLType.BROKEN_PAGE case RejectionReason.NOT_RELEVANT: - url.outcome = URLStatus.NOT_RELEVANT.value + validation_type = URLType.NOT_RELEVANT case _: raise HTTPException( status_code=HTTP_400_BAD_REQUEST, detail="Invalid rejection reason" ) + flag_url_validated = FlagURLValidated( + url_id=self.url_id, + type=validation_type + ) + session.add(flag_url_validated) + # Add rejecting user rejecting_user_url = ReviewingUserURL( user_id=self.user_id, diff --git a/src/api/endpoints/review/routes.py b/src/api/endpoints/review/routes.py deleted file mode 100644 index c2ceada9..00000000 --- a/src/api/endpoints/review/routes.py +++ /dev/null @@ -1,59 +0,0 @@ -from fastapi import APIRouter, Depends, Query - -from src.api.dependencies import get_async_core -from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse -from src.api.endpoints.review.reject.dto import FinalReviewRejectionInfo -from src.core.core import AsyncCore -from src.security.dtos.access_info import AccessInfo -from src.security.enums import Permissions -from src.security.manager import require_permission - -review_router = APIRouter( - prefix="/review", - tags=["Review"], - responses={404: {"description": "Not found"}}, -) - -requires_final_review_permission = require_permission(Permissions.SOURCE_COLLECTOR_FINAL_REVIEW) - -batch_id_query = Query( - description="The batch id of the next URL to get. " - "If not specified, defaults to first qualifying URL", - default=None -) - -@review_router.get("/next-source") -async def get_next_source( - core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(requires_final_review_permission), - batch_id: int | None = batch_id_query, -) -> GetNextURLForFinalReviewOuterResponse: - return await core.get_next_source_for_review(batch_id=batch_id) - -@review_router.post("/approve-source") -async def approve_source( - core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(requires_final_review_permission), - approval_info: FinalReviewApprovalInfo = FinalReviewApprovalInfo, - batch_id: int | None = batch_id_query, -) -> GetNextURLForFinalReviewOuterResponse: - await core.approve_url( - approval_info, - access_info=access_info, - ) - return await core.get_next_source_for_review(batch_id=batch_id) - -@review_router.post("/reject-source") -async def reject_source( - core: AsyncCore = Depends(get_async_core), - access_info: AccessInfo = Depends(requires_final_review_permission), - review_info: FinalReviewRejectionInfo = FinalReviewRejectionInfo, - batch_id: int | None = batch_id_query, -) -> GetNextURLForFinalReviewOuterResponse: - await core.reject_url( - url_id=review_info.url_id, - access_info=access_info, - rejection_reason=review_info.rejection_reason - ) - return await core.get_next_source_for_review(batch_id=batch_id) diff --git a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/dtos/__init__.py b/src/api/endpoints/search/agency/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/root_url_cache/dtos/__init__.py rename to src/api/endpoints/search/agency/__init__.py diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/__init__.py b/src/api/endpoints/search/agency/ctes/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_miscellaneous_metadata/__init__.py rename to src/api/endpoints/search/agency/ctes/__init__.py diff --git a/src/api/endpoints/search/agency/ctes/with_location_id.py b/src/api/endpoints/search/agency/ctes/with_location_id.py new file mode 100644 index 00000000..345cb245 --- /dev/null +++ b/src/api/endpoints/search/agency/ctes/with_location_id.py @@ -0,0 +1,48 @@ +from sqlalchemy import select, literal, CTE, Column + +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.views.dependent_locations import DependentLocationView + + +class WithLocationIdCTEContainer: + + def __init__(self, location_id: int): + + target_locations_cte = ( + select( + literal(location_id).label("location_id") + ) + .union( + select( + DependentLocationView.dependent_location_id + ) + .where( + DependentLocationView.parent_location_id == location_id + ) + ) + .cte("target_locations") + ) + + self._cte = ( + select( + LinkAgencyLocation.agency_id, + LinkAgencyLocation.location_id + ) + .join( + target_locations_cte, + target_locations_cte.c.location_id == LinkAgencyLocation.location_id + ) + .cte("with_location_id") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def agency_id(self) -> Column: + return self._cte.c.agency_id + + @property + def location_id(self) -> Column: + return self._cte.c.location_id \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/queries/__init__.py b/src/api/endpoints/search/agency/models/__init__.py similarity index 100% rename from src/core/tasks/url/operators/url_miscellaneous_metadata/queries/__init__.py rename to src/api/endpoints/search/agency/models/__init__.py diff --git a/src/api/endpoints/search/agency/models/response.py b/src/api/endpoints/search/agency/models/response.py new file mode 100644 index 00000000..1b6b82d5 --- /dev/null +++ b/src/api/endpoints/search/agency/models/response.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.db.models.impl.agency.enums import AgencyType, JurisdictionType + + +class AgencySearchResponse(BaseModel): + agency_id: int + agency_name: str + jurisdiction_type: JurisdictionType | None + agency_type: AgencyType + location_display_name: str diff --git a/src/api/endpoints/search/agency/query.py b/src/api/endpoints/search/agency/query.py new file mode 100644 index 00000000..9476e039 --- /dev/null +++ b/src/api/endpoints/search/agency/query.py @@ -0,0 +1,84 @@ +from typing import Sequence + +from sqlalchemy import select, func, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.search.agency.ctes.with_location_id import WithLocationIdCTEContainer +from src.api.endpoints.search.agency.models.response import AgencySearchResponse +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.agency.enums import JurisdictionType +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.views.location_expanded import LocationExpandedView +from src.db.queries.base.builder import QueryBuilderBase + + +class SearchAgencyQueryBuilder(QueryBuilderBase): + + def __init__( + self, + location_id: int | None, + query: str | None, + jurisdiction_type: JurisdictionType | None, + ): + super().__init__() + self.location_id = location_id + self.query = query + self.jurisdiction_type = jurisdiction_type + + async def run(self, session: AsyncSession) -> list[AgencySearchResponse]: + + query = ( + select( + Agency.agency_id, + Agency.name.label("agency_name"), + Agency.jurisdiction_type, + Agency.agency_type, + LocationExpandedView.display_name.label("location_display_name") + ) + ) + if self.location_id is None: + query = query.join( + LinkAgencyLocation, + LinkAgencyLocation.agency_id == Agency.agency_id + ).join( + LocationExpandedView, + LocationExpandedView.id == LinkAgencyLocation.location_id + ) + else: + with_location_id_cte_container = WithLocationIdCTEContainer(self.location_id) + query = query.join( + with_location_id_cte_container.cte, + with_location_id_cte_container.agency_id == Agency.agency_id + ).join( + LocationExpandedView, + LocationExpandedView.id == with_location_id_cte_container.location_id + ) + + if self.jurisdiction_type is not None: + query = query.where( + Agency.jurisdiction_type == self.jurisdiction_type + ) + + if self.query is not None: + query = query.order_by( + func.similarity( + Agency.name, + self.query + ).desc() + ) + + query = query.limit(50) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query) + + return [ + AgencySearchResponse( + **mapping + ) + for mapping in mappings + ] + + + + diff --git a/src/api/endpoints/search/dtos/response.py b/src/api/endpoints/search/dtos/response.py index 1a46c0be..c2283ea4 100644 --- a/src/api/endpoints/search/dtos/response.py +++ b/src/api/endpoints/search/dtos/response.py @@ -5,4 +5,4 @@ class SearchURLResponse(BaseModel): found: bool - url_id: Optional[int] = None \ No newline at end of file + url_id: int | None = None \ No newline at end of file diff --git a/src/api/endpoints/search/routes.py b/src/api/endpoints/search/routes.py index a1b576f2..f2abb93c 100644 --- a/src/api/endpoints/search/routes.py +++ b/src/api/endpoints/search/routes.py @@ -1,8 +1,13 @@ -from fastapi import APIRouter, Query, Depends + +from fastapi import APIRouter, Query, Depends, HTTPException +from starlette import status from src.api.dependencies import get_async_core +from src.api.endpoints.search.agency.models.response import AgencySearchResponse +from src.api.endpoints.search.agency.query import SearchAgencyQueryBuilder from src.api.endpoints.search.dtos.response import SearchURLResponse from src.core.core import AsyncCore +from src.db.models.impl.agency.enums import JurisdictionType from src.security.manager import get_access_info from src.security.dtos.access_info import AccessInfo @@ -18,4 +23,36 @@ async def search_url( """ Search for a URL in the database """ - return await async_core.search_for_url(url) \ No newline at end of file + return await async_core.search_for_url(url) + + +@search_router.get("/agency") +async def search_agency( + location_id: int | None = Query( + description="The location id to search for", + default=None + ), + query: str | None = Query( + description="The query to search for", + default=None + ), + jurisdiction_type: JurisdictionType | None = Query( + description="The jurisdiction type to search for", + default=None + ), + access_info: AccessInfo = Depends(get_access_info), + async_core: AsyncCore = Depends(get_async_core), +) -> list[AgencySearchResponse]: + if query is None and location_id is None and jurisdiction_type is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one of query or location_id must be provided" + ) + + return await async_core.adb_client.run_query_builder( + SearchAgencyQueryBuilder( + location_id=location_id, + query=query, + jurisdiction_type=jurisdiction_type + ) + ) \ No newline at end of file diff --git a/src/core/tasks/url/subtasks/agency_identification/__init__.py b/src/api/endpoints/submit/__init__.py similarity index 100% rename from src/core/tasks/url/subtasks/agency_identification/__init__.py rename to src/api/endpoints/submit/__init__.py diff --git a/src/api/endpoints/submit/routes.py b/src/api/endpoints/submit/routes.py new file mode 100644 index 00000000..d91d1821 --- /dev/null +++ b/src/api/endpoints/submit/routes.py @@ -0,0 +1,24 @@ +from fastapi import APIRouter, Depends + +from src.api.dependencies import get_async_core +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.api.endpoints.submit.url.queries.core import SubmitURLQueryBuilder +from src.core.core import AsyncCore +from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_access_info + +submit_router = APIRouter(prefix="/submit", tags=["submit"]) + +@submit_router.post("/url") +async def submit_url( + request: URLSubmissionRequest, + access_info: AccessInfo = Depends(get_access_info), + async_core: AsyncCore = Depends(get_async_core), +) -> URLSubmissionResponse: + return await async_core.adb_client.run_query_builder( + SubmitURLQueryBuilder( + request=request, + user_id=access_info.user_id + ) + ) \ No newline at end of file diff --git a/src/db/dtos/url/annotations/__init__.py b/src/api/endpoints/submit/url/__init__.py similarity index 100% rename from src/db/dtos/url/annotations/__init__.py rename to src/api/endpoints/submit/url/__init__.py diff --git a/src/api/endpoints/submit/url/enums.py b/src/api/endpoints/submit/url/enums.py new file mode 100644 index 00000000..08802072 --- /dev/null +++ b/src/api/endpoints/submit/url/enums.py @@ -0,0 +1,7 @@ +from enum import Enum + +class URLSubmissionStatus(Enum): + ACCEPTED_AS_IS = "accepted_as_is" + ACCEPTED_WITH_CLEANING = "accepted_with_cleaning" + DATABASE_DUPLICATE = "database_duplicate" + INVALID = "invalid" \ No newline at end of file diff --git a/src/db/dtos/url/annotations/auto/__init__.py b/src/api/endpoints/submit/url/models/__init__.py similarity index 100% rename from src/db/dtos/url/annotations/auto/__init__.py rename to src/api/endpoints/submit/url/models/__init__.py diff --git a/src/api/endpoints/submit/url/models/request.py b/src/api/endpoints/submit/url/models/request.py new file mode 100644 index 00000000..5b52d761 --- /dev/null +++ b/src/api/endpoints/submit/url/models/request.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.core.enums import RecordType + + +class URLSubmissionRequest(BaseModel): + url: str + record_type: RecordType | None = None + name: str | None = None + location_id: int | None = None + agency_id: int | None = None \ No newline at end of file diff --git a/src/api/endpoints/submit/url/models/response.py b/src/api/endpoints/submit/url/models/response.py new file mode 100644 index 00000000..f2f8d031 --- /dev/null +++ b/src/api/endpoints/submit/url/models/response.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel, model_validator + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus + + +class URLSubmissionResponse(BaseModel): + url_original: str + url_cleaned: str | None = None + status: URLSubmissionStatus + url_id: int | None = None + + @model_validator(mode="after") + def validate_url_id_if_accepted(self): + if self.status in [URLSubmissionStatus.ACCEPTED_AS_IS, URLSubmissionStatus.ACCEPTED_WITH_CLEANING]: + if self.url_id is None: + raise ValueError("url_id is required for accepted urls") + return self + diff --git a/src/db/models/instantiations/__init__.py b/src/api/endpoints/submit/url/queries/__init__.py similarity index 100% rename from src/db/models/instantiations/__init__.py rename to src/api/endpoints/submit/url/queries/__init__.py diff --git a/src/api/endpoints/submit/url/queries/convert.py b/src/api/endpoints/submit/url/queries/convert.py new file mode 100644 index 00000000..90a32566 --- /dev/null +++ b/src/api/endpoints/submit/url/queries/convert.py @@ -0,0 +1,21 @@ +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse + + +def convert_invalid_url_to_url_response( + url: str +) -> URLSubmissionResponse: + return URLSubmissionResponse( + url_original=url, + status=URLSubmissionStatus.INVALID, + ) + +def convert_duplicate_urls_to_url_response( + clean_url: str, + original_url: str +) -> URLSubmissionResponse: + return URLSubmissionResponse( + url_original=original_url, + url_cleaned=clean_url, + status=URLSubmissionStatus.DATABASE_DUPLICATE, + ) diff --git a/src/api/endpoints/submit/url/queries/core.py b/src/api/endpoints/submit/url/queries/core.py new file mode 100644 index 00000000..081b5456 --- /dev/null +++ b/src/api/endpoints/submit/url/queries/core.py @@ -0,0 +1,128 @@ + +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.api.endpoints.submit.url.queries.convert import convert_invalid_url_to_url_response, \ + convert_duplicate_urls_to_url_response +from src.api.endpoints.submit.url.queries.dedupe import DeduplicateURLQueryBuilder +from src.collectors.enums import URLStatus +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.link.user_suggestion_not_found.users_submitted_url.sqlalchemy import LinkUserSubmittedURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.queries.base.builder import QueryBuilderBase +from src.db.utils.validate import is_valid_url +from src.util.clean import clean_url + + +class SubmitURLQueryBuilder(QueryBuilderBase): + + def __init__( + self, + request: URLSubmissionRequest, + user_id: int + ): + super().__init__() + self.request = request + self.user_id = user_id + + async def run(self, session: AsyncSession) -> URLSubmissionResponse: + url_original: str = self.request.url + + # Filter out invalid URLs + valid: bool = is_valid_url(url_original) + if not valid: + return convert_invalid_url_to_url_response(url_original) + + # Clean URLs + url_clean: str = clean_url(url_original) + + # Check if duplicate + is_duplicate: bool = await DeduplicateURLQueryBuilder(url=url_clean).run(session) + if is_duplicate: + return convert_duplicate_urls_to_url_response( + clean_url=url_clean, + original_url=url_original + ) + + # Submit URLs and get URL id + + # Add URL + url_insert = URL( + url=url_clean, + source=URLSource.MANUAL, + status=URLStatus.OK, + ) + session.add(url_insert) + await session.flush() + + # Add Link + link = LinkUserSubmittedURL( + url_id=url_insert.id, + user_id=self.user_id, + ) + session.add(link) + + # Add record type as suggestion if exists + if self.request.record_type is not None: + rec_sugg = UserRecordTypeSuggestion( + user_id=self.user_id, + url_id=url_insert.id, + record_type=self.request.record_type.value + ) + session.add(rec_sugg) + + # Add name as suggestion if exists + if self.request.name is not None: + name_sugg = URLNameSuggestion( + url_id=url_insert.id, + suggestion=self.request.name, + source=NameSuggestionSource.USER + ) + session.add(name_sugg) + await session.flush() + + link_name_sugg = LinkUserNameSuggestion( + suggestion_id=name_sugg.id, + user_id=self.user_id + ) + session.add(link_name_sugg) + + + + # Add location ID as suggestion if exists + if self.request.location_id is not None: + loc_sugg = UserLocationSuggestion( + user_id=self.user_id, + url_id=url_insert.id, + location_id=self.request.location_id + ) + session.add(loc_sugg) + + # Add agency ID as suggestion if exists + if self.request.agency_id is not None: + agen_sugg = UserUrlAgencySuggestion( + user_id=self.user_id, + url_id=url_insert.id, + agency_id=self.request.agency_id + ) + session.add(agen_sugg) + + if url_clean == url_original: + status = URLSubmissionStatus.ACCEPTED_AS_IS + else: + status = URLSubmissionStatus.ACCEPTED_WITH_CLEANING + + return URLSubmissionResponse( + url_original=url_original, + url_cleaned=url_clean, + status=status, + url_id=url_insert.id, + ) diff --git a/src/api/endpoints/submit/url/queries/dedupe.py b/src/api/endpoints/submit/url/queries/dedupe.py new file mode 100644 index 00000000..43c92edd --- /dev/null +++ b/src/api/endpoints/submit/url/queries/dedupe.py @@ -0,0 +1,28 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class DeduplicateURLQueryBuilder(QueryBuilderBase): + + def __init__(self, url: str): + super().__init__() + self.url = url + + async def run(self, session: AsyncSession) -> bool: + + query = select( + URL.url + ).where( + URL.url == self.url + ) + + return await sh.has_results(session, query=query) + + + + + diff --git a/src/api/endpoints/task/by_id/dto.py b/src/api/endpoints/task/by_id/dto.py index 411ad7f7..64595f5d 100644 --- a/src/api/endpoints/task/by_id/dto.py +++ b/src/api/endpoints/task/by_id/dto.py @@ -1,18 +1,17 @@ import datetime -from typing import Optional from pydantic import BaseModel -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.dtos.url.core import URLInfo from src.db.enums import TaskType -from src.core.enums import BatchStatus +from src.db.models.impl.task.enums import TaskStatus +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.error_info.pydantic import URLErrorInfoPydantic class TaskInfo(BaseModel): task_type: TaskType - task_status: BatchStatus + task_status: TaskStatus updated_at: datetime.datetime - error_info: Optional[str] = None + error_info: str | None = None urls: list[URLInfo] - url_errors: list[URLErrorPydanticInfo] \ No newline at end of file + url_errors: list[URLErrorInfoPydantic] \ No newline at end of file diff --git a/src/api/endpoints/task/by_id/query.py b/src/api/endpoints/task/by_id/query.py index a57b9daf..92487327 100644 --- a/src/api/endpoints/task/by_id/query.py +++ b/src/api/endpoints/task/by_id/query.py @@ -1,15 +1,15 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, joinedload from src.api.endpoints.task.by_id.dto import TaskInfo from src.collectors.enums import URLStatus -from src.core.enums import BatchStatus -from src.db.dtos.url.core import URLInfo -from src.db.dtos.url.error import URLErrorPydanticInfo from src.db.enums import TaskType -from src.db.models.instantiations.task.core import Task -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.task.core import Task +from src.db.models.impl.task.enums import TaskStatus +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.error_info.pydantic import URLErrorInfoPydantic from src.db.queries.base.builder import QueryBuilderBase @@ -27,12 +27,12 @@ async def run(self, session: AsyncSession) -> TaskInfo: .options( selectinload(Task.urls) .selectinload(URL.batch), - selectinload(Task.error), - selectinload(Task.errored_urls) + selectinload(Task.url_errors), + selectinload(Task.errors) ) ) task = result.scalars().first() - error = task.error[0].error if len(task.error) > 0 else None + error = task.errors[0].error if len(task.errors) > 0 else None # Get error info if any # Get URLs urls = task.urls @@ -43,23 +43,23 @@ async def run(self, session: AsyncSession) -> TaskInfo: batch_id=url.batch.id, url=url.url, collector_metadata=url.collector_metadata, - outcome=URLStatus(url.outcome), + status=URLStatus(url.status), updated_at=url.updated_at ) url_infos.append(url_info) errored_urls = [] - for url in task.errored_urls: - url_error_info = URLErrorPydanticInfo( + for url in task.url_errors: + url_error_info = URLErrorInfoPydantic( task_id=url.task_id, url_id=url.url_id, error=url.error, - updated_at=url.updated_at + updated_at=url.created_at ) errored_urls.append(url_error_info) return TaskInfo( task_type=TaskType(task.task_type), - task_status=BatchStatus(task.task_status), + task_status=TaskStatus(task.task_status), error_info=error, updated_at=task.updated_at, urls=url_infos, diff --git a/src/api/endpoints/task/routes.py b/src/api/endpoints/task/routes.py index a719d6b9..23f52999 100644 --- a/src/api/endpoints/task/routes.py +++ b/src/api/endpoints/task/routes.py @@ -25,11 +25,11 @@ async def get_tasks( description="The page number", default=1 ), - task_status: Optional[BatchStatus] = Query( + task_status: BatchStatus | None = Query( description="Filter by task status", default=None ), - task_type: Optional[TaskType] = Query( + task_type: TaskType | None = Query( description="Filter by task type", default=None ), diff --git a/src/db/models/instantiations/link/__init__.py b/src/api/endpoints/url/by_id/__init__.py similarity index 100% rename from src/db/models/instantiations/link/__init__.py rename to src/api/endpoints/url/by_id/__init__.py diff --git a/src/db/models/instantiations/task/__init__.py b/src/api/endpoints/url/by_id/screenshot/__init__.py similarity index 100% rename from src/db/models/instantiations/task/__init__.py rename to src/api/endpoints/url/by_id/screenshot/__init__.py diff --git a/src/api/endpoints/url/by_id/screenshot/query.py b/src/api/endpoints/url/by_id/screenshot/query.py new file mode 100644 index 00000000..93a38b23 --- /dev/null +++ b/src/api/endpoints/url/by_id/screenshot/query.py @@ -0,0 +1,28 @@ +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class GetURLScreenshotQueryBuilder(QueryBuilderBase): + + def __init__(self, url_id: int): + super().__init__() + self.url_id = url_id + + async def run(self, session: AsyncSession) -> bytes | None: + + query = ( + select(URLScreenshot.content) + .where(URLScreenshot.url_id == self.url_id) + ) + + return await sh.one_or_none( + session=session, + query=query + ) + diff --git a/src/api/endpoints/url/by_id/screenshot/wrapper.py b/src/api/endpoints/url/by_id/screenshot/wrapper.py new file mode 100644 index 00000000..9de38cbb --- /dev/null +++ b/src/api/endpoints/url/by_id/screenshot/wrapper.py @@ -0,0 +1,22 @@ +from http import HTTPStatus + +from fastapi import HTTPException + +from src.api.endpoints.url.by_id.screenshot.query import GetURLScreenshotQueryBuilder +from src.db.client.async_ import AsyncDatabaseClient + + +async def get_url_screenshot_wrapper( + url_id: int, + adb_client: AsyncDatabaseClient, +) -> bytes: + + raw_result: bytes | None = await adb_client.run_query_builder( + GetURLScreenshotQueryBuilder(url_id=url_id) + ) + if raw_result is None: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail="URL not found" + ) + return raw_result \ No newline at end of file diff --git a/src/api/endpoints/url/get/dto.py b/src/api/endpoints/url/get/dto.py index 3b3e980e..a4616d7e 100644 --- a/src/api/endpoints/url/get/dto.py +++ b/src/api/endpoints/url/get/dto.py @@ -4,10 +4,11 @@ from pydantic import BaseModel from src.collectors.enums import URLStatus -from src.db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource +from src.db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource, TaskType + class GetURLsResponseErrorInfo(BaseModel): - id: int + task: TaskType error: str updated_at: datetime.datetime @@ -25,7 +26,7 @@ class GetURLsResponseInnerInfo(BaseModel): batch_id: int | None url: str status: URLStatus - collector_metadata: Optional[dict] + collector_metadata: dict | None updated_at: datetime.datetime created_at: datetime.datetime errors: list[GetURLsResponseErrorInfo] diff --git a/src/api/endpoints/url/get/query.py b/src/api/endpoints/url/get/query.py index 1ba5a75f..d7198612 100644 --- a/src/api/endpoints/url/get/query.py +++ b/src/api/endpoints/url/get/query.py @@ -5,8 +5,8 @@ from src.api.endpoints.url.get.dto import GetURLsResponseInfo, GetURLsResponseErrorInfo, GetURLsResponseInnerInfo from src.collectors.enums import URLStatus from src.db.client.helpers import add_standard_limit_and_offset -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.error_info import URLErrorInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError from src.db.queries.base.builder import QueryBuilderBase @@ -23,14 +23,14 @@ def __init__( async def run(self, session: AsyncSession) -> GetURLsResponseInfo: statement = select(URL).options( - selectinload(URL.error_info), + selectinload(URL.task_errors), selectinload(URL.batch) ).order_by(URL.id) if self.errors: # Only return URLs with errors statement = statement.where( exists( - select(URLErrorInfo).where(URLErrorInfo.url_id == URL.id) + select(URLTaskError).where(URLTaskError.url_id == URL.id) ) ) add_standard_limit_and_offset(statement, self.page) @@ -39,11 +39,11 @@ async def run(self, session: AsyncSession) -> GetURLsResponseInfo: final_results = [] for result in all_results: error_results = [] - for error in result.error_info: + for error in result.task_errors: error_result = GetURLsResponseErrorInfo( - id=error.id, + task=error.task_type, error=error.error, - updated_at=error.updated_at + updated_at=error.created_at ) error_results.append(error_result) final_results.append( @@ -51,7 +51,7 @@ async def run(self, session: AsyncSession) -> GetURLsResponseInfo: id=result.id, batch_id=result.batch.id if result.batch is not None else None, url=result.url, - status=URLStatus(result.outcome), + status=URLStatus(result.status), collector_metadata=result.collector_metadata, updated_at=result.updated_at, created_at=result.created_at, diff --git a/src/api/endpoints/url/routes.py b/src/api/endpoints/url/routes.py index 225dd5d6..c7bb59b0 100644 --- a/src/api/endpoints/url/routes.py +++ b/src/api/endpoints/url/routes.py @@ -1,6 +1,7 @@ -from fastapi import APIRouter, Query, Depends +from fastapi import APIRouter, Query, Depends, Response from src.api.dependencies import get_async_core +from src.api.endpoints.url.by_id.screenshot.wrapper import get_url_screenshot_wrapper from src.api.endpoints.url.get.dto import GetURLsResponseInfo from src.core.core import AsyncCore from src.security.manager import get_access_info @@ -27,3 +28,18 @@ async def get_urls( ) -> GetURLsResponseInfo: result = await async_core.get_urls(page=page, errors=errors) return result + +@url_router.get("/{url_id}/screenshot") +async def get_url_screenshot( + url_id: int, + async_core: AsyncCore = Depends(get_async_core), +) -> Response: + + raw_result: bytes = await get_url_screenshot_wrapper( + url_id=url_id, + adb_client=async_core.adb_client + ) + return Response( + content=raw_result, + media_type="image/webp" + ) diff --git a/src/api/main.py b/src/api/main.py index 355fbedf..2d31dc1f 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -10,34 +10,43 @@ from src.api.endpoints.annotate.routes import annotate_router from src.api.endpoints.batch.routes import batch_router from src.api.endpoints.collector.routes import collector_router +from src.api.endpoints.contributions.routes import contributions_router from src.api.endpoints.metrics.routes import metrics_router -from src.api.endpoints.review.routes import review_router from src.api.endpoints.root import root_router from src.api.endpoints.search.routes import search_router +from src.api.endpoints.submit.routes import submit_router from src.api.endpoints.task.routes import task_router from src.api.endpoints.url.routes import url_router +from src.collectors.impl.muckrock.api_interface.core import MuckrockAPIInterface from src.collectors.manager import AsyncCollectorManager -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface from src.core.core import AsyncCore -from src.core.logger import AsyncCoreLogger from src.core.env_var_manager import EnvVarManager +from src.core.logger import AsyncCoreLogger from src.core.tasks.handler import TaskHandler from src.core.tasks.scheduled.loader import ScheduledTaskOperatorLoader from src.core.tasks.scheduled.manager import AsyncScheduledTaskManager +from src.core.tasks.scheduled.registry.core import ScheduledJobRegistry from src.core.tasks.url.loader import URLTaskOperatorLoader from src.core.tasks.url.manager import TaskManager -from src.core.tasks.url.operators.url_html.scraper.parser.core import HTMLResponseParser -from src.core.tasks.url.operators.url_html.scraper.request_interface.core import URLRequestInterface +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.enums import \ + SpacyModelType +from src.core.tasks.url.operators.html.scraper.parser.core import HTMLResponseParser from src.db.client.async_ import AsyncDatabaseClient from src.db.client.sync import DatabaseClient -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.core import RootURLCache +from src.external.huggingface.hub.client import HuggingFaceHubClient from src.external.huggingface.inference.client import HuggingFaceInferenceClient +from src.external.internet_archives.client import InternetArchivesClient from src.external.pdap.client import PDAPClient +from src.external.url_request.core import URLRequestInterface +from environs import Env @asynccontextmanager async def lifespan(app: FastAPI): env_var_manager = EnvVarManager.get() + env = Env() + env.read_env() # Initialize shared dependencies db_client = DatabaseClient( @@ -51,11 +60,16 @@ async def lifespan(app: FastAPI): session = aiohttp.ClientSession() - task_handler = TaskHandler( - adb_client=adb_client, - discord_poster=DiscordPoster( + if env.bool("POST_TO_DISCORD_FLAG", True): + discord_poster = DiscordPoster( webhook_url=env_var_manager.discord_webhook_url ) + else: + discord_poster = None + + task_handler = TaskHandler( + adb_client=adb_client, + discord_poster=discord_poster ) pdap_client = PDAPClient( access_manager=AccessManager( @@ -72,9 +86,7 @@ async def lifespan(app: FastAPI): loader=URLTaskOperatorLoader( adb_client=adb_client, url_request_interface=URLRequestInterface(), - html_parser=HTMLResponseParser( - root_url_cache=RootURLCache() - ), + html_parser=HTMLResponseParser(), pdap_client=pdap_client, muckrock_api_interface=MuckrockAPIInterface( session=session @@ -82,6 +94,9 @@ async def lifespan(app: FastAPI): hf_inference_client=HuggingFaceInferenceClient( session=session, token=env_var_manager.hf_inference_api_key + ), + nlp_processor=NLPProcessor( + model_type=SpacyModelType.EN_CORE_WEB_SM ) ), ) @@ -97,12 +112,19 @@ async def lifespan(app: FastAPI): collector_manager=async_collector_manager ) async_scheduled_task_manager = AsyncScheduledTaskManager( - async_core=async_core, handler=task_handler, loader=ScheduledTaskOperatorLoader( adb_client=adb_client, - pdap_client=pdap_client - ) + pdap_client=pdap_client, + hf_client=HuggingFaceHubClient( + token=env_var_manager.hf_hub_token + ), + async_core=async_core, + ia_client=InternetArchivesClient( + session=session + ) + ), + registry=ScheduledJobRegistry() ) await async_scheduled_task_manager.setup() @@ -152,9 +174,10 @@ async def redirect_docs(): annotate_router, url_router, task_router, - review_router, search_router, - metrics_router + metrics_router, + submit_router, + contributions_router ] for router in routers: diff --git a/src/collectors/enums.py b/src/collectors/enums.py index 1732bd19..f40e5f19 100644 --- a/src/collectors/enums.py +++ b/src/collectors/enums.py @@ -11,11 +11,6 @@ class CollectorType(Enum): MANUAL = "manual" class URLStatus(Enum): - PENDING = "pending" - SUBMITTED = "submitted" - VALIDATED = "validated" + OK = "ok" ERROR = "error" DUPLICATE = "duplicate" - NOT_RELEVANT = "not relevant" - NOT_FOUND = "404 not found" - INDIVIDUAL_RECORD = "individual record" diff --git a/src/collectors/source_collectors/README.md b/src/collectors/impl/README.md similarity index 100% rename from src/collectors/source_collectors/README.md rename to src/collectors/impl/README.md diff --git a/src/db/models/instantiations/url/__init__.py b/src/collectors/impl/__init__.py similarity index 100% rename from src/db/models/instantiations/url/__init__.py rename to src/collectors/impl/__init__.py diff --git a/src/collectors/source_collectors/auto_googler/README.md b/src/collectors/impl/auto_googler/README.md similarity index 100% rename from src/collectors/source_collectors/auto_googler/README.md rename to src/collectors/impl/auto_googler/README.md diff --git a/src/db/models/instantiations/url/suggestion/__init__.py b/src/collectors/impl/auto_googler/__init__.py similarity index 100% rename from src/db/models/instantiations/url/suggestion/__init__.py rename to src/collectors/impl/auto_googler/__init__.py diff --git a/src/collectors/impl/auto_googler/auto_googler.py b/src/collectors/impl/auto_googler/auto_googler.py new file mode 100644 index 00000000..bbaefed9 --- /dev/null +++ b/src/collectors/impl/auto_googler/auto_googler.py @@ -0,0 +1,35 @@ +from src.collectors.impl.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO +from src.collectors.impl.auto_googler.searcher import GoogleSearcher +from src.collectors.impl.auto_googler.dtos.config import SearchConfig + + +class AutoGoogler: + """ + The AutoGoogler orchestrates the process of fetching urls from Google Search + and processing them for source collection + + """ + def __init__( + self, + search_config: SearchConfig, + google_searcher: GoogleSearcher + ): + self.search_config = search_config + self.google_searcher = google_searcher + self.data: dict[str, list[GoogleSearchQueryResultsInnerDTO]] = { + query : [] for query in search_config.queries + } + + async def run(self) -> str: + """ + Runs the AutoGoogler + Yields status messages + """ + for query in self.search_config.queries: + yield f"Searching for '{query}' ..." + results = await self.google_searcher.search(query) + yield f"Found {len(results)} results for '{query}'." + if results is not None: + self.data[query] = results + yield "Done." + diff --git a/src/collectors/impl/auto_googler/collector.py b/src/collectors/impl/auto_googler/collector.py new file mode 100644 index 00000000..9046f421 --- /dev/null +++ b/src/collectors/impl/auto_googler/collector.py @@ -0,0 +1,78 @@ +from typing import Any + +from src.collectors.impl.auto_googler.queries.agency import AutoGooglerAddAgencyQueryBuilder +from src.collectors.impl.auto_googler.queries.location import AutoGooglerAddLocationQueryBuilder +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.enums import CollectorType +from src.core.env_var_manager import EnvVarManager +from src.core.preprocessors.autogoogler import AutoGooglerPreprocessor +from src.collectors.impl.auto_googler.auto_googler import AutoGoogler +from src.collectors.impl.auto_googler.dtos.output import AutoGooglerInnerOutputDTO +from src.collectors.impl.auto_googler.dtos.input import AutoGooglerInputDTO +from src.collectors.impl.auto_googler.searcher import GoogleSearcher +from src.collectors.impl.auto_googler.dtos.config import SearchConfig +from src.db.models.impl.link.agency_batch.sqlalchemy import LinkAgencyBatch +from src.util.helper_functions import base_model_list_dump + + +class AutoGooglerCollector(AsyncCollectorBase): + collector_type = CollectorType.AUTO_GOOGLER + preprocessor = AutoGooglerPreprocessor + + async def run_to_completion(self) -> AutoGoogler: + dto: AutoGooglerInputDTO = self.dto + + queries: list[str] = dto.queries.copy() + + if dto.agency_id is not None: + + agency_name: str = await self.adb_client.run_query_builder( + AutoGooglerAddAgencyQueryBuilder( + batch_id=self.batch_id, + agency_id=dto.agency_id, + ) + ) + + # Add to all queries + queries = [f"{query} {agency_name}" for query in queries] + + if dto.location_id is not None: + location_name: str = await self.adb_client.run_query_builder( + AutoGooglerAddLocationQueryBuilder( + batch_id=self.batch_id, + location_id=dto.location_id, + ) + ) + + # Add to all queries + queries = [f"{query} {location_name}" for query in queries] + + env_var_manager = EnvVarManager.get() + auto_googler = AutoGoogler( + search_config=SearchConfig( + urls_per_result=dto.urls_per_result, + queries=queries, + ), + google_searcher=GoogleSearcher( + api_key=env_var_manager.google_api_key, + cse_id=env_var_manager.google_cse_id, + ) + ) + async for log in auto_googler.run(): + await self.log(log) + return auto_googler + + async def run_implementation(self) -> None: + + auto_googler: AutoGoogler = await self.run_to_completion() + + inner_data: list[dict[str, Any]] = [] + for query in auto_googler.search_config.queries: + query_results: list[AutoGooglerInnerOutputDTO] = auto_googler.data[query] + inner_data.append({ + "query": query, + "query_results": base_model_list_dump(query_results), + }) + + self.data = {"data": inner_data} + diff --git a/src/db/models/instantiations/url/suggestion/agency/__init__.py b/src/collectors/impl/auto_googler/dtos/__init__.py similarity index 100% rename from src/db/models/instantiations/url/suggestion/agency/__init__.py rename to src/collectors/impl/auto_googler/dtos/__init__.py diff --git a/src/collectors/source_collectors/auto_googler/dtos/config.py b/src/collectors/impl/auto_googler/dtos/config.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/dtos/config.py rename to src/collectors/impl/auto_googler/dtos/config.py diff --git a/src/collectors/impl/auto_googler/dtos/input.py b/src/collectors/impl/auto_googler/dtos/input.py new file mode 100644 index 00000000..07c55eec --- /dev/null +++ b/src/collectors/impl/auto_googler/dtos/input.py @@ -0,0 +1,23 @@ +from pydantic import BaseModel, Field + + +class AutoGooglerInputDTO(BaseModel): + urls_per_result: int = Field( + description="Maximum number of URLs returned per result. Minimum is 1. Default is 10", + default=10, + ge=1, + le=50 + ) + queries: list[str] = Field( + description="List of queries to search for.", + min_length=1, + max_length=100 + ) + agency_id: int | None = Field( + description="ID of the agency to search for. Optional.", + default=None + ) + location_id: int | None = Field( + description="ID of the location to search for. Optional.", + default=None + ) diff --git a/src/collectors/source_collectors/auto_googler/dtos/output.py b/src/collectors/impl/auto_googler/dtos/output.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/dtos/output.py rename to src/collectors/impl/auto_googler/dtos/output.py diff --git a/src/collectors/source_collectors/auto_googler/dtos/query_results.py b/src/collectors/impl/auto_googler/dtos/query_results.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/dtos/query_results.py rename to src/collectors/impl/auto_googler/dtos/query_results.py diff --git a/src/collectors/source_collectors/auto_googler/exceptions.py b/src/collectors/impl/auto_googler/exceptions.py similarity index 100% rename from src/collectors/source_collectors/auto_googler/exceptions.py rename to src/collectors/impl/auto_googler/exceptions.py diff --git a/src/db/models/instantiations/url/suggestion/record_type/__init__.py b/src/collectors/impl/auto_googler/queries/__init__.py similarity index 100% rename from src/db/models/instantiations/url/suggestion/record_type/__init__.py rename to src/collectors/impl/auto_googler/queries/__init__.py diff --git a/src/collectors/impl/auto_googler/queries/agency.py b/src/collectors/impl/auto_googler/queries/agency.py new file mode 100644 index 00000000..344ea31f --- /dev/null +++ b/src/collectors/impl/auto_googler/queries/agency.py @@ -0,0 +1,36 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.link.agency_batch.sqlalchemy import LinkAgencyBatch +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class AutoGooglerAddAgencyQueryBuilder(QueryBuilderBase): + + def __init__( + self, + batch_id: int, + agency_id: int, + ): + super().__init__() + self.batch_id = batch_id + self.agency_id = agency_id + + async def run(self, session: AsyncSession) -> str: + """Add link and return agency name.""" + + link = LinkAgencyBatch( + batch_id=self.batch_id, + agency_id=self.agency_id + ) + session.add(link) + + query = ( + select( + Agency.name + ) + ) + + return await sh.scalar(session, query=query) \ No newline at end of file diff --git a/src/collectors/impl/auto_googler/queries/location.py b/src/collectors/impl/auto_googler/queries/location.py new file mode 100644 index 00000000..b554176a --- /dev/null +++ b/src/collectors/impl/auto_googler/queries/location.py @@ -0,0 +1,39 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.link.location_batch.sqlalchemy import LinkLocationBatch +from src.db.models.views.location_expanded import LocationExpandedView +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class AutoGooglerAddLocationQueryBuilder(QueryBuilderBase): + + def __init__( + self, + batch_id: int, + location_id: int + ): + super().__init__() + self.batch_id = batch_id + self.location_id = location_id + + async def run(self, session: AsyncSession) -> str: + """Add link and return location name.""" + + link = LinkLocationBatch( + batch_id=self.batch_id, + location_id=self.location_id + ) + session.add(link) + + query = ( + select( + LocationExpandedView.full_display_name + ) + .where( + LocationExpandedView.id == self.location_id + ) + ) + + return await sh.scalar(session, query=query) diff --git a/src/collectors/impl/auto_googler/searcher.py b/src/collectors/impl/auto_googler/searcher.py new file mode 100644 index 00000000..cb877e25 --- /dev/null +++ b/src/collectors/impl/auto_googler/searcher.py @@ -0,0 +1,85 @@ +from typing import Union + +import aiohttp +from googleapiclient.errors import HttpError + +from src.collectors.impl.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO +from src.collectors.impl.auto_googler.exceptions import QuotaExceededError + + +class GoogleSearcher: + """ + A class that provides a GoogleSearcher object for performing searches using the Google Custom Search API. + + Attributes: + api_key (str): The API key required for accessing the Google Custom Search API. + cse_id (str): The CSE (Custom Search Engine) ID required for identifying the specific search engine to use. + service (Google API service): The Google API service object for performing the search. + + Methods: + __init__(api_key: str, cse_id: str) + Initializes a GoogleSearcher object with the provided API key and CSE ID. Raises a RuntimeError if either + the API key or CSE ID is None. + + search(query: str) -> Union[list[dict], None] + Performs a search using the Google Custom Search API with the provided query string. Returns a list of + search results as dictionaries or None if the daily quota for the API has been exceeded. Raises a RuntimeError + if any other error occurs during the search. + """ + GOOGLE_SEARCH_URL = "https://www.googleapis.com/customsearch/v1" + + def __init__( + self, + api_key: str, + cse_id: str + ): + if api_key is None or cse_id is None: + raise RuntimeError("Custom search API key and CSE ID cannot be None.") + self.api_key = api_key + self.cse_id = cse_id + + async def search(self, query: str) -> Union[list[dict], None]: + """ + Searches for results using the specified query. + + Args: + query (str): The query to search for. + + Returns: Union[list[dict], None]: A list of dictionaries representing the search results. + If the daily quota is exceeded, None is returned. + """ + try: + return await self.get_query_results(query) + # Process your results + except HttpError as e: + if "Quota exceeded" in str(e): + raise QuotaExceededError("Quota exceeded for the day") + else: + raise RuntimeError(f"An error occurred: {str(e)}") + + async def get_query_results(self, query) -> list[GoogleSearchQueryResultsInnerDTO] or None: + params = { + "key": self.api_key, + "cx": self.cse_id, + "q": query, + } + + async with aiohttp.ClientSession() as session: + async with session.get(self.GOOGLE_SEARCH_URL, params=params) as response: + response.raise_for_status() + results = await response.json() + + if "items" not in results: + return None + + items = [] + + for item in results["items"]: + inner_dto = GoogleSearchQueryResultsInnerDTO( + url=item["link"], + title=item["title"], + snippet=item.get("snippet", ""), + ) + items.append(inner_dto) + + return items diff --git a/src/collectors/impl/base.py b/src/collectors/impl/base.py new file mode 100644 index 00000000..c3986c64 --- /dev/null +++ b/src/collectors/impl/base.py @@ -0,0 +1,134 @@ +import abc +import asyncio +import time +from abc import ABC +from typing import Type, Optional + +from pydantic import BaseModel + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.models.impl.log.pydantic.info import LogInfo +from src.collectors.enums import CollectorType +from src.core.logger import AsyncCoreLogger +from src.core.function_trigger import FunctionTrigger +from src.core.enums import BatchStatus +from src.core.preprocessors.base import PreprocessorBase +from src.db.models.impl.url.core.pydantic.info import URLInfo + + +class AsyncCollectorBase(ABC): + collector_type: CollectorType = None + preprocessor: Type[PreprocessorBase] = None + + + def __init__( + self, + batch_id: int, + dto: BaseModel, + logger: AsyncCoreLogger, + adb_client: AsyncDatabaseClient, + raise_error: bool = False, + post_collection_function_trigger: Optional[FunctionTrigger] = None, + ) -> None: + self.post_collection_function_trigger = post_collection_function_trigger + self.batch_id = batch_id + self.adb_client = adb_client + self.dto = dto + self.data: Optional[BaseModel] = None + self.logger = logger + self.status = BatchStatus.IN_PROCESS + self.start_time = None + self.compute_time = None + self.raise_error = raise_error + + @abc.abstractmethod + async def run_implementation(self) -> None: + """ + This is the method that will be overridden by each collector + No other methods should be modified except for this one. + However, in each inherited class, new methods in addition to this one can be created + Returns: + + """ + raise NotImplementedError + + async def start_timer(self) -> None: + self.start_time = time.time() + + async def stop_timer(self) -> None: + self.compute_time = time.time() - self.start_time + + async def handle_error(self, e: Exception) -> None: + if self.raise_error: + raise e + await self.log(f"Error: {e}") + await self.adb_client.update_batch_post_collection( + batch_id=self.batch_id, + batch_status=self.status, + compute_time=self.compute_time, + total_url_count=0, + original_url_count=0, + duplicate_url_count=0 + ) + + async def process(self) -> None: + await self.log("Processing collector...") + preprocessor: PreprocessorBase = self.preprocessor() + url_infos: list[URLInfo] = preprocessor.preprocess(self.data) + await self.log(f"URLs processed: {len(url_infos)}") + + await self.log("Inserting URLs...") + insert_urls_info: InsertURLsInfo = await self.adb_client.insert_urls( + url_infos=url_infos, + batch_id=self.batch_id + ) + await self.log("Updating batch...") + await self.adb_client.update_batch_post_collection( + batch_id=self.batch_id, + total_url_count=insert_urls_info.total_count, + duplicate_url_count=insert_urls_info.duplicate_count, + original_url_count=insert_urls_info.original_count, + batch_status=self.status, + compute_time=self.compute_time + ) + await self.log("Done processing collector.") + + if self.post_collection_function_trigger is not None: + await self.post_collection_function_trigger.trigger_or_rerun() + + async def run(self) -> None: + try: + await self.start_timer() + await self.run_implementation() + await self.stop_timer() + await self.log("Collector completed successfully.") + await self.close() + await self.process() + except asyncio.CancelledError: + await self.stop_timer() + self.status = BatchStatus.ABORTED + await self.adb_client.update_batch_post_collection( + batch_id=self.batch_id, + batch_status=BatchStatus.ABORTED, + compute_time=self.compute_time, + total_url_count=0, + original_url_count=0, + duplicate_url_count=0 + ) + except Exception as e: + await self.stop_timer() + self.status = BatchStatus.ERROR + await self.handle_error(e) + + async def log( + self, + message: str, + ) -> None: + await self.logger.log(LogInfo( + batch_id=self.batch_id, + log=message + )) + + async def close(self) -> None: + self.status = BatchStatus.READY_TO_LABEL diff --git a/src/collectors/source_collectors/ckan/README.md b/src/collectors/impl/ckan/README.md similarity index 100% rename from src/collectors/source_collectors/ckan/README.md rename to src/collectors/impl/ckan/README.md diff --git a/src/db/models/instantiations/url/suggestion/relevant/__init__.py b/src/collectors/impl/ckan/__init__.py similarity index 100% rename from src/db/models/instantiations/url/suggestion/relevant/__init__.py rename to src/collectors/impl/ckan/__init__.py diff --git a/src/collectors/impl/ckan/collector.py b/src/collectors/impl/ckan/collector.py new file mode 100644 index 00000000..42390306 --- /dev/null +++ b/src/collectors/impl/ckan/collector.py @@ -0,0 +1,71 @@ +from pydantic import BaseModel + +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.enums import CollectorType +from src.core.preprocessors.ckan import CKANPreprocessor +from src.collectors.impl.ckan.dtos.input import CKANInputDTO +from src.collectors.impl.ckan.scraper_toolkit.search_funcs.group import ckan_group_package_search +from src.collectors.impl.ckan.scraper_toolkit.search_funcs.organization import ckan_package_search_from_organization +from src.collectors.impl.ckan.scraper_toolkit.search_funcs.package import ckan_package_search +from src.collectors.impl.ckan.scraper_toolkit.search import perform_search, get_flat_list, deduplicate_entries, \ + get_collections, filter_result, parse_result +from src.util.helper_functions import base_model_list_dump + +SEARCH_FUNCTION_MAPPINGS = { + "package_search": ckan_package_search, + "group_search": ckan_group_package_search, + "organization_search": ckan_package_search_from_organization +} + +class CKANCollector(AsyncCollectorBase): + collector_type = CollectorType.CKAN + preprocessor = CKANPreprocessor + + async def run_implementation(self): + results = await self.get_results() + flat_list = get_flat_list(results) + deduped_flat_list = deduplicate_entries(flat_list) + + list_with_collection_child_packages = await self.add_collection_child_packages(deduped_flat_list) + + filtered_results = list( + filter( + filter_result, + list_with_collection_child_packages + ) + ) + parsed_results = list(map(parse_result, filtered_results)) + + self.data = {"results": parsed_results} + + async def add_collection_child_packages(self, deduped_flat_list): + # TODO: Find a way to clearly indicate which parts call from the CKAN API + list_with_collection_child_packages = [] + count = len(deduped_flat_list) + for idx, result in enumerate(deduped_flat_list): + if "extras" in result.keys(): + await self.log(f"Found collection ({idx + 1}/{count}): {result['id']}") + collections = await get_collections(result) + if collections: + list_with_collection_child_packages += collections[0] + continue + + list_with_collection_child_packages.append(result) + return list_with_collection_child_packages + + async def get_results(self): + results = [] + dto: CKANInputDTO = self.dto + for search in SEARCH_FUNCTION_MAPPINGS.keys(): + await self.log(f"Running search '{search}'...") + sub_dtos: list[BaseModel] = getattr(dto, search) + if sub_dtos is None: + continue + func = SEARCH_FUNCTION_MAPPINGS[search] + results = await perform_search( + search_func=func, + search_terms=base_model_list_dump(model_list=sub_dtos), + results=results + ) + return results + diff --git a/src/collectors/source_collectors/ckan/constants.py b/src/collectors/impl/ckan/constants.py similarity index 100% rename from src/collectors/source_collectors/ckan/constants.py rename to src/collectors/impl/ckan/constants.py diff --git a/src/db/queries/implementations/core/tasks/__init__.py b/src/collectors/impl/ckan/dtos/__init__.py similarity index 100% rename from src/db/queries/implementations/core/tasks/__init__.py rename to src/collectors/impl/ckan/dtos/__init__.py diff --git a/src/collectors/impl/ckan/dtos/input.py b/src/collectors/impl/ckan/dtos/input.py new file mode 100644 index 00000000..315bcafd --- /dev/null +++ b/src/collectors/impl/ckan/dtos/input.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel, Field + +from src.collectors.impl.ckan.dtos.search.group_and_organization import GroupAndOrganizationSearchDTO +from src.collectors.impl.ckan.dtos.search.package import CKANPackageSearchDTO + + +class CKANInputDTO(BaseModel): + package_search: list[CKANPackageSearchDTO] or None = Field( + description="The list of package searches to perform.", + default=None + ) + group_search: list[GroupAndOrganizationSearchDTO] or None = Field( + description="The list of group searches to perform.", + default=None + ) + organization_search: list[GroupAndOrganizationSearchDTO] or None = Field( + description="The list of organization searches to perform.", + default=None + ) diff --git a/src/collectors/source_collectors/ckan/dtos/package.py b/src/collectors/impl/ckan/dtos/package.py similarity index 100% rename from src/collectors/source_collectors/ckan/dtos/package.py rename to src/collectors/impl/ckan/dtos/package.py diff --git a/src/db/queries/implementations/core/tasks/agency_sync/__init__.py b/src/collectors/impl/ckan/dtos/search/__init__.py similarity index 100% rename from src/db/queries/implementations/core/tasks/agency_sync/__init__.py rename to src/collectors/impl/ckan/dtos/search/__init__.py diff --git a/src/collectors/source_collectors/ckan/dtos/search/_helpers.py b/src/collectors/impl/ckan/dtos/search/_helpers.py similarity index 100% rename from src/collectors/source_collectors/ckan/dtos/search/_helpers.py rename to src/collectors/impl/ckan/dtos/search/_helpers.py diff --git a/src/collectors/source_collectors/ckan/dtos/search/group_and_organization.py b/src/collectors/impl/ckan/dtos/search/group_and_organization.py similarity index 76% rename from src/collectors/source_collectors/ckan/dtos/search/group_and_organization.py rename to src/collectors/impl/ckan/dtos/search/group_and_organization.py index da413ce1..4a352321 100644 --- a/src/collectors/source_collectors/ckan/dtos/search/group_and_organization.py +++ b/src/collectors/impl/ckan/dtos/search/group_and_organization.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field -from src.collectors.source_collectors.ckan.dtos.search._helpers import url_field +from src.collectors.impl.ckan.dtos.search._helpers import url_field class GroupAndOrganizationSearchDTO(BaseModel): diff --git a/src/collectors/impl/ckan/dtos/search/package.py b/src/collectors/impl/ckan/dtos/search/package.py new file mode 100644 index 00000000..3ef73d1a --- /dev/null +++ b/src/collectors/impl/ckan/dtos/search/package.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel, Field + +from src.collectors.impl.ckan.dtos.search._helpers import url_field + + +class CKANPackageSearchDTO(BaseModel): + url: str = url_field + terms: Optional[list[str]] = Field( + description="The search terms to use to refine the packages returned. " + "None will return all packages.", + default=None + ) diff --git a/src/collectors/source_collectors/ckan/exceptions.py b/src/collectors/impl/ckan/exceptions.py similarity index 100% rename from src/collectors/source_collectors/ckan/exceptions.py rename to src/collectors/impl/ckan/exceptions.py diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/README.md b/src/collectors/impl/ckan/scraper_toolkit/README.md similarity index 100% rename from src/collectors/source_collectors/ckan/scraper_toolkit/README.md rename to src/collectors/impl/ckan/scraper_toolkit/README.md diff --git a/tests/automated/integration/api/review/__init__.py b/src/collectors/impl/ckan/scraper_toolkit/__init__.py similarity index 100% rename from tests/automated/integration/api/review/__init__.py rename to src/collectors/impl/ckan/scraper_toolkit/__init__.py diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/_api_interface.py b/src/collectors/impl/ckan/scraper_toolkit/_api_interface.py similarity index 96% rename from src/collectors/source_collectors/ckan/scraper_toolkit/_api_interface.py rename to src/collectors/impl/ckan/scraper_toolkit/_api_interface.py index d94c1516..8f557f3f 100644 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/_api_interface.py +++ b/src/collectors/impl/ckan/scraper_toolkit/_api_interface.py @@ -3,7 +3,7 @@ import aiohttp from aiohttp import ContentTypeError -from src.collectors.source_collectors.ckan.exceptions import CKANAPIError +from src.collectors.impl.ckan.exceptions import CKANAPIError class CKANAPIInterface: diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search.py b/src/collectors/impl/ckan/scraper_toolkit/search.py similarity index 96% rename from src/collectors/source_collectors/ckan/scraper_toolkit/search.py rename to src/collectors/impl/ckan/scraper_toolkit/search.py index 5bf686d1..7cd24b27 100644 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/search.py +++ b/src/collectors/impl/ckan/scraper_toolkit/search.py @@ -7,9 +7,9 @@ from from_root import from_root from tqdm import tqdm -from src.collectors.source_collectors.ckan.scraper_toolkit.search_funcs.collection import ckan_collection_search -from src.collectors.source_collectors.ckan.dtos.package import Package -from src.collectors.source_collectors.ckan.constants import CKAN_DATA_TYPES, CKAN_TYPE_CONVERSION_MAPPING +from src.collectors.impl.ckan.scraper_toolkit.search_funcs.collection import ckan_collection_search +from src.collectors.impl.ckan.dtos.package import Package +from src.collectors.impl.ckan.constants import CKAN_DATA_TYPES, CKAN_TYPE_CONVERSION_MAPPING p = from_root(".pydocstyle").parent sys.path.insert(1, str(p)) diff --git a/tests/automated/integration/api/review/rejection/__init__.py b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/__init__.py similarity index 100% rename from tests/automated/integration/api/review/rejection/__init__.py rename to src/collectors/impl/ckan/scraper_toolkit/search_funcs/__init__.py diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/collection.py b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/collection.py similarity index 98% rename from src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/collection.py rename to src/collectors/impl/ckan/scraper_toolkit/search_funcs/collection.py index 07fcd0f9..cd275fc0 100644 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/collection.py +++ b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/collection.py @@ -7,7 +7,7 @@ import aiohttp from bs4 import ResultSet, Tag, BeautifulSoup -from src.collectors.source_collectors.ckan.dtos.package import Package +from src.collectors.impl.ckan.dtos.package import Package async def ckan_collection_search(base_url: str, collection_id: str) -> list[Package]: diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/group.py b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/group.py similarity index 88% rename from src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/group.py rename to src/collectors/impl/ckan/scraper_toolkit/search_funcs/group.py index 1c0a296d..b74d32f2 100644 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/group.py +++ b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/group.py @@ -1,7 +1,7 @@ import sys from typing import Optional, Any -from src.collectors.source_collectors.ckan.scraper_toolkit._api_interface import CKANAPIInterface +from src.collectors.impl.ckan.scraper_toolkit._api_interface import CKANAPIInterface async def ckan_group_package_search( diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/organization.py b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/organization.py similarity index 82% rename from src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/organization.py rename to src/collectors/impl/ckan/scraper_toolkit/search_funcs/organization.py index 45ff6767..6f53ce52 100644 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/organization.py +++ b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/organization.py @@ -1,7 +1,7 @@ from typing import Any -from src.collectors.source_collectors.ckan.scraper_toolkit._api_interface import CKANAPIInterface -from src.collectors.source_collectors.ckan.scraper_toolkit.search_funcs.package import ckan_package_search +from src.collectors.impl.ckan.scraper_toolkit._api_interface import CKANAPIInterface +from src.collectors.impl.ckan.scraper_toolkit.search_funcs.package import ckan_package_search async def ckan_package_search_from_organization( diff --git a/src/collectors/impl/ckan/scraper_toolkit/search_funcs/package.py b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/package.py new file mode 100644 index 00000000..e6bb2495 --- /dev/null +++ b/src/collectors/impl/ckan/scraper_toolkit/search_funcs/package.py @@ -0,0 +1,52 @@ +import sys +from typing import Optional, Any + +from src.collectors.impl.ckan.scraper_toolkit._api_interface import CKANAPIInterface + + +async def ckan_package_search( + base_url: str, + query: Optional[str] = None, + rows: Optional[int] = sys.maxsize, + start: Optional[int] = 0, + **kwargs, +) -> list[dict[str, Any]]: + """Performs a CKAN package (dataset) search from a CKAN data catalog URL. + + :param base_url: Base URL to search from. e.g. "https://catalog.data.gov/" + :param query: Search string, defaults to None. None will return all packages. + :param rows: Maximum number of results to return, defaults to maximum integer. + :param start: Offsets the results, defaults to 0. + :param kwargs: See https://docs.ckan.org/en/2.10/api/index.html#ckan.logic.action.get.package_search for additional arguments. + :return: List of dictionaries representing the CKAN package search results. + """ + interface = CKANAPIInterface(base_url) + results = [] + offset = start + rows_max = 1000 # CKAN's package search has a hard limit of 1000 packages returned at a time by default + + while start < rows: + num_rows = rows - start + offset + packages: dict = await interface.package_search( + query=query, rows=num_rows, start=start, **kwargs + ) + add_base_url_to_packages(base_url, packages) + results += packages["results"] + + total_results = packages["count"] + if rows > total_results: + rows = total_results + + result_len = len(packages["results"]) + # Check if the website has a different rows_max value than CKAN's default + if result_len != rows_max and start + rows_max < total_results: + rows_max = result_len + + start += rows_max + + return results + + +def add_base_url_to_packages(base_url, packages): + # Add the base_url to each package + [package.update(base_url=base_url) for package in packages["results"]] diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/__init__.py b/src/collectors/impl/common_crawler/__init__.py similarity index 100% rename from tests/automated/integration/db/client/get_next_url_for_final_review/__init__.py rename to src/collectors/impl/common_crawler/__init__.py diff --git a/src/collectors/impl/common_crawler/collector.py b/src/collectors/impl/common_crawler/collector.py new file mode 100644 index 00000000..f390ef71 --- /dev/null +++ b/src/collectors/impl/common_crawler/collector.py @@ -0,0 +1,25 @@ +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.enums import CollectorType +from src.core.preprocessors.common_crawler import CommonCrawlerPreprocessor +from src.collectors.impl.common_crawler.crawler import CommonCrawler +from src.collectors.impl.common_crawler.input import CommonCrawlerInputDTO + + +class CommonCrawlerCollector(AsyncCollectorBase): + collector_type = CollectorType.COMMON_CRAWLER + preprocessor = CommonCrawlerPreprocessor + + async def run_implementation(self) -> None: + print("Running Common Crawler...") + dto: CommonCrawlerInputDTO = self.dto + common_crawler = CommonCrawler( + crawl_id=dto.common_crawl_id, + url=dto.url, + keyword=dto.search_term, + start_page=dto.start_page, + num_pages=dto.total_pages, + ) + async for status in common_crawler.run(): + await self.log(status) + + self.data = {"urls": common_crawler.url_results} \ No newline at end of file diff --git a/src/collectors/source_collectors/common_crawler/crawler.py b/src/collectors/impl/common_crawler/crawler.py similarity index 98% rename from src/collectors/source_collectors/common_crawler/crawler.py rename to src/collectors/impl/common_crawler/crawler.py index ca4f7ca9..f963aa4a 100644 --- a/src/collectors/source_collectors/common_crawler/crawler.py +++ b/src/collectors/impl/common_crawler/crawler.py @@ -6,7 +6,7 @@ import aiohttp -from src.collectors.source_collectors.common_crawler.utils import URLWithParameters +from src.collectors.impl.common_crawler.utils import URLWithParameters async def async_make_request( search_url: 'URLWithParameters' diff --git a/src/collectors/source_collectors/common_crawler/input.py b/src/collectors/impl/common_crawler/input.py similarity index 100% rename from src/collectors/source_collectors/common_crawler/input.py rename to src/collectors/impl/common_crawler/input.py diff --git a/src/collectors/source_collectors/common_crawler/utils.py b/src/collectors/impl/common_crawler/utils.py similarity index 100% rename from src/collectors/source_collectors/common_crawler/utils.py rename to src/collectors/impl/common_crawler/utils.py diff --git a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/__init__.py b/src/collectors/impl/example/__init__.py similarity index 100% rename from tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/__init__.py rename to src/collectors/impl/example/__init__.py diff --git a/src/collectors/impl/example/core.py b/src/collectors/impl/example/core.py new file mode 100644 index 00000000..4bccf242 --- /dev/null +++ b/src/collectors/impl/example/core.py @@ -0,0 +1,34 @@ +""" +Example collector +Exists as a proof of concept for collector functionality + +""" +import asyncio + +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.impl.example.dtos.input import ExampleInputDTO +from src.collectors.impl.example.dtos.output import ExampleOutputDTO +from src.collectors.enums import CollectorType +from src.core.preprocessors.example import ExamplePreprocessor + + +class ExampleCollector(AsyncCollectorBase): + collector_type = CollectorType.EXAMPLE + preprocessor = ExamplePreprocessor + + async def run_implementation(self) -> None: + dto: ExampleInputDTO = self.dto + sleep_time = dto.sleep_time + for i in range(sleep_time): # Simulate a task + await self.log(f"Step {i + 1}/{sleep_time}") + await self.sleep() + self.data = ExampleOutputDTO( + message=f"Data collected by {self.batch_id}", + urls=["https://example.com", "https://example.com/2"], + parameters=self.dto.model_dump(), + ) + + @staticmethod + async def sleep(): + # Simulate work + await asyncio.sleep(1) \ No newline at end of file diff --git a/tests/automated/integration/html_tag_collector/__init__.py b/src/collectors/impl/example/dtos/__init__.py similarity index 100% rename from tests/automated/integration/html_tag_collector/__init__.py rename to src/collectors/impl/example/dtos/__init__.py diff --git a/src/collectors/source_collectors/example/dtos/input.py b/src/collectors/impl/example/dtos/input.py similarity index 100% rename from src/collectors/source_collectors/example/dtos/input.py rename to src/collectors/impl/example/dtos/input.py diff --git a/src/collectors/source_collectors/example/dtos/output.py b/src/collectors/impl/example/dtos/output.py similarity index 100% rename from src/collectors/source_collectors/example/dtos/output.py rename to src/collectors/impl/example/dtos/output.py diff --git a/src/collectors/source_collectors/muckrock/README.md b/src/collectors/impl/muckrock/README.md similarity index 100% rename from src/collectors/source_collectors/muckrock/README.md rename to src/collectors/impl/muckrock/README.md diff --git a/tests/automated/integration/tasks/scheduled/agency_sync/__init__.py b/src/collectors/impl/muckrock/__init__.py similarity index 100% rename from tests/automated/integration/tasks/scheduled/agency_sync/__init__.py rename to src/collectors/impl/muckrock/__init__.py diff --git a/tests/automated/integration/tasks/url/auto_relevant/__init__.py b/src/collectors/impl/muckrock/api_interface/__init__.py similarity index 100% rename from tests/automated/integration/tasks/url/auto_relevant/__init__.py rename to src/collectors/impl/muckrock/api_interface/__init__.py diff --git a/src/collectors/impl/muckrock/api_interface/core.py b/src/collectors/impl/muckrock/api_interface/core.py new file mode 100644 index 00000000..4dd97572 --- /dev/null +++ b/src/collectors/impl/muckrock/api_interface/core.py @@ -0,0 +1,40 @@ +from typing import Optional + +import requests +from aiohttp import ClientSession + +from src.collectors.impl.muckrock.api_interface.lookup_response import AgencyLookupResponse +from src.collectors.impl.muckrock.enums import AgencyLookupResponseType + + +class MuckrockAPIInterface: + + def __init__(self, session: Optional[ClientSession] = None): + self.base_url = "https://www.muckrock.com/api_v1/" + self.session = session + + def build_url(self, subpath: str): + return f"{self.base_url}{subpath}" + + + async def lookup_agency(self, muckrock_agency_id: int) -> AgencyLookupResponse: + url = self.build_url(f"agency/{muckrock_agency_id}") + try: + async with self.session.get(url) as results: + results.raise_for_status() + json = await results.json() + name = json["name"] + return AgencyLookupResponse( + name=name, type=AgencyLookupResponseType.FOUND + ) + except requests.exceptions.HTTPError as e: + return AgencyLookupResponse( + name=None, + type=AgencyLookupResponseType.ERROR, + error=str(e) + ) + except KeyError: + return AgencyLookupResponse( + name=None, type=AgencyLookupResponseType.NOT_FOUND + ) + diff --git a/src/collectors/impl/muckrock/api_interface/lookup_response.py b/src/collectors/impl/muckrock/api_interface/lookup_response.py new file mode 100644 index 00000000..d1fd9635 --- /dev/null +++ b/src/collectors/impl/muckrock/api_interface/lookup_response.py @@ -0,0 +1,11 @@ +from typing import Optional + +from pydantic import BaseModel + +from src.collectors.impl.muckrock.enums import AgencyLookupResponseType + + +class AgencyLookupResponse(BaseModel): + name: str | None + type: AgencyLookupResponseType + error: str | None = None diff --git a/tests/automated/integration/tasks/url/duplicate/__init__.py b/src/collectors/impl/muckrock/collectors/__init__.py similarity index 100% rename from tests/automated/integration/tasks/url/duplicate/__init__.py rename to src/collectors/impl/muckrock/collectors/__init__.py diff --git a/tests/automated/integration/tasks/url/html/__init__.py b/src/collectors/impl/muckrock/collectors/all_foia/__init__.py similarity index 100% rename from tests/automated/integration/tasks/url/html/__init__.py rename to src/collectors/impl/muckrock/collectors/all_foia/__init__.py diff --git a/src/collectors/impl/muckrock/collectors/all_foia/core.py b/src/collectors/impl/muckrock/collectors/all_foia/core.py new file mode 100644 index 00000000..f4249b2a --- /dev/null +++ b/src/collectors/impl/muckrock/collectors/all_foia/core.py @@ -0,0 +1,50 @@ +from src.collectors.enums import CollectorType +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.impl.muckrock.collectors.all_foia.dto import MuckrockAllFOIARequestsCollectorInputDTO +from src.collectors.impl.muckrock.fetchers.foia.core import FOIAFetcher +from src.collectors.impl.muckrock.exceptions import MuckrockNoMoreDataError +from src.core.preprocessors.muckrock import MuckrockPreprocessor + + +class MuckrockAllFOIARequestsCollector(AsyncCollectorBase): + """ + Retrieves urls associated with all Muckrock FOIA requests + """ + collector_type = CollectorType.MUCKROCK_ALL_SEARCH + preprocessor = MuckrockPreprocessor + + async def run_implementation(self) -> None: + dto: MuckrockAllFOIARequestsCollectorInputDTO = self.dto + start_page = dto.start_page + fetcher = FOIAFetcher( + start_page=start_page, + ) + total_pages = dto.total_pages + all_page_data = await self.get_page_data(fetcher, start_page, total_pages) + all_transformed_data = self.transform_data(all_page_data) + self.data = {"urls": all_transformed_data} + + + async def get_page_data(self, fetcher, start_page, total_pages): + all_page_data = [] + for page in range(start_page, start_page + total_pages): + await self.log(f"Fetching page {fetcher.current_page}") + try: + page_data = await fetcher.fetch_next_page() + except MuckrockNoMoreDataError: + await self.log(f"No more data to fetch at page {fetcher.current_page}") + break + if page_data is None: + continue + all_page_data.append(page_data) + return all_page_data + + def transform_data(self, all_page_data): + all_transformed_data = [] + for page_data in all_page_data: + for data in page_data["results"]: + all_transformed_data.append({ + "url": data["absolute_url"], + "metadata": data + }) + return all_transformed_data diff --git a/src/collectors/source_collectors/muckrock/collectors/all_foia/dto.py b/src/collectors/impl/muckrock/collectors/all_foia/dto.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/all_foia/dto.py rename to src/collectors/impl/muckrock/collectors/all_foia/dto.py diff --git a/tests/automated/integration/tasks/url/html/mocks/__init__.py b/src/collectors/impl/muckrock/collectors/county/__init__.py similarity index 100% rename from tests/automated/integration/tasks/url/html/mocks/__init__.py rename to src/collectors/impl/muckrock/collectors/county/__init__.py diff --git a/src/collectors/impl/muckrock/collectors/county/core.py b/src/collectors/impl/muckrock/collectors/county/core.py new file mode 100644 index 00000000..50c79470 --- /dev/null +++ b/src/collectors/impl/muckrock/collectors/county/core.py @@ -0,0 +1,60 @@ +from src.collectors.enums import CollectorType +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.impl.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO +from src.collectors.impl.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest +from src.collectors.impl.muckrock.fetch_requests.jurisdiction_loop import \ + JurisdictionLoopFetchRequest +from src.collectors.impl.muckrock.fetchers.foia.loop import FOIALoopFetcher +from src.collectors.impl.muckrock.fetchers.jurisdiction.generator import \ + JurisdictionGeneratorFetcher +from src.core.preprocessors.muckrock import MuckrockPreprocessor + + +class MuckrockCountyLevelSearchCollector(AsyncCollectorBase): + """ + Searches for any and all requests in a certain county + """ + collector_type = CollectorType.MUCKROCK_COUNTY_SEARCH + preprocessor = MuckrockPreprocessor + + async def run_implementation(self) -> None: + jurisdiction_ids = await self.get_jurisdiction_ids() + if jurisdiction_ids is None: + await self.log("No jurisdictions found") + return + all_data = await self.get_foia_records(jurisdiction_ids) + formatted_data = self.format_data(all_data) + self.data = {"urls": formatted_data} + + def format_data(self, all_data): + formatted_data = [] + for data in all_data: + formatted_data.append({ + "url": data["absolute_url"], + "metadata": data + }) + return formatted_data + + async def get_foia_records(self, jurisdiction_ids): + all_data = [] + for name, id_ in jurisdiction_ids.items(): + await self.log(f"Fetching records for {name}...") + request = FOIALoopFetchRequest(jurisdiction=id_) + fetcher = FOIALoopFetcher(request) + await fetcher.loop_fetch() + all_data.extend(fetcher.ffm.results) + return all_data + + async def get_jurisdiction_ids(self): + dto: MuckrockCountySearchCollectorInputDTO = self.dto + parent_jurisdiction_id = dto.parent_jurisdiction_id + request = JurisdictionLoopFetchRequest( + level="l", + parent=parent_jurisdiction_id, + town_names=dto.town_names + ) + fetcher = JurisdictionGeneratorFetcher(initial_request=request) + async for message in fetcher.generator_fetch(): + await self.log(message) + jurisdiction_ids = fetcher.jfm.jurisdictions + return jurisdiction_ids diff --git a/src/collectors/source_collectors/muckrock/collectors/county/dto.py b/src/collectors/impl/muckrock/collectors/county/dto.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/county/dto.py rename to src/collectors/impl/muckrock/collectors/county/dto.py diff --git a/tests/automated/unit/dto/__init__.py b/src/collectors/impl/muckrock/collectors/simple/__init__.py similarity index 100% rename from tests/automated/unit/dto/__init__.py rename to src/collectors/impl/muckrock/collectors/simple/__init__.py diff --git a/src/collectors/impl/muckrock/collectors/simple/core.py b/src/collectors/impl/muckrock/collectors/simple/core.py new file mode 100644 index 00000000..1470b7c1 --- /dev/null +++ b/src/collectors/impl/muckrock/collectors/simple/core.py @@ -0,0 +1,58 @@ +import itertools + +from src.collectors.enums import CollectorType +from src.collectors.impl.base import AsyncCollectorBase +from src.collectors.impl.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO +from src.collectors.impl.muckrock.collectors.simple.searcher import FOIASearcher +from src.collectors.impl.muckrock.fetchers.foia.core import FOIAFetcher +from src.collectors.impl.muckrock.exceptions import SearchCompleteException +from src.core.preprocessors.muckrock import MuckrockPreprocessor + + +class MuckrockSimpleSearchCollector(AsyncCollectorBase): + """ + Performs searches on MuckRock's database + by matching a search string to title of request + """ + collector_type = CollectorType.MUCKROCK_SIMPLE_SEARCH + preprocessor = MuckrockPreprocessor + + def check_for_count_break(self, count, max_count) -> None: + if max_count is None: + return + if count >= max_count: + raise SearchCompleteException + + async def run_implementation(self) -> None: + fetcher = FOIAFetcher() + dto: MuckrockSimpleSearchCollectorInputDTO = self.dto + searcher = FOIASearcher( + fetcher=fetcher, + search_term=dto.search_string + ) + max_count = dto.max_results + all_results = [] + results_count = 0 + for search_count in itertools.count(): + try: + results = await searcher.get_next_page_results() + all_results.extend(results) + results_count += len(results) + self.check_for_count_break(results_count, max_count) + except SearchCompleteException: + break + await self.log(f"Search {search_count}: Found {len(results)} results") + + await self.log(f"Search Complete. Total results: {results_count}") + self.data = {"urls": self.format_results(all_results)} + + def format_results(self, results: list[dict]) -> list[dict]: + formatted_results = [] + for result in results: + formatted_result = { + "url": result["absolute_url"], + "metadata": result + } + formatted_results.append(formatted_result) + + return formatted_results diff --git a/src/collectors/source_collectors/muckrock/collectors/simple/dto.py b/src/collectors/impl/muckrock/collectors/simple/dto.py similarity index 100% rename from src/collectors/source_collectors/muckrock/collectors/simple/dto.py rename to src/collectors/impl/muckrock/collectors/simple/dto.py diff --git a/src/collectors/impl/muckrock/collectors/simple/searcher.py b/src/collectors/impl/muckrock/collectors/simple/searcher.py new file mode 100644 index 00000000..2f326a5d --- /dev/null +++ b/src/collectors/impl/muckrock/collectors/simple/searcher.py @@ -0,0 +1,43 @@ +from typing import Optional + +from src.collectors.impl.muckrock.fetchers.foia.core import FOIAFetcher +from src.collectors.impl.muckrock.exceptions import SearchCompleteException + + +class FOIASearcher: + """ + Used for searching FOIA data from MuckRock + """ + + def __init__(self, fetcher: FOIAFetcher, search_term: Optional[str] = None): + self.fetcher = fetcher + self.search_term = search_term + + async def fetch_page(self) -> list[dict] | None: + """ + Fetches the next page of results using the fetcher. + """ + data = await self.fetcher.fetch_next_page() + if data is None or data.get("results") is None: + return None + return data.get("results") + + def filter_results(self, results: list[dict]) -> list[dict]: + """ + Filters the results based on the search term. + Override or modify as needed for custom filtering logic. + """ + if self.search_term: + return [result for result in results if self.search_term.lower() in result["title"].lower()] + return results + + + async def get_next_page_results(self) -> list[dict]: + """ + Fetches and processes the next page of results. + """ + results = await self.fetch_page() + if not results: + raise SearchCompleteException + return self.filter_results(results) + diff --git a/src/collectors/source_collectors/muckrock/constants.py b/src/collectors/impl/muckrock/constants.py similarity index 100% rename from src/collectors/source_collectors/muckrock/constants.py rename to src/collectors/impl/muckrock/constants.py diff --git a/src/collectors/source_collectors/muckrock/enums.py b/src/collectors/impl/muckrock/enums.py similarity index 100% rename from src/collectors/source_collectors/muckrock/enums.py rename to src/collectors/impl/muckrock/enums.py diff --git a/src/collectors/source_collectors/muckrock/exceptions.py b/src/collectors/impl/muckrock/exceptions.py similarity index 100% rename from src/collectors/source_collectors/muckrock/exceptions.py rename to src/collectors/impl/muckrock/exceptions.py diff --git a/tests/manual/migration_with_prod_data/__init__.py b/src/collectors/impl/muckrock/fetch_requests/__init__.py similarity index 100% rename from tests/manual/migration_with_prod_data/__init__.py rename to src/collectors/impl/muckrock/fetch_requests/__init__.py diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/base.py b/src/collectors/impl/muckrock/fetch_requests/base.py similarity index 100% rename from src/collectors/source_collectors/muckrock/fetch_requests/base.py rename to src/collectors/impl/muckrock/fetch_requests/base.py diff --git a/src/collectors/impl/muckrock/fetch_requests/foia.py b/src/collectors/impl/muckrock/fetch_requests/foia.py new file mode 100644 index 00000000..87a66811 --- /dev/null +++ b/src/collectors/impl/muckrock/fetch_requests/foia.py @@ -0,0 +1,6 @@ +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest + + +class FOIAFetchRequest(FetchRequest): + page: int + page_size: int diff --git a/src/collectors/impl/muckrock/fetch_requests/foia_loop.py b/src/collectors/impl/muckrock/fetch_requests/foia_loop.py new file mode 100644 index 00000000..0371eeae --- /dev/null +++ b/src/collectors/impl/muckrock/fetch_requests/foia_loop.py @@ -0,0 +1,5 @@ +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest + + +class FOIALoopFetchRequest(FetchRequest): + jurisdiction: int diff --git a/src/collectors/impl/muckrock/fetch_requests/jurisdiction_by_id.py b/src/collectors/impl/muckrock/fetch_requests/jurisdiction_by_id.py new file mode 100644 index 00000000..22d23f74 --- /dev/null +++ b/src/collectors/impl/muckrock/fetch_requests/jurisdiction_by_id.py @@ -0,0 +1,5 @@ +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest + + +class JurisdictionByIDFetchRequest(FetchRequest): + jurisdiction_id: int diff --git a/src/collectors/impl/muckrock/fetch_requests/jurisdiction_loop.py b/src/collectors/impl/muckrock/fetch_requests/jurisdiction_loop.py new file mode 100644 index 00000000..369fbeed --- /dev/null +++ b/src/collectors/impl/muckrock/fetch_requests/jurisdiction_loop.py @@ -0,0 +1,7 @@ +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest + + +class JurisdictionLoopFetchRequest(FetchRequest): + level: str + parent: int + town_names: list diff --git a/api/main.py b/src/collectors/impl/muckrock/fetchers/__init__.py similarity index 100% rename from api/main.py rename to src/collectors/impl/muckrock/fetchers/__init__.py diff --git a/src/collectors/impl/muckrock/fetchers/foia/__init__.py b/src/collectors/impl/muckrock/fetchers/foia/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/impl/muckrock/fetchers/foia/core.py b/src/collectors/impl/muckrock/fetchers/foia/core.py new file mode 100644 index 00000000..c6c51d94 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/foia/core.py @@ -0,0 +1,36 @@ +from src.collectors.impl.muckrock.fetch_requests.foia import FOIAFetchRequest +from src.collectors.impl.muckrock.fetchers.templates.fetcher import MuckrockFetcherBase +from src.collectors.impl.muckrock.constants import BASE_MUCKROCK_URL + +FOIA_BASE_URL = f"{BASE_MUCKROCK_URL}/foia" + + +class FOIAFetcher(MuckrockFetcherBase): + """ + A fetcher for FOIA requests. + Iterates through all FOIA requests available through the MuckRock FOIA API. + """ + + def __init__(self, start_page: int = 1, per_page: int = 100): + """ + Constructor for the FOIAFetcher class. + + Args: + start_page (int): The page number to start fetching from (default is 1). + per_page (int): The number of results to fetch per page (default is 100). + """ + self.current_page = start_page + self.per_page = per_page + + def build_url(self, request: FOIAFetchRequest) -> str: + return f"{FOIA_BASE_URL}?page={request.page}&page_size={request.page_size}&format=json" + + async def fetch_next_page(self) -> dict | None: + """ + Fetches data from a specific page of the MuckRock FOIA API. + """ + page = self.current_page + self.current_page += 1 + request = FOIAFetchRequest(page=page, page_size=self.per_page) + return await self.fetch(request) + diff --git a/src/collectors/impl/muckrock/fetchers/foia/generator.py b/src/collectors/impl/muckrock/fetchers/foia/generator.py new file mode 100644 index 00000000..9260f43b --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/foia/generator.py @@ -0,0 +1,16 @@ +from src.collectors.impl.muckrock.fetch_requests import FOIALoopFetchRequest +from src.collectors.impl.muckrock.fetchers.foia.manager import FOIAFetchManager +from src.collectors.impl.muckrock.fetchers.templates.generator import MuckrockGeneratorFetcher + + +class FOIAGeneratorFetcher(MuckrockGeneratorFetcher): + + def __init__(self, initial_request: FOIALoopFetchRequest): + super().__init__(initial_request) + self.ffm = FOIAFetchManager() + + def process_results(self, results: list[dict]): + self.ffm.process_results(results) + return (f"Loop {self.ffm.loop_count}: " + f"Found {self.ffm.num_found_last_loop} FOIA records;" + f"{self.ffm.num_found} FOIA records found total.") diff --git a/src/collectors/impl/muckrock/fetchers/foia/loop.py b/src/collectors/impl/muckrock/fetchers/foia/loop.py new file mode 100644 index 00000000..44b4b845 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/foia/loop.py @@ -0,0 +1,25 @@ +from datasets import tqdm + +from src.collectors.impl.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest +from src.collectors.impl.muckrock.fetchers.foia.manager import FOIAFetchManager +from src.collectors.impl.muckrock.fetchers.templates.loop import MuckrockLoopFetcher + + +class FOIALoopFetcher(MuckrockLoopFetcher): + + def __init__(self, initial_request: FOIALoopFetchRequest): + super().__init__(initial_request) + self.pbar_records = tqdm( + desc="Fetching FOIA records", + unit="record", + ) + self.ffm = FOIAFetchManager() + + def process_results(self, results: list[dict]): + self.ffm.process_results(results) + + def build_url(self, request: FOIALoopFetchRequest): + return self.ffm.build_url(request) + + def report_progress(self): + self.pbar_records.update(self.ffm.num_found_last_loop) diff --git a/src/collectors/impl/muckrock/fetchers/foia/manager.py b/src/collectors/impl/muckrock/fetchers/foia/manager.py new file mode 100644 index 00000000..09f71a59 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/foia/manager.py @@ -0,0 +1,20 @@ +from src.collectors.impl.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest +from src.collectors.impl.muckrock.constants import BASE_MUCKROCK_URL + + +class FOIAFetchManager: + + def __init__(self): + self.num_found = 0 + self.loop_count = 0 + self.num_found_last_loop = 0 + self.results = [] + + def build_url(self, request: FOIALoopFetchRequest): + return f"{BASE_MUCKROCK_URL}/foia/?status=done&jurisdiction={request.jurisdiction}" + + def process_results(self, results: list[dict]): + self.loop_count += 1 + self.num_found_last_loop = len(results) + self.results.extend(results) + self.num_found += len(results) \ No newline at end of file diff --git a/src/collectors/impl/muckrock/fetchers/jurisdiction/__init__.py b/src/collectors/impl/muckrock/fetchers/jurisdiction/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/impl/muckrock/fetchers/jurisdiction/core.py b/src/collectors/impl/muckrock/fetchers/jurisdiction/core.py new file mode 100644 index 00000000..8f21bca3 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/jurisdiction/core.py @@ -0,0 +1,13 @@ +from src.collectors.impl.muckrock.fetch_requests.jurisdiction_by_id import \ + JurisdictionByIDFetchRequest +from src.collectors.impl.muckrock.fetchers.templates.fetcher import MuckrockFetcherBase +from src.collectors.impl.muckrock.constants import BASE_MUCKROCK_URL + + +class JurisdictionByIDFetcher(MuckrockFetcherBase): + + def build_url(self, request: JurisdictionByIDFetchRequest) -> str: + return f"{BASE_MUCKROCK_URL}/jurisdiction/{request.jurisdiction_id}/" + + async def get_jurisdiction(self, jurisdiction_id: int) -> dict: + return await self.fetch(request=JurisdictionByIDFetchRequest(jurisdiction_id=jurisdiction_id)) diff --git a/src/collectors/impl/muckrock/fetchers/jurisdiction/generator.py b/src/collectors/impl/muckrock/fetchers/jurisdiction/generator.py new file mode 100644 index 00000000..394a6801 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/jurisdiction/generator.py @@ -0,0 +1,17 @@ +from src.collectors.impl.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest +from src.collectors.impl.muckrock.fetchers.jurisdiction.manager import JurisdictionFetchManager +from src.collectors.impl.muckrock.fetchers.templates.generator import MuckrockGeneratorFetcher + + +class JurisdictionGeneratorFetcher(MuckrockGeneratorFetcher): + + def __init__(self, initial_request: JurisdictionLoopFetchRequest): + super().__init__(initial_request) + self.jfm = JurisdictionFetchManager(town_names=initial_request.town_names) + + def build_url(self, request: JurisdictionLoopFetchRequest) -> str: + return self.jfm.build_url(request) + + def process_results(self, results: list[dict]): + return self.jfm.process_results(results) + diff --git a/src/collectors/impl/muckrock/fetchers/jurisdiction/loop.py b/src/collectors/impl/muckrock/fetchers/jurisdiction/loop.py new file mode 100644 index 00000000..16ecdaa3 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/jurisdiction/loop.py @@ -0,0 +1,38 @@ +from tqdm import tqdm + +from src.collectors.impl.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest +from src.collectors.impl.muckrock.fetchers.jurisdiction.manager import JurisdictionFetchManager +from src.collectors.impl.muckrock.fetchers.templates.loop import MuckrockLoopFetcher + + +class JurisdictionLoopFetcher(MuckrockLoopFetcher): + + def __init__(self, initial_request: JurisdictionLoopFetchRequest): + super().__init__(initial_request) + self.jfm = JurisdictionFetchManager(town_names=initial_request.town_names) + self.pbar_jurisdictions = tqdm( + total=len(self.jfm.town_names), + desc="Fetching jurisdictions", + unit="jurisdiction", + position=0, + leave=False + ) + self.pbar_page = tqdm( + desc="Processing pages", + unit="page", + position=1, + leave=False + ) + + def build_url(self, request: JurisdictionLoopFetchRequest) -> str: + return self.jfm.build_url(request) + + def process_results(self, results: list[dict]): + self.jfm.process_results(results) + + def report_progress(self): + old_num_jurisdictions_found = self.jfm.num_jurisdictions_found + self.jfm.num_jurisdictions_found = len(self.jfm.jurisdictions) + difference = self.jfm.num_jurisdictions_found - old_num_jurisdictions_found + self.pbar_jurisdictions.update(difference) + self.pbar_page.update(1) diff --git a/src/collectors/impl/muckrock/fetchers/jurisdiction/manager.py b/src/collectors/impl/muckrock/fetchers/jurisdiction/manager.py new file mode 100644 index 00000000..9cd24df2 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/jurisdiction/manager.py @@ -0,0 +1,22 @@ +from src.collectors.impl.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest +from src.collectors.impl.muckrock.constants import BASE_MUCKROCK_URL + + +class JurisdictionFetchManager: + + def __init__(self, town_names: list[str]): + self.town_names = town_names + self.num_jurisdictions_found = 0 + self.total_found = 0 + self.jurisdictions = {} + + def build_url(self, request: JurisdictionLoopFetchRequest) -> str: + return f"{BASE_MUCKROCK_URL}/jurisdiction/?level={request.level}&parent={request.parent}" + + def process_results(self, results: list[dict]): + for item in results: + if item["name"] in self.town_names: + self.jurisdictions[item["name"]] = item["id"] + self.total_found += 1 + self.num_jurisdictions_found = len(self.jurisdictions) + return f"Found {self.num_jurisdictions_found} jurisdictions; {self.total_found} entries found total." diff --git a/src/collectors/impl/muckrock/fetchers/templates/__init__.py b/src/collectors/impl/muckrock/fetchers/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/source_collectors/muckrock/fetchers/templates/fetcher.py b/src/collectors/impl/muckrock/fetchers/templates/fetcher.py similarity index 83% rename from src/collectors/source_collectors/muckrock/fetchers/templates/fetcher.py rename to src/collectors/impl/muckrock/fetchers/templates/fetcher.py index 6661c04a..1c41f6fd 100644 --- a/src/collectors/source_collectors/muckrock/fetchers/templates/fetcher.py +++ b/src/collectors/impl/muckrock/fetchers/templates/fetcher.py @@ -4,8 +4,8 @@ import requests import aiohttp -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest -from src.collectors.source_collectors.muckrock.exceptions import MuckrockNoMoreDataError, MuckrockServerError +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest +from src.collectors.impl.muckrock.exceptions import MuckrockNoMoreDataError, MuckrockServerError class MuckrockFetcherBase(ABC): diff --git a/src/collectors/impl/muckrock/fetchers/templates/generator.py b/src/collectors/impl/muckrock/fetchers/templates/generator.py new file mode 100644 index 00000000..55fa62ec --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/templates/generator.py @@ -0,0 +1,30 @@ +from src.collectors.impl.muckrock.fetchers.templates.iter_fetcher import MuckrockIterFetcherBase +from src.collectors.impl.muckrock.exceptions import RequestFailureException + + +class MuckrockGeneratorFetcher(MuckrockIterFetcherBase): + """ + Similar to the Muckrock Loop fetcher, but behaves + as a generator instead of a loop + """ + + async def generator_fetch(self) -> str: + """ + Fetches data and yields status messages between requests + """ + url = self.build_url(self.initial_request) + final_message = "No more records found. Exiting..." + while url is not None: + try: + data = await self.get_response(url) + except RequestFailureException: + final_message = "Request unexpectedly failed. Exiting..." + break + + yield self.process_results(data["results"]) + url = data["next"] + + yield final_message + + + diff --git a/src/collectors/source_collectors/muckrock/fetchers/templates/iter_fetcher.py b/src/collectors/impl/muckrock/fetchers/templates/iter_fetcher.py similarity index 83% rename from src/collectors/source_collectors/muckrock/fetchers/templates/iter_fetcher.py rename to src/collectors/impl/muckrock/fetchers/templates/iter_fetcher.py index cc397242..66ee4cd3 100644 --- a/src/collectors/source_collectors/muckrock/fetchers/templates/iter_fetcher.py +++ b/src/collectors/impl/muckrock/fetchers/templates/iter_fetcher.py @@ -3,8 +3,8 @@ import aiohttp import requests -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest -from src.collectors.source_collectors.muckrock.exceptions import RequestFailureException +from src.collectors.impl.muckrock.fetch_requests.base import FetchRequest +from src.collectors.impl.muckrock.exceptions import RequestFailureException class MuckrockIterFetcherBase(ABC): diff --git a/src/collectors/impl/muckrock/fetchers/templates/loop.py b/src/collectors/impl/muckrock/fetchers/templates/loop.py new file mode 100644 index 00000000..427564c2 --- /dev/null +++ b/src/collectors/impl/muckrock/fetchers/templates/loop.py @@ -0,0 +1,32 @@ +from abc import abstractmethod +from time import sleep + +from src.collectors.impl.muckrock.fetchers.templates.iter_fetcher import MuckrockIterFetcherBase +from src.collectors.impl.muckrock.exceptions import RequestFailureException + + +class MuckrockLoopFetcher(MuckrockIterFetcherBase): + + async def loop_fetch(self): + url = self.build_url(self.initial_request) + while url is not None: + try: + data = await self.get_response(url) + except RequestFailureException: + break + + url = self.process_data(data) + sleep(1) + + def process_data(self, data: dict): + """ + Process data and get next url, if any + """ + self.process_results(data["results"]) + self.report_progress() + url = data["next"] + return url + + @abstractmethod + def report_progress(self): + pass diff --git a/src/collectors/manager.py b/src/collectors/manager.py index b90e03a6..a493b92c 100644 --- a/src/collectors/manager.py +++ b/src/collectors/manager.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from src.db.client.async_ import AsyncDatabaseClient -from src.collectors.source_collectors.base import AsyncCollectorBase +from src.collectors.impl.base import AsyncCollectorBase from src.collectors.exceptions import InvalidCollectorError from src.collectors.mapping import COLLECTOR_MAPPING from src.collectors.enums import CollectorType diff --git a/src/collectors/mapping.py b/src/collectors/mapping.py index e07cac09..32aeda5a 100644 --- a/src/collectors/mapping.py +++ b/src/collectors/mapping.py @@ -1,11 +1,11 @@ from src.collectors.enums import CollectorType -from src.collectors.source_collectors.auto_googler.collector import AutoGooglerCollector -from src.collectors.source_collectors.ckan.collector import CKANCollector -from src.collectors.source_collectors.common_crawler.collector import CommonCrawlerCollector -from src.collectors.source_collectors.example.core import ExampleCollector -from src.collectors.source_collectors.muckrock.collectors.all_foia.core import MuckrockAllFOIARequestsCollector -from src.collectors.source_collectors.muckrock.collectors.county.core import MuckrockCountyLevelSearchCollector -from src.collectors.source_collectors.muckrock.collectors.simple.core import MuckrockSimpleSearchCollector +from src.collectors.impl.auto_googler.collector import AutoGooglerCollector +from src.collectors.impl.ckan.collector import CKANCollector +from src.collectors.impl.common_crawler.collector import CommonCrawlerCollector +from src.collectors.impl.example.core import ExampleCollector +from src.collectors.impl.muckrock.collectors.all_foia.core import MuckrockAllFOIARequestsCollector +from src.collectors.impl.muckrock.collectors.county.core import MuckrockCountyLevelSearchCollector +from src.collectors.impl.muckrock.collectors.simple.core import MuckrockSimpleSearchCollector COLLECTOR_MAPPING = { CollectorType.EXAMPLE: ExampleCollector, diff --git a/src/collectors/queries/__init__.py b/src/collectors/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/queries/get_url_info.py b/src/collectors/queries/get_url_info.py new file mode 100644 index 00000000..9dc9fc24 --- /dev/null +++ b/src/collectors/queries/get_url_info.py @@ -0,0 +1,19 @@ +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLInfoByURLQueryBuilder(QueryBuilderBase): + + def __init__(self, url: str): + super().__init__() + self.url = url + + async def run(self, session: AsyncSession) -> URLInfo | None: + query = Select(URL).where(URL.url == self.url) + raw_result = await session.execute(query) + url = raw_result.scalars().first() + return URLInfo(**url.__dict__) \ No newline at end of file diff --git a/src/collectors/queries/insert/__init__.py b/src/collectors/queries/insert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/queries/insert/url.py b/src/collectors/queries/insert/url.py new file mode 100644 index 00000000..af72a3aa --- /dev/null +++ b/src/collectors/queries/insert/url.py @@ -0,0 +1,33 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class InsertURLQueryBuilder(QueryBuilderBase): + + + def __init__(self, url_info: URLInfo): + super().__init__() + self.url_info = url_info + + async def run(self, session: AsyncSession) -> int: + """Insert a new URL into the database.""" + url_entry = URL( + url=self.url_info.url, + collector_metadata=self.url_info.collector_metadata, + status=self.url_info.status.value, + source=self.url_info.source + ) + if self.url_info.created_at is not None: + url_entry.created_at = self.url_info.created_at + session.add(url_entry) + await session.flush() + link = LinkBatchURL( + batch_id=self.url_info.batch_id, + url_id=url_entry.id + ) + session.add(link) + return url_entry.id \ No newline at end of file diff --git a/src/collectors/queries/insert/urls/__init__.py b/src/collectors/queries/insert/urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/collectors/queries/insert/urls/query.py b/src/collectors/queries/insert/urls/query.py new file mode 100644 index 00000000..75176158 --- /dev/null +++ b/src/collectors/queries/insert/urls/query.py @@ -0,0 +1,56 @@ +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.queries.insert.urls.request_manager import InsertURLsRequestManager +from src.util.clean import clean_url +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.duplicate.pydantic.insert import DuplicateInsertInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.queries.base.builder import QueryBuilderBase + + +class InsertURLsQueryBuilder(QueryBuilderBase): + + def __init__( + self, + url_infos: list[URLInfo], + batch_id: int + ): + super().__init__() + self.url_infos = url_infos + self.batch_id = batch_id + + async def run(self, session: AsyncSession) -> InsertURLsInfo: + url_mappings = [] + duplicates = [] + rm = InsertURLsRequestManager(session=session) + for url_info in self.url_infos: + url_info.url = clean_url(url_info.url) + url_info.batch_id = self.batch_id + try: + async with session.begin_nested() as sp: + url_id = await rm.insert_url(url_info) + url_mappings.append( + URLMapping( + url_id=url_id, + url=url_info.url + ) + ) + except IntegrityError: + sp.rollback() + orig_url_info = await rm.get_url_info_by_url(url_info.url) + duplicate_info = DuplicateInsertInfo( + batch_id=self.batch_id, + original_url_id=orig_url_info.id + ) + duplicates.append(duplicate_info) + await rm.insert_duplicates(duplicates) + + return InsertURLsInfo( + url_mappings=url_mappings, + total_count=len(self.url_infos), + original_count=len(url_mappings), + duplicate_count=len(duplicates), + url_ids=[url_mapping.url_id for url_mapping in url_mappings] + ) diff --git a/src/collectors/queries/insert/urls/request_manager.py b/src/collectors/queries/insert/urls/request_manager.py new file mode 100644 index 00000000..22f6ff66 --- /dev/null +++ b/src/collectors/queries/insert/urls/request_manager.py @@ -0,0 +1,33 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.queries.get_url_info import GetURLInfoByURLQueryBuilder +from src.collectors.queries.insert.url import InsertURLQueryBuilder +from src.db.models.impl.duplicate.pydantic.insert import DuplicateInsertInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo + +from src.db.helpers.session import session_helper as sh + + +class InsertURLsRequestManager: + + def __init__( + self, + session: AsyncSession + ): + self.session = session + + async def insert_url(self, url_info: URLInfo) -> int: + return await InsertURLQueryBuilder( + url_info=url_info + ).run(self.session) + + async def get_url_info_by_url(self, url: str) -> URLInfo | None: + return await GetURLInfoByURLQueryBuilder( + url=url + ).run(self.session) + + async def insert_duplicates( + self, + duplicates: list[DuplicateInsertInfo] + ) -> None: + await sh.bulk_insert(self.session, models=duplicates) \ No newline at end of file diff --git a/src/collectors/source_collectors/auto_googler/auto_googler.py b/src/collectors/source_collectors/auto_googler/auto_googler.py deleted file mode 100644 index 49cdc2de..00000000 --- a/src/collectors/source_collectors/auto_googler/auto_googler.py +++ /dev/null @@ -1,31 +0,0 @@ -from src.collectors.source_collectors.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO -from src.collectors.source_collectors.auto_googler.searcher import GoogleSearcher -from src.collectors.source_collectors.auto_googler.dtos.config import SearchConfig - - -class AutoGoogler: - """ - The AutoGoogler orchestrates the process of fetching urls from Google Search - and processing them for source collection - - """ - def __init__(self, search_config: SearchConfig, google_searcher: GoogleSearcher): - self.search_config = search_config - self.google_searcher = google_searcher - self.data: dict[str, list[GoogleSearchQueryResultsInnerDTO]] = { - query : [] for query in search_config.queries - } - - async def run(self) -> str: - """ - Runs the AutoGoogler - Yields status messages - """ - for query in self.search_config.queries: - yield f"Searching for '{query}' ..." - results = await self.google_searcher.search(query) - yield f"Found {len(results)} results for '{query}'." - if results is not None: - self.data[query] = results - yield "Done." - diff --git a/src/collectors/source_collectors/auto_googler/collector.py b/src/collectors/source_collectors/auto_googler/collector.py deleted file mode 100644 index 718bdfb7..00000000 --- a/src/collectors/source_collectors/auto_googler/collector.py +++ /dev/null @@ -1,48 +0,0 @@ - -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.enums import CollectorType -from src.core.env_var_manager import EnvVarManager -from src.core.preprocessors.autogoogler import AutoGooglerPreprocessor -from src.collectors.source_collectors.auto_googler.auto_googler import AutoGoogler -from src.collectors.source_collectors.auto_googler.dtos.output import AutoGooglerInnerOutputDTO -from src.collectors.source_collectors.auto_googler.dtos.input import AutoGooglerInputDTO -from src.collectors.source_collectors.auto_googler.searcher import GoogleSearcher -from src.collectors.source_collectors.auto_googler.dtos.config import SearchConfig -from src.util.helper_functions import base_model_list_dump - - -class AutoGooglerCollector(AsyncCollectorBase): - collector_type = CollectorType.AUTO_GOOGLER - preprocessor = AutoGooglerPreprocessor - - async def run_to_completion(self) -> AutoGoogler: - dto: AutoGooglerInputDTO = self.dto - env_var_manager = EnvVarManager.get() - auto_googler = AutoGoogler( - search_config=SearchConfig( - urls_per_result=dto.urls_per_result, - queries=dto.queries, - ), - google_searcher=GoogleSearcher( - api_key=env_var_manager.google_api_key, - cse_id=env_var_manager.google_cse_id, - ) - ) - async for log in auto_googler.run(): - await self.log(log) - return auto_googler - - async def run_implementation(self) -> None: - - auto_googler = await self.run_to_completion() - - inner_data = [] - for query in auto_googler.search_config.queries: - query_results: list[AutoGooglerInnerOutputDTO] = auto_googler.data[query] - inner_data.append({ - "query": query, - "query_results": base_model_list_dump(query_results), - }) - - self.data = {"data": inner_data} - diff --git a/src/collectors/source_collectors/auto_googler/dtos/input.py b/src/collectors/source_collectors/auto_googler/dtos/input.py deleted file mode 100644 index 801d6104..00000000 --- a/src/collectors/source_collectors/auto_googler/dtos/input.py +++ /dev/null @@ -1,15 +0,0 @@ -from pydantic import BaseModel, Field - - -class AutoGooglerInputDTO(BaseModel): - urls_per_result: int = Field( - description="Maximum number of URLs returned per result. Minimum is 1. Default is 10", - default=10, - ge=1, - le=50 - ) - queries: list[str] = Field( - description="List of queries to search for.", - min_length=1, - max_length=100 - ) diff --git a/src/collectors/source_collectors/auto_googler/searcher.py b/src/collectors/source_collectors/auto_googler/searcher.py deleted file mode 100644 index aa8a0bb6..00000000 --- a/src/collectors/source_collectors/auto_googler/searcher.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Union - -import aiohttp -from googleapiclient.errors import HttpError - -from src.collectors.source_collectors.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO -from src.collectors.source_collectors.auto_googler.exceptions import QuotaExceededError - - -class GoogleSearcher: - """ - A class that provides a GoogleSearcher object for performing searches using the Google Custom Search API. - - Attributes: - api_key (str): The API key required for accessing the Google Custom Search API. - cse_id (str): The CSE (Custom Search Engine) ID required for identifying the specific search engine to use. - service (Google API service): The Google API service object for performing the search. - - Methods: - __init__(api_key: str, cse_id: str) - Initializes a GoogleSearcher object with the provided API key and CSE ID. Raises a RuntimeError if either - the API key or CSE ID is None. - - search(query: str) -> Union[list[dict], None] - Performs a search using the Google Custom Search API with the provided query string. Returns a list of - search results as dictionaries or None if the daily quota for the API has been exceeded. Raises a RuntimeError - if any other error occurs during the search. - """ - GOOGLE_SEARCH_URL = "https://www.googleapis.com/customsearch/v1" - - def __init__( - self, - api_key: str, - cse_id: str - ): - if api_key is None or cse_id is None: - raise RuntimeError("Custom search API key and CSE ID cannot be None.") - self.api_key = api_key - self.cse_id = cse_id - - async def search(self, query: str) -> Union[list[dict], None]: - """ - Searches for results using the specified query. - - Args: - query (str): The query to search for. - - Returns: Union[list[dict], None]: A list of dictionaries representing the search results. - If the daily quota is exceeded, None is returned. - """ - try: - return await self.get_query_results(query) - # Process your results - except HttpError as e: - if "Quota exceeded" in str(e): - raise QuotaExceededError("Quota exceeded for the day") - else: - raise RuntimeError(f"An error occurred: {str(e)}") - - async def get_query_results(self, query) -> list[GoogleSearchQueryResultsInnerDTO] or None: - params = { - "key": self.api_key, - "cx": self.cse_id, - "q": query, - } - - async with aiohttp.ClientSession() as session: - async with session.get(self.GOOGLE_SEARCH_URL, params=params) as response: - response.raise_for_status() - results = await response.json() - - if "items" not in results: - return None - - items = [] - - for item in results["items"]: - inner_dto = GoogleSearchQueryResultsInnerDTO( - url=item["link"], - title=item["title"], - snippet=item.get("snippet", ""), - ) - items.append(inner_dto) - - return items diff --git a/src/collectors/source_collectors/base.py b/src/collectors/source_collectors/base.py deleted file mode 100644 index 5fbb08c5..00000000 --- a/src/collectors/source_collectors/base.py +++ /dev/null @@ -1,133 +0,0 @@ -import abc -import asyncio -import time -from abc import ABC -from typing import Type, Optional - -from pydantic import BaseModel - -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.insert import InsertURLsInfo -from src.db.dtos.log import LogInfo -from src.collectors.enums import CollectorType -from src.core.logger import AsyncCoreLogger -from src.core.function_trigger import FunctionTrigger -from src.core.enums import BatchStatus -from src.core.preprocessors.base import PreprocessorBase - - -class AsyncCollectorBase(ABC): - collector_type: CollectorType = None - preprocessor: Type[PreprocessorBase] = None - - - def __init__( - self, - batch_id: int, - dto: BaseModel, - logger: AsyncCoreLogger, - adb_client: AsyncDatabaseClient, - raise_error: bool = False, - post_collection_function_trigger: Optional[FunctionTrigger] = None, - ) -> None: - self.post_collection_function_trigger = post_collection_function_trigger - self.batch_id = batch_id - self.adb_client = adb_client - self.dto = dto - self.data: Optional[BaseModel] = None - self.logger = logger - self.status = BatchStatus.IN_PROCESS - self.start_time = None - self.compute_time = None - self.raise_error = raise_error - - @abc.abstractmethod - async def run_implementation(self) -> None: - """ - This is the method that will be overridden by each collector - No other methods should be modified except for this one. - However, in each inherited class, new methods in addition to this one can be created - Returns: - - """ - raise NotImplementedError - - async def start_timer(self) -> None: - self.start_time = time.time() - - async def stop_timer(self) -> None: - self.compute_time = time.time() - self.start_time - - async def handle_error(self, e: Exception) -> None: - if self.raise_error: - raise e - await self.log(f"Error: {e}") - await self.adb_client.update_batch_post_collection( - batch_id=self.batch_id, - batch_status=self.status, - compute_time=self.compute_time, - total_url_count=0, - original_url_count=0, - duplicate_url_count=0 - ) - - async def process(self) -> None: - await self.log("Processing collector...") - preprocessor = self.preprocessor() - url_infos = preprocessor.preprocess(self.data) - await self.log(f"URLs processed: {len(url_infos)}") - - await self.log("Inserting URLs...") - insert_urls_info: InsertURLsInfo = await self.adb_client.insert_urls( - url_infos=url_infos, - batch_id=self.batch_id - ) - await self.log("Updating batch...") - await self.adb_client.update_batch_post_collection( - batch_id=self.batch_id, - total_url_count=insert_urls_info.total_count, - duplicate_url_count=insert_urls_info.duplicate_count, - original_url_count=insert_urls_info.original_count, - batch_status=self.status, - compute_time=self.compute_time - ) - await self.log("Done processing collector.") - - if self.post_collection_function_trigger is not None: - await self.post_collection_function_trigger.trigger_or_rerun() - - async def run(self) -> None: - try: - await self.start_timer() - await self.run_implementation() - await self.stop_timer() - await self.log("Collector completed successfully.") - await self.close() - await self.process() - except asyncio.CancelledError: - await self.stop_timer() - self.status = BatchStatus.ABORTED - await self.adb_client.update_batch_post_collection( - batch_id=self.batch_id, - batch_status=BatchStatus.ABORTED, - compute_time=self.compute_time, - total_url_count=0, - original_url_count=0, - duplicate_url_count=0 - ) - except Exception as e: - await self.stop_timer() - self.status = BatchStatus.ERROR - await self.handle_error(e) - - async def log( - self, - message: str, - ) -> None: - await self.logger.log(LogInfo( - batch_id=self.batch_id, - log=message - )) - - async def close(self) -> None: - self.status = BatchStatus.READY_TO_LABEL diff --git a/src/collectors/source_collectors/ckan/collector.py b/src/collectors/source_collectors/ckan/collector.py deleted file mode 100644 index 3239e83b..00000000 --- a/src/collectors/source_collectors/ckan/collector.py +++ /dev/null @@ -1,71 +0,0 @@ -from pydantic import BaseModel - -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.enums import CollectorType -from src.core.preprocessors.ckan import CKANPreprocessor -from src.collectors.source_collectors.ckan.dtos.input import CKANInputDTO -from src.collectors.source_collectors.ckan.scraper_toolkit.search_funcs.group import ckan_group_package_search -from src.collectors.source_collectors.ckan.scraper_toolkit.search_funcs.organization import ckan_package_search_from_organization -from src.collectors.source_collectors.ckan.scraper_toolkit.search_funcs.package import ckan_package_search -from src.collectors.source_collectors.ckan.scraper_toolkit.search import perform_search, get_flat_list, deduplicate_entries, \ - get_collections, filter_result, parse_result -from src.util.helper_functions import base_model_list_dump - -SEARCH_FUNCTION_MAPPINGS = { - "package_search": ckan_package_search, - "group_search": ckan_group_package_search, - "organization_search": ckan_package_search_from_organization -} - -class CKANCollector(AsyncCollectorBase): - collector_type = CollectorType.CKAN - preprocessor = CKANPreprocessor - - async def run_implementation(self): - results = await self.get_results() - flat_list = get_flat_list(results) - deduped_flat_list = deduplicate_entries(flat_list) - - list_with_collection_child_packages = await self.add_collection_child_packages(deduped_flat_list) - - filtered_results = list( - filter( - filter_result, - list_with_collection_child_packages - ) - ) - parsed_results = list(map(parse_result, filtered_results)) - - self.data = {"results": parsed_results} - - async def add_collection_child_packages(self, deduped_flat_list): - # TODO: Find a way to clearly indicate which parts call from the CKAN API - list_with_collection_child_packages = [] - count = len(deduped_flat_list) - for idx, result in enumerate(deduped_flat_list): - if "extras" in result.keys(): - await self.log(f"Found collection ({idx + 1}/{count}): {result['id']}") - collections = await get_collections(result) - if collections: - list_with_collection_child_packages += collections[0] - continue - - list_with_collection_child_packages.append(result) - return list_with_collection_child_packages - - async def get_results(self): - results = [] - dto: CKANInputDTO = self.dto - for search in SEARCH_FUNCTION_MAPPINGS.keys(): - await self.log(f"Running search '{search}'...") - sub_dtos: list[BaseModel] = getattr(dto, search) - if sub_dtos is None: - continue - func = SEARCH_FUNCTION_MAPPINGS[search] - results = await perform_search( - search_func=func, - search_terms=base_model_list_dump(model_list=sub_dtos), - results=results - ) - return results - diff --git a/src/collectors/source_collectors/ckan/dtos/input.py b/src/collectors/source_collectors/ckan/dtos/input.py deleted file mode 100644 index b835999e..00000000 --- a/src/collectors/source_collectors/ckan/dtos/input.py +++ /dev/null @@ -1,19 +0,0 @@ -from pydantic import BaseModel, Field - -from src.collectors.source_collectors.ckan.dtos.search.group_and_organization import GroupAndOrganizationSearchDTO -from src.collectors.source_collectors.ckan.dtos.search.package import CKANPackageSearchDTO - - -class CKANInputDTO(BaseModel): - package_search: list[CKANPackageSearchDTO] or None = Field( - description="The list of package searches to perform.", - default=None - ) - group_search: list[GroupAndOrganizationSearchDTO] or None = Field( - description="The list of group searches to perform.", - default=None - ) - organization_search: list[GroupAndOrganizationSearchDTO] or None = Field( - description="The list of organization searches to perform.", - default=None - ) diff --git a/src/collectors/source_collectors/ckan/dtos/search/package.py b/src/collectors/source_collectors/ckan/dtos/search/package.py deleted file mode 100644 index 43fcbda5..00000000 --- a/src/collectors/source_collectors/ckan/dtos/search/package.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel, Field - -from src.collectors.source_collectors.ckan.dtos.search._helpers import url_field - - -class CKANPackageSearchDTO(BaseModel): - url: str = url_field - terms: Optional[list[str]] = Field( - description="The search terms to use to refine the packages returned. " - "None will return all packages.", - default=None - ) diff --git a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/package.py b/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/package.py deleted file mode 100644 index f5737b35..00000000 --- a/src/collectors/source_collectors/ckan/scraper_toolkit/search_funcs/package.py +++ /dev/null @@ -1,52 +0,0 @@ -import sys -from typing import Optional, Any - -from src.collectors.source_collectors.ckan.scraper_toolkit._api_interface import CKANAPIInterface - - -async def ckan_package_search( - base_url: str, - query: Optional[str] = None, - rows: Optional[int] = sys.maxsize, - start: Optional[int] = 0, - **kwargs, -) -> list[dict[str, Any]]: - """Performs a CKAN package (dataset) search from a CKAN data catalog URL. - - :param base_url: Base URL to search from. e.g. "https://catalog.data.gov/" - :param query: Search string, defaults to None. None will return all packages. - :param rows: Maximum number of results to return, defaults to maximum integer. - :param start: Offsets the results, defaults to 0. - :param kwargs: See https://docs.ckan.org/en/2.10/api/index.html#ckan.logic.action.get.package_search for additional arguments. - :return: List of dictionaries representing the CKAN package search results. - """ - interface = CKANAPIInterface(base_url) - results = [] - offset = start - rows_max = 1000 # CKAN's package search has a hard limit of 1000 packages returned at a time by default - - while start < rows: - num_rows = rows - start + offset - packages: dict = await interface.package_search( - query=query, rows=num_rows, start=start, **kwargs - ) - add_base_url_to_packages(base_url, packages) - results += packages["results"] - - total_results = packages["count"] - if rows > total_results: - rows = total_results - - result_len = len(packages["results"]) - # Check if the website has a different rows_max value than CKAN's default - if result_len != rows_max and start + rows_max < total_results: - rows_max = result_len - - start += rows_max - - return results - - -def add_base_url_to_packages(base_url, packages): - # Add the base_url to each package - [package.update(base_url=base_url) for package in packages["results"]] diff --git a/src/collectors/source_collectors/common_crawler/collector.py b/src/collectors/source_collectors/common_crawler/collector.py deleted file mode 100644 index e5e65dfe..00000000 --- a/src/collectors/source_collectors/common_crawler/collector.py +++ /dev/null @@ -1,25 +0,0 @@ -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.enums import CollectorType -from src.core.preprocessors.common_crawler import CommonCrawlerPreprocessor -from src.collectors.source_collectors.common_crawler.crawler import CommonCrawler -from src.collectors.source_collectors.common_crawler.input import CommonCrawlerInputDTO - - -class CommonCrawlerCollector(AsyncCollectorBase): - collector_type = CollectorType.COMMON_CRAWLER - preprocessor = CommonCrawlerPreprocessor - - async def run_implementation(self) -> None: - print("Running Common Crawler...") - dto: CommonCrawlerInputDTO = self.dto - common_crawler = CommonCrawler( - crawl_id=dto.common_crawl_id, - url=dto.url, - keyword=dto.search_term, - start_page=dto.start_page, - num_pages=dto.total_pages, - ) - async for status in common_crawler.run(): - await self.log(status) - - self.data = {"urls": common_crawler.url_results} \ No newline at end of file diff --git a/src/collectors/source_collectors/example/core.py b/src/collectors/source_collectors/example/core.py deleted file mode 100644 index 988caa09..00000000 --- a/src/collectors/source_collectors/example/core.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Example collector -Exists as a proof of concept for collector functionality - -""" -import asyncio - -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO -from src.collectors.source_collectors.example.dtos.output import ExampleOutputDTO -from src.collectors.enums import CollectorType -from src.core.preprocessors.example import ExamplePreprocessor - - -class ExampleCollector(AsyncCollectorBase): - collector_type = CollectorType.EXAMPLE - preprocessor = ExamplePreprocessor - - async def run_implementation(self) -> None: - dto: ExampleInputDTO = self.dto - sleep_time = dto.sleep_time - for i in range(sleep_time): # Simulate a task - await self.log(f"Step {i + 1}/{sleep_time}") - await self.sleep() - self.data = ExampleOutputDTO( - message=f"Data collected by {self.batch_id}", - urls=["https://example.com", "https://example.com/2"], - parameters=self.dto.model_dump(), - ) - - @staticmethod - async def sleep(): - # Simulate work - await asyncio.sleep(1) \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/api_interface/core.py b/src/collectors/source_collectors/muckrock/api_interface/core.py deleted file mode 100644 index 3b174cf5..00000000 --- a/src/collectors/source_collectors/muckrock/api_interface/core.py +++ /dev/null @@ -1,40 +0,0 @@ -from typing import Optional - -import requests -from aiohttp import ClientSession - -from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse -from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType - - -class MuckrockAPIInterface: - - def __init__(self, session: Optional[ClientSession] = None): - self.base_url = "https://www.muckrock.com/api_v1/" - self.session = session - - def build_url(self, subpath: str): - return f"{self.base_url}{subpath}" - - - async def lookup_agency(self, muckrock_agency_id: int) -> AgencyLookupResponse: - url = self.build_url(f"agency/{muckrock_agency_id}") - try: - async with self.session.get(url) as results: - results.raise_for_status() - json = await results.json() - name = json["name"] - return AgencyLookupResponse( - name=name, type=AgencyLookupResponseType.FOUND - ) - except requests.exceptions.HTTPError as e: - return AgencyLookupResponse( - name=None, - type=AgencyLookupResponseType.ERROR, - error=str(e) - ) - except KeyError: - return AgencyLookupResponse( - name=None, type=AgencyLookupResponseType.NOT_FOUND - ) - diff --git a/src/collectors/source_collectors/muckrock/api_interface/lookup_response.py b/src/collectors/source_collectors/muckrock/api_interface/lookup_response.py deleted file mode 100644 index a714eeb5..00000000 --- a/src/collectors/source_collectors/muckrock/api_interface/lookup_response.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType - - -class AgencyLookupResponse(BaseModel): - name: Optional[str] - type: AgencyLookupResponseType - error: Optional[str] = None diff --git a/src/collectors/source_collectors/muckrock/collectors/all_foia/core.py b/src/collectors/source_collectors/muckrock/collectors/all_foia/core.py deleted file mode 100644 index 0033d242..00000000 --- a/src/collectors/source_collectors/muckrock/collectors/all_foia/core.py +++ /dev/null @@ -1,50 +0,0 @@ -from src.collectors.enums import CollectorType -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.source_collectors.muckrock.collectors.all_foia.dto import MuckrockAllFOIARequestsCollectorInputDTO -from src.collectors.source_collectors.muckrock.fetchers.foia.core import FOIAFetcher -from src.collectors.source_collectors.muckrock.exceptions import MuckrockNoMoreDataError -from src.core.preprocessors.muckrock import MuckrockPreprocessor - - -class MuckrockAllFOIARequestsCollector(AsyncCollectorBase): - """ - Retrieves urls associated with all Muckrock FOIA requests - """ - collector_type = CollectorType.MUCKROCK_ALL_SEARCH - preprocessor = MuckrockPreprocessor - - async def run_implementation(self) -> None: - dto: MuckrockAllFOIARequestsCollectorInputDTO = self.dto - start_page = dto.start_page - fetcher = FOIAFetcher( - start_page=start_page, - ) - total_pages = dto.total_pages - all_page_data = await self.get_page_data(fetcher, start_page, total_pages) - all_transformed_data = self.transform_data(all_page_data) - self.data = {"urls": all_transformed_data} - - - async def get_page_data(self, fetcher, start_page, total_pages): - all_page_data = [] - for page in range(start_page, start_page + total_pages): - await self.log(f"Fetching page {fetcher.current_page}") - try: - page_data = await fetcher.fetch_next_page() - except MuckrockNoMoreDataError: - await self.log(f"No more data to fetch at page {fetcher.current_page}") - break - if page_data is None: - continue - all_page_data.append(page_data) - return all_page_data - - def transform_data(self, all_page_data): - all_transformed_data = [] - for page_data in all_page_data: - for data in page_data["results"]: - all_transformed_data.append({ - "url": data["absolute_url"], - "metadata": data - }) - return all_transformed_data diff --git a/src/collectors/source_collectors/muckrock/collectors/county/core.py b/src/collectors/source_collectors/muckrock/collectors/county/core.py deleted file mode 100644 index 9a429d5d..00000000 --- a/src/collectors/source_collectors/muckrock/collectors/county/core.py +++ /dev/null @@ -1,60 +0,0 @@ -from src.collectors.enums import CollectorType -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.source_collectors.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO -from src.collectors.source_collectors.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest -from src.collectors.source_collectors.muckrock.fetch_requests.jurisdiction_loop import \ - JurisdictionLoopFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.foia.loop import FOIALoopFetcher -from src.collectors.source_collectors.muckrock.fetchers.jurisdiction.generator import \ - JurisdictionGeneratorFetcher -from src.core.preprocessors.muckrock import MuckrockPreprocessor - - -class MuckrockCountyLevelSearchCollector(AsyncCollectorBase): - """ - Searches for any and all requests in a certain county - """ - collector_type = CollectorType.MUCKROCK_COUNTY_SEARCH - preprocessor = MuckrockPreprocessor - - async def run_implementation(self) -> None: - jurisdiction_ids = await self.get_jurisdiction_ids() - if jurisdiction_ids is None: - await self.log("No jurisdictions found") - return - all_data = await self.get_foia_records(jurisdiction_ids) - formatted_data = self.format_data(all_data) - self.data = {"urls": formatted_data} - - def format_data(self, all_data): - formatted_data = [] - for data in all_data: - formatted_data.append({ - "url": data["absolute_url"], - "metadata": data - }) - return formatted_data - - async def get_foia_records(self, jurisdiction_ids): - all_data = [] - for name, id_ in jurisdiction_ids.items(): - await self.log(f"Fetching records for {name}...") - request = FOIALoopFetchRequest(jurisdiction=id_) - fetcher = FOIALoopFetcher(request) - await fetcher.loop_fetch() - all_data.extend(fetcher.ffm.results) - return all_data - - async def get_jurisdiction_ids(self): - dto: MuckrockCountySearchCollectorInputDTO = self.dto - parent_jurisdiction_id = dto.parent_jurisdiction_id - request = JurisdictionLoopFetchRequest( - level="l", - parent=parent_jurisdiction_id, - town_names=dto.town_names - ) - fetcher = JurisdictionGeneratorFetcher(initial_request=request) - async for message in fetcher.generator_fetch(): - await self.log(message) - jurisdiction_ids = fetcher.jfm.jurisdictions - return jurisdiction_ids diff --git a/src/collectors/source_collectors/muckrock/collectors/simple/core.py b/src/collectors/source_collectors/muckrock/collectors/simple/core.py deleted file mode 100644 index 2776a69e..00000000 --- a/src/collectors/source_collectors/muckrock/collectors/simple/core.py +++ /dev/null @@ -1,58 +0,0 @@ -import itertools - -from src.collectors.enums import CollectorType -from src.collectors.source_collectors.base import AsyncCollectorBase -from src.collectors.source_collectors.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO -from src.collectors.source_collectors.muckrock.collectors.simple.searcher import FOIASearcher -from src.collectors.source_collectors.muckrock.fetchers.foia.core import FOIAFetcher -from src.collectors.source_collectors.muckrock.exceptions import SearchCompleteException -from src.core.preprocessors.muckrock import MuckrockPreprocessor - - -class MuckrockSimpleSearchCollector(AsyncCollectorBase): - """ - Performs searches on MuckRock's database - by matching a search string to title of request - """ - collector_type = CollectorType.MUCKROCK_SIMPLE_SEARCH - preprocessor = MuckrockPreprocessor - - def check_for_count_break(self, count, max_count) -> None: - if max_count is None: - return - if count >= max_count: - raise SearchCompleteException - - async def run_implementation(self) -> None: - fetcher = FOIAFetcher() - dto: MuckrockSimpleSearchCollectorInputDTO = self.dto - searcher = FOIASearcher( - fetcher=fetcher, - search_term=dto.search_string - ) - max_count = dto.max_results - all_results = [] - results_count = 0 - for search_count in itertools.count(): - try: - results = await searcher.get_next_page_results() - all_results.extend(results) - results_count += len(results) - self.check_for_count_break(results_count, max_count) - except SearchCompleteException: - break - await self.log(f"Search {search_count}: Found {len(results)} results") - - await self.log(f"Search Complete. Total results: {results_count}") - self.data = {"urls": self.format_results(all_results)} - - def format_results(self, results: list[dict]) -> list[dict]: - formatted_results = [] - for result in results: - formatted_result = { - "url": result["absolute_url"], - "metadata": result - } - formatted_results.append(formatted_result) - - return formatted_results diff --git a/src/collectors/source_collectors/muckrock/collectors/simple/searcher.py b/src/collectors/source_collectors/muckrock/collectors/simple/searcher.py deleted file mode 100644 index 3bb13617..00000000 --- a/src/collectors/source_collectors/muckrock/collectors/simple/searcher.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import Optional - -from src.collectors.source_collectors.muckrock.fetchers.foia.core import FOIAFetcher -from src.collectors.source_collectors.muckrock.exceptions import SearchCompleteException - - -class FOIASearcher: - """ - Used for searching FOIA data from MuckRock - """ - - def __init__(self, fetcher: FOIAFetcher, search_term: Optional[str] = None): - self.fetcher = fetcher - self.search_term = search_term - - async def fetch_page(self) -> list[dict] | None: - """ - Fetches the next page of results using the fetcher. - """ - data = await self.fetcher.fetch_next_page() - if data is None or data.get("results") is None: - return None - return data.get("results") - - def filter_results(self, results: list[dict]) -> list[dict]: - """ - Filters the results based on the search term. - Override or modify as needed for custom filtering logic. - """ - if self.search_term: - return [result for result in results if self.search_term.lower() in result["title"].lower()] - return results - - - async def get_next_page_results(self) -> list[dict]: - """ - Fetches and processes the next page of results. - """ - results = await self.fetch_page() - if not results: - raise SearchCompleteException - return self.filter_results(results) - diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/foia.py b/src/collectors/source_collectors/muckrock/fetch_requests/foia.py deleted file mode 100644 index 1f0bffec..00000000 --- a/src/collectors/source_collectors/muckrock/fetch_requests/foia.py +++ /dev/null @@ -1,6 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest - - -class FOIAFetchRequest(FetchRequest): - page: int - page_size: int diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/foia_loop.py b/src/collectors/source_collectors/muckrock/fetch_requests/foia_loop.py deleted file mode 100644 index 54c063b6..00000000 --- a/src/collectors/source_collectors/muckrock/fetch_requests/foia_loop.py +++ /dev/null @@ -1,5 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest - - -class FOIALoopFetchRequest(FetchRequest): - jurisdiction: int diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_by_id.py b/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_by_id.py deleted file mode 100644 index 7825ade6..00000000 --- a/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_by_id.py +++ /dev/null @@ -1,5 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest - - -class JurisdictionByIDFetchRequest(FetchRequest): - jurisdiction_id: int diff --git a/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_loop.py b/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_loop.py deleted file mode 100644 index a39da62d..00000000 --- a/src/collectors/source_collectors/muckrock/fetch_requests/jurisdiction_loop.py +++ /dev/null @@ -1,7 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.base import FetchRequest - - -class JurisdictionLoopFetchRequest(FetchRequest): - level: str - parent: int - town_names: list diff --git a/src/collectors/source_collectors/muckrock/fetchers/foia/core.py b/src/collectors/source_collectors/muckrock/fetchers/foia/core.py deleted file mode 100644 index 5717f112..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/foia/core.py +++ /dev/null @@ -1,36 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.foia import FOIAFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.templates.fetcher import MuckrockFetcherBase -from src.collectors.source_collectors.muckrock.constants import BASE_MUCKROCK_URL - -FOIA_BASE_URL = f"{BASE_MUCKROCK_URL}/foia" - - -class FOIAFetcher(MuckrockFetcherBase): - """ - A fetcher for FOIA requests. - Iterates through all FOIA requests available through the MuckRock FOIA API. - """ - - def __init__(self, start_page: int = 1, per_page: int = 100): - """ - Constructor for the FOIAFetcher class. - - Args: - start_page (int): The page number to start fetching from (default is 1). - per_page (int): The number of results to fetch per page (default is 100). - """ - self.current_page = start_page - self.per_page = per_page - - def build_url(self, request: FOIAFetchRequest) -> str: - return f"{FOIA_BASE_URL}?page={request.page}&page_size={request.page_size}&format=json" - - async def fetch_next_page(self) -> dict | None: - """ - Fetches data from a specific page of the MuckRock FOIA API. - """ - page = self.current_page - self.current_page += 1 - request = FOIAFetchRequest(page=page, page_size=self.per_page) - return await self.fetch(request) - diff --git a/src/collectors/source_collectors/muckrock/fetchers/foia/generator.py b/src/collectors/source_collectors/muckrock/fetchers/foia/generator.py deleted file mode 100644 index 8e4fa7ac..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/foia/generator.py +++ /dev/null @@ -1,16 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests import FOIALoopFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.foia.manager import FOIAFetchManager -from src.collectors.source_collectors.muckrock.fetchers.templates.generator import MuckrockGeneratorFetcher - - -class FOIAGeneratorFetcher(MuckrockGeneratorFetcher): - - def __init__(self, initial_request: FOIALoopFetchRequest): - super().__init__(initial_request) - self.ffm = FOIAFetchManager() - - def process_results(self, results: list[dict]): - self.ffm.process_results(results) - return (f"Loop {self.ffm.loop_count}: " - f"Found {self.ffm.num_found_last_loop} FOIA records;" - f"{self.ffm.num_found} FOIA records found total.") diff --git a/src/collectors/source_collectors/muckrock/fetchers/foia/loop.py b/src/collectors/source_collectors/muckrock/fetchers/foia/loop.py deleted file mode 100644 index ec21810e..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/foia/loop.py +++ /dev/null @@ -1,25 +0,0 @@ -from datasets import tqdm - -from src.collectors.source_collectors.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.foia.manager import FOIAFetchManager -from src.collectors.source_collectors.muckrock.fetchers.templates.loop import MuckrockLoopFetcher - - -class FOIALoopFetcher(MuckrockLoopFetcher): - - def __init__(self, initial_request: FOIALoopFetchRequest): - super().__init__(initial_request) - self.pbar_records = tqdm( - desc="Fetching FOIA records", - unit="record", - ) - self.ffm = FOIAFetchManager() - - def process_results(self, results: list[dict]): - self.ffm.process_results(results) - - def build_url(self, request: FOIALoopFetchRequest): - return self.ffm.build_url(request) - - def report_progress(self): - self.pbar_records.update(self.ffm.num_found_last_loop) diff --git a/src/collectors/source_collectors/muckrock/fetchers/foia/manager.py b/src/collectors/source_collectors/muckrock/fetchers/foia/manager.py deleted file mode 100644 index 7a38caaa..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/foia/manager.py +++ /dev/null @@ -1,20 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.foia_loop import FOIALoopFetchRequest -from src.collectors.source_collectors.muckrock.constants import BASE_MUCKROCK_URL - - -class FOIAFetchManager: - - def __init__(self): - self.num_found = 0 - self.loop_count = 0 - self.num_found_last_loop = 0 - self.results = [] - - def build_url(self, request: FOIALoopFetchRequest): - return f"{BASE_MUCKROCK_URL}/foia/?status=done&jurisdiction={request.jurisdiction}" - - def process_results(self, results: list[dict]): - self.loop_count += 1 - self.num_found_last_loop = len(results) - self.results.extend(results) - self.num_found += len(results) \ No newline at end of file diff --git a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/core.py b/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/core.py deleted file mode 100644 index befbc3e9..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/core.py +++ /dev/null @@ -1,13 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.jurisdiction_by_id import \ - JurisdictionByIDFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.templates.fetcher import MuckrockFetcherBase -from src.collectors.source_collectors.muckrock.constants import BASE_MUCKROCK_URL - - -class JurisdictionByIDFetcher(MuckrockFetcherBase): - - def build_url(self, request: JurisdictionByIDFetchRequest) -> str: - return f"{BASE_MUCKROCK_URL}/jurisdiction/{request.jurisdiction_id}/" - - async def get_jurisdiction(self, jurisdiction_id: int) -> dict: - return await self.fetch(request=JurisdictionByIDFetchRequest(jurisdiction_id=jurisdiction_id)) diff --git a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/generator.py b/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/generator.py deleted file mode 100644 index b285e852..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/generator.py +++ /dev/null @@ -1,17 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.jurisdiction.manager import JurisdictionFetchManager -from src.collectors.source_collectors.muckrock.fetchers.templates.generator import MuckrockGeneratorFetcher - - -class JurisdictionGeneratorFetcher(MuckrockGeneratorFetcher): - - def __init__(self, initial_request: JurisdictionLoopFetchRequest): - super().__init__(initial_request) - self.jfm = JurisdictionFetchManager(town_names=initial_request.town_names) - - def build_url(self, request: JurisdictionLoopFetchRequest) -> str: - return self.jfm.build_url(request) - - def process_results(self, results: list[dict]): - return self.jfm.process_results(results) - diff --git a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/loop.py b/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/loop.py deleted file mode 100644 index 5ca4b900..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/loop.py +++ /dev/null @@ -1,38 +0,0 @@ -from tqdm import tqdm - -from src.collectors.source_collectors.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest -from src.collectors.source_collectors.muckrock.fetchers.jurisdiction.manager import JurisdictionFetchManager -from src.collectors.source_collectors.muckrock.fetchers.templates.loop import MuckrockLoopFetcher - - -class JurisdictionLoopFetcher(MuckrockLoopFetcher): - - def __init__(self, initial_request: JurisdictionLoopFetchRequest): - super().__init__(initial_request) - self.jfm = JurisdictionFetchManager(town_names=initial_request.town_names) - self.pbar_jurisdictions = tqdm( - total=len(self.jfm.town_names), - desc="Fetching jurisdictions", - unit="jurisdiction", - position=0, - leave=False - ) - self.pbar_page = tqdm( - desc="Processing pages", - unit="page", - position=1, - leave=False - ) - - def build_url(self, request: JurisdictionLoopFetchRequest) -> str: - return self.jfm.build_url(request) - - def process_results(self, results: list[dict]): - self.jfm.process_results(results) - - def report_progress(self): - old_num_jurisdictions_found = self.jfm.num_jurisdictions_found - self.jfm.num_jurisdictions_found = len(self.jfm.jurisdictions) - difference = self.jfm.num_jurisdictions_found - old_num_jurisdictions_found - self.pbar_jurisdictions.update(difference) - self.pbar_page.update(1) diff --git a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/manager.py b/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/manager.py deleted file mode 100644 index dfd27569..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/jurisdiction/manager.py +++ /dev/null @@ -1,22 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetch_requests.jurisdiction_loop import JurisdictionLoopFetchRequest -from src.collectors.source_collectors.muckrock.constants import BASE_MUCKROCK_URL - - -class JurisdictionFetchManager: - - def __init__(self, town_names: list[str]): - self.town_names = town_names - self.num_jurisdictions_found = 0 - self.total_found = 0 - self.jurisdictions = {} - - def build_url(self, request: JurisdictionLoopFetchRequest) -> str: - return f"{BASE_MUCKROCK_URL}/jurisdiction/?level={request.level}&parent={request.parent}" - - def process_results(self, results: list[dict]): - for item in results: - if item["name"] in self.town_names: - self.jurisdictions[item["name"]] = item["id"] - self.total_found += 1 - self.num_jurisdictions_found = len(self.jurisdictions) - return f"Found {self.num_jurisdictions_found} jurisdictions; {self.total_found} entries found total." diff --git a/src/collectors/source_collectors/muckrock/fetchers/templates/generator.py b/src/collectors/source_collectors/muckrock/fetchers/templates/generator.py deleted file mode 100644 index 3a6a0e01..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/templates/generator.py +++ /dev/null @@ -1,30 +0,0 @@ -from src.collectors.source_collectors.muckrock.fetchers.templates.iter_fetcher import MuckrockIterFetcherBase -from src.collectors.source_collectors.muckrock.exceptions import RequestFailureException - - -class MuckrockGeneratorFetcher(MuckrockIterFetcherBase): - """ - Similar to the Muckrock Loop fetcher, but behaves - as a generator instead of a loop - """ - - async def generator_fetch(self) -> str: - """ - Fetches data and yields status messages between requests - """ - url = self.build_url(self.initial_request) - final_message = "No more records found. Exiting..." - while url is not None: - try: - data = await self.get_response(url) - except RequestFailureException: - final_message = "Request unexpectedly failed. Exiting..." - break - - yield self.process_results(data["results"]) - url = data["next"] - - yield final_message - - - diff --git a/src/collectors/source_collectors/muckrock/fetchers/templates/loop.py b/src/collectors/source_collectors/muckrock/fetchers/templates/loop.py deleted file mode 100644 index c3b5dc0f..00000000 --- a/src/collectors/source_collectors/muckrock/fetchers/templates/loop.py +++ /dev/null @@ -1,32 +0,0 @@ -from abc import abstractmethod -from time import sleep - -from src.collectors.source_collectors.muckrock.fetchers.templates.iter_fetcher import MuckrockIterFetcherBase -from src.collectors.source_collectors.muckrock.exceptions import RequestFailureException - - -class MuckrockLoopFetcher(MuckrockIterFetcherBase): - - async def loop_fetch(self): - url = self.build_url(self.initial_request) - while url is not None: - try: - data = await self.get_response(url) - except RequestFailureException: - break - - url = self.process_data(data) - sleep(1) - - def process_data(self, data: dict): - """ - Process data and get next url, if any - """ - self.process_results(data["results"]) - self.report_progress() - url = data["next"] - return url - - @abstractmethod - def report_progress(self): - pass diff --git a/src/core/core.py b/src/core/core.py index 78554b39..7d4ac083 100644 --- a/src/core/core.py +++ b/src/core/core.py @@ -3,14 +3,10 @@ from fastapi import HTTPException from pydantic import BaseModel -from sqlalchemy.exc import IntegrityError -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAnnotationResponse -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.api.endpoints.annotate.all.get.dto import GetNextURLForAllAnnotationResponse -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseOuterInfo +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.api.endpoints.annotate.all.post.query import AddAllAnnotationsToURLQueryBuilder from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary @@ -32,22 +28,24 @@ from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse from src.api.endpoints.search.dtos.response import SearchURLResponse from src.api.endpoints.task.by_id.dto import TaskInfo +from src.api.endpoints.task.dtos.get.task_status import GetTaskStatusResponseInfo from src.api.endpoints.task.dtos.get.tasks import GetTasksResponse from src.api.endpoints.url.get.dto import GetURLsResponseInfo -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.batch import BatchInfo -from src.api.endpoints.task.dtos.get.task_status import GetTaskStatusResponseInfo -from src.db.enums import TaskType -from src.collectors.manager import AsyncCollectorManager from src.collectors.enums import CollectorType +from src.collectors.manager import AsyncCollectorManager +from src.core.enums import BatchStatus from src.core.tasks.url.manager import TaskManager -from src.core.error_manager.core import ErrorManager -from src.core.enums import BatchStatus, RecordType, AnnotationType, SuggestedStatus - +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.batch.pydantic.info import BatchInfo +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.security.dtos.access_info import AccessInfo class AsyncCore: + task_manager: TaskManager | None = None + adb_client: AsyncDatabaseClient | None = None + collector_manager: AsyncCollectorManager | None = None def __init__( self, @@ -57,7 +55,6 @@ def __init__( ): self.task_manager = task_manager self.adb_client = adb_client - self.collector_manager = collector_manager @@ -91,16 +88,14 @@ async def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> Get async def get_batch_statuses( self, - collector_type: Optional[CollectorType], - status: Optional[BatchStatus], - has_pending_urls: Optional[bool], + collector_type: CollectorType | None, + status: BatchURLStatusEnum | None, page: int ) -> GetBatchSummariesResponse: results = await self.adb_client.get_batch_summaries( collector_type=collector_type, status=status, page=page, - has_pending_urls=has_pending_urls ) return results @@ -112,10 +107,10 @@ async def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: # region Collector async def initiate_collector( - self, - collector_type: CollectorType, - user_id: int, - dto: Optional[BaseModel] = None, + self, + collector_type: CollectorType, + user_id: int, + dto: BaseModel | None = None, ) -> CollectorStartInfo: """ Reserves a batch ID from the database @@ -159,157 +154,9 @@ async def get_tasks( task_status=task_status ) - async def get_task_info(self, task_id: int) -> TaskInfo: return await self.adb_client.get_task_info(task_id=task_id) - - #region Annotations and Review - - async def submit_url_relevance_annotation( - self, - user_id: int, - url_id: int, - suggested_status: SuggestedStatus - ): - try: - return await self.adb_client.add_user_relevant_suggestion( - user_id=user_id, - url_id=url_id, - suggested_status=suggested_status - ) - except IntegrityError: - return await ErrorManager.raise_annotation_exists_error( - annotation_type=AnnotationType.RELEVANCE, - url_id=url_id - ) - - async def get_next_url_for_relevance_annotation( - self, - user_id: int, - batch_id: Optional[int] - ) -> GetNextRelevanceAnnotationResponseOuterInfo: - next_annotation = await self.adb_client.get_next_url_for_relevance_annotation( - user_id=user_id, - batch_id=batch_id - ) - return GetNextRelevanceAnnotationResponseOuterInfo( - next_annotation=next_annotation - ) - - async def get_next_url_for_record_type_annotation( - self, - user_id: int, - batch_id: Optional[int] - ) -> GetNextRecordTypeAnnotationResponseOuterInfo: - next_annotation = await self.adb_client.get_next_url_for_record_type_annotation( - user_id=user_id, - batch_id=batch_id - ) - return GetNextRecordTypeAnnotationResponseOuterInfo( - next_annotation=next_annotation - ) - - async def submit_url_record_type_annotation( - self, - user_id: int, - url_id: int, - record_type: RecordType, - ): - try: - return await self.adb_client.add_user_record_type_suggestion( - user_id=user_id, - url_id=url_id, - record_type=record_type - ) - except IntegrityError: - return await ErrorManager.raise_annotation_exists_error( - annotation_type=AnnotationType.RECORD_TYPE, - url_id=url_id - ) - - - async def get_next_url_agency_for_annotation( - self, - user_id: int, - batch_id: Optional[int] - ) -> GetNextURLForAgencyAnnotationResponse: - return await self.adb_client.get_next_url_agency_for_annotation( - user_id=user_id, - batch_id=batch_id - ) - - async def submit_url_agency_annotation( - self, - user_id: int, - url_id: int, - agency_post_info: URLAgencyAnnotationPostInfo - ) -> GetNextURLForAgencyAnnotationResponse: - if not agency_post_info.is_new and not agency_post_info.suggested_agency: - raise ValueError("suggested_agency must be provided if is_new is False") - - if agency_post_info.is_new: - agency_suggestion_id = None - else: - agency_suggestion_id = agency_post_info.suggested_agency - return await self.adb_client.add_agency_manual_suggestion( - user_id=user_id, - url_id=url_id, - agency_id=agency_suggestion_id, - is_new=agency_post_info.is_new, - ) - - async def get_next_source_for_review( - self, - batch_id: Optional[int] - ) -> GetNextURLForFinalReviewOuterResponse: - return await self.adb_client.get_next_url_for_final_review( - batch_id=batch_id - ) - - async def get_next_url_for_all_annotations( - self, - batch_id: Optional[int] - ) -> GetNextURLForAllAnnotationResponse: - return await self.adb_client.get_next_url_for_all_annotations( - batch_id=batch_id - ) - - async def submit_url_for_all_annotations( - self, - user_id: int, - url_id: int, - post_info: AllAnnotationPostInfo - ): - await self.adb_client.add_all_annotations_to_url( - user_id=user_id, - url_id=url_id, - post_info=post_info - ) - - async def approve_url( - self, - approval_info: FinalReviewApprovalInfo, - access_info: AccessInfo - ): - await self.adb_client.approve_url( - approval_info=approval_info, - user_id=access_info.user_id - ) - - - async def reject_url( - self, - url_id: int, - access_info: AccessInfo, - rejection_reason: RejectionReason - ): - await self.adb_client.reject_url( - url_id=url_id, - user_id=access_info.user_id, - rejection_reason=rejection_reason - ) - async def upload_manual_batch( self, dto: ManualBatchInputDTO, diff --git a/src/core/enums.py b/src/core/enums.py index c6f90c80..fa64a5cb 100644 --- a/src/core/enums.py +++ b/src/core/enums.py @@ -16,6 +16,7 @@ class RecordType(Enum): """ All available URL record types """ + # Police and Public ACCIDENT_REPORTS = "Accident Reports" ARREST_RECORDS = "Arrest Records" CALLS_FOR_SERVICE = "Calls for Service" @@ -31,16 +32,21 @@ class RecordType(Enum): SURVEYS = "Surveys" USE_OF_FORCE_REPORTS = "Use of Force Reports" VEHICLE_PURSUITS = "Vehicle Pursuits" + + # Info About Officers COMPLAINTS_AND_MISCONDUCT = "Complaints & Misconduct" DAILY_ACTIVITY_LOGS = "Daily Activity Logs" TRAINING_AND_HIRING_INFO = "Training & Hiring Info" PERSONNEL_RECORDS = "Personnel Records" + + # Info About Agencies ANNUAL_AND_MONTHLY_REPORTS = "Annual & Monthly Reports" BUDGETS_AND_FINANCES = "Budgets & Finances" - CONTACT_INFO_AND_AGENCY_META = "Contact Info & Agency Meta" GEOGRAPHIC = "Geographic" LIST_OF_DATA_SOURCES = "List of Data Sources" POLICIES_AND_CONTRACTS = "Policies & Contracts" + + # Agency-Published Resources CRIME_MAPS_AND_REPORTS = "Crime Maps & Reports" CRIME_STATISTICS = "Crime Statistics" MEDIA_BULLETINS = "Media Bulletins" @@ -48,9 +54,13 @@ class RecordType(Enum): RESOURCES = "Resources" SEX_OFFENDER_REGISTRY = "Sex Offender Registry" WANTED_PERSONS = "Wanted Persons" + + # Jails and Courts Specific BOOKING_REPORTS = "Booking Reports" COURT_CASES = "Court Cases" INCARCERATION_RECORDS = "Incarceration Records" + + # Other OTHER = "Other" @@ -71,12 +81,3 @@ class SubmitResponseStatus(Enum): SUCCESS = "success" FAILURE = "FAILURE" ALREADY_EXISTS = "already_exists" - -class SuggestedStatus(Enum): - """ - Possible values for user_relevant_suggestions:suggested_status - """ - RELEVANT = "relevant" - NOT_RELEVANT = "not relevant" - INDIVIDUAL_RECORD = "individual record" - BROKEN_PAGE_404 = "broken page/404 not found" \ No newline at end of file diff --git a/src/core/env_var_manager.py b/src/core/env_var_manager.py index 8fce7ac3..cbf424ec 100644 --- a/src/core/env_var_manager.py +++ b/src/core/env_var_manager.py @@ -16,7 +16,8 @@ def __init__(self, env: dict = os.environ): self.env = env self._load() - def _load(self): + def _load(self) -> None: + """Load environment variables from environment""" self.google_api_key = self.require_env("GOOGLE_API_KEY") self.google_cse_id = self.require_env("GOOGLE_CSE_ID") @@ -30,6 +31,7 @@ def _load(self): self.openai_api_key = self.require_env("OPENAI_API_KEY") self.hf_inference_api_key = self.require_env("HUGGINGFACE_INFERENCE_API_KEY") + self.hf_hub_token = self.require_env("HUGGINGFACE_HUB_TOKEN") self.postgres_user = self.require_env("POSTGRES_USER") self.postgres_password = self.require_env("POSTGRES_PASSWORD") diff --git a/src/core/exceptions.py b/src/core/exceptions.py index e3e93e55..a361a24d 100644 --- a/src/core/exceptions.py +++ b/src/core/exceptions.py @@ -3,10 +3,6 @@ from fastapi import HTTPException -class InvalidPreprocessorError(Exception): - pass - - class MuckrockAPIError(Exception): pass @@ -17,4 +13,5 @@ class MatchAgencyError(Exception): class FailedValidationException(HTTPException): def __init__(self, detail: str): - super().__init__(status_code=HTTPStatus.BAD_REQUEST, detail=detail) \ No newline at end of file + super().__init__(status_code=HTTPStatus.BAD_REQUEST, detail=detail) + diff --git a/src/core/helpers.py b/src/core/helpers.py deleted file mode 100644 index eeb951fe..00000000 --- a/src/core/helpers.py +++ /dev/null @@ -1,48 +0,0 @@ -from src.core.enums import SuggestionType -from src.core.exceptions import MatchAgencyError -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse -from src.external.pdap.enums import MatchAgencyResponseStatus - - -def process_match_agency_response_to_suggestions( - url_id: int, - match_agency_response: MatchAgencyResponse -) -> list[URLAgencySuggestionInfo]: - if match_agency_response.status == MatchAgencyResponseStatus.EXACT_MATCH: - match = match_agency_response.matches[0] - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=int(match.id), - agency_name=match.submitted_name, - state=match.state, - county=match.county, - ) - ] - if match_agency_response.status == MatchAgencyResponseStatus.NO_MATCH: - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.UNKNOWN, - ) - ] - - if match_agency_response.status != MatchAgencyResponseStatus.PARTIAL_MATCH: - raise MatchAgencyError( - f"Unknown Match Agency Response Status: {match_agency_response.status}" - ) - - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=match.id, - agency_name=match.submitted_name, - state=match.state, - county=match.county, - locality=match.locality - ) - for match in match_agency_response.matches - ] diff --git a/src/core/logger.py b/src/core/logger.py index e49dd057..22f35492 100644 --- a/src/core/logger.py +++ b/src/core/logger.py @@ -1,7 +1,7 @@ import asyncio from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.log import LogInfo +from src.db.models.impl.log.pydantic.info import LogInfo class AsyncCoreLogger: diff --git a/src/core/preprocessors/autogoogler.py b/src/core/preprocessors/autogoogler.py index e827c77d..e3771f2c 100644 --- a/src/core/preprocessors/autogoogler.py +++ b/src/core/preprocessors/autogoogler.py @@ -1,7 +1,8 @@ from typing import List -from src.db.dtos.url.core import URLInfo from src.core.preprocessors.base import PreprocessorBase +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo class AutoGooglerPreprocessor(PreprocessorBase): @@ -18,6 +19,7 @@ def preprocess_entry(self, entry: dict) -> list[URLInfo]: "snippet": qr["snippet"], "title": qr["title"] }, + source=URLSource.COLLECTOR )) return url_infos diff --git a/src/core/preprocessors/base.py b/src/core/preprocessors/base.py index dea8df10..16d9432b 100644 --- a/src/core/preprocessors/base.py +++ b/src/core/preprocessors/base.py @@ -2,7 +2,7 @@ from abc import ABC from typing import List -from src.db.dtos.url.core import URLInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo class PreprocessorBase(ABC): diff --git a/src/core/preprocessors/ckan.py b/src/core/preprocessors/ckan.py index c07d4ab5..671134c2 100644 --- a/src/core/preprocessors/ckan.py +++ b/src/core/preprocessors/ckan.py @@ -1,7 +1,7 @@ from datetime import datetime from typing import List -from src.db.dtos.url.core import URLInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo class CKANPreprocessor: diff --git a/src/core/preprocessors/common_crawler.py b/src/core/preprocessors/common_crawler.py index 9a7e1d04..d831c520 100644 --- a/src/core/preprocessors/common_crawler.py +++ b/src/core/preprocessors/common_crawler.py @@ -1,7 +1,8 @@ from typing import List -from src.db.dtos.url.core import URLInfo from src.core.preprocessors.base import PreprocessorBase +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo class CommonCrawlerPreprocessor(PreprocessorBase): @@ -12,6 +13,7 @@ def preprocess(self, data: dict) -> List[URLInfo]: for url in data["urls"]: url_info = URLInfo( url=url, + source=URLSource.COLLECTOR ) url_infos.append(url_info) diff --git a/src/core/preprocessors/example.py b/src/core/preprocessors/example.py index dfc7338a..31e68e44 100644 --- a/src/core/preprocessors/example.py +++ b/src/core/preprocessors/example.py @@ -1,8 +1,9 @@ from typing import List -from src.db.dtos.url.core import URLInfo -from src.collectors.source_collectors.example.dtos.output import ExampleOutputDTO +from src.collectors.impl.example.dtos.output import ExampleOutputDTO from src.core.preprocessors.base import PreprocessorBase +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo class ExamplePreprocessor(PreprocessorBase): @@ -12,6 +13,7 @@ def preprocess(self, data: ExampleOutputDTO) -> List[URLInfo]: for url in data.urls: url_info = URLInfo( url=url, + source=URLSource.COLLECTOR ) url_infos.append(url_info) diff --git a/src/core/preprocessors/muckrock.py b/src/core/preprocessors/muckrock.py index 281ea2f8..1e05395a 100644 --- a/src/core/preprocessors/muckrock.py +++ b/src/core/preprocessors/muckrock.py @@ -1,7 +1,8 @@ from typing import List -from src.db.dtos.url.core import URLInfo from src.core.preprocessors.base import PreprocessorBase +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo class MuckrockPreprocessor(PreprocessorBase): @@ -12,6 +13,7 @@ def preprocess(self, data: dict) -> List[URLInfo]: url_info = URLInfo( url=entry["url"], collector_metadata=entry["metadata"], + source=URLSource.COLLECTOR ) url_infos.append(url_info) diff --git a/src/core/tasks/base/operator.py b/src/core/tasks/base/operator.py index ba7a3d3a..51f07a47 100644 --- a/src/core/tasks/base/operator.py +++ b/src/core/tasks/base/operator.py @@ -1,16 +1,30 @@ import traceback from abc import ABC, abstractmethod +from src.core.enums import BatchStatus from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from src.db.client.async_ import AsyncDatabaseClient from src.db.enums import TaskType +from src.db.models.impl.task.enums import TaskStatus +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall class TaskOperatorBase(ABC): def __init__(self, adb_client: AsyncDatabaseClient): - self.adb_client = adb_client - self.task_id = None + self._adb_client = adb_client + self._task_id: int | None = None + + @property + def task_id(self) -> int: + if self._task_id is None: + raise AttributeError("Task id is not set. Call initiate_task_in_db() first.") + return self._task_id + + @property + def adb_client(self) -> AsyncDatabaseClient: + return self._adb_client @property @abstractmethod @@ -27,8 +41,8 @@ async def initiate_task_in_db(self) -> int: async def conclude_task(self): raise NotImplementedError - async def run_task(self, task_id: int) -> TaskOperatorRunInfo: - self.task_id = task_id + async def run_task(self) -> TaskOperatorRunInfo: + self._task_id = await self.initiate_task_in_db() try: await self.inner_task_logic() return await self.conclude_task() @@ -45,12 +59,27 @@ async def run_info(self, outcome: TaskOperatorOutcome, message: str) -> TaskOper @abstractmethod - async def inner_task_logic(self): + async def inner_task_logic(self) -> None: raise NotImplementedError async def handle_task_error(self, e): - await self.adb_client.update_task_status(task_id=self.task_id, status=BatchStatus.ERROR) + await self.adb_client.update_task_status(task_id=self.task_id, status=TaskStatus.ERROR) await self.adb_client.add_task_error( task_id=self.task_id, error=str(e) ) + + async def add_task_errors( + self, + errors: list[URLTaskErrorSmall] + ) -> None: + inserts: list[URLTaskErrorPydantic] = [ + URLTaskErrorPydantic( + task_id=self.task_id, + url_id=error.url_id, + task_type=self.task_type, + error=error.error + ) + for error in errors + ] + await self.adb_client.bulk_insert(inserts) \ No newline at end of file diff --git a/src/core/tasks/base/run_info.py b/src/core/tasks/base/run_info.py index b822c59f..78e6b357 100644 --- a/src/core/tasks/base/run_info.py +++ b/src/core/tasks/base/run_info.py @@ -7,7 +7,7 @@ class TaskOperatorRunInfo(BaseModel): - task_id: Optional[int] + task_id: int | None task_type: TaskType outcome: TaskOperatorOutcome message: str = "" \ No newline at end of file diff --git a/src/core/tasks/dtos/run_info.py b/src/core/tasks/dtos/run_info.py deleted file mode 100644 index 2296f65b..00000000 --- a/src/core/tasks/dtos/run_info.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from src.core.tasks.base.run_info import TaskOperatorRunInfo -from src.core.tasks.url.enums import TaskOperatorOutcome - - -class URLTaskOperatorRunInfo(TaskOperatorRunInfo): - linked_url_ids: list[int] diff --git a/src/core/tasks/handler.py b/src/core/tasks/handler.py index 3e3aca77..92b96103 100644 --- a/src/core/tasks/handler.py +++ b/src/core/tasks/handler.py @@ -4,10 +4,10 @@ from src.core.enums import BatchStatus from src.core.tasks.base.run_info import TaskOperatorRunInfo -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from src.db.client.async_ import AsyncDatabaseClient from src.db.enums import TaskType +from src.db.models.impl.task.enums import TaskStatus class TaskHandler: @@ -15,7 +15,7 @@ class TaskHandler: def __init__( self, adb_client: AsyncDatabaseClient, - discord_poster: DiscordPoster + discord_poster: DiscordPoster | None ): self.adb_client = adb_client self.discord_poster = discord_poster @@ -25,7 +25,10 @@ def __init__( self.logger.setLevel(logging.INFO) - async def post_to_discord(self, message: str): + async def post_to_discord(self, message: str) -> None: + if self.discord_poster is None: + print("Post to Discord disabled by POST_TO_DISCORD_FLAG") + return self.discord_poster.post_to_discord(message=message) async def initiate_task_in_db(self, task_type: TaskType) -> int: # @@ -40,19 +43,23 @@ async def handle_outcome(self, run_info: TaskOperatorRunInfo): # case TaskOperatorOutcome.SUCCESS: await self.adb_client.update_task_status( task_id=run_info.task_id, - status=BatchStatus.READY_TO_LABEL + status=TaskStatus.COMPLETE ) async def handle_task_error(self, run_info: TaskOperatorRunInfo): # await self.adb_client.update_task_status( task_id=run_info.task_id, - status=BatchStatus.ERROR) + status=TaskStatus.ERROR + ) await self.adb_client.add_task_error( task_id=run_info.task_id, error=run_info.message ) - self.discord_poster.post_to_discord( - message=f"Task {run_info.task_id} ({run_info.task_type.value}) failed with error.") + msg: str = f"Task {run_info.task_id} ({run_info.task_type.value}) failed with error: {run_info.message[:100]}..." + print(msg) + await self.post_to_discord( + message=msg + ) async def link_urls_to_task(self, task_id: int, url_ids: list[int]): await self.adb_client.link_urls_to_task( diff --git a/src/core/tasks/mixins/__init__.py b/src/core/tasks/mixins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/mixins/link_urls.py b/src/core/tasks/mixins/link_urls.py new file mode 100644 index 00000000..f58a3dff --- /dev/null +++ b/src/core/tasks/mixins/link_urls.py @@ -0,0 +1,43 @@ +from abc import abstractmethod + +from src.db.client.async_ import AsyncDatabaseClient + + +class LinkURLsMixin: + + def __init__( + self, + *args, + **kwargs + ): + super().__init__(*args, **kwargs) + self._urls_linked = False + self._linked_url_ids = [] + + @property + def urls_linked(self) -> bool: + return self._urls_linked + + @property + def linked_url_ids(self) -> list[int]: + return self._linked_url_ids + + @property + @abstractmethod + def adb_client(self) -> AsyncDatabaseClient: + raise NotImplementedError + + @property + @abstractmethod + def task_id(self) -> int: + raise NotImplementedError + + async def link_urls_to_task(self, url_ids: list[int]): + self._linked_url_ids = url_ids + if not hasattr(self, "linked_url_ids"): + raise AttributeError("Class does not have linked_url_ids attribute") + await self.adb_client.link_urls_to_task( + task_id=self.task_id, + url_ids=url_ids + ) + self._urls_linked = True \ No newline at end of file diff --git a/src/core/tasks/mixins/prereq.py b/src/core/tasks/mixins/prereq.py new file mode 100644 index 00000000..dcfec66b --- /dev/null +++ b/src/core/tasks/mixins/prereq.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + + +class HasPrerequisitesMixin(ABC): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @abstractmethod + async def meets_task_prerequisites(self) -> bool: + """ + A task should not be initiated unless certain + conditions are met + """ + raise NotImplementedError \ No newline at end of file diff --git a/src/core/tasks/scheduled/enums.py b/src/core/tasks/scheduled/enums.py new file mode 100644 index 00000000..e011ab6e --- /dev/null +++ b/src/core/tasks/scheduled/enums.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class IntervalEnum(Enum): + DAILY = 60 * 24 + HOURLY = 60 + TEN_MINUTES = 10 \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/__init__.py b/src/core/tasks/scheduled/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/backlog/__init__.py b/src/core/tasks/scheduled/impl/backlog/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/backlog/operator.py b/src/core/tasks/scheduled/impl/backlog/operator.py new file mode 100644 index 00000000..d628c91c --- /dev/null +++ b/src/core/tasks/scheduled/impl/backlog/operator.py @@ -0,0 +1,16 @@ +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType + + +class PopulateBacklogSnapshotTaskOperator(ScheduledTaskOperatorBase): + + def __init__(self, adb_client: AsyncDatabaseClient): + super().__init__(adb_client) + + @property + def task_type(self) -> TaskType: + return TaskType.POPULATE_BACKLOG_SNAPSHOT + + async def inner_task_logic(self) -> None: + await self.adb_client.populate_backlog_snapshot() \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/delete_logs/__init__.py b/src/core/tasks/scheduled/impl/delete_logs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/delete_logs/operator.py b/src/core/tasks/scheduled/impl/delete_logs/operator.py new file mode 100644 index 00000000..41be3af9 --- /dev/null +++ b/src/core/tasks/scheduled/impl/delete_logs/operator.py @@ -0,0 +1,21 @@ +import datetime + +from sqlalchemy import delete + +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.log.sqlalchemy import Log + + +class DeleteOldLogsTaskOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.DELETE_OLD_LOGS + + async def inner_task_logic(self) -> None: + statement = delete(Log).where( + Log.created_at < datetime.datetime.now() - datetime.timedelta(days=7) + ) + await self.adb_client.execute(statement) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/delete_stale_screenshots/__init__.py b/src/core/tasks/scheduled/impl/delete_stale_screenshots/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/delete_stale_screenshots/operator.py b/src/core/tasks/scheduled/impl/delete_stale_screenshots/operator.py new file mode 100644 index 00000000..0c386cfe --- /dev/null +++ b/src/core/tasks/scheduled/impl/delete_stale_screenshots/operator.py @@ -0,0 +1,15 @@ +from src.core.tasks.scheduled.impl.delete_stale_screenshots.query import DeleteStaleScreenshotsQueryBuilder +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.enums import TaskType + + +class DeleteStaleScreenshotsTaskOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.DELETE_STALE_SCREENSHOTS + + async def inner_task_logic(self) -> None: + await self.adb_client.run_query_builder( + DeleteStaleScreenshotsQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/delete_stale_screenshots/query.py b/src/core/tasks/scheduled/impl/delete_stale_screenshots/query.py new file mode 100644 index 00000000..624f44c5 --- /dev/null +++ b/src/core/tasks/scheduled/impl/delete_stale_screenshots/query.py @@ -0,0 +1,31 @@ +from typing import Any + +from sqlalchemy import delete, exists, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from src.db.queries.base.builder import QueryBuilderBase + + +class DeleteStaleScreenshotsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> Any: + + statement = ( + delete( + URLScreenshot + ) + .where( + exists( + select( + FlagURLValidated + ) + .where( + FlagURLValidated.url_id == URLScreenshot.url_id, + ) + ) + ) + ) + + await session.execute(statement) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/huggingface/__init__.py b/src/core/tasks/scheduled/impl/huggingface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/huggingface/operator.py b/src/core/tasks/scheduled/impl/huggingface/operator.py new file mode 100644 index 00000000..9bb7a85e --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/operator.py @@ -0,0 +1,49 @@ +from itertools import count + +from src.core.tasks.mixins.prereq import HasPrerequisitesMixin +from src.core.tasks.scheduled.impl.huggingface.queries.check.core import CheckValidURLsUpdatedQueryBuilder +from src.core.tasks.scheduled.impl.huggingface.queries.get.core import GetForLoadingToHuggingFaceQueryBuilder +from src.core.tasks.scheduled.impl.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.external.huggingface.hub.client import HuggingFaceHubClient + + +class PushToHuggingFaceTaskOperator( + ScheduledTaskOperatorBase, + HasPrerequisitesMixin +): + + @property + def task_type(self) -> TaskType: + return TaskType.PUSH_TO_HUGGINGFACE + + def __init__( + self, + adb_client: AsyncDatabaseClient, + hf_client: HuggingFaceHubClient + ): + super().__init__(adb_client) + self.hf_client = hf_client + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + CheckValidURLsUpdatedQueryBuilder() + ) + + async def inner_task_logic(self): + """Push raw data sources to huggingface.""" + run_dt = await self.adb_client.get_current_database_time() + for idx in count(start=1): + outputs: list[GetForLoadingToHuggingFaceOutput] = await self._get_data_sources_raw_for_huggingface(page=idx) + if len(outputs) == 0: + break + self.hf_client.push_data_sources_raw_to_hub(outputs, idx=idx) + + await self.adb_client.set_hugging_face_upload_state(run_dt.replace(tzinfo=None)) + + async def _get_data_sources_raw_for_huggingface(self, page: int) -> list[GetForLoadingToHuggingFaceOutput]: + return await self.adb_client.run_query_builder( + GetForLoadingToHuggingFaceQueryBuilder(page) + ) diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/__init__.py b/src/core/tasks/scheduled/impl/huggingface/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/check/__init__.py b/src/core/tasks/scheduled/impl/huggingface/queries/check/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/check/core.py b/src/core/tasks/scheduled/impl/huggingface/queries/check/core.py new file mode 100644 index 00000000..c76fa2e1 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/check/core.py @@ -0,0 +1,14 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.huggingface.queries.check.requester import CheckValidURLsUpdatedRequester +from src.db.queries.base.builder import QueryBuilderBase + + +class CheckValidURLsUpdatedQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + requester = CheckValidURLsUpdatedRequester(session=session) + latest_upload = await requester.latest_upload() + return await requester.has_valid_urls(latest_upload) + + diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/check/requester.py b/src/core/tasks/scheduled/impl/huggingface/queries/check/requester.py new file mode 100644 index 00000000..ef43bd3d --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/check/requester.py @@ -0,0 +1,52 @@ +from datetime import datetime +from operator import or_ + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql.functions import count + +from src.collectors.enums import URLStatus +from src.db.enums import TaskType +from src.db.helpers.query import not_exists_url, no_url_task_error, exists_url +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.state.huggingface import HuggingFaceUploadState +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.core.sqlalchemy import URL + + +class CheckValidURLsUpdatedRequester: + + def __init__(self, session: AsyncSession): + self.session = session + + async def latest_upload(self) -> datetime: + query = ( + select( + HuggingFaceUploadState.last_upload_at + ) + ) + return await sh.scalar( + session=self.session, + query=query + ) + + async def has_valid_urls(self, last_upload_at: datetime | None) -> bool: + query = ( + select(count(URL.id)) + .join( + URLCompressedHTML, + URL.id == URLCompressedHTML.url_id + ) + .where( + exists_url(FlagURLValidated), + no_url_task_error(TaskType.PUSH_TO_HUGGINGFACE) + ) + ) + if last_upload_at is not None: + query = query.where(URL.updated_at > last_upload_at) + url_count = await sh.scalar( + session=self.session, + query=query + ) + return url_count > 0 diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/__init__.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/convert.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/convert.py new file mode 100644 index 00000000..41926fe4 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/get/convert.py @@ -0,0 +1,22 @@ +from src.core.enums import RecordType +from src.core.tasks.scheduled.impl.huggingface.queries.get.enums import RecordTypeCoarse +from src.core.tasks.scheduled.impl.huggingface.queries.get.mappings import FINE_COARSE_RECORD_TYPE_MAPPING +from src.db.models.impl.flag.url_validated.enums import URLType + + +def convert_fine_to_coarse_record_type( + fine_record_type: RecordType +) -> RecordTypeCoarse: + return FINE_COARSE_RECORD_TYPE_MAPPING[fine_record_type] + + +def convert_validated_type_to_relevant( + validated_type: URLType +) -> bool: + match validated_type: + case URLType.NOT_RELEVANT: + return False + case URLType.DATA_SOURCE: + return True + case _: + raise ValueError(f"Disallowed validated type: {validated_type}") \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/core.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/core.py new file mode 100644 index 00000000..5b6bd08d --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/get/core.py @@ -0,0 +1,81 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.huggingface.queries.get.convert import convert_fine_to_coarse_record_type, \ + convert_validated_type_to_relevant +from src.core.tasks.scheduled.impl.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput +from src.db.client.helpers import add_standard_limit_and_offset +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.queries.base.builder import QueryBuilderBase +from src.db.utils.compression import decompress_html + + +class GetForLoadingToHuggingFaceQueryBuilder(QueryBuilderBase): + + def __init__(self, page: int): + super().__init__() + self.page = page + + + async def run(self, session: AsyncSession) -> list[GetForLoadingToHuggingFaceOutput]: + label_url_id = 'url_id' + label_url = 'url' + label_record_type_fine = 'record_type_fine' + label_html = 'html' + label_type = 'type' + + + query = ( + select( + URL.id.label(label_url_id), + URL.url.label(label_url), + URLRecordType.record_type.label(label_record_type_fine), + URLCompressedHTML.compressed_html.label(label_html), + FlagURLValidated.type.label(label_type) + ) + .join( + URLRecordType, + URL.id == URLRecordType.url_id + ) + .join( + URLCompressedHTML, + URL.id == URLCompressedHTML.url_id + ) + .outerjoin( + FlagURLValidated, + URL.id == FlagURLValidated.url_id + ) + .where( + FlagURLValidated.type.in_( + (URLType.DATA_SOURCE, + URLType.NOT_RELEVANT) + ) + ) + ) + query = add_standard_limit_and_offset(page=self.page, statement=query) + db_results = await sh.mappings( + session=session, + query=query + ) + final_results = [] + for result in db_results: + output = GetForLoadingToHuggingFaceOutput( + url_id=result[label_url_id], + url=result[label_url], + relevant=convert_validated_type_to_relevant( + URLType(result[label_type]) + ), + record_type_fine=result[label_record_type_fine], + record_type_coarse=convert_fine_to_coarse_record_type( + result[label_record_type_fine] + ), + html=decompress_html(result[label_html]) + ) + final_results.append(output) + + return final_results diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/enums.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/enums.py new file mode 100644 index 00000000..86e1c511 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/get/enums.py @@ -0,0 +1,12 @@ +from enum import Enum + + +class RecordTypeCoarse(Enum): + INFO_ABOUT_AGENCIES = "Info About Agencies" + INFO_ABOUT_OFFICERS = "Info About Officers" + AGENCY_PUBLISHED_RESOURCES = "Agency-Published Resources" + POLICE_AND_PUBLIC = "Police & Public Interactions" + POOR_DATA_SOURCE = "Poor Data Source" + NOT_RELEVANT = "Not Relevant" + JAILS_AND_COURTS = "Jails & Courts Specific" + OTHER = "Other" \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/mappings.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/mappings.py new file mode 100644 index 00000000..0621ee52 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/get/mappings.py @@ -0,0 +1,48 @@ +from src.collectors.enums import URLStatus +from src.core.enums import RecordType +from src.core.tasks.scheduled.impl.huggingface.queries.get.enums import RecordTypeCoarse + +FINE_COARSE_RECORD_TYPE_MAPPING = { + # Police and Public + RecordType.ACCIDENT_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.ARREST_RECORDS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CALLS_FOR_SERVICE: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CAR_GPS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.CITATIONS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.DISPATCH_LOGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.DISPATCH_RECORDINGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.FIELD_CONTACTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.INCIDENT_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.MISC_POLICE_ACTIVITY: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.OFFICER_INVOLVED_SHOOTINGS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.STOPS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.SURVEYS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.USE_OF_FORCE_REPORTS: RecordTypeCoarse.POLICE_AND_PUBLIC, + RecordType.VEHICLE_PURSUITS: RecordTypeCoarse.POLICE_AND_PUBLIC, + # Info About Officers + RecordType.COMPLAINTS_AND_MISCONDUCT: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.DAILY_ACTIVITY_LOGS: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.TRAINING_AND_HIRING_INFO: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + RecordType.PERSONNEL_RECORDS: RecordTypeCoarse.INFO_ABOUT_OFFICERS, + # Info About Agencies + RecordType.ANNUAL_AND_MONTHLY_REPORTS: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.BUDGETS_AND_FINANCES: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.GEOGRAPHIC: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.LIST_OF_DATA_SOURCES: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + RecordType.POLICIES_AND_CONTRACTS: RecordTypeCoarse.INFO_ABOUT_AGENCIES, + # Agency-Published Resources + RecordType.CRIME_MAPS_AND_REPORTS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.CRIME_STATISTICS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.MEDIA_BULLETINS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.RECORDS_REQUEST_INFO: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.RESOURCES: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.SEX_OFFENDER_REGISTRY: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + RecordType.WANTED_PERSONS: RecordTypeCoarse.AGENCY_PUBLISHED_RESOURCES, + # Jails and Courts Specific + RecordType.BOOKING_REPORTS: RecordTypeCoarse.JAILS_AND_COURTS, + RecordType.COURT_CASES: RecordTypeCoarse.JAILS_AND_COURTS, + RecordType.INCARCERATION_RECORDS: RecordTypeCoarse.JAILS_AND_COURTS, + # Other + RecordType.OTHER: RecordTypeCoarse.OTHER, + None: RecordTypeCoarse.NOT_RELEVANT +} diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/get/model.py b/src/core/tasks/scheduled/impl/huggingface/queries/get/model.py new file mode 100644 index 00000000..187b2ee2 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/get/model.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +from src.core.enums import RecordType +from src.core.tasks.scheduled.impl.huggingface.queries.get.enums import RecordTypeCoarse + + +class GetForLoadingToHuggingFaceOutput(BaseModel): + url_id: int + url: str + relevant: bool + record_type_fine: RecordType | None + record_type_coarse: RecordTypeCoarse | None + html: str \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/huggingface/queries/state.py b/src/core/tasks/scheduled/impl/huggingface/queries/state.py new file mode 100644 index 00000000..3abebc71 --- /dev/null +++ b/src/core/tasks/scheduled/impl/huggingface/queries/state.py @@ -0,0 +1,24 @@ +from datetime import datetime + +from sqlalchemy import delete, insert +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.state.huggingface import HuggingFaceUploadState +from src.db.queries.base.builder import QueryBuilderBase + + +class SetHuggingFaceUploadStateQueryBuilder(QueryBuilderBase): + + def __init__(self, dt: datetime): + super().__init__() + self.dt = dt + + async def run(self, session: AsyncSession) -> None: + # Delete entry if any exists + await session.execute( + delete(HuggingFaceUploadState) + ) + # Insert entry + await session.execute( + insert(HuggingFaceUploadState).values(last_upload_at=self.dt) + ) diff --git a/src/core/tasks/scheduled/impl/internet_archives/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/probe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/convert.py b/src/core/tasks/scheduled/impl/internet_archives/probe/convert.py new file mode 100644 index 00000000..efd5e45c --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/convert.py @@ -0,0 +1,16 @@ +from src.db.models.impl.url.internet_archives.probe.pydantic import URLInternetArchiveMetadataPydantic +from src.external.internet_archives.models.ia_url_mapping import InternetArchivesURLMapping +from src.util.url_mapper import URLMapper + + +def convert_ia_url_mapping_to_ia_metadata( + url_mapper: URLMapper, + ia_mapping: InternetArchivesURLMapping +) -> URLInternetArchiveMetadataPydantic: + iam = ia_mapping.ia_metadata + return URLInternetArchiveMetadataPydantic( + url_id=url_mapper.get_id(ia_mapping.url), + archive_url=iam.archive_url, + digest=iam.digest, + length=iam.length + ) diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/filter.py b/src/core/tasks/scheduled/impl/internet_archives/probe/filter.py new file mode 100644 index 00000000..2713b080 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/filter.py @@ -0,0 +1,16 @@ +from src.external.internet_archives.models.ia_url_mapping import InternetArchivesURLMapping +from src.core.tasks.scheduled.impl.internet_archives.probe.models.subset import IAURLMappingSubsets + + +def filter_into_subsets( + ia_mappings: list[InternetArchivesURLMapping] +) -> IAURLMappingSubsets: + subsets = IAURLMappingSubsets() + for ia_mapping in ia_mappings: + if ia_mapping.has_error: + subsets.error.append(ia_mapping) + + if ia_mapping.has_metadata: + subsets.has_metadata.append(ia_mapping) + + return subsets diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/models/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/probe/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/models/subset.py b/src/core/tasks/scheduled/impl/internet_archives/probe/models/subset.py new file mode 100644 index 00000000..b01fd317 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/models/subset.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.external.internet_archives.models.ia_url_mapping import InternetArchivesURLMapping + + +class IAURLMappingSubsets(BaseModel): + error: list[InternetArchivesURLMapping] = [] + has_metadata: list[InternetArchivesURLMapping] = [] \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/operator.py b/src/core/tasks/scheduled/impl/internet_archives/probe/operator.py new file mode 100644 index 00000000..f4773417 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/operator.py @@ -0,0 +1,119 @@ +from tqdm.asyncio import tqdm_asyncio + +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.mixins.prereq import HasPrerequisitesMixin +from src.core.tasks.scheduled.impl.internet_archives.probe.convert import convert_ia_url_mapping_to_ia_metadata +from src.core.tasks.scheduled.impl.internet_archives.probe.filter import filter_into_subsets +from src.core.tasks.scheduled.impl.internet_archives.probe.models.subset import IAURLMappingSubsets +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.delete import \ + DeleteOldUnsuccessfulIACheckedFlagsQueryBuilder +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.get import GetURLsForInternetArchivesTaskQueryBuilder +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.prereq import \ + CheckURLInternetArchivesTaskPrerequisitesQueryBuilder +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType +from src.db.models.impl.flag.checked_for_ia.pydantic import FlagURLCheckedForInternetArchivesPydantic +from src.db.models.impl.url.internet_archives.probe.pydantic import URLInternetArchiveMetadataPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall +from src.external.internet_archives.client import InternetArchivesClient +from src.external.internet_archives.models.ia_url_mapping import InternetArchivesURLMapping +from src.util.progress_bar import get_progress_bar_disabled +from src.util.url_mapper import URLMapper + + +class InternetArchivesProbeTaskOperator( + ScheduledTaskOperatorBase, + HasPrerequisitesMixin, + LinkURLsMixin +): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + ia_client: InternetArchivesClient + ): + super().__init__(adb_client) + self.ia_client = ia_client + + @property + def task_type(self) -> TaskType: + return TaskType.IA_PROBE + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + CheckURLInternetArchivesTaskPrerequisitesQueryBuilder() + ) + + async def inner_task_logic(self) -> None: + await self.adb_client.run_query_builder( + DeleteOldUnsuccessfulIACheckedFlagsQueryBuilder() + ) + + url_mappings: list[URLMapping] = await self._get_url_mappings() + if len(url_mappings) == 0: + return + mapper = URLMapper(url_mappings) + + await self.link_urls_to_task(mapper.get_all_ids()) + + ia_mappings: list[InternetArchivesURLMapping] = await self._search_for_internet_archive_links(mapper.get_all_urls()) + await self._add_ia_flags_to_db(mapper, ia_mappings=ia_mappings) + + subsets: IAURLMappingSubsets = filter_into_subsets(ia_mappings) + await self._add_errors_to_db(mapper, ia_mappings=subsets.error) + await self._add_ia_metadata_to_db(mapper, ia_mappings=subsets.has_metadata) + + async def _add_errors_to_db(self, mapper: URLMapper, ia_mappings: list[InternetArchivesURLMapping]) -> None: + url_error_info_list: list[URLTaskErrorSmall] = [] + for ia_mapping in ia_mappings: + url_id = mapper.get_id(ia_mapping.url) + url_error_info = URLTaskErrorSmall( + url_id=url_id, + error=ia_mapping.error, + ) + url_error_info_list.append(url_error_info) + await self.add_task_errors(url_error_info_list) + + async def _get_url_mappings(self) -> list[URLMapping]: + return await self.adb_client.run_query_builder( + GetURLsForInternetArchivesTaskQueryBuilder() + ) + + async def _search_for_internet_archive_links(self, urls: list[str]) -> list[InternetArchivesURLMapping]: + return await tqdm_asyncio.gather( + *[ + self.ia_client.search_for_url_snapshot(url) + for url in urls + ], + timeout=60 * 10, # 10 minutes + disable=get_progress_bar_disabled() + ) + + async def _add_ia_metadata_to_db( + self, + url_mapper: URLMapper, + ia_mappings: list[InternetArchivesURLMapping], + ) -> None: + insert_objects: list[URLInternetArchiveMetadataPydantic] = [ + convert_ia_url_mapping_to_ia_metadata( + url_mapper=url_mapper, + ia_mapping=ia_mapping + ) + for ia_mapping in ia_mappings + ] + await self.adb_client.bulk_insert(insert_objects) + + async def _add_ia_flags_to_db( + self, mapper: URLMapper, ia_mappings: list[InternetArchivesURLMapping]) -> None: + flags: list[FlagURLCheckedForInternetArchivesPydantic] = [] + for ia_mapping in ia_mappings: + url_id = mapper.get_id(ia_mapping.url) + flag = FlagURLCheckedForInternetArchivesPydantic( + url_id=url_id, + success=not ia_mapping.has_error + ) + flags.append(flag) + await self.adb_client.bulk_insert(flags) + diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/queries/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/queries/cte.py b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/cte.py new file mode 100644 index 00000000..7de8b290 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/cte.py @@ -0,0 +1,42 @@ +from sqlalchemy import select, or_, exists, func, text, CTE, ColumnElement + +from src.db.helpers.query import not_exists_url +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.models.impl.url.core.sqlalchemy import URL + + +class CheckURLInternetArchivesCTEContainer: + + def __init__(self): + + self._cte = ( + select( + URL.id.label("url_id"), + URL.url + ) + .where( + or_( + not_exists_url(FlagURLCheckedForInternetArchives), + exists( + select(FlagURLCheckedForInternetArchives.url_id) + .where( + FlagURLCheckedForInternetArchives.url_id == URL.id, + ~FlagURLCheckedForInternetArchives.success, + FlagURLCheckedForInternetArchives.created_at < func.now() - text("INTERVAL '1 week'") + ) + ) + ) + ).cte("check_url_internet_archives_prereq") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> ColumnElement[int]: + return self._cte.c.url_id + + @property + def url(self) -> ColumnElement[str]: + return self._cte.c.url \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/queries/delete.py b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/delete.py new file mode 100644 index 00000000..2d9a08e1 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/delete.py @@ -0,0 +1,24 @@ +from sqlalchemy import delete, exists, select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.cte import CheckURLInternetArchivesCTEContainer +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.queries.base.builder import QueryBuilderBase + +class DeleteOldUnsuccessfulIACheckedFlagsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> None: + cte = CheckURLInternetArchivesCTEContainer() + query = ( + delete(FlagURLCheckedForInternetArchives) + .where( + exists( + select(cte.url_id) + .where( + FlagURLCheckedForInternetArchives.url_id == cte.url_id, + ) + ) + ) + ) + + await session.execute(query) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/queries/get.py b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/get.py new file mode 100644 index 00000000..3306943a --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/get.py @@ -0,0 +1,31 @@ +from sqlalchemy import select, or_, exists, text, func +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.cte import CheckURLInternetArchivesCTEContainer +from src.db.dtos.url.mapping import URLMapping +from src.db.helpers.query import not_exists_url +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class GetURLsForInternetArchivesTaskQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[URLMapping]: + cte = CheckURLInternetArchivesCTEContainer() + query = ( + select( + cte.url_id, + cte.url + ) + .limit(100) + ) + + db_mappings = await sh.mappings(session, query=query) + return [ + URLMapping( + url_id=mapping["url_id"], + url=mapping["url"] + ) for mapping in db_mappings + ] diff --git a/src/core/tasks/scheduled/impl/internet_archives/probe/queries/prereq.py b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/prereq.py new file mode 100644 index 00000000..d8994641 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/probe/queries/prereq.py @@ -0,0 +1,19 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.internet_archives.probe.queries.cte import CheckURLInternetArchivesCTEContainer +from src.db.helpers.query import not_exists_url +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class CheckURLInternetArchivesTaskPrerequisitesQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + cte = CheckURLInternetArchivesCTEContainer() + query = ( + select(cte.url_id) + ) + return await sh.results_exist(session, query=query) diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/save/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/filter.py b/src/core/tasks/scheduled/impl/internet_archives/save/filter.py new file mode 100644 index 00000000..2a66ad26 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/filter.py @@ -0,0 +1,14 @@ +from src.core.tasks.scheduled.impl.internet_archives.save.models.mapping import URLInternetArchivesSaveResponseMapping +from src.core.tasks.scheduled.impl.internet_archives.save.models.subset import IASaveURLMappingSubsets + + +def filter_save_responses( + resp_mappings: list[URLInternetArchivesSaveResponseMapping] +) -> IASaveURLMappingSubsets: + subsets = IASaveURLMappingSubsets() + for resp_mapping in resp_mappings: + if resp_mapping.response.has_error: + subsets.error.append(resp_mapping.response) + else: + subsets.success.append(resp_mapping.response) + return subsets \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/mapper.py b/src/core/tasks/scheduled/impl/internet_archives/save/mapper.py new file mode 100644 index 00000000..1d20b1c2 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/mapper.py @@ -0,0 +1,18 @@ +from src.core.tasks.scheduled.impl.internet_archives.save.models.entry import InternetArchivesSaveTaskEntry + + +class URLToEntryMapper: + + def __init__(self, entries: list[InternetArchivesSaveTaskEntry]): + self._url_to_entry: dict[str, InternetArchivesSaveTaskEntry] = { + entry.url: entry for entry in entries + } + + def get_is_new(self, url: str) -> bool: + return self._url_to_entry[url].is_new + + def get_url_id(self, url: str) -> int: + return self._url_to_entry[url].url_id + + def get_all_urls(self) -> list[str]: + return list(self._url_to_entry.keys()) diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/models/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/save/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/models/entry.py b/src/core/tasks/scheduled/impl/internet_archives/save/models/entry.py new file mode 100644 index 00000000..6e4ae84e --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/models/entry.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel + +from src.db.dtos.url.mapping import URLMapping + + +class InternetArchivesSaveTaskEntry(BaseModel): + url: str + url_id: int + is_new: bool + + def to_url_mapping(self) -> URLMapping: + return URLMapping( + url_id=self.url_id, + url=self.url + ) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/models/mapping.py b/src/core/tasks/scheduled/impl/internet_archives/save/models/mapping.py new file mode 100644 index 00000000..d30362a3 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/models/mapping.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.external.internet_archives.models.save_response import InternetArchivesSaveResponseInfo + + +class URLInternetArchivesSaveResponseMapping(BaseModel): + url: str + response: InternetArchivesSaveResponseInfo \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/models/subset.py b/src/core/tasks/scheduled/impl/internet_archives/save/models/subset.py new file mode 100644 index 00000000..a6b29794 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/models/subset.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.core.tasks.scheduled.impl.internet_archives.save.models.mapping import URLInternetArchivesSaveResponseMapping + + +class IASaveURLMappingSubsets(BaseModel): + error: list[URLInternetArchivesSaveResponseMapping] = [] + success: list[URLInternetArchivesSaveResponseMapping] = [] \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/operator.py b/src/core/tasks/scheduled/impl/internet_archives/save/operator.py new file mode 100644 index 00000000..fad0d7ac --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/operator.py @@ -0,0 +1,133 @@ +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.mixins.prereq import HasPrerequisitesMixin +from src.core.tasks.scheduled.impl.internet_archives.save.filter import filter_save_responses +from src.core.tasks.scheduled.impl.internet_archives.save.mapper import URLToEntryMapper +from src.core.tasks.scheduled.impl.internet_archives.save.models.entry import InternetArchivesSaveTaskEntry +from src.core.tasks.scheduled.impl.internet_archives.save.models.mapping import URLInternetArchivesSaveResponseMapping +from src.core.tasks.scheduled.impl.internet_archives.save.models.subset import IASaveURLMappingSubsets +from src.core.tasks.scheduled.impl.internet_archives.save.queries.get import \ + GetURLsForInternetArchivesSaveTaskQueryBuilder +from src.core.tasks.scheduled.impl.internet_archives.save.queries.prereq import \ + MeetsPrerequisitesForInternetArchivesSaveQueryBuilder +from src.core.tasks.scheduled.impl.internet_archives.save.queries.update import \ + UpdateInternetArchivesSaveMetadataQueryBuilder +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.internet_archives.save.pydantic import URLInternetArchiveSaveMetadataPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall +from src.external.internet_archives.client import InternetArchivesClient +from src.external.internet_archives.models.save_response import InternetArchivesSaveResponseInfo + + +class InternetArchivesSaveTaskOperator( + ScheduledTaskOperatorBase, + HasPrerequisitesMixin, + LinkURLsMixin +): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + ia_client: InternetArchivesClient + ): + super().__init__(adb_client) + self.ia_client = ia_client + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + MeetsPrerequisitesForInternetArchivesSaveQueryBuilder() + ) + + @property + def task_type(self) -> TaskType: + return TaskType.IA_SAVE + + async def inner_task_logic(self) -> None: + entries: list[InternetArchivesSaveTaskEntry] = await self._get_valid_urls() + mapper = URLToEntryMapper(entries) + url_ids = [entry.url_id for entry in entries] + await self.link_urls_to_task(url_ids=url_ids) + + # Save all to internet archives and get responses + resp_mappings: list[URLInternetArchivesSaveResponseMapping] = await self._save_all_to_internet_archives( + mapper.get_all_urls() + ) + + # Separate errors from successful saves + subsets: IASaveURLMappingSubsets = filter_save_responses(resp_mappings) + + # Save errors + await self._add_errors_to_db(mapper, responses=subsets.error) + + # Save successful saves that are new archive entries + await self._save_new_saves_to_db(mapper, ia_mappings=subsets.success) + + # Save successful saves that are existing archive entries + await self._save_existing_saves_to_db(mapper, ia_mappings=subsets.success) + + + + async def _save_all_to_internet_archives(self, urls: list[str]) -> list[URLInternetArchivesSaveResponseMapping]: + resp_mappings: list[URLInternetArchivesSaveResponseMapping] = [] + for url in urls: + resp: InternetArchivesSaveResponseInfo = await self.ia_client.save_to_internet_archives(url) + mapping = URLInternetArchivesSaveResponseMapping( + url=url, + response=resp + ) + resp_mappings.append(mapping) + return resp_mappings + + async def _get_valid_urls(self) -> list[InternetArchivesSaveTaskEntry]: + return await self.adb_client.run_query_builder( + GetURLsForInternetArchivesSaveTaskQueryBuilder() + ) + + async def _add_errors_to_db( + self, + mapper: URLToEntryMapper, + responses: list[InternetArchivesSaveResponseInfo] + ) -> None: + error_info_list: list[URLTaskErrorSmall] = [] + for response in responses: + url_id = mapper.get_url_id(response.url) + url_error_info = URLTaskErrorSmall( + url_id=url_id, + error=response.error, + ) + error_info_list.append(url_error_info) + await self.add_task_errors(error_info_list) + + async def _save_new_saves_to_db( + self, + mapper: URLToEntryMapper, + ia_mappings: list[URLInternetArchivesSaveResponseMapping] + ) -> None: + insert_objects: list[URLInternetArchiveSaveMetadataPydantic] = [] + for ia_mapping in ia_mappings: + is_new = mapper.get_is_new(ia_mapping.url) + if not is_new: + continue + insert_object = URLInternetArchiveSaveMetadataPydantic( + url_id=mapper.get_url_id(ia_mapping.url), + ) + insert_objects.append(insert_object) + await self.adb_client.bulk_insert(insert_objects) + + async def _save_existing_saves_to_db( + self, + mapper: URLToEntryMapper, + ia_mappings: list[URLInternetArchivesSaveResponseMapping] + ) -> None: + url_ids: list[int] = [] + for ia_mapping in ia_mappings: + is_new = mapper.get_is_new(ia_mapping.url) + if is_new: + continue + url_ids.append(mapper.get_url_id(ia_mapping.url)) + await self.adb_client.run_query_builder( + UpdateInternetArchivesSaveMetadataQueryBuilder( + url_ids=url_ids + ) + ) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/get.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/get.py new file mode 100644 index 00000000..0c853775 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/queries/get.py @@ -0,0 +1,29 @@ +from typing import Sequence + +from sqlalchemy import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.internet_archives.save.models.entry import InternetArchivesSaveTaskEntry +from src.core.tasks.scheduled.impl.internet_archives.save.queries.shared.get_valid_entries import \ + IA_SAVE_VALID_ENTRIES_QUERY +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsForInternetArchivesSaveTaskQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[InternetArchivesSaveTaskEntry]: + query = ( + IA_SAVE_VALID_ENTRIES_QUERY + # Limit to 15, which is the maximum number of URLs that can be saved at once. + .limit(15) + ) + + db_mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + return [ + InternetArchivesSaveTaskEntry( + url_id=mapping["id"], + url=mapping["url"], + is_new=mapping["is_new"], + ) for mapping in db_mappings + ] \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/prereq.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/prereq.py new file mode 100644 index 00000000..1c661807 --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/queries/prereq.py @@ -0,0 +1,20 @@ +from sqlalchemy import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.scheduled.impl.internet_archives.save.queries.shared.get_valid_entries import \ + IA_SAVE_VALID_ENTRIES_QUERY +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class MeetsPrerequisitesForInternetArchivesSaveQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + + query = ( + IA_SAVE_VALID_ENTRIES_QUERY + .limit(1) + ) + + result: RowMapping | None = await sh.one_or_none(session, query=query) + + return result is not None \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/shared/__init__.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/shared/get_valid_entries.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/shared/get_valid_entries.py new file mode 100644 index 00000000..b0f9eeea --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/queries/shared/get_valid_entries.py @@ -0,0 +1,51 @@ +from sqlalchemy import select, or_, func, text + +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.internet_archives.probe.sqlalchemy import URLInternetArchivesProbeMetadata +from src.db.models.impl.url.internet_archives.save.sqlalchemy import URLInternetArchivesSaveMetadata +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata + +IA_SAVE_VALID_ENTRIES_QUERY = ( + select( + URL.id, + URL.url, + (URLInternetArchivesSaveMetadata.url_id.is_(None)).label("is_new"), + ) + # URL must have been previously probed for its online status. + .join( + URLWebMetadata, + URL.id == URLWebMetadata.url_id + ) + # URL must have been previously probed for an Internet Archive URL. + .join( + FlagURLCheckedForInternetArchives, + URL.id == FlagURLCheckedForInternetArchives.url_id + ) + + .outerjoin( + URLInternetArchivesProbeMetadata, + URL.id == URLInternetArchivesProbeMetadata.url_id + ) + .outerjoin( + URLInternetArchivesSaveMetadata, + URL.id == URLInternetArchivesSaveMetadata.url_id, + + ) + .where( + # Must not have been archived at all + # OR not have been archived in the last month + or_( + URLInternetArchivesSaveMetadata.url_id.is_(None), + URLInternetArchivesSaveMetadata.last_uploaded_at < func.now() - text("INTERVAL '1 month'") + ), + # Must have returned a 200 status code + URLWebMetadata.status_code == 200 + ) + # Order favoring URLs that have never been archived, and never been probed + .order_by( + URLInternetArchivesProbeMetadata.url_id.is_(None).desc(), + URLInternetArchivesSaveMetadata.url_id.is_(None).desc(), + ) + .limit(100) +) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/internet_archives/save/queries/update.py b/src/core/tasks/scheduled/impl/internet_archives/save/queries/update.py new file mode 100644 index 00000000..dd80d18f --- /dev/null +++ b/src/core/tasks/scheduled/impl/internet_archives/save/queries/update.py @@ -0,0 +1,21 @@ +from sqlalchemy import update, func +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.url.internet_archives.save.sqlalchemy import URLInternetArchivesSaveMetadata +from src.db.queries.base.builder import QueryBuilderBase + + +class UpdateInternetArchivesSaveMetadataQueryBuilder(QueryBuilderBase): + + def __init__(self, url_ids: list[int]): + super().__init__() + self.url_ids = url_ids + + async def run(self, session: AsyncSession) -> None: + stmt = ( + update(URLInternetArchivesSaveMetadata) + .where(URLInternetArchivesSaveMetadata.url_id.in_(self.url_ids)) + .values(last_uploaded_at=func.now()) + ) + await session.execute(stmt) + diff --git a/src/core/tasks/scheduled/impl/mark_never_completed/__init__.py b/src/core/tasks/scheduled/impl/mark_never_completed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/mark_never_completed/operator.py b/src/core/tasks/scheduled/impl/mark_never_completed/operator.py new file mode 100644 index 00000000..7ec08298 --- /dev/null +++ b/src/core/tasks/scheduled/impl/mark_never_completed/operator.py @@ -0,0 +1,15 @@ +from src.core.tasks.scheduled.impl.mark_never_completed.query import MarkTaskNeverCompletedQueryBuilder +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.enums import TaskType + + +class MarkTaskNeverCompletedOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.MARK_TASK_NEVER_COMPLETED + + async def inner_task_logic(self) -> None: + await self.adb_client.run_query_builder( + MarkTaskNeverCompletedQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/mark_never_completed/query.py b/src/core/tasks/scheduled/impl/mark_never_completed/query.py new file mode 100644 index 00000000..1aba3aea --- /dev/null +++ b/src/core/tasks/scheduled/impl/mark_never_completed/query.py @@ -0,0 +1,28 @@ +from datetime import timedelta, datetime +from typing import Any + +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.enums import BatchStatus +from src.db.enums import TaskType +from src.db.models.impl.task.core import Task +from src.db.models.impl.task.enums import TaskStatus +from src.db.queries.base.builder import QueryBuilderBase + + +class MarkTaskNeverCompletedQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> Any: + statement = ( + update( + Task + ).values( + task_status=TaskStatus.NEVER_COMPLETED.value + ). + where( + Task.task_status == TaskStatus.IN_PROCESS, + Task.updated_at < datetime.now() - timedelta(hours=1) + ) + ) + await session.execute(statement) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/refresh_materialized_views/__init__.py b/src/core/tasks/scheduled/impl/refresh_materialized_views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/refresh_materialized_views/operator.py b/src/core/tasks/scheduled/impl/refresh_materialized_views/operator.py new file mode 100644 index 00000000..e19feee5 --- /dev/null +++ b/src/core/tasks/scheduled/impl/refresh_materialized_views/operator.py @@ -0,0 +1,12 @@ +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.enums import TaskType + + +class RefreshMaterializedViewsOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.REFRESH_MATERIALIZED_VIEWS + + async def inner_task_logic(self) -> None: + await self.adb_client.refresh_materialized_views() \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/run_url_tasks/__init__.py b/src/core/tasks/scheduled/impl/run_url_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/run_url_tasks/operator.py b/src/core/tasks/scheduled/impl/run_url_tasks/operator.py new file mode 100644 index 00000000..ef76fbac --- /dev/null +++ b/src/core/tasks/scheduled/impl/run_url_tasks/operator.py @@ -0,0 +1,17 @@ +from src.core.core import AsyncCore +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.enums import TaskType + + +class RunURLTasksTaskOperator(ScheduledTaskOperatorBase): + + def __init__(self, async_core: AsyncCore): + super().__init__(async_core.adb_client) + self.async_core = async_core + + @property + def task_type(self) -> TaskType: + return TaskType.RUN_URL_TASKS + + async def inner_task_logic(self) -> None: + await self.async_core.run_tasks() \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/task_cleanup/__init__.py b/src/core/tasks/scheduled/impl/task_cleanup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/impl/task_cleanup/operator.py b/src/core/tasks/scheduled/impl/task_cleanup/operator.py new file mode 100644 index 00000000..ea4febcd --- /dev/null +++ b/src/core/tasks/scheduled/impl/task_cleanup/operator.py @@ -0,0 +1,15 @@ +from src.core.tasks.scheduled.impl.task_cleanup.query import TaskCleanupQueryBuilder +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase +from src.db.enums import TaskType + + +class TaskCleanupOperator(ScheduledTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.TASK_CLEANUP + + async def inner_task_logic(self) -> None: + await self.adb_client.run_query_builder( + TaskCleanupQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/scheduled/impl/task_cleanup/query.py b/src/core/tasks/scheduled/impl/task_cleanup/query.py new file mode 100644 index 00000000..b455e1c6 --- /dev/null +++ b/src/core/tasks/scheduled/impl/task_cleanup/query.py @@ -0,0 +1,23 @@ +from datetime import timedelta, datetime +from typing import Any + +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.task.core import Task +from src.db.queries.base.builder import QueryBuilderBase + + +class TaskCleanupQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> Any: + one_week_ago: datetime = datetime.now() - timedelta(days=7) + + statement = ( + delete(Task) + .where( + Task.updated_at < one_week_ago + ) + ) + + await session.execute(statement) \ No newline at end of file diff --git a/src/core/tasks/scheduled/loader.py b/src/core/tasks/scheduled/loader.py index fb92dcb0..82ac92cc 100644 --- a/src/core/tasks/scheduled/loader.py +++ b/src/core/tasks/scheduled/loader.py @@ -1,5 +1,22 @@ -from src.core.tasks.scheduled.operators.agency_sync.core import SyncAgenciesTaskOperator +from environs import Env + +from src.core.core import AsyncCore +from src.core.tasks.scheduled.enums import IntervalEnum +from src.core.tasks.scheduled.impl.backlog.operator import PopulateBacklogSnapshotTaskOperator +from src.core.tasks.scheduled.impl.delete_logs.operator import DeleteOldLogsTaskOperator +from src.core.tasks.scheduled.impl.delete_stale_screenshots.operator import DeleteStaleScreenshotsTaskOperator +from src.core.tasks.scheduled.impl.huggingface.operator import PushToHuggingFaceTaskOperator +from src.core.tasks.scheduled.impl.internet_archives.probe.operator import InternetArchivesProbeTaskOperator +from src.core.tasks.scheduled.impl.internet_archives.save.operator import InternetArchivesSaveTaskOperator +from src.core.tasks.scheduled.impl.mark_never_completed.operator import MarkTaskNeverCompletedOperator +from src.core.tasks.scheduled.impl.mark_never_completed.query import MarkTaskNeverCompletedQueryBuilder +from src.core.tasks.scheduled.impl.refresh_materialized_views.operator import RefreshMaterializedViewsOperator +from src.core.tasks.scheduled.impl.run_url_tasks.operator import RunURLTasksTaskOperator +from src.core.tasks.scheduled.impl.task_cleanup.operator import TaskCleanupOperator +from src.core.tasks.scheduled.models.entry import ScheduledTaskEntry from src.db.client.async_ import AsyncDatabaseClient +from src.external.huggingface.hub.client import HuggingFaceHubClient +from src.external.internet_archives.client import InternetArchivesClient from src.external.pdap.client import PDAPClient @@ -7,17 +24,96 @@ class ScheduledTaskOperatorLoader: def __init__( self, + async_core: AsyncCore, adb_client: AsyncDatabaseClient, pdap_client: PDAPClient, + hf_client: HuggingFaceHubClient, + ia_client: InternetArchivesClient ): # Dependencies + self.async_core = async_core self.adb_client = adb_client self.pdap_client = pdap_client + # External Interfaces + self.hf_client = hf_client + self.ia_client = ia_client + + self.env = Env() + self.env.read_env() + + def setup_flag(self, name: str) -> bool: + return self.env.bool(name, default=True) + + + async def load_entries(self) -> list[ScheduledTaskEntry]: + scheduled_task_flag = self.env.bool("SCHEDULED_TASKS_FLAG", default=True) + if not scheduled_task_flag: + print("Scheduled tasks are disabled.") + return [] + - async def get_sync_agencies_task_operator(self): - operator = SyncAgenciesTaskOperator( - adb_client=self.adb_client, - pdap_client=self.pdap_client - ) - return operator \ No newline at end of file + return [ + ScheduledTaskEntry( + operator=InternetArchivesProbeTaskOperator( + adb_client=self.adb_client, + ia_client=self.ia_client + ), + interval_minutes=IntervalEnum.TEN_MINUTES.value, + enabled=self.setup_flag("IA_PROBE_TASK_FLAG"), + ), + ScheduledTaskEntry( + operator=InternetArchivesSaveTaskOperator( + adb_client=self.adb_client, + ia_client=self.ia_client + ), + interval_minutes=IntervalEnum.TEN_MINUTES.value, + enabled=self.setup_flag("IA_SAVE_TASK_FLAG"), + ), + ScheduledTaskEntry( + operator=DeleteOldLogsTaskOperator(adb_client=self.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("DELETE_OLD_LOGS_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=RunURLTasksTaskOperator(async_core=self.async_core), + interval_minutes=self.env.int( + "URL_TASKS_FREQUENCY_MINUTES", + default=IntervalEnum.HOURLY.value + ), + enabled=self.setup_flag("RUN_URL_TASKS_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=PopulateBacklogSnapshotTaskOperator(adb_client=self.async_core.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("POPULATE_BACKLOG_SNAPSHOT_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=PushToHuggingFaceTaskOperator( + adb_client=self.async_core.adb_client, + hf_client=self.hf_client + ), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("PUSH_TO_HUGGING_FACE_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=MarkTaskNeverCompletedOperator(adb_client=self.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("MARK_TASK_NEVER_COMPLETED_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=DeleteStaleScreenshotsTaskOperator(adb_client=self.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("DELETE_STALE_SCREENSHOTS_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=TaskCleanupOperator(adb_client=self.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("TASK_CLEANUP_TASK_FLAG") + ), + ScheduledTaskEntry( + operator=RefreshMaterializedViewsOperator(adb_client=self.adb_client), + interval_minutes=IntervalEnum.DAILY.value, + enabled=self.setup_flag("REFRESH_MATERIALIZED_VIEWS_TASK_FLAG") + ) + ] diff --git a/src/core/tasks/scheduled/manager.py b/src/core/tasks/scheduled/manager.py index 44576cfa..87cb5a27 100644 --- a/src/core/tasks/scheduled/manager.py +++ b/src/core/tasks/scheduled/manager.py @@ -1,80 +1,72 @@ -from datetime import datetime, timedelta - -from apscheduler.schedulers.asyncio import AsyncIOScheduler -from apscheduler.triggers.interval import IntervalTrigger -from src.core.core import AsyncCore from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.core.tasks.handler import TaskHandler +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.mixins.prereq import HasPrerequisitesMixin from src.core.tasks.scheduled.loader import ScheduledTaskOperatorLoader -from src.core.tasks.scheduled.operators.base import ScheduledTaskOperatorBase +from src.core.tasks.scheduled.models.entry import ScheduledTaskEntry +from src.core.tasks.scheduled.registry.core import ScheduledJobRegistry +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase class AsyncScheduledTaskManager: def __init__( self, - async_core: AsyncCore, handler: TaskHandler, - loader: ScheduledTaskOperatorLoader + loader: ScheduledTaskOperatorLoader, + registry: ScheduledJobRegistry ): - # Dependencies - self.async_core = async_core - self.handler = handler - self.loader = loader - # Main objects - self.scheduler = AsyncIOScheduler() + # Dependencies + self._handler = handler + self._loader = loader + self._registry = registry - # Jobs - self.run_cycles_job = None - self.delete_logs_job = None - self.populate_backlog_snapshot_job = None - self.sync_agencies_job = None async def setup(self): - self.scheduler.start() + self._registry.start_scheduler() await self.add_scheduled_tasks() + await self._registry.report_next_scheduled_task() + + async def add_scheduled_tasks(self): - self.run_cycles_job = self.scheduler.add_job( - self.async_core.run_tasks, - trigger=IntervalTrigger( - hours=1, - start_date=datetime.now() + timedelta(minutes=1) - ), - misfire_grace_time=60 - ) - self.delete_logs_job = self.scheduler.add_job( - self.async_core.adb_client.delete_old_logs, - trigger=IntervalTrigger( - days=1, - start_date=datetime.now() + timedelta(minutes=10) - ) - ) - self.populate_backlog_snapshot_job = self.scheduler.add_job( - self.async_core.adb_client.populate_backlog_snapshot, - trigger=IntervalTrigger( - days=1, - start_date=datetime.now() + timedelta(minutes=20) + """ + Modifies: + self._registry + """ + entries: list[ScheduledTaskEntry] = await self._loader.load_entries() + enabled_entries: list[ScheduledTaskEntry] = [] + for entry in entries: + if not entry.enabled: + print(f"{entry.operator.task_type.value} is disabled. Skipping add to scheduler.") + continue + enabled_entries.append(entry) + + initial_lag: int = 1 + for idx, entry in enumerate(enabled_entries): + await self._registry.add_job( + func=self.run_task, + entry=entry, + minute_lag=idx + initial_lag ) - ) - self.sync_agencies_job = self.scheduler.add_job( - self.run_task, - trigger=IntervalTrigger( - days=1, - start_date=datetime.now() + timedelta(minutes=2) - ), - kwargs={ - "operator": await self.loader.get_sync_agencies_task_operator() - } - ) def shutdown(self): - if self.scheduler.running: - self.scheduler.shutdown() + self._registry.shutdown_scheduler() async def run_task(self, operator: ScheduledTaskOperatorBase): print(f"Running {operator.task_type.value} Task") - task_id = await self.handler.initiate_task_in_db(task_type=operator.task_type) - run_info: TaskOperatorRunInfo = await operator.run_task(task_id) - await self.handler.handle_outcome(run_info) + if issubclass(operator.__class__, HasPrerequisitesMixin): + operator: HasPrerequisitesMixin + if not await operator.meets_task_prerequisites(): + operator: ScheduledTaskOperatorBase + print(f"Prerequisites not met for {operator.task_type.value} Task. Skipping.") + return + run_info: TaskOperatorRunInfo = await operator.run_task() + if issubclass(operator.__class__, LinkURLsMixin): + operator: LinkURLsMixin + if not operator.urls_linked: + operator: ScheduledTaskOperatorBase + raise Exception(f"Task {operator.task_type.value} has not been linked to any URLs but is designated as a link task") + await self._handler.handle_outcome(run_info) + await self._registry.report_next_scheduled_task() diff --git a/src/core/tasks/scheduled/models/__init__.py b/src/core/tasks/scheduled/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/models/entry.py b/src/core/tasks/scheduled/models/entry.py new file mode 100644 index 00000000..32abb913 --- /dev/null +++ b/src/core/tasks/scheduled/models/entry.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + +from src.core.tasks.scheduled.enums import IntervalEnum +from src.core.tasks.scheduled.templates.operator import ScheduledTaskOperatorBase + + +class ScheduledTaskEntry(BaseModel): + + class Config: + arbitrary_types_allowed = True + + operator: ScheduledTaskOperatorBase + interval_minutes: int + enabled: bool diff --git a/src/core/tasks/scheduled/operators/agency_sync/constants.py b/src/core/tasks/scheduled/operators/agency_sync/constants.py deleted file mode 100644 index a58a7aca..00000000 --- a/src/core/tasks/scheduled/operators/agency_sync/constants.py +++ /dev/null @@ -1,7 +0,0 @@ - - -""" -Denotes the maximum number of requests to the Agencies Sync endpoint -permissible in a single task run. -""" -MAX_SYNC_REQUESTS = 30 \ No newline at end of file diff --git a/src/core/tasks/scheduled/operators/agency_sync/core.py b/src/core/tasks/scheduled/operators/agency_sync/core.py deleted file mode 100644 index c522effd..00000000 --- a/src/core/tasks/scheduled/operators/agency_sync/core.py +++ /dev/null @@ -1,48 +0,0 @@ -from src.core.tasks.scheduled.operators.agency_sync.constants import MAX_SYNC_REQUESTS -from src.core.tasks.scheduled.operators.agency_sync.dtos.parameters import AgencySyncParameters -from src.core.tasks.scheduled.operators.agency_sync.exceptions import MaxRequestsExceededError -from src.core.tasks.scheduled.operators.base import ScheduledTaskOperatorBase -from src.db.client.async_ import AsyncDatabaseClient -from src.db.enums import TaskType -from src.external.pdap.client import PDAPClient - - -class SyncAgenciesTaskOperator(ScheduledTaskOperatorBase): - - def __init__( - self, - adb_client: AsyncDatabaseClient, - pdap_client: PDAPClient - ): - super().__init__(adb_client) - self.pdap_client = pdap_client - - @property - def task_type(self) -> TaskType: # - return TaskType.SYNC_AGENCIES - - async def inner_task_logic(self): - params = await self.adb_client.get_agencies_sync_parameters() - if params.page is None: - params.page = 1 - - response = await self.pdap_client.sync_agencies(params) - request_count = 1 - while len(response.agencies) > 0: - if request_count > MAX_SYNC_REQUESTS: - raise MaxRequestsExceededError( - f"Max requests in a single task run ({MAX_SYNC_REQUESTS}) exceeded." - ) - await self.adb_client.upsert_agencies(response.agencies) - - params = AgencySyncParameters( - page=params.page + 1, - cutoff_date=params.cutoff_date - ) - await self.adb_client.update_agencies_sync_progress(params.page) - - response = await self.pdap_client.sync_agencies(params) - request_count += 1 - - await self.adb_client.mark_full_agencies_sync() - diff --git a/src/core/tasks/scheduled/operators/agency_sync/dtos/parameters.py b/src/core/tasks/scheduled/operators/agency_sync/dtos/parameters.py deleted file mode 100644 index 3d8cceb4..00000000 --- a/src/core/tasks/scheduled/operators/agency_sync/dtos/parameters.py +++ /dev/null @@ -1,9 +0,0 @@ -from datetime import date -from typing import Optional - -from pydantic import BaseModel - - -class AgencySyncParameters(BaseModel): - cutoff_date: Optional[date] - page: Optional[int] diff --git a/src/core/tasks/scheduled/operators/agency_sync/exceptions.py b/src/core/tasks/scheduled/operators/agency_sync/exceptions.py deleted file mode 100644 index 0af9937f..00000000 --- a/src/core/tasks/scheduled/operators/agency_sync/exceptions.py +++ /dev/null @@ -1,5 +0,0 @@ - - - -class MaxRequestsExceededError(Exception): - pass \ No newline at end of file diff --git a/src/core/tasks/scheduled/registry/__init__.py b/src/core/tasks/scheduled/registry/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/registry/core.py b/src/core/tasks/scheduled/registry/core.py new file mode 100644 index 00000000..e9fc205b --- /dev/null +++ b/src/core/tasks/scheduled/registry/core.py @@ -0,0 +1,69 @@ +from datetime import datetime, timedelta +from typing import Callable + +from apscheduler.job import Job +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from apscheduler.triggers.interval import IntervalTrigger + +from src.core.tasks.scheduled.models.entry import ScheduledTaskEntry +from src.core.tasks.scheduled.registry.format import format_job_datetime +from src.db.enums import TaskType + + +class ScheduledJobRegistry: + + + def __init__(self): + # Main objects + self.scheduler = AsyncIOScheduler() + + # Jobs + self._jobs: dict[TaskType, Job] = {} + + async def add_job( + self, + func: Callable, + entry: ScheduledTaskEntry, + minute_lag: int + ) -> None: + """ + Modifies: + self._jobs + """ + job: Job = self.scheduler.add_job( + id=entry.operator.task_type.value, + func=func, + trigger=IntervalTrigger( + minutes=entry.interval_minutes, + start_date=datetime.now() + timedelta(minutes=minute_lag) + ), + misfire_grace_time=60, + kwargs={"operator": entry.operator} + ) + run_time_str: str = format_job_datetime(job.next_run_time) + print(f"Adding {job.id} task to scheduler. " + + f"First run at {run_time_str}") + self._jobs[entry.operator.task_type] = job + + def start_scheduler(self) -> None: + """ + Modifies: + self.scheduler + """ + self.scheduler.start() + + def shutdown_scheduler(self) -> None: + if self.scheduler.running: + self.scheduler.shutdown() + + async def report_next_scheduled_task(self): + jobs: list[Job] = self.scheduler.get_jobs() + if len(jobs) == 0: + print("No scheduled tasks found.") + return + + jobs_sorted: list[Job] = sorted(jobs, key=lambda job: job.next_run_time) + next_job: Job = jobs_sorted[0] + + run_time_str: str = format_job_datetime(next_job.next_run_time) + print(f"Next scheduled task: {run_time_str} ({next_job.id})") \ No newline at end of file diff --git a/src/core/tasks/scheduled/registry/format.py b/src/core/tasks/scheduled/registry/format.py new file mode 100644 index 00000000..23eea364 --- /dev/null +++ b/src/core/tasks/scheduled/registry/format.py @@ -0,0 +1,7 @@ +from datetime import datetime + +def format_job_datetime(dt: datetime) -> str: + date_str: str = dt.strftime("%Y-%m-%d") + format_24: str = dt.strftime("%H:%M:%S") + format_12: str = dt.strftime("%I:%M:%S %p") + return f"{date_str} {format_24} ({format_12})" \ No newline at end of file diff --git a/src/core/tasks/scheduled/templates/__init__.py b/src/core/tasks/scheduled/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/scheduled/operators/base.py b/src/core/tasks/scheduled/templates/operator.py similarity index 100% rename from src/core/tasks/scheduled/operators/base.py rename to src/core/tasks/scheduled/templates/operator.py diff --git a/src/core/tasks/url/loader.py b/src/core/tasks/url/loader.py index 99997e3f..b5910f5e 100644 --- a/src/core/tasks/url/loader.py +++ b/src/core/tasks/url/loader.py @@ -2,22 +2,33 @@ The task loader loads task a task operator and all dependencies. """ -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface +from environs import Env + +from src.collectors.impl.muckrock.api_interface.core import MuckrockAPIInterface +from src.core.tasks.url.models.entry import URLTaskEntry from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.loader import AgencyIdentificationSubtaskLoader +from src.core.tasks.url.operators.auto_name.core import AutoNameURLTaskOperator from src.core.tasks.url.operators.auto_relevant.core import URLAutoRelevantTaskOperator -from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.html.core import URLHTMLTaskOperator +from src.core.tasks.url.operators.html.scraper.parser.core import HTMLResponseParser +from src.core.tasks.url.operators.location_id.core import LocationIdentificationTaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.loader import LocationIdentificationSubtaskLoader +from src.core.tasks.url.operators.misc_metadata.core import URLMiscellaneousMetadataTaskOperator +from src.core.tasks.url.operators.probe.core import URLProbeTaskOperator from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator from src.core.tasks.url.operators.record_type.llm_api.record_classifier.openai import OpenAIRecordClassifier -from src.core.tasks.url.operators.submit_approved_url.core import SubmitApprovedURLTaskOperator -from src.core.tasks.url.operators.url_404_probe.core import URL404ProbeTaskOperator -from src.core.tasks.url.operators.url_duplicate.core import URLDuplicateTaskOperator -from src.core.tasks.url.operators.url_html.core import URLHTMLTaskOperator -from src.core.tasks.url.operators.url_html.scraper.parser.core import HTMLResponseParser -from src.core.tasks.url.operators.url_html.scraper.request_interface.core import URLRequestInterface -from src.core.tasks.url.operators.url_miscellaneous_metadata.core import URLMiscellaneousMetadataTaskOperator +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.core.tasks.url.operators.screenshot.core import URLScreenshotTaskOperator +from src.core.tasks.url.operators.submit_approved.core import SubmitApprovedURLTaskOperator +from src.core.tasks.url.operators.submit_meta_urls.core import SubmitMetaURLsTaskOperator +from src.core.tasks.url.operators.suspend.core import SuspendURLTaskOperator +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator from src.db.client.async_ import AsyncDatabaseClient from src.external.huggingface.inference.client import HuggingFaceInferenceClient from src.external.pdap.client import PDAPClient +from src.external.url_request.core import URLRequestInterface class URLTaskOperatorLoader: @@ -29,83 +40,185 @@ def __init__( html_parser: HTMLResponseParser, pdap_client: PDAPClient, muckrock_api_interface: MuckrockAPIInterface, - hf_inference_client: HuggingFaceInferenceClient + hf_inference_client: HuggingFaceInferenceClient, + nlp_processor: NLPProcessor ): # Dependencies self.adb_client = adb_client self.url_request_interface = url_request_interface self.html_parser = html_parser + self.nlp_processor = nlp_processor + self.env = Env() # External clients and interfaces self.pdap_client = pdap_client self.muckrock_api_interface = muckrock_api_interface self.hf_inference_client = hf_inference_client - async def get_url_html_task_operator(self): + def setup_flag(self, name: str) -> bool: + return self.env.bool( + name, + default=True + ) + + def _get_url_html_task_operator(self) -> URLTaskEntry: operator = URLHTMLTaskOperator( adb_client=self.adb_client, url_request_interface=self.url_request_interface, html_parser=self.html_parser ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_HTML_TASK_FLAG") + ) - async def get_url_record_type_task_operator(self): + def _get_url_record_type_task_operator(self) -> URLTaskEntry: operator = URLRecordTypeTaskOperator( adb_client=self.adb_client, classifier=OpenAIRecordClassifier() ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_RECORD_TYPE_TASK_FLAG") + ) - async def get_agency_identification_task_operator(self): + def _get_agency_identification_task_operator(self) -> URLTaskEntry: operator = AgencyIdentificationTaskOperator( adb_client=self.adb_client, - pdap_client=self.pdap_client, - muckrock_api_interface=self.muckrock_api_interface + loader=AgencyIdentificationSubtaskLoader( + pdap_client=self.pdap_client, + muckrock_api_interface=self.muckrock_api_interface, + adb_client=self.adb_client, + ) + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_AGENCY_IDENTIFICATION_TASK_FLAG") ) - return operator - async def get_submit_approved_url_task_operator(self): + def _get_submit_approved_url_task_operator(self) -> URLTaskEntry: operator = SubmitApprovedURLTaskOperator( adb_client=self.adb_client, pdap_client=self.pdap_client ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_SUBMIT_APPROVED_TASK_FLAG") + ) + + def _get_submit_meta_urls_task_operator(self) -> URLTaskEntry: + operator = SubmitMetaURLsTaskOperator( + adb_client=self.adb_client, + pdap_client=self.pdap_client + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_SUBMIT_META_URLS_TASK_FLAG") + ) - async def get_url_miscellaneous_metadata_task_operator(self): + def _get_url_miscellaneous_metadata_task_operator(self) -> URLTaskEntry: operator = URLMiscellaneousMetadataTaskOperator( adb_client=self.adb_client ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_MISC_METADATA_TASK_FLAG") + ) + - async def get_url_duplicate_task_operator(self): - operator = URLDuplicateTaskOperator( + def _get_url_auto_relevance_task_operator(self) -> URLTaskEntry: + operator = URLAutoRelevantTaskOperator( adb_client=self.adb_client, - pdap_client=self.pdap_client + hf_client=self.hf_inference_client + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_AUTO_RELEVANCE_TASK_FLAG") ) - return operator - async def get_url_404_probe_task_operator(self): - operator = URL404ProbeTaskOperator( + def _get_url_probe_task_operator(self) -> URLTaskEntry: + operator = URLProbeTaskOperator( adb_client=self.adb_client, url_request_interface=self.url_request_interface ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_PROBE_TASK_FLAG") + ) - async def get_url_auto_relevance_task_operator(self): - operator = URLAutoRelevantTaskOperator( + def _get_url_root_url_task_operator(self) -> URLTaskEntry: + operator = URLRootURLTaskOperator( + adb_client=self.adb_client + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_ROOT_URL_TASK_FLAG") + ) + + def _get_url_screenshot_task_operator(self) -> URLTaskEntry: + operator = URLScreenshotTaskOperator( adb_client=self.adb_client, - hf_client=self.hf_inference_client ) - return operator + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_SCREENSHOT_TASK_FLAG") + ) + + def _get_location_id_task_operator(self) -> URLTaskEntry: + operator = LocationIdentificationTaskOperator( + adb_client=self.adb_client, + loader=LocationIdentificationSubtaskLoader( + adb_client=self.adb_client, + nlp_processor=self.nlp_processor + ) + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_LOCATION_IDENTIFICATION_TASK_FLAG") + ) + + def _get_auto_validate_task_operator(self) -> URLTaskEntry: + operator = AutoValidateURLTaskOperator( + adb_client=self.adb_client + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_AUTO_VALIDATE_TASK_FLAG") + ) + + def _get_auto_name_task_operator(self) -> URLTaskEntry: + operator = AutoNameURLTaskOperator( + adb_client=self.adb_client, + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_AUTO_NAME_TASK_FLAG") + ) + + def _get_suspend_url_task_operator(self) -> URLTaskEntry: + operator = SuspendURLTaskOperator( + adb_client=self.adb_client + ) + return URLTaskEntry( + operator=operator, + enabled=self.setup_flag("URL_SUSPEND_TASK_FLAG") + ) + - async def get_task_operators(self) -> list[URLTaskOperatorBase]: + async def load_entries(self) -> list[URLTaskEntry]: return [ - await self.get_url_html_task_operator(), - await self.get_url_duplicate_task_operator(), - await self.get_url_404_probe_task_operator(), - await self.get_url_record_type_task_operator(), - await self.get_agency_identification_task_operator(), - await self.get_url_miscellaneous_metadata_task_operator(), - await self.get_submit_approved_url_task_operator(), - await self.get_url_auto_relevance_task_operator() + self._get_url_root_url_task_operator(), + self._get_url_probe_task_operator(), + self._get_url_html_task_operator(), + self._get_url_record_type_task_operator(), + self._get_agency_identification_task_operator(), + self._get_url_miscellaneous_metadata_task_operator(), + self._get_submit_approved_url_task_operator(), + self._get_submit_meta_urls_task_operator(), + self._get_url_auto_relevance_task_operator(), + self._get_url_screenshot_task_operator(), + self._get_location_id_task_operator(), + self._get_auto_validate_task_operator(), + self._get_auto_name_task_operator(), + self._get_suspend_url_task_operator(), ] diff --git a/src/core/tasks/url/manager.py b/src/core/tasks/url/manager.py index 1d843b95..7fc6b4e3 100644 --- a/src/core/tasks/url/manager.py +++ b/src/core/tasks/url/manager.py @@ -1,9 +1,10 @@ import logging +from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.core.tasks.handler import TaskHandler from src.core.tasks.url.loader import URLTaskOperatorLoader +from src.core.tasks.url.models.entry import URLTaskEntry from src.db.enums import TaskType -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.function_trigger import FunctionTrigger @@ -28,41 +29,45 @@ def __init__( #region Tasks - async def set_manager_status(self, task_type: TaskType): + async def set_manager_status(self, task_type: TaskType) -> None: + """ + Modifies: + self.manager_status + """ self.manager_status = task_type - async def run_tasks(self): - operators = await self.loader.get_task_operators() - for operator in operators: - count = 0 - await self.set_manager_status(task_type=operator.task_type) + async def run_tasks(self) -> None: + entries: list[URLTaskEntry] = await self.loader.load_entries() + for entry in entries: + if not entry.enabled: + continue + await self._run_task(entry) + await self.set_manager_status(task_type=TaskType.IDLE) + async def _run_task(self, entry: URLTaskEntry) -> None: + operator = entry.operator + count = 0 + await self.set_manager_status(task_type=operator.task_type) + meets_prereq = await operator.meets_task_prerequisites() + while meets_prereq: + print(f"Running {operator.task_type.value} Task") + if count > TASK_REPEAT_THRESHOLD: + message = f"Task {operator.task_type.value} has been run more than {TASK_REPEAT_THRESHOLD} times in a row. Task loop terminated." + print(message) + await self.handler.post_to_discord(message=message) + break + run_info: TaskOperatorRunInfo = await operator.run_task() + await self.conclude_task(run_info) + if run_info.outcome == TaskOperatorOutcome.ERROR: + break + count += 1 meets_prereq = await operator.meets_task_prerequisites() - while meets_prereq: - print(f"Running {operator.task_type.value} Task") - if count > TASK_REPEAT_THRESHOLD: - message = f"Task {operator.task_type.value} has been run more than {TASK_REPEAT_THRESHOLD} times in a row. Task loop terminated." - print(message) - await self.handler.post_to_discord(message=message) - break - task_id = await self.handler.initiate_task_in_db(task_type=operator.task_type) - run_info: URLTaskOperatorRunInfo = await operator.run_task(task_id) - await self.conclude_task(run_info) - if run_info.outcome == TaskOperatorOutcome.ERROR: - break - count += 1 - meets_prereq = await operator.meets_task_prerequisites() - await self.set_manager_status(task_type=TaskType.IDLE) - async def trigger_task_run(self): + async def trigger_task_run(self) -> None: await self.task_trigger.trigger_or_rerun() - async def conclude_task(self, run_info: URLTaskOperatorRunInfo): - await self.handler.link_urls_to_task( - task_id=run_info.task_id, - url_ids=run_info.linked_url_ids - ) + async def conclude_task(self, run_info: TaskOperatorRunInfo) -> None: await self.handler.handle_outcome(run_info) diff --git a/src/core/tasks/url/models/__init__.py b/src/core/tasks/url/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/models/entry.py b/src/core/tasks/url/models/entry.py new file mode 100644 index 00000000..eeb09047 --- /dev/null +++ b/src/core/tasks/url/models/entry.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.base import URLTaskOperatorBase + + +class URLTaskEntry(BaseModel): + + class Config: + arbitrary_types_allowed = True + + operator: URLTaskOperatorBase + enabled: bool \ No newline at end of file diff --git a/src/core/tasks/url/operators/_shared/__init__.py b/src/core/tasks/url/operators/_shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/_shared/container/__init__.py b/src/core/tasks/url/operators/_shared/container/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/_shared/container/subtask/__init__.py b/src/core/tasks/url/operators/_shared/container/subtask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/_shared/container/subtask/eligible.py b/src/core/tasks/url/operators/_shared/container/subtask/eligible.py new file mode 100644 index 00000000..989b509f --- /dev/null +++ b/src/core/tasks/url/operators/_shared/container/subtask/eligible.py @@ -0,0 +1,40 @@ +from sqlalchemy import CTE, ColumnElement, Column, Select, exists, func + +from src.db.models.impl.url.core.sqlalchemy import URL + + +class URLsSubtaskEligibleCTEContainer: + """ + CTE for URLs eligible for a given subtask. + A successful left join on this indicates the URL is eligible for the subtask. + A true value for `subtask_entry_exists` indicates + a subtask entry for the URL already exists + """ + + def __init__( + self, + cte: CTE, + ) -> None: + self._cte=cte + + @property + def cte(self) -> CTE: + return self._cte + + @property + def entry_exists(self) -> ColumnElement[bool]: + return self.cte.c['subtask_entry_exists'] + + @property + def url_id(self) -> Column[int]: + return self.cte.c['id'] + + @property + def eligible_query(self) -> ColumnElement[bool]: + return ( + exists() + .where( + self.url_id == URL.id, + self.entry_exists.is_(False), + ) + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/_shared/container/subtask/exists.py b/src/core/tasks/url/operators/_shared/container/subtask/exists.py new file mode 100644 index 00000000..f10956d3 --- /dev/null +++ b/src/core/tasks/url/operators/_shared/container/subtask/exists.py @@ -0,0 +1,33 @@ +from sqlalchemy import CTE, Column, ColumnElement, exists + +from src.db.models.impl.url.core.sqlalchemy import URL + + +class URLsSubtaskExistsCTEContainer: + """ + Base class for CTEs that determine validity for each subtask. + + Single column CTEs intended to be left-joined and considered valid only + if the joined row is not null. + """ + + def __init__( + self, + cte: CTE, + ) -> None: + self._cte = cte + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self.cte.columns[0] + + @property + def not_exists_query(self) -> ColumnElement[bool]: + return ( + ~exists() + .where(self.url_id == URL.id) + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/_shared/ctes/__init__.py b/src/core/tasks/url/operators/_shared/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/_shared/ctes/validated.py b/src/core/tasks/url/operators/_shared/ctes/validated.py new file mode 100644 index 00000000..43f6a6ba --- /dev/null +++ b/src/core/tasks/url/operators/_shared/ctes/validated.py @@ -0,0 +1,16 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.exists import \ + URLsSubtaskExistsCTEContainer +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated + +cte = ( + select( + FlagURLValidated.url_id + ) + .cte("validated_exists") +) + +VALIDATED_EXISTS_CONTAINER = URLsSubtaskExistsCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/_shared/exceptions.py b/src/core/tasks/url/operators/_shared/exceptions.py new file mode 100644 index 00000000..709189e3 --- /dev/null +++ b/src/core/tasks/url/operators/_shared/exceptions.py @@ -0,0 +1,4 @@ + + +class SubtaskError(Exception): + pass \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/core.py b/src/core/tasks/url/operators/agency_identification/core.py index d93143aa..7657ea0e 100644 --- a/src/core/tasks/url/operators/agency_identification/core.py +++ b/src/core/tasks/url/operators/agency_identification/core.py @@ -1,100 +1,68 @@ -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.core.tasks.url.operators.agency_identification.dtos.tdo import AgencyIdentificationTDO +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.url.operators._shared.exceptions import SubtaskError +from src.core.tasks.url.operators.agency_identification.subtasks.flags.core import SubtaskFlagger +from src.core.tasks.url.operators.agency_identification.subtasks.loader import AgencyIdentificationSubtaskLoader +from src.core.tasks.url.operators.agency_identification.subtasks.models.run_info import AgencyIDSubtaskRunInfo +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.core import \ + AgencyIDSubtaskSurveyQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo from src.db.enums import TaskType -from src.collectors.enums import CollectorType -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.core.tasks.url.subtasks.agency_identification.auto_googler import AutoGooglerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.ckan import CKANAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.common_crawler import CommonCrawlerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.muckrock import MuckrockAgencyIdentificationSubtask -from src.core.enums import SuggestionType -from src.external.pdap.client import PDAPClient +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType -# TODO: Validate with Manual Tests - -class AgencyIdentificationTaskOperator(URLTaskOperatorBase): +class AgencyIdentificationTaskOperator( + URLTaskOperatorBase, + LinkURLsMixin +): def __init__( self, adb_client: AsyncDatabaseClient, - pdap_client: PDAPClient, - muckrock_api_interface: MuckrockAPIInterface, + loader: AgencyIdentificationSubtaskLoader, ): super().__init__(adb_client) - self.pdap_client = pdap_client - self.muckrock_api_interface = muckrock_api_interface + self.loader = loader + self._subtask: AutoAgencyIDSubtaskType | None = None @property - def task_type(self): + def task_type(self) -> TaskType: return TaskType.AGENCY_IDENTIFICATION - async def meets_task_prerequisites(self): - has_urls_without_agency_suggestions = await self.adb_client.has_urls_without_agency_suggestions() - return has_urls_without_agency_suggestions - - async def get_pending_urls_without_agency_identification(self): - return await self.adb_client.get_urls_without_agency_suggestions() + async def meets_task_prerequisites(self) -> bool: + """ + Modifies: + - self._subtask + """ + flagger = SubtaskFlagger() + allowed_subtasks: list[AutoAgencyIDSubtaskType] = flagger.get_allowed_subtasks() - async def get_muckrock_subtask(self): - return MuckrockAgencyIdentificationSubtask( - muckrock_api_interface=self.muckrock_api_interface, - pdap_client=self.pdap_client - ) - - async def get_subtask(self, collector_type: CollectorType): - match collector_type: - case CollectorType.MUCKROCK_SIMPLE_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.MUCKROCK_COUNTY_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.MUCKROCK_ALL_SEARCH: - return await self.get_muckrock_subtask() - case CollectorType.AUTO_GOOGLER: - return AutoGooglerAgencyIdentificationSubtask() - case CollectorType.COMMON_CRAWLER: - return CommonCrawlerAgencyIdentificationSubtask() - case CollectorType.CKAN: - return CKANAgencyIdentificationSubtask( - pdap_client=self.pdap_client + next_subtask: AutoAgencyIDSubtaskType | None = \ + await self.adb_client.run_query_builder( + AgencyIDSubtaskSurveyQueryBuilder( + allowed_subtasks=allowed_subtasks ) - return None + ) + self._subtask = next_subtask + if next_subtask is None: + return False + return True - @staticmethod - async def run_subtask(subtask, url_id, collector_metadata) -> list[URLAgencySuggestionInfo]: - return await subtask.run(url_id=url_id, collector_metadata=collector_metadata) - async def inner_task_logic(self): - tdos: list[AgencyIdentificationTDO] = await self.get_pending_urls_without_agency_identification() - await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) - error_infos = [] - all_agency_suggestions = [] - for tdo in tdos: - subtask = await self.get_subtask(tdo.collector_type) - try: - new_agency_suggestions = await self.run_subtask( - subtask, - tdo.url_id, - tdo.collector_metadata - ) - all_agency_suggestions.extend(new_agency_suggestions) - except Exception as e: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, - url_id=tdo.url_id, - error=str(e), - ) - error_infos.append(error_info) + async def load_subtask( + self, + subtask_type: AutoAgencyIDSubtaskType + ) -> AgencyIDSubtaskOperatorBase: + """Get subtask based on collector type.""" + return await self.loader.load_subtask(subtask_type, task_id=self.task_id) - non_unknown_agency_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type != SuggestionType.UNKNOWN] - await self.adb_client.upsert_new_agencies(non_unknown_agency_suggestions) - confirmed_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type == SuggestionType.CONFIRMED] - await self.adb_client.add_confirmed_agency_url_links(confirmed_suggestions) - non_confirmed_suggestions = [suggestion for suggestion in all_agency_suggestions if suggestion.suggestion_type != SuggestionType.CONFIRMED] - await self.adb_client.add_agency_auto_suggestions(non_confirmed_suggestions) - await self.adb_client.add_url_error_infos(error_infos) + async def inner_task_logic(self) -> None: + subtask_operator: AgencyIDSubtaskOperatorBase = await self.load_subtask(self._subtask) + print(f"Running Subtask: {self._subtask.value}") + run_info: AgencyIDSubtaskRunInfo = await subtask_operator.run() + await self.link_urls_to_task(run_info.linked_url_ids) + if not run_info.is_success: + raise SubtaskError(run_info.error) diff --git a/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py b/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py index c0ea08f4..39f2cab3 100644 --- a/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py +++ b/src/core/tasks/url/operators/agency_identification/dtos/suggestion.py @@ -7,10 +7,10 @@ class URLAgencySuggestionInfo(BaseModel): url_id: int - suggestion_type: SuggestionType - pdap_agency_id: Optional[int] = None - agency_name: Optional[str] = None - state: Optional[str] = None - county: Optional[str] = None - locality: Optional[str] = None - user_id: Optional[int] = None + suggestion_type: SuggestionType = SuggestionType.UNKNOWN + pdap_agency_id: int | None = None + agency_name: str | None = None + state: str | None = None + county: str | None = None + locality: str | None = None + user_id: int | None = None diff --git a/src/core/tasks/url/operators/agency_identification/dtos/tdo.py b/src/core/tasks/url/operators/agency_identification/dtos/tdo.py deleted file mode 100644 index 70ff1ae5..00000000 --- a/src/core/tasks/url/operators/agency_identification/dtos/tdo.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from src.collectors.enums import CollectorType - - -class AgencyIdentificationTDO(BaseModel): - url_id: int - collector_metadata: Optional[dict] = None - collector_type: CollectorType diff --git a/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py b/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py deleted file mode 100644 index 27459145..00000000 --- a/src/core/tasks/url/operators/agency_identification/queries/get_pending_urls_without_agency_suggestions.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import Any - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from src.collectors.enums import URLStatus, CollectorType -from src.core.tasks.url.operators.agency_identification.dtos.tdo import AgencyIdentificationTDO -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer - - -class GetPendingURLsWithoutAgencySuggestionsQueryBuilder(QueryBuilderBase): - - async def run(self, session: AsyncSession) -> list[AgencyIdentificationTDO]: - - statement = ( - select(URL.id, URL.collector_metadata, Batch.strategy) - .select_from(URL) - .where(URL.outcome == URLStatus.PENDING.value) - .join(LinkBatchURL) - .join(Batch) - ) - statement = StatementComposer.exclude_urls_with_agency_suggestions(statement) - statement = statement.limit(100) - raw_results = await session.execute(statement) - return [ - AgencyIdentificationTDO( - url_id=raw_result[0], - collector_metadata=raw_result[1], - collector_type=CollectorType(raw_result[2]) - ) - for raw_result in raw_results - ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/convert.py b/src/core/tasks/url/operators/agency_identification/subtasks/convert.py new file mode 100644 index 00000000..95c9e704 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/convert.py @@ -0,0 +1,54 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.models.suggestion import AgencySuggestion +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo +from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse +from src.external.pdap.enums import MatchAgencyResponseStatus + +def convert_match_agency_response_to_subtask_data( + url_id: int, + response: MatchAgencyResponse, + subtask_type: AutoAgencyIDSubtaskType, + task_id: int +): + suggestions: list[AgencySuggestion] = \ + _convert_match_agency_response_to_suggestions( + response + ) + agencies_found: bool = len(suggestions) > 0 + subtask_pydantic = URLAutoAgencyIDSubtaskPydantic( + url_id=url_id, + type=subtask_type, + agencies_found=agencies_found, + task_id=task_id + ) + return AutoAgencyIDSubtaskData( + pydantic_model=subtask_pydantic, + suggestions=suggestions + ) + +def _convert_match_agency_response_to_suggestions( + match_response: MatchAgencyResponse, +) -> list[AgencySuggestion]: + if match_response.status == MatchAgencyResponseStatus.EXACT_MATCH: + match_info: MatchAgencyInfo = match_response.matches[0] + return [ + AgencySuggestion( + agency_id=int(match_info.id), + confidence=100 + ) + ] + if match_response.status == MatchAgencyResponseStatus.NO_MATCH: + return [] + if match_response.status != MatchAgencyResponseStatus.PARTIAL_MATCH: + raise ValueError(f"Unknown Match Agency Response Status: {match_response.status}") + total_confidence: int = 100 + confidence_per_match: int = total_confidence // len(match_response.matches) + return [ + AgencySuggestion( + agency_id=int(match_info.id), + confidence=confidence_per_match + ) + for match_info in match_response.matches + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/flags/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/flags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/flags/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/flags/core.py new file mode 100644 index 00000000..41997322 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/flags/core.py @@ -0,0 +1,26 @@ + +from environs import Env + +from src.core.tasks.url.operators.agency_identification.subtasks.flags.mappings import SUBTASK_TO_ENV_FLAG +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + + +class SubtaskFlagger: + """ + Manages flags allowing and disallowing subtasks + """ + def __init__(self): + self.env = Env() + + def _get_subtask_flag(self, subtask_type: AutoAgencyIDSubtaskType) -> bool: + return self.env.bool( + SUBTASK_TO_ENV_FLAG[subtask_type], + default=True + ) + + def get_allowed_subtasks(self) -> list[AutoAgencyIDSubtaskType]: + return [ + subtask_type + for subtask_type, flag in SUBTASK_TO_ENV_FLAG.items() + if self._get_subtask_flag(subtask_type) + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/flags/mappings.py b/src/core/tasks/url/operators/agency_identification/subtasks/flags/mappings.py new file mode 100644 index 00000000..dcc0b60c --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/flags/mappings.py @@ -0,0 +1,9 @@ +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +SUBTASK_TO_ENV_FLAG: dict[AutoAgencyIDSubtaskType, str] = { + AutoAgencyIDSubtaskType.HOMEPAGE_MATCH: "AGENCY_ID_HOMEPAGE_MATCH_FLAG", + AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH: "AGENCY_ID_NLP_LOCATION_MATCH_FLAG", + AutoAgencyIDSubtaskType.CKAN: "AGENCY_ID_CKAN_FLAG", + AutoAgencyIDSubtaskType.MUCKROCK: "AGENCY_ID_MUCKROCK_FLAG", + AutoAgencyIDSubtaskType.BATCH_LINK: "AGENCY_ID_BATCH_LINK_FLAG" +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/core.py new file mode 100644 index 00000000..9e15996f --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/core.py @@ -0,0 +1,48 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.impl.batch_link.params import \ + AgencyBatchLinkSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.batch_link.query import \ + GetLocationBatchLinkSubtaskParamsQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.models.suggestion import AgencySuggestion +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic + + +class AgencyBatchLinkSubtaskOperator(AgencyIDSubtaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int + ): + super().__init__(adb_client=adb_client, task_id=task_id) + + async def inner_logic(self) -> None: + params: list[AgencyBatchLinkSubtaskParams] = await self._get_params() + self.linked_urls = [param.url_id for param in params] + subtask_data_list: list[AutoAgencyIDSubtaskData] = [] + for param in params: + subtask_data: AutoAgencyIDSubtaskData = AutoAgencyIDSubtaskData( + pydantic_model=URLAutoAgencyIDSubtaskPydantic( + task_id=self.task_id, + url_id=param.url_id, + type=AutoAgencyIDSubtaskType.BATCH_LINK, + agencies_found=True, + ), + suggestions=[ + AgencySuggestion( + agency_id=param.agency_id, + confidence=80, + ) + ], + ) + subtask_data_list.append(subtask_data) + + await self._upload_subtask_data(subtask_data_list) + + async def _get_params(self) -> list[AgencyBatchLinkSubtaskParams]: + return await self.adb_client.run_query_builder( + GetLocationBatchLinkSubtaskParamsQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/params.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/params.py new file mode 100644 index 00000000..3008f9be --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/params.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class AgencyBatchLinkSubtaskParams(BaseModel): + url_id: int + agency_id: int \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/query.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/query.py new file mode 100644 index 00000000..008bd1f2 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/batch_link/query.py @@ -0,0 +1,45 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.batch_link.params import \ + AgencyBatchLinkSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.eligible import \ + EligibleContainer +from src.db.models.impl.link.agency_batch.sqlalchemy import LinkAgencyBatch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class GetLocationBatchLinkSubtaskParamsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[AgencyBatchLinkSubtaskParams]: + container = EligibleContainer() + query = ( + select( + container.url_id, + LinkAgencyBatch.agency_id, + ) + .select_from(container.cte) + .join( + LinkBatchURL, + LinkBatchURL.url_id == container.url_id, + ) + .join( + LinkAgencyBatch, + LinkAgencyBatch.batch_id == LinkBatchURL.batch_id, + ) + .where( + container.batch_link, + ) + .limit(500) + ) + results: Sequence[RowMapping] = await sh.mappings(session, query=query) + return [ + AgencyBatchLinkSubtaskParams( + url_id=mapping["id"], + agency_id=mapping["agency_id"], + ) + for mapping in results + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/core.py new file mode 100644 index 00000000..d1af5391 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/core.py @@ -0,0 +1,54 @@ +from typing import final + +from typing_extensions import override + +from src.core.tasks.url.operators.agency_identification.subtasks.convert import \ + convert_match_agency_response_to_subtask_data +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan_.params import CKANAgencyIDSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan_.query import \ + GetCKANAgencyIDSubtaskParamsQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import \ + AgencyIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.external.pdap.client import PDAPClient +from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse + + +@final +class CKANAgencyIDSubtaskOperator(AgencyIDSubtaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int, + pdap_client: PDAPClient + ): + super().__init__(adb_client, task_id=task_id) + self.pdap_client = pdap_client + + @override + async def inner_logic(self) -> None: + params: list[CKANAgencyIDSubtaskParams] = await self._get_params() + self.linked_urls = [param.url_id for param in params] + subtask_data_list: list[AutoAgencyIDSubtaskData] = [] + for param in params: + agency_name: str = param.collector_metadata["agency_name"] + response: MatchAgencyResponse = await self.pdap_client.match_agency( + name=agency_name + ) + subtask_data: AutoAgencyIDSubtaskData = convert_match_agency_response_to_subtask_data( + url_id=param.url_id, + response=response, + subtask_type=AutoAgencyIDSubtaskType.CKAN, + task_id=self.task_id + ) + subtask_data_list.append(subtask_data) + + await self._upload_subtask_data(subtask_data_list) + + async def _get_params(self) -> list[CKANAgencyIDSubtaskParams]: + return await self.adb_client.run_query_builder( + GetCKANAgencyIDSubtaskParamsQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/params.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/params.py new file mode 100644 index 00000000..ce4b7ce1 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/params.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class CKANAgencyIDSubtaskParams(BaseModel): + url_id: int + collector_metadata: dict \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/query.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/query.py new file mode 100644 index 00000000..503d5414 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/ckan_/query.py @@ -0,0 +1,43 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan_.params import CKANAgencyIDSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.eligible import \ + EligibleContainer +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetCKANAgencyIDSubtaskParamsQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> list[CKANAgencyIDSubtaskParams]: + container = EligibleContainer() + query = ( + select( + container.url_id, + URL.collector_metadata + ) + .join( + URL, + URL.id == container.url_id, + ) + .where( + container.ckan, + ) + .limit(500) + ) + + results: Sequence[RowMapping] = await sh.mappings(session, query=query) + return [ + CKANAgencyIDSubtaskParams( + url_id=mapping["id"], + collector_metadata=mapping["collector_metadata"], + ) + for mapping in results + ] diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/convert.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/convert.py new file mode 100644 index 00000000..f4ba913e --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/convert.py @@ -0,0 +1,47 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.models.entry import \ + GetHomepageMatchParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.models.mapping import \ + SubtaskURLMapping +from src.db.models.impl.url.suggestion.agency.subtask.enum import SubtaskDetailCode, AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.db.models.impl.url.suggestion.agency.suggestion.pydantic import AgencyIDSubtaskSuggestionPydantic + + +def convert_params_to_subtask_entries( + params: list[GetHomepageMatchParams], + task_id: int +) -> list[URLAutoAgencyIDSubtaskPydantic]: + url_id_to_detail_code: dict[int, SubtaskDetailCode] = {} + for param in params: + url_id_to_detail_code[param.url_id] = param.detail_code + + results: list[URLAutoAgencyIDSubtaskPydantic] = [] + for url_id, detail_code in url_id_to_detail_code.items(): + result = URLAutoAgencyIDSubtaskPydantic( + task_id=task_id, + url_id=url_id, + type=AutoAgencyIDSubtaskType.HOMEPAGE_MATCH, + agencies_found=True, + detail=detail_code, + ) + results.append(result) + return results + +def convert_subtask_mappings_and_params_to_suggestions( + mappings: list[SubtaskURLMapping], + params: list[GetHomepageMatchParams] +) -> list[AgencyIDSubtaskSuggestionPydantic]: + url_id_to_subtask_id: dict[int, int] = { + mapping.url_id: mapping.subtask_id + for mapping in mappings + } + suggestions: list[AgencyIDSubtaskSuggestionPydantic] = [] + for param in params: + subtask_id = url_id_to_subtask_id.get(param.url_id) + suggestion = AgencyIDSubtaskSuggestionPydantic( + subtask_id=subtask_id, + agency_id=param.agency_id, + confidence=param.confidence, + ) + suggestions.append(suggestion) + return suggestions \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/core.py new file mode 100644 index 00000000..f335cb3a --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/core.py @@ -0,0 +1,63 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.convert import \ + convert_params_to_subtask_entries, convert_subtask_mappings_and_params_to_suggestions +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.models.entry import \ + GetHomepageMatchParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.models.mapping import \ + SubtaskURLMapping +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.get import \ + GetHomepageMatchSubtaskURLsQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.db.models.impl.url.suggestion.agency.suggestion.pydantic import AgencyIDSubtaskSuggestionPydantic + + +class HomepageMatchSubtaskOperator( + AgencyIDSubtaskOperatorBase, +): + + async def inner_logic(self) -> None: + # Get Params + params: list[GetHomepageMatchParams] = \ + await self.adb_client.run_query_builder( + GetHomepageMatchSubtaskURLsQueryBuilder() + ) + + # Insert Subtask Entries + subtask_entries: list[URLAutoAgencyIDSubtaskPydantic] = convert_params_to_subtask_entries( + params=params, + task_id=self.task_id + ) + subtask_mappings: list[SubtaskURLMapping] = await self.insert_subtask_entries( + entries=subtask_entries + ) + + # Link URLs + url_ids: list[int] = [mapping.url_id for mapping in subtask_mappings] + self.linked_urls = url_ids + + # Insert Entries + suggestions: list[AgencyIDSubtaskSuggestionPydantic] = convert_subtask_mappings_and_params_to_suggestions( + mappings=subtask_mappings, + params=params + ) + await self.adb_client.bulk_insert( + models=suggestions, + ) + + + async def insert_subtask_entries( + self, + entries: list[URLAutoAgencyIDSubtaskPydantic] + ) -> list[SubtaskURLMapping]: + subtask_ids: list[int] = await self.adb_client.bulk_insert( + models=entries, + return_ids=True + ) + mappings: list[SubtaskURLMapping] = [] + for subtask_id, entry in zip(subtask_ids, entries): + mapping = SubtaskURLMapping( + url_id=entry.url_id, + subtask_id=subtask_id, + ) + mappings.append(mapping) + return mappings diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/entry.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/entry.py new file mode 100644 index 00000000..6c65f9ad --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/entry.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + +from src.db.models.impl.url.suggestion.agency.subtask.enum import SubtaskDetailCode + + +class GetHomepageMatchParams(BaseModel): + url_id: int + agency_id: int + confidence: int = Field(..., ge=0, le=100) + detail_code: SubtaskDetailCode \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/mapping.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/mapping.py new file mode 100644 index 00000000..2e4d2fbb --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/models/mapping.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class SubtaskURLMapping(BaseModel): + url_id: int + subtask_id: int \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/consolidated.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/consolidated.py new file mode 100644 index 00000000..d90dfed6 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/consolidated.py @@ -0,0 +1,28 @@ +from sqlalchemy import CTE, select + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.count_agency_per_url import \ + COUNT_AGENCY_PER_URL_CTE +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.meta_urls_with_root_agencies import \ + META_ROOT_URLS_WITH_AGENCIES +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.unvalidated_urls_with_root import \ + UNVALIDATED_URLS_WITH_ROOT + +CONSOLIDATED_CTE: CTE = ( + select( + UNVALIDATED_URLS_WITH_ROOT.c.url_id, + META_ROOT_URLS_WITH_AGENCIES.c.agency_id, + COUNT_AGENCY_PER_URL_CTE.c.agency_count, + ) + .join( + COUNT_AGENCY_PER_URL_CTE, + COUNT_AGENCY_PER_URL_CTE.c.root_url_id == UNVALIDATED_URLS_WITH_ROOT.c.root_url_id + ) + .join( + META_ROOT_URLS_WITH_AGENCIES, + META_ROOT_URLS_WITH_AGENCIES.c.root_url_id == UNVALIDATED_URLS_WITH_ROOT.c.root_url_id + ) + .where( + COUNT_AGENCY_PER_URL_CTE.c.agency_count >= 1 + ) + .cte("consolidated") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/count_agency_per_url.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/count_agency_per_url.py new file mode 100644 index 00000000..774787b7 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/count_agency_per_url.py @@ -0,0 +1,20 @@ +from sqlalchemy import CTE, func, select + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.meta_urls_with_root import \ + META_ROOT_URLS_CTE +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency + +COUNT_AGENCY_PER_URL_CTE: CTE = ( + select( + META_ROOT_URLS_CTE.c.root_url_id, + func.count(LinkURLAgency.agency_id).label("agency_count") + ) + .join( + LinkURLAgency, + META_ROOT_URLS_CTE.c.meta_url_id == LinkURLAgency.url_id + ) + .group_by( + META_ROOT_URLS_CTE.c.root_url_id + ) + .cte("count_agency_per_url") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root.py new file mode 100644 index 00000000..63b6b417 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root.py @@ -0,0 +1,23 @@ +from sqlalchemy import CTE, select + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.whitelisted_root_urls import \ + WHITELISTED_ROOT_URLS_CTE +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.views.meta_url import MetaURL + +META_ROOT_URLS_CTE: CTE = ( + select( + MetaURL.url_id.label("meta_url_id"), + LinkURLRootURL.root_url_id + ) + .join( + LinkURLRootURL, + MetaURL.url_id == LinkURLRootURL.url_id + ) + # Must be a Whitelisted Root URL + .join( + WHITELISTED_ROOT_URLS_CTE, + WHITELISTED_ROOT_URLS_CTE.c.id == LinkURLRootURL.root_url_id + ) + .cte("meta_root_urls") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root_agencies.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root_agencies.py new file mode 100644 index 00000000..86b14ee4 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/meta_urls_with_root_agencies.py @@ -0,0 +1,20 @@ +from sqlalchemy import CTE, select + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.meta_urls_with_root import \ + META_ROOT_URLS_CTE +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency + +META_ROOT_URLS_WITH_AGENCIES: CTE = ( + select( + META_ROOT_URLS_CTE.c.meta_url_id, + META_ROOT_URLS_CTE.c.root_url_id, + LinkURLAgency.agency_id + ) + .join( + LinkURLAgency, + META_ROOT_URLS_CTE.c.meta_url_id == LinkURLAgency.url_id + ) + .cte( + "meta_root_urls_with_agencies" + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/multi_agency_case.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/multi_agency_case.py new file mode 100644 index 00000000..edf9e601 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/multi_agency_case.py @@ -0,0 +1,17 @@ +from sqlalchemy import CTE, select, literal + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.consolidated import \ + CONSOLIDATED_CTE +from src.db.models.impl.url.suggestion.agency.subtask.enum import SubtaskDetailCode + +MULTI_AGENCY_CASE_QUERY = ( + select( + CONSOLIDATED_CTE.c.url_id, + CONSOLIDATED_CTE.c.agency_id, + (literal(100) / CONSOLIDATED_CTE.c.agency_count).label("confidence"), + literal(SubtaskDetailCode.HOMEPAGE_MULTI_AGENCY.value).label("detail_code") + ) + .where( + CONSOLIDATED_CTE.c.agency_count > 1 + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/single_agency_case.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/single_agency_case.py new file mode 100644 index 00000000..5778ecb6 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/single_agency_case.py @@ -0,0 +1,17 @@ +from sqlalchemy import select, CTE, literal + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.consolidated import \ + CONSOLIDATED_CTE +from src.db.models.impl.url.suggestion.agency.subtask.enum import SubtaskDetailCode + +SINGLE_AGENCY_CASE_QUERY = ( + select( + CONSOLIDATED_CTE.c.url_id, + CONSOLIDATED_CTE.c.agency_id, + literal(95).label("confidence"), + literal(SubtaskDetailCode.HOMEPAGE_SINGLE_AGENCY.value).label("detail_code") + ) + .where( + CONSOLIDATED_CTE.c.agency_count == 1 + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/unvalidated_urls_with_root.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/unvalidated_urls_with_root.py new file mode 100644 index 00000000..46702833 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/unvalidated_urls_with_root.py @@ -0,0 +1,22 @@ +from sqlalchemy import CTE, select + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.whitelisted_root_urls import \ + WHITELISTED_ROOT_URLS_CTE +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.views.unvalidated_url import UnvalidatedURL + +UNVALIDATED_URLS_WITH_ROOT: CTE = ( + select( + UnvalidatedURL.url_id, + LinkURLRootURL.root_url_id + ) + .join( + LinkURLRootURL, + UnvalidatedURL.url_id == LinkURLRootURL.url_id + ) + .join( + WHITELISTED_ROOT_URLS_CTE, + WHITELISTED_ROOT_URLS_CTE.c.id == LinkURLRootURL.root_url_id + ) + .cte("unvalidated_urls_with_root") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/whitelisted_root_urls.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/whitelisted_root_urls.py new file mode 100644 index 00000000..272717b5 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/ctes/whitelisted_root_urls.py @@ -0,0 +1,47 @@ +from sqlalchemy import CTE, select, func + +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.sqlalchemy import URL + +WHITELISTED_ROOT_URLS_CTE: CTE = ( + select( + URL.id + ) + .join( + FlagRootURL, + URL.id == FlagRootURL.url_id + ) + # Must be linked to other URLs + .join( + LinkURLRootURL, + URL.id == LinkURLRootURL.root_url_id + ) + # Those URLs must be meta URLS + .join( + FlagURLValidated, + FlagURLValidated.url_id == LinkURLRootURL.url_id + ) + # Get the Agency URLs for those URLs + .join( + LinkURLAgency, + LinkURLAgency.url_id == LinkURLRootURL.url_id + ) + .where( + # The connected URLs must be Meta URLs + FlagURLValidated.type == URLType.META_URL, + # Root URL can't be "https://catalog.data.gov" + URL.url != "https://catalog.data.gov" + ) + .group_by( + URL.id + ) + # Must have no more than two agencies connected + .having( + func.count(LinkURLAgency.agency_id) <= 2 + ) + .cte("whitelisted_root_urls") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/get.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/get.py new file mode 100644 index 00000000..10619531 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/homepage_match_/queries/get.py @@ -0,0 +1,35 @@ +from typing import Sequence + +from sqlalchemy import Select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.models.entry import \ + GetHomepageMatchParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.multi_agency_case import \ + MULTI_AGENCY_CASE_QUERY +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.single_agency_case import \ + SINGLE_AGENCY_CASE_QUERY +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.suggestion.agency.subtask.enum import SubtaskDetailCode +from src.db.queries.base.builder import QueryBuilderBase + + +class GetHomepageMatchSubtaskURLsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[GetHomepageMatchParams]: + + query: Select = SINGLE_AGENCY_CASE_QUERY.union(MULTI_AGENCY_CASE_QUERY) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + results: list[GetHomepageMatchParams] = [] + for mapping in mappings: + response = GetHomepageMatchParams( + url_id=mapping["url_id"], + agency_id=mapping["agency_id"], + confidence=mapping["confidence"], + detail_code=SubtaskDetailCode(mapping["detail_code"]), + ) + results.append(response) + + return results \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/core.py new file mode 100644 index 00000000..4fa92c2e --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/core.py @@ -0,0 +1,93 @@ +from typing import final + +from typing_extensions import override + +from src.collectors.impl.muckrock.api_interface.core import MuckrockAPIInterface +from src.collectors.impl.muckrock.api_interface.lookup_response import AgencyLookupResponse +from src.collectors.impl.muckrock.enums import AgencyLookupResponseType +from src.core.tasks.url.operators.agency_identification.subtasks.convert import \ + convert_match_agency_response_to_subtask_data +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock_.params import \ + MuckrockAgencyIDSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock_.query import \ + GetMuckrockAgencyIDSubtaskParamsQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType, SubtaskDetailCode +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.external.pdap.client import PDAPClient +from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse + + +@final +class MuckrockAgencyIDSubtaskOperator(AgencyIDSubtaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int, + muckrock_api_interface: MuckrockAPIInterface, + pdap_client: PDAPClient + ): + super().__init__(adb_client, task_id=task_id) + self.muckrock_api_interface = muckrock_api_interface + self.pdap_client = pdap_client + + @override + async def inner_logic(self) -> None: + params: list[MuckrockAgencyIDSubtaskParams] = await self._get_params() + self.linked_urls = [param.url_id for param in params] + subtask_data_list: list[AutoAgencyIDSubtaskData] = [] + for param in params: + muckrock_agency_id: int = param.collector_metadata["agency"] + agency_lookup_response: AgencyLookupResponse = await self.muckrock_api_interface.lookup_agency( + muckrock_agency_id=muckrock_agency_id + ) + if agency_lookup_response.type != AgencyLookupResponseType.FOUND: + data: AutoAgencyIDSubtaskData = await self._error_subtask_data( + url_id=param.url_id, + muckrock_agency_id=muckrock_agency_id, + agency_lookup_response=agency_lookup_response + ) + subtask_data_list.append(data) + continue + match_agency_response: MatchAgencyResponse = await self.pdap_client.match_agency( + name=agency_lookup_response.name + ) + subtask_data: AutoAgencyIDSubtaskData = convert_match_agency_response_to_subtask_data( + url_id=param.url_id, + response=match_agency_response, + subtask_type=AutoAgencyIDSubtaskType.MUCKROCK, + task_id=self.task_id + ) + subtask_data_list.append(subtask_data) + + await self._upload_subtask_data(subtask_data_list) + + + async def _error_subtask_data( + self, + url_id: int, + muckrock_agency_id: int, + agency_lookup_response: AgencyLookupResponse + ) -> AutoAgencyIDSubtaskData: + pydantic_model = URLAutoAgencyIDSubtaskPydantic( + task_id=self.task_id, + url_id=url_id, + type=AutoAgencyIDSubtaskType.MUCKROCK, + agencies_found=False, + detail=SubtaskDetailCode.RETRIEVAL_ERROR + ) + error: str = f"Failed to lookup muckrock agency: {muckrock_agency_id}:" + \ + f" {agency_lookup_response.type.value}: {agency_lookup_response.error}" + return AutoAgencyIDSubtaskData( + pydantic_model=pydantic_model, + suggestions=[], + error=error + ) + + async def _get_params(self) -> list[MuckrockAgencyIDSubtaskParams]: + return await self.adb_client.run_query_builder( + GetMuckrockAgencyIDSubtaskParamsQueryBuilder() + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/params.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/params.py new file mode 100644 index 00000000..6010f022 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/params.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class MuckrockAgencyIDSubtaskParams(BaseModel): + url_id: int + collector_metadata: dict \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/query.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/query.py new file mode 100644 index 00000000..6f575b4f --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/muckrock_/query.py @@ -0,0 +1,49 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.enums import CollectorType +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock_.params import \ + MuckrockAgencyIDSubtaskParams +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.eligible import \ + EligibleContainer +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetMuckrockAgencyIDSubtaskParamsQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> list[MuckrockAgencyIDSubtaskParams]: + container = EligibleContainer() + + query = ( + select( + container.url_id, + URL.collector_metadata + ) + .join( + URL, + URL.id == container.url_id, + ) + .where( + container.muckrock, + ) + .limit(500) + ) + + results: Sequence[RowMapping] = await sh.mappings(session, query=query) + return [ + MuckrockAgencyIDSubtaskParams( + url_id=mapping["id"], + collector_metadata=mapping["collector_metadata"], + ) + for mapping in results + ] + diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/convert.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/convert.py new file mode 100644 index 00000000..2766bff0 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/convert.py @@ -0,0 +1,49 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.models.input import \ + NLPLocationMatchSubtaskInput +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.models.suggestion import AgencySuggestion +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic + + +def convert_location_agency_mappings_to_subtask_data_list( + task_id: int, + inputs: list[NLPLocationMatchSubtaskInput] +) -> list[AutoAgencyIDSubtaskData]: + results: list[AutoAgencyIDSubtaskData] = [] + for input_ in inputs: + suggestions: list[AgencySuggestion] = [] + if not input_.has_locations_with_agencies: + agencies_found: bool = False + else: + agencies_found: bool = True + for mapping in input_.mappings: + agency_ids: list[int] = mapping.agency_ids + confidence_per_agency: int = _calculate_confidence_per_agency( + agency_ids, + confidence=mapping.location_annotation.confidence + ) + for agency_id in agency_ids: + suggestion = AgencySuggestion( + agency_id=agency_id, + confidence=confidence_per_agency, + ) + suggestions.append(suggestion) + data = AutoAgencyIDSubtaskData( + pydantic_model=URLAutoAgencyIDSubtaskPydantic( + url_id=input_.url_id, + type=AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH, + agencies_found=agencies_found, + task_id=task_id, + ), + suggestions=suggestions, + ) + results.append(data) + return results + + +def _calculate_confidence_per_agency(agency_ids: list[int], confidence: int): + num_agencies: int = len(agency_ids) + confidence_per_agency: int = confidence // num_agencies + return confidence_per_agency + diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/core.py new file mode 100644 index 00000000..4463ff0d --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/core.py @@ -0,0 +1,36 @@ +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.convert import \ + convert_location_agency_mappings_to_subtask_data_list +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.models.input import \ + NLPLocationMatchSubtaskInput +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.query_.query import \ + GetAgenciesLinkedToAnnotatedLocationsQueryBuilder +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient + + +class NLPLocationMatchSubtaskOperator(AgencyIDSubtaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int, + ) -> None: + super().__init__(adb_client, task_id=task_id) + + async def inner_logic(self) -> None: + inputs: list[NLPLocationMatchSubtaskInput] = await self._get_from_db() + await self.run_subtask_iteration(inputs) + + async def run_subtask_iteration(self, inputs: list[NLPLocationMatchSubtaskInput]) -> None: + self.linked_urls.extend([input_.url_id for input_ in inputs]) + subtask_data_list: list[AutoAgencyIDSubtaskData] = convert_location_agency_mappings_to_subtask_data_list( + task_id=self.task_id, + inputs=inputs, + ) + await self._upload_subtask_data(subtask_data_list) + + async def _get_from_db(self) -> list[NLPLocationMatchSubtaskInput]: + return await self.adb_client.run_query_builder( + GetAgenciesLinkedToAnnotatedLocationsQueryBuilder(), + ) diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/input.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/input.py new file mode 100644 index 00000000..74fb49d1 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/input.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +class LocationAnnotation(BaseModel): + location_id: int + confidence: int + +class LocationAnnotationToAgencyIDMapping(BaseModel): + location_annotation: LocationAnnotation + agency_ids: list[int] + +class NLPLocationMatchSubtaskInput(BaseModel): + url_id: int + mappings: list[LocationAnnotationToAgencyIDMapping] + + @property + def has_locations_with_agencies(self) -> bool: + return len(self.mappings) > 0 \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/subsets/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/subsets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/subsets/nlp_responses.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/subsets/nlp_responses.py new file mode 100644 index 00000000..304c7e01 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/models/subsets/nlp_responses.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_nlp_response import \ + URLToNLPResponseMapping + + +class NLPResponseSubsets(BaseModel): + valid: list[URLToNLPResponseMapping] + invalid: list[URLToNLPResponseMapping] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/query.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/query.py new file mode 100644 index 00000000..f0dcac94 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/query.py @@ -0,0 +1,84 @@ +from collections import defaultdict +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.models.input import \ + NLPLocationMatchSubtaskInput, LocationAnnotationToAgencyIDMapping, LocationAnnotation +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.nlp_location import \ + NLP_LOCATION_CONTAINER +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class GetAgenciesLinkedToAnnotatedLocationsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[NLPLocationMatchSubtaskInput]: + query = ( + select( + NLP_LOCATION_CONTAINER.url_id, + LocationIDSubtaskSuggestion.location_id, + LocationIDSubtaskSuggestion.confidence, + LinkAgencyLocation.agency_id, + ) + .join( + AutoLocationIDSubtask, + AutoLocationIDSubtask.url_id == NLP_LOCATION_CONTAINER.url_id + ) + .join( + LocationIDSubtaskSuggestion, + LocationIDSubtaskSuggestion.subtask_id == AutoLocationIDSubtask.id + ) + .join( + LinkAgencyLocation, + LinkAgencyLocation.location_id == LocationIDSubtaskSuggestion.location_id + ) + .where( + ~NLP_LOCATION_CONTAINER.entry_exists + ) + ) + + url_id_to_location_id_to_agency_ids: dict[int, dict[int, list[int]]] = defaultdict( + lambda: defaultdict(list) + ) + url_id_to_location_id_to_annotations: dict[int, dict[int, LocationAnnotation]] = defaultdict(dict) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + for mapping in mappings: + url_id: int = mapping["id"] + location_id: int = mapping["location_id"] + confidence: int = mapping["confidence"] + agency_id: int = mapping["agency_id"] + + if agency_id is None: + continue + url_id_to_location_id_to_agency_ids[url_id][location_id].append(agency_id) + if location_id not in url_id_to_location_id_to_annotations[url_id]: + location_annotation = LocationAnnotation( + location_id=location_id, + confidence=confidence, + ) + url_id_to_location_id_to_annotations[url_id][location_id] = location_annotation + + results: list[NLPLocationMatchSubtaskInput] = [] + for url_id in url_id_to_location_id_to_agency_ids: + anno_mappings: list[LocationAnnotationToAgencyIDMapping] = [] + for location_id in url_id_to_location_id_to_agency_ids[url_id]: + location_annotation: LocationAnnotation = url_id_to_location_id_to_annotations[url_id][location_id] + agency_ids: list[int] = url_id_to_location_id_to_agency_ids[url_id][location_id] + anno_mapping: LocationAnnotationToAgencyIDMapping = LocationAnnotationToAgencyIDMapping( + location_annotation=location_annotation, + agency_ids=agency_ids, + ) + anno_mappings.append(anno_mapping) + input_ = NLPLocationMatchSubtaskInput( + url_id=url_id, + mappings=anno_mappings, + ) + results.append(input_) + return results + diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/response.py b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/response.py new file mode 100644 index 00000000..6205de78 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/impl/nlp_location_match_/query_/response.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class GetAgenciesLinkedToAnnotatedLocationsResponse(BaseModel): + url_id: int + location_id: int + location_confidence: int + agency_ids: list[int] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/loader.py b/src/core/tasks/url/operators/agency_identification/subtasks/loader.py new file mode 100644 index 00000000..24099540 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/loader.py @@ -0,0 +1,84 @@ +from src.collectors.impl.muckrock.api_interface.core import MuckrockAPIInterface +from src.core.tasks.url.operators.agency_identification.subtasks.impl.batch_link.core import \ + AgencyBatchLinkSubtaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.impl.ckan_.core import CKANAgencyIDSubtaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.core import \ + HomepageMatchSubtaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.impl.muckrock_.core import \ + MuckrockAgencyIDSubtaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.impl.nlp_location_match_.core import \ + NLPLocationMatchSubtaskOperator +from src.core.tasks.url.operators.agency_identification.subtasks.templates.subtask import AgencyIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.external.pdap.client import PDAPClient + + +class AgencyIdentificationSubtaskLoader: + """Loads subtasks and associated dependencies.""" + + def __init__( + self, + pdap_client: PDAPClient, + muckrock_api_interface: MuckrockAPIInterface, + adb_client: AsyncDatabaseClient, + ): + self._pdap_client = pdap_client + self._muckrock_api_interface = muckrock_api_interface + self.adb_client = adb_client + + def _load_muckrock_subtask(self, task_id: int) -> MuckrockAgencyIDSubtaskOperator: + return MuckrockAgencyIDSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + muckrock_api_interface=self._muckrock_api_interface, + pdap_client=self._pdap_client + ) + + def _load_ckan_subtask(self, task_id: int) -> CKANAgencyIDSubtaskOperator: + return CKANAgencyIDSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + pdap_client=self._pdap_client + ) + + def _load_homepage_match_subtask(self, task_id: int) -> HomepageMatchSubtaskOperator: + return HomepageMatchSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + ) + + def _load_nlp_location_match_subtask(self, task_id: int) -> NLPLocationMatchSubtaskOperator: + return NLPLocationMatchSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + ) + + def _load_batch_link_subtask( + self, + task_id: int + ) -> AgencyBatchLinkSubtaskOperator: + return AgencyBatchLinkSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + ) + + + async def load_subtask( + self, + subtask_type: AutoAgencyIDSubtaskType, + task_id: int + ) -> AgencyIDSubtaskOperatorBase: + """Get subtask based on collector type.""" + match subtask_type: + case AutoAgencyIDSubtaskType.MUCKROCK: + return self._load_muckrock_subtask(task_id) + case AutoAgencyIDSubtaskType.CKAN: + return self._load_ckan_subtask(task_id) + case AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH: + return self._load_nlp_location_match_subtask(task_id) + case AutoAgencyIDSubtaskType.HOMEPAGE_MATCH: + return self._load_homepage_match_subtask(task_id) + case AutoAgencyIDSubtaskType.BATCH_LINK: + return self._load_batch_link_subtask(task_id) + raise ValueError(f"Unknown subtask type: {subtask_type}") diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/models/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/models/run_info.py b/src/core/tasks/url/operators/agency_identification/subtasks/models/run_info.py new file mode 100644 index 00000000..524830e3 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/models/run_info.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + + +class AgencyIDSubtaskRunInfo(BaseModel): + error: str | None = None + linked_url_ids: list[int] | None = None + + @property + def is_success(self) -> bool: + return self.error is None + + @property + def has_linked_urls(self) -> bool: + return len(self.linked_url_ids) > 0 \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/models/subtask.py b/src/core/tasks/url/operators/agency_identification/subtasks/models/subtask.py new file mode 100644 index 00000000..7da0a8f5 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/models/subtask.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.agency_identification.subtasks.models.suggestion import AgencySuggestion +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic + + +class AutoAgencyIDSubtaskData(BaseModel): + pydantic_model: URLAutoAgencyIDSubtaskPydantic + suggestions: list[AgencySuggestion] + error: str | None = None + + @property + def has_error(self) -> bool: + return self.error is not None + + @property + def url_id(self) -> int: + return self.pydantic_model.url_id \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/models/suggestion.py b/src/core/tasks/url/operators/agency_identification/subtasks/models/suggestion.py new file mode 100644 index 00000000..669c498c --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/models/suggestion.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class AgencySuggestion(BaseModel): + agency_id: int + confidence: int = Field(ge=0, le=100) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/constants.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/constants.py new file mode 100644 index 00000000..bea99266 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/constants.py @@ -0,0 +1,15 @@ +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +# Determines priority of subtasks, all else being equal. +SUBTASK_HIERARCHY: list[AutoAgencyIDSubtaskType] = [ + AutoAgencyIDSubtaskType.CKAN, + AutoAgencyIDSubtaskType.MUCKROCK, + AutoAgencyIDSubtaskType.HOMEPAGE_MATCH, + AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH, + AutoAgencyIDSubtaskType.BATCH_LINK +] + +SUBTASK_HIERARCHY_MAPPING: dict[AutoAgencyIDSubtaskType, int] = { + subtask: idx + for idx, subtask in enumerate(SUBTASK_HIERARCHY) +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/core.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/core.py new file mode 100644 index 00000000..2b81d2de --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/core.py @@ -0,0 +1,77 @@ +from collections import Counter + +from sqlalchemy import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.constants import SUBTASK_HIERARCHY, \ + SUBTASK_HIERARCHY_MAPPING +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.eligible_counts import \ + ELIGIBLE_COUNTS_QUERY +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class AgencyIDSubtaskSurveyQueryBuilder(QueryBuilderBase): + """ + Survey applicable URLs to determine next subtask to run + + URLs are "inapplicable" if they have any of the following properties: + - Are validated via FlagURLValidated model + - Have at least one annotation with agency suggestion with confidence >= 95 + - Have all possible subtasks completed + + Returns a list of one or more subtasks to run + based on which subtask(s) have the most applicable URLs + (or an empty list if no subtasks have applicable URLs) + """ + + def __init__( + self, + allowed_subtasks: list[AutoAgencyIDSubtaskType] + ): + super().__init__() + self._allowed_subtasks = allowed_subtasks + + async def run(self, session: AsyncSession) -> AutoAgencyIDSubtaskType | None: + results: RowMapping = await sh.mapping(session, ELIGIBLE_COUNTS_QUERY) + counts: Counter[str] = Counter(results) + + allowed_counts: Counter[str] = await self._filter_allowed_counts(counts) + if len(allowed_counts) == 0: + return None + max_count: int = max(allowed_counts.values()) + if max_count == 0: + return None + subtasks_with_max_count: list[str] = [ + subtask for subtask, count in allowed_counts.items() + if count == max_count + ] + subtasks_as_enum_list: list[AutoAgencyIDSubtaskType] = [ + AutoAgencyIDSubtaskType(subtask) + for subtask in subtasks_with_max_count + ] + # Sort subtasks by priority + sorted_subtasks: list[AutoAgencyIDSubtaskType] = sorted( + subtasks_as_enum_list, + key=lambda subtask: SUBTASK_HIERARCHY_MAPPING[subtask], + reverse=True, + ) + # Return the highest priority subtask + return sorted_subtasks[0] + + async def _filter_allowed_counts(self, counts: Counter[str]) -> Counter[str]: + return Counter( + { + subtask: count + for subtask, count in counts.items() + if AutoAgencyIDSubtaskType(subtask) in self._allowed_subtasks + } + ) + + + + + + + diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/README.md b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/README.md new file mode 100644 index 00000000..38324fa7 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/README.md @@ -0,0 +1,3 @@ +Contains CTEs for determining validity for each subtask. + +Each file corresponds to the validity CTE for that subtask. \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/eligible.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/eligible.py new file mode 100644 index 00000000..ff7e2d72 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/eligible.py @@ -0,0 +1,64 @@ +from sqlalchemy import select, CTE, Column + +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.exists.high_confidence_annotations import \ + HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER +from src.core.tasks.url.operators._shared.ctes.validated import \ + VALIDATED_EXISTS_CONTAINER +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.batch_link import \ + BATCH_LINK_SUBTASK_CONTAINER +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.ckan import \ + CKAN_SUBTASK_CONTAINER +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.homepage import \ + HOMEPAGE_SUBTASK_CONTAINER +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.muckrock import \ + MUCKROCK_SUBTASK_CONTAINER +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.impl.nlp_location import \ + NLP_LOCATION_CONTAINER +from src.db.models.impl.url.core.sqlalchemy import URL + +class EligibleContainer: + + def __init__(self): + self._cte = ( + select( + URL.id, + CKAN_SUBTASK_CONTAINER.eligible_query.label("ckan"), + MUCKROCK_SUBTASK_CONTAINER.eligible_query.label("muckrock"), + HOMEPAGE_SUBTASK_CONTAINER.eligible_query.label("homepage"), + NLP_LOCATION_CONTAINER.eligible_query.label("nlp_location"), + BATCH_LINK_SUBTASK_CONTAINER.eligible_query.label("batch_link"), + ) + .where( + HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER.not_exists_query, + VALIDATED_EXISTS_CONTAINER.not_exists_query, + ) + .cte("eligible") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.c['id'] + + @property + def ckan(self) -> Column[bool]: + return self._cte.c['ckan'] + + @property + def batch_link(self) -> Column[bool]: + return self._cte.c['batch_link'] + + @property + def muckrock(self) -> Column[bool]: + return self._cte.c['muckrock'] + + @property + def homepage(self) -> Column[bool]: + return self._cte.c['homepage'] + + @property + def nlp_location(self) -> Column[bool]: + return self._cte.c['nlp_location'] \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/exists/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/exists/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py new file mode 100644 index 00000000..cfb92327 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py @@ -0,0 +1,29 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.exists import \ + URLsSubtaskExistsCTEContainer +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.sqlalchemy import URLAutoAgencyIDSubtask +from src.db.models.impl.url.suggestion.agency.suggestion.sqlalchemy import AgencyIDSubtaskSuggestion + +cte = ( + select( + URL.id + ) + .join( + URLAutoAgencyIDSubtask, + URLAutoAgencyIDSubtask.url_id == URL.id, + ) + .join( + AgencyIDSubtaskSuggestion, + AgencyIDSubtaskSuggestion.subtask_id == URLAutoAgencyIDSubtask.id, + ) + .where( + AgencyIDSubtaskSuggestion.confidence >= 95, + ) + .cte("high_confidence_annotations_exists") +) + +HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER = URLsSubtaskExistsCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/helpers.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/helpers.py new file mode 100644 index 00000000..b06442ea --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/helpers.py @@ -0,0 +1,18 @@ +from sqlalchemy import ColumnElement, exists + +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.sqlalchemy import URLAutoAgencyIDSubtask + + +def get_exists_subtask_query( + subtask_type: AutoAgencyIDSubtaskType, +) -> ColumnElement[bool]: + return ( + exists() + .where( + URLAutoAgencyIDSubtask.url_id == URL.id, + URLAutoAgencyIDSubtask.type == subtask_type, + ) + .label("subtask_entry_exists") + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py new file mode 100644 index 00000000..42fcc02f --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py @@ -0,0 +1,31 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.link.agency_batch.sqlalchemy import LinkAgencyBatch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +cte = ( + select( + URL.id, + get_exists_subtask_query( + AutoAgencyIDSubtaskType.BATCH_LINK, + ) + ) + .join( + LinkBatchURL, + LinkBatchURL.url_id == URL.id, + ) + .join( + LinkAgencyBatch, + LinkAgencyBatch.batch_id == LinkBatchURL.batch_id, + ) + .cte("batch_link_eligible") +) + +BATCH_LINK_SUBTASK_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/ckan.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/ckan.py new file mode 100644 index 00000000..6b8ed9e8 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/ckan.py @@ -0,0 +1,36 @@ +from sqlalchemy import select + +from src.collectors.enums import CollectorType +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +cte = ( + select( + URL.id, + get_exists_subtask_query( + AutoAgencyIDSubtaskType.CKAN, + ), + ) + .join( + LinkBatchURL, + LinkBatchURL.url_id == URL.id, + ) + .join( + Batch, + Batch.id == LinkBatchURL.batch_id, + ) + .where( + Batch.strategy == CollectorType.CKAN.value, + + ) + .cte("ckan_eligible") +) + +CKAN_SUBTASK_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/homepage.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/homepage.py new file mode 100644 index 00000000..7daba916 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/homepage.py @@ -0,0 +1,33 @@ +from sqlalchemy import select, exists + +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.agency_identification.subtasks.impl.homepage_match_.queries.ctes.consolidated import \ + CONSOLIDATED_CTE +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +VALID_URL_FLAG = ( + exists() + .where( + URL.id == CONSOLIDATED_CTE.c.url_id, + ) +) + +cte = ( + select( + URL.id, + get_exists_subtask_query( + AutoAgencyIDSubtaskType.HOMEPAGE_MATCH, + ) + ) + .where( + VALID_URL_FLAG, + ) + .cte("homepage_eligible") +) + +HOMEPAGE_SUBTASK_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/muckrock.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/muckrock.py new file mode 100644 index 00000000..9e267f66 --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/muckrock.py @@ -0,0 +1,39 @@ +from sqlalchemy import select + +from src.collectors.enums import CollectorType +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + +cte = ( + select( + URL.id, + get_exists_subtask_query( + AutoAgencyIDSubtaskType.MUCKROCK, + ) + ) + .join( + LinkBatchURL, + LinkBatchURL.url_id == URL.id, + ) + .join( + Batch, + Batch.id == LinkBatchURL.batch_id, + ) + .where( + Batch.strategy.in_( + (CollectorType.MUCKROCK_ALL_SEARCH.value, + CollectorType.MUCKROCK_COUNTY_SEARCH.value, + CollectorType.MUCKROCK_SIMPLE_SEARCH.value,) + ), + ) + .cte("muckrock_eligible") +) + +MUCKROCK_SUBTASK_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location.py new file mode 100644 index 00000000..17055d1a --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location.py @@ -0,0 +1,50 @@ +from operator import and_ + +from sqlalchemy import select, exists + +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion + +cte = ( + select( + URL.id, + get_exists_subtask_query( + AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH + ) + ) + .join( + AutoLocationIDSubtask, + and_( + AutoLocationIDSubtask.url_id == URL.id, + AutoLocationIDSubtask.locations_found + ) + ) + .where( + # One of the locations must be linked to an agency + exists( + select( + LinkAgencyLocation.id + ) + .join( + LocationIDSubtaskSuggestion, + LocationIDSubtaskSuggestion.location_id == LinkAgencyLocation.location_id, + ) + .join( + AutoLocationIDSubtask, + AutoLocationIDSubtask.id == LocationIDSubtaskSuggestion.subtask_id, + ) + ) + + ) + .cte("nlp_location_eligible") +) + +NLP_LOCATION_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/eligible_counts.py b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/eligible_counts.py new file mode 100644 index 00000000..d3b7fe6b --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/queries/survey/queries/eligible_counts.py @@ -0,0 +1,26 @@ +from sqlalchemy import select, ColumnElement, Integer, func + +from src.core.tasks.url.operators.agency_identification.subtasks.queries.survey.queries.ctes.eligible import \ + EligibleContainer +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType + + +def sum_count(col: ColumnElement[bool], subtask_type: AutoAgencyIDSubtaskType) -> ColumnElement[int]: + return func.coalesce( + func.sum( + col.cast(Integer) + ), + 0, + ).label(subtask_type.value) + +container = EligibleContainer() + +ELIGIBLE_COUNTS_QUERY = ( + select( + sum_count(container.ckan, AutoAgencyIDSubtaskType.CKAN), + sum_count(container.muckrock, AutoAgencyIDSubtaskType.MUCKROCK), + sum_count(container.homepage, AutoAgencyIDSubtaskType.HOMEPAGE_MATCH), + sum_count(container.nlp_location, AutoAgencyIDSubtaskType.NLP_LOCATION_MATCH), + sum_count(container.batch_link, AutoAgencyIDSubtaskType.BATCH_LINK) + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/templates/__init__.py b/src/core/tasks/url/operators/agency_identification/subtasks/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/agency_identification/subtasks/templates/subtask.py b/src/core/tasks/url/operators/agency_identification/subtasks/templates/subtask.py new file mode 100644 index 00000000..9335afcf --- /dev/null +++ b/src/core/tasks/url/operators/agency_identification/subtasks/templates/subtask.py @@ -0,0 +1,96 @@ +import abc +import traceback +from abc import ABC + +from src.core.tasks.url.operators.agency_identification.subtasks.models.run_info import AgencyIDSubtaskRunInfo +from src.core.tasks.url.operators.agency_identification.subtasks.models.subtask import AutoAgencyIDSubtaskData +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.db.models.impl.url.suggestion.agency.suggestion.pydantic import AgencyIDSubtaskSuggestionPydantic +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +class AgencyIDSubtaskOperatorBase(ABC): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int + ) -> None: + self.adb_client: AsyncDatabaseClient = adb_client + self.task_id: int = task_id + self.linked_urls: list[int] = [] + + async def run(self) -> AgencyIDSubtaskRunInfo: + try: + await self.inner_logic() + except Exception as e: + # Get stack trace + stack_trace: str = traceback.format_exc() + return AgencyIDSubtaskRunInfo( + error=f"{type(e).__name__}: {str(e)}: {stack_trace}", + linked_url_ids=self.linked_urls + ) + return AgencyIDSubtaskRunInfo( + linked_url_ids=self.linked_urls + ) + + @abc.abstractmethod + async def inner_logic(self) -> AgencyIDSubtaskRunInfo: + raise NotImplementedError + + async def _upload_subtask_data( + self, + subtask_data_list: list[AutoAgencyIDSubtaskData] + ) -> None: + + subtask_models: list[URLAutoAgencyIDSubtaskPydantic] = [ + subtask_data.pydantic_model + for subtask_data in subtask_data_list + ] + subtask_ids: list[int] = await self.adb_client.bulk_insert( + models=subtask_models, + return_ids=True + ) + suggestions: list[AgencyIDSubtaskSuggestionPydantic] = [] + for subtask_id, subtask_info in zip(subtask_ids, subtask_data_list): + for suggestion in subtask_info.suggestions: + suggestion_pydantic = AgencyIDSubtaskSuggestionPydantic( + subtask_id=subtask_id, + agency_id=suggestion.agency_id, + confidence=suggestion.confidence, + ) + suggestions.append(suggestion_pydantic) + + await self.adb_client.bulk_insert( + models=suggestions, + ) + + error_infos: list[URLTaskErrorSmall] = [] + for subtask_info in subtask_data_list: + if not subtask_info.has_error: + continue + error_info = URLTaskErrorSmall( + url_id=subtask_info.url_id, + error=subtask_info.error, + ) + error_infos.append(error_info) + + await self.add_task_errors(error_infos) + + async def add_task_errors( + self, + errors: list[URLTaskErrorSmall] + ) -> None: + inserts: list[URLTaskErrorPydantic] = [ + URLTaskErrorPydantic( + task_id=self.task_id, + url_id=error.url_id, + task_type=TaskType.AGENCY_IDENTIFICATION, + error=error.error + ) + for error in errors + ] + await self.adb_client.bulk_insert(inserts) \ No newline at end of file diff --git a/src/core/tasks/url/operators/auto_name/__init__.py b/src/core/tasks/url/operators/auto_name/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/auto_name/clean.py b/src/core/tasks/url/operators/auto_name/clean.py new file mode 100644 index 00000000..2e1820ab --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/clean.py @@ -0,0 +1,7 @@ +from src.db.models.impl.url.suggestion.location.auto.subtask.constants import MAX_SUGGESTION_LENGTH + + +def clean_title(title: str) -> str: + if len(title) > MAX_SUGGESTION_LENGTH: + return title[:MAX_SUGGESTION_LENGTH-3] + "..." + return title \ No newline at end of file diff --git a/src/core/tasks/url/operators/auto_name/core.py b/src/core/tasks/url/operators/auto_name/core.py new file mode 100644 index 00000000..00af9838 --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/core.py @@ -0,0 +1,44 @@ +from src.core.tasks.url.operators.auto_name.clean import clean_title +from src.core.tasks.url.operators.auto_name.input import AutoNamePrerequisitesInput +from src.core.tasks.url.operators.auto_name.queries.get import AutoNameGetInputsQueryBuilder +from src.core.tasks.url.operators.auto_name.queries.prereq import AutoNamePrerequisitesQueryBuilder +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.pydantic import URLNameSuggestionPydantic + + +class AutoNameURLTaskOperator(URLTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.AUTO_NAME + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + AutoNamePrerequisitesQueryBuilder() + ) + + async def inner_task_logic(self) -> None: + + # Get URLs with HTML metadata title + inputs: list[AutoNamePrerequisitesInput] = await self.adb_client.run_query_builder( + AutoNameGetInputsQueryBuilder() + ) + + # Link URLs to task + url_ids: list[int] = [input.url_id for input in inputs] + await self.link_urls_to_task(url_ids) + + # Add suggestions + suggestions: list[URLNameSuggestionPydantic] = [ + URLNameSuggestionPydantic( + url_id=input_.url_id, + suggestion=clean_title(input_.title), + source=NameSuggestionSource.HTML_METADATA_TITLE, + ) + for input_ in inputs + ] + + await self.adb_client.bulk_insert(models=suggestions) + diff --git a/src/core/tasks/url/operators/auto_name/input.py b/src/core/tasks/url/operators/auto_name/input.py new file mode 100644 index 00000000..afbd2f34 --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/input.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class AutoNamePrerequisitesInput(BaseModel): + url_id: int + title: str \ No newline at end of file diff --git a/src/core/tasks/url/operators/auto_name/queries/__init__.py b/src/core/tasks/url/operators/auto_name/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/auto_name/queries/cte.py b/src/core/tasks/url/operators/auto_name/queries/cte.py new file mode 100644 index 00000000..1c7fc503 --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/queries/cte.py @@ -0,0 +1,48 @@ +from sqlalchemy import select, exists, CTE, Column + +from src.db.enums import URLHTMLContentType, TaskType +from src.db.helpers.query import no_url_task_error +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.content.sqlalchemy import URLHTMLContent +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion + + +class AutoNamePrerequisiteCTEContainer: + + def __init__(self): + self._query = ( + select( + URL.id.label("url_id"), + URLHTMLContent.content + ) + .join( + URLHTMLContent, + URLHTMLContent.url_id == URL.id + ) + .where( + URLHTMLContent.content_type == URLHTMLContentType.TITLE.value, + ~exists( + select( + URLNameSuggestion.id + ) + .where( + URLNameSuggestion.url_id == URL.id, + URLNameSuggestion.source == NameSuggestionSource.HTML_METADATA_TITLE.value, + ) + ), + no_url_task_error(TaskType.AUTO_NAME) + ).cte("auto_name_prerequisites") + ) + + @property + def cte(self) -> CTE: + return self._query + + @property + def url_id(self) -> Column[int]: + return self.cte.c.url_id + + @property + def content(self) -> Column[str]: + return self.cte.c.content \ No newline at end of file diff --git a/src/core/tasks/url/operators/auto_name/queries/get.py b/src/core/tasks/url/operators/auto_name/queries/get.py new file mode 100644 index 00000000..b4978521 --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/queries/get.py @@ -0,0 +1,27 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.auto_name.input import AutoNamePrerequisitesInput +from src.core.tasks.url.operators.auto_name.queries.cte import AutoNamePrerequisiteCTEContainer +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class AutoNameGetInputsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[AutoNamePrerequisitesInput]: + cte = AutoNamePrerequisiteCTEContainer() + query = select(cte.url_id, cte.content) + + mappings: Sequence[RowMapping] = await sh.mappings(session=session, query=query) + results: list[AutoNamePrerequisitesInput] = [] + for mapping in mappings: + result = AutoNamePrerequisitesInput( + url_id=mapping["url_id"], + title=mapping["content"], + ) + results.append(result) + + return results \ No newline at end of file diff --git a/src/core/tasks/url/operators/auto_name/queries/prereq.py b/src/core/tasks/url/operators/auto_name/queries/prereq.py new file mode 100644 index 00000000..c6224db8 --- /dev/null +++ b/src/core/tasks/url/operators/auto_name/queries/prereq.py @@ -0,0 +1,16 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.auto_name.queries.cte import AutoNamePrerequisiteCTEContainer +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class AutoNamePrerequisitesQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + cte = AutoNamePrerequisiteCTEContainer() + query = select(cte.url_id) + return await sh.results_exist(session, query=query) + + diff --git a/src/core/tasks/url/operators/auto_relevant/core.py b/src/core/tasks/url/operators/auto_relevant/core.py index 1a0c6c13..86cc179e 100644 --- a/src/core/tasks/url/operators/auto_relevant/core.py +++ b/src/core/tasks/url/operators/auto_relevant/core.py @@ -1,11 +1,14 @@ from src.core.tasks.url.operators.auto_relevant.models.annotation import RelevanceAnnotationInfo from src.core.tasks.url.operators.auto_relevant.models.tdo import URLRelevantTDO +from src.core.tasks.url.operators.auto_relevant.queries.get import GetAutoRelevantTDOsQueryBuilder +from src.core.tasks.url.operators.auto_relevant.queries.prereq import AutoRelevantPrerequisitesQueryBuilder from src.core.tasks.url.operators.auto_relevant.sort import separate_success_and_error_subsets from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.annotations.auto.relevancy import AutoRelevancyAnnotationInput -from src.db.dtos.url.error import URLErrorPydanticInfo +from src.db.models.impl.url.suggestion.relevant.auto.pydantic.input import AutoRelevancyAnnotationInput from src.db.enums import TaskType +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall from src.external.huggingface.inference.client import HuggingFaceInferenceClient from src.external.huggingface.inference.models.input import BasicInput @@ -21,16 +24,18 @@ def __init__( self.hf_client = hf_client @property - def task_type(self): + def task_type(self) -> TaskType: return TaskType.RELEVANCY - async def meets_task_prerequisites(self): - return await self.adb_client.has_urls_with_html_data_and_without_auto_relevant_suggestion() + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + builder=AutoRelevantPrerequisitesQueryBuilder() + ) async def get_tdos(self) -> list[URLRelevantTDO]: - return await self.adb_client.get_tdos_for_auto_relevancy() + return await self.adb_client.run_query_builder(builder=GetAutoRelevantTDOsQueryBuilder()) - async def inner_task_logic(self): + async def inner_task_logic(self) -> None: tdos = await self.get_tdos() url_ids = [tdo.url_id for tdo in tdos] await self.link_urls_to_task(url_ids=url_ids) @@ -41,7 +46,12 @@ async def inner_task_logic(self): await self.put_results_into_database(subsets.success) await self.update_errors_in_database(subsets.error) - async def get_ml_classifications(self, tdos: list[URLRelevantTDO]): + async def get_ml_classifications(self, tdos: list[URLRelevantTDO]) -> None: + """ + Modifies: + tdo.annotation + tdo.error + """ for tdo in tdos: try: input_ = BasicInput( @@ -59,7 +69,7 @@ async def get_ml_classifications(self, tdos: list[URLRelevantTDO]): ) tdo.annotation = annotation_info - async def put_results_into_database(self, tdos: list[URLRelevantTDO]): + async def put_results_into_database(self, tdos: list[URLRelevantTDO]) -> None: inputs = [] for tdo in tdos: input_ = AutoRelevancyAnnotationInput( @@ -71,15 +81,14 @@ async def put_results_into_database(self, tdos: list[URLRelevantTDO]): inputs.append(input_) await self.adb_client.add_user_relevant_suggestions(inputs) - async def update_errors_in_database(self, tdos: list[URLRelevantTDO]): - error_infos = [] + async def update_errors_in_database(self, tdos: list[URLRelevantTDO]) -> None: + task_errors: list[URLTaskErrorSmall] = [] for tdo in tdos: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, + error_info = URLTaskErrorSmall( url_id=tdo.url_id, error=tdo.error ) - error_infos.append(error_info) - await self.adb_client.add_url_error_infos(error_infos) + task_errors.append(error_info) + await self.add_task_errors(task_errors) diff --git a/src/core/tasks/url/operators/auto_relevant/queries/cte.py b/src/core/tasks/url/operators/auto_relevant/queries/cte.py new file mode 100644 index 00000000..8ad33867 --- /dev/null +++ b/src/core/tasks/url/operators/auto_relevant/queries/cte.py @@ -0,0 +1,39 @@ +from sqlalchemy import select, CTE +from sqlalchemy.orm import aliased + +from src.collectors.enums import URLStatus +from src.db.enums import TaskType +from src.db.helpers.query import not_exists_url, no_url_task_error +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion + + +class AutoRelevantPrerequisitesCTEContainer: + + def __init__(self): + self._cte = ( + select( + URL + ) + .join( + URLCompressedHTML, + URL.id == URLCompressedHTML.url_id + ) + .where( + URL.status == URLStatus.OK.value, + not_exists_url(AutoRelevantSuggestion), + no_url_task_error(TaskType.RELEVANCY) + ).cte("auto_relevant_prerequisites") + ) + + self._url_alias = aliased(URL, self._cte) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_alias(self): + """Return an ORM alias of URL mapped to the CTE.""" + return self._url_alias diff --git a/src/core/tasks/url/operators/auto_relevant/queries/get.py b/src/core/tasks/url/operators/auto_relevant/queries/get.py new file mode 100644 index 00000000..6f6c59b0 --- /dev/null +++ b/src/core/tasks/url/operators/auto_relevant/queries/get.py @@ -0,0 +1,42 @@ +from typing import Sequence + +from sqlalchemy import select, Row +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.collectors.enums import URLStatus +from src.core.tasks.url.operators.auto_relevant.models.tdo import URLRelevantTDO +from src.core.tasks.url.operators.auto_relevant.queries.cte import AutoRelevantPrerequisitesCTEContainer +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion +from src.db.queries.base.builder import QueryBuilderBase +from src.db.statement_composer import StatementComposer +from src.db.utils.compression import decompress_html + + +class GetAutoRelevantTDOsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[URLRelevantTDO]: + cte = AutoRelevantPrerequisitesCTEContainer() + query = ( + select(cte.url_alias) + .options( + selectinload(cte.url_alias.compressed_html) + ) + ) + + query = query.limit(100).order_by(cte.url_alias.id) + raw_result = await session.execute(query) + urls: Sequence[Row[URL]] = raw_result.unique().scalars().all() + tdos = [] + for url in urls: + tdos.append( + URLRelevantTDO( + url_id=url.id, + html=decompress_html(url.compressed_html.compressed_html), + ) + ) + + return tdos + diff --git a/src/core/tasks/url/operators/auto_relevant/queries/get_tdos.py b/src/core/tasks/url/operators/auto_relevant/queries/get_tdos.py deleted file mode 100644 index b444b5b3..00000000 --- a/src/core/tasks/url/operators/auto_relevant/queries/get_tdos.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Sequence - -from sqlalchemy import select, Row -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload - -from src.collectors.enums import URLStatus -from src.core.tasks.url.operators.auto_relevant.models.tdo import URLRelevantTDO -from src.db.models.instantiations.url.compressed_html import URLCompressedHTML -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.relevant.auto import AutoRelevantSuggestion -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer -from src.db.utils.compression import decompress_html - - -class GetAutoRelevantTDOsQueryBuilder(QueryBuilderBase): - - def __init__(self): - super().__init__() - - async def run(self, session: AsyncSession) -> list[URLRelevantTDO]: - query = ( - select( - URL - ) - .options( - selectinload(URL.compressed_html) - ) - .join(URLCompressedHTML) - .where( - URL.outcome == URLStatus.PENDING.value, - ) - ) - query = StatementComposer.exclude_urls_with_extant_model( - query, - model=AutoRelevantSuggestion - ) - query = query.limit(100).order_by(URL.id) - raw_result = await session.execute(query) - urls: Sequence[Row[URL]] = raw_result.unique().scalars().all() - tdos = [] - for url in urls: - tdos.append( - URLRelevantTDO( - url_id=url.id, - html=decompress_html(url.compressed_html.compressed_html), - ) - ) - - return tdos - diff --git a/src/core/tasks/url/operators/auto_relevant/queries/prereq.py b/src/core/tasks/url/operators/auto_relevant/queries/prereq.py new file mode 100644 index 00000000..2736693e --- /dev/null +++ b/src/core/tasks/url/operators/auto_relevant/queries/prereq.py @@ -0,0 +1,18 @@ + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.auto_relevant.queries.cte import AutoRelevantPrerequisitesCTEContainer +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class AutoRelevantPrerequisitesQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + + cte = AutoRelevantPrerequisitesCTEContainer() + query = ( + select(cte.url_alias) + ) + + return await sh.results_exist(session, query=query) \ No newline at end of file diff --git a/src/core/tasks/url/operators/base.py b/src/core/tasks/url/operators/base.py index 59c41c6a..e1d70d5e 100644 --- a/src/core/tasks/url/operators/base.py +++ b/src/core/tasks/url/operators/base.py @@ -1,61 +1,36 @@ -import traceback -from abc import ABC, abstractmethod - from src.core.tasks.base.operator import TaskOperatorBase -from src.db.client.async_ import AsyncDatabaseClient -from src.db.enums import TaskType -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo +from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.mixins.prereq import HasPrerequisitesMixin from src.core.tasks.url.enums import TaskOperatorOutcome -from src.core.enums import BatchStatus +from src.db.client.async_ import AsyncDatabaseClient -class URLTaskOperatorBase(TaskOperatorBase): +class URLTaskOperatorBase( + TaskOperatorBase, + LinkURLsMixin, + HasPrerequisitesMixin, +): def __init__(self, adb_client: AsyncDatabaseClient): super().__init__(adb_client) - self.tasks_linked = False - self.linked_url_ids = [] - - @abstractmethod - async def meets_task_prerequisites(self): - """ - A task should not be initiated unless certain - conditions are met - """ - raise NotImplementedError - - async def link_urls_to_task(self, url_ids: list[int]): - self.linked_url_ids = url_ids async def conclude_task(self): - if not self.linked_url_ids: + if not self.urls_linked: raise Exception("Task has not been linked to any URLs") return await self.run_info( outcome=TaskOperatorOutcome.SUCCESS, message="Task completed successfully" ) - async def run_task(self, task_id: int) -> URLTaskOperatorRunInfo: - self.task_id = task_id - try: - await self.inner_task_logic() - return await self.conclude_task() - except Exception as e: - stack_trace = traceback.format_exc() - return await self.run_info( - outcome=TaskOperatorOutcome.ERROR, - message=str(e) + "\n" + stack_trace - ) - async def run_info( self, outcome: TaskOperatorOutcome, message: str - ) -> URLTaskOperatorRunInfo: - return URLTaskOperatorRunInfo( + ) -> TaskOperatorRunInfo: + return TaskOperatorRunInfo( task_id=self.task_id, task_type=self.task_type, - linked_url_ids=self.linked_url_ids, outcome=outcome, message=message ) diff --git a/src/core/tasks/url/operators/html/__init__.py b/src/core/tasks/url/operators/html/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/url_html/content_info_getter.py b/src/core/tasks/url/operators/html/content_info_getter.py similarity index 78% rename from src/core/tasks/url/operators/url_html/content_info_getter.py rename to src/core/tasks/url/operators/html/content_info_getter.py index 644e12e4..bee7183c 100644 --- a/src/core/tasks/url/operators/url_html/content_info_getter.py +++ b/src/core/tasks/url/operators/html/content_info_getter.py @@ -1,5 +1,6 @@ -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo -from src.db.dtos.url.html_content import URLHTMLContentInfo, HTMLContentType +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.db.dtos.url.html_content import URLHTMLContentInfo +from src.db.models.impl.url.html.content.enums import HTMLContentType class HTMLContentInfoGetter: diff --git a/src/core/tasks/url/operators/html/core.py b/src/core/tasks/url/operators/html/core.py new file mode 100644 index 00000000..26f70cdb --- /dev/null +++ b/src/core/tasks/url/operators/html/core.py @@ -0,0 +1,84 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.html.filter import filter_just_urls, filter_404_subset +from src.core.tasks.url.operators.html.queries.insert.query import InsertURLHTMLInfoQueryBuilder +from src.core.tasks.url.operators.html.scraper.parser.core import HTMLResponseParser +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.external.url_request.core import URLRequestInterface + + +class URLHTMLTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + url_request_interface: URLRequestInterface, + adb_client: AsyncDatabaseClient, + html_parser: HTMLResponseParser + ): + super().__init__(adb_client) + self.url_request_interface = url_request_interface + self.html_parser = html_parser + + @property + def task_type(self) -> TaskType: + return TaskType.HTML + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.has_non_errored_urls_without_html_data() + + async def inner_task_logic(self) -> None: + tdos = await self._get_non_errored_urls_without_html_data() + url_ids = [task_info.url_info.id for task_info in tdos] + await self.link_urls_to_task(url_ids=url_ids) + + await self._get_raw_html_data_for_urls(tdos) + await self._process_html_data(tdos) + + tdos_404 = await filter_404_subset(tdos) + await self._update_404s_in_database(tdos_404) + await self._update_html_data_in_database(tdos) + + + async def _get_non_errored_urls_without_html_data(self) -> list[UrlHtmlTDO]: + pending_urls: list[URLInfo] = await self.adb_client.get_non_errored_urls_without_html_data() + tdos = [ + UrlHtmlTDO( + url_info=url_info, + ) for url_info in pending_urls + ] + return tdos + + async def _get_raw_html_data_for_urls(self, tdos: list[UrlHtmlTDO]) -> None: + just_urls = await filter_just_urls(tdos) + url_response_infos = await self.url_request_interface.make_requests_with_html(just_urls) + for tdto, url_response_info in zip(tdos, url_response_infos): + tdto.url_response_info = url_response_info + + async def _update_404s_in_database(self, tdos_404: list[UrlHtmlTDO]) -> None: + url_ids = [tdo.url_info.id for tdo in tdos_404] + await self.adb_client.mark_all_as_404(url_ids) + + + async def _process_html_data(self, tdos: list[UrlHtmlTDO]) -> None: + """ + Modifies: + tdto.html_tag_info + """ + for tdto in tdos: + if not tdto.url_response_info.success: + continue + html_tag_info = await self.html_parser.parse( + url=tdto.url_info.url, + html_content=tdto.url_response_info.html, + content_type=tdto.url_response_info.content_type + ) + tdto.html_tag_info = html_tag_info + + async def _update_html_data_in_database(self, tdos: list[UrlHtmlTDO]) -> None: + await self.adb_client.run_query_builder( + InsertURLHTMLInfoQueryBuilder(tdos, task_id=self.task_id) + ) + + diff --git a/src/core/tasks/url/operators/html/filter.py b/src/core/tasks/url/operators/html/filter.py new file mode 100644 index 00000000..86da0e8a --- /dev/null +++ b/src/core/tasks/url/operators/html/filter.py @@ -0,0 +1,13 @@ +from http import HTTPStatus + +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO + + +async def filter_just_urls(tdos: list[UrlHtmlTDO]): + return [task_info.url_info.url for task_info in tdos] + +async def filter_404_subset(tdos: list[UrlHtmlTDO]) -> list[UrlHtmlTDO]: + return [ + tdo for tdo in tdos + if tdo.url_response_info.status == HTTPStatus.NOT_FOUND + ] diff --git a/src/core/tasks/url/operators/html/models/__init__.py b/src/core/tasks/url/operators/html/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/html/models/subsets/__init__.py b/src/core/tasks/url/operators/html/models/subsets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/html/models/subsets/error_404.py b/src/core/tasks/url/operators/html/models/subsets/error_404.py new file mode 100644 index 00000000..f526368c --- /dev/null +++ b/src/core/tasks/url/operators/html/models/subsets/error_404.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO + + +class ErrorSubsets(BaseModel): + is_404: list[UrlHtmlTDO] + not_404: list[UrlHtmlTDO] \ No newline at end of file diff --git a/src/core/tasks/url/operators/html/models/subsets/success_error.py b/src/core/tasks/url/operators/html/models/subsets/success_error.py new file mode 100644 index 00000000..75429a6e --- /dev/null +++ b/src/core/tasks/url/operators/html/models/subsets/success_error.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO + + +class SuccessErrorSubset(BaseModel): + success: list[UrlHtmlTDO] + error: list[UrlHtmlTDO] \ No newline at end of file diff --git a/src/core/tasks/url/operators/html/queries/__init__.py b/src/core/tasks/url/operators/html/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/html/queries/get.py b/src/core/tasks/url/operators/html/queries/get.py new file mode 100644 index 00000000..832d9917 --- /dev/null +++ b/src/core/tasks/url/operators/html/queries/get.py @@ -0,0 +1,31 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.statement_composer import StatementComposer + + +class GetPendingURLsWithoutHTMLDataQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[URLInfo]: + statement = StatementComposer.has_non_errored_urls_without_html_data() + statement = statement.limit(100).order_by(URL.id) + scalar_result = await session.scalars(statement) + url_results: list[URL] = scalar_result.all() + + final_results = [] + for url in url_results: + url_info = URLInfo( + id=url.id, + batch_id=url.batch.id if url.batch is not None else None, + url=url.url, + collector_metadata=url.collector_metadata, + status=url.status, + created_at=url.created_at, + updated_at=url.updated_at, + name=url.name + ) + final_results.append(url_info) + + return final_results diff --git a/src/core/tasks/url/operators/html/queries/insert/__init__.py b/src/core/tasks/url/operators/html/queries/insert/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/html/queries/insert/convert.py b/src/core/tasks/url/operators/html/queries/insert/convert.py new file mode 100644 index 00000000..ca827c7e --- /dev/null +++ b/src/core/tasks/url/operators/html/queries/insert/convert.py @@ -0,0 +1,76 @@ +from http import HTTPStatus + +from src.core.tasks.url.operators.html.content_info_getter import HTMLContentInfoGetter +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO +from src.db.dtos.url.html_content import URLHTMLContentInfo +from src.db.enums import TaskType +from src.db.models.impl.url.error_info.pydantic import URLErrorInfoPydantic +from src.db.models.impl.url.html.compressed.pydantic import URLCompressedHTMLPydantic +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus +from src.db.models.impl.url.scrape_info.pydantic import URLScrapeInfoInsertModel +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.utils.compression import compress_html +from src.external.url_request.dtos.url_response import URLResponseInfo + + +def convert_to_compressed_html(tdos: list[UrlHtmlTDO]) -> list[URLCompressedHTMLPydantic]: + models = [] + for tdo in tdos: + if tdo.url_response_info.status != HTTPStatus.OK: + continue + model = URLCompressedHTMLPydantic( + url_id=tdo.url_info.id, + compressed_html=compress_html(tdo.url_response_info.html) + ) + models.append(model) + return models + + + +def _convert_to_html_content_info_getter(tdo: UrlHtmlTDO) -> HTMLContentInfoGetter: + return HTMLContentInfoGetter( + response_html_info=tdo.html_tag_info, + url_id=tdo.url_info.id + ) + +def convert_to_html_content_info_list(tdos: list[UrlHtmlTDO]) -> list[URLHTMLContentInfo]: + html_content_infos = [] + for tdo in tdos: + if tdo.url_response_info.status != HTTPStatus.OK: + continue + hcig = _convert_to_html_content_info_getter(tdo) + results = hcig.get_all_html_content() + html_content_infos.extend(results) + return html_content_infos + +def get_scrape_status(response_info: URLResponseInfo) -> ScrapeStatus: + if response_info.success: + return ScrapeStatus.SUCCESS + return ScrapeStatus.ERROR + +def convert_to_scrape_infos(tdos: list[UrlHtmlTDO]) -> list[URLScrapeInfoInsertModel]: + models = [] + for tdo in tdos: + model = URLScrapeInfoInsertModel( + url_id=tdo.url_info.id, + status=get_scrape_status(tdo.url_response_info) + ) + models.append(model) + return models + +def convert_to_url_errors( + tdos: list[UrlHtmlTDO], + task_id: int +) -> list[URLErrorInfoPydantic]: + models = [] + for tdo in tdos: + if tdo.url_response_info.success: + continue + model = URLTaskErrorPydantic( + url_id=tdo.url_info.id, + error=tdo.url_response_info.exception, + task_id=task_id, + task_type=TaskType.HTML + ) + models.append(model) + return models \ No newline at end of file diff --git a/src/core/tasks/url/operators/html/queries/insert/query.py b/src/core/tasks/url/operators/html/queries/insert/query.py new file mode 100644 index 00000000..e0bff2e6 --- /dev/null +++ b/src/core/tasks/url/operators/html/queries/insert/query.py @@ -0,0 +1,30 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.html.queries.insert.convert import convert_to_compressed_html, \ + convert_to_html_content_info_list, convert_to_scrape_infos, convert_to_url_errors +from src.core.tasks.url.operators.html.tdo import UrlHtmlTDO +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class InsertURLHTMLInfoQueryBuilder(QueryBuilderBase): + + def __init__(self, tdos: list[UrlHtmlTDO], task_id: int): + super().__init__() + self.tdos = tdos + self.task_id = task_id + + async def run(self, session: AsyncSession) -> None: + compressed_html_models = convert_to_compressed_html(self.tdos) + url_html_content_list = convert_to_html_content_info_list(self.tdos) + scrape_info_list = convert_to_scrape_infos(self.tdos) + url_errors = convert_to_url_errors(self.tdos, task_id=self.task_id) + + for models in [ + compressed_html_models, + url_html_content_list, + scrape_info_list, + url_errors + ]: + await sh.bulk_insert(session, models=models) + + diff --git a/src/core/tasks/url/operators/url_html/scraper/README.md b/src/core/tasks/url/operators/html/scraper/README.md similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/README.md rename to src/core/tasks/url/operators/html/scraper/README.md diff --git a/src/core/tasks/url/operators/html/scraper/__init__.py b/src/core/tasks/url/operators/html/scraper/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/README.md b/src/core/tasks/url/operators/html/scraper/parser/README.md similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/parser/README.md rename to src/core/tasks/url/operators/html/scraper/parser/README.md diff --git a/src/core/tasks/url/operators/html/scraper/parser/__init__.py b/src/core/tasks/url/operators/html/scraper/parser/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/constants.py b/src/core/tasks/url/operators/html/scraper/parser/constants.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/parser/constants.py rename to src/core/tasks/url/operators/html/scraper/parser/constants.py diff --git a/src/core/tasks/url/operators/html/scraper/parser/core.py b/src/core/tasks/url/operators/html/scraper/parser/core.py new file mode 100644 index 00000000..d79ab1f6 --- /dev/null +++ b/src/core/tasks/url/operators/html/scraper/parser/core.py @@ -0,0 +1,124 @@ +import json + +from bs4 import BeautifulSoup + +from src.core.tasks.url.operators.html.scraper.parser.constants import HEADER_TAGS +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.core.tasks.url.operators.html.scraper.parser.enums import ParserTypeEnum +from src.core.tasks.url.operators.html.scraper.parser.util import remove_excess_whitespace, add_https, \ + remove_trailing_backslash, \ + drop_hostname + + +class HTMLResponseParser: + + async def parse(self, url: str, html_content: str, content_type: str) -> ResponseHTMLInfo: + html_info = ResponseHTMLInfo() + self.add_url_and_path(html_info, html_content=html_content, url=url) + parser_type = self.get_parser_type(content_type) + if parser_type is None: + return html_info + self.add_html_from_beautiful_soup( + html_info=html_info, + parser_type=parser_type, + html_content=html_content + ) + return html_info + + def add_html_from_beautiful_soup( + self, + html_info: ResponseHTMLInfo, + parser_type: ParserTypeEnum, + html_content: str + ) -> None: + """ + Modifies: + html_info + """ + + soup = BeautifulSoup( + markup=html_content, + features=parser_type.value, + ) + html_info.title = self.get_html_title(soup) + html_info.description = self.get_meta_description(soup) + self.add_header_tags(html_info, soup) + html_info.div = self.get_div_text(soup) + # Prevents most bs4 memory leaks + if soup.html is not None: + soup.html.decompose() + + def get_div_text(self, soup: BeautifulSoup) -> str: + div_text = "" + MAX_WORDS = 500 + for div in soup.find_all("div"): + text = div.get_text(" ", strip=True) + if text is None: + continue + # Check if adding the current text exceeds the word limit + if len(div_text.split()) + len(text.split()) <= MAX_WORDS: + div_text += text + " " + else: + break # Stop adding text if word limit is reached + + # Truncate to 5000 characters in case of run-on 'words' + div_text = div_text[: MAX_WORDS * 10] + + return div_text + + def get_meta_description(self, soup: BeautifulSoup) -> str: + meta_tag = soup.find("meta", attrs={"name": "description"}) + if meta_tag is None: + return "" + try: + return remove_excess_whitespace(meta_tag["content"]) + except KeyError: + return "" + + def add_header_tags(self, html_info: ResponseHTMLInfo, soup: BeautifulSoup): + for header_tag in HEADER_TAGS: + headers = soup.find_all(header_tag) + # Retrieves and drops headers containing links to reduce training bias + header_content = [header.get_text(" ", strip=True) for header in headers if not header.a] + tag_content = json.dumps(header_content, ensure_ascii=False) + if tag_content == "[]": + continue + setattr(html_info, header_tag, tag_content) + + def get_html_title(self, soup: BeautifulSoup) -> str | None: + if soup.title is None: + return None + if soup.title.string is None: + return None + return remove_excess_whitespace(soup.title.string) + + + def add_url_and_path( + self, + html_info: ResponseHTMLInfo, + html_content: str, + url: str + ) -> None: + """ + Modifies: + html_info.url + html_info.url_path + """ + url = add_https(url) + html_info.url = url + + url_path = drop_hostname(url) + url_path = remove_trailing_backslash(url_path) + html_info.url_path = url_path + + def get_parser_type(self, content_type: str) -> ParserTypeEnum | None: + try: + # If content type does not contain "html" or "xml" then we can assume that the content is unreadable + if "html" in content_type: + return ParserTypeEnum.LXML + if "xml" in content_type: + return ParserTypeEnum.LXML_XML + return None + except KeyError: + return None + diff --git a/src/core/tasks/url/operators/html/scraper/parser/dtos/__init__.py b/src/core/tasks/url/operators/html/scraper/parser/dtos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/dtos/response_html.py b/src/core/tasks/url/operators/html/scraper/parser/dtos/response_html.py similarity index 91% rename from src/core/tasks/url/operators/url_html/scraper/parser/dtos/response_html.py rename to src/core/tasks/url/operators/html/scraper/parser/dtos/response_html.py index dfa34510..0df614ce 100644 --- a/src/core/tasks/url/operators/url_html/scraper/parser/dtos/response_html.py +++ b/src/core/tasks/url/operators/html/scraper/parser/dtos/response_html.py @@ -7,7 +7,6 @@ class ResponseHTMLInfo(BaseModel): url_path: str = "" title: str = "" description: str = "" - root_page_title: str = "" http_response: int = -1 h1: str = "" h2: str = "" diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/enums.py b/src/core/tasks/url/operators/html/scraper/parser/enums.py similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/parser/enums.py rename to src/core/tasks/url/operators/html/scraper/parser/enums.py diff --git a/src/core/tasks/url/operators/html/scraper/parser/mapping.py b/src/core/tasks/url/operators/html/scraper/parser/mapping.py new file mode 100644 index 00000000..b4bb4f4a --- /dev/null +++ b/src/core/tasks/url/operators/html/scraper/parser/mapping.py @@ -0,0 +1,13 @@ +from src.db.models.impl.url.html.content.enums import HTMLContentType + +ENUM_TO_ATTRIBUTE_MAPPING = { + HTMLContentType.TITLE: "title", + HTMLContentType.DESCRIPTION: "description", + HTMLContentType.H1: "h1", + HTMLContentType.H2: "h2", + HTMLContentType.H3: "h3", + HTMLContentType.H4: "h4", + HTMLContentType.H5: "h5", + HTMLContentType.H6: "h6", + HTMLContentType.DIV: "div" +} diff --git a/src/core/tasks/url/operators/html/scraper/parser/util.py b/src/core/tasks/url/operators/html/scraper/parser/util.py new file mode 100644 index 00000000..924506a1 --- /dev/null +++ b/src/core/tasks/url/operators/html/scraper/parser/util.py @@ -0,0 +1,45 @@ +from urllib.parse import urlparse + +from src.db.dtos.url.html_content import URLHTMLContentInfo +from src.core.tasks.url.operators.html.scraper.parser.mapping import ENUM_TO_ATTRIBUTE_MAPPING +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo + + +def convert_to_response_html_info( + html_content_infos: list[URLHTMLContentInfo] +) -> ResponseHTMLInfo: + response_html_info = ResponseHTMLInfo() + + for html_content_info in html_content_infos: + setattr(response_html_info, ENUM_TO_ATTRIBUTE_MAPPING[html_content_info.content_type], html_content_info.content) + + return response_html_info + + +def remove_excess_whitespace(s: str) -> str: + """Removes leading, trailing, and excess adjacent whitespace. + + Args: + s (str): String to remove whitespace from. + + Returns: + str: Clean string with excess whitespace stripped. + """ + return " ".join(s.split()).strip() + + +def add_https(url: str) -> str: + if not url.startswith("http"): + url = "https://" + url + return url + + +def remove_trailing_backslash(url_path: str) -> str: + if url_path and url_path[-1] == "/": + url_path = url_path[:-1] + return url_path + + +def drop_hostname(new_url: str) -> str: + url_path = urlparse(new_url).path[1:] + return url_path diff --git a/src/core/tasks/url/operators/html/tdo.py b/src/core/tasks/url/operators/html/tdo.py new file mode 100644 index 00000000..00d5b9af --- /dev/null +++ b/src/core/tasks/url/operators/html/tdo.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel + +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.external.url_request.dtos.url_response import URLResponseInfo + + +class UrlHtmlTDO(BaseModel): + url_info: URLInfo + url_response_info: URLResponseInfo | None = None + html_tag_info: ResponseHTMLInfo | None = None + diff --git a/src/core/tasks/url/operators/location_id/__init__.py b/src/core/tasks/url/operators/location_id/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/core.py b/src/core/tasks/url/operators/location_id/core.py new file mode 100644 index 00000000..3833a80c --- /dev/null +++ b/src/core/tasks/url/operators/location_id/core.py @@ -0,0 +1,63 @@ +from src.core.tasks.mixins.link_urls import LinkURLsMixin +from src.core.tasks.url.operators._shared.exceptions import SubtaskError +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.location_id.subtasks.flags.core import SubtaskFlagger +from src.core.tasks.url.operators.location_id.subtasks.loader import LocationIdentificationSubtaskLoader +from src.core.tasks.url.operators.location_id.subtasks.models.run_info import LocationIDSubtaskRunInfo +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.core import LocationIDSurveyQueryBuilder +from src.core.tasks.url.operators.location_id.subtasks.templates.subtask import LocationIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + + +class LocationIdentificationTaskOperator( + URLTaskOperatorBase, + LinkURLsMixin, +): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + loader: LocationIdentificationSubtaskLoader, + ): + super().__init__(adb_client) + self.loader = loader + + @property + def task_type(self) -> TaskType: + return TaskType.LOCATION_ID + + async def load_subtask( + self, + subtask_type: LocationIDSubtaskType + ) -> LocationIDSubtaskOperatorBase: + return await self.loader.load_subtask(subtask_type, task_id=self.task_id) + + async def meets_task_prerequisites(self) -> bool: + """ + Modifies: + - self._subtask + """ + flagger = SubtaskFlagger() + allowed_subtasks: list[LocationIDSubtaskType] = flagger.get_allowed_subtasks() + + next_subtask: LocationIDSubtaskType | None = \ + await self.adb_client.run_query_builder( + LocationIDSurveyQueryBuilder( + allowed_subtasks=allowed_subtasks + ) + ) + self._subtask = next_subtask + if next_subtask is None: + return False + return True + + + async def inner_task_logic(self) -> None: + subtask_operator: LocationIDSubtaskOperatorBase = await self.load_subtask(self._subtask) + print(f"Running Subtask: {self._subtask.value}") + run_info: LocationIDSubtaskRunInfo = await subtask_operator.run() + await self.link_urls_to_task(run_info.linked_url_ids) + if not run_info.is_success: + raise SubtaskError(run_info.error) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/flags/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/flags/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/flags/core.py b/src/core/tasks/url/operators/location_id/subtasks/flags/core.py new file mode 100644 index 00000000..1b6cb55c --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/flags/core.py @@ -0,0 +1,25 @@ +from environs import Env + +from src.core.tasks.url.operators.location_id.subtasks.flags.mappings import SUBTASK_TO_ENV_FLAG +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + + +class SubtaskFlagger: + """ + Manages flags allowing and disallowing subtasks + """ + def __init__(self): + self.env = Env() + + def _get_subtask_flag(self, subtask_type: LocationIDSubtaskType) -> bool: + return self.env.bool( + SUBTASK_TO_ENV_FLAG[subtask_type], + default=True + ) + + def get_allowed_subtasks(self) -> list[LocationIDSubtaskType]: + return [ + subtask_type + for subtask_type, flag in SUBTASK_TO_ENV_FLAG.items() + if self._get_subtask_flag(subtask_type) + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/flags/mappings.py b/src/core/tasks/url/operators/location_id/subtasks/flags/mappings.py new file mode 100644 index 00000000..48f5d194 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/flags/mappings.py @@ -0,0 +1,6 @@ +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + +SUBTASK_TO_ENV_FLAG: dict[LocationIDSubtaskType, str] = { + LocationIDSubtaskType.NLP_LOCATION_FREQUENCY: "LOCATION_ID_NLP_LOCATION_MATCH_FLAG", + LocationIDSubtaskType.BATCH_LINK: "LOCATION_ID_BATCH_LINK_FLAG", +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/core.py b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/core.py new file mode 100644 index 00000000..a85e572a --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/core.py @@ -0,0 +1,56 @@ +from src.core.tasks.url.operators.location_id.subtasks.impl.batch_link.inputs import LocationBatchLinkInput +from src.core.tasks.url.operators.location_id.subtasks.impl.batch_link.query import GetLocationBatchLinkQueryBuilder +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.constants import ITERATIONS_PER_SUBTASK +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion +from src.core.tasks.url.operators.location_id.subtasks.templates.subtask import LocationIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.pydantic import AutoLocationIDSubtaskPydantic + + +class LocationBatchLinkSubtaskOperator(LocationIDSubtaskOperatorBase): + + def __init__( + self, + task_id: int, + adb_client: AsyncDatabaseClient, + ): + super().__init__(adb_client=adb_client, task_id=task_id) + + async def inner_logic(self) -> None: + for iteration in range(ITERATIONS_PER_SUBTASK): + inputs: list[LocationBatchLinkInput] = await self._get_from_db() + if len(inputs) == 0: + break + await self.run_subtask_iteration(inputs) + + async def run_subtask_iteration( + self, + inputs: list[LocationBatchLinkInput] + ) -> None: + self.linked_urls.extend([input_.url_id for input_ in inputs]) + subtask_data_list: list[AutoLocationIDSubtaskData] = [] + for input_ in inputs: + subtask_data_list.append( + AutoLocationIDSubtaskData( + pydantic_model=AutoLocationIDSubtaskPydantic( + url_id=input_.url_id, + task_id=self.task_id, + locations_found=True, + type=LocationIDSubtaskType.BATCH_LINK, + ), + suggestions=[ + LocationSuggestion( + location_id=input_.location_id, + confidence=80, + ) + ] + ) + ) + + await self._upload_subtask_data(subtask_data_list) + + async def _get_from_db(self) -> list[LocationBatchLinkInput]: + query = GetLocationBatchLinkQueryBuilder() + return await self.adb_client.run_query_builder(query) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/inputs.py b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/inputs.py new file mode 100644 index 00000000..0bd10414 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/inputs.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class LocationBatchLinkInput(BaseModel): + location_id: int + url_id: int \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/query.py b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/query.py new file mode 100644 index 00000000..1a7d424f --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/batch_link/query.py @@ -0,0 +1,46 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.location_id.subtasks.impl.batch_link.inputs import LocationBatchLinkInput +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.constants import \ + NUMBER_OF_ENTRIES_PER_ITERATION +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.eligible import EligibleContainer +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.link.location_batch.sqlalchemy import LinkLocationBatch +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class GetLocationBatchLinkQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[LocationBatchLinkInput]: + container = EligibleContainer() + query = ( + select( + LinkLocationBatch.location_id, + LinkBatchURL.url_id + ) + .join( + LinkLocationBatch, + LinkBatchURL.batch_id == LinkLocationBatch.batch_id, + ) + .join( + container.cte, + LinkBatchURL.url_id == container.url_id, + ) + .where( + container.batch_link, + ) + .limit(NUMBER_OF_ENTRIES_PER_ITERATION) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + inputs: list[LocationBatchLinkInput] = [ + LocationBatchLinkInput( + location_id=mapping["location_id"], + url_id=mapping["url_id"], + ) + for mapping in mappings + ] + return inputs diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/constants.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/constants.py new file mode 100644 index 00000000..31890aaa --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/constants.py @@ -0,0 +1,4 @@ + + +ITERATIONS_PER_SUBTASK = 4 +NUMBER_OF_ENTRIES_PER_ITERATION = 10 \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/core.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/core.py new file mode 100644 index 00000000..1f9c8d62 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/core.py @@ -0,0 +1,56 @@ +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.constants import ITERATIONS_PER_SUBTASK +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.input_ import \ + NLPLocationFrequencySubtaskInput +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.core import \ + NLPLocationFrequencySubtaskInternalProcessor +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.query import \ + GetNLPLocationFrequencySubtaskInputQueryBuilder +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.templates.subtask import LocationIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient + + +class NLPLocationFrequencySubtaskOperator(LocationIDSubtaskOperatorBase): + + def __init__( + self, + task_id: int, + adb_client: AsyncDatabaseClient, + nlp_processor: NLPProcessor, + ): + super().__init__(adb_client=adb_client, task_id=task_id) + self._nlp_processor: NLPProcessor = nlp_processor + self.processor = NLPLocationFrequencySubtaskInternalProcessor( + nlp_processor=nlp_processor, + adb_client=adb_client, + task_id=task_id, + ) + + + async def inner_logic(self) -> None: + for iteration in range(ITERATIONS_PER_SUBTASK): + inputs: list[NLPLocationFrequencySubtaskInput] = await self._get_from_db() + if len(inputs) == 0: + break + await self.run_subtask_iteration(inputs) + + async def run_subtask_iteration(self, inputs: list[NLPLocationFrequencySubtaskInput]) -> None: + self.linked_urls.extend([input_.url_id for input_ in inputs]) + subtask_data_list: list[AutoLocationIDSubtaskData] = await self._process_inputs(inputs) + + await self._upload_subtask_data(subtask_data_list) + + async def _process_inputs( + self, + inputs: list[NLPLocationFrequencySubtaskInput] + ) -> list[AutoLocationIDSubtaskData]: + return await self.processor.process( + inputs=inputs, + ) + + + async def _get_from_db(self) -> list[NLPLocationFrequencySubtaskInput]: + return await self.adb_client.run_query_builder( + GetNLPLocationFrequencySubtaskInputQueryBuilder(), + ) diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/input_.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/input_.py new file mode 100644 index 00000000..0ba1647e --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/input_.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class NLPLocationFrequencySubtaskInput(BaseModel): + url_id: int + html: str \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_nlp_response.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_nlp_response.py new file mode 100644 index 00000000..1f611ad7 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_nlp_response.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse + + +class URLToNLPResponseMapping(BaseModel): + url_id: int + nlp_response: NLPLocationMatchResponse \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_search_response.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_search_response.py new file mode 100644 index 00000000..807b38d0 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/mappings/url_id_search_response.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.response import \ + SearchSimilarLocationsResponse +from src.external.pdap.dtos.search_agency_by_location.response import SearchAgencyByLocationResponse + + +class URLToSearchResponseMapping(BaseModel): + url_id: int + search_responses: list[SearchSimilarLocationsResponse] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/subsets.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/subsets.py new file mode 100644 index 00000000..304c7e01 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/models/subsets.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_nlp_response import \ + URLToNLPResponseMapping + + +class NLPResponseSubsets(BaseModel): + valid: list[URLToNLPResponseMapping] + invalid: list[URLToNLPResponseMapping] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/constants.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/constants.py new file mode 100644 index 00000000..cc16da9f --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/constants.py @@ -0,0 +1,3 @@ + + +MAX_NLP_CONFIDENCE: int = 90 \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/convert.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/convert.py new file mode 100644 index 00000000..8ec60b35 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/convert.py @@ -0,0 +1,149 @@ +from math import ceil + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_nlp_response import \ + URLToNLPResponseMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_search_response import \ + URLToSearchResponseMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.constants import \ + MAX_NLP_CONFIDENCE +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.counter import RequestCounter +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.models.url_id_search_params import \ + URLToSearchParamsMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.params import \ + SearchSimilarLocationsParams +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.response import \ + SearchSimilarLocationsResponse +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.pydantic import AutoLocationIDSubtaskPydantic + + +def convert_invalid_url_nlp_mappings_to_subtask_data_list( + mappings: list[URLToNLPResponseMapping], + task_id: int +) -> list[AutoLocationIDSubtaskData]: + url_ids: list[int] = [] + for mapping in mappings: + url_ids.append(mapping.url_id) + + return convert_url_ids_to_empty_subtask_data_list( + url_ids=url_ids, + task_id=task_id + ) + +def convert_url_ids_to_empty_subtask_data_list( + url_ids: list[int], + task_id: int +) -> list[AutoLocationIDSubtaskData]: + results: list[AutoLocationIDSubtaskData] = [] + for url_id in url_ids: + subtask_data = AutoLocationIDSubtaskData( + pydantic_model=AutoLocationIDSubtaskPydantic( + task_id=task_id, + url_id=url_id, + type=LocationIDSubtaskType.NLP_LOCATION_FREQUENCY, + locations_found=False + ), + suggestions=[] + ) + results.append(subtask_data) + + return results + +def convert_search_location_responses_to_subtask_data_list( + mappings: list[URLToSearchResponseMapping], + task_id: int +) -> list[AutoLocationIDSubtaskData]: + subtask_data_list: list[AutoLocationIDSubtaskData] = [] + + # First, extract agency suggestions for URL + for mapping in mappings: + url_id: int = mapping.url_id + search_responses: list[SearchSimilarLocationsResponse] = mapping.search_responses + suggestions: list[LocationSuggestion] = _convert_search_agency_response_to_agency_suggestions( + search_responses + ) + pydantic_model: AutoLocationIDSubtaskPydantic = convert_search_agency_response_to_subtask_pydantic( + url_id=url_id, + task_id=task_id, + suggestions=suggestions + ) + subtask_data = AutoLocationIDSubtaskData( + pydantic_model=pydantic_model, + suggestions=suggestions + ) + subtask_data_list.append(subtask_data) + + return subtask_data_list + +def convert_search_agency_response_to_subtask_pydantic( + url_id: int, + task_id: int, + suggestions: list[LocationSuggestion] +) -> AutoLocationIDSubtaskPydantic: + + return AutoLocationIDSubtaskPydantic( + task_id=task_id, + url_id=url_id, + type=LocationIDSubtaskType.NLP_LOCATION_FREQUENCY, + locations_found=len(suggestions) > 0, + ) + +def _convert_search_agency_response_to_agency_suggestions( + responses: list[SearchSimilarLocationsResponse], +) -> list[LocationSuggestion]: + suggestions: list[LocationSuggestion] = [] + for response in responses: + for result in response.results: + location_id: int = result.location_id + similarity: float = result.similarity + confidence: int = min(ceil(similarity * 100), MAX_NLP_CONFIDENCE) + suggestion: LocationSuggestion = LocationSuggestion( + location_id=location_id, + confidence=confidence, + ) + suggestions.append(suggestion) + return suggestions + + + +def convert_urls_to_search_params( + url_to_nlp_mappings: list[URLToNLPResponseMapping] +) -> list[URLToSearchParamsMapping]: + url_to_search_params_mappings: list[URLToSearchParamsMapping] = [] + counter = RequestCounter() + for mapping in url_to_nlp_mappings: + search_params: list[SearchSimilarLocationsParams] = \ + convert_nlp_response_to_search_similar_location_params( + counter=counter, + nlp_response=mapping.nlp_response, + ) + mapping = URLToSearchParamsMapping( + url_id=mapping.url_id, + search_params=search_params, + ) + url_to_search_params_mappings.append(mapping) + return url_to_search_params_mappings + + +def convert_nlp_response_to_search_similar_location_params( + nlp_response: NLPLocationMatchResponse, + counter: RequestCounter +) -> list[SearchSimilarLocationsParams]: + params: list[SearchSimilarLocationsParams] = [] + for location in nlp_response.locations: + if nlp_response.us_state is None: + raise ValueError("US State is None; cannot convert NLP response to search agency by location params") + request_id: int = counter.next() + param = SearchSimilarLocationsParams( + request_id=request_id, + query=location, + iso=nlp_response.us_state.iso, + ) + params.append(param) + + return params + diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/core.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/core.py new file mode 100644 index 00000000..bfacd67e --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/core.py @@ -0,0 +1,151 @@ +from collections import defaultdict + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.input_ import \ + NLPLocationFrequencySubtaskInput +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.subsets import NLPResponseSubsets +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.filter import \ + filter_valid_and_invalid_nlp_responses, filter_top_n_suggestions, filter_out_responses_with_zero_similarity +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_search_response import \ + URLToSearchResponseMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.mapper import \ + URLRequestIDMapper +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_nlp_response import \ + URLToNLPResponseMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.convert import \ + convert_invalid_url_nlp_mappings_to_subtask_data_list, convert_search_location_responses_to_subtask_data_list, \ + convert_urls_to_search_params +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.models.url_id_search_params import \ + URLToSearchParamsMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.preprocess import \ + preprocess_html +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.core import \ + SearchSimilarLocationsQueryBuilder +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.params import \ + SearchSimilarLocationsParams +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.response import \ + SearchSimilarLocationsResponse, SearchSimilarLocationsOuterResponse +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.db.client.async_ import AsyncDatabaseClient + + +class NLPLocationFrequencySubtaskInternalProcessor: + + def __init__( + self, + nlp_processor: NLPProcessor, + adb_client: AsyncDatabaseClient, + task_id: int, + ): + self._nlp_processor = nlp_processor + self._adb_client = adb_client + self._task_id = task_id + + async def process( + self, + inputs: list[NLPLocationFrequencySubtaskInput] + ) -> list[AutoLocationIDSubtaskData]: + subtask_data_list: list[AutoLocationIDSubtaskData] = [] + + url_to_nlp_mappings: list[URLToNLPResponseMapping] = \ + self._parse_all_url_htmls_for_locations(inputs) + + # Filter out valid and invalid NLP responses + nlp_response_subsets: NLPResponseSubsets = \ + filter_valid_and_invalid_nlp_responses(url_to_nlp_mappings) + + + # For invalid responses, convert to subtask data with empty locations + subtask_data_no_location_list: list[AutoLocationIDSubtaskData] = \ + convert_invalid_url_nlp_mappings_to_subtask_data_list( + mappings=nlp_response_subsets.invalid, + task_id=self._task_id, + ) + subtask_data_list.extend(subtask_data_no_location_list) + + # For valid responses, convert to search param mappings + url_to_search_params_mappings: list[URLToSearchParamsMapping] = \ + convert_urls_to_search_params(nlp_response_subsets.valid) + + response_mappings: list[URLToSearchResponseMapping] = \ + await self._get_db_location_info(url_to_search_params_mappings) + + subtask_data_list_location_list: list[AutoLocationIDSubtaskData] = \ + convert_search_location_responses_to_subtask_data_list( + mappings=response_mappings, + task_id=self._task_id, + ) + + filter_top_n_suggestions(subtask_data_list_location_list) + + subtask_data_list.extend(subtask_data_list_location_list) + + return subtask_data_list + + async def _get_db_location_info( + self, + mappings: list[URLToSearchParamsMapping] + ) -> list[URLToSearchResponseMapping]: + if len(mappings) == 0: + return [] + params: list[SearchSimilarLocationsParams] = [] + # Map request IDs to URL IDs for later use + mapper = URLRequestIDMapper() + for mapping in mappings: + for search_param in mapping.search_params: + mapper.add_mapping( + request_id=search_param.request_id, + url_id=mapping.url_id, + ) + params.append(search_param) + + url_id_to_search_responses: dict[int, list[SearchSimilarLocationsResponse]] = defaultdict(list) + + outer_response: SearchSimilarLocationsOuterResponse = await self._adb_client.run_query_builder( + SearchSimilarLocationsQueryBuilder( + params=params, + ) + ) + responses: list[SearchSimilarLocationsResponse] = outer_response.responses + # Map responses to URL IDs via request IDs + for response in responses: + request_id: int = response.request_id + url_id: int = mapper.get_url_id_by_request_id(request_id) + url_id_to_search_responses[url_id].append(response) + + # Reconcile URL IDs to search responses + response_mappings: list[URLToSearchResponseMapping] = [] + for url_id, responses in url_id_to_search_responses.items(): + for response in responses: + response.results = filter_out_responses_with_zero_similarity(response.results) + + mapping = URLToSearchResponseMapping( + url_id=url_id, + search_responses=responses, + ) + response_mappings.append(mapping) + + return response_mappings + + def _parse_all_url_htmls_for_locations( + self, + inputs: list[NLPLocationFrequencySubtaskInput] + ) -> list[URLToNLPResponseMapping]: + url_to_nlp_mappings: list[URLToNLPResponseMapping] = [] + for input_ in inputs: + nlp_response: NLPLocationMatchResponse = self._parse_for_locations(input_.html) + mapping = URLToNLPResponseMapping( + url_id=input_.url_id, + nlp_response=nlp_response, + ) + url_to_nlp_mappings.append(mapping) + return url_to_nlp_mappings + + def _parse_for_locations( + self, + html: str + ) -> NLPLocationMatchResponse: + preprocessed_html: str = preprocess_html(html) + return self._nlp_processor.parse_for_locations(preprocessed_html) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/counter.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/counter.py new file mode 100644 index 00000000..12e9e048 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/counter.py @@ -0,0 +1,11 @@ + + + +class RequestCounter: + + def __init__(self): + self._counter: int = 0 + + def next(self) -> int: + self._counter += 1 + return self._counter \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/filter.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/filter.py new file mode 100644 index 00000000..474279b0 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/filter.py @@ -0,0 +1,65 @@ +from collections import defaultdict + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.mappings.url_id_nlp_response import \ + URLToNLPResponseMapping +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.subsets import NLPResponseSubsets +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.response import \ + SearchSimilarLocationsLocationInfo +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion + + +def filter_valid_and_invalid_nlp_responses( + mappings: list[URLToNLPResponseMapping] +) -> NLPResponseSubsets: + valid: list[URLToNLPResponseMapping] = [] + invalid: list[URLToNLPResponseMapping] = [] + for mapping in mappings: + nlp_response: NLPLocationMatchResponse = mapping.nlp_response + if nlp_response.valid: + valid.append(mapping) + else: + invalid.append(mapping) + return NLPResponseSubsets( + valid=valid, + invalid=invalid, + ) + +def filter_top_n_suggestions( + subtask_data_list: list[AutoLocationIDSubtaskData], + n: int = 5 +) -> None: + """Filters out all but the top N suggestions for each URL. + + Modifies: + - AutoLocationIDSubtaskData.suggestions + """ + for subtask_data in subtask_data_list: + # Eliminate location ID duplicates; + location_to_suggestions: dict[int, list[LocationSuggestion]] = defaultdict(list) + for suggestion in subtask_data.suggestions: + location_to_suggestions[suggestion.location_id].append(suggestion) + + # in the case of a tie, keep the suggestion with the highest confidence + deduped_suggestions: list[LocationSuggestion] = [] + for location_suggestions in location_to_suggestions.values(): + location_suggestions.sort( + key=lambda x: x.confidence, + reverse=True # Descending order + ) + deduped_suggestions.append(location_suggestions[0]) + + # Sort suggestions by confidence and keep top N + suggestions_sorted: list[LocationSuggestion] = sorted( + deduped_suggestions, + key=lambda x: x.confidence, + reverse=True # Descending order + ) + subtask_data.suggestions = suggestions_sorted[:n] + +def filter_out_responses_with_zero_similarity( + entries: list[SearchSimilarLocationsLocationInfo] +) -> list[SearchSimilarLocationsLocationInfo]: + return [entry for entry in entries if entry.similarity > 0] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/mapper.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/mapper.py new file mode 100644 index 00000000..8192dbb6 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/mapper.py @@ -0,0 +1,10 @@ +class URLRequestIDMapper: + + def __init__(self): + self._request_id_to_url_id_mapper: dict[int, int] = {} + + def add_mapping(self, request_id: int, url_id: int) -> None: + self._request_id_to_url_id_mapper[request_id] = url_id + + def get_url_id_by_request_id(self, request_id: int) -> int: + return self._request_id_to_url_id_mapper[request_id] diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/models/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/models/url_id_search_params.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/models/url_id_search_params.py new file mode 100644 index 00000000..d47992ee --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/models/url_id_search_params.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.params import \ + SearchSimilarLocationsParams +from src.external.pdap.dtos.search_agency_by_location.params import SearchAgencyByLocationParams + + +class URLToSearchParamsMapping(BaseModel): + url_id: int + search_params: list[SearchSimilarLocationsParams] + + @property + def is_empty(self) -> bool: + return len(self.search_params) == 0 \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/check.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/check.py new file mode 100644 index 00000000..502014f0 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/check.py @@ -0,0 +1,14 @@ +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.constants import \ + BLACKLISTED_WORDS +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.mappings import \ + US_STATE_ISO_TO_NAME, US_NAME_TO_STATE_ISO + + +def is_iso_us_state(iso: str) -> bool: + return iso in US_STATE_ISO_TO_NAME + +def is_name_us_state(name: str) -> bool: + return name in US_NAME_TO_STATE_ISO + +def is_blacklisted_word(word: str) -> bool: + return word.lower() in BLACKLISTED_WORDS \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/constants.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/constants.py new file mode 100644 index 00000000..01c13edb --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/constants.py @@ -0,0 +1,26 @@ + + +TOP_N_LOCATIONS_COUNT: int = 5 + +INVALID_LOCATION_CHARACTERS: set[str] = { + "=", + "\\", + "/", + "\'", + "\"," +} + +# State ISOs that commonly align with other words, +# Which cannot be used in simple text scanning +INVALID_SCAN_ISOS: set[str] = { + "IN", + "OR", + "ME", + "ID" +} + +BLACKLISTED_WORDS: set[str] = { + "the united states", + "download", + "geoplatform" +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/convert.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/convert.py new file mode 100644 index 00000000..a0796b4c --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/convert.py @@ -0,0 +1,27 @@ +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.mappings import \ + US_STATE_ISO_TO_NAME, US_NAME_TO_STATE_ISO +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.us_state import \ + USState + + +def convert_us_state_iso_to_us_state(iso: str) -> USState | None: + name: str | None = US_STATE_ISO_TO_NAME.get(iso, None) + + if name is None: + return None + + return USState( + name=name, + iso=iso + ) + +def convert_us_state_name_to_us_state(name: str) -> USState | None: + iso: str | None = US_NAME_TO_STATE_ISO.get(name, None) + + if iso is None: + return None + + return USState( + name=name, + iso=iso + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/core.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/core.py new file mode 100644 index 00000000..275e2946 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/core.py @@ -0,0 +1,90 @@ +from collections import Counter + +import spacy +from spacy import Language +from spacy.tokens import Doc + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.check import \ + is_name_us_state, is_iso_us_state, is_blacklisted_word +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.constants import \ + INVALID_LOCATION_CHARACTERS, INVALID_SCAN_ISOS +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.convert import \ + convert_us_state_name_to_us_state, convert_us_state_iso_to_us_state +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.enums import \ + SpacyModelType +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.extract import \ + extract_most_common_us_state, extract_top_n_locations +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.us_state import \ + USState + + +class NLPProcessor: + + def __init__( + self, + model_type: SpacyModelType = SpacyModelType.EN_CORE_WEB_SM + ): + self._model_type: SpacyModelType = model_type + self._model: Language | None = None + + def lazy_load_model(self) -> Language: + if self._model is None: + self._model = spacy.load(self._model_type.value, disable=['parser']) + return self._model + + + def parse_for_locations(self, html: str) -> NLPLocationMatchResponse: + model: Language = self.lazy_load_model() + doc: Doc = model(html) + us_state_counter: Counter[USState] = Counter() + location_counter: Counter[str] = Counter() + + # Scan over tokens + for token in doc: + upper_token: str = token.text.upper() + # Disregard certain ISOs that align with common words + if upper_token in INVALID_SCAN_ISOS: + continue + if not is_iso_us_state(upper_token): + continue + + us_state: USState | None = convert_us_state_iso_to_us_state(upper_token) + if us_state is not None: + us_state_counter[us_state] += 1 + + + # Scan over entities using spacy + for ent in doc.ents: + if ent.label_ != "GPE": # Geopolitical Entity + continue + text: str = ent.text + if any(char in text for char in INVALID_LOCATION_CHARACTERS): + continue + if is_blacklisted_word(text): + continue + if is_name_us_state(text): + us_state: USState | None = convert_us_state_name_to_us_state(text) + if us_state is not None: + us_state_counter[us_state] += 1 + continue + if is_iso_us_state(text): + us_state: USState | None = convert_us_state_iso_to_us_state(text) + if us_state is not None: + us_state_counter[us_state] += 1 + continue + location_counter[text] += 1 + + # Get most common US State if exists + most_common_us_state: USState | None = extract_most_common_us_state(us_state_counter) + + top_n_locations: list[str] = extract_top_n_locations(location_counter) + + return NLPLocationMatchResponse( + us_state=most_common_us_state, + locations=top_n_locations + ) + + + diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/enums.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/enums.py new file mode 100644 index 00000000..9d1b987b --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/enums.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class SpacyModelType(Enum): + EN_CORE_WEB_SM = "en_core_web_sm" + EN_CORE_WEB_LG = "en_core_web_lg" + EN_CORE_WEB_MD = "en_core_web_md" + EN_CORE_WEB_TRF = "en_core_web_trf" \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/extract.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/extract.py new file mode 100644 index 00000000..4b84ecc4 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/extract.py @@ -0,0 +1,25 @@ +from collections import Counter + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.constants import \ + TOP_N_LOCATIONS_COUNT +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.us_state import \ + USState + + +def extract_most_common_us_state( + us_state_counter: Counter[USState] +) -> USState | None: + try: + return us_state_counter.most_common(1)[0][0] + except IndexError: + return None + +def extract_top_n_locations( + location_counter: Counter[str] +) -> list[str]: + top_n_locations_raw: list[tuple[str, int]] = \ + location_counter.most_common(TOP_N_LOCATIONS_COUNT) + top_n_locations: list[str] = [] + for location, _ in top_n_locations_raw: + top_n_locations.append(location) + return top_n_locations \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/mappings.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/mappings.py new file mode 100644 index 00000000..03417480 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/mappings.py @@ -0,0 +1,59 @@ + + +US_STATE_ISO_TO_NAME: dict[str, str] = { + 'AL': 'Alabama', + 'AK': 'Alaska', + 'AZ': 'Arizona', + 'AR': 'Arkansas', + 'CA': 'California', + 'CO': 'Colorado', + 'CT': 'Connecticut', + 'DE': 'Delaware', + 'FL': 'Florida', + 'GA': 'Georgia', + 'HI': 'Hawaii', + 'ID': 'Idaho', + 'IL': 'Illinois', + 'IN': 'Indiana', + 'IA': 'Iowa', + 'KS': 'Kansas', + 'KY': 'Kentucky', + 'LA': 'Louisiana', + 'ME': 'Maine', + 'MD': 'Maryland', + 'MA': 'Massachusetts', + 'MI': 'Michigan', + 'MN': 'Minnesota', + 'MS': 'Mississippi', + 'MO': 'Missouri', + 'MT': 'Montana', + 'NE': 'Nebraska', + 'NV': 'Nevada', + 'NH': 'New Hampshire', + 'NJ': 'New Jersey', + 'NM': 'New Mexico', + 'NY': 'New York', + 'NC': 'North Carolina', + 'ND': 'North Dakota', + 'OH': 'Ohio', + 'OK': 'Oklahoma', + 'OR': 'Oregon', + 'PA': 'Pennsylvania', + 'RI': 'Rhode Island', + 'SC': 'South Carolina', + 'SD': 'South Dakota', + 'TN': 'Tennessee', + 'TX': 'Texas', + 'UT': 'Utah', + 'VT': 'Vermont', + 'VA': 'Virginia', + 'WA': 'Washington', + 'WV': 'West Virginia', + 'WI': 'Wisconsin', + 'WY': 'Wyoming', + 'DC': 'District of Columbia', +} + +US_NAME_TO_STATE_ISO: dict[str, str] = { + name: iso for iso, name in US_STATE_ISO_TO_NAME.items() +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/params.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/params.py new file mode 100644 index 00000000..79378612 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/params.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class NLPLocationMatchParams(BaseModel): + url_id: int + html: str \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/response.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/response.py new file mode 100644 index 00000000..11fc66e5 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/response.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.us_state import \ + USState + + +class NLPLocationMatchResponse(BaseModel): + locations: list[str] + us_state: USState | None + + @property + def valid(self) -> bool: + # Valid responses must have a US State and at least one location + if self.us_state is None: + return False + if len(self.locations) == 0: + return False + return True diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/us_state.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/us_state.py new file mode 100644 index 00000000..0b29771f --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/models/us_state.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, ConfigDict + + +class USState(BaseModel): + model_config = ConfigDict(frozen=True) + + name: str + iso: str diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/preprocess.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/preprocess.py new file mode 100644 index 00000000..da20f4f4 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/nlp/preprocess.py @@ -0,0 +1,20 @@ +import re + +import unicodedata +from bs4 import BeautifulSoup + + +def preprocess_html(raw_html: str) -> str: + """Preprocess HTML to extract text content.""" + soup = BeautifulSoup(raw_html, 'lxml') + + # Remove scripts, styles, and other non-textual elements + for tag in soup(['script','style','noscript','iframe','canvas','svg','header','footer','nav','aside']): + tag.decompose() + # Extract text + text = soup.get_text(separator=' ') + # Normalize text and collapse whitespace + text = unicodedata.normalize('NFKC', text) + text = re.sub(r'[ \t\u00A0]+', ' ', text) + text = re.sub(r'\n\s*\n\s*\n+', '\n\n', text) + return text.strip() \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/core.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/core.py new file mode 100644 index 00000000..f6011f49 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/core.py @@ -0,0 +1,114 @@ +from collections import defaultdict +from typing import Any, Sequence + +from sqlalchemy import values, column, String, Integer, func, select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.params import \ + SearchSimilarLocationsParams +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.query_.models.response import \ + SearchSimilarLocationsOuterResponse, SearchSimilarLocationsLocationInfo, SearchSimilarLocationsResponse +from src.db.models.views.location_expanded import LocationExpandedView +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class SearchSimilarLocationsQueryBuilder(QueryBuilderBase): + + def __init__( + self, + params: list[SearchSimilarLocationsParams] + ): + super().__init__() + self.params = params + + async def run(self, session: AsyncSession) -> SearchSimilarLocationsOuterResponse: + queries_as_tups: list[tuple[int, str, str]] = [ + ( + param.request_id, + param.query, + param.iso, + ) + for param in self.params + ] + + vals = ( + values( + column("request_id", Integer), + column("query", String), + column("iso", String), + name="input_queries", + ) + .data(queries_as_tups) + .alias("input_queries_alias") + ) + + similarity = func.similarity( + vals.c.query, + LocationExpandedView.display_name, + ) + + lateral_top_5 = ( + select( + vals.c.request_id, + LocationExpandedView.id.label("location_id"), + func.row_number().over( + partition_by=vals.c.request_id, + order_by=similarity.desc(), + ).label("rank"), + similarity.label("similarity"), + ) + .join( + LocationExpandedView, + LocationExpandedView.state_iso == vals.c.iso, + ) + .order_by( + similarity.desc(), + ) + .lateral("lateral_top_5") + ) + + final = ( + select( + vals.c.request_id, + lateral_top_5.c.location_id, + lateral_top_5.c.similarity, + ).join( + lateral_top_5, + vals.c.request_id == lateral_top_5.c.request_id, + ) + .where( + lateral_top_5.c.rank <= 5, + ) + ) + + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=final) + request_id_to_locations: dict[int, list[SearchSimilarLocationsLocationInfo]] = ( + defaultdict(list) + ) + for mapping in mappings: + inner_response = SearchSimilarLocationsLocationInfo( + location_id=mapping["location_id"], + similarity=mapping["similarity"], + ) + request_id: int = mapping["request_id"] + request_id_to_locations[request_id].append(inner_response) + + responses: list[SearchSimilarLocationsResponse] = [] + for request_id, inner_responses in request_id_to_locations.items(): + sorted_responses: list[SearchSimilarLocationsLocationInfo] = sorted( + inner_responses, + key=lambda x: x.similarity, + reverse=True, + ) + request_level_response = SearchSimilarLocationsResponse( + request_id=request_id, + results=sorted_responses, + ) + responses.append(request_level_response) + + return SearchSimilarLocationsOuterResponse( + responses=responses, + ) + diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/params.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/params.py new file mode 100644 index 00000000..180d27b4 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/params.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class SearchSimilarLocationsParams(BaseModel): + request_id: int + query: str + iso: str = Field( + description="US State ISO Code", + max_length=2, + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/response.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/response.py new file mode 100644 index 00000000..95bf9e93 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/processor/query_/models/response.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field + + +class SearchSimilarLocationsLocationInfo(BaseModel): + location_id: int + similarity: float = Field(ge=0, le=1) + +class SearchSimilarLocationsResponse(BaseModel): + request_id: int + results: list[SearchSimilarLocationsLocationInfo] + +class SearchSimilarLocationsOuterResponse(BaseModel): + responses: list[SearchSimilarLocationsResponse] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/query.py b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/query.py new file mode 100644 index 00000000..96b63bb1 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/impl/nlp_location_freq/query.py @@ -0,0 +1,48 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.constants import \ + NUMBER_OF_ENTRIES_PER_ITERATION +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.input_ import \ + NLPLocationFrequencySubtaskInput +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.eligible import EligibleContainer +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.queries.base.builder import QueryBuilderBase +from src.db.utils.compression import decompress_html + + +class GetNLPLocationFrequencySubtaskInputQueryBuilder(QueryBuilderBase): + + async def run( + self, + session: AsyncSession + ) -> list[NLPLocationFrequencySubtaskInput]: + container = EligibleContainer() + query = ( + select( + container.url_id, + URLCompressedHTML.compressed_html + ) + .join( + URLCompressedHTML, + URLCompressedHTML.url_id == container.url_id, + ) + .where( + container.nlp_location, + ) + .limit(NUMBER_OF_ENTRIES_PER_ITERATION) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + inputs: list[NLPLocationFrequencySubtaskInput] = [ + NLPLocationFrequencySubtaskInput( + url_id=mapping["id"], + html=decompress_html(mapping["compressed_html"]), + ) + for mapping in mappings + ] + return inputs + diff --git a/src/core/tasks/url/operators/location_id/subtasks/loader.py b/src/core/tasks/url/operators/location_id/subtasks/loader.py new file mode 100644 index 00000000..408b5a07 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/loader.py @@ -0,0 +1,44 @@ +from src.core.tasks.url.operators.location_id.subtasks.impl.batch_link.core import LocationBatchLinkSubtaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.core import \ + NLPLocationFrequencySubtaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.templates.subtask import LocationIDSubtaskOperatorBase +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + + +class LocationIdentificationSubtaskLoader: + """Loads subtasks and associated dependencies.""" + + def __init__( + self, + adb_client: AsyncDatabaseClient, + nlp_processor: NLPProcessor, + ): + self.adb_client = adb_client + self._nlp_processor = nlp_processor + + def _load_nlp_location_match_subtask(self, task_id: int) -> NLPLocationFrequencySubtaskOperator: + return NLPLocationFrequencySubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + nlp_processor=self._nlp_processor + ) + + def _load_batch_link_subtask(self, task_id: int) -> LocationBatchLinkSubtaskOperator: + return LocationBatchLinkSubtaskOperator( + task_id=task_id, + adb_client=self.adb_client, + ) + + async def load_subtask( + self, + subtask_type: LocationIDSubtaskType, + task_id: int + ) -> LocationIDSubtaskOperatorBase: + match subtask_type: + case LocationIDSubtaskType.NLP_LOCATION_FREQUENCY: + return self._load_nlp_location_match_subtask(task_id=task_id) + case LocationIDSubtaskType.BATCH_LINK: + return self._load_batch_link_subtask(task_id=task_id) + raise ValueError(f"Unknown subtask type: {subtask_type}") diff --git a/src/core/tasks/url/operators/location_id/subtasks/models/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/models/run_info.py b/src/core/tasks/url/operators/location_id/subtasks/models/run_info.py new file mode 100644 index 00000000..de382736 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/models/run_info.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + + +class LocationIDSubtaskRunInfo(BaseModel): + error: str | None = None + linked_url_ids: list[int] | None = None + + @property + def is_success(self) -> bool: + return self.error is None + + @property + def has_linked_urls(self) -> bool: + return len(self.linked_url_ids) > 0 \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/models/subtask.py b/src/core/tasks/url/operators/location_id/subtasks/models/subtask.py new file mode 100644 index 00000000..b06d2ff9 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/models/subtask.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion +from src.db.models.impl.url.suggestion.location.auto.subtask.pydantic import AutoLocationIDSubtaskPydantic + + +class AutoLocationIDSubtaskData(BaseModel): + pydantic_model: AutoLocationIDSubtaskPydantic + suggestions: list[LocationSuggestion] + error: str | None = None + + @property + def has_error(self) -> bool: + return self.error is not None + + @property + def url_id(self) -> int: + return self.pydantic_model.url_id \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/models/suggestion.py b/src/core/tasks/url/operators/location_id/subtasks/models/suggestion.py new file mode 100644 index 00000000..3c4ef6e9 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/models/suggestion.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class LocationSuggestion(BaseModel): + location_id: int + confidence: int = Field(ge=0, le=100) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/constants.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/constants.py new file mode 100644 index 00000000..b9f85e2d --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/constants.py @@ -0,0 +1,12 @@ +# Determines priority of subtasks, all else being equal. +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + +SUBTASK_HIERARCHY: list[LocationIDSubtaskType] = [ + LocationIDSubtaskType.NLP_LOCATION_FREQUENCY, + LocationIDSubtaskType.BATCH_LINK +] + +SUBTASK_HIERARCHY_MAPPING: dict[LocationIDSubtaskType, int] = { + subtask: idx + for idx, subtask in enumerate(SUBTASK_HIERARCHY) +} \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/core.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/core.py new file mode 100644 index 00000000..c267b89e --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/core.py @@ -0,0 +1,73 @@ +from collections import Counter + +from sqlalchemy import RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.constants import SUBTASK_HIERARCHY_MAPPING +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.eligible_counts import \ + ELIGIBLE_COUNTS_QUERY +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class LocationIDSurveyQueryBuilder(QueryBuilderBase): + """ + Survey applicable URLs to determine next subtask to run + + URLs are "inapplicable" if they have any of the following properties: + - Are validated via FlagURLValidated model + - Have at least one annotation with agency suggestion with confidence >= 95 + - Have all possible subtasks completed + + Returns a list of one or more subtasks to run + based on which subtask(s) have the most applicable URLs + (or an empty list if no subtasks have applicable URLs) + """ + + def __init__( + self, + allowed_subtasks: list[LocationIDSubtaskType] + ): + super().__init__() + self._allowed_subtasks = allowed_subtasks + + async def run(self, session: AsyncSession) -> LocationIDSubtaskType | None: + results: RowMapping = await sh.mapping(session, ELIGIBLE_COUNTS_QUERY) + counts: Counter[str] = Counter(results) + + allowed_counts: Counter[str] = await self._filter_allowed_counts(counts) + if len(allowed_counts) == 0: + return None + max_count: int = max(allowed_counts.values()) + if max_count == 0: + return None + subtasks_with_max_count: list[str] = [ + subtask for subtask, count in allowed_counts.items() + if count == max_count + ] + subtasks_as_enum_list: list[LocationIDSubtaskType] = [ + LocationIDSubtaskType(subtask) + for subtask in subtasks_with_max_count + ] + # Sort subtasks by priority + sorted_subtasks: list[LocationIDSubtaskType] = sorted( + subtasks_as_enum_list, + key=lambda subtask: SUBTASK_HIERARCHY_MAPPING[subtask], + reverse=True, + ) + # Return the highest priority subtask + return sorted_subtasks[0] + + async def _filter_allowed_counts(self, counts: Counter[str]) -> Counter[str]: + return Counter( + { + subtask: count + for subtask, count in counts.items() + if LocationIDSubtaskType(subtask) in self._allowed_subtasks + } + ) + + + + diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/eligible.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/eligible.py new file mode 100644 index 00000000..1c97f8fb --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/eligible.py @@ -0,0 +1,45 @@ + + +from sqlalchemy import select, CTE, Column + +from src.core.tasks.url.operators._shared.ctes.validated import VALIDATED_EXISTS_CONTAINER +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.exists.high_confidence_annotations import \ + HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.subtask.impl.batch_link import \ + BATCH_LINK_CONTAINER +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.subtask.impl.nlp_location_freq import \ + NLP_LOCATION_CONTAINER +from src.db.models.impl.url.core.sqlalchemy import URL + + +class EligibleContainer: + + def __init__(self): + self._cte = ( + select( + URL.id, + NLP_LOCATION_CONTAINER.eligible_query.label("nlp_location"), + BATCH_LINK_CONTAINER.eligible_query.label("batch_link"), + ) + .where( + HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER.not_exists_query, + VALIDATED_EXISTS_CONTAINER.not_exists_query, + ) + .cte("eligible") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.c['id'] + + @property + def nlp_location(self) -> Column[bool]: + return self._cte.c['nlp_location'] + + @property + def batch_link(self) -> Column[bool]: + return self._cte.c['batch_link'] \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/exists/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/exists/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py new file mode 100644 index 00000000..7d0dddfd --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/exists/high_confidence_annotations.py @@ -0,0 +1,29 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.exists import \ + URLsSubtaskExistsCTEContainer +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion + +cte = ( + select( + URL.id + ) + .join( + AutoLocationIDSubtask, + AutoLocationIDSubtask.url_id == URL.id, + ) + .join( + LocationIDSubtaskSuggestion, + LocationIDSubtaskSuggestion.subtask_id == AutoLocationIDSubtask.id, + ) + .where( + LocationIDSubtaskSuggestion.confidence >= 95, + ) + .cte("high_confidence_annotations_exists") +) + +HIGH_CONFIDENCE_ANNOTATIONS_EXISTS_CONTAINER = URLsSubtaskExistsCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/helpers.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/helpers.py new file mode 100644 index 00000000..acd73c4b --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/helpers.py @@ -0,0 +1,18 @@ +from sqlalchemy import ColumnElement, exists + +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask + + +def get_exists_subtask_query( + subtask_type: LocationIDSubtaskType, +) -> ColumnElement[bool]: + return ( + exists() + .where( + AutoLocationIDSubtask.url_id == URL.id, + AutoLocationIDSubtask.type == subtask_type, + ) + .label("subtask_entry_exists") + ) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py new file mode 100644 index 00000000..14c2f260 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/batch_link.py @@ -0,0 +1,31 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.link.location_batch.sqlalchemy import LinkLocationBatch +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + +cte = ( + select( + URL.id, + get_exists_subtask_query( + LocationIDSubtaskType.BATCH_LINK + ) + ) + .join( + LinkBatchURL, + LinkBatchURL.url_id == URL.id, + ) + .join( + LinkLocationBatch, + LinkLocationBatch.batch_id == LinkBatchURL.batch_id, + ) + .cte("batch_link") +) + +BATCH_LINK_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location_freq.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location_freq.py new file mode 100644 index 00000000..7ab2e0eb --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/ctes/subtask/impl/nlp_location_freq.py @@ -0,0 +1,25 @@ +from sqlalchemy import select + +from src.core.tasks.url.operators._shared.container.subtask.eligible import URLsSubtaskEligibleCTEContainer +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.subtask.helpers import \ + get_exists_subtask_query +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + +cte = ( + select( + URL.id, + get_exists_subtask_query( + LocationIDSubtaskType.NLP_LOCATION_FREQUENCY + ) + ) + .join( + URLCompressedHTML, + ) + .cte("nlp_location_eligible") +) + +NLP_LOCATION_CONTAINER = URLsSubtaskEligibleCTEContainer( + cte, +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/eligible_counts.py b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/eligible_counts.py new file mode 100644 index 00000000..b803b7f2 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/queries/survey/queries/eligible_counts.py @@ -0,0 +1,22 @@ +from sqlalchemy import ColumnElement, func, Integer, select + +from src.core.tasks.url.operators.location_id.subtasks.queries.survey.queries.ctes.eligible import EligibleContainer +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType + + +def sum_count(col: ColumnElement[bool], subtask_type: LocationIDSubtaskType) -> ColumnElement[int]: + return func.coalesce( + func.sum( + col.cast(Integer) + ), + 0, + ).label(subtask_type.value) + +container = EligibleContainer() + +ELIGIBLE_COUNTS_QUERY = ( + select( + sum_count(container.nlp_location, LocationIDSubtaskType.NLP_LOCATION_FREQUENCY), + sum_count(container.batch_link, LocationIDSubtaskType.BATCH_LINK) + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/location_id/subtasks/templates/__init__.py b/src/core/tasks/url/operators/location_id/subtasks/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/location_id/subtasks/templates/subtask.py b/src/core/tasks/url/operators/location_id/subtasks/templates/subtask.py new file mode 100644 index 00000000..8ee856c2 --- /dev/null +++ b/src/core/tasks/url/operators/location_id/subtasks/templates/subtask.py @@ -0,0 +1,98 @@ +import abc +import traceback +from abc import ABC + +from src.core.tasks.url.operators.location_id.subtasks.models.run_info import LocationIDSubtaskRunInfo +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.pydantic import AutoLocationIDSubtaskPydantic +from src.db.models.impl.url.suggestion.location.auto.suggestion.pydantic import LocationIDSubtaskSuggestionPydantic +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +class LocationIDSubtaskOperatorBase(ABC): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + task_id: int + ) -> None: + self.adb_client: AsyncDatabaseClient = adb_client + self.task_id: int = task_id + self.linked_urls: list[int] = [] + + async def run(self) -> LocationIDSubtaskRunInfo: + try: + await self.inner_logic() + except Exception as e: + # Get stack trace + stack_trace: str = traceback.format_exc() + return LocationIDSubtaskRunInfo( + error=f"{type(e).__name__}: {str(e)}: {stack_trace}", + linked_url_ids=self.linked_urls + ) + return LocationIDSubtaskRunInfo( + linked_url_ids=self.linked_urls + ) + + @abc.abstractmethod + async def inner_logic(self) -> LocationIDSubtaskRunInfo: + raise NotImplementedError + + async def _upload_subtask_data( + self, + subtask_data_list: list[AutoLocationIDSubtaskData] + ) -> None: + + subtask_models: list[AutoLocationIDSubtaskPydantic] = [ + subtask_data.pydantic_model + for subtask_data in subtask_data_list + ] + subtask_ids: list[int] = await self.adb_client.bulk_insert( + models=subtask_models, + return_ids=True + ) + suggestions: list[LocationIDSubtaskSuggestionPydantic] = [] + for subtask_id, subtask_info in zip(subtask_ids, subtask_data_list): + suggestions_raw: list[LocationSuggestion] = subtask_info.suggestions + for suggestion in suggestions_raw: + suggestion_pydantic = LocationIDSubtaskSuggestionPydantic( + subtask_id=subtask_id, + location_id=suggestion.location_id, + confidence=suggestion.confidence, + ) + suggestions.append(suggestion_pydantic) + + await self.adb_client.bulk_insert( + models=suggestions, + ) + + error_infos: list[URLTaskErrorSmall] = [] + for subtask_info in subtask_data_list: + if not subtask_info.has_error: + continue + error_info = URLTaskErrorSmall( + url_id=subtask_info.url_id, + error=subtask_info.error, + ) + error_infos.append(error_info) + + await self.add_task_errors(error_infos) + + async def add_task_errors( + self, + errors: list[URLTaskErrorSmall] + ) -> None: + inserts: list[URLTaskErrorPydantic] = [ + URLTaskErrorPydantic( + task_id=self.task_id, + url_id=error.url_id, + task_type=TaskType.LOCATION_ID, + error=error.error + ) + for error in errors + ] + await self.adb_client.bulk_insert(inserts) \ No newline at end of file diff --git a/src/core/tasks/url/operators/misc_metadata/__init__.py b/src/core/tasks/url/operators/misc_metadata/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/misc_metadata/core.py b/src/core/tasks/url/operators/misc_metadata/core.py new file mode 100644 index 00000000..1db953d4 --- /dev/null +++ b/src/core/tasks/url/operators/misc_metadata/core.py @@ -0,0 +1,86 @@ +from src.collectors.enums import CollectorType +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.misc_metadata.queries.get_pending_urls_missing_miscellaneous_data import \ + GetPendingURLsMissingMiscellaneousDataQueryBuilder +from src.core.tasks.url.operators.misc_metadata.queries.has_pending_urls_missing_miscellaneous_data import \ + HasPendingURsMissingMiscellaneousDataQueryBuilder +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.subtasks.miscellaneous_metadata.auto_googler import AutoGooglerMiscMetadataSubtask +from src.core.tasks.url.subtasks.miscellaneous_metadata.base import \ + MiscellaneousMetadataSubtaskBase +from src.core.tasks.url.subtasks.miscellaneous_metadata.ckan import CKANMiscMetadataSubtask +from src.core.tasks.url.subtasks.miscellaneous_metadata.muckrock import MuckrockMiscMetadataSubtask +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +class URLMiscellaneousMetadataTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient + ): + super().__init__(adb_client) + + @property + def task_type(self) -> TaskType: + return TaskType.MISC_METADATA + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder(HasPendingURsMissingMiscellaneousDataQueryBuilder()) + + async def get_subtask( + self, + collector_type: CollectorType + ) -> MiscellaneousMetadataSubtaskBase | None: + match collector_type: + case CollectorType.MUCKROCK_SIMPLE_SEARCH: + return MuckrockMiscMetadataSubtask() + case CollectorType.MUCKROCK_COUNTY_SEARCH: + return MuckrockMiscMetadataSubtask() + case CollectorType.MUCKROCK_ALL_SEARCH: + return MuckrockMiscMetadataSubtask() + case CollectorType.AUTO_GOOGLER: + return AutoGooglerMiscMetadataSubtask() + case CollectorType.CKAN: + return CKANMiscMetadataSubtask() + case _: + return None + + async def html_default_logic(self, tdo: URLMiscellaneousMetadataTDO): + """ + Modifies: + tdo.name + tdo.description + """ + if tdo.name is None: + tdo.name = tdo.html_metadata_info.title + if tdo.description is None: + tdo.description = tdo.html_metadata_info.description + + async def inner_task_logic(self) -> None: + tdos: list[URLMiscellaneousMetadataTDO] = await self.get_pending_urls_missing_miscellaneous_metadata() + await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) + + task_errors: list[URLTaskErrorSmall] = [] + for tdo in tdos: + subtask = await self.get_subtask(tdo.collector_type) + try: + if subtask is not None: + subtask.process(tdo) + await self.html_default_logic(tdo) + except Exception as e: + error_info = URLTaskErrorSmall( + url_id=tdo.url_id, + error=str(e), + ) + task_errors.append(error_info) + + await self.adb_client.add_miscellaneous_metadata(tdos) + await self.add_task_errors(task_errors) + + async def get_pending_urls_missing_miscellaneous_metadata( + self, + ) -> list[URLMiscellaneousMetadataTDO]: + return await self.adb_client.run_query_builder(GetPendingURLsMissingMiscellaneousDataQueryBuilder()) diff --git a/src/core/tasks/url/operators/misc_metadata/queries/__init__.py b/src/core/tasks/url/operators/misc_metadata/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/queries/get_pending_urls_missing_miscellaneous_data.py b/src/core/tasks/url/operators/misc_metadata/queries/get_pending_urls_missing_miscellaneous_data.py similarity index 86% rename from src/core/tasks/url/operators/url_miscellaneous_metadata/queries/get_pending_urls_missing_miscellaneous_data.py rename to src/core/tasks/url/operators/misc_metadata/queries/get_pending_urls_missing_miscellaneous_data.py index c4c9892f..0efbfceb 100644 --- a/src/core/tasks/url/operators/url_miscellaneous_metadata/queries/get_pending_urls_missing_miscellaneous_data.py +++ b/src/core/tasks/url/operators/misc_metadata/queries/get_pending_urls_missing_miscellaneous_data.py @@ -1,12 +1,10 @@ -from typing import Any - from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from src.collectors.enums import CollectorType -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo -from src.db.dtos.url.html_content import HTMLContentType -from src.db.models.instantiations.url.core import URL +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo +from src.db.models.impl.url.html.content.enums import HTMLContentType +from src.db.models.impl.url.core.sqlalchemy import URL from src.db.queries.base.builder import QueryBuilderBase from src.db.statement_composer import StatementComposer diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/queries/has_pending_urls_missing_miscellaneous_data.py b/src/core/tasks/url/operators/misc_metadata/queries/has_pending_urls_missing_miscellaneous_data.py similarity index 100% rename from src/core/tasks/url/operators/url_miscellaneous_metadata/queries/has_pending_urls_missing_miscellaneous_data.py rename to src/core/tasks/url/operators/misc_metadata/queries/has_pending_urls_missing_miscellaneous_data.py diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/tdo.py b/src/core/tasks/url/operators/misc_metadata/tdo.py similarity index 100% rename from src/core/tasks/url/operators/url_miscellaneous_metadata/tdo.py rename to src/core/tasks/url/operators/misc_metadata/tdo.py diff --git a/src/core/tasks/url/operators/probe/__init__.py b/src/core/tasks/url/operators/probe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/convert.py b/src/core/tasks/url/operators/probe/convert.py new file mode 100644 index 00000000..dcb211f0 --- /dev/null +++ b/src/core/tasks/url/operators/probe/convert.py @@ -0,0 +1,18 @@ +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic + + +def convert_tdo_to_web_metadata_list(tdos: list[URLProbeTDO]) -> list[URLWebMetadataPydantic]: + results: list[URLWebMetadataPydantic] = [] + for tdo in tdos: + response = tdo.response.response + web_metadata_object = URLWebMetadataPydantic( + url_id=tdo.url_mapping.url_id, + accessed=response.status_code != 404, + status_code=response.status_code, + content_type=response.content_type, + error_message=response.error + ) + results.append(web_metadata_object) + return results + diff --git a/src/core/tasks/url/operators/probe/core.py b/src/core/tasks/url/operators/probe/core.py new file mode 100644 index 00000000..1c961155 --- /dev/null +++ b/src/core/tasks/url/operators/probe/core.py @@ -0,0 +1,85 @@ +from typing import final +from typing_extensions import override + +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.probe.convert import convert_tdo_to_web_metadata_list +from src.core.tasks.url.operators.probe.filter import filter_non_redirect_tdos, filter_redirect_tdos +from src.core.tasks.url.operators.probe.queries.insert_redirects.query import InsertRedirectsQueryBuilder +from src.core.tasks.url.operators.probe.queries.urls.not_probed.exists import HasURLsWithoutProbeQueryBuilder +from src.core.tasks.url.operators.probe.queries.urls.not_probed.get.query import GetURLsWithoutProbeQueryBuilder +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic +from src.external.url_request.core import URLRequestInterface +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType + +@final +class URLProbeTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + url_request_interface: URLRequestInterface + ): + super().__init__(adb_client=adb_client) + self.url_request_interface = url_request_interface + + + @property + @override + def task_type(self) -> TaskType: + return TaskType.PROBE_URL + + @override + async def meets_task_prerequisites(self) -> bool: + return await self.has_urls_without_probe() + + async def get_urls_without_probe(self) -> list[URLProbeTDO]: + url_mappings: list[URLMapping] = await self.adb_client.run_query_builder( + GetURLsWithoutProbeQueryBuilder() + ) + return [URLProbeTDO(url_mapping=url_mapping) for url_mapping in url_mappings] + + @override + async def inner_task_logic(self) -> None: + tdos = await self.get_urls_without_probe() + await self.link_urls_to_task( + url_ids=[tdo.url_mapping.url_id for tdo in tdos] + ) + await self.probe_urls(tdos) + await self.update_database(tdos) + + async def probe_urls(self, tdos: list[URLProbeTDO]) -> None: + """Probe URLs and add responses to URLProbeTDO + + Modifies: + URLProbeTDO.response + """ + url_to_tdo: dict[str, URLProbeTDO] = { + tdo.url_mapping.url: tdo for tdo in tdos + } + responses = await self.url_request_interface.probe_urls( + urls=[tdo.url_mapping.url for tdo in tdos] + ) + # Re-associate the responses with the URL mappings + for response in responses: + tdo = url_to_tdo[response.original_url] + tdo.response = response + + async def update_database(self, tdos: list[URLProbeTDO]) -> None: + non_redirect_tdos = filter_non_redirect_tdos(tdos) + web_metadata_objects: list[URLWebMetadataPydantic] = convert_tdo_to_web_metadata_list(non_redirect_tdos) + await self.adb_client.bulk_upsert(web_metadata_objects) + + redirect_tdos: list[URLProbeTDO] = filter_redirect_tdos(tdos) + + query_builder = InsertRedirectsQueryBuilder(tdos=redirect_tdos) + await self.adb_client.run_query_builder(query_builder) + + + async def has_urls_without_probe(self) -> bool: + return await self.adb_client.run_query_builder( + HasURLsWithoutProbeQueryBuilder() + ) + diff --git a/src/core/tasks/url/operators/probe/filter.py b/src/core/tasks/url/operators/probe/filter.py new file mode 100644 index 00000000..4a129676 --- /dev/null +++ b/src/core/tasks/url/operators/probe/filter.py @@ -0,0 +1,8 @@ +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO + + +def filter_non_redirect_tdos(tdos: list[URLProbeTDO]) -> list[URLProbeTDO]: + return [tdo for tdo in tdos if not tdo.response.is_redirect] + +def filter_redirect_tdos(tdos: list[URLProbeTDO]) -> list[URLProbeTDO]: + return [tdo for tdo in tdos if tdo.response.is_redirect] \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/__init__.py b/src/core/tasks/url/operators/probe/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/__init__.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/convert.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/convert.py new file mode 100644 index 00000000..eb0597ba --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/convert.py @@ -0,0 +1,56 @@ +from src.core.tasks.url.operators.probe.queries.insert_redirects.models.url_response_map import URLResponseMapping +from src.core.tasks.url.operators.probe.queries.urls.exist.model import UrlExistsResult +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic + + +def convert_url_response_mapping_to_web_metadata_list( + url_response_mappings: list[URLResponseMapping] +) -> list[URLWebMetadataPydantic]: + results: list[URLWebMetadataPydantic] = [] + for url_response_mapping in url_response_mappings: + response = url_response_mapping.response + web_metadata_object = URLWebMetadataPydantic( + url_id=url_response_mapping.url_mapping.url_id, + accessed=response.status_code is not None, + status_code=response.status_code, + content_type=response.content_type, + error_message=response.error + ) + results.append(web_metadata_object) + return results + + +def convert_to_url_mappings(url_exists_results: list[UrlExistsResult]) -> list[URLMapping]: + return [ + URLMapping( + url=url_exists_result.url, + url_id=url_exists_result.url_id + ) for url_exists_result in url_exists_results + ] + + +def convert_to_url_insert_models(urls: list[str]) -> list[URLInsertModel]: + results = [] + for url in urls: + results.append( + URLInsertModel( + url=url, + source=URLSource.REDIRECT + ) + ) + return results + +def convert_tdo_to_url_response_mappings(tdos: list[URLProbeTDO]) -> list[URLResponseMapping]: + results = [] + for tdo in tdos: + results.append( + URLResponseMapping( + url_mapping=tdo.url_mapping, + response=tdo.response.response.source + ) + ) + return results \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/extract.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/extract.py new file mode 100644 index 00000000..3de66e85 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/extract.py @@ -0,0 +1,16 @@ +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.dtos.url.mapping import URLMapping +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair + + +def extract_response_pairs(tdos: list[URLProbeTDO]) -> list[URLProbeRedirectResponsePair]: + results: list[URLProbeRedirectResponsePair] = [] + for tdo in tdos: + if not tdo.response.is_redirect: + raise ValueError(f"Expected {tdo.url_mapping.url} to be a redirect.") + + response: URLProbeRedirectResponsePair = tdo.response.response + if not isinstance(response, URLProbeRedirectResponsePair): + raise ValueError(f"Expected {tdo.url_mapping.url} to be {URLProbeRedirectResponsePair.__name__}.") + results.append(response) + return results diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/filter.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/filter.py new file mode 100644 index 00000000..1f36893d --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/filter.py @@ -0,0 +1,14 @@ +from src.db.dtos.url.mapping import URLMapping + + +def filter_new_dest_urls( + url_mappings_in_db: list[URLMapping], + all_dest_urls: list[str] +) -> list[str]: + extant_destination_urls: set[str] = set([url_mapping.url for url_mapping in url_mappings_in_db]) + new_dest_urls: list[str] = [ + url + for url in all_dest_urls + if url not in extant_destination_urls + ] + return new_dest_urls \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/map.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/map.py new file mode 100644 index 00000000..53f2b2e1 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/map.py @@ -0,0 +1,19 @@ +from src.core.tasks.url.operators.probe.queries.insert_redirects.models.url_response_map import URLResponseMapping +from src.db.dtos.url.mapping import URLMapping +from src.external.url_request.probe.models.response import URLProbeResponse + + +def map_url_mappings_to_probe_responses( + url_mappings: list[URLMapping], + url_to_probe_responses: dict[str, URLProbeResponse] +) -> list[URLResponseMapping]: + results = [] + for url_mapping in url_mappings: + response = url_to_probe_responses[url_mapping.url] + results.append( + URLResponseMapping( + url_mapping=url_mapping, + response=response + ) + ) + return results \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/models/__init__.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/models/url_response_map.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/models/url_response_map.py new file mode 100644 index 00000000..efbd5db8 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/models/url_response_map.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.db.dtos.url.mapping import URLMapping +from src.external.url_request.probe.models.response import URLProbeResponse + + +class URLResponseMapping(BaseModel): + url_mapping: URLMapping + response: URLProbeResponse \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/query.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/query.py new file mode 100644 index 00000000..0ba70c47 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/query.py @@ -0,0 +1,84 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.probe.queries.insert_redirects.extract import extract_response_pairs +from src.core.tasks.url.operators.probe.queries.insert_redirects.filter import filter_new_dest_urls +from src.core.tasks.url.operators.probe.queries.insert_redirects.request_manager import InsertRedirectsRequestManager +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.dtos.url.mapping import URLMapping +from src.db.queries.base.builder import QueryBuilderBase +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair +from src.external.url_request.probe.models.response import URLProbeResponse +from src.util.url_mapper import URLMapper + + +class InsertRedirectsQueryBuilder(QueryBuilderBase): + def __init__( + self, + tdos: list[URLProbeTDO], + ): + super().__init__() + self.tdos = tdos + self.source_url_mappings = [tdo.url_mapping for tdo in self.tdos] + self._mapper = URLMapper(self.source_url_mappings) + + self._response_pairs: list[URLProbeRedirectResponsePair] = extract_response_pairs(self.tdos) + + self._destination_probe_responses: list[URLProbeResponse] = [ + pair.destination + for pair in self._response_pairs + ] + self._destination_urls: list[str] = [ + response.url + for response in self._destination_probe_responses + ] + + self._destination_url_to_probe_response_mapping: dict[str, URLProbeResponse] = { + response.url: response + for response in self._destination_probe_responses + } + + + + + async def run(self, session: AsyncSession) -> None: + """ + Modifies: + self._mapper + """ + + rm = InsertRedirectsRequestManager( + session=session + ) + + + # Get all destination URLs already in the database + dest_url_mappings_in_db: list[URLMapping] = await rm.get_url_mappings_in_db( + urls=self._destination_urls + ) + + # Filter out to only have those URLs that are new in the database + new_dest_urls: list[str] = filter_new_dest_urls( + url_mappings_in_db=dest_url_mappings_in_db, + all_dest_urls=self._destination_urls + ) + + # Add the new URLs + new_dest_url_mappings: list[URLMapping] = await rm.insert_new_urls( + urls=new_dest_urls + ) + all_dest_url_mappings: list[URLMapping] = dest_url_mappings_in_db + new_dest_url_mappings + + self._mapper.add_mappings(all_dest_url_mappings) + + # Add web metadata for new URLs + await rm.add_web_metadata( + all_dest_url_mappings=all_dest_url_mappings, + dest_url_to_probe_response_mappings=self._destination_url_to_probe_response_mapping, + tdos=self.tdos + ) + + # Add redirect links for new URLs + await rm.add_redirect_links( + response_pairs=self._response_pairs, + mapper=self._mapper + ) diff --git a/src/core/tasks/url/operators/probe/queries/insert_redirects/request_manager.py b/src/core/tasks/url/operators/probe/queries/insert_redirects/request_manager.py new file mode 100644 index 00000000..35dfded5 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/insert_redirects/request_manager.py @@ -0,0 +1,116 @@ +from typing import Sequence + +from sqlalchemy import select, tuple_, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.probe.queries.insert_redirects.convert import convert_to_url_mappings, \ + convert_to_url_insert_models, convert_tdo_to_url_response_mappings, \ + convert_url_response_mapping_to_web_metadata_list +from src.core.tasks.url.operators.probe.queries.insert_redirects.map import map_url_mappings_to_probe_responses +from src.core.tasks.url.operators.probe.queries.insert_redirects.models.url_response_map import URLResponseMapping +from src.core.tasks.url.operators.probe.queries.urls.exist.model import UrlExistsResult +from src.core.tasks.url.operators.probe.queries.urls.exist.query import URLsExistInDBQueryBuilder +from src.core.tasks.url.operators.probe.tdo import URLProbeTDO +from src.db.dtos.url.mapping import URLMapping +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.link.url_redirect_url.pydantic import LinkURLRedirectURLPydantic +from src.db.models.impl.link.url_redirect_url.sqlalchemy import LinkURLRedirectURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair +from src.external.url_request.probe.models.response import URLProbeResponse +from src.util.url_mapper import URLMapper + + +class InsertRedirectsRequestManager: + + def __init__(self, session: AsyncSession): + self.session = session + + async def get_url_mappings_in_db( + self, + urls: list[str], + ): + results: list[UrlExistsResult] = await URLsExistInDBQueryBuilder( + urls=urls + ).run(self.session) + extant_urls = [result for result in results if result.exists] + return convert_to_url_mappings(extant_urls) + + async def insert_new_urls(self, urls: list[str]) -> list[URLMapping]: + if len(urls) == 0: + return [] + deduplicated_urls = list(set(urls)) + insert_models = convert_to_url_insert_models(deduplicated_urls) + url_ids = await sh.bulk_insert(self.session, models=insert_models, return_ids=True) + url_mappings = [ + URLMapping(url=url, url_id=url_id) + for url, url_id + in zip(deduplicated_urls, url_ids) + ] + return url_mappings + + async def add_web_metadata( + self, + all_dest_url_mappings: list[URLMapping], + dest_url_to_probe_response_mappings: dict[str, URLProbeResponse], + tdos: list[URLProbeTDO], + ) -> None: + dest_url_response_mappings = map_url_mappings_to_probe_responses( + url_mappings=all_dest_url_mappings, + url_to_probe_responses=dest_url_to_probe_response_mappings + ) + src_url_response_mappings: list[URLResponseMapping] = convert_tdo_to_url_response_mappings( + tdos=tdos + ) + all_url_response_mappings: list[URLResponseMapping] = src_url_response_mappings + dest_url_response_mappings + web_metadata_list: list[URLWebMetadataPydantic] = convert_url_response_mapping_to_web_metadata_list( + all_url_response_mappings + ) + await sh.bulk_upsert(self.session, models=web_metadata_list) + + async def add_redirect_links( + self, + response_pairs: list[URLProbeRedirectResponsePair], + mapper: URLMapper + ) -> None: + # Get all existing links and exclude + link_tuples: list[tuple[int, int]] = [] + for pair in response_pairs: + source_url_id = mapper.get_id(pair.source.url) + destination_url_id = mapper.get_id(pair.destination.url) + link_tuples.append((source_url_id, destination_url_id)) + + query = ( + select( + LinkURLRedirectURL.source_url_id, + LinkURLRedirectURL.destination_url_id + ) + .where( + tuple_( + LinkURLRedirectURL.source_url_id, + LinkURLRedirectURL.destination_url_id + ).in_(link_tuples) + ) + ) + mappings: Sequence[RowMapping] = await sh.mappings(self.session, query=query) + existing_links: set[tuple[int, int]] = { + (mapping["source_url_id"], mapping["destination_url_id"]) + for mapping in mappings + } + new_links: list[tuple[int, int]] = [ + (source_url_id, destination_url_id) + for source_url_id, destination_url_id in link_tuples + if (source_url_id, destination_url_id) not in existing_links + ] + + + links: list[LinkURLRedirectURLPydantic] = [] + for link in new_links: + source_url_id, destination_url_id = link + link = LinkURLRedirectURLPydantic( + source_url_id=source_url_id, + destination_url_id=destination_url_id + ) + links.append(link) + await sh.bulk_insert(self.session, models=links) diff --git a/src/core/tasks/url/operators/probe/queries/urls/__init__.py b/src/core/tasks/url/operators/probe/queries/urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/urls/exist/__init__.py b/src/core/tasks/url/operators/probe/queries/urls/exist/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/urls/exist/model.py b/src/core/tasks/url/operators/probe/queries/urls/exist/model.py new file mode 100644 index 00000000..1245044c --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/urls/exist/model.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class UrlExistsResult(BaseModel): + url: str + url_id: int | None + + @property + def exists(self): + return self.url_id is not None \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/urls/exist/query.py b/src/core/tasks/url/operators/probe/queries/urls/exist/query.py new file mode 100644 index 00000000..5176add9 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/urls/exist/query.py @@ -0,0 +1,29 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.probe.queries.urls.exist.model import UrlExistsResult +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class URLsExistInDBQueryBuilder(QueryBuilderBase): + """Checks if URLs exist in the database.""" + + def __init__(self, urls: list[str]): + super().__init__() + self.urls = urls + + async def run(self, session: AsyncSession) -> list[UrlExistsResult]: + query = select(URL.id, URL.url).where(URL.url.in_(self.urls)) + db_mappings = await sh.mappings(session, query=query) + + url_to_id_map: dict[str, int] = { + row["url"]: row["id"] + for row in db_mappings + } + return [ + UrlExistsResult( + url=url, + url_id=url_to_id_map.get(url) + ) for url in self.urls + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/queries/urls/not_probed/__init__.py b/src/core/tasks/url/operators/probe/queries/urls/not_probed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/urls/not_probed/exists.py b/src/core/tasks/url/operators/probe/queries/urls/not_probed/exists.py new file mode 100644 index 00000000..5954c197 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/urls/not_probed/exists.py @@ -0,0 +1,35 @@ +from datetime import timedelta, datetime + +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import override, final + +from src.db.enums import TaskType +from src.db.helpers.query import not_exists_url, no_url_task_error +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata +from src.db.queries.base.builder import QueryBuilderBase + +@final +class HasURLsWithoutProbeQueryBuilder(QueryBuilderBase): + + @override + async def run(self, session: AsyncSession) -> bool: + query = ( + select( + URL.id + ) + .outerjoin( + URLWebMetadata, + URL.id == URLWebMetadata.url_id + ) + .where( + or_( + URLWebMetadata.id.is_(None), + URLWebMetadata.updated_at < datetime.now() - timedelta(days=30) + ), + no_url_task_error(TaskType.PROBE_URL) + ) + ) + return await sh.has_results(session, query=query) diff --git a/src/core/tasks/url/operators/probe/queries/urls/not_probed/get/__init__.py b/src/core/tasks/url/operators/probe/queries/urls/not_probed/get/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/probe/queries/urls/not_probed/get/query.py b/src/core/tasks/url/operators/probe/queries/urls/not_probed/get/query.py new file mode 100644 index 00000000..36450252 --- /dev/null +++ b/src/core/tasks/url/operators/probe/queries/urls/not_probed/get/query.py @@ -0,0 +1,43 @@ +from datetime import timedelta, datetime + +from sqlalchemy import select, or_ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import override, final + +from src.util.clean import clean_url +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +@final +class GetURLsWithoutProbeQueryBuilder(QueryBuilderBase): + + @override + async def run(self, session: AsyncSession) -> list[URLMapping]: + query = ( + select( + URL.id.label("url_id"), + URL.url + ) + .outerjoin( + URLWebMetadata, + URL.id == URLWebMetadata.url_id + ) + .where( + or_( + URLWebMetadata.id.is_(None), + URLWebMetadata.updated_at < datetime.now() - timedelta(days=30) + ) + ) + .limit(500) + ) + db_mappings = await sh.mappings(session, query=query) + return [ + URLMapping( + url_id=mapping["url_id"], + url=clean_url(mapping["url"]) + ) for mapping in db_mappings + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/probe/tdo.py b/src/core/tasks/url/operators/probe/tdo.py new file mode 100644 index 00000000..5208fd80 --- /dev/null +++ b/src/core/tasks/url/operators/probe/tdo.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel + +from src.db.dtos.url.mapping import URLMapping +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper + + +class URLProbeTDO(BaseModel): + url_mapping: URLMapping + response: URLProbeResponseOuterWrapper | None = None diff --git a/src/core/tasks/url/operators/record_type/core.py b/src/core/tasks/url/operators/record_type/core.py index ce73ceb4..8e31fa8d 100644 --- a/src/core/tasks/url/operators/record_type/core.py +++ b/src/core/tasks/url/operators/record_type/core.py @@ -1,10 +1,10 @@ -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.enums import TaskType -from src.core.tasks.url.operators.record_type.tdo import URLRecordTypeTDO -from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.core.enums import RecordType +from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.core.tasks.url.operators.record_type.llm_api.record_classifier.openai import OpenAIRecordClassifier +from src.core.tasks.url.operators.record_type.tdo import URLRecordTypeTDO +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall class URLRecordTypeTaskOperator(URLTaskOperatorBase): @@ -42,15 +42,14 @@ async def inner_task_logic(self): await self.update_errors_in_database(error_subset) async def update_errors_in_database(self, tdos: list[URLRecordTypeTDO]): - error_infos = [] + task_errors: list[URLTaskErrorSmall] = [] for tdo in tdos: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, + error_info = URLTaskErrorSmall( url_id=tdo.url_with_html.url_id, error=tdo.error ) - error_infos.append(error_info) - await self.adb_client.add_url_error_infos(error_infos) + task_errors.append(error_info) + await self.add_task_errors(task_errors) async def put_results_into_database(self, tdos: list[URLRecordTypeTDO]): suggestions = [] diff --git a/src/core/tasks/url/operators/record_type/llm_api/record_classifier/base.py b/src/core/tasks/url/operators/record_type/llm_api/record_classifier/base.py index b995bda9..1268e4e5 100644 --- a/src/core/tasks/url/operators/record_type/llm_api/record_classifier/base.py +++ b/src/core/tasks/url/operators/record_type/llm_api/record_classifier/base.py @@ -70,8 +70,3 @@ async def classify_url(self, content_infos: list[URLHTMLContentInfo]) -> str: response_format=self.response_format ) return self.post_process_response(response) - - result_str = response.choices[0].message.content - - result_dict = json.loads(result_str) - return result_dict["record_type"] \ No newline at end of file diff --git a/src/core/tasks/url/operators/record_type/tdo.py b/src/core/tasks/url/operators/record_type/tdo.py index 43a32bab..3effcf53 100644 --- a/src/core/tasks/url/operators/record_type/tdo.py +++ b/src/core/tasks/url/operators/record_type/tdo.py @@ -8,8 +8,8 @@ class URLRecordTypeTDO(BaseModel): url_with_html: URLWithHTML - record_type: Optional[RecordType] = None - error: Optional[str] = None + record_type: RecordType | None = None + error: str | None = None def is_errored(self): return self.error is not None \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/__init__.py b/src/core/tasks/url/operators/root_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/root_url/convert.py b/src/core/tasks/url/operators/root_url/convert.py new file mode 100644 index 00000000..405cbc49 --- /dev/null +++ b/src/core/tasks/url/operators/root_url/convert.py @@ -0,0 +1,49 @@ +from src.core.tasks.url.operators.root_url.extract import extract_root_url +from src.core.tasks.url.operators.root_url.models.root_mapping import URLRootURLMapping +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.root_url.pydantic import FlagRootURLPydantic +from src.db.models.impl.link.urls_root_url.pydantic import LinkURLRootURLPydantic +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.util.url_mapper import URLMapper + + +def convert_to_flag_root_url_pydantic(url_ids: list[int]) -> list[FlagRootURLPydantic]: + return [FlagRootURLPydantic(url_id=url_id) for url_id in url_ids] + +def convert_to_url_root_url_mapping(url_mappings: list[URLMapping]) -> list[URLRootURLMapping]: + return [ + URLRootURLMapping( + url=mapping.url, + root_url=extract_root_url(mapping.url) + ) for mapping in url_mappings + ] + +def convert_to_url_insert_models(urls: list[str]) -> list[URLInsertModel]: + return [ + URLInsertModel( + url=url, + source=URLSource.ROOT_URL + ) for url in urls + ] + +def convert_to_root_url_links( + root_db_mappings: list[URLMapping], + branch_db_mappings: list[URLMapping], + url_root_url_mappings: list[URLRootURLMapping] +) -> list[LinkURLRootURLPydantic]: + root_mapper = URLMapper(root_db_mappings) + branch_mapper = URLMapper(branch_db_mappings) + results: list[LinkURLRootURLPydantic] = [] + + for url_root_url_mapping in url_root_url_mappings: + root_url_id = root_mapper.get_id(url_root_url_mapping.root_url) + branch_url_id = branch_mapper.get_id(url_root_url_mapping.url) + + results.append( + LinkURLRootURLPydantic( + root_url_id=root_url_id, + url_id=branch_url_id) + ) + + return results diff --git a/src/core/tasks/url/operators/root_url/core.py b/src/core/tasks/url/operators/root_url/core.py new file mode 100644 index 00000000..e32654da --- /dev/null +++ b/src/core/tasks/url/operators/root_url/core.py @@ -0,0 +1,162 @@ +from typing import final + +from typing_extensions import override + +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.root_url.convert import convert_to_flag_root_url_pydantic, \ + convert_to_url_root_url_mapping, convert_to_url_insert_models, convert_to_root_url_links +from src.core.tasks.url.operators.root_url.models.root_mapping import URLRootURLMapping +from src.core.tasks.url.operators.root_url.queries.get import GetURLsForRootURLTaskQueryBuilder +from src.core.tasks.url.operators.root_url.queries.lookup.query import LookupRootURLsQueryBuilder +from src.core.tasks.url.operators.root_url.queries.lookup.response import LookupRootsURLResponse +from src.core.tasks.url.operators.root_url.queries.prereq import CheckPrereqsForRootURLTaskQueryBuilder +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType +from src.db.models.impl.flag.root_url.pydantic import FlagRootURLPydantic +from src.db.models.impl.link.urls_root_url.pydantic import LinkURLRootURLPydantic +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.util.url_mapper import URLMapper + + +@final +class URLRootURLTaskOperator(URLTaskOperatorBase): + + def __init__(self, adb_client: AsyncDatabaseClient): + super().__init__(adb_client) + + @override + async def meets_task_prerequisites(self) -> bool: + builder = CheckPrereqsForRootURLTaskQueryBuilder() + return await self.adb_client.run_query_builder(builder) + + @property + @override + def task_type(self) -> TaskType: + return TaskType.ROOT_URL + + @override + async def inner_task_logic(self) -> None: + all_task_mappings: list[URLMapping] = await self._get_urls_for_root_url_task() + + await self.link_urls_to_task( + url_ids=[mapping.url_id for mapping in all_task_mappings] + ) + + # Get the Root URLs for all URLs + mapper = URLMapper(all_task_mappings) + + # -- Identify and Derive Root URLs -- + + root_url_mappings: list[URLRootURLMapping] = convert_to_url_root_url_mapping(all_task_mappings) + + # For those where the URL is also the Root URL, separate them + original_root_urls: list[str] = [mapping.url for mapping in root_url_mappings if mapping.is_root_url] + derived_root_urls: list[str] = [mapping.root_url for mapping in root_url_mappings if not mapping.is_root_url] + + # -- Add new Derived Root URLs -- + + # For derived Root URLs, we need to check if they are already in the database + derived_root_url_lookup_responses: list[LookupRootsURLResponse] = await self._lookup_root_urls(derived_root_urls) + + # For those not already in the database, we need to add them and get their mappings + derived_root_urls_not_in_db: list[str] = [ + response.url + for response in derived_root_url_lookup_responses + if response.url_id is None + ] + new_derived_root_url_mappings: list[URLMapping] = await self._add_new_urls(derived_root_urls_not_in_db) + + # Add these to the mapper + mapper.add_mappings(new_derived_root_url_mappings) + + # -- Flag Root URLs -- + + # Of those we obtain, we need to get those that are not yet flagged as Root URLs + extant_derived_root_url_ids_not_flagged: list[int] = [ + response.url_id + for response in derived_root_url_lookup_responses + if response.url_id is not None and not response.flagged_as_root + ] + original_root_url_ids_not_flagged: list[int] = [ + mapper.get_id(url) + for url in original_root_urls + ] + new_derived_root_url_ids_not_flagged: list[int] = [ + mapping.url_id + for mapping in new_derived_root_url_mappings + ] + + all_root_url_ids_not_flagged: list[int] = list(set( + extant_derived_root_url_ids_not_flagged + + new_derived_root_url_ids_not_flagged + + original_root_url_ids_not_flagged + )) + + await self._flag_root_urls(all_root_url_ids_not_flagged) + + # -- Add Root URL Links -- + + branch_url_mappings: list[URLRootURLMapping] = [mapping for mapping in root_url_mappings if not mapping.is_root_url] + await self._add_root_url_links( + mapper, + root_url_mappings=branch_url_mappings, + ) + + async def _add_root_url_links( + self, + mapper: URLMapper, + root_url_mappings: list[URLRootURLMapping], + ): + # For all task URLs that are not root URLs (i.e. 'branch' URLs): + # - Connect them to the Root URL + # - Add the link + + branch_urls: list[str] = [mapping.url for mapping in root_url_mappings] + root_urls: list[str] = [mapping.root_url for mapping in root_url_mappings] + + root_url_db_mappings: list[URLMapping] = await self._lookup_root_urls(root_urls) + task_url_db_mappings: list[URLMapping] = mapper.get_mappings_by_url(branch_urls) + + links: list[LinkURLRootURLPydantic] = convert_to_root_url_links( + root_db_mappings=root_url_db_mappings, + branch_db_mappings=task_url_db_mappings, + url_root_url_mappings=root_url_mappings + ) + await self._add_link_url_root_urls(links) + + async def _flag_root_urls( + self, + url_ids: list[int] + ): + await self._flag_as_root_urls(url_ids) + + async def _get_urls_for_root_url_task(self) -> list[URLMapping]: + builder = GetURLsForRootURLTaskQueryBuilder() + return await self.adb_client.run_query_builder(builder) + + async def _lookup_root_urls(self, urls: list[str]) -> list[LookupRootsURLResponse]: + builder = LookupRootURLsQueryBuilder(urls=list(set(urls))) + return await self.adb_client.run_query_builder(builder) + + async def _add_new_urls(self, urls: list[str]) -> list[URLMapping]: + if len(urls) == 0: + return [] + insert_models: list[URLInsertModel] = convert_to_url_insert_models(urls) + url_ids: list[int] = await self.adb_client.bulk_insert(insert_models, return_ids=True) + mappings: list[URLMapping] = [] + for url, url_id in zip(urls, url_ids): + mappings.append( + URLMapping( + url=url, + url_id=url_id + ) + ) + return mappings + + async def _flag_as_root_urls(self, url_ids: list[int]) -> None: + flag_root_urls: list[FlagRootURLPydantic] = convert_to_flag_root_url_pydantic(url_ids) + await self.adb_client.bulk_insert(flag_root_urls) + + async def _add_link_url_root_urls(self, links: list[LinkURLRootURLPydantic]) -> None: + await self.adb_client.bulk_insert(links) diff --git a/src/core/tasks/url/operators/root_url/extract.py b/src/core/tasks/url/operators/root_url/extract.py new file mode 100644 index 00000000..e384fd15 --- /dev/null +++ b/src/core/tasks/url/operators/root_url/extract.py @@ -0,0 +1,7 @@ +from urllib.parse import urlparse, ParseResult + + +def extract_root_url(url: str) -> str: + parsed_url: ParseResult = urlparse(url) + root_url = f"{parsed_url.scheme}://{parsed_url.netloc}" + return root_url \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/models/__init__.py b/src/core/tasks/url/operators/root_url/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/root_url/models/root_mapping.py b/src/core/tasks/url/operators/root_url/models/root_mapping.py new file mode 100644 index 00000000..7b115f36 --- /dev/null +++ b/src/core/tasks/url/operators/root_url/models/root_mapping.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class URLRootURLMapping(BaseModel): + url: str + root_url: str + + @property + def is_root_url(self) -> bool: + return self.url == self.root_url \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/queries/__init__.py b/src/core/tasks/url/operators/root_url/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/root_url/queries/_shared/__init__.py b/src/core/tasks/url/operators/root_url/queries/_shared/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/root_url/queries/_shared/urls_without_root_id.py b/src/core/tasks/url/operators/root_url/queries/_shared/urls_without_root_id.py new file mode 100644 index 00000000..f573133f --- /dev/null +++ b/src/core/tasks/url/operators/root_url/queries/_shared/urls_without_root_id.py @@ -0,0 +1,28 @@ +""" +A query to retrieve URLS that either +- are not a root URL +- are not already linked to a root URL + +""" + +from sqlalchemy import select + +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.sqlalchemy import URL + +URLS_WITHOUT_ROOT_ID_QUERY = ( + select( + URL.id, + URL.url + ).outerjoin( + FlagRootURL, + URL.id == FlagRootURL.url_id + ).outerjoin( + LinkURLRootURL, + URL.id == LinkURLRootURL.url_id + ).where( + FlagRootURL.url_id.is_(None), + LinkURLRootURL.url_id.is_(None) + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/queries/get.py b/src/core/tasks/url/operators/root_url/queries/get.py new file mode 100644 index 00000000..3643f343 --- /dev/null +++ b/src/core/tasks/url/operators/root_url/queries/get.py @@ -0,0 +1,23 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from typing_extensions import override + +from src.core.tasks.url.operators.root_url.queries._shared.urls_without_root_id import URLS_WITHOUT_ROOT_ID_QUERY +from src.db.dtos.url.mapping import URLMapping +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsForRootURLTaskQueryBuilder(QueryBuilderBase): + + @override + async def run(self, session: AsyncSession) -> list[URLMapping]: + query = ( + URLS_WITHOUT_ROOT_ID_QUERY + ) + mappings = await sh.mappings(session, query=query) + return [ + URLMapping( + url_id=mapping["id"], + url=mapping["url"] + ) for mapping in mappings + ] \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/queries/lookup/__init__.py b/src/core/tasks/url/operators/root_url/queries/lookup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/root_url/queries/lookup/query.py b/src/core/tasks/url/operators/root_url/queries/lookup/query.py new file mode 100644 index 00000000..88e1112e --- /dev/null +++ b/src/core/tasks/url/operators/root_url/queries/lookup/query.py @@ -0,0 +1,58 @@ +from sqlalchemy import select, case +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.root_url.queries.lookup.response import LookupRootsURLResponse +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class LookupRootURLsQueryBuilder(QueryBuilderBase): + """ + Looks up URLs to see if they exist in the database as root URLs + """ + + def __init__(self, urls: list[str]): + super().__init__() + self.urls = urls + + async def run(self, session: AsyncSession) -> list[LookupRootsURLResponse]: + + # Run query + query = select( + URL.id, + URL.url, + case( + (FlagRootURL.url_id.is_(None), False), + else_=True + ).label("flagged_as_root") + ).outerjoin(FlagRootURL).where( + URL.url.in_(self.urls), + ) + mappings = await sh.mappings(session, query=query) + + # Store results in intermediate map + url_to_response_map: dict[str, LookupRootsURLResponse] = {} + for mapping in mappings: + url = mapping["url"] + response = LookupRootsURLResponse( + url=url, + url_id=mapping["id"], + flagged_as_root=mapping["flagged_as_root"] + ) + url_to_response_map[url] = response + + # Iterate through original URLs and add missing responses + results: list[LookupRootsURLResponse] = [] + for url in self.urls: + response = url_to_response_map.get(url) + if response is None: + response = LookupRootsURLResponse( + url=url, + url_id=None, + flagged_as_root=False + ) + results.append(response) + + return results diff --git a/src/core/tasks/url/operators/root_url/queries/lookup/response.py b/src/core/tasks/url/operators/root_url/queries/lookup/response.py new file mode 100644 index 00000000..ea21b38d --- /dev/null +++ b/src/core/tasks/url/operators/root_url/queries/lookup/response.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel, model_validator + + +class LookupRootsURLResponse(BaseModel): + url: str + url_id: int | None + flagged_as_root: bool + + @property + def exists_in_db(self) -> bool: + return self.url_id is not None + + @model_validator(mode='after') + def validate_flagged_as_root(self): + if self.flagged_as_root and self.url_id is None: + raise ValueError('URL ID should be provided if flagged as root') + return self \ No newline at end of file diff --git a/src/core/tasks/url/operators/root_url/queries/prereq.py b/src/core/tasks/url/operators/root_url/queries/prereq.py new file mode 100644 index 00000000..e447f9d9 --- /dev/null +++ b/src/core/tasks/url/operators/root_url/queries/prereq.py @@ -0,0 +1,19 @@ +from typing_extensions import override + +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.root_url.queries._shared.urls_without_root_id import URLS_WITHOUT_ROOT_ID_QUERY +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class CheckPrereqsForRootURLTaskQueryBuilder(QueryBuilderBase): + + @override + async def run(self, session: AsyncSession) -> bool: + query = ( + URLS_WITHOUT_ROOT_ID_QUERY + .limit(1) + ) + result = await sh.one_or_none(session, query=query) + return result is not None \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/__init__.py b/src/core/tasks/url/operators/screenshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/screenshot/constants.py b/src/core/tasks/url/operators/screenshot/constants.py new file mode 100644 index 00000000..b41f697d --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/constants.py @@ -0,0 +1,4 @@ + + + +TASK_URL_LIMIT: int = 25 \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/convert.py b/src/core/tasks/url/operators/screenshot/convert.py new file mode 100644 index 00000000..09904ff1 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/convert.py @@ -0,0 +1,29 @@ +from src.core.tasks.url.operators.screenshot.models.outcome import URLScreenshotOutcome +from src.db.models.impl.url.screenshot.pydantic import URLScreenshotPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +def convert_to_url_screenshot_pydantic( + outcomes: list[URLScreenshotOutcome] +) -> list[URLScreenshotPydantic]: + results: list[URLScreenshotPydantic] = [] + for outcome in outcomes: + result = URLScreenshotPydantic( + url_id=outcome.url_id, + content=outcome.screenshot, + file_size=len(outcome.screenshot), + ) + results.append(result) + return results + +def convert_to_task_error( + outcomes: list[URLScreenshotOutcome] +) -> list[URLTaskErrorSmall]: + results: list[URLTaskErrorSmall] = [] + for outcome in outcomes: + result = URLTaskErrorSmall( + url_id=outcome.url_id, + error=outcome.error, + ) + results.append(result) + return results diff --git a/src/core/tasks/url/operators/screenshot/core.py b/src/core/tasks/url/operators/screenshot/core.py new file mode 100644 index 00000000..96627ab8 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/core.py @@ -0,0 +1,62 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.screenshot.convert import convert_to_url_screenshot_pydantic, \ + convert_to_task_error +from src.core.tasks.url.operators.screenshot.filter import filter_success_outcomes +from src.core.tasks.url.operators.screenshot.get import get_url_screenshots +from src.core.tasks.url.operators.screenshot.models.outcome import URLScreenshotOutcome +from src.core.tasks.url.operators.screenshot.models.subsets import URLScreenshotOutcomeSubsets +from src.core.tasks.url.operators.screenshot.queries.get import GetURLsForScreenshotTaskQueryBuilder +from src.core.tasks.url.operators.screenshot.queries.prereq import URLsForScreenshotTaskPrerequisitesQueryBuilder +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType +from src.db.models.impl.url.screenshot.pydantic import URLScreenshotPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +class URLScreenshotTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + ): + super().__init__(adb_client) + + @property + def task_type(self) -> TaskType: + return TaskType.SCREENSHOT + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + URLsForScreenshotTaskPrerequisitesQueryBuilder() + ) + + async def get_urls_without_screenshot(self) -> list[URLMapping]: + return await self.adb_client.run_query_builder( + GetURLsForScreenshotTaskQueryBuilder() + ) + + async def upload_screenshots(self, outcomes: list[URLScreenshotOutcome]) -> None: + insert_models: list[URLScreenshotPydantic] = convert_to_url_screenshot_pydantic(outcomes) + await self.adb_client.bulk_insert(insert_models) + + async def upload_errors(self, outcomes: list[URLScreenshotOutcome]) -> None: + insert_models: list[URLTaskErrorSmall] = convert_to_task_error( + outcomes=outcomes, + ) + await self.add_task_errors(insert_models) + + async def inner_task_logic(self) -> None: + url_mappings: list[URLMapping] = await self.get_urls_without_screenshot() + await self.link_urls_to_task( + url_ids=[url_mapping.url_id for url_mapping in url_mappings] + ) + + outcomes: list[URLScreenshotOutcome] = await get_url_screenshots( + mappings=url_mappings + ) + + subsets: URLScreenshotOutcomeSubsets = filter_success_outcomes(outcomes) + await self.upload_screenshots(subsets.success) + await self.upload_errors(subsets.failed) + diff --git a/src/core/tasks/url/operators/screenshot/filter.py b/src/core/tasks/url/operators/screenshot/filter.py new file mode 100644 index 00000000..97cb5c89 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/filter.py @@ -0,0 +1,13 @@ +from src.core.tasks.url.operators.screenshot.models.outcome import URLScreenshotOutcome +from src.core.tasks.url.operators.screenshot.models.subsets import URLScreenshotOutcomeSubsets + + +def filter_success_outcomes(outcomes: list[URLScreenshotOutcome]) -> URLScreenshotOutcomeSubsets: + success: list[URLScreenshotOutcome] = [] + failed: list[URLScreenshotOutcome] = [] + for outcome in outcomes: + if outcome.success: + success.append(outcome) + else: + failed.append(outcome) + return URLScreenshotOutcomeSubsets(success=success, failed=failed) \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/get.py b/src/core/tasks/url/operators/screenshot/get.py new file mode 100644 index 00000000..7c0d6a42 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/get.py @@ -0,0 +1,22 @@ +from src.core.tasks.url.operators.screenshot.models.outcome import URLScreenshotOutcome +from src.db.dtos.url.mapping import URLMapping +from src.external.url_request.dtos.screenshot_response import URLScreenshotResponse +from src.external.url_request.screenshot_.core import get_screenshots +from src.util.url_mapper import URLMapper + + +async def get_url_screenshots(mappings: list[URLMapping]) -> list[URLScreenshotOutcome]: + mapper = URLMapper(mappings) + responses: list[URLScreenshotResponse] = await get_screenshots( + urls=mapper.get_all_urls() + ) + outcomes: list[URLScreenshotOutcome] = [] + for response in responses: + url_id: int = mapper.get_id(response.url) + outcome = URLScreenshotOutcome( + url_id=url_id, + screenshot=response.screenshot, + error=response.error, + ) + outcomes.append(outcome) + return outcomes diff --git a/src/core/tasks/url/operators/screenshot/models/__init__.py b/src/core/tasks/url/operators/screenshot/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/screenshot/models/outcome.py b/src/core/tasks/url/operators/screenshot/models/outcome.py new file mode 100644 index 00000000..4940b903 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/models/outcome.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class URLScreenshotOutcome(BaseModel): + url_id: int + screenshot: bytes | None + error: str | None + + @property + def success(self) -> bool: + return self.error is None \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/models/subsets.py b/src/core/tasks/url/operators/screenshot/models/subsets.py new file mode 100644 index 00000000..070171e6 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/models/subsets.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.core.tasks.url.operators.screenshot.models.outcome import URLScreenshotOutcome + + +class URLScreenshotOutcomeSubsets(BaseModel): + success: list[URLScreenshotOutcome] + failed: list[URLScreenshotOutcome] \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/queries/__init__.py b/src/core/tasks/url/operators/screenshot/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/screenshot/queries/cte.py b/src/core/tasks/url/operators/screenshot/queries/cte.py new file mode 100644 index 00000000..d961aabf --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/queries/cte.py @@ -0,0 +1,37 @@ +from sqlalchemy import CTE, select, Column + +from src.db.enums import TaskType +from src.db.helpers.query import url_not_validated, not_exists_url, no_url_task_error +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata + + +class URLScreenshotPrerequisitesCTEContainer: + + def __init__(self): + self._cte: CTE = ( + select( + URL.id.label("url_id"), + URL.url, + ) + .join( + URLWebMetadata, + URL.id == URLWebMetadata.url_id + ) + .where( + url_not_validated(), + not_exists_url(URLScreenshot), + no_url_task_error(TaskType.SCREENSHOT), + URLWebMetadata.status_code == 200, + ) + .cte("url_screenshot_prerequisites") + ) + + @property + def url_id(self) -> Column[int]: + return self._cte.c.url_id + + @property + def url(self) -> Column[str]: + return self._cte.c.url \ No newline at end of file diff --git a/src/core/tasks/url/operators/screenshot/queries/get.py b/src/core/tasks/url/operators/screenshot/queries/get.py new file mode 100644 index 00000000..e2dd94df --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/queries/get.py @@ -0,0 +1,25 @@ +from typing import Any, Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.screenshot.constants import TASK_URL_LIMIT +from src.core.tasks.url.operators.screenshot.queries.cte import URLScreenshotPrerequisitesCTEContainer +from src.db.dtos.url.mapping import URLMapping +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class GetURLsForScreenshotTaskQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[URLMapping]: + cte = URLScreenshotPrerequisitesCTEContainer() + + query = select( + cte.url_id, + cte.url, + ).limit(TASK_URL_LIMIT) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + return [URLMapping(**mapping) for mapping in mappings] diff --git a/src/core/tasks/url/operators/screenshot/queries/prereq.py b/src/core/tasks/url/operators/screenshot/queries/prereq.py new file mode 100644 index 00000000..885b8ad4 --- /dev/null +++ b/src/core/tasks/url/operators/screenshot/queries/prereq.py @@ -0,0 +1,21 @@ +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.screenshot.queries.cte import URLScreenshotPrerequisitesCTEContainer +from src.db.queries.base.builder import QueryBuilderBase + +from src.db.helpers.session import session_helper as sh + +class URLsForScreenshotTaskPrerequisitesQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> Any: + cte = URLScreenshotPrerequisitesCTEContainer() + + query = select( + cte.url_id, + cte.url, + ).limit(1) + + return await sh.results_exist(session=session, query=query) diff --git a/src/core/tasks/url/operators/submit_approved/__init__.py b/src/core/tasks/url/operators/submit_approved/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/submit_approved/convert.py b/src/core/tasks/url/operators/submit_approved/convert.py new file mode 100644 index 00000000..1c4a8298 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/convert.py @@ -0,0 +1,19 @@ +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall + + +async def convert_to_task_errors( + submitted_url_infos: list[SubmittedURLInfo] +) -> list[URLTaskErrorSmall]: + task_errors: list[URLTaskErrorSmall] = [] + error_response_objects = [ + response_object for response_object in submitted_url_infos + if response_object.request_error is not None + ] + for error_response_object in error_response_objects: + error_info = URLTaskErrorSmall( + url_id=error_response_object.url_id, + error=error_response_object.request_error, + ) + task_errors.append(error_info) + return task_errors diff --git a/src/core/tasks/url/operators/submit_approved/core.py b/src/core/tasks/url/operators/submit_approved/core.py new file mode 100644 index 00000000..e16a1269 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/core.py @@ -0,0 +1,50 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.submit_approved.convert import convert_to_task_errors +from src.core.tasks.url.operators.submit_approved.filter import filter_successes +from src.core.tasks.url.operators.submit_approved.queries.get import GetValidatedURLsQueryBuilder +from src.core.tasks.url.operators.submit_approved.queries.has_validated import HasValidatedURLsQueryBuilder +from src.core.tasks.url.operators.submit_approved.tdo import SubmitApprovedURLTDO, SubmittedURLInfo +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall +from src.external.pdap.client import PDAPClient + + +class SubmitApprovedURLTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + pdap_client: PDAPClient + ): + super().__init__(adb_client) + self.pdap_client = pdap_client + + @property + def task_type(self): + return TaskType.SUBMIT_APPROVED + + async def meets_task_prerequisites(self): + return await self.adb_client.run_query_builder(HasValidatedURLsQueryBuilder()) + + async def inner_task_logic(self): + # Retrieve all URLs that are validated and not submitted + tdos: list[SubmitApprovedURLTDO] = await self.get_validated_urls() + + # Link URLs to this task + await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) + + # Submit each URL, recording errors if they exist + submitted_url_infos: list[SubmittedURLInfo] = await self.pdap_client.submit_data_source_urls(tdos) + + task_errors: list[URLTaskErrorSmall] = await convert_to_task_errors(submitted_url_infos) + success_infos = await filter_successes(submitted_url_infos) + + # Update the database for successful submissions + await self.adb_client.mark_urls_as_submitted(infos=success_infos) + + # Update the database for failed submissions + await self.add_task_errors(task_errors) + + async def get_validated_urls(self) -> list[SubmitApprovedURLTDO]: + return await self.adb_client.run_query_builder(GetValidatedURLsQueryBuilder()) diff --git a/src/core/tasks/url/operators/submit_approved/filter.py b/src/core/tasks/url/operators/submit_approved/filter.py new file mode 100644 index 00000000..4ba2fad8 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/filter.py @@ -0,0 +1,11 @@ +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo + + +async def filter_successes( + submitted_url_infos: list[SubmittedURLInfo] +) -> list[SubmittedURLInfo]: + success_infos = [ + response_object for response_object in submitted_url_infos + if response_object.data_source_id is not None + ] + return success_infos diff --git a/src/core/tasks/url/operators/submit_approved/queries/__init__.py b/src/core/tasks/url/operators/submit_approved/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/submit_approved/queries/cte.py b/src/core/tasks/url/operators/submit_approved/queries/cte.py new file mode 100644 index 00000000..cf7ccb71 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/queries/cte.py @@ -0,0 +1,31 @@ +from sqlalchemy import CTE, select, exists +from sqlalchemy.orm import aliased + +from src.collectors.enums import URLStatus +from src.db.enums import TaskType +from src.db.helpers.query import not_exists_url, no_url_task_error +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource + +VALIDATED_URLS_WITHOUT_DS_SQ =( + select(URL) + .join( + FlagURLValidated, + FlagURLValidated.url_id == URL.id + ) + .where( + URL.status == URLStatus.OK, + URL.name.isnot(None), + FlagURLValidated.type == URLType.DATA_SOURCE, + not_exists_url(URLDataSource), + no_url_task_error(TaskType.SUBMIT_APPROVED) + ) + .subquery() +) + +VALIDATED_URLS_WITHOUT_DS_ALIAS = aliased( + URL, + VALIDATED_URLS_WITHOUT_DS_SQ +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_approved/queries/get.py b/src/core/tasks/url/operators/submit_approved/queries/get.py new file mode 100644 index 00000000..d4138f9a --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/queries/get.py @@ -0,0 +1,68 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.core.tasks.url.operators.submit_approved.queries.cte import VALIDATED_URLS_WITHOUT_DS_ALIAS +from src.core.tasks.url.operators.submit_approved.tdo import SubmitApprovedURLTDO +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetValidatedURLsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[SubmitApprovedURLTDO]: + query = await self._build_query() + urls = await sh.scalars(session, query) + return await self._process_results(urls) + + async def _process_results(self, urls): + results: list[SubmitApprovedURLTDO] = [] + for url in urls: + try: + tdo = await self._process_result(url) + except Exception as e: + raise ValueError(f"Failed to process url {url.id}") from e + results.append(tdo) + return results + + @staticmethod + async def _build_query(): + query = ( + select(VALIDATED_URLS_WITHOUT_DS_ALIAS) + .options( + selectinload(VALIDATED_URLS_WITHOUT_DS_ALIAS.optional_data_source_metadata), + selectinload(VALIDATED_URLS_WITHOUT_DS_ALIAS.confirmed_agencies), + selectinload(VALIDATED_URLS_WITHOUT_DS_ALIAS.reviewing_user), + selectinload(VALIDATED_URLS_WITHOUT_DS_ALIAS.record_type), + ).limit(100) + ) + return query + + @staticmethod + async def _process_result(url: URL) -> SubmitApprovedURLTDO: + agency_ids = [] + for agency in url.confirmed_agencies: + agency_ids.append(agency.agency_id) + optional_metadata = url.optional_data_source_metadata + if optional_metadata is None: + record_formats = None + data_portal_type = None + supplying_entity = None + else: + record_formats = optional_metadata.record_formats + data_portal_type = optional_metadata.data_portal_type + supplying_entity = optional_metadata.supplying_entity + tdo = SubmitApprovedURLTDO( + url_id=url.id, + url=url.url, + name=url.name, + agency_ids=agency_ids, + description=url.description, + record_type=url.record_type.record_type, + record_formats=record_formats, + data_portal_type=data_portal_type, + supplying_entity=supplying_entity, + approving_user_id=url.reviewing_user.user_id + ) + return tdo \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_approved/queries/has_validated.py b/src/core/tasks/url/operators/submit_approved/queries/has_validated.py new file mode 100644 index 00000000..2cbee486 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/queries/has_validated.py @@ -0,0 +1,18 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.submit_approved.queries.cte import VALIDATED_URLS_WITHOUT_DS_ALIAS +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class HasValidatedURLsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + query = ( + select(VALIDATED_URLS_WITHOUT_DS_ALIAS) + .limit(1) + ) + url: URL | None = await sh.one_or_none(session, query=query) + return url is not None \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_approved/queries/mark_submitted.py b/src/core/tasks/url/operators/submit_approved/queries/mark_submitted.py new file mode 100644 index 00000000..4ebfef56 --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/queries/mark_submitted.py @@ -0,0 +1,29 @@ +from sqlalchemy import update +from sqlalchemy.ext.asyncio import AsyncSession + +from src.collectors.enums import URLStatus +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.queries.base.builder import QueryBuilderBase + + +class MarkURLsAsSubmittedQueryBuilder(QueryBuilderBase): + + def __init__(self, infos: list[SubmittedURLInfo]): + super().__init__() + self.infos = infos + + async def run(self, session: AsyncSession): + for info in self.infos: + url_id = info.url_id + data_source_id = info.data_source_id + + url_data_source_object = URLDataSource( + url_id=url_id, + data_source_id=data_source_id + ) + if info.submitted_at is not None: + url_data_source_object.created_at = info.submitted_at + session.add(url_data_source_object) + diff --git a/src/core/tasks/url/operators/submit_approved/tdo.py b/src/core/tasks/url/operators/submit_approved/tdo.py new file mode 100644 index 00000000..89d89d9e --- /dev/null +++ b/src/core/tasks/url/operators/submit_approved/tdo.py @@ -0,0 +1,26 @@ +from datetime import datetime + +from pydantic import BaseModel + +from src.core.enums import RecordType + + +class SubmitApprovedURLTDO(BaseModel): + url_id: int + url: str + record_type: RecordType + agency_ids: list[int] + name: str + description: str | None = None + approving_user_id: int + record_formats: list[str] | None = None + data_portal_type: str | None = None + supplying_entity: str | None = None + data_source_id: int | None = None + request_error: str | None = None + +class SubmittedURLInfo(BaseModel): + url_id: int + data_source_id: int | None + request_error: str | None + submitted_at: datetime | None = None \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_approved_url/core.py b/src/core/tasks/url/operators/submit_approved_url/core.py deleted file mode 100644 index dd2df39e..00000000 --- a/src/core/tasks/url/operators/submit_approved_url/core.py +++ /dev/null @@ -1,65 +0,0 @@ -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.enums import TaskType -from src.core.tasks.url.operators.submit_approved_url.tdo import SubmitApprovedURLTDO -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.external.pdap.client import PDAPClient - - -class SubmitApprovedURLTaskOperator(URLTaskOperatorBase): - - def __init__( - self, - adb_client: AsyncDatabaseClient, - pdap_client: PDAPClient - ): - super().__init__(adb_client) - self.pdap_client = pdap_client - - @property - def task_type(self): - return TaskType.SUBMIT_APPROVED - - async def meets_task_prerequisites(self): - return await self.adb_client.has_validated_urls() - - async def inner_task_logic(self): - # Retrieve all URLs that are validated and not submitted - tdos: list[SubmitApprovedURLTDO] = await self.adb_client.get_validated_urls() - - # Link URLs to this task - await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) - - # Submit each URL, recording errors if they exist - submitted_url_infos = await self.pdap_client.submit_urls(tdos) - - error_infos = await self.get_error_infos(submitted_url_infos) - success_infos = await self.get_success_infos(submitted_url_infos) - - # Update the database for successful submissions - await self.adb_client.mark_urls_as_submitted(infos=success_infos) - - # Update the database for failed submissions - await self.adb_client.add_url_error_infos(error_infos) - - async def get_success_infos(self, submitted_url_infos): - success_infos = [ - response_object for response_object in submitted_url_infos - if response_object.data_source_id is not None - ] - return success_infos - - async def get_error_infos(self, submitted_url_infos): - error_infos: list[URLErrorPydanticInfo] = [] - error_response_objects = [ - response_object for response_object in submitted_url_infos - if response_object.request_error is not None - ] - for error_response_object in error_response_objects: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, - url_id=error_response_object.url_id, - error=error_response_object.request_error, - ) - error_infos.append(error_info) - return error_infos diff --git a/src/core/tasks/url/operators/submit_approved_url/tdo.py b/src/core/tasks/url/operators/submit_approved_url/tdo.py deleted file mode 100644 index d5193640..00000000 --- a/src/core/tasks/url/operators/submit_approved_url/tdo.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from src.core.enums import RecordType -from datetime import datetime - -class SubmitApprovedURLTDO(BaseModel): - url_id: int - url: str - record_type: RecordType - agency_ids: list[int] - name: str - description: str - approving_user_id: int - record_formats: Optional[list[str]] = None - data_portal_type: Optional[str] = None - supplying_entity: Optional[str] = None - data_source_id: Optional[int] = None - request_error: Optional[str] = None - -class SubmittedURLInfo(BaseModel): - url_id: int - data_source_id: Optional[int] - request_error: Optional[str] - submitted_at: Optional[datetime] = None \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_meta_urls/__init__.py b/src/core/tasks/url/operators/submit_meta_urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/submit_meta_urls/core.py b/src/core/tasks/url/operators/submit_meta_urls/core.py new file mode 100644 index 00000000..e06901da --- /dev/null +++ b/src/core/tasks/url/operators/submit_meta_urls/core.py @@ -0,0 +1,78 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.submit_meta_urls.queries.get import GetMetaURLsForSubmissionQueryBuilder +from src.core.tasks.url.operators.submit_meta_urls.queries.prereq import \ + MeetsMetaURLSSubmissionPrerequisitesQueryBuilder +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType +from src.db.models.impl.url.ds_meta_url.pydantic import URLDSMetaURLPydantic +from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall +from src.external.pdap.client import PDAPClient +from src.external.pdap.impl.meta_urls.enums import SubmitMetaURLsStatus +from src.external.pdap.impl.meta_urls.request import SubmitMetaURLsRequest +from src.external.pdap.impl.meta_urls.response import SubmitMetaURLsResponse +from src.util.url_mapper import URLMapper + + +class SubmitMetaURLsTaskOperator(URLTaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + pdap_client: PDAPClient + ): + super().__init__(adb_client) + self.pdap_client = pdap_client + + @property + def task_type(self) -> TaskType: + return TaskType.SUBMIT_META_URLS + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + MeetsMetaURLSSubmissionPrerequisitesQueryBuilder() + ) + + async def inner_task_logic(self) -> None: + requests: list[SubmitMetaURLsRequest] = await self.adb_client.run_query_builder( + GetMetaURLsForSubmissionQueryBuilder() + ) + + url_mappings: list[URLMapping] = [ + URLMapping( + url=request.url, + url_id=request.url_id, + ) + for request in requests + ] + + mapper = URLMapper(url_mappings) + + await self.link_urls_to_task(mapper.get_all_ids()) + + responses: list[SubmitMetaURLsResponse] = \ + await self.pdap_client.submit_meta_urls(requests) + + errors: list[URLTaskErrorSmall] = [] + inserts: list[URLDSMetaURLPydantic] = [] + + for response in responses: + url_id: int = mapper.get_id(response.url) + if response.status == SubmitMetaURLsStatus.SUCCESS: + inserts.append( + URLDSMetaURLPydantic( + url_id=url_id, + agency_id=response.agency_id, + ds_meta_url_id=response.meta_url_id + ) + ) + else: + errors.append( + URLTaskErrorSmall( + url_id=url_id, + error=response.error, + ) + ) + + await self.add_task_errors(errors) + await self.adb_client.bulk_insert(inserts) diff --git a/src/core/tasks/url/operators/submit_meta_urls/queries/__init__.py b/src/core/tasks/url/operators/submit_meta_urls/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/submit_meta_urls/queries/cte.py b/src/core/tasks/url/operators/submit_meta_urls/queries/cte.py new file mode 100644 index 00000000..d350258c --- /dev/null +++ b/src/core/tasks/url/operators/submit_meta_urls/queries/cte.py @@ -0,0 +1,61 @@ +from sqlalchemy import select, exists, Column, CTE + +from src.db.enums import TaskType +from src.db.helpers.query import no_url_task_error +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.ds_meta_url.sqlalchemy import URLDSMetaURL +from src.db.models.views.meta_url import MetaURL + + +class SubmitMetaURLsPrerequisitesCTEContainer: + + def __init__(self): + + self._cte = ( + select( + URL.id.label("url_id"), + URL.url, + LinkURLAgency.agency_id, + ) + # Validated as Meta URL + .join( + MetaURL, + MetaURL.url_id == URL.id + ) + .join( + LinkURLAgency, + LinkURLAgency.url_id == URL.id + ) + # Does not have a submission + .where( + ~exists( + select( + URLDSMetaURL.ds_meta_url_id + ) + .where( + URLDSMetaURL.url_id == URL.id, + URLDSMetaURL.agency_id == LinkURLAgency.agency_id + ) + ), + no_url_task_error(TaskType.SUBMIT_META_URLS) + ) + .cte("submit_meta_urls_prerequisites") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.c.url_id + + @property + def agency_id(self) -> Column[int]: + return self._cte.c.agency_id + + @property + def url(self) -> Column[str]: + return self._cte.c.url \ No newline at end of file diff --git a/src/core/tasks/url/operators/submit_meta_urls/queries/get.py b/src/core/tasks/url/operators/submit_meta_urls/queries/get.py new file mode 100644 index 00000000..518393f6 --- /dev/null +++ b/src/core/tasks/url/operators/submit_meta_urls/queries/get.py @@ -0,0 +1,34 @@ +from typing import Any, Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.submit_meta_urls.queries.cte import SubmitMetaURLsPrerequisitesCTEContainer +from src.db.queries.base.builder import QueryBuilderBase +from src.external.pdap.impl.meta_urls.request import SubmitMetaURLsRequest + +from src.db.helpers.session import session_helper as sh + +class GetMetaURLsForSubmissionQueryBuilder(QueryBuilderBase): + + + async def run(self, session: AsyncSession) -> list[SubmitMetaURLsRequest]: + cte = SubmitMetaURLsPrerequisitesCTEContainer() + query = ( + select( + cte.url_id, + cte.agency_id, + cte.url + ) + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + + return [ + SubmitMetaURLsRequest( + url_id=mapping["url_id"], + agency_id=mapping["agency_id"], + url=mapping["url"], + ) + for mapping in mappings + ] diff --git a/src/core/tasks/url/operators/submit_meta_urls/queries/prereq.py b/src/core/tasks/url/operators/submit_meta_urls/queries/prereq.py new file mode 100644 index 00000000..3b5538be --- /dev/null +++ b/src/core/tasks/url/operators/submit_meta_urls/queries/prereq.py @@ -0,0 +1,20 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.submit_meta_urls.queries.cte import SubmitMetaURLsPrerequisitesCTEContainer +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + + +class MeetsMetaURLSSubmissionPrerequisitesQueryBuilder(QueryBuilderBase): + + + async def run(self, session: AsyncSession) -> bool: + cte = SubmitMetaURLsPrerequisitesCTEContainer() + query = ( + select( + cte.url_id, + ) + ) + + return await sh.has_results(session, query=query) \ No newline at end of file diff --git a/src/core/tasks/url/operators/suspend/__init__.py b/src/core/tasks/url/operators/suspend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/suspend/core.py b/src/core/tasks/url/operators/suspend/core.py new file mode 100644 index 00000000..2dcfc53b --- /dev/null +++ b/src/core/tasks/url/operators/suspend/core.py @@ -0,0 +1,30 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.suspend.queries.get.query import GetURLsForSuspensionQueryBuilder +from src.core.tasks.url.operators.suspend.queries.get.response import GetURLsForSuspensionResponse +from src.core.tasks.url.operators.suspend.queries.insert import InsertURLSuspensionsQueryBuilder +from src.core.tasks.url.operators.suspend.queries.prereq import GetURLsForSuspensionPrerequisitesQueryBuilder +from src.db.enums import TaskType + + +class SuspendURLTaskOperator(URLTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.SUSPEND_URLS + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + GetURLsForSuspensionPrerequisitesQueryBuilder() + ) + + async def inner_task_logic(self) -> None: + # Get URLs for auto validation + responses: list[GetURLsForSuspensionResponse] = await self.adb_client.run_query_builder( + GetURLsForSuspensionQueryBuilder() + ) + url_ids: list[int] = [response.url_id for response in responses] + await self.link_urls_to_task(url_ids) + + await self.adb_client.run_query_builder( + InsertURLSuspensionsQueryBuilder(responses) + ) diff --git a/src/core/tasks/url/operators/suspend/queries/__init__.py b/src/core/tasks/url/operators/suspend/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/suspend/queries/cte.py b/src/core/tasks/url/operators/suspend/queries/cte.py new file mode 100644 index 00000000..7b15aee4 --- /dev/null +++ b/src/core/tasks/url/operators/suspend/queries/cte.py @@ -0,0 +1,49 @@ +from sqlalchemy import select, func, Select, exists, or_ + +from src.db.helpers.query import no_url_task_error +from src.db.models.impl.flag.url_suspended.sqlalchemy import FlagURLSuspended +from src.db.models.impl.link.user_suggestion_not_found.agency.sqlalchemy import LinkUserSuggestionAgencyNotFound +from src.db.models.impl.link.user_suggestion_not_found.location.sqlalchemy import LinkUserSuggestionLocationNotFound +from src.db.models.views.unvalidated_url import UnvalidatedURL + + +class GetURLsForSuspensionCTEContainer: + + def __init__(self): + self.cte = ( + select( + UnvalidatedURL.url_id + ) + .outerjoin( + LinkUserSuggestionAgencyNotFound, + UnvalidatedURL.url_id == LinkUserSuggestionAgencyNotFound.url_id + ) + .outerjoin( + LinkUserSuggestionLocationNotFound, + UnvalidatedURL.url_id == LinkUserSuggestionLocationNotFound.url_id + ) + .where( + ~exists( + select( + FlagURLSuspended.url_id + ) + .where( + FlagURLSuspended.url_id == UnvalidatedURL.url_id + ) + ), + ) + .group_by( + UnvalidatedURL.url_id + ) + .having( + or_( + func.count(LinkUserSuggestionAgencyNotFound.user_id) >= 2, + func.count(LinkUserSuggestionLocationNotFound.user_id) >= 2, + ) + ) + .cte("get_urls_for_suspension") + ) + + @property + def query(self) -> Select: + return select(self.cte.c.url_id) \ No newline at end of file diff --git a/src/core/tasks/url/operators/suspend/queries/get/__init__.py b/src/core/tasks/url/operators/suspend/queries/get/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/suspend/queries/get/query.py b/src/core/tasks/url/operators/suspend/queries/get/query.py new file mode 100644 index 00000000..23a48d5b --- /dev/null +++ b/src/core/tasks/url/operators/suspend/queries/get/query.py @@ -0,0 +1,16 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.suspend.queries.cte import GetURLsForSuspensionCTEContainer +from src.core.tasks.url.operators.suspend.queries.get.response import GetURLsForSuspensionResponse +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class GetURLsForSuspensionQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[GetURLsForSuspensionResponse]: + cte = GetURLsForSuspensionCTEContainer() + results = await sh.mappings(session=session, query=cte.query) + return [ + GetURLsForSuspensionResponse(url_id=result["url_id"]) + for result in results + ] diff --git a/src/core/tasks/url/operators/suspend/queries/get/response.py b/src/core/tasks/url/operators/suspend/queries/get/response.py new file mode 100644 index 00000000..2f207fbe --- /dev/null +++ b/src/core/tasks/url/operators/suspend/queries/get/response.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class GetURLsForSuspensionResponse(BaseModel): + url_id: int \ No newline at end of file diff --git a/src/core/tasks/url/operators/suspend/queries/insert.py b/src/core/tasks/url/operators/suspend/queries/insert.py new file mode 100644 index 00000000..e979563f --- /dev/null +++ b/src/core/tasks/url/operators/suspend/queries/insert.py @@ -0,0 +1,24 @@ +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.suspend.queries.get.response import GetURLsForSuspensionResponse +from src.db.models.impl.flag.url_suspended.sqlalchemy import FlagURLSuspended +from src.db.queries.base.builder import QueryBuilderBase + + +class InsertURLSuspensionsQueryBuilder(QueryBuilderBase): + + def __init__(self, responses: list[GetURLsForSuspensionResponse]): + super().__init__() + self.responses = responses + + async def run(self, session: AsyncSession) -> Any: + models: list[FlagURLSuspended] = [] + for response in self.responses: + models.append( + FlagURLSuspended( + url_id=response.url_id, + ) + ) + session.add_all(models) diff --git a/src/core/tasks/url/operators/suspend/queries/prereq.py b/src/core/tasks/url/operators/suspend/queries/prereq.py new file mode 100644 index 00000000..416d68f6 --- /dev/null +++ b/src/core/tasks/url/operators/suspend/queries/prereq.py @@ -0,0 +1,12 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.suspend.queries.cte import GetURLsForSuspensionCTEContainer +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsForSuspensionPrerequisitesQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + cte = GetURLsForSuspensionCTEContainer() + return await sh.results_exist(session=session, query=cte.query) diff --git a/src/core/tasks/url/operators/url_404_probe/core.py b/src/core/tasks/url/operators/url_404_probe/core.py deleted file mode 100644 index 7da96068..00000000 --- a/src/core/tasks/url/operators/url_404_probe/core.py +++ /dev/null @@ -1,63 +0,0 @@ -from http import HTTPStatus - -from pydantic import BaseModel - -from src.core.tasks.url.operators.url_html.scraper.request_interface.core import URLRequestInterface -from src.db.client.async_ import AsyncDatabaseClient -from src.db.enums import TaskType -from src.core.tasks.url.operators.url_404_probe.tdo import URL404ProbeTDO -from src.core.tasks.url.operators.base import URLTaskOperatorBase - - -class URL404ProbeTDOSubsets(BaseModel): - successful: list[URL404ProbeTDO] - is_404: list[URL404ProbeTDO] - - - -class URL404ProbeTaskOperator(URLTaskOperatorBase): - - def __init__( - self, - url_request_interface: URLRequestInterface, - adb_client: AsyncDatabaseClient, - ): - super().__init__(adb_client) - self.url_request_interface = url_request_interface - - @property - def task_type(self): - return TaskType.PROBE_404 - - async def meets_task_prerequisites(self): - return await self.adb_client.has_pending_urls_not_recently_probed_for_404() - - async def probe_urls_for_404(self, tdos: list[URL404ProbeTDO]): - responses = await self.url_request_interface.make_simple_requests( - urls=[tdo.url for tdo in tdos] - ) - for tdo, response in zip(tdos, responses): - if response.status is None: - continue - tdo.is_404 = response.status == HTTPStatus.NOT_FOUND - - - async def inner_task_logic(self): - tdos = await self.get_pending_urls_not_recently_probed_for_404() - url_ids = [task_info.url_id for task_info in tdos] - await self.link_urls_to_task(url_ids=url_ids) - await self.probe_urls_for_404(tdos) - url_ids_404 = [tdo.url_id for tdo in tdos if tdo.is_404] - - await self.update_404s_in_database(url_ids_404) - await self.mark_as_recently_probed_for_404(url_ids) - - async def get_pending_urls_not_recently_probed_for_404(self) -> list[URL404ProbeTDO]: - return await self.adb_client.get_pending_urls_not_recently_probed_for_404() - - async def update_404s_in_database(self, url_ids_404: list[int]): - await self.adb_client.mark_all_as_404(url_ids_404) - - async def mark_as_recently_probed_for_404(self, url_ids: list[int]): - await self.adb_client.mark_all_as_recently_probed_for_404(url_ids) - diff --git a/src/core/tasks/url/operators/url_404_probe/tdo.py b/src/core/tasks/url/operators/url_404_probe/tdo.py deleted file mode 100644 index f24cd7b3..00000000 --- a/src/core/tasks/url/operators/url_404_probe/tdo.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class URL404ProbeTDO(BaseModel): - url_id: int - url: str - is_404: Optional[bool] = None \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_duplicate/core.py b/src/core/tasks/url/operators/url_duplicate/core.py deleted file mode 100644 index ed3d00a5..00000000 --- a/src/core/tasks/url/operators/url_duplicate/core.py +++ /dev/null @@ -1,47 +0,0 @@ -from http import HTTPStatus - -from aiohttp import ClientResponseError - -from src.db.client.async_ import AsyncDatabaseClient -from src.db.enums import TaskType -from src.core.tasks.url.operators.url_duplicate.tdo import URLDuplicateTDO -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.external.pdap.client import PDAPClient - - -class URLDuplicateTaskOperator(URLTaskOperatorBase): - - def __init__( - self, - adb_client: AsyncDatabaseClient, - pdap_client: PDAPClient - ): - super().__init__(adb_client) - self.pdap_client = pdap_client - - @property - def task_type(self): - return TaskType.DUPLICATE_DETECTION - - async def meets_task_prerequisites(self): - return await self.adb_client.has_pending_urls_not_checked_for_duplicates() - - async def inner_task_logic(self): - tdos: list[URLDuplicateTDO] = await self.adb_client.get_pending_urls_not_checked_for_duplicates() - url_ids = [tdo.url_id for tdo in tdos] - await self.link_urls_to_task(url_ids=url_ids) - checked_tdos = [] - for tdo in tdos: - try: - tdo.is_duplicate = await self.pdap_client.is_url_duplicate(tdo.url) - checked_tdos.append(tdo) - except ClientResponseError as e: - print("ClientResponseError:", e.status) - if e.status == HTTPStatus.TOO_MANY_REQUESTS: - break - raise e - - duplicate_url_ids = [tdo.url_id for tdo in checked_tdos if tdo.is_duplicate] - checked_url_ids = [tdo.url_id for tdo in checked_tdos] - await self.adb_client.mark_all_as_duplicates(duplicate_url_ids) - await self.adb_client.mark_as_checked_for_duplicates(checked_url_ids) diff --git a/src/core/tasks/url/operators/url_duplicate/tdo.py b/src/core/tasks/url/operators/url_duplicate/tdo.py deleted file mode 100644 index af00ce38..00000000 --- a/src/core/tasks/url/operators/url_duplicate/tdo.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class URLDuplicateTDO(BaseModel): - url_id: int - url: str - is_duplicate: Optional[bool] = None diff --git a/src/core/tasks/url/operators/url_html/core.py b/src/core/tasks/url/operators/url_html/core.py deleted file mode 100644 index 495845a4..00000000 --- a/src/core/tasks/url/operators/url_html/core.py +++ /dev/null @@ -1,149 +0,0 @@ -from http import HTTPStatus - -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.dtos.url.core import URLInfo -from src.db.dtos.url.raw_html import RawHTMLInfo -from src.db.enums import TaskType -from src.core.tasks.url.operators.url_html.tdo import UrlHtmlTDO -from src.core.tasks.url.operators.url_html.content_info_getter import HTMLContentInfoGetter -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.core.tasks.url.operators.url_html.scraper.parser.core import HTMLResponseParser -from src.core.tasks.url.operators.url_html.scraper.request_interface.core import URLRequestInterface - - -class URLHTMLTaskOperator(URLTaskOperatorBase): - - def __init__( - self, - url_request_interface: URLRequestInterface, - adb_client: AsyncDatabaseClient, - html_parser: HTMLResponseParser - ): - super().__init__(adb_client) - self.url_request_interface = url_request_interface - self.html_parser = html_parser - - @property - def task_type(self): - return TaskType.HTML - - async def meets_task_prerequisites(self): - return await self.adb_client.has_pending_urls_without_html_data() - - async def inner_task_logic(self): - tdos = await self.get_pending_urls_without_html_data() - url_ids = [task_info.url_info.id for task_info in tdos] - await self.link_urls_to_task(url_ids=url_ids) - await self.get_raw_html_data_for_urls(tdos) - success_subset, error_subset = await self.separate_success_and_error_subsets(tdos) - non_404_error_subset, is_404_error_subset = await self.separate_error_and_404_subsets(error_subset) - await self.process_html_data(success_subset) - await self.update_database(is_404_error_subset, non_404_error_subset, success_subset) - - async def update_database( - self, - is_404_error_subset: list[UrlHtmlTDO], - non_404_error_subset: list[UrlHtmlTDO], - success_subset: list[UrlHtmlTDO] - ): - await self.update_errors_in_database(non_404_error_subset) - await self.update_404s_in_database(is_404_error_subset) - await self.update_html_data_in_database(success_subset) - - async def get_just_urls(self, tdos: list[UrlHtmlTDO]): - return [task_info.url_info.url for task_info in tdos] - - async def get_pending_urls_without_html_data(self): - pending_urls: list[URLInfo] = await self.adb_client.get_pending_urls_without_html_data() - tdos = [ - UrlHtmlTDO( - url_info=url_info, - ) for url_info in pending_urls - ] - return tdos - - async def get_raw_html_data_for_urls(self, tdos: list[UrlHtmlTDO]): - just_urls = await self.get_just_urls(tdos) - url_response_infos = await self.url_request_interface.make_requests_with_html(just_urls) - for tdto, url_response_info in zip(tdos, url_response_infos): - tdto.url_response_info = url_response_info - - async def separate_success_and_error_subsets( - self, - tdos: list[UrlHtmlTDO] - ) -> tuple[ - list[UrlHtmlTDO], # Successful - list[UrlHtmlTDO] # Error - ]: - errored_tdos = [] - successful_tdos = [] - for tdto in tdos: - if not tdto.url_response_info.success: - errored_tdos.append(tdto) - else: - successful_tdos.append(tdto) - return successful_tdos, errored_tdos - - async def separate_error_and_404_subsets( - self, - tdos: list[UrlHtmlTDO] - ) -> tuple[ - list[UrlHtmlTDO], # Error - list[UrlHtmlTDO] # 404 - ]: - tdos_error = [] - tdos_404 = [] - for tdo in tdos: - if tdo.url_response_info.status is None: - tdos_error.append(tdo) - continue - if tdo.url_response_info.status == HTTPStatus.NOT_FOUND: - tdos_404.append(tdo) - else: - tdos_error.append(tdo) - return tdos_error, tdos_404 - - async def update_404s_in_database(self, tdos_404: list[UrlHtmlTDO]): - url_ids = [tdo.url_info.id for tdo in tdos_404] - await self.adb_client.mark_all_as_404(url_ids) - - async def update_errors_in_database(self, error_tdos: list[UrlHtmlTDO]): - error_infos = [] - for error_tdo in error_tdos: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, - url_id=error_tdo.url_info.id, - error=str(error_tdo.url_response_info.exception), - ) - error_infos.append(error_info) - await self.adb_client.add_url_error_infos(error_infos) - - async def process_html_data(self, tdos: list[UrlHtmlTDO]): - for tdto in tdos: - - html_tag_info = await self.html_parser.parse( - url=tdto.url_info.url, - html_content=tdto.url_response_info.html, - content_type=tdto.url_response_info.content_type - ) - tdto.html_tag_info = html_tag_info - - async def update_html_data_in_database(self, tdos: list[UrlHtmlTDO]): - html_content_infos = [] - raw_html_data = [] - for tdto in tdos: - hcig = HTMLContentInfoGetter( - response_html_info=tdto.html_tag_info, - url_id=tdto.url_info.id - ) - rhi = RawHTMLInfo( - url_id=tdto.url_info.id, - html=tdto.url_response_info.html - ) - raw_html_data.append(rhi) - results = hcig.get_all_html_content() - html_content_infos.extend(results) - - await self.adb_client.add_html_content_infos(html_content_infos) - await self.adb_client.add_raw_html(raw_html_data) diff --git a/src/core/tasks/url/operators/url_html/queries/get_pending_urls_without_html_data.py b/src/core/tasks/url/operators/url_html/queries/get_pending_urls_without_html_data.py deleted file mode 100644 index 6af92abe..00000000 --- a/src/core/tasks/url/operators/url_html/queries/get_pending_urls_without_html_data.py +++ /dev/null @@ -1,32 +0,0 @@ -from sqlalchemy.ext.asyncio import AsyncSession - -from src.db.dto_converter import DTOConverter -from src.db.dtos.url.core import URLInfo -from src.db.models.instantiations.url.core import URL -from src.db.queries.base.builder import QueryBuilderBase -from src.db.statement_composer import StatementComposer - - -class GetPendingURLsWithoutHTMLDataQueryBuilder(QueryBuilderBase): - - async def run(self, session: AsyncSession) -> list[URLInfo]: - statement = StatementComposer.pending_urls_without_html_data() - statement = statement.limit(100).order_by(URL.id) - scalar_result = await session.scalars(statement) - url_results: list[URL] = scalar_result.all() - - final_results = [] - for url in url_results: - url_info = URLInfo( - id=url.id, - batch_id=url.batch.id if url.batch is not None else None, - url=url.url, - collector_metadata=url.collector_metadata, - outcome=url.outcome, - created_at=url.created_at, - updated_at=url.updated_at, - name=url.name - ) - final_results.append(url_info) - - return final_results diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/core.py b/src/core/tasks/url/operators/url_html/scraper/parser/core.py deleted file mode 100644 index 737f03dd..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/parser/core.py +++ /dev/null @@ -1,120 +0,0 @@ -import json -from typing import Optional - -from bs4 import BeautifulSoup - -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo -from src.core.tasks.url.operators.url_html.scraper.parser.enums import ParserTypeEnum -from src.core.tasks.url.operators.url_html.scraper.parser.constants import HEADER_TAGS -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.core import RootURLCache -from src.core.tasks.url.operators.url_html.scraper.parser.util import remove_excess_whitespace, add_https, remove_trailing_backslash, \ - drop_hostname - - -class HTMLResponseParser: - - def __init__(self, root_url_cache: RootURLCache): - self.root_url_cache = root_url_cache - - async def parse(self, url: str, html_content: str, content_type: str) -> ResponseHTMLInfo: - html_info = ResponseHTMLInfo() - self.add_url_and_path(html_info, html_content=html_content, url=url) - await self.add_root_page_titles(html_info) - parser_type = self.get_parser_type(content_type) - if parser_type is None: - return html_info - self.add_html_from_beautiful_soup( - html_info=html_info, - parser_type=parser_type, - html_content=html_content - ) - return html_info - - def add_html_from_beautiful_soup( - self, - html_info: ResponseHTMLInfo, - parser_type: ParserTypeEnum, - html_content: str - ): - soup = BeautifulSoup( - markup=html_content, - features=parser_type.value, - ) - html_info.title = self.get_html_title(soup) - html_info.description = self.get_meta_description(soup) - self.add_header_tags(html_info, soup) - html_info.div = self.get_div_text(soup) - # Prevents most bs4 memory leaks - if soup.html is not None: - soup.html.decompose() - - def get_div_text(self, soup): - div_text = "" - MAX_WORDS = 500 - for div in soup.find_all("div"): - text = div.get_text(" ", strip=True) - if text is None: - continue - # Check if adding the current text exceeds the word limit - if len(div_text.split()) + len(text.split()) <= MAX_WORDS: - div_text += text + " " - else: - break # Stop adding text if word limit is reached - - # Truncate to 5000 characters in case of run-on 'words' - div_text = div_text[: MAX_WORDS * 10] - - return div_text - - def get_meta_description(self, soup: BeautifulSoup) -> str: - meta_tag = soup.find("meta", attrs={"name": "description"}) - if meta_tag is None: - return "" - try: - return remove_excess_whitespace(meta_tag["content"]) - except KeyError: - return "" - - def add_header_tags(self, html_info: ResponseHTMLInfo, soup: BeautifulSoup): - for header_tag in HEADER_TAGS: - headers = soup.find_all(header_tag) - # Retrieves and drops headers containing links to reduce training bias - header_content = [header.get_text(" ", strip=True) for header in headers if not header.a] - tag_content = json.dumps(header_content, ensure_ascii=False) - if tag_content == "[]": - continue - setattr(html_info, header_tag, tag_content) - - def get_html_title(self, soup: BeautifulSoup) -> Optional[str]: - if soup.title is None: - return None - if soup.title.string is None: - return None - return remove_excess_whitespace(soup.title.string) - - - def add_url_and_path(self, html_info: ResponseHTMLInfo, html_content: str, url: str): - url = add_https(url) - html_info.url = url - - url_path = drop_hostname(url) - url_path = remove_trailing_backslash(url_path) - html_info.url_path = url_path - - async def add_root_page_titles(self, html_info: ResponseHTMLInfo): - root_page_title = await self.root_url_cache.get_title(html_info.url) - html_info.root_page_title = remove_excess_whitespace( - root_page_title - ) - - def get_parser_type(self, content_type: str) -> ParserTypeEnum or None: - try: - # If content type does not contain "html" or "xml" then we can assume that the content is unreadable - if "html" in content_type: - return ParserTypeEnum.LXML - if "xml" in content_type: - return ParserTypeEnum.LXML_XML - return None - except KeyError: - return None - diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/mapping.py b/src/core/tasks/url/operators/url_html/scraper/parser/mapping.py deleted file mode 100644 index 6b5f0b83..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/parser/mapping.py +++ /dev/null @@ -1,13 +0,0 @@ -from src.db.dtos.url.html_content import HTMLContentType - -ENUM_TO_ATTRIBUTE_MAPPING = { - HTMLContentType.TITLE: "title", - HTMLContentType.DESCRIPTION: "description", - HTMLContentType.H1: "h1", - HTMLContentType.H2: "h2", - HTMLContentType.H3: "h3", - HTMLContentType.H4: "h4", - HTMLContentType.H5: "h5", - HTMLContentType.H6: "h6", - HTMLContentType.DIV: "div" -} diff --git a/src/core/tasks/url/operators/url_html/scraper/parser/util.py b/src/core/tasks/url/operators/url_html/scraper/parser/util.py deleted file mode 100644 index 09453984..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/parser/util.py +++ /dev/null @@ -1,43 +0,0 @@ -from urllib.parse import urlparse - -from src.db.dtos.url.html_content import URLHTMLContentInfo -from src.core.tasks.url.operators.url_html.scraper.parser.mapping import ENUM_TO_ATTRIBUTE_MAPPING -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo - - -def convert_to_response_html_info(html_content_infos: list[URLHTMLContentInfo]): - response_html_info = ResponseHTMLInfo() - - for html_content_info in html_content_infos: - setattr(response_html_info, ENUM_TO_ATTRIBUTE_MAPPING[html_content_info.content_type], html_content_info.content) - - return response_html_info - - -def remove_excess_whitespace(s: str) -> str: - """Removes leading, trailing, and excess adjacent whitespace. - - Args: - s (str): String to remove whitespace from. - - Returns: - str: Clean string with excess whitespace stripped. - """ - return " ".join(s.split()).strip() - - -def add_https(url: str) -> str: - if not url.startswith("http"): - url = "https://" + url - return url - - -def remove_trailing_backslash(url_path): - if url_path and url_path[-1] == "/": - url_path = url_path[:-1] - return url_path - - -def drop_hostname(new_url): - url_path = urlparse(new_url).path[1:] - return url_path diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/constants.py b/src/core/tasks/url/operators/url_html/scraper/request_interface/constants.py deleted file mode 100644 index dc832aff..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/request_interface/constants.py +++ /dev/null @@ -1,2 +0,0 @@ -HTML_CONTENT_TYPE = "text/html" -MAX_CONCURRENCY = 5 diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/core.py b/src/core/tasks/url/operators/url_html/scraper/request_interface/core.py deleted file mode 100644 index f45780cb..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/request_interface/core.py +++ /dev/null @@ -1,80 +0,0 @@ -from http import HTTPStatus -from typing import Optional - -from aiohttp import ClientSession, ClientResponseError -from playwright.async_api import async_playwright -from tqdm.asyncio import tqdm - -from src.core.tasks.url.operators.url_html.scraper.request_interface.constants import HTML_CONTENT_TYPE -from src.core.tasks.url.operators.url_html.scraper.request_interface.dtos.request_resources import RequestResources -from src.core.tasks.url.operators.url_html.scraper.request_interface.dtos.url_response import URLResponseInfo - - -class URLRequestInterface: - - async def get_response(self, session: ClientSession, url: str) -> URLResponseInfo: - try: - async with session.get(url, timeout=20) as response: - response.raise_for_status() - text = await response.text() - return URLResponseInfo( - success=True, - html=text, - content_type=response.headers.get("content-type"), - status=HTTPStatus(response.status) - ) - except ClientResponseError as e: - return URLResponseInfo(success=False, status=HTTPStatus(e.status), exception=str(e)) - except Exception as e: - print(f"An error occurred while fetching {url}: {e}") - return URLResponseInfo(success=False, exception=str(e)) - - async def fetch_and_render(self, rr: RequestResources, url: str) -> Optional[URLResponseInfo]: - simple_response = await self.get_response(rr.session, url) - if not simple_response.success: - return simple_response - - if simple_response.content_type != HTML_CONTENT_TYPE: - return simple_response - - return await self.get_dynamic_html_content(rr, url) - - async def get_dynamic_html_content(self, rr, url): - # For HTML responses, attempt to load the page to check for dynamic html content - async with rr.semaphore: - page = await rr.browser.new_page() - try: - await page.goto(url) - await page.wait_for_load_state("networkidle") - html_content = await page.content() - return URLResponseInfo( - success=True, - html=html_content, - content_type=HTML_CONTENT_TYPE, - status=HTTPStatus.OK - ) - except Exception as e: - return URLResponseInfo(success=False, exception=str(e)) - finally: - await page.close() - - async def fetch_urls(self, urls: list[str]) -> list[URLResponseInfo]: - async with ClientSession() as session: - async with async_playwright() as playwright: - browser = await playwright.chromium.launch(headless=True) - request_resources = RequestResources(session=session, browser=browser) - tasks = [self.fetch_and_render(request_resources, url) for url in urls] - results = await tqdm.gather(*tasks) - return results - - async def make_requests_with_html( - self, - urls: list[str], - ) -> list[URLResponseInfo]: - return await self.fetch_urls(urls) - - async def make_simple_requests(self, urls: list[str]) -> list[URLResponseInfo]: - async with ClientSession() as session: - tasks = [self.get_response(session, url) for url in urls] - results = await tqdm.gather(*tasks) - return results diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/request_resources.py b/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/request_resources.py deleted file mode 100644 index 62ad714a..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/request_resources.py +++ /dev/null @@ -1,14 +0,0 @@ -import asyncio -from dataclasses import dataclass - -from aiohttp import ClientSession -from playwright.async_api import async_playwright - -from src.core.tasks.url.operators.url_html.scraper.request_interface.constants import MAX_CONCURRENCY - - -@dataclass -class RequestResources: - session: ClientSession - browser: async_playwright - semaphore: asyncio.Semaphore = asyncio.Semaphore(MAX_CONCURRENCY) diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/url_response.py b/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/url_response.py deleted file mode 100644 index 8e17c078..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/request_interface/dtos/url_response.py +++ /dev/null @@ -1,12 +0,0 @@ -from http import HTTPStatus -from typing import Optional - -from pydantic import BaseModel - - -class URLResponseInfo(BaseModel): - success: bool - status: Optional[HTTPStatus] = None - html: Optional[str] = None - content_type: Optional[str] = None - exception: Optional[str] = None diff --git a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/constants.py b/src/core/tasks/url/operators/url_html/scraper/root_url_cache/constants.py deleted file mode 100644 index 52d392e0..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/constants.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Some websites refuse the connection of automated requests, -setting the User-Agent will circumvent that. -""" -USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0 Safari/537.36" -REQUEST_HEADERS = { - "User-Agent": USER_AGENT, - # Make sure there's no pre-mature closing of responses before a redirect completes - "accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", - } diff --git a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/core.py b/src/core/tasks/url/operators/url_html/scraper/root_url_cache/core.py deleted file mode 100644 index c30bc16e..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/core.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import Optional -from urllib.parse import urlparse - -from aiohttp import ClientSession -from bs4 import BeautifulSoup - -from src.db.client.async_ import AsyncDatabaseClient -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.constants import REQUEST_HEADERS -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.dtos.response import RootURLCacheResponseInfo - -DEBUG = False - - -class RootURLCache: - def __init__(self, adb_client: Optional[AsyncDatabaseClient] = None): - if adb_client is None: - adb_client = AsyncDatabaseClient() - self.adb_client = adb_client - self.cache = None - - async def save_to_cache(self, url: str, title: str): - if url in self.cache: - return - self.cache[url] = title - await self.adb_client.add_to_root_url_cache(url=url, page_title=title) - - async def get_from_cache(self, url: str) -> Optional[str]: - if self.cache is None: - self.cache = await self.adb_client.load_root_url_cache() - - if url in self.cache: - return self.cache[url] - return None - - async def get_request(self, url: str) -> RootURLCacheResponseInfo: - async with ClientSession() as session: - try: - async with session.get(url, headers=REQUEST_HEADERS, timeout=120) as response: - response.raise_for_status() - text = await response.text() - return RootURLCacheResponseInfo(text=text) - except Exception as e: - return RootURLCacheResponseInfo(exception=e) - - async def get_title(self, url) -> str: - if not url.startswith('http'): - url = "https://" + url - - parsed_url = urlparse(url) - root_url = f"{parsed_url.scheme}://{parsed_url.netloc}" - - title = await self.get_from_cache(root_url) - if title is not None: - return title - - response_info = await self.get_request(root_url) - if response_info.exception is not None: - return self.handle_exception(response_info.exception) - - title = await self.get_title_from_soup(response_info.text) - - await self.save_to_cache(url=root_url, title=title) - - return title - - async def get_title_from_soup(self, text: str) -> str: - soup = BeautifulSoup(text, 'html.parser') - try: - title = soup.find('title').text - except AttributeError: - title = "" - # Prevents most bs4 memory leaks - if soup.html: - soup.html.decompose() - return title - - def handle_exception(self, e): - if DEBUG: - return f"Error retrieving title: {e}" - else: - return "" diff --git a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/dtos/response.py b/src/core/tasks/url/operators/url_html/scraper/root_url_cache/dtos/response.py deleted file mode 100644 index 6ea1d21c..00000000 --- a/src/core/tasks/url/operators/url_html/scraper/root_url_cache/dtos/response.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class RootURLCacheResponseInfo(BaseModel): - class Config: - arbitrary_types_allowed = True - - text: Optional[str] = None - exception: Optional[Exception] = None diff --git a/src/core/tasks/url/operators/url_html/tdo.py b/src/core/tasks/url/operators/url_html/tdo.py deleted file mode 100644 index 7fe14078..00000000 --- a/src/core/tasks/url/operators/url_html/tdo.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo -from src.db.dtos.url.core import URLInfo -from src.core.tasks.url.operators.url_html.scraper.request_interface.dtos.url_response import URLResponseInfo - - -class UrlHtmlTDO(BaseModel): - url_info: URLInfo - url_response_info: Optional[URLResponseInfo] = None - html_tag_info: Optional[ResponseHTMLInfo] = None - diff --git a/src/core/tasks/url/operators/url_miscellaneous_metadata/core.py b/src/core/tasks/url/operators/url_miscellaneous_metadata/core.py deleted file mode 100644 index 988fbe8b..00000000 --- a/src/core/tasks/url/operators/url_miscellaneous_metadata/core.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Optional - -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.enums import TaskType -from src.collectors.enums import CollectorType -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO -from src.core.tasks.url.operators.base import URLTaskOperatorBase -from src.core.tasks.url.subtasks.miscellaneous_metadata.auto_googler import AutoGooglerMiscMetadataSubtask -from src.core.tasks.url.subtasks.miscellaneous_metadata.ckan import CKANMiscMetadataSubtask -from src.core.tasks.url.subtasks.miscellaneous_metadata.base import \ - MiscellaneousMetadataSubtaskBase -from src.core.tasks.url.subtasks.miscellaneous_metadata.muckrock import MuckrockMiscMetadataSubtask - - -class URLMiscellaneousMetadataTaskOperator(URLTaskOperatorBase): - - def __init__( - self, - adb_client: AsyncDatabaseClient - ): - super().__init__(adb_client) - - @property - def task_type(self): - return TaskType.MISC_METADATA - - async def meets_task_prerequisites(self): - return await self.adb_client.has_pending_urls_missing_miscellaneous_metadata() - - async def get_subtask( - self, - collector_type: CollectorType - ) -> Optional[MiscellaneousMetadataSubtaskBase]: - match collector_type: - case CollectorType.MUCKROCK_SIMPLE_SEARCH: - return MuckrockMiscMetadataSubtask() - case CollectorType.MUCKROCK_COUNTY_SEARCH: - return MuckrockMiscMetadataSubtask() - case CollectorType.MUCKROCK_ALL_SEARCH: - return MuckrockMiscMetadataSubtask() - case CollectorType.AUTO_GOOGLER: - return AutoGooglerMiscMetadataSubtask() - case CollectorType.CKAN: - return CKANMiscMetadataSubtask() - case _: - return None - - async def html_default_logic(self, tdo: URLMiscellaneousMetadataTDO): - if tdo.name is None: - tdo.name = tdo.html_metadata_info.title - if tdo.description is None: - tdo.description = tdo.html_metadata_info.description - - async def inner_task_logic(self): - tdos: list[URLMiscellaneousMetadataTDO] = await self.adb_client.get_pending_urls_missing_miscellaneous_metadata() - await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) - - error_infos = [] - for tdo in tdos: - subtask = await self.get_subtask(tdo.collector_type) - try: - if subtask is not None: - subtask.process(tdo) - await self.html_default_logic(tdo) - except Exception as e: - error_info = URLErrorPydanticInfo( - task_id=self.task_id, - url_id=tdo.url_id, - error=str(e), - ) - error_infos.append(error_info) - - await self.adb_client.add_miscellaneous_metadata(tdos) - await self.adb_client.add_url_error_infos(error_infos) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/__init__.py b/src/core/tasks/url/operators/validate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/core.py b/src/core/tasks/url/operators/validate/core.py new file mode 100644 index 00000000..9d8aa5af --- /dev/null +++ b/src/core/tasks/url/operators/validate/core.py @@ -0,0 +1,30 @@ +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.validate.queries.get.core import GetURLsForAutoValidationQueryBuilder +from src.core.tasks.url.operators.validate.queries.get.models.response import GetURLsForAutoValidationResponse +from src.core.tasks.url.operators.validate.queries.insert import InsertURLAutoValidationsQueryBuilder +from src.core.tasks.url.operators.validate.queries.prereq.core import AutoValidatePrerequisitesQueryBuilder +from src.db.enums import TaskType + + +class AutoValidateURLTaskOperator(URLTaskOperatorBase): + + @property + def task_type(self) -> TaskType: + return TaskType.AUTO_VALIDATE + + async def meets_task_prerequisites(self) -> bool: + return await self.adb_client.run_query_builder( + AutoValidatePrerequisitesQueryBuilder() + ) + + async def inner_task_logic(self) -> None: + # Get URLs for auto validation + responses: list[GetURLsForAutoValidationResponse] = await self.adb_client.run_query_builder( + GetURLsForAutoValidationQueryBuilder() + ) + url_ids: list[int] = [response.url_id for response in responses] + await self.link_urls_to_task(url_ids) + + await self.adb_client.run_query_builder( + InsertURLAutoValidationsQueryBuilder(responses) + ) diff --git a/src/core/tasks/url/operators/validate/queries/__init__.py b/src/core/tasks/url/operators/validate/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/__init__.py b/src/core/tasks/url/operators/validate/queries/ctes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/__init__.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/base.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/base.py new file mode 100644 index 00000000..7a85df9c --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/base.py @@ -0,0 +1,15 @@ +from abc import ABC, abstractmethod + +from sqlalchemy import Column, CTE + + +class ValidationCTEContainer: + _query: CTE + + @property + def url_id(self) -> Column[int]: + return self._query.c.url_id + + @property + def query(self) -> CTE: + return self._query \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/helper.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/helper.py new file mode 100644 index 00000000..6078e5bb --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/helper.py @@ -0,0 +1,17 @@ +from sqlalchemy import CTE, select + +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +def build_validation_query( + scored_cte: ScoredCTEContainer, + label: str +) -> CTE: + return select( + scored_cte.url_id, + scored_cte.entity.label(label) + ).where( + scored_cte.max_votes >= 2, + scored_cte.votes == scored_cte.max_votes, + scored_cte.num_labels_with_that_vote == 1 + ).cte(f"{label}_validation") diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/__init__.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/agency.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/agency.py new file mode 100644 index 00000000..b5b5ee63 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/agency.py @@ -0,0 +1,24 @@ +from sqlalchemy import select, Column + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.base import ValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.helper import build_validation_query +from src.core.tasks.url.operators.validate.queries.ctes.counts.impl.agency import AGENCY_VALIDATION_COUNTS_CTE +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +class AgencyValidationCTEContainer(ValidationCTEContainer): + + def __init__(self): + _scored = ScoredCTEContainer( + AGENCY_VALIDATION_COUNTS_CTE + ) + + self._query = build_validation_query( + _scored, + "agency_id" + ) + + + @property + def agency_id(self) -> Column[int]: + return self._query.c.agency_id \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/location.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/location.py new file mode 100644 index 00000000..29951968 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/location.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.base import ValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.helper import build_validation_query +from src.core.tasks.url.operators.validate.queries.ctes.counts.impl.location import LOCATION_VALIDATION_COUNTS_CTE +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +class LocationValidationCTEContainer(ValidationCTEContainer): + + def __init__(self): + _scored = ScoredCTEContainer( + LOCATION_VALIDATION_COUNTS_CTE + ) + + self._query = build_validation_query( + _scored, + "location_id" + ) + + @property + def location_id(self) -> Column[int]: + return self._query.c.location_id \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/name.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/name.py new file mode 100644 index 00000000..b51f77b5 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/name.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.base import ValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.helper import build_validation_query +from src.core.tasks.url.operators.validate.queries.ctes.counts.impl.name import NAME_VALIDATION_COUNTS_CTE +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +class NameValidationCTEContainer(ValidationCTEContainer): + + def __init__(self): + _scored = ScoredCTEContainer( + NAME_VALIDATION_COUNTS_CTE + ) + + self._query = build_validation_query( + _scored, + "name" + ) + + @property + def name(self) -> Column[int]: + return self._query.c.name \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/record_type.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/record_type.py new file mode 100644 index 00000000..befb0c7e --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/record_type.py @@ -0,0 +1,24 @@ +from sqlalchemy import select, Column + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.base import ValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.helper import build_validation_query +from src.core.tasks.url.operators.validate.queries.ctes.counts.impl.record_type import RECORD_TYPE_COUNTS_CTE +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +class RecordTypeValidationCTEContainer(ValidationCTEContainer): + + def __init__(self): + + _scored = ScoredCTEContainer( + RECORD_TYPE_COUNTS_CTE + ) + + self._query = build_validation_query( + _scored, + "record_type" + ) + + @property + def record_type(self) -> Column[str]: + return self._query.c.record_type \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/url_type.py b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/url_type.py new file mode 100644 index 00000000..4d4ec750 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/consensus/impl/url_type.py @@ -0,0 +1,23 @@ +from sqlalchemy import select, Column + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.base import ValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.helper import build_validation_query +from src.core.tasks.url.operators.validate.queries.ctes.counts.impl.url_type import URL_TYPES_VALIDATION_COUNTS_CTE +from src.core.tasks.url.operators.validate.queries.ctes.scored import ScoredCTEContainer + + +class URLTypeValidationCTEContainer(ValidationCTEContainer): + + def __init__(self): + _scored = ScoredCTEContainer( + URL_TYPES_VALIDATION_COUNTS_CTE + ) + + self._query = build_validation_query( + _scored, + "url_type" + ) + + @property + def url_type(self) -> Column[str]: + return self._query.c.url_type \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/__init__.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/core.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/core.py new file mode 100644 index 00000000..af7e97b4 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/core.py @@ -0,0 +1,23 @@ +from sqlalchemy import CTE, Column + + +class ValidatedCountsCTEContainer: + + def __init__(self, cte: CTE): + self._cte: CTE = cte + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.c.url_id + + @property + def entity(self) -> Column: + return self._cte.c.entity + + @property + def votes(self) -> Column[int]: + return self._cte.c.votes \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/__init__.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/agency.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/agency.py new file mode 100644 index 00000000..e9df9db4 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/agency.py @@ -0,0 +1,24 @@ +from sqlalchemy import select, func + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL + +AGENCY_VALIDATION_COUNTS_CTE = ValidatedCountsCTEContainer( + ( + select( + UserUrlAgencySuggestion.url_id, + UserUrlAgencySuggestion.agency_id.label("entity"), + func.count().label("votes") + ) + .join( + UnvalidatedURL, + UserUrlAgencySuggestion.url_id == UnvalidatedURL.url_id + ) + .group_by( + UserUrlAgencySuggestion.url_id, + UserUrlAgencySuggestion.agency_id + ) + .cte("counts_agency") + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/location.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/location.py new file mode 100644 index 00000000..2ef385cc --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/location.py @@ -0,0 +1,24 @@ +from sqlalchemy import select, func + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL + +LOCATION_VALIDATION_COUNTS_CTE = ValidatedCountsCTEContainer( + ( + select( + UserLocationSuggestion.url_id, + UserLocationSuggestion.location_id.label("entity"), + func.count().label("votes") + ) + .join( + UnvalidatedURL, + UserLocationSuggestion.url_id == UnvalidatedURL.url_id + ) + .group_by( + UserLocationSuggestion.url_id, + UserLocationSuggestion.location_id + ) + .cte("counts_location") + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/name.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/name.py new file mode 100644 index 00000000..5cb014f1 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/name.py @@ -0,0 +1,28 @@ +from sqlalchemy import select, func + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL + +NAME_VALIDATION_COUNTS_CTE = ValidatedCountsCTEContainer( + ( + select( + URLNameSuggestion.url_id, + URLNameSuggestion.suggestion.label("entity"), + func.count().label("votes") + ) + .join( + UnvalidatedURL, + URLNameSuggestion.url_id == UnvalidatedURL.url_id + ) + .join( + LinkUserNameSuggestion, + LinkUserNameSuggestion.suggestion_id == URLNameSuggestion.id + ) + .group_by( + URLNameSuggestion.url_id, + URLNameSuggestion.suggestion + ) + ).cte("counts_name") +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/record_type.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/record_type.py new file mode 100644 index 00000000..6300ec92 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/record_type.py @@ -0,0 +1,24 @@ +from sqlalchemy import select, func + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL + +RECORD_TYPE_COUNTS_CTE = ValidatedCountsCTEContainer( + ( + select( + UserRecordTypeSuggestion.url_id, + UserRecordTypeSuggestion.record_type.label("entity"), + func.count().label("votes") + ) + .join( + UnvalidatedURL, + UserRecordTypeSuggestion.url_id == UnvalidatedURL.url_id + ) + .group_by( + UserRecordTypeSuggestion.url_id, + UserRecordTypeSuggestion.record_type + ) + .cte("counts_record_type") + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/url_type.py b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/url_type.py new file mode 100644 index 00000000..0e3de946 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/counts/impl/url_type.py @@ -0,0 +1,25 @@ +from sqlalchemy import select, func + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from src.db.models.views.unvalidated_url import UnvalidatedURL + +URL_TYPES_VALIDATION_COUNTS_CTE = ValidatedCountsCTEContainer( + ( + select( + UserURLTypeSuggestion.url_id, + UserURLTypeSuggestion.type.label("entity"), + func.count().label("votes") + ) + .join( + UnvalidatedURL, + UserURLTypeSuggestion.url_id == UnvalidatedURL.url_id + ) + .group_by( + UserURLTypeSuggestion.url_id, + UserURLTypeSuggestion.type + ) + .cte("counts_url_type") + ) +) \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/ctes/scored.py b/src/core/tasks/url/operators/validate/queries/ctes/scored.py new file mode 100644 index 00000000..557e38ea --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/ctes/scored.py @@ -0,0 +1,52 @@ +from sqlalchemy import CTE, select, func, Column + +from src.core.tasks.url.operators.validate.queries.ctes.counts.core import ValidatedCountsCTEContainer + + +class ScoredCTEContainer: + + def __init__( + self, + counts_cte: ValidatedCountsCTEContainer + ): + self._cte: CTE = ( + select( + counts_cte.url_id, + counts_cte.entity, + counts_cte.votes, + func.max(counts_cte.votes).over( + partition_by=counts_cte.url_id + ).label("max_votes"), + func.count().over( + partition_by=( + counts_cte.url_id, + counts_cte.votes + ) + ).label("num_labels_with_that_vote") + ) + .cte(f"scored_{counts_cte.cte.name}") + ) + + @property + def cte(self) -> CTE: + return self._cte + + @property + def url_id(self) -> Column[int]: + return self._cte.c.url_id + + @property + def entity(self) -> Column: + return self._cte.c.entity + + @property + def votes(self) -> Column[int]: + return self._cte.c.votes + + @property + def max_votes(self) -> Column[int]: + return self._cte.c.max_votes + + @property + def num_labels_with_that_vote(self) -> Column[int]: + return self._cte.c.num_labels_with_that_vote \ No newline at end of file diff --git a/src/core/tasks/url/operators/validate/queries/get/__init__.py b/src/core/tasks/url/operators/validate/queries/get/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/get/core.py b/src/core/tasks/url/operators/validate/queries/get/core.py new file mode 100644 index 00000000..31d21f07 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/get/core.py @@ -0,0 +1,78 @@ +from typing import Sequence + +from sqlalchemy import select, RowMapping +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.exceptions import FailedValidationException +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.agency import AgencyValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.location import LocationValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.name import NameValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.record_type import \ + RecordTypeValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.url_type import URLTypeValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.get.models.response import GetURLsForAutoValidationResponse +from src.core.tasks.url.operators.validate.queries.helper import add_where_condition +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetURLsForAutoValidationQueryBuilder(QueryBuilderBase): + + + async def run(self, session: AsyncSession) -> list[GetURLsForAutoValidationResponse]: + agency = AgencyValidationCTEContainer() + location = LocationValidationCTEContainer() + url_type = URLTypeValidationCTEContainer() + record_type = RecordTypeValidationCTEContainer() + name = NameValidationCTEContainer() + + query = ( + select( + URL.id.label("url_id"), + location.location_id, + agency.agency_id, + url_type.url_type, + record_type.record_type, + name.name, + ) + .outerjoin( + agency.query, + URL.id == agency.url_id, + ) + .outerjoin( + location.query, + URL.id == location.url_id, + ) + .outerjoin( + url_type.query, + URL.id == url_type.url_id, + ) + .outerjoin( + record_type.query, + URL.id == record_type.url_id, + ) + .outerjoin( + name.query, + URL.id == name.url_id, + ) + ) + query = add_where_condition( + query, + agency=agency, + location=location, + url_type=url_type, + record_type=record_type, + name=name, + ) + + mappings: Sequence[RowMapping] = await sh.mappings(session, query=query) + responses: list[GetURLsForAutoValidationResponse] = [] + for mapping in mappings: + try: + response = GetURLsForAutoValidationResponse(**mapping) + responses.append(response) + except FailedValidationException as e: + raise FailedValidationException( + f"Failed to validate URL {mapping['url_id']}") from e + return responses diff --git a/src/core/tasks/url/operators/validate/queries/get/models/__init__.py b/src/core/tasks/url/operators/validate/queries/get/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/get/models/response.py b/src/core/tasks/url/operators/validate/queries/get/models/response.py new file mode 100644 index 00000000..6913e256 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/get/models/response.py @@ -0,0 +1,68 @@ +from pydantic import BaseModel, model_validator + +from src.core.enums import RecordType +from src.core.exceptions import FailedValidationException +from src.db.models.impl.flag.url_validated.enums import URLType + + +class GetURLsForAutoValidationResponse(BaseModel): + url_id: int + location_id: int | None + agency_id: int | None + url_type: URLType + record_type: RecordType | None + name: str | None + + @model_validator(mode="after") + def forbid_record_type_if_not_data_source(self): + if self.url_type == URLType.DATA_SOURCE: + return self + if self.record_type is not None: + raise FailedValidationException("record_type must be None if suggested_status is META_URL") + return self + + + @model_validator(mode="after") + def require_record_type_if_data_source(self): + if self.url_type == URLType.DATA_SOURCE and self.record_type is None: + raise FailedValidationException("record_type must be provided if suggested_status is DATA_SOURCE") + return self + + @model_validator(mode="after") + def require_location_if_relevant(self): + if self.url_type not in [ + URLType.META_URL, + URLType.DATA_SOURCE, + URLType.INDIVIDUAL_RECORD, + ]: + return self + if self.location_id is None: + raise FailedValidationException("location_id must be provided if suggested_status is META_URL or DATA_SOURCE") + return self + + + @model_validator(mode="after") + def require_agency_id_if_relevant(self): + if self.url_type not in [ + URLType.META_URL, + URLType.DATA_SOURCE, + URLType.INDIVIDUAL_RECORD, + ]: + return self + if self.agency_id is None: + raise FailedValidationException("agency_id must be provided if suggested_status is META_URL or DATA_SOURCE") + return self + + @model_validator(mode="after") + def forbid_all_else_if_not_relevant(self): + if self.url_type != URLType.NOT_RELEVANT: + return self + if self.record_type is not None: + raise FailedValidationException("record_type must be None if suggested_status is NOT RELEVANT") + if self.agency_id is not None: + raise FailedValidationException("agency_ids must be empty if suggested_status is NOT RELEVANT") + if self.location_id is not None: + raise FailedValidationException("location_ids must be empty if suggested_status is NOT RELEVANT") + return self + + diff --git a/src/core/tasks/url/operators/validate/queries/helper.py b/src/core/tasks/url/operators/validate/queries/helper.py new file mode 100644 index 00000000..e2632ca6 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/helper.py @@ -0,0 +1,43 @@ +from sqlalchemy import Select, or_, and_ + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.agency import AgencyValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.location import LocationValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.name import NameValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.record_type import \ + RecordTypeValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.url_type import URLTypeValidationCTEContainer +from src.db.models.impl.flag.url_validated.enums import URLType + + +def add_where_condition( + query: Select, + agency: AgencyValidationCTEContainer, + location: LocationValidationCTEContainer, + url_type: URLTypeValidationCTEContainer, + record_type: RecordTypeValidationCTEContainer, + name: NameValidationCTEContainer, +) -> Select: + return ( + query + .where( + url_type.url_type.isnot(None), + or_( + and_( + url_type.url_type == URLType.DATA_SOURCE.value, + agency.agency_id.isnot(None), + location.location_id.isnot(None), + record_type.record_type.isnot(None), + name.name.isnot(None), + ), + and_( + url_type.url_type.in_( + (URLType.META_URL.value, URLType.INDIVIDUAL_RECORD.value) + ), + agency.agency_id.isnot(None), + location.location_id.isnot(None), + name.name.isnot(None), + ), + url_type.url_type == URLType.NOT_RELEVANT.value + ), + ) + ) diff --git a/src/core/tasks/url/operators/validate/queries/insert.py b/src/core/tasks/url/operators/validate/queries/insert.py new file mode 100644 index 00000000..31bdfa74 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/insert.py @@ -0,0 +1,85 @@ +from typing import Any + +from sqlalchemy import update, case +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.validate.queries.get.models.response import GetURLsForAutoValidationResponse +from src.db.models.impl.flag.auto_validated.pydantic import FlagURLAutoValidatedPydantic +from src.db.models.impl.flag.url_validated.pydantic import FlagURLValidatedPydantic +from src.db.models.impl.link.url_agency.pydantic import LinkURLAgencyPydantic +from src.db.models.impl.url.core.pydantic.upsert import URLUpsertModel +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.record_type.pydantic import URLRecordTypePydantic +from src.db.queries.base.builder import QueryBuilderBase +from src.db.helpers.session import session_helper as sh + +class InsertURLAutoValidationsQueryBuilder(QueryBuilderBase): + + def __init__(self, responses: list[GetURLsForAutoValidationResponse]): + super().__init__() + self._responses = responses + + async def run(self, session: AsyncSession) -> Any: + url_record_types: list[URLRecordTypePydantic] = [] + link_url_agencies: list[LinkURLAgencyPydantic] = [] + url_validated_flags: list[FlagURLValidatedPydantic] = [] + url_auto_validated_flags: list[FlagURLAutoValidatedPydantic] = [] + + for response in self._responses: + if response.agency_id is not None: + link_url_agency: LinkURLAgencyPydantic = LinkURLAgencyPydantic( + url_id=response.url_id, + agency_id=response.agency_id + ) + link_url_agencies.append(link_url_agency) + + if response.record_type is not None: + url_record_type: URLRecordTypePydantic = URLRecordTypePydantic( + url_id=response.url_id, + record_type=response.record_type + ) + url_record_types.append(url_record_type) + + url_validated_flag: FlagURLValidatedPydantic = FlagURLValidatedPydantic( + url_id=response.url_id, + type=response.url_type + ) + url_validated_flags.append(url_validated_flag) + + url_auto_validated_flag: FlagURLAutoValidatedPydantic = FlagURLAutoValidatedPydantic( + url_id=response.url_id, + ) + url_auto_validated_flags.append(url_auto_validated_flag) + + for inserts in [ + link_url_agencies, + url_record_types, + url_validated_flags, + url_auto_validated_flags, + ]: + await sh.bulk_insert(session, models=inserts) + + await self.update_urls(session) + + + async def update_urls(self, session: AsyncSession) -> Any: + id_to_name: dict[int, str] = {} + for response in self._responses: + if response.name is not None: + id_to_name[response.url_id] = response.name + + if len(id_to_name) == 0: + return + + stmt = ( + update(URL) + .where(URL.id.in_(id_to_name.keys())) + .values( + name=case( + {id_: val for id_, val in id_to_name.items()}, + value=URL.id + ) + ) + ) + + await session.execute(stmt) diff --git a/src/core/tasks/url/operators/validate/queries/prereq/__init__.py b/src/core/tasks/url/operators/validate/queries/prereq/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/validate/queries/prereq/core.py b/src/core/tasks/url/operators/validate/queries/prereq/core.py new file mode 100644 index 00000000..6ee25e53 --- /dev/null +++ b/src/core/tasks/url/operators/validate/queries/prereq/core.py @@ -0,0 +1,71 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.agency import AgencyValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.location import LocationValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.name import NameValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.record_type import \ + RecordTypeValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.ctes.consensus.impl.url_type import URLTypeValidationCTEContainer +from src.core.tasks.url.operators.validate.queries.helper import add_where_condition +from src.db.helpers.session import session_helper as sh +from src.db.models.views.unvalidated_url import UnvalidatedURL +from src.db.queries.base.builder import QueryBuilderBase + + +class AutoValidatePrerequisitesQueryBuilder(QueryBuilderBase): + """ + Checks to see if any URL meets any of the following prerequisites + - Is a DATA SOURCE URL with consensus on all fields + - Is a META URL with consensus on url_type, agency, and location fields + - Is a NOT RELEVANT or SINGLE PAGE URL with consensus on url_type + """ + + async def run(self, session: AsyncSession) -> bool: + agency = AgencyValidationCTEContainer() + location = LocationValidationCTEContainer() + url_type = URLTypeValidationCTEContainer() + record_type = RecordTypeValidationCTEContainer() + name = NameValidationCTEContainer() + + + query = ( + select( + UnvalidatedURL.url_id, + ) + .select_from( + UnvalidatedURL + ) + .outerjoin( + agency.query, + UnvalidatedURL.url_id == agency.url_id, + ) + .outerjoin( + location.query, + UnvalidatedURL.url_id == location.url_id, + ) + .outerjoin( + url_type.query, + UnvalidatedURL.url_id == url_type.url_id, + ) + .outerjoin( + record_type.query, + UnvalidatedURL.url_id == record_type.url_id, + ) + .outerjoin( + name.query, + UnvalidatedURL.url_id == name.url_id, + ) + ) + query = add_where_condition( + query, + agency=agency, + location=location, + url_type=url_type, + record_type=record_type, + name=name, + ).limit(1) + + return await sh.results_exist(session, query=query) + + diff --git a/src/core/tasks/url/subtasks/agency_identification/auto_googler.py b/src/core/tasks/url/subtasks/agency_identification/auto_googler.py deleted file mode 100644 index 6f19ee7b..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/auto_googler.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Optional - -from src.core.enums import SuggestionType -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.core.tasks.url.subtasks.agency_identification.base import AgencyIdentificationSubtaskBase - - -class AutoGooglerAgencyIdentificationSubtask(AgencyIdentificationSubtaskBase): - - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] = None - ) -> list[URLAgencySuggestionInfo]: - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=None, - agency_name=None, - state=None, - county=None, - locality=None - ) - ] diff --git a/src/core/tasks/url/subtasks/agency_identification/base.py b/src/core/tasks/url/subtasks/agency_identification/base.py deleted file mode 100644 index 5727fcc8..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/base.py +++ /dev/null @@ -1,16 +0,0 @@ -import abc -from abc import ABC -from typing import Optional - -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo - - -class AgencyIdentificationSubtaskBase(ABC): - - @abc.abstractmethod - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] = None - ) -> list[URLAgencySuggestionInfo]: - raise NotImplementedError diff --git a/src/core/tasks/url/subtasks/agency_identification/ckan.py b/src/core/tasks/url/subtasks/agency_identification/ckan.py deleted file mode 100644 index 6092aed4..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/ckan.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Optional - -from src.core.helpers import process_match_agency_response_to_suggestions -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.external.pdap.client import PDAPClient -from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse - - -class CKANAgencyIdentificationSubtask: - - def __init__( - self, - pdap_client: PDAPClient - ): - self.pdap_client = pdap_client - - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] - ) -> list[URLAgencySuggestionInfo]: - agency_name = collector_metadata["agency_name"] - match_agency_response: MatchAgencyResponse = await self.pdap_client.match_agency( - name=agency_name - ) - return process_match_agency_response_to_suggestions( - url_id=url_id, - match_agency_response=match_agency_response - ) diff --git a/src/core/tasks/url/subtasks/agency_identification/common_crawler.py b/src/core/tasks/url/subtasks/agency_identification/common_crawler.py deleted file mode 100644 index fae8faaf..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/common_crawler.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Optional - -from src.core.enums import SuggestionType -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo - - -class CommonCrawlerAgencyIdentificationSubtask: - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] - ) -> list[URLAgencySuggestionInfo]: - return [ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=None, - agency_name=None, - state=None, - county=None, - locality=None - ) - ] diff --git a/src/core/tasks/url/subtasks/agency_identification/muckrock.py b/src/core/tasks/url/subtasks/agency_identification/muckrock.py deleted file mode 100644 index df61e281..00000000 --- a/src/core/tasks/url/subtasks/agency_identification/muckrock.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional - -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface -from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse -from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType -from src.core.exceptions import MuckrockAPIError -from src.core.helpers import process_match_agency_response_to_suggestions -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.external.pdap.client import PDAPClient -from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse - - -class MuckrockAgencyIdentificationSubtask: - - def __init__( - self, - muckrock_api_interface: MuckrockAPIInterface, - pdap_client: PDAPClient - ): - self.muckrock_api_interface = muckrock_api_interface - self.pdap_client = pdap_client - - async def run( - self, - url_id: int, - collector_metadata: Optional[dict] - ) -> list[URLAgencySuggestionInfo]: - muckrock_agency_id = collector_metadata["agency"] - agency_lookup_response: AgencyLookupResponse = await self.muckrock_api_interface.lookup_agency( - muckrock_agency_id=muckrock_agency_id - ) - if agency_lookup_response.type != AgencyLookupResponseType.FOUND: - raise MuckrockAPIError( - f"Failed to lookup muckrock agency: {muckrock_agency_id}:" - f" {agency_lookup_response.type.value}: {agency_lookup_response.error}" - ) - - match_agency_response: MatchAgencyResponse = await self.pdap_client.match_agency( - name=agency_lookup_response.name - ) - return process_match_agency_response_to_suggestions( - url_id=url_id, - match_agency_response=match_agency_response - ) diff --git a/src/core/tasks/url/subtasks/miscellaneous_metadata/auto_googler.py b/src/core/tasks/url/subtasks/miscellaneous_metadata/auto_googler.py index 0f183f78..e060d0d3 100644 --- a/src/core/tasks/url/subtasks/miscellaneous_metadata/auto_googler.py +++ b/src/core/tasks/url/subtasks/miscellaneous_metadata/auto_googler.py @@ -1,4 +1,4 @@ -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO from src.core.tasks.url.subtasks.miscellaneous_metadata.base import \ MiscellaneousMetadataSubtaskBase diff --git a/src/core/tasks/url/subtasks/miscellaneous_metadata/base.py b/src/core/tasks/url/subtasks/miscellaneous_metadata/base.py index 7b38504d..3ca7357b 100644 --- a/src/core/tasks/url/subtasks/miscellaneous_metadata/base.py +++ b/src/core/tasks/url/subtasks/miscellaneous_metadata/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO class MiscellaneousMetadataSubtaskBase(ABC): diff --git a/src/core/tasks/url/subtasks/miscellaneous_metadata/ckan.py b/src/core/tasks/url/subtasks/miscellaneous_metadata/ckan.py index 90512e2b..ef60b48c 100644 --- a/src/core/tasks/url/subtasks/miscellaneous_metadata/ckan.py +++ b/src/core/tasks/url/subtasks/miscellaneous_metadata/ckan.py @@ -1,4 +1,4 @@ -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO from src.core.tasks.url.subtasks.miscellaneous_metadata.base import \ MiscellaneousMetadataSubtaskBase diff --git a/src/core/tasks/url/subtasks/miscellaneous_metadata/muckrock.py b/src/core/tasks/url/subtasks/miscellaneous_metadata/muckrock.py index bb3eaadf..18a749b7 100644 --- a/src/core/tasks/url/subtasks/miscellaneous_metadata/muckrock.py +++ b/src/core/tasks/url/subtasks/miscellaneous_metadata/muckrock.py @@ -1,4 +1,4 @@ -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO from src.core.tasks.url.subtasks.miscellaneous_metadata.base import \ MiscellaneousMetadataSubtaskBase diff --git a/src/db/__init__.py b/src/db/__init__.py index e69de29b..812e7e5b 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -0,0 +1,6 @@ + + +from src.db.models.impl.location.location.sqlalchemy import Location +from src.db.models.impl.location.us_state.sqlalchemy import USState +from src.db.models.impl.location.county.sqlalchemy import County +from src.db.models.impl.location.locality.sqlalchemy import Locality diff --git a/src/db/client/async_.py b/src/db/client/async_.py index 45505be5..93c36544 100644 --- a/src/db/client/async_.py +++ b/src/db/client/async_.py @@ -1,26 +1,13 @@ from datetime import datetime, timedelta from functools import wraps -from operator import or_ from typing import Optional, Type, Any, List, Sequence -from sqlalchemy import select, exists, func, case, Select, and_, update, delete, literal, text, Row -from sqlalchemy.dialects import postgresql -from sqlalchemy.dialects.postgresql import insert as pg_insert -from sqlalchemy.exc import IntegrityError, NoResultFound +from sqlalchemy import select, exists, func, Select, and_, update, delete, Row, text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.orm import selectinload, QueryableAttribute - -from src.api.endpoints.annotate._shared.queries.get_annotation_batch_info import GetAnnotationBatchInfoQueryBuilder -from src.api.endpoints.annotate._shared.queries.get_next_url_for_user_annotation import \ - GetNextURLForUserAnnotationQueryBuilder -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAnnotationResponse -from src.api.endpoints.annotate.agency.get.queries.next_for_annotation import GetNextURLAgencyForAnnotationQueryBuilder -from src.api.endpoints.annotate.all.get.dto import GetNextURLForAllAnnotationResponse -from src.api.endpoints.annotate.all.get.query import GetNextURLForAllAnnotationQueryBuilder -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseInfo -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseInfo -from src.api.endpoints.annotate.relevance.get.query import GetNextUrlForRelevanceAnnotationQueryBuilder +from sqlalchemy.orm import selectinload + +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.get.queries.core import GetNextURLForAllAnnotationQueryBuilder from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.api.endpoints.batch.duplicates.query import GetDuplicatesByBatchIDQueryBuilder @@ -28,108 +15,100 @@ from src.api.endpoints.collector.dtos.manual_batch.post import ManualBatchInputDTO from src.api.endpoints.collector.dtos.manual_batch.response import ManualBatchResponseDTO from src.api.endpoints.collector.manual.query import UploadManualBatchQueryBuilder +from src.api.endpoints.metrics.backlog.query import GetBacklogMetricsQueryBuilder from src.api.endpoints.metrics.batches.aggregated.dto import GetMetricsBatchesAggregatedResponseDTO -from src.api.endpoints.metrics.batches.aggregated.query import GetBatchesAggregatedMetricsQueryBuilder +from src.api.endpoints.metrics.batches.aggregated.query.core import GetBatchesAggregatedMetricsQueryBuilder from src.api.endpoints.metrics.batches.breakdown.dto import GetMetricsBatchesBreakdownResponseDTO from src.api.endpoints.metrics.batches.breakdown.query import GetBatchesBreakdownMetricsQueryBuilder -from src.api.endpoints.metrics.dtos.get.backlog import GetMetricsBacklogResponseDTO, GetMetricsBacklogResponseInnerDTO +from src.api.endpoints.metrics.dtos.get.backlog import GetMetricsBacklogResponseDTO from src.api.endpoints.metrics.dtos.get.urls.aggregated.core import GetMetricsURLsAggregatedResponseDTO -from src.api.endpoints.metrics.dtos.get.urls.breakdown.pending import GetMetricsURLsBreakdownPendingResponseDTO, \ - GetMetricsURLsBreakdownPendingResponseInnerDTO +from src.api.endpoints.metrics.dtos.get.urls.breakdown.pending import GetMetricsURLsBreakdownPendingResponseDTO from src.api.endpoints.metrics.dtos.get.urls.breakdown.submitted import GetMetricsURLsBreakdownSubmittedResponseDTO, \ GetMetricsURLsBreakdownSubmittedInnerDTO +from src.api.endpoints.metrics.urls.aggregated.query.core import GetURLsAggregatedMetricsQueryBuilder +from src.api.endpoints.metrics.urls.breakdown.query.core import GetURLsBreakdownPendingMetricsQueryBuilder from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.api.endpoints.review.approve.query import ApproveURLQueryBuilder +from src.api.endpoints.review.approve.query_.core import ApproveURLQueryBuilder from src.api.endpoints.review.enums import RejectionReason -from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse from src.api.endpoints.review.reject.query import RejectURLQueryBuilder from src.api.endpoints.search.dtos.response import SearchURLResponse from src.api.endpoints.task.by_id.dto import TaskInfo - from src.api.endpoints.task.by_id.query import GetTaskInfoQueryBuilder from src.api.endpoints.task.dtos.get.tasks import GetTasksResponse, GetTasksResponseTaskInfo from src.api.endpoints.url.get.dto import GetURLsResponseInfo - from src.api.endpoints.url.get.query import GetURLsQueryBuilder from src.collectors.enums import URLStatus, CollectorType -from src.core.enums import BatchStatus, SuggestionType, RecordType, SuggestedStatus +from src.collectors.queries.insert.urls.query import InsertURLsQueryBuilder +from src.core.enums import BatchStatus, RecordType from src.core.env_var_manager import EnvVarManager -from src.core.tasks.scheduled.operators.agency_sync.dtos.parameters import AgencySyncParameters +from src.core.tasks.scheduled.impl.huggingface.queries.state import SetHuggingFaceUploadStateQueryBuilder from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.core.tasks.url.operators.agency_identification.dtos.tdo import AgencyIdentificationTDO -from src.core.tasks.url.operators.agency_identification.queries.get_pending_urls_without_agency_suggestions import \ - GetPendingURLsWithoutAgencySuggestionsQueryBuilder -from src.core.tasks.url.operators.auto_relevant.models.tdo import URLRelevantTDO -from src.core.tasks.url.operators.auto_relevant.queries.get_tdos import GetAutoRelevantTDOsQueryBuilder -from src.core.tasks.url.operators.submit_approved_url.tdo import SubmitApprovedURLTDO, SubmittedURLInfo -from src.core.tasks.url.operators.url_404_probe.tdo import URL404ProbeTDO -from src.core.tasks.url.operators.url_duplicate.tdo import URLDuplicateTDO -from src.core.tasks.url.operators.url_html.queries.get_pending_urls_without_html_data import \ +from src.core.tasks.url.operators.html.queries.get import \ GetPendingURLsWithoutHTMLDataQueryBuilder -from src.core.tasks.url.operators.url_miscellaneous_metadata.queries.get_pending_urls_missing_miscellaneous_data import \ - GetPendingURLsMissingMiscellaneousDataQueryBuilder -from src.core.tasks.url.operators.url_miscellaneous_metadata.queries.has_pending_urls_missing_miscellaneous_data import \ - HasPendingURsMissingMiscellaneousDataQueryBuilder -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO +from src.core.tasks.url.operators.submit_approved.queries.mark_submitted import MarkURLsAsSubmittedQueryBuilder +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo from src.db.client.helpers import add_standard_limit_and_offset from src.db.client.types import UserSuggestionModel from src.db.config_manager import ConfigManager from src.db.constants import PLACEHOLDER_AGENCY_NAME from src.db.dto_converter import DTOConverter -from src.db.dtos.batch import BatchInfo -from src.db.dtos.duplicate import DuplicateInsertInfo, DuplicateInfo -from src.db.dtos.log import LogInfo, LogOutputInfo -from src.db.dtos.url.annotations.auto.relevancy import AutoRelevancyAnnotationInput -from src.db.dtos.url.core import URLInfo -from src.db.dtos.url.error import URLErrorPydanticInfo from src.db.dtos.url.html_content import URLHTMLContentInfo from src.db.dtos.url.insert import InsertURLsInfo -from src.db.dtos.url.mapping import URLMapping from src.db.dtos.url.raw_html import RawHTMLInfo from src.db.enums import TaskType -from src.db.models.instantiations.agency import Agency -from src.db.models.instantiations.backlog_snapshot import BacklogSnapshot -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.duplicate import Duplicate -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.link.link_task_url import LinkTaskURL -from src.db.models.instantiations.log import Log -from src.db.models.instantiations.root_url_cache import RootURL -from src.db.models.instantiations.sync_state_agencies import AgenciesSyncState -from src.db.models.instantiations.task.core import Task -from src.db.models.instantiations.task.error import TaskError -from src.db.models.instantiations.url.checked_for_duplicate import URLCheckedForDuplicate -from src.db.models.instantiations.url.compressed_html import URLCompressedHTML -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.data_source import URLDataSource -from src.db.models.instantiations.url.error_info import URLErrorInfo -from src.db.models.instantiations.url.html_content import URLHTMLContent -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.probed_for_404 import URLProbedFor404 -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.auto import AutoRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.auto import AutoRelevantSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from src.db.models.templates import Base +from src.db.helpers.session import session_helper as sh +from src.db.models.impl.agency.enums import AgencyType +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.backlog_snapshot import BacklogSnapshot +from src.db.models.impl.batch.pydantic.info import BatchInfo +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.duplicate.pydantic.info import DuplicateInfo +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.task_url import LinkTaskURL +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.log.pydantic.info import LogInfo +from src.db.models.impl.log.pydantic.output import LogOutputInfo +from src.db.models.impl.log.sqlalchemy import Log +from src.db.models.impl.task.core import Task +from src.db.models.impl.task.enums import TaskStatus +from src.db.models.impl.task.error import TaskError +from src.db.models.impl.url.checked_for_duplicate import URLCheckedForDuplicate +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.html.content.sqlalchemy import URLHTMLContent +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.auto.pydantic.input import AutoRelevancyAnnotationInput +from src.db.models.impl.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata +from src.db.models.templates_.base import Base +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.db.queries.base.builder import QueryBuilderBase -from src.api.endpoints.review.next.query import GetNextURLForFinalReviewQueryBuilder from src.db.queries.implementations.core.get.html_content_info import GetHTMLContentInfoQueryBuilder from src.db.queries.implementations.core.get.recent_batch_summaries.builder import GetRecentBatchSummariesQueryBuilder from src.db.queries.implementations.core.metrics.urls.aggregated.pending import \ GetMetricsURLSAggregatedPendingQueryBuilder -from src.db.queries.implementations.core.tasks.agency_sync.upsert import get_upsert_agencies_mappings +from src.db.queries.implementations.location.get import GetLocationQueryBuilder from src.db.statement_composer import StatementComposer +from src.db.templates.markers.bulk.delete import BulkDeletableModel +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel from src.db.utils.compression import decompress_html, compress_html -from src.external.pdap.dtos.agencies_sync import AgenciesSyncResponseInnerInfo class AsyncDatabaseClient: - def __init__(self, db_url: Optional[str] = None): + def __init__(self, db_url: str | None = None): if db_url is None: db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) + self.db_url = db_url echo = ConfigManager.get_sqlalchemy_echo() self.engine = create_async_engine( url=db_url, @@ -162,18 +141,27 @@ async def wrapper(self, *args, **kwargs): return wrapper - @session_manager async def execute(self, session: AsyncSession, statement): await session.execute(statement) @session_manager - async def add(self, session: AsyncSession, model: Base): - session.add(model) + async def add( + self, + session: AsyncSession, + model: Base, + return_id: bool = False + ) -> int | None: + return await sh.add(session=session, model=model, return_id=return_id) @session_manager - async def add_all(self, session: AsyncSession, models: list[Base]): - session.add_all(models) + async def add_all( + self, + session: AsyncSession, + models: list[Base], + return_ids: bool = False + ) -> list[int] | None: + return await sh.add_all(session=session, models=models, return_ids=return_ids) @session_manager async def bulk_update( @@ -192,42 +180,43 @@ async def bulk_update( async def bulk_upsert( self, session: AsyncSession, - model: Base, - mappings: list[dict], - id_value: str = "id" + models: list[BulkUpsertableModel], ): + return await sh.bulk_upsert(session, models) - query = pg_insert(model) - - set_ = {} - for k, v in mappings[0].items(): - if k == id_value: - continue - set_[k] = getattr(query.excluded, k) - - query = query.on_conflict_do_update( - index_elements=[id_value], - set_=set_ - ) - + @session_manager + async def bulk_delete( + self, + session: AsyncSession, + models: list[BulkDeletableModel], + ): + return await sh.bulk_delete(session, models) - # Note, mapping must include primary key - await session.execute( - query, - mappings - ) + @session_manager + async def bulk_insert( + self, + session: AsyncSession, + models: list[BulkInsertableModel], + return_ids: bool = False + ) -> list[int] | None: + return await sh.bulk_insert(session, models=models, return_ids=return_ids) @session_manager async def scalar(self, session: AsyncSession, statement): - return (await session.execute(statement)).scalar() + """Fetch the first column of the first row.""" + return await sh.scalar(session, statement) @session_manager async def scalars(self, session: AsyncSession, statement): - return (await session.execute(statement)).scalars().all() + return await sh.scalars(session, statement) @session_manager async def mapping(self, session: AsyncSession, statement): - return (await session.execute(statement)).mappings().one() + return await sh.mapping(session, statement) + + @session_manager + async def one_or_none(self, session: AsyncSession, statement): + return await sh.one_or_none(session, statement) @session_manager async def run_query_builder( @@ -265,7 +254,7 @@ async def get_user_suggestion( model: UserSuggestionModel, user_id: int, url_id: int - ) -> Optional[UserSuggestionModel]: + ) -> UserSuggestionModel | None: statement = Select(model).where( and_( model.url_id == url_id, @@ -275,103 +264,35 @@ async def get_user_suggestion( result = await session.execute(statement) return result.unique().scalar_one_or_none() - async def get_next_url_for_user_annotation( - self, - user_suggestion_model_to_exclude: UserSuggestionModel, - auto_suggestion_relationship: QueryableAttribute, - batch_id: Optional[int], - check_if_annotated_not_relevant: bool = False - ) -> URL: - return await self.run_query_builder( - builder=GetNextURLForUserAnnotationQueryBuilder( - user_suggestion_model_to_exclude=user_suggestion_model_to_exclude, - auto_suggestion_relationship=auto_suggestion_relationship, - batch_id=batch_id, - check_if_annotated_not_relevant=check_if_annotated_not_relevant - ) - ) - - async def get_tdos_for_auto_relevancy(self) -> list[URLRelevantTDO]: - return await self.run_query_builder(builder=GetAutoRelevantTDOsQueryBuilder()) - @session_manager async def add_user_relevant_suggestion( self, session: AsyncSession, url_id: int, user_id: int, - suggested_status: SuggestedStatus + suggested_status: URLType ): prior_suggestion = await self.get_user_suggestion( session, - model=UserRelevantSuggestion, + model=UserURLTypeSuggestion, user_id=user_id, url_id=url_id ) if prior_suggestion is not None: - prior_suggestion.suggested_status = suggested_status.value + prior_suggestion.type = suggested_status.value return - suggestion = UserRelevantSuggestion( + suggestion = UserURLTypeSuggestion( url_id=url_id, user_id=user_id, - suggested_status=suggested_status.value + type=suggested_status.value ) session.add(suggestion) - async def get_next_url_for_relevance_annotation( - self, - batch_id: int | None, - user_id: int | None = None, - ) -> GetNextRelevanceAnnotationResponseInfo | None: - return await self.run_query_builder(GetNextUrlForRelevanceAnnotationQueryBuilder(batch_id)) - # endregion relevant # region record_type - @session_manager - async def get_next_url_for_record_type_annotation( - self, - session: AsyncSession, - user_id: int, - batch_id: Optional[int] - ) -> Optional[GetNextRecordTypeAnnotationResponseInfo]: - - url = await GetNextURLForUserAnnotationQueryBuilder( - user_suggestion_model_to_exclude=UserRecordTypeSuggestion, - auto_suggestion_relationship=URL.auto_record_type_suggestion, - batch_id=batch_id, - check_if_annotated_not_relevant=True - ).run(session) - if url is None: - return None - - # Next, get all HTML content for the URL - html_response_info = DTOConverter.html_content_list_to_html_response_info( - url.html_content - ) - - if url.auto_record_type_suggestion is not None: - suggestion = url.auto_record_type_suggestion.record_type - else: - suggestion = None - - return GetNextRecordTypeAnnotationResponseInfo( - url_info=URLMapping( - url=url.url, - url_id=url.id - ), - suggested_record_type=suggestion, - html_info=html_response_info, - batch_info=await GetAnnotationBatchInfoQueryBuilder( - batch_id=batch_id, - models=[ - UserUrlAgencySuggestion, - ] - ).run(session) - ) - @session_manager async def add_auto_record_type_suggestions( self, @@ -423,57 +344,18 @@ async def add_user_record_type_suggestion( # endregion record_type - @session_manager - async def add_url_error_infos(self, session: AsyncSession, url_error_infos: list[URLErrorPydanticInfo]): - for url_error_info in url_error_infos: - statement = select(URL).where(URL.id == url_error_info.url_id) - scalar_result = await session.scalars(statement) - url = scalar_result.first() - url.outcome = URLStatus.ERROR.value - - url_error = URLErrorInfo(**url_error_info.model_dump()) - session.add(url_error) - - @session_manager - async def get_urls_with_errors(self, session: AsyncSession) -> list[URLErrorPydanticInfo]: - statement = (select(URL, URLErrorInfo.error, URLErrorInfo.updated_at, URLErrorInfo.task_id) - .join(URLErrorInfo) - .where(URL.outcome == URLStatus.ERROR.value) - .order_by(URL.id)) - scalar_result = await session.execute(statement) - results = scalar_result.all() - final_results = [] - for url, error, updated_at, task_id in results: - final_results.append( - URLErrorPydanticInfo( - url_id=url.id, - error=error, - updated_at=updated_at, - task_id=task_id - ) - ) - - return final_results @session_manager async def add_html_content_infos(self, session: AsyncSession, html_content_infos: list[URLHTMLContentInfo]): await self._add_models(session, URLHTMLContent, html_content_infos) @session_manager - async def has_pending_urls_without_html_data(self, session: AsyncSession) -> bool: - statement = self.statement_composer.pending_urls_without_html_data() + async def has_non_errored_urls_without_html_data(self, session: AsyncSession) -> bool: + statement = self.statement_composer.has_non_errored_urls_without_html_data() statement = statement.limit(1) scalar_result = await session.scalars(statement) return bool(scalar_result.first()) - async def has_pending_urls_missing_miscellaneous_metadata(self) -> bool: - return await self.run_query_builder(HasPendingURsMissingMiscellaneousDataQueryBuilder()) - - async def get_pending_urls_missing_miscellaneous_metadata( - self, - ) -> list[URLMiscellaneousMetadataTDO]: - return await self.run_query_builder(GetPendingURLsMissingMiscellaneousDataQueryBuilder()) - @session_manager async def add_miscellaneous_metadata(self, session: AsyncSession, tdos: list[URLMiscellaneousMetadataTDO]): updates = [] @@ -502,7 +384,7 @@ async def add_miscellaneous_metadata(self, session: AsyncSession, tdos: list[URL ) session.add(metadata_object) - async def get_pending_urls_without_html_data(self) -> list[URLInfo]: + async def get_non_errored_urls_without_html_data(self) -> list[URLInfo]: return await self.run_query_builder(GetPendingURLsWithoutHTMLDataQueryBuilder()) async def get_urls_with_html_data_and_without_models( @@ -512,7 +394,7 @@ async def get_urls_with_html_data_and_without_models( ): statement = (select(URL) .options(selectinload(URL.html_content)) - .where(URL.outcome == URLStatus.PENDING.value)) + .where(URL.status == URLStatus.OK.value)) statement = self.statement_composer.exclude_urls_with_extant_model( statement=statement, model=model @@ -534,7 +416,6 @@ async def get_urls_with_html_data_and_without_auto_record_type_suggestion( model=AutoRecordTypeSuggestion ) - async def has_urls_with_html_data_and_without_models( self, session: AsyncSession, @@ -542,7 +423,7 @@ async def has_urls_with_html_data_and_without_models( ) -> bool: statement = (select(URL) .join(URLCompressedHTML) - .where(URL.outcome == URLStatus.PENDING.value)) + .where(URL.status == URLStatus.OK.value)) # Exclude URLs with auto suggested record types statement = self.statement_composer.exclude_urls_with_extant_model( statement=statement, @@ -552,13 +433,6 @@ async def has_urls_with_html_data_and_without_models( scalar_result = await session.scalars(statement) return bool(scalar_result.first()) - @session_manager - async def has_urls_with_html_data_and_without_auto_relevant_suggestion(self, session: AsyncSession) -> bool: - return await self.has_urls_with_html_data_and_without_models( - session=session, - model=AutoRelevantSuggestion - ) - @session_manager async def has_urls_with_html_data_and_without_auto_record_type_suggestion(self, session: AsyncSession) -> bool: return await self.has_urls_with_html_data_and_without_models( @@ -571,41 +445,21 @@ async def get_all( self, session, model: Base, - order_by_attribute: Optional[str] = None + order_by_attribute: str | None = None ) -> list[Base]: - """ - Get all records of a model - Used primarily in testing - """ - statement = select(model) - if order_by_attribute: - statement = statement.order_by(getattr(model, order_by_attribute)) - result = await session.execute(statement) - return result.scalars().all() - - @session_manager - async def load_root_url_cache(self, session: AsyncSession) -> dict[str, str]: - statement = select(RootURL) - scalar_result = await session.scalars(statement) - model_result = scalar_result.all() - d = {} - for result in model_result: - d[result.url] = result.page_title - return d - - async def add_to_root_url_cache(self, url: str, page_title: str) -> None: - cache = RootURL(url=url, page_title=page_title) - await self.add(cache) + """Get all records of a model. Used primarily in testing.""" + return await sh.get_all(session=session, model=model, order_by_attribute=order_by_attribute) async def get_urls( self, page: int, errors: bool ) -> GetURLsResponseInfo: - return await self.run_query_builder(GetURLsQueryBuilder( - page=page, errors=errors - )) - + return await self.run_query_builder( + GetURLsQueryBuilder( + page=page, errors=errors + ) + ) @session_manager async def initiate_task( @@ -625,7 +479,13 @@ async def initiate_task( return task.id @session_manager - async def update_task_status(self, session: AsyncSession, task_id: int, status: BatchStatus): + async def update_task_status( + self, + session: + AsyncSession, + task_id: int, + status: TaskStatus + ): task = await session.get(Task, task_id) task.task_status = status.value @@ -646,7 +506,12 @@ async def get_html_content_info(self, url_id: int) -> list[URLHTMLContentInfo]: return await self.run_query_builder(GetHTMLContentInfoQueryBuilder(url_id)) @session_manager - async def link_urls_to_task(self, session: AsyncSession, task_id: int, url_ids: list[int]): + async def link_urls_to_task( + self, + session: AsyncSession, + task_id: int, + url_ids: list[int] + ) -> None: for url_id in url_ids: link = LinkTaskURL( url_id=url_id, @@ -658,8 +523,8 @@ async def link_urls_to_task(self, session: AsyncSession, task_id: int, url_ids: async def get_tasks( self, session: AsyncSession, - task_type: Optional[TaskType] = None, - task_status: Optional[BatchStatus] = None, + task_type: TaskType | None = None, + task_status: BatchStatus | None = None, page: int = 1 ) -> GetTasksResponse: url_count_subquery = self.statement_composer.simple_count_subquery( @@ -669,7 +534,7 @@ async def get_tasks( ) url_error_count_subquery = self.statement_composer.simple_count_subquery( - URLErrorInfo, + URLTaskError, 'task_id', 'url_error_count' ) @@ -709,42 +574,6 @@ async def get_tasks( tasks=final_results ) - @session_manager - async def has_urls_without_agency_suggestions( - self, - session: AsyncSession - ) -> bool: - statement = ( - select( - URL.id - ).where( - URL.outcome == URLStatus.PENDING.value - ) - ) - - statement = self.statement_composer.exclude_urls_with_agency_suggestions(statement) - raw_result = await session.execute(statement) - result = raw_result.all() - return len(result) != 0 - - async def get_urls_without_agency_suggestions( - self - ) -> list[AgencyIdentificationTDO]: - """Retrieve URLs without confirmed or suggested agencies.""" - return await self.run_query_builder(GetPendingURLsWithoutAgencySuggestionsQueryBuilder()) - - - async def get_next_url_agency_for_annotation( - self, - user_id: int, - batch_id: int | None - ) -> GetNextURLForAgencyAnnotationResponse: - return await self.run_query_builder(builder=GetNextURLAgencyForAnnotationQueryBuilder( - user_id=user_id, - batch_id=batch_id - )) - - @session_manager async def upsert_new_agencies( self, @@ -755,14 +584,14 @@ async def upsert_new_agencies( Add or update agencies in the database """ for suggestion in suggestions: - agency = Agency( - agency_id=suggestion.pdap_agency_id, - name=suggestion.agency_name, - state=suggestion.state, - county=suggestion.county, - locality=suggestion.locality - ) - await session.merge(agency) + query = select(Agency).where(Agency.agency_id == suggestion.pdap_agency_id) + result = await session.execute(query) + agency = result.scalars().one_or_none() + if agency is None: + agency = Agency(agency_id=suggestion.pdap_agency_id) + agency.name = suggestion.agency_name + agency.agency_type = AgencyType.UNKNOWN + session.add(agency) @session_manager async def add_confirmed_agency_url_links( @@ -771,26 +600,12 @@ async def add_confirmed_agency_url_links( suggestions: list[URLAgencySuggestionInfo] ): for suggestion in suggestions: - confirmed_agency = ConfirmedURLAgency( + confirmed_agency = LinkURLAgency( url_id=suggestion.url_id, agency_id=suggestion.pdap_agency_id ) session.add(confirmed_agency) - @session_manager - async def add_agency_auto_suggestions( - self, - session: AsyncSession, - suggestions: list[URLAgencySuggestionInfo] - ): - for suggestion in suggestions: - url_agency_suggestion = AutomatedUrlAgencySuggestion( - url_id=suggestion.url_id, - agency_id=suggestion.pdap_agency_id, - is_unknown=suggestion.suggestion_type == SuggestionType.UNKNOWN - ) - session.add(url_agency_suggestion) - @session_manager async def add_agency_manual_suggestion( self, @@ -810,7 +625,8 @@ async def add_agency_manual_suggestion( if len(result.all()) == 0: agency = Agency( agency_id=agency_id, - name=PLACEHOLDER_AGENCY_NAME + name=PLACEHOLDER_AGENCY_NAME, + agency_type=AgencyType.UNKNOWN, ) await session.merge(agency) @@ -824,32 +640,21 @@ async def add_agency_manual_suggestion( @session_manager async def get_urls_with_confirmed_agencies(self, session: AsyncSession) -> list[URL]: - statement = select(URL).where(exists().where(ConfirmedURLAgency.url_id == URL.id)) + statement = select(URL).where(exists().where(LinkURLAgency.url_id == URL.id)) results = await session.execute(statement) return list(results.scalars().all()) - @session_manager - async def get_next_url_for_final_review( - self, - session: AsyncSession, - batch_id: Optional[int] - ) -> GetNextURLForFinalReviewOuterResponse: - - builder = GetNextURLForFinalReviewQueryBuilder( - batch_id=batch_id - ) - result = await builder.run(session) - return result - async def approve_url( self, approval_info: FinalReviewApprovalInfo, user_id: int, ) -> None: - await self.run_query_builder(ApproveURLQueryBuilder( - user_id=user_id, - approval_info=approval_info - )) + await self.run_query_builder( + ApproveURLQueryBuilder( + user_id=user_id, + approval_info=approval_info + ) + ) async def reject_url( self, @@ -857,12 +662,13 @@ async def reject_url( user_id: int, rejection_reason: RejectionReason ) -> None: - await self.run_query_builder(RejectURLQueryBuilder( - url_id=url_id, - user_id=user_id, - rejection_reason=rejection_reason - )) - + await self.run_query_builder( + RejectURLQueryBuilder( + url_id=url_id, + user_id=user_id, + rejection_reason=rejection_reason + ) + ) @session_manager async def get_batch_by_id(self, session, batch_id: int) -> Optional[BatchSummary]: @@ -878,45 +684,19 @@ async def get_batch_by_id(self, session, batch_id: int) -> Optional[BatchSummary async def get_urls_by_batch(self, batch_id: int, page: int = 1) -> list[URLInfo]: """Retrieve all URLs associated with a batch.""" - return await self.run_query_builder(GetURLsByBatchQueryBuilder( - batch_id=batch_id, - page=page - )) - - @session_manager - async def insert_url(self, session: AsyncSession, url_info: URLInfo) -> int: - """Insert a new URL into the database.""" - url_entry = URL( - url=url_info.url, - collector_metadata=url_info.collector_metadata, - outcome=url_info.outcome.value - ) - if url_info.created_at is not None: - url_entry.created_at = url_info.created_at - session.add(url_entry) - await session.flush() - link = LinkBatchURL( - batch_id=url_info.batch_id, - url_id=url_entry.id + return await self.run_query_builder( + GetURLsByBatchQueryBuilder( + batch_id=batch_id, + page=page + ) ) - return url_entry.id - - @session_manager - async def get_url_info_by_url(self, session: AsyncSession, url: str) -> Optional[URLInfo]: - query = Select(URL).where(URL.url == url) - raw_result = await session.execute(query) - url = raw_result.scalars().first() - return URLInfo(**url.__dict__) @session_manager - async def get_url_info_by_id(self, session: AsyncSession, url_id: int) -> Optional[URLInfo]: - query = Select(URL).where(URL.id == url_id) - raw_result = await session.execute(query) - url = raw_result.scalars().first() - return URLInfo(**url.__dict__) - - @session_manager - async def insert_logs(self, session, log_infos: List[LogInfo]): + async def insert_logs( + self, + session: AsyncSession, + log_infos: list[LogInfo] + ) -> None: for log_info in log_infos: log = Log(log=log_info.log, batch_id=log_info.batch_id) if log_info.created_at is not None: @@ -924,16 +704,11 @@ async def insert_logs(self, session, log_infos: List[LogInfo]): session.add(log) @session_manager - async def insert_duplicates(self, session, duplicate_infos: list[DuplicateInsertInfo]): - for duplicate_info in duplicate_infos: - duplicate = Duplicate( - batch_id=duplicate_info.duplicate_batch_id, - original_url_id=duplicate_info.original_url_id, - ) - session.add(duplicate) - - @session_manager - async def insert_batch(self, session: AsyncSession, batch_info: BatchInfo) -> int: + async def insert_batch( + self, + session: AsyncSession, + batch_info: BatchInfo + ) -> int: """Insert a new batch into the database and return its ID.""" batch = Batch( strategy=batch_info.strategy, @@ -941,11 +716,6 @@ async def insert_batch(self, session: AsyncSession, batch_info: BatchInfo) -> in status=batch_info.status.value, parameters=batch_info.parameters, compute_time=batch_info.compute_time, - strategy_success_rate=0, - metadata_success_rate=0, - agency_match_rate=0, - record_type_match_rate=0, - record_category_match_rate=0, ) if batch_info.date_generated is not None: batch.date_generated = batch_info.date_generated @@ -953,42 +723,28 @@ async def insert_batch(self, session: AsyncSession, batch_info: BatchInfo) -> in await session.flush() return batch.id - async def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo: - url_mappings = [] - duplicates = [] - for url_info in url_infos: - url_info.batch_id = batch_id - try: - url_id = await self.insert_url(url_info) - url_mappings.append(URLMapping(url_id=url_id, url=url_info.url)) - except IntegrityError: - orig_url_info = await self.get_url_info_by_url(url_info.url) - duplicate_info = DuplicateInsertInfo( - duplicate_batch_id=batch_id, - original_url_id=orig_url_info.id - ) - duplicates.append(duplicate_info) - await self.insert_duplicates(duplicates) - - return InsertURLsInfo( - url_mappings=url_mappings, - total_count=len(url_infos), - original_count=len(url_mappings), - duplicate_count=len(duplicates), - url_ids=[url_mapping.url_id for url_mapping in url_mappings] + async def insert_urls( + self, + url_infos: list[URLInfo], + batch_id: int + ) -> InsertURLsInfo: + builder = InsertURLsQueryBuilder( + url_infos=url_infos, + batch_id=batch_id ) + return await self.run_query_builder(builder) @session_manager async def update_batch_post_collection( self, - session, + session: AsyncSession, batch_id: int, total_url_count: int, original_url_count: int, duplicate_url_count: int, batch_status: BatchStatus, compute_time: float = None, - ): + ) -> None: query = Select(Batch).where(Batch.id == batch_id) result = await session.execute(query) @@ -1000,108 +756,30 @@ async def update_batch_post_collection( batch.status = batch_status.value batch.compute_time = compute_time - @session_manager - async def has_validated_urls(self, session: AsyncSession) -> bool: - query = ( - select(URL) - .where(URL.outcome == URLStatus.VALIDATED.value) - ) - urls = await session.execute(query) - urls = urls.scalars().all() - return len(urls) > 0 - - @session_manager - async def get_validated_urls( - self, - session: AsyncSession - ) -> list[SubmitApprovedURLTDO]: - query = ( - select(URL) - .where(URL.outcome == URLStatus.VALIDATED.value) - .options( - selectinload(URL.optional_data_source_metadata), - selectinload(URL.confirmed_agencies), - selectinload(URL.reviewing_user) - ).limit(100) - ) - urls = await session.execute(query) - urls = urls.scalars().all() - results: list[SubmitApprovedURLTDO] = [] - for url in urls: - agency_ids = [] - for agency in url.confirmed_agencies: - agency_ids.append(agency.agency_id) - optional_metadata = url.optional_data_source_metadata - - if optional_metadata is None: - record_formats = None - data_portal_type = None - supplying_entity = None - else: - record_formats = optional_metadata.record_formats - data_portal_type = optional_metadata.data_portal_type - supplying_entity = optional_metadata.supplying_entity - - tdo = SubmitApprovedURLTDO( - url_id=url.id, - url=url.url, - name=url.name, - agency_ids=agency_ids, - description=url.description, - record_type=url.record_type, - record_formats=record_formats, - data_portal_type=data_portal_type, - supplying_entity=supplying_entity, - approving_user_id=url.reviewing_user.user_id - ) - results.append(tdo) - return results - - @session_manager - async def mark_urls_as_submitted(self, session: AsyncSession, infos: list[SubmittedURLInfo]): - for info in infos: - url_id = info.url_id - data_source_id = info.data_source_id - - query = ( - update(URL) - .where(URL.id == url_id) - .values( - outcome=URLStatus.SUBMITTED.value - ) - ) - - url_data_source_object = URLDataSource( - url_id=url_id, - data_source_id=data_source_id - ) - if info.submitted_at is not None: - url_data_source_object.created_at = info.submitted_at - session.add(url_data_source_object) - - await session.execute(query) + async def mark_urls_as_submitted(self, infos: list[SubmittedURLInfo]): + await self.run_query_builder(MarkURLsAsSubmittedQueryBuilder(infos)) async def get_duplicates_by_batch_id(self, batch_id: int, page: int) -> list[DuplicateInfo]: - return await self.run_query_builder(GetDuplicatesByBatchIDQueryBuilder( - batch_id=batch_id, - page=page - )) + return await self.run_query_builder( + GetDuplicatesByBatchIDQueryBuilder( + batch_id=batch_id, + page=page + ) + ) @session_manager async def get_batch_summaries( self, session, page: int, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - has_pending_urls: Optional[bool] = None + collector_type: CollectorType | None = None, + status: BatchURLStatusEnum | None = None, ) -> GetBatchSummariesResponse: # Get only the batch_id, collector_type, status, and created_at builder = GetRecentBatchSummariesQueryBuilder( page=page, collector_type=collector_type, status=status, - has_pending_urls=has_pending_urls ) summaries = await builder.run(session) return GetBatchSummariesResponse( @@ -1125,56 +803,28 @@ async def delete_old_logs(self): await self.execute(statement) async def get_next_url_for_all_annotations( - self, batch_id: int | None = None - ) -> GetNextURLForAllAnnotationResponse: - return await self.run_query_builder(GetNextURLForAllAnnotationQueryBuilder(batch_id)) - - @session_manager - async def add_all_annotations_to_url( self, - session, user_id: int, - url_id: int, - post_info: AllAnnotationPostInfo - ): - - # Add relevant annotation - relevant_suggestion = UserRelevantSuggestion( - url_id=url_id, - user_id=user_id, - suggested_status=post_info.suggested_status.value - ) - session.add(relevant_suggestion) - - # If not relevant, do nothing else - if not post_info.suggested_status == SuggestedStatus.RELEVANT: - return - - record_type_suggestion = UserRecordTypeSuggestion( - url_id=url_id, - user_id=user_id, - record_type=post_info.record_type.value - ) - session.add(record_type_suggestion) - - agency_suggestion = UserUrlAgencySuggestion( - url_id=url_id, + batch_id: int | None = None, + url_id: int | None = None + ) -> GetNextURLForAllAnnotationResponse: + return await self.run_query_builder(GetNextURLForAllAnnotationQueryBuilder( + batch_id=batch_id, user_id=user_id, - agency_id=post_info.agency.suggested_agency, - is_new=post_info.agency.is_new - ) - session.add(agency_suggestion) + url_id=url_id + )) async def upload_manual_batch( self, user_id: int, dto: ManualBatchInputDTO ) -> ManualBatchResponseDTO: - return await self.run_query_builder(UploadManualBatchQueryBuilder( - user_id=user_id, - dto=dto - )) - + return await self.run_query_builder( + UploadManualBatchQueryBuilder( + user_id=user_id, + dto=dto + ) + ) @session_manager async def search_for_url(self, session: AsyncSession, url: str) -> SearchURLResponse: @@ -1196,7 +846,6 @@ async def get_batches_aggregated_metrics(self) -> GetMetricsBatchesAggregatedRes GetBatchesAggregatedMetricsQueryBuilder() ) - async def get_batches_breakdown_metrics( self, page: int @@ -1238,187 +887,16 @@ async def get_urls_breakdown_submitted_metrics( entries=final_results ) - @session_manager - async def get_urls_aggregated_metrics( - self, - session: AsyncSession - ) -> GetMetricsURLsAggregatedResponseDTO: - sc = StatementComposer - - oldest_pending_url_query = select( - URL.id, - URL.created_at - ).where( - URL.outcome == URLStatus.PENDING.value - ).order_by( - URL.created_at.asc() - ).limit(1) - - oldest_pending_url = await session.execute(oldest_pending_url_query) - oldest_pending_url = oldest_pending_url.one_or_none() - if oldest_pending_url is None: - oldest_pending_url_id = None - oldest_pending_created_at = None - else: - oldest_pending_url_id = oldest_pending_url.id - oldest_pending_created_at = oldest_pending_url.created_at - - def case_column(status: URLStatus, label): - return sc.count_distinct( - case( - ( - URL.outcome == status.value, - URL.id - ) - ), - label=label - ) - - count_query = select( - sc.count_distinct(URL.id, label="count"), - case_column(URLStatus.PENDING, label="count_pending"), - case_column(URLStatus.SUBMITTED, label="count_submitted"), - case_column(URLStatus.VALIDATED, label="count_validated"), - case_column(URLStatus.NOT_RELEVANT, label="count_rejected"), - case_column(URLStatus.ERROR, label="count_error"), - ) - raw_results = await session.execute(count_query) - results = raw_results.all() - - return GetMetricsURLsAggregatedResponseDTO( - count_urls_total=results[0].count, - count_urls_pending=results[0].count_pending, - count_urls_submitted=results[0].count_submitted, - count_urls_validated=results[0].count_validated, - count_urls_rejected=results[0].count_rejected, - count_urls_errors=results[0].count_error, - oldest_pending_url_id=oldest_pending_url_id, - oldest_pending_url_created_at=oldest_pending_created_at, - ) - - def compile(self, statement): - compiled_sql = statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}) - return compiled_sql - - @session_manager - async def get_urls_breakdown_pending_metrics( - self, - session: AsyncSession - ) -> GetMetricsURLsBreakdownPendingResponseDTO: - sc = StatementComposer - - flags = ( - select( - URL.id.label("url_id"), - case((UserRecordTypeSuggestion.url_id != None, literal(True)), else_=literal(False)).label( - "has_user_record_type_annotation" - ), - case((UserRelevantSuggestion.url_id != None, literal(True)), else_=literal(False)).label( - "has_user_relevant_annotation" - ), - case((UserUrlAgencySuggestion.url_id != None, literal(True)), else_=literal(False)).label( - "has_user_agency_annotation" - ), - ) - .outerjoin(UserRecordTypeSuggestion, URL.id == UserRecordTypeSuggestion.url_id) - .outerjoin(UserRelevantSuggestion, URL.id == UserRelevantSuggestion.url_id) - .outerjoin(UserUrlAgencySuggestion, URL.id == UserUrlAgencySuggestion.url_id) - ).cte("flags") + async def get_urls_aggregated_metrics(self) -> GetMetricsURLsAggregatedResponseDTO: + return await self.run_query_builder(GetURLsAggregatedMetricsQueryBuilder()) - month = func.date_trunc('month', URL.created_at) + async def get_urls_breakdown_pending_metrics(self) -> GetMetricsURLsBreakdownPendingResponseDTO: + return await self.run_query_builder(GetURLsBreakdownPendingMetricsQueryBuilder()) - # Build the query - query = ( - select( - month.label('month'), - func.count(URL.id).label('count_total'), - func.count( - case( - (flags.c.has_user_record_type_annotation == True, 1) - ) - ).label('user_record_type_count'), - func.count( - case( - (flags.c.has_user_relevant_annotation == True, 1) - ) - ).label('user_relevant_count'), - func.count( - case( - (flags.c.has_user_agency_annotation == True, 1) - ) - ).label('user_agency_count'), - ) - .outerjoin(flags, flags.c.url_id == URL.id) - .where(URL.outcome == URLStatus.PENDING.value) - .group_by(month) - .order_by(month.asc()) - ) - - # Execute the query and return the results - results = await session.execute(query) - all_results = results.all() - final_results: list[GetMetricsURLsBreakdownPendingResponseInnerDTO] = [] - - for result in all_results: - dto = GetMetricsURLsBreakdownPendingResponseInnerDTO( - month=result.month.strftime("%B %Y"), - count_pending_total=result.count_total, - count_pending_relevant_user=result.user_relevant_count, - count_pending_record_type_user=result.user_record_type_count, - count_pending_agency_user=result.user_agency_count, - ) - final_results.append(dto) - return GetMetricsURLsBreakdownPendingResponseDTO( - entries=final_results, - ) - - @session_manager async def get_backlog_metrics( self, - session: AsyncSession ) -> GetMetricsBacklogResponseDTO: - month = func.date_trunc('month', BacklogSnapshot.created_at) - - # 1. Create a subquery that assigns row_number() partitioned by month - monthly_snapshot_subq = ( - select( - BacklogSnapshot.id, - BacklogSnapshot.created_at, - BacklogSnapshot.count_pending_total, - month.label("month_start"), - func.row_number() - .over( - partition_by=month, - order_by=BacklogSnapshot.created_at.desc() - ) - .label("row_number") - ) - .subquery() - ) - - # 2. Filter for the top (most recent) row in each month - stmt = ( - select( - monthly_snapshot_subq.c.month_start, - monthly_snapshot_subq.c.created_at, - monthly_snapshot_subq.c.count_pending_total - ) - .where(monthly_snapshot_subq.c.row_number == 1) - .order_by(monthly_snapshot_subq.c.month_start) - ) - - raw_result = await session.execute(stmt) - results = raw_result.all() - final_results = [] - for result in results: - final_results.append( - GetMetricsBacklogResponseInnerDTO( - month=result.month_start.strftime("%B %Y"), - count_pending_total=result.count_pending_total, - ) - ) - - return GetMetricsBacklogResponseDTO(entries=final_results) + return await self.run_query_builder(GetBacklogMetricsQueryBuilder()) @session_manager async def populate_backlog_snapshot( @@ -1428,10 +906,15 @@ async def populate_backlog_snapshot( ): sc = StatementComposer # Get count of pending URLs - query = select( - sc.count_distinct(URL.id, label="count") - ).where( - URL.outcome == URLStatus.PENDING.value + query = ( + select( + sc.count_distinct(URL.id, label="count") + ) + .outerjoin(FlagURLValidated, URL.id == FlagURLValidated.url_id) + .where( + URL.status == URLStatus.OK.value, + FlagURLValidated.url_id.is_(None), + ) ) raw_result = await session.execute(query) @@ -1446,176 +929,19 @@ async def populate_backlog_snapshot( session.add(snapshot) - @session_manager - async def has_pending_urls_not_checked_for_duplicates(self, session: AsyncSession) -> bool: - query = (select( - URL.id - ).outerjoin( - URLCheckedForDuplicate, - URL.id == URLCheckedForDuplicate.url_id - ).where( - URL.outcome == URLStatus.PENDING.value, - URLCheckedForDuplicate.id == None - ).limit(1) - ) - - raw_result = await session.execute(query) - result = raw_result.one_or_none() - return result is not None - - @session_manager - async def get_pending_urls_not_checked_for_duplicates(self, session: AsyncSession) -> List[URLDuplicateTDO]: - query = (select( - URL - ).outerjoin( - URLCheckedForDuplicate, - URL.id == URLCheckedForDuplicate.url_id - ).where( - URL.outcome == URLStatus.PENDING.value, - URLCheckedForDuplicate.id == None - ).limit(100) - ) - - raw_result = await session.execute(query) - urls = raw_result.scalars().all() - return [URLDuplicateTDO(url=url.url, url_id=url.id) for url in urls] - - async def mark_all_as_duplicates(self, url_ids: List[int]): - query = update(URL).where(URL.id.in_(url_ids)).values(outcome=URLStatus.DUPLICATE.value) - await self.execute(query) - async def mark_all_as_404(self, url_ids: List[int]): - query = update(URL).where(URL.id.in_(url_ids)).values(outcome=URLStatus.NOT_FOUND.value) + query = update(URLWebMetadata).where(URLWebMetadata.url_id.in_(url_ids)).values(status_code=404) await self.execute(query) - async def mark_all_as_recently_probed_for_404( - self, - url_ids: List[int], - dt: datetime = func.now() - ): - values = [ - {"url_id": url_id, "last_probed_at": dt} for url_id in url_ids - ] - stmt = pg_insert(URLProbedFor404).values(values) - update_stmt = stmt.on_conflict_do_update( - index_elements=['url_id'], - set_={"last_probed_at": dt} - ) - await self.execute(update_stmt) - @session_manager async def mark_as_checked_for_duplicates(self, session: AsyncSession, url_ids: list[int]): for url_id in url_ids: url_checked_for_duplicate = URLCheckedForDuplicate(url_id=url_id) session.add(url_checked_for_duplicate) - @session_manager - async def has_pending_urls_not_recently_probed_for_404(self, session: AsyncSession) -> bool: - month_ago = func.now() - timedelta(days=30) - query = ( - select( - URL.id - ).outerjoin( - URLProbedFor404 - ).where( - and_( - URL.outcome == URLStatus.PENDING.value, - or_( - URLProbedFor404.id == None, - URLProbedFor404.last_probed_at < month_ago - ) - ) - ).limit(1) - ) - - raw_result = await session.execute(query) - result = raw_result.one_or_none() - return result is not None - - @session_manager - async def get_pending_urls_not_recently_probed_for_404(self, session: AsyncSession) -> List[URL404ProbeTDO]: - month_ago = func.now() - timedelta(days=30) - query = ( - select( - URL - ).outerjoin( - URLProbedFor404 - ).where( - and_( - URL.outcome == URLStatus.PENDING.value, - or_( - URLProbedFor404.id == None, - URLProbedFor404.last_probed_at < month_ago - ) - ) - ).limit(100) - ) - - raw_result = await session.execute(query) - urls = raw_result.scalars().all() - return [URL404ProbeTDO(url=url.url, url_id=url.id) for url in urls] - @session_manager - async def get_urls_aggregated_pending_metrics( - self, - session: AsyncSession - ): - builder = GetMetricsURLSAggregatedPendingQueryBuilder() - result = await builder.run( - session=session - ) - return result - - @session_manager - async def get_agencies_sync_parameters( - self, - session: AsyncSession - ) -> AgencySyncParameters: - query = select( - AgenciesSyncState.current_page, - AgenciesSyncState.current_cutoff_date - ) - try: - result = (await session.execute(query)).mappings().one() - return AgencySyncParameters( - page=result['current_page'], - cutoff_date=result['current_cutoff_date'] - ) - except NoResultFound: - # Add value - state = AgenciesSyncState() - session.add(state) - return AgencySyncParameters(page=None, cutoff_date=None) - - - - async def upsert_agencies( - self, - agencies: list[AgenciesSyncResponseInnerInfo] - ): - await self.bulk_upsert( - model=Agency, - mappings=get_upsert_agencies_mappings(agencies), - id_value="agency_id", - ) - - async def update_agencies_sync_progress(self, page: int): - query = update( - AgenciesSyncState - ).values( - current_page=page - ) - await self.execute(query) - - async def mark_full_agencies_sync(self): - query = update( - AgenciesSyncState - ).values( - last_full_sync_at=func.now(), - current_cutoff_date=func.now() - text('interval \'1 day\''), - current_page=None - ) - await self.execute(query) + async def get_urls_aggregated_pending_metrics(self): + return await self.run_query_builder(GetMetricsURLSAggregatedPendingQueryBuilder()) @session_manager async def get_html_for_url( @@ -1638,10 +964,40 @@ async def add_raw_html( self, session: AsyncSession, info_list: list[RawHTMLInfo] - ): + ) -> None: for info in info_list: compressed_html = URLCompressedHTML( url_id=info.url_id, compressed_html=compress_html(info.html) ) session.add(compressed_html) + + async def set_hugging_face_upload_state(self, dt: datetime) -> None: + await self.run_query_builder( + SetHuggingFaceUploadStateQueryBuilder(dt=dt) + ) + + async def get_current_database_time(self) -> datetime: + return await self.scalar(select(func.now())) + + async def get_location_id( + self, + us_state_id: int, + county_id: int | None = None, + locality_id: int | None = None + ) -> int | None: + return await self.run_query_builder( + GetLocationQueryBuilder( + us_state_id=us_state_id, + county_id=county_id, + locality_id=locality_id + ) + ) + + async def refresh_materialized_views(self): + await self.execute( + text("REFRESH MATERIALIZED VIEW url_status_mat_view") + ) + await self.execute( + text("REFRESH MATERIALIZED VIEW batch_url_status_mat_view") + ) \ No newline at end of file diff --git a/src/db/client/sync.py b/src/db/client/sync.py index 8ec13085..006d6f0e 100644 --- a/src/db/client/sync.py +++ b/src/db/client/sync.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Optional, List +from typing import List from sqlalchemy import create_engine, update, Select from sqlalchemy.exc import IntegrityError @@ -7,27 +7,27 @@ from src.collectors.enums import URLStatus from src.db.config_manager import ConfigManager -from src.db.dtos.batch import BatchInfo -from src.db.dtos.duplicate import DuplicateInsertInfo +from src.db.models.impl.batch.pydantic.info import BatchInfo +from src.db.models.impl.duplicate.pydantic.insert import DuplicateInsertInfo from src.db.dtos.url.insert import InsertURLsInfo -from src.db.dtos.log import LogInfo -from src.db.dtos.url.core import URLInfo +from src.db.models.impl.log.pydantic.info import LogInfo from src.db.dtos.url.mapping import URLMapping -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.templates import Base -from src.db.models.instantiations.duplicate import Duplicate -from src.db.models.instantiations.log import Log -from src.db.models.instantiations.url.data_source import URLDataSource -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.batch import Batch -from src.core.tasks.url.operators.submit_approved_url.tdo import SubmittedURLInfo +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.templates_.base import Base +from src.db.models.impl.duplicate.sqlalchemy import Duplicate +from src.db.models.impl.log.sqlalchemy import Log +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.batch.sqlalchemy import Batch +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo from src.core.env_var_manager import EnvVarManager from src.core.enums import BatchStatus # Database Client class DatabaseClient: - def __init__(self, db_url: Optional[str] = None): + def __init__(self, db_url: str | None = None): """Initialize the DatabaseClient.""" if db_url is None: db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) @@ -58,6 +58,11 @@ def wrapper(self, *args, **kwargs): return wrapper + @session_manager + def add_all(self, session: Session, objects: list[Base]): + session.add_all(objects) + session.commit() + @session_manager def insert_batch(self, session: Session, batch_info: BatchInfo) -> int: """Insert a new batch into the database and return its ID.""" @@ -67,11 +72,6 @@ def insert_batch(self, session: Session, batch_info: BatchInfo) -> int: status=batch_info.status.value, parameters=batch_info.parameters, compute_time=batch_info.compute_time, - strategy_success_rate=0, - metadata_success_rate=0, - agency_match_rate=0, - record_type_match_rate=0, - record_category_match_rate=0, ) if batch_info.date_generated is not None: batch.date_generated = batch_info.date_generated @@ -99,7 +99,7 @@ def insert_duplicates( ): for duplicate_info in duplicate_infos: duplicate = Duplicate( - batch_id=duplicate_info.duplicate_batch_id, + batch_id=duplicate_info.batch_id, original_url_id=duplicate_info.original_url_id, ) session.add(duplicate) @@ -119,19 +119,21 @@ def insert_url(self, session, url_info: URLInfo) -> int: url_entry = URL( url=url_info.url, collector_metadata=url_info.collector_metadata, - outcome=url_info.outcome.value, - name=url_info.name + status=url_info.status, + name=url_info.name, + source=url_info.source ) if url_info.created_at is not None: url_entry.created_at = url_info.created_at session.add(url_entry) session.commit() session.refresh(url_entry) - link = LinkBatchURL( - batch_id=url_info.batch_id, - url_id=url_entry.id - ) - session.add(link) + if url_info.batch_id is not None: + link = LinkBatchURL( + batch_id=url_info.batch_id, + url_id=url_entry.id + ) + session.add(link) return url_entry.id def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo: @@ -142,10 +144,10 @@ def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo try: url_id = self.insert_url(url_info) url_mappings.append(URLMapping(url_id=url_id, url=url_info.url)) - except IntegrityError: + except IntegrityError as e: orig_url_info = self.get_url_info_by_url(url_info.url) duplicate_info = DuplicateInsertInfo( - duplicate_batch_id=batch_id, + batch_id=batch_id, original_url_id=orig_url_info.id ) duplicates.append(duplicate_info) @@ -219,14 +221,6 @@ def mark_urls_as_submitted( url_id = info.url_id data_source_id = info.data_source_id - query = ( - update(URL) - .where(URL.id == url_id) - .values( - outcome=URLStatus.SUBMITTED.value - ) - ) - url_data_source_object = URLDataSource( url_id=url_id, data_source_id=data_source_id @@ -235,7 +229,6 @@ def mark_urls_as_submitted( url_data_source_object.created_at = info.submitted_at session.add(url_data_source_object) - session.execute(query) if __name__ == "__main__": client = DatabaseClient() diff --git a/src/db/client/types.py b/src/db/client/types.py index 5ee28c10..ffce5621 100644 --- a/src/db/client/types.py +++ b/src/db/client/types.py @@ -1,9 +1,5 @@ -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.auto import AutoRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.auto import AutoRelevantSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion -UserSuggestionModel = UserRelevantSuggestion or UserRecordTypeSuggestion or UserUrlAgencySuggestion -AutoSuggestionModel = AutoRelevantSuggestion or AutoRecordTypeSuggestion or AutomatedUrlAgencySuggestion +UserSuggestionModel = UserURLTypeSuggestion or UserRecordTypeSuggestion or UserUrlAgencySuggestion diff --git a/src/db/constants.py b/src/db/constants.py index 80cbcd93..a3574a96 100644 --- a/src/db/constants.py +++ b/src/db/constants.py @@ -1,25 +1,13 @@ -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.auto import AutoRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.auto import AutoRelevantSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion PLACEHOLDER_AGENCY_NAME = "PLACEHOLDER_AGENCY_NAME" STANDARD_ROW_LIMIT = 100 -ALL_ANNOTATION_MODELS = [ - AutoRecordTypeSuggestion, - AutoRelevantSuggestion, - AutomatedUrlAgencySuggestion, - UserRelevantSuggestion, - UserRecordTypeSuggestion, - UserUrlAgencySuggestion -] - USER_ANNOTATION_MODELS = [ - UserRelevantSuggestion, + UserURLTypeSuggestion, UserRecordTypeSuggestion, UserUrlAgencySuggestion ] \ No newline at end of file diff --git a/src/db/dto_converter.py b/src/db/dto_converter.py index 5397c803..f0c9b097 100644 --- a/src/db/dto_converter.py +++ b/src/db/dto_converter.py @@ -1,24 +1,23 @@ -from typing import Optional +from collections import Counter from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAgencyInfo from src.api.endpoints.annotate.relevance.get.dto import RelevanceAnnotationResponseInfo from src.api.endpoints.review.next.dto import FinalReviewAnnotationRelevantInfo, FinalReviewAnnotationRecordTypeInfo, \ - FinalReviewAnnotationAgencyAutoInfo, FinalReviewAnnotationAgencyInfo + FinalReviewAnnotationAgencyInfo from src.core.enums import RecordType, SuggestionType -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo -from src.core.tasks.url.operators.url_html.scraper.parser.mapping import ENUM_TO_ATTRIBUTE_MAPPING -from src.db.dtos.url.html_content import HTMLContentType, URLHTMLContentInfo -from src.db.dtos.url.core import URLInfo +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.core.tasks.url.operators.html.scraper.parser.mapping import ENUM_TO_ATTRIBUTE_MAPPING +from src.db.dtos.url.html_content import URLHTMLContentInfo from src.db.dtos.url.with_html import URLWithHTML -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.auto import AutoRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.html_content import URLHTMLContent -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.auto import AutoRelevantSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.content.enums import HTMLContentType +from src.db.models.impl.url.html.content.sqlalchemy import URLHTMLContent +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion class DTOConverter: @@ -29,7 +28,7 @@ class DTOConverter: @staticmethod def final_review_annotation_relevant_info( - user_suggestion: UserRelevantSuggestion, + user_suggestions: list[UserURLTypeSuggestion], auto_suggestion: AutoRelevantSuggestion ) -> FinalReviewAnnotationRelevantInfo: @@ -39,15 +38,17 @@ def final_review_annotation_relevant_info( model_name=auto_suggestion.model_name ) if auto_suggestion else None - user_value = user_suggestion.suggested_status if user_suggestion else None + + user_types = [suggestion.type for suggestion in user_suggestions] + counter = Counter(user_types) return FinalReviewAnnotationRelevantInfo( auto=auto_value, - user=user_value + user=dict(counter) ) @staticmethod def final_review_annotation_record_type_info( - user_suggestion: UserRecordTypeSuggestion, + user_suggestions: list[UserRecordTypeSuggestion], auto_suggestion: AutoRecordTypeSuggestion ): @@ -55,121 +56,16 @@ def final_review_annotation_record_type_info( auto_value = None else: auto_value = RecordType(auto_suggestion.record_type) - if user_suggestion is None: - user_value = None - else: - user_value = RecordType(user_suggestion.record_type) + + record_types: list[RecordType] = [suggestion.record_type for suggestion in user_suggestions] + counter = Counter(record_types) + user_value = dict(counter) return FinalReviewAnnotationRecordTypeInfo( auto=auto_value, user=user_value ) - @staticmethod - def final_review_annotation_agency_auto_info( - automated_agency_suggestions: list[AutomatedUrlAgencySuggestion] - ) -> FinalReviewAnnotationAgencyAutoInfo: - - if len(automated_agency_suggestions) == 0: - return FinalReviewAnnotationAgencyAutoInfo( - unknown=True, - suggestions=[] - ) - - if len(automated_agency_suggestions) == 1: - suggestion = automated_agency_suggestions[0] - unknown = suggestion.is_unknown - else: - unknown = False - - if unknown: - return FinalReviewAnnotationAgencyAutoInfo( - unknown=True, - suggestions=[ - GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.UNKNOWN, - ) - ] - ) - - return FinalReviewAnnotationAgencyAutoInfo( - unknown=unknown, - suggestions=[ - GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=suggestion.agency_id, - agency_name=suggestion.agency.name, - state=suggestion.agency.state, - county=suggestion.agency.county, - locality=suggestion.agency.locality - ) for suggestion in automated_agency_suggestions - ] - ) - - @staticmethod - def user_url_agency_suggestion_to_final_review_annotation_agency_user_info( - user_url_agency_suggestion: UserUrlAgencySuggestion - ) -> Optional[GetNextURLForAgencyAgencyInfo]: - suggestion = user_url_agency_suggestion - if suggestion is None: - return None - if suggestion.is_new: - return GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.NEW_AGENCY, - ) - return GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.USER_SUGGESTION, - pdap_agency_id=suggestion.agency_id, - agency_name=suggestion.agency.name, - state=suggestion.agency.state, - county=suggestion.agency.county, - locality=suggestion.agency.locality - ) - - - @staticmethod - def confirmed_agencies_to_final_review_annotation_agency_info( - confirmed_agencies: list[ConfirmedURLAgency] - ) -> list[GetNextURLForAgencyAgencyInfo]: - results = [] - for confirmed_agency in confirmed_agencies: - agency = confirmed_agency.agency - agency_info = GetNextURLForAgencyAgencyInfo( - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=agency.agency_id, - agency_name=agency.name, - state=agency.state, - county=agency.county, - locality=agency.locality - ) - results.append(agency_info) - return results - - - @staticmethod - def final_review_annotation_agency_info( - automated_agency_suggestions: list[AutomatedUrlAgencySuggestion], - confirmed_agencies: list[ConfirmedURLAgency], - user_agency_suggestion: UserUrlAgencySuggestion - ): - - confirmed_agency_info = DTOConverter.confirmed_agencies_to_final_review_annotation_agency_info( - confirmed_agencies - ) - - agency_auto_info = DTOConverter.final_review_annotation_agency_auto_info( - automated_agency_suggestions - ) - - agency_user_info = DTOConverter.user_url_agency_suggestion_to_final_review_annotation_agency_user_info( - user_agency_suggestion - ) - - return FinalReviewAnnotationAgencyInfo( - confirmed=confirmed_agency_info, - user=agency_user_info, - auto=agency_auto_info - ) @staticmethod diff --git a/src/db/dtos/batch.py b/src/db/dtos/batch.py deleted file mode 100644 index 3e1d265b..00000000 --- a/src/db/dtos/batch.py +++ /dev/null @@ -1,17 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel - -from src.core.enums import BatchStatus - - -class BatchInfo(BaseModel): - id: Optional[int] = None - strategy: str - status: BatchStatus - parameters: dict - user_id: int - total_url_count: Optional[int] = None - compute_time: Optional[float] = None - date_generated: Optional[datetime] = None diff --git a/src/db/dtos/duplicate.py b/src/db/dtos/duplicate.py deleted file mode 100644 index d978f91e..00000000 --- a/src/db/dtos/duplicate.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel - - -class DuplicateInsertInfo(BaseModel): - original_url_id: int - duplicate_batch_id: int - -class DuplicateInfo(DuplicateInsertInfo): - source_url: str - original_batch_id: int - duplicate_metadata: dict - original_metadata: dict \ No newline at end of file diff --git a/src/db/dtos/log.py b/src/db/dtos/log.py deleted file mode 100644 index 43ed1cec..00000000 --- a/src/db/dtos/log.py +++ /dev/null @@ -1,16 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel - - -class LogInfo(BaseModel): - id: Optional[int] = None - log: str - batch_id: int - created_at: Optional[datetime] = None - -class LogOutputInfo(BaseModel): - id: Optional[int] = None - log: str - created_at: Optional[datetime] = None \ No newline at end of file diff --git a/src/db/dtos/metadata_annotation.py b/src/db/dtos/metadata_annotation.py deleted file mode 100644 index 5a004cf1..00000000 --- a/src/db/dtos/metadata_annotation.py +++ /dev/null @@ -1,11 +0,0 @@ -from datetime import datetime - -from pydantic import BaseModel - - -class MetadataAnnotationInfo(BaseModel): - id: int - user_id: int - metadata_id: int - value: str - created_at: datetime diff --git a/src/db/dtos/url/core.py b/src/db/dtos/url/core.py deleted file mode 100644 index e409c32c..00000000 --- a/src/db/dtos/url/core.py +++ /dev/null @@ -1,17 +0,0 @@ -import datetime -from typing import Optional - -from pydantic import BaseModel - -from src.collectors.enums import URLStatus - - -class URLInfo(BaseModel): - id: Optional[int] = None - batch_id: Optional[int] = None - url: str - collector_metadata: Optional[dict] = None - outcome: URLStatus = URLStatus.PENDING - updated_at: Optional[datetime.datetime] = None - created_at: Optional[datetime.datetime] = None - name: Optional[str] = None diff --git a/src/db/dtos/url/error.py b/src/db/dtos/url/error.py deleted file mode 100644 index 46f5b9fa..00000000 --- a/src/db/dtos/url/error.py +++ /dev/null @@ -1,11 +0,0 @@ -import datetime -from typing import Optional - -from pydantic import BaseModel - - -class URLErrorPydanticInfo(BaseModel): - task_id: int - url_id: int - error: str - updated_at: Optional[datetime.datetime] = None \ No newline at end of file diff --git a/src/db/dtos/url/html_content.py b/src/db/dtos/url/html_content.py index f8b24eb0..d7fb560e 100644 --- a/src/db/dtos/url/html_content.py +++ b/src/db/dtos/url/html_content.py @@ -1,21 +1,15 @@ -from enum import Enum -from typing import Optional +from src.db.models.impl.url.html.content.enums import HTMLContentType +from src.db.models.impl.url.html.content.sqlalchemy import URLHTMLContent +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel -from pydantic import BaseModel - -class HTMLContentType(Enum): - TITLE = "Title" - DESCRIPTION = "Description" - H1 = "H1" - H2 = "H2" - H3 = "H3" - H4 = "H4" - H5 = "H5" - H6 = "H6" - DIV = "Div" - -class URLHTMLContentInfo(BaseModel): - url_id: Optional[int] = None +class URLHTMLContentInfo(BulkInsertableModel): + url_id: int | None = None content_type: HTMLContentType - content: str | list[str] \ No newline at end of file + content: str | list[str] + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URLHTMLContent \ No newline at end of file diff --git a/src/db/dtos/url/mapping.py b/src/db/dtos/url/mapping.py index 38efbce4..d48a4649 100644 --- a/src/db/dtos/url/mapping.py +++ b/src/db/dtos/url/mapping.py @@ -1,6 +1,9 @@ -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict class URLMapping(BaseModel): + """Mapping between url and url_id.""" + model_config = ConfigDict(frozen=True) # <- makes it immutable & hashable + url: str url_id: int diff --git a/src/db/dtos/url/metadata.py b/src/db/dtos/url/metadata.py deleted file mode 100644 index acac01b8..00000000 --- a/src/db/dtos/url/metadata.py +++ /dev/null @@ -1,19 +0,0 @@ -from datetime import datetime -from typing import Optional - -from pydantic import BaseModel - -from src.db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource - - -class URLMetadataInfo(BaseModel): - id: Optional[int] = None - url_id: Optional[int] = None - attribute: Optional[URLMetadataAttributeType] = None - # TODO: May need to add validation here depending on the type of attribute - value: Optional[str] = None - notes: Optional[str] = None - validation_status: Optional[ValidationStatus] = None - validation_source: Optional[ValidationSource] = None - created_at: Optional[datetime] = None - updated_at: Optional[datetime] = None \ No newline at end of file diff --git a/src/db/enums.py b/src/db/enums.py index 0a45addd..b232c188 100644 --- a/src/db/enums.py +++ b/src/db/enums.py @@ -32,22 +32,52 @@ class URLHTMLContentType(PyEnum): DIV = "Div" class TaskType(PyEnum): + + # URL Tasks HTML = "HTML" RELEVANCY = "Relevancy" RECORD_TYPE = "Record Type" AGENCY_IDENTIFICATION = "Agency Identification" MISC_METADATA = "Misc Metadata" SUBMIT_APPROVED = "Submit Approved URLs" + SUBMIT_META_URLS = "Submit Meta URLs" DUPLICATE_DETECTION = "Duplicate Detection" IDLE = "Idle" - PROBE_404 = "404 Probe" + PROBE_URL = "URL Probe" + ROOT_URL = "Root URL" + IA_PROBE = "Internet Archives Probe" + IA_SAVE = "Internet Archives Archive" + SCREENSHOT = "Screenshot" + LOCATION_ID = "Location ID" + AUTO_VALIDATE = "Auto Validate" + AUTO_NAME = "Auto Name" + SUSPEND_URLS = "Suspend URLs" + + # Scheduled Tasks + PUSH_TO_HUGGINGFACE = "Push to Hugging Face" SYNC_AGENCIES = "Sync Agencies" + SYNC_DATA_SOURCES = "Sync Data Sources" + POPULATE_BACKLOG_SNAPSHOT = "Populate Backlog Snapshot" + DELETE_OLD_LOGS = "Delete Old Logs" + DELETE_STALE_SCREENSHOTS = "Delete Stale Screenshots" + MARK_TASK_NEVER_COMPLETED = "Mark Task Never Completed" + RUN_URL_TASKS = "Run URL Task Cycles" + TASK_CLEANUP = "Task Cleanup" + REFRESH_MATERIALIZED_VIEWS = "Refresh Materialized Views" + +class ChangeLogOperationType(PyEnum): + INSERT = "INSERT" + UPDATE = "UPDATE" + DELETE = "DELETE" class PGEnum(TypeDecorator): impl = postgresql.ENUM + cache_ok = True + def process_bind_param(self, value: PyEnum, dialect): # Convert Python Enum to its value before binding to the DB if isinstance(value, PyEnum): return value.value return value + diff --git a/src/db/helpers.py b/src/db/helpers.py deleted file mode 100644 index 618b2e6d..00000000 --- a/src/db/helpers.py +++ /dev/null @@ -1,5 +0,0 @@ -from src.core.env_var_manager import EnvVarManager - - -def get_postgres_connection_string(is_async = False): - return EnvVarManager.get().get_postgres_connection_string(is_async) diff --git a/src/db/helpers/__init__.py b/src/db/helpers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/helpers/connect.py b/src/db/helpers/connect.py new file mode 100644 index 00000000..2a15cba5 --- /dev/null +++ b/src/db/helpers/connect.py @@ -0,0 +1,5 @@ +from src.core.env_var_manager import EnvVarManager + + +def get_postgres_connection_string(is_async = False) -> str: + return EnvVarManager.get().get_postgres_connection_string(is_async) diff --git a/src/db/helpers/query.py b/src/db/helpers/query.py new file mode 100644 index 00000000..4375cc33 --- /dev/null +++ b/src/db/helpers/query.py @@ -0,0 +1,31 @@ +from sqlalchemy import exists, ColumnElement + +from src.db.enums import TaskType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.db.models.mixins import URLDependentMixin + + +def url_not_validated() -> ColumnElement[bool]: + return not_exists_url(FlagURLValidated) + +def not_exists_url( + model: type[URLDependentMixin] +) -> ColumnElement[bool]: + return ~exists().where( + model.url_id == URL.id + ) + +def exists_url( + model: type[URLDependentMixin] +) -> ColumnElement[bool]: + return exists().where( + model.url_id == URL.id + ) + +def no_url_task_error(task_type: TaskType) -> ColumnElement[bool]: + return ~exists().where( + URLTaskError.url_id == URL.id, + URLTaskError.task_type == task_type + ) \ No newline at end of file diff --git a/src/db/helpers/session/__init__.py b/src/db/helpers/session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/helpers/session/parser.py b/src/db/helpers/session/parser.py new file mode 100644 index 00000000..b580dcd1 --- /dev/null +++ b/src/db/helpers/session/parser.py @@ -0,0 +1,41 @@ +from src.db.helpers.session.types import BulkActionType +from src.db.models.templates_.base import Base +from src.db.templates.protocols.sa_correlated.core import SQLAlchemyCorrelatedProtocol +from src.db.templates.protocols.sa_correlated.with_id import SQLAlchemyCorrelatedWithIDProtocol +from src.db.utils.validate import validate_all_models_of_same_type + + +class BulkActionParser: + + def __init__( + self, + models: list[BulkActionType], + ): + validate_all_models_of_same_type(models) + model_class = type(models[0]) + self.models = models + self.model_class = model_class + + @property + def id_field(self) -> str: + if not issubclass(self.model_class, SQLAlchemyCorrelatedWithIDProtocol): + raise TypeError("Model must implement SQLAlchemyCorrelatedWithID protocol.") + + return self.model_class.id_field() + + @property + def sa_model(self) -> type[Base]: + if not issubclass(self.model_class, SQLAlchemyCorrelatedProtocol): + raise TypeError(f"Model {self.model_class} must implement SQLAlchemyCorrelated protocol.") + return self.model_class.sa_model() + + def get_non_id_fields(self) -> list[str]: + return [ + field for field in self.model_class.model_fields.keys() + if field != self.id_field + ] + + def get_all_fields(self) -> list[str]: + return [ + field for field in self.model_class.model_fields.keys() + ] diff --git a/src/db/helpers/session/session_helper.py b/src/db/helpers/session/session_helper.py new file mode 100644 index 00000000..43369ff3 --- /dev/null +++ b/src/db/helpers/session/session_helper.py @@ -0,0 +1,234 @@ +""" +session_helper (aliased as sh) contains a number of convenience +functions for workings with a SQLAlchemy session +""" +from typing import Any, Optional, Sequence + +import sqlalchemy as sa +from sqlalchemy import update, ColumnElement, Row, Select +from sqlalchemy.dialects import postgresql +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.helpers.session.parser import BulkActionParser +from src.db.models.templates_.base import Base +from src.db.models.templates_.with_id import WithIDBase +from src.db.templates.markers.bulk.delete import BulkDeletableModel +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.update import BulkUpdatableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel +from src.db.templates.protocols.has_id import HasIDProtocol + + +async def one_or_none( + session: AsyncSession, + query: sa.Select +) -> sa.Row | None: + raw_result = await session.execute(query) + return raw_result.scalars().one_or_none() + +async def scalar(session: AsyncSession, query: sa.Select) -> Any: + """Fetch the first column of the first row.""" + raw_result = await session.execute(query) + return raw_result.scalar() + +async def scalars(session: AsyncSession, query: sa.Select) -> Any: + raw_result = await session.execute(query) + return raw_result.scalars().all() + +async def mapping(session: AsyncSession, query: sa.Select) -> sa.RowMapping: + raw_result = await session.execute(query) + return raw_result.mappings().one() + +async def mappings(session: AsyncSession, query: sa.Select) -> Sequence[sa.RowMapping]: + raw_result = await session.execute(query) + return raw_result.mappings().all() + +async def has_results(session: AsyncSession, query: sa.Select) -> bool: + raw_result = await session.execute(query) + return raw_result.first() is not None + +async def bulk_upsert( + session: AsyncSession, + models: list[BulkUpsertableModel], +) -> None: + if len(models) == 0: + return + # Parse models to get sa_model and id_field + parser = BulkActionParser(models) + + # Create base insert query + query = pg_insert(parser.sa_model) + + upsert_mappings: list[dict[str, Any]] = [ + upsert_model.model_dump() for upsert_model in models + ] + + # Set all but two fields to the values in the upsert mapping + set_ = {} + for k, v in upsert_mappings[0].items(): + if k == parser.id_field: + continue + if k == "created_at": + continue + set_[k] = getattr(query.excluded, k) + + # Add upsert logic to update on conflict + query = query.on_conflict_do_update( + index_elements=[parser.id_field], + set_=set_ + ) + + # Note, mapping must include primary key + await session.execute( + statement=query, + params=upsert_mappings + ) + +async def add( + session: AsyncSession, + model: Base, + return_id: bool = False +) -> int | None: + session.add(model) + if return_id: + if not isinstance(model, HasIDProtocol): + raise AttributeError("Models must have an id attribute") + await session.flush() + return model.id + return None + + +async def add_all( + session: AsyncSession, + models: list[WithIDBase], + return_ids: bool = False +) -> list[int] | None: + session.add_all(models) + if return_ids: + if not isinstance(models[0], HasIDProtocol): + raise AttributeError("Models must have an id attribute") + await session.flush() + return [ + model.id # pyright: ignore [reportAttributeAccessIssue] + for model in models + ] + return None + +async def get_all( + session: AsyncSession, + model: Base, + order_by_attribute: Optional[str] = None +) -> Sequence[Row]: + """ + Get all records of a model + Used primarily in testing + """ + statement = sa.select(model) + if order_by_attribute: + statement = statement.order_by(getattr(model, order_by_attribute)) + result = await session.execute(statement) + return result.scalars().all() + +def compile_to_sql(statement) -> str: + compiled_sql = statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}) + return compiled_sql + + +async def bulk_delete(session: AsyncSession, models: list[BulkDeletableModel]): + """Bulk delete sqlalchemy models of the same type.""" + if len(models) == 0: + return + + parser = BulkActionParser(models) + + # Use declared field names from the model (excludes properties/methods) + field_names = parser.get_all_fields() + + sa_model = parser.sa_model + + # Get value tuples to be used in identifying attributes for bulk delete + value_tuples = [] + for model in models: + tup = tuple(getattr(model, field) for field in field_names) + value_tuples.append(tup) + + + statement = ( + sa.delete( + sa_model + ).where( + sa.tuple_( + *[ + getattr(sa_model, attr) + for attr in field_names + ] + ).in_(value_tuples) + ) + ) + + await session.execute(statement) + +async def bulk_insert( + session: AsyncSession, + models: list[BulkInsertableModel], + return_ids: bool = False +) -> list[int] | None: + """Bulk insert sqlalchemy models via their pydantic counterparts.""" + + if len(models) == 0: + return None + + parser = BulkActionParser(models) + sa_model = parser.sa_model + + models_to_add = [] + for model in models: + sa_model_instance = sa_model(**model.model_dump()) + models_to_add.append(sa_model_instance) + + return await add_all( + session=session, + models=models_to_add, + return_ids=return_ids + ) + +async def results_exist( + session: AsyncSession, + query: Select +) -> bool: + query = query.limit(1) + result: sa.Row | None = await one_or_none(session=session, query=query) + return result is not None + +async def bulk_update( + session: AsyncSession, + models: list[BulkUpdatableModel], +): + """Bulk update sqlalchemy models via their pydantic counterparts.""" + if len(models) == 0: + return + + parser = BulkActionParser(models) + + sa_model = parser.sa_model + id_field = parser.id_field + update_fields = parser.get_non_id_fields() + + + for model in models: + update_values = { + k: getattr(model, k) + for k in update_fields + } + id_value = getattr(model, id_field) + id_attr: ColumnElement = getattr(sa_model, id_field) + stmt = ( + update(sa_model) + .where( + id_attr == id_value + ) + .values(**update_values) + ) + await session.execute(stmt) + diff --git a/src/db/helpers/session/types.py b/src/db/helpers/session/types.py new file mode 100644 index 00000000..b960b76c --- /dev/null +++ b/src/db/helpers/session/types.py @@ -0,0 +1,8 @@ +from src.db.templates.markers.bulk.delete import BulkDeletableModel +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.update import BulkUpdatableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + +BulkActionType = ( + BulkInsertableModel | BulkUpdatableModel | BulkDeletableModel | BulkUpsertableModel +) diff --git a/src/db/models/exceptions.py b/src/db/models/exceptions.py new file mode 100644 index 00000000..491aa9a4 --- /dev/null +++ b/src/db/models/exceptions.py @@ -0,0 +1,4 @@ + + +class WriteToViewError(Exception): + pass \ No newline at end of file diff --git a/src/db/models/helpers.py b/src/db/models/helpers.py index f72f06ba..f547e8d4 100644 --- a/src/db/models/helpers.py +++ b/src/db/models/helpers.py @@ -1,13 +1,13 @@ -from sqlalchemy import Column, TIMESTAMP, func, Integer, ForeignKey +from sqlalchemy import Column, TIMESTAMP, func, Integer, ForeignKey, Enum as SAEnum, PrimaryKeyConstraint +from enum import Enum as PyEnum - -def get_created_at_column(): +def get_created_at_column() -> Column: return Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) def get_agency_id_foreign_column( nullable: bool = False -): +) -> Column: return Column( 'agency_id', Integer(), @@ -15,4 +15,57 @@ def get_agency_id_foreign_column( nullable=nullable ) +def enum_column( + enum_type: type[PyEnum], + name: str, + nullable: bool = False +) -> Column[SAEnum]: + return Column( + SAEnum( + enum_type, + name=name, + native_enum=True, + values_callable=lambda enum_type: [e.value for e in enum_type] + ), + nullable=nullable + ) + +def url_id_column() -> Column[int]: + return Column( + Integer(), + ForeignKey('urls.id', ondelete='CASCADE'), + nullable=False + ) + +def location_id_column() -> Column[int]: + return Column( + Integer(), + ForeignKey('locations.id', ondelete='CASCADE'), + nullable=False + ) + CURRENT_TIME_SERVER_DEFAULT = func.now() + +def url_id_primary_key_constraint() -> PrimaryKeyConstraint: + return PrimaryKeyConstraint('url_id') + +def county_column(nullable: bool = False) -> Column[int]: + return Column( + Integer(), + ForeignKey('counties.id', ondelete='CASCADE'), + nullable=nullable + ) + +def locality_column(nullable: bool = False) -> Column[int]: + return Column( + Integer(), + ForeignKey('localities.id', ondelete='CASCADE'), + nullable=nullable + ) + +def us_state_column(nullable: bool = False) -> Column[int]: + return Column( + Integer(), + ForeignKey('us_states.id', ondelete='CASCADE'), + nullable=nullable + ) \ No newline at end of file diff --git a/src/db/models/impl/__init__.py b/src/db/models/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/agency/__init__.py b/src/db/models/impl/agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/agency/enums.py b/src/db/models/impl/agency/enums.py new file mode 100644 index 00000000..80ed9780 --- /dev/null +++ b/src/db/models/impl/agency/enums.py @@ -0,0 +1,19 @@ +from enum import Enum + + +class AgencyType(Enum): + UNKNOWN = "unknown" + INCARCERATION = "incarceration" + LAW_ENFORCEMENT = "law enforcement" + COURT = "court" + AGGREGATED = "aggregated" + +class JurisdictionType(Enum): + SCHOOL = "school" + COUNTY = "county" + LOCAL = "local" + PORT = "port" + TRIBAL = "tribal" + TRANSIT = "transit" + STATE = "state" + FEDERAL = "federal" \ No newline at end of file diff --git a/src/db/models/impl/agency/pydantic/__init__.py b/src/db/models/impl/agency/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/agency/pydantic/upsert.py b/src/db/models/impl/agency/pydantic/upsert.py new file mode 100644 index 00000000..099e8451 --- /dev/null +++ b/src/db/models/impl/agency/pydantic/upsert.py @@ -0,0 +1,22 @@ +from datetime import datetime + +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + + +class AgencyUpsertModel(BulkUpsertableModel): + + @classmethod + def id_field(cls) -> str: + return "agency_id" + + @classmethod + def sa_model(cls) -> type[Base]: + return Agency + + agency_id: int + name: str + state: str | None + county: str | None + locality: str | None diff --git a/src/db/models/impl/agency/sqlalchemy.py b/src/db/models/impl/agency/sqlalchemy.py new file mode 100644 index 00000000..002b0255 --- /dev/null +++ b/src/db/models/impl/agency/sqlalchemy.py @@ -0,0 +1,35 @@ +""" +References an agency in the data sources database. +""" + +from sqlalchemy import Column, Integer, String, DateTime +from sqlalchemy.orm import relationship + +from src.db.models.helpers import enum_column +from src.db.models.impl.agency.enums import AgencyType, JurisdictionType +from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class Agency( + CreatedAtMixin, # When agency was added to database + UpdatedAtMixin, # When agency was last updated in database + WithIDBase +): + __tablename__ = "agencies" + + # TODO: Rename agency_id to ds_agency_id + + agency_id = Column(Integer, primary_key=True) + name = Column(String, nullable=False) + agency_type = enum_column(AgencyType, name="agency_type_enum") + jurisdiction_type = enum_column( + JurisdictionType, + name="jurisdiction_type_enum", + nullable=True, + ) + + # Relationships + automated_suggestions = relationship("AgencyIDSubtaskSuggestion") + user_suggestions = relationship("UserUrlAgencySuggestion", back_populates="agency") + confirmed_urls = relationship("LinkURLAgency", back_populates="agency") diff --git a/src/db/models/impl/agency/suggestion/__init__.py b/src/db/models/impl/agency/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/agency/suggestion/pydantic.py b/src/db/models/impl/agency/suggestion/pydantic.py new file mode 100644 index 00000000..84046717 --- /dev/null +++ b/src/db/models/impl/agency/suggestion/pydantic.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from src.db.models.impl.agency.enums import JurisdictionType, AgencyType +from src.db.models.impl.agency.suggestion.sqlalchemy import NewAgencySuggestion +from src.db.models.templates_.base import Base + + +class NewAgencySuggestionPydantic(BaseModel): + + name: str + location_id: int + jurisdiction_type: JurisdictionType | None + agency_type: AgencyType | None + + @classmethod + def sa_model(cls) -> type[Base]: + return NewAgencySuggestion \ No newline at end of file diff --git a/src/db/models/impl/agency/suggestion/sqlalchemy.py b/src/db/models/impl/agency/suggestion/sqlalchemy.py new file mode 100644 index 00000000..f15b2ef0 --- /dev/null +++ b/src/db/models/impl/agency/suggestion/sqlalchemy.py @@ -0,0 +1,19 @@ +from sqlalchemy import String, Column + +from src.db.models.helpers import enum_column +from src.db.models.impl.agency.enums import JurisdictionType, AgencyType +from src.db.models.mixins import CreatedAtMixin, LocationDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class NewAgencySuggestion( + WithIDBase, + CreatedAtMixin, + LocationDependentMixin, +): + + __tablename__ = 'new_agency_suggestions' + + name = Column(String) + jurisdiction_type = enum_column(JurisdictionType, name='jurisdiction_type_enum', nullable=True) + agency_type = enum_column(AgencyType, name='agency_type_enum', nullable=True) \ No newline at end of file diff --git a/src/db/models/impl/backlog_snapshot.py b/src/db/models/impl/backlog_snapshot.py new file mode 100644 index 00000000..6b0982cd --- /dev/null +++ b/src/db/models/impl/backlog_snapshot.py @@ -0,0 +1,10 @@ +from sqlalchemy import Column, Integer + +from src.db.models.mixins import CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class BacklogSnapshot(CreatedAtMixin, WithIDBase): + __tablename__ = "backlog_snapshot" + + count_pending_total = Column(Integer, nullable=False) diff --git a/src/db/models/impl/batch/__init__.py b/src/db/models/impl/batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/batch/pydantic/__init__.py b/src/db/models/impl/batch/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/batch/pydantic/info.py b/src/db/models/impl/batch/pydantic/info.py new file mode 100644 index 00000000..3272ceef --- /dev/null +++ b/src/db/models/impl/batch/pydantic/info.py @@ -0,0 +1,17 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + +from src.core.enums import BatchStatus + + +class BatchInfo(BaseModel): + id: int | None = None + strategy: str + status: BatchStatus + parameters: dict + user_id: int + total_url_count: int | None = None + compute_time: float | None = None + date_generated: datetime | None = None diff --git a/src/db/models/impl/batch/pydantic/insert.py b/src/db/models/impl/batch/pydantic/insert.py new file mode 100644 index 00000000..882ab371 --- /dev/null +++ b/src/db/models/impl/batch/pydantic/insert.py @@ -0,0 +1,17 @@ +from datetime import datetime + +from src.core.enums import BatchStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class BatchInsertModel(BulkInsertableModel): + strategy: str + status: BatchStatus + parameters: dict + user_id: int + date_generated: datetime + + @classmethod + def sa_model(cls) -> type[Batch]: + return Batch \ No newline at end of file diff --git a/src/db/models/impl/batch/sqlalchemy.py b/src/db/models/impl/batch/sqlalchemy.py new file mode 100644 index 00000000..564ce163 --- /dev/null +++ b/src/db/models/impl/batch/sqlalchemy.py @@ -0,0 +1,50 @@ +from sqlalchemy import Column, Integer, TIMESTAMP, Float, JSON +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import relationship + +from src.db.models.helpers import CURRENT_TIME_SERVER_DEFAULT +from src.db.models.impl.log.sqlalchemy import Log +from src.db.models.templates_.with_id import WithIDBase +from src.db.models.types import batch_status_enum + + +class Batch(WithIDBase): + __tablename__ = 'batches' + + strategy = Column( + postgresql.ENUM( + 'example', + 'ckan', + 'muckrock_county_search', + 'auto_googler', + 'muckrock_all_search', + 'muckrock_simple_search', + 'common_crawler', + 'manual', + name='batch_strategy'), + nullable=False) + user_id = Column(Integer, nullable=False) + # Gives the status of the batch + status = Column( + batch_status_enum, + nullable=False + ) + date_generated = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) + + # Time taken to generate the batch + # TODO: Add means to update after execution + compute_time = Column(Float) + # The parameters used to generate the batch + parameters = Column(JSON) + + # Relationships + urls = relationship( + "URL", + secondary="link_batch_urls", + back_populates="batch", + overlaps="url" + ) + # These relationships exist but are never referenced by their attributes + # missings = relationship("Missing", back_populates="batch") + logs = relationship(Log, back_populates="batch") + duplicates = relationship("Duplicate", back_populates="batch") diff --git a/src/db/models/impl/change_log.py b/src/db/models/impl/change_log.py new file mode 100644 index 00000000..0cb74659 --- /dev/null +++ b/src/db/models/impl/change_log.py @@ -0,0 +1,19 @@ + +from sqlalchemy import Column, Enum +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped + +from src.db.enums import ChangeLogOperationType +from src.db.models.mixins import CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class ChangeLog(CreatedAtMixin, WithIDBase): + + __tablename__ = "change_log" + + operation_type = Column(Enum(ChangeLogOperationType, name="operation_type")) + table_name: Mapped[str] + affected_id: Mapped[int] + old_data = Column("old_data", JSONB, nullable=True) + new_data = Column("new_data", JSONB, nullable=True) diff --git a/src/db/models/impl/duplicate/__init__.py b/src/db/models/impl/duplicate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/duplicate/pydantic/__init__.py b/src/db/models/impl/duplicate/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/duplicate/pydantic/info.py b/src/db/models/impl/duplicate/pydantic/info.py new file mode 100644 index 00000000..627f5d54 --- /dev/null +++ b/src/db/models/impl/duplicate/pydantic/info.py @@ -0,0 +1,8 @@ +from src.db.models.impl.duplicate.pydantic.insert import DuplicateInsertInfo + + +class DuplicateInfo(DuplicateInsertInfo): + source_url: str + original_batch_id: int + duplicate_metadata: dict + original_metadata: dict diff --git a/src/db/models/impl/duplicate/pydantic/insert.py b/src/db/models/impl/duplicate/pydantic/insert.py new file mode 100644 index 00000000..7de4974a --- /dev/null +++ b/src/db/models/impl/duplicate/pydantic/insert.py @@ -0,0 +1,11 @@ +from src.db.models.impl.duplicate.sqlalchemy import Duplicate +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class DuplicateInsertInfo(BulkInsertableModel): + original_url_id: int + batch_id: int + + @classmethod + def sa_model(self) -> type[Duplicate]: + return Duplicate \ No newline at end of file diff --git a/src/db/models/impl/duplicate/sqlalchemy.py b/src/db/models/impl/duplicate/sqlalchemy.py new file mode 100644 index 00000000..03c492e3 --- /dev/null +++ b/src/db/models/impl/duplicate/sqlalchemy.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column, Integer, ForeignKey +from sqlalchemy.orm import relationship + +from src.db.models.mixins import BatchDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class Duplicate(BatchDependentMixin, WithIDBase): + """ + Identifies duplicates which occur within a batch + """ + __tablename__ = 'duplicates' + + original_url_id = Column( + Integer, + ForeignKey('urls.id'), + nullable=False, + doc="The original URL ID" + ) + + # Relationships + batch = relationship("Batch", back_populates="duplicates") + original_url = relationship("URL", back_populates="duplicates") diff --git a/src/db/models/impl/flag/__init__.py b/src/db/models/impl/flag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/auto_validated/__init__.py b/src/db/models/impl/flag/auto_validated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/auto_validated/pydantic.py b/src/db/models/impl/flag/auto_validated/pydantic.py new file mode 100644 index 00000000..da1efb7b --- /dev/null +++ b/src/db/models/impl/flag/auto_validated/pydantic.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from src.db.models.impl.flag.auto_validated.sqlalchemy import FlagURLAutoValidated + + +class FlagURLAutoValidatedPydantic(BaseModel): + + url_id: int + + @classmethod + def sa_model(cls) -> type[FlagURLAutoValidated]: + return FlagURLAutoValidated \ No newline at end of file diff --git a/src/db/models/impl/flag/auto_validated/sqlalchemy.py b/src/db/models/impl/flag/auto_validated/sqlalchemy.py new file mode 100644 index 00000000..a0ce02b9 --- /dev/null +++ b/src/db/models/impl/flag/auto_validated/sqlalchemy.py @@ -0,0 +1,18 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class FlagURLAutoValidated( + Base, + URLDependentMixin, + CreatedAtMixin +): + + __tablename__ = 'flag_url_auto_validated' + __table_args__ = ( + PrimaryKeyConstraint( + "url_id" + ), + ) \ No newline at end of file diff --git a/src/db/models/impl/flag/checked_for_ia/__init__.py b/src/db/models/impl/flag/checked_for_ia/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/checked_for_ia/pydantic.py b/src/db/models/impl/flag/checked_for_ia/pydantic.py new file mode 100644 index 00000000..5b801f6d --- /dev/null +++ b/src/db/models/impl/flag/checked_for_ia/pydantic.py @@ -0,0 +1,11 @@ +from src.db.models.impl.flag.checked_for_ia.sqlalchemy import FlagURLCheckedForInternetArchives +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class FlagURLCheckedForInternetArchivesPydantic(BulkInsertableModel): + url_id: int + success: bool + + @classmethod + def sa_model(cls) -> type[FlagURLCheckedForInternetArchives]: + return FlagURLCheckedForInternetArchives \ No newline at end of file diff --git a/src/db/models/impl/flag/checked_for_ia/sqlalchemy.py b/src/db/models/impl/flag/checked_for_ia/sqlalchemy.py new file mode 100644 index 00000000..efdf9257 --- /dev/null +++ b/src/db/models/impl/flag/checked_for_ia/sqlalchemy.py @@ -0,0 +1,22 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base +from src.db.models.templates_.with_id import WithIDBase + + +class FlagURLCheckedForInternetArchives( + URLDependentMixin, + CreatedAtMixin, + Base +): + + success: Mapped[bool] + + __tablename__ = 'flag_url_checked_for_internet_archive' + __table_args__ = ( + PrimaryKeyConstraint( + 'url_id', + ), + ) diff --git a/src/db/models/impl/flag/root_url/__init__.py b/src/db/models/impl/flag/root_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/root_url/pydantic.py b/src/db/models/impl/flag/root_url/pydantic.py new file mode 100644 index 00000000..a840192a --- /dev/null +++ b/src/db/models/impl/flag/root_url/pydantic.py @@ -0,0 +1,11 @@ +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class FlagRootURLPydantic(BulkInsertableModel): + + url_id: int + + @classmethod + def sa_model(cls) -> type[FlagRootURL]: + return FlagRootURL \ No newline at end of file diff --git a/src/db/models/impl/flag/root_url/sqlalchemy.py b/src/db/models/impl/flag/root_url/sqlalchemy.py new file mode 100644 index 00000000..8c8afbed --- /dev/null +++ b/src/db/models/impl/flag/root_url/sqlalchemy.py @@ -0,0 +1,17 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class FlagRootURL( + CreatedAtMixin, + URLDependentMixin, + Base +): + __tablename__ = 'flag_root_url' + __table_args__ = ( + PrimaryKeyConstraint( + 'url_id', + ), + ) diff --git a/src/db/models/impl/flag/url_suspended/__init__.py b/src/db/models/impl/flag/url_suspended/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/url_suspended/sqlalchemy.py b/src/db/models/impl/flag/url_suspended/sqlalchemy.py new file mode 100644 index 00000000..dea3f0b0 --- /dev/null +++ b/src/db/models/impl/flag/url_suspended/sqlalchemy.py @@ -0,0 +1,17 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class FlagURLSuspended( + Base, + URLDependentMixin, + CreatedAtMixin +): + + __tablename__ = "flag_url_suspended" + + __table_args__ = ( + PrimaryKeyConstraint("url_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/flag/url_validated/__init__.py b/src/db/models/impl/flag/url_validated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/flag/url_validated/enums.py b/src/db/models/impl/flag/url_validated/enums.py new file mode 100644 index 00000000..7ac2a0ad --- /dev/null +++ b/src/db/models/impl/flag/url_validated/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class URLType(Enum): + DATA_SOURCE = "data source" + META_URL = "meta url" + NOT_RELEVANT = "not relevant" + INDIVIDUAL_RECORD = "individual record" + BROKEN_PAGE = "broken page" \ No newline at end of file diff --git a/src/db/models/impl/flag/url_validated/pydantic.py b/src/db/models/impl/flag/url_validated/pydantic.py new file mode 100644 index 00000000..a8bd5b42 --- /dev/null +++ b/src/db/models/impl/flag/url_validated/pydantic.py @@ -0,0 +1,22 @@ +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + +type_ = type + +class FlagURLValidatedPydantic( + BulkInsertableModel, + BulkUpsertableModel +): + + url_id: int + type: URLType + + @classmethod + def sa_model(cls) -> type_[FlagURLValidated]: + return FlagURLValidated + + @classmethod + def id_field(cls) -> str: + return "url_id" \ No newline at end of file diff --git a/src/db/models/impl/flag/url_validated/sqlalchemy.py b/src/db/models/impl/flag/url_validated/sqlalchemy.py new file mode 100644 index 00000000..97abf056 --- /dev/null +++ b/src/db/models/impl/flag/url_validated/sqlalchemy.py @@ -0,0 +1,25 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.helpers import enum_column +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, UpdatedAtMixin +from src.db.models.templates_.base import Base + + +class FlagURLValidated( + URLDependentMixin, + CreatedAtMixin, + UpdatedAtMixin, + Base, +): + __tablename__ = "flag_url_validated" + __table_args__ = ( + PrimaryKeyConstraint( + 'url_id', + ), + ) + + type = enum_column( + enum_type=URLType, + name="url_type", + ) diff --git a/src/db/models/impl/link/__init__.py b/src/db/models/impl/link/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/agency_batch/__init__.py b/src/db/models/impl/link/agency_batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/agency_batch/sqlalchemy.py b/src/db/models/impl/link/agency_batch/sqlalchemy.py new file mode 100644 index 00000000..dcb670d3 --- /dev/null +++ b/src/db/models/impl/link/agency_batch/sqlalchemy.py @@ -0,0 +1,20 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import CreatedAtMixin, AgencyDependentMixin, BatchDependentMixin +from src.db.models.templates_.base import Base + + +class LinkAgencyBatch( + Base, + CreatedAtMixin, + BatchDependentMixin, + AgencyDependentMixin, +): + __tablename__ = "link_agency_batches" + __table_args__ = ( + PrimaryKeyConstraint( + 'batch_id', + 'agency_id', + name='link_agency_batches_pk' + ), + ) diff --git a/src/db/models/impl/link/agency_location/__init__.py b/src/db/models/impl/link/agency_location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/agency_location/sqlalchemy.py b/src/db/models/impl/link/agency_location/sqlalchemy.py new file mode 100644 index 00000000..18a3ae5f --- /dev/null +++ b/src/db/models/impl/link/agency_location/sqlalchemy.py @@ -0,0 +1,10 @@ +from src.db.models.mixins import AgencyDependentMixin, LocationDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class LinkAgencyLocation( + WithIDBase, + AgencyDependentMixin, + LocationDependentMixin, +): + __tablename__ = "link_agencies_locations" \ No newline at end of file diff --git a/src/db/models/impl/link/batch_url/__init__.py b/src/db/models/impl/link/batch_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/batch_url/pydantic.py b/src/db/models/impl/link/batch_url/pydantic.py new file mode 100644 index 00000000..143c57ce --- /dev/null +++ b/src/db/models/impl/link/batch_url/pydantic.py @@ -0,0 +1,11 @@ +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LinkBatchURLPydantic(BulkInsertableModel): + batch_id: int + url_id: int + + @classmethod + def sa_model(cls) -> type[LinkBatchURL]: + return LinkBatchURL \ No newline at end of file diff --git a/src/db/models/impl/link/batch_url/sqlalchemy.py b/src/db/models/impl/link/batch_url/sqlalchemy.py new file mode 100644 index 00000000..951ac539 --- /dev/null +++ b/src/db/models/impl/link/batch_url/sqlalchemy.py @@ -0,0 +1,15 @@ +from sqlalchemy.orm import relationship + +from src.db.models.mixins import CreatedAtMixin, UpdatedAtMixin, BatchDependentMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class LinkBatchURL( + UpdatedAtMixin, + CreatedAtMixin, + URLDependentMixin, + BatchDependentMixin, + WithIDBase +): + __tablename__ = "link_batch_urls" + diff --git a/src/db/models/impl/link/location_batch/__init__.py b/src/db/models/impl/link/location_batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/location_batch/sqlalchemy.py b/src/db/models/impl/link/location_batch/sqlalchemy.py new file mode 100644 index 00000000..e73a5ec8 --- /dev/null +++ b/src/db/models/impl/link/location_batch/sqlalchemy.py @@ -0,0 +1,21 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import LocationDependentMixin, BatchDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class LinkLocationBatch( + Base, + LocationDependentMixin, + BatchDependentMixin, + CreatedAtMixin +): + + __tablename__ = "link_location_batches" + __table_args__ = ( + PrimaryKeyConstraint( + 'batch_id', + 'location_id', + name='link_location_batches_pk' + ), + ) \ No newline at end of file diff --git a/src/db/models/impl/link/task_url.py b/src/db/models/impl/link/task_url.py new file mode 100644 index 00000000..2535d317 --- /dev/null +++ b/src/db/models/impl/link/task_url.py @@ -0,0 +1,15 @@ +from sqlalchemy import UniqueConstraint, Column, Integer, ForeignKey + +from src.db.models.templates_.base import Base + + +class LinkTaskURL(Base): + __tablename__ = 'link_task_urls' + __table_args__ = (UniqueConstraint( + "task_id", + "url_id", + name="uq_task_id_url_id"), + ) + + task_id = Column(Integer, ForeignKey('tasks.id', ondelete="CASCADE"), primary_key=True) + url_id = Column(Integer, ForeignKey('urls.id', ondelete="CASCADE"), primary_key=True) diff --git a/src/db/models/impl/link/url_agency/__init__.py b/src/db/models/impl/link/url_agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/url_agency/pydantic.py b/src/db/models/impl/link/url_agency/pydantic.py new file mode 100644 index 00000000..fe9194de --- /dev/null +++ b/src/db/models/impl/link/url_agency/pydantic.py @@ -0,0 +1,19 @@ +from pydantic import ConfigDict + +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.templates.markers.bulk.delete import BulkDeletableModel +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LinkURLAgencyPydantic( + BulkDeletableModel, + BulkInsertableModel +): + model_config = ConfigDict(frozen=True) + + url_id: int + agency_id: int + + @classmethod + def sa_model(cls) -> type[LinkURLAgency]: + return LinkURLAgency \ No newline at end of file diff --git a/src/db/models/impl/link/url_agency/sqlalchemy.py b/src/db/models/impl/link/url_agency/sqlalchemy.py new file mode 100644 index 00000000..875fa25f --- /dev/null +++ b/src/db/models/impl/link/url_agency/sqlalchemy.py @@ -0,0 +1,19 @@ +from sqlalchemy import UniqueConstraint +from sqlalchemy.orm import relationship, Mapped + +from src.db.models.helpers import get_agency_id_foreign_column +from src.db.models.mixins import URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class LinkURLAgency(URLDependentMixin, WithIDBase): + __tablename__ = "link_urls_agency" + + agency_id: Mapped[int] = get_agency_id_foreign_column() + + url = relationship("URL", back_populates="confirmed_agencies") + agency = relationship("Agency", back_populates="confirmed_urls") + + __table_args__ = ( + UniqueConstraint("url_id", "agency_id", name="uq_confirmed_url_agency"), + ) diff --git a/src/db/models/impl/link/url_redirect_url/__init__.py b/src/db/models/impl/link/url_redirect_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/url_redirect_url/pydantic.py b/src/db/models/impl/link/url_redirect_url/pydantic.py new file mode 100644 index 00000000..b7b5dff3 --- /dev/null +++ b/src/db/models/impl/link/url_redirect_url/pydantic.py @@ -0,0 +1,12 @@ +from src.db.models.impl.link.url_redirect_url.sqlalchemy import LinkURLRedirectURL +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LinkURLRedirectURLPydantic(BulkInsertableModel): + source_url_id: int + destination_url_id: int + + @classmethod + def sa_model(cls) -> type[LinkURLRedirectURL]: + return LinkURLRedirectURL + diff --git a/src/db/models/impl/link/url_redirect_url/sqlalchemy.py b/src/db/models/impl/link/url_redirect_url/sqlalchemy.py new file mode 100644 index 00000000..312cbb57 --- /dev/null +++ b/src/db/models/impl/link/url_redirect_url/sqlalchemy.py @@ -0,0 +1,10 @@ +from src.db.models.helpers import url_id_column +from src.db.models.templates_.standard import StandardBase + + + +class LinkURLRedirectURL(StandardBase): + __tablename__ = "link_urls_redirect_url" + source_url_id = url_id_column() + destination_url_id = url_id_column() + diff --git a/src/db/models/impl/link/urls_root_url/__init__.py b/src/db/models/impl/link/urls_root_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/urls_root_url/pydantic.py b/src/db/models/impl/link/urls_root_url/pydantic.py new file mode 100644 index 00000000..c3037567 --- /dev/null +++ b/src/db/models/impl/link/urls_root_url/pydantic.py @@ -0,0 +1,12 @@ +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LinkURLRootURLPydantic(BulkInsertableModel): + + url_id: int + root_url_id: int + + @classmethod + def sa_model(cls) -> type[LinkURLRootURL]: + return LinkURLRootURL \ No newline at end of file diff --git a/src/db/models/impl/link/urls_root_url/sqlalchemy.py b/src/db/models/impl/link/urls_root_url/sqlalchemy.py new file mode 100644 index 00000000..a856dd31 --- /dev/null +++ b/src/db/models/impl/link/urls_root_url/sqlalchemy.py @@ -0,0 +1,14 @@ +from src.db.models.helpers import url_id_column +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, UpdatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class LinkURLRootURL( + UpdatedAtMixin, + CreatedAtMixin, + URLDependentMixin, + WithIDBase +): + __tablename__ = "link_urls_root_url" + + root_url_id = url_id_column() \ No newline at end of file diff --git a/src/db/models/impl/link/user_name_suggestion/__init__.py b/src/db/models/impl/link/user_name_suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/user_name_suggestion/pydantic.py b/src/db/models/impl/link/user_name_suggestion/pydantic.py new file mode 100644 index 00000000..6e07989b --- /dev/null +++ b/src/db/models/impl/link/user_name_suggestion/pydantic.py @@ -0,0 +1,12 @@ +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LinkUserNameSuggestionPydantic(BulkInsertableModel): + + suggestion_id: int + user_id: int + + @classmethod + def sa_model(cls) -> type[LinkUserNameSuggestion]: + return LinkUserNameSuggestion \ No newline at end of file diff --git a/src/db/models/impl/link/user_name_suggestion/sqlalchemy.py b/src/db/models/impl/link/user_name_suggestion/sqlalchemy.py new file mode 100644 index 00000000..316a8e3c --- /dev/null +++ b/src/db/models/impl/link/user_name_suggestion/sqlalchemy.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, Integer, ForeignKey + +from src.db.models.mixins import CreatedAtMixin +from src.db.models.templates_.base import Base + + +class LinkUserNameSuggestion( + Base, + CreatedAtMixin, +): + + __tablename__ = "link_user_name_suggestions" + + suggestion_id = Column( + Integer, + ForeignKey("url_name_suggestions.id"), + primary_key=True, + nullable=False, + ) + + user_id = Column( + Integer, + primary_key=True, + nullable=False, + ) \ No newline at end of file diff --git a/src/db/models/impl/link/user_suggestion_not_found/__init__.py b/src/db/models/impl/link/user_suggestion_not_found/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/user_suggestion_not_found/agency/__init__.py b/src/db/models/impl/link/user_suggestion_not_found/agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/user_suggestion_not_found/agency/sqlalchemy.py b/src/db/models/impl/link/user_suggestion_not_found/agency/sqlalchemy.py new file mode 100644 index 00000000..0092f504 --- /dev/null +++ b/src/db/models/impl/link/user_suggestion_not_found/agency/sqlalchemy.py @@ -0,0 +1,20 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base +from src.util.alembic_helpers import user_id_column + + +class LinkUserSuggestionAgencyNotFound( + Base, + URLDependentMixin, + CreatedAtMixin, +): + __tablename__ = "link_user_suggestion_agency_not_found" + + user_id: Mapped[int] = user_id_column() + + __table_args__ = ( + PrimaryKeyConstraint("url_id", "user_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/link/user_suggestion_not_found/location/__init__.py b/src/db/models/impl/link/user_suggestion_not_found/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/user_suggestion_not_found/location/sqlalchemy.py b/src/db/models/impl/link/user_suggestion_not_found/location/sqlalchemy.py new file mode 100644 index 00000000..d608b04d --- /dev/null +++ b/src/db/models/impl/link/user_suggestion_not_found/location/sqlalchemy.py @@ -0,0 +1,20 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base +from src.util.alembic_helpers import user_id_column + + +class LinkUserSuggestionLocationNotFound( + Base, + URLDependentMixin, + CreatedAtMixin, +): + __tablename__ = "link_user_suggestion_location_not_found" + + user_id: Mapped[int] = user_id_column() + + __table_args__ = ( + PrimaryKeyConstraint("url_id", "user_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/link/user_suggestion_not_found/users_submitted_url/__init__.py b/src/db/models/impl/link/user_suggestion_not_found/users_submitted_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/link/user_suggestion_not_found/users_submitted_url/sqlalchemy.py b/src/db/models/impl/link/user_suggestion_not_found/users_submitted_url/sqlalchemy.py new file mode 100644 index 00000000..23e61993 --- /dev/null +++ b/src/db/models/impl/link/user_suggestion_not_found/users_submitted_url/sqlalchemy.py @@ -0,0 +1,19 @@ +from sqlalchemy import Column, Integer, PrimaryKeyConstraint, UniqueConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class LinkUserSubmittedURL( + Base, + URLDependentMixin, + CreatedAtMixin, +): + __tablename__ = "link_user_submitted_urls" + __table_args__ = ( + PrimaryKeyConstraint("url_id", "user_id"), + UniqueConstraint("url_id"), + ) + + user_id: Mapped[int] \ No newline at end of file diff --git a/src/db/models/impl/location/__init__.py b/src/db/models/impl/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/location/county/__init__.py b/src/db/models/impl/location/county/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/location/county/sqlalchemy.py b/src/db/models/impl/location/county/sqlalchemy.py new file mode 100644 index 00000000..99d82bdc --- /dev/null +++ b/src/db/models/impl/location/county/sqlalchemy.py @@ -0,0 +1,18 @@ +from sqlalchemy import String, Column, Float, Integer +from sqlalchemy.orm import Mapped + +from src.db.models.helpers import us_state_column +from src.db.models.templates_.with_id import WithIDBase + + +class County( + WithIDBase, +): + __tablename__ = "counties" + + name: Mapped[str] + state_id: Mapped[int] = us_state_column() + fips: Mapped[str] = Column(String(5), nullable=True) + lat: Mapped[float] = Column(Float, nullable=True) + lng: Mapped[float] = Column(Float, nullable=True) + population: Mapped[int] = Column(Integer, nullable=True) \ No newline at end of file diff --git a/src/db/models/impl/location/locality/__init__.py b/src/db/models/impl/location/locality/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/location/locality/sqlalchemy.py b/src/db/models/impl/location/locality/sqlalchemy.py new file mode 100644 index 00000000..c462a8c1 --- /dev/null +++ b/src/db/models/impl/location/locality/sqlalchemy.py @@ -0,0 +1,15 @@ +from sqlalchemy import String, Column +from sqlalchemy.orm import Mapped + +from src.db.models.helpers import county_column +from src.db.models.templates_.with_id import WithIDBase + + +class Locality( + WithIDBase, +): + + __tablename__ = "localities" + + name = Column(String(255), nullable=False) + county_id: Mapped[int] = county_column(nullable=False) diff --git a/src/db/models/impl/location/location/__init__.py b/src/db/models/impl/location/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/location/location/enums.py b/src/db/models/impl/location/location/enums.py new file mode 100644 index 00000000..24a99ce9 --- /dev/null +++ b/src/db/models/impl/location/location/enums.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class LocationType(Enum): + NATIONAL = "National" + STATE = "State" + COUNTY = "County" + LOCALITY = "Locality" \ No newline at end of file diff --git a/src/db/models/impl/location/location/sqlalchemy.py b/src/db/models/impl/location/location/sqlalchemy.py new file mode 100644 index 00000000..1a5dc435 --- /dev/null +++ b/src/db/models/impl/location/location/sqlalchemy.py @@ -0,0 +1,19 @@ +from sqlalchemy import Float, Column + +from src.db.models.helpers import us_state_column, county_column, locality_column, enum_column +from src.db.models.impl.location.location.enums import LocationType +from src.db.models.templates_.with_id import WithIDBase + + +class Location( + WithIDBase +): + + __tablename__ = "locations" + + state_id = us_state_column(nullable=True) + county_id = county_column(nullable=True) + locality_id = locality_column(nullable=True) + type = enum_column(LocationType, name="location_type", nullable=False) + lat = Column(Float(), nullable=True) + lng = Column(Float(), nullable=True) diff --git a/src/db/models/impl/location/us_state/__init__.py b/src/db/models/impl/location/us_state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/location/us_state/sqlalchemy.py b/src/db/models/impl/location/us_state/sqlalchemy.py new file mode 100644 index 00000000..c4cdfc2f --- /dev/null +++ b/src/db/models/impl/location/us_state/sqlalchemy.py @@ -0,0 +1,12 @@ +from sqlalchemy.orm import Mapped + +from src.db.models.templates_.with_id import WithIDBase + + +class USState( + WithIDBase, +): + __tablename__ = "us_states" + + state_name: Mapped[str] + state_iso: Mapped[str] \ No newline at end of file diff --git a/src/db/models/impl/log/__init__.py b/src/db/models/impl/log/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/log/pydantic/__init__.py b/src/db/models/impl/log/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/log/pydantic/info.py b/src/db/models/impl/log/pydantic/info.py new file mode 100644 index 00000000..76af0dd7 --- /dev/null +++ b/src/db/models/impl/log/pydantic/info.py @@ -0,0 +1,11 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class LogInfo(BaseModel): + id: int | None = None + log: str + batch_id: int + created_at: datetime | None = None diff --git a/src/db/models/impl/log/pydantic/output.py b/src/db/models/impl/log/pydantic/output.py new file mode 100644 index 00000000..36ea843b --- /dev/null +++ b/src/db/models/impl/log/pydantic/output.py @@ -0,0 +1,10 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel + + +class LogOutputInfo(BaseModel): + id: int | None = None + log: str + created_at: datetime | None = None diff --git a/src/db/models/impl/log/sqlalchemy.py b/src/db/models/impl/log/sqlalchemy.py new file mode 100644 index 00000000..60f17875 --- /dev/null +++ b/src/db/models/impl/log/sqlalchemy.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Text +from sqlalchemy.orm import relationship + +from src.db.models.mixins import CreatedAtMixin, BatchDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class Log(CreatedAtMixin, BatchDependentMixin, WithIDBase): + __tablename__ = 'logs' + + log = Column(Text, nullable=False) + + # Relationships + batch = relationship("Batch", back_populates="logs") diff --git a/src/db/models/instantiations/missing.py b/src/db/models/impl/missing.py similarity index 82% rename from src/db/models/instantiations/missing.py rename to src/db/models/impl/missing.py index 0babd91d..6ad868df 100644 --- a/src/db/models/instantiations/missing.py +++ b/src/db/models/impl/missing.py @@ -3,10 +3,10 @@ from src.db.models.helpers import get_created_at_column from src.db.models.mixins import BatchDependentMixin -from src.db.models.templates import StandardModel +from src.db.models.templates_.with_id import WithIDBase -class Missing(BatchDependentMixin, StandardModel): +class Missing(BatchDependentMixin, WithIDBase): __tablename__ = 'missing' place_id = Column(Integer, nullable=False) diff --git a/src/db/models/impl/state/__init__.py b/src/db/models/impl/state/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/state/huggingface.py b/src/db/models/impl/state/huggingface.py new file mode 100644 index 00000000..d858dc0a --- /dev/null +++ b/src/db/models/impl/state/huggingface.py @@ -0,0 +1,10 @@ +from sqlalchemy import Column, Integer, DateTime + +from src.db.models.templates_.base import Base + + +class HuggingFaceUploadState(Base): + __tablename__ = "huggingface_upload_state" + + id = Column(Integer, primary_key=True) + last_upload_at = Column(DateTime, nullable=False) \ No newline at end of file diff --git a/src/db/models/impl/task/__init__.py b/src/db/models/impl/task/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/task/core.py b/src/db/models/impl/task/core.py new file mode 100644 index 00000000..566dd116 --- /dev/null +++ b/src/db/models/impl/task/core.py @@ -0,0 +1,39 @@ +from sqlalchemy import Column +from sqlalchemy.orm import relationship + +from src.db.enums import PGEnum, TaskType +from src.db.models.impl.task.error import TaskError +from src.db.models.mixins import UpdatedAtMixin +from src.db.models.templates_.with_id import WithIDBase +from src.db.models.types import batch_status_enum + + + +class Task(UpdatedAtMixin, WithIDBase): + __tablename__ = 'tasks' + + task_type = Column( + PGEnum( + *[task_type.value for task_type in TaskType], + name='task_type' + ), nullable=False) + task_status = Column( + PGEnum( + 'complete', + 'in-process', + 'error', + 'aborted', + 'never_completed', + name='task_status_enum' + ), + nullable=False + ) + + # Relationships + urls = relationship( + "URL", + secondary="link_task_urls", + back_populates="tasks" + ) + errors = relationship(TaskError) + url_errors = relationship("URLTaskError") diff --git a/src/db/models/impl/task/enums.py b/src/db/models/impl/task/enums.py new file mode 100644 index 00000000..b166d747 --- /dev/null +++ b/src/db/models/impl/task/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class TaskStatus(Enum): + COMPLETE = "complete" + IN_PROCESS = "in-process" + ERROR = "error" + ABORTED = "aborted" + NEVER_COMPLETED = "never-completed" diff --git a/src/db/models/impl/task/error.py b/src/db/models/impl/task/error.py new file mode 100644 index 00000000..2de0c66a --- /dev/null +++ b/src/db/models/impl/task/error.py @@ -0,0 +1,20 @@ +from sqlalchemy import Column, Text, UniqueConstraint +from sqlalchemy.orm import relationship + +from src.db.models.mixins import UpdatedAtMixin, TaskDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class TaskError(UpdatedAtMixin, TaskDependentMixin, WithIDBase): + __tablename__ = 'task_errors' + + error = Column(Text, nullable=False) + + # Relationships + task = relationship("Task") + + __table_args__ = (UniqueConstraint( + "task_id", + "error", + name="uq_task_id_error"), + ) diff --git a/src/db/models/impl/url/__init__.py b/src/db/models/impl/url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/instantiations/url/checked_for_duplicate.py b/src/db/models/impl/url/checked_for_duplicate.py similarity index 82% rename from src/db/models/instantiations/url/checked_for_duplicate.py rename to src/db/models/impl/url/checked_for_duplicate.py index d5811c6e..bb7cf666 100644 --- a/src/db/models/instantiations/url/checked_for_duplicate.py +++ b/src/db/models/impl/url/checked_for_duplicate.py @@ -1,10 +1,10 @@ from sqlalchemy.orm import relationship from src.db.models.mixins import CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel +from src.db.models.templates_.with_id import WithIDBase -class URLCheckedForDuplicate(CreatedAtMixin, URLDependentMixin, StandardModel): +class URLCheckedForDuplicate(CreatedAtMixin, URLDependentMixin, WithIDBase): __tablename__ = 'url_checked_for_duplicate' # Relationships diff --git a/src/db/models/impl/url/core/__init__.py b/src/db/models/impl/url/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/core/enums.py b/src/db/models/impl/url/core/enums.py new file mode 100644 index 00000000..88fe5bc4 --- /dev/null +++ b/src/db/models/impl/url/core/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class URLSource(Enum): + COLLECTOR = "collector" + MANUAL = "manual" + DATA_SOURCES = "data_sources_app" + REDIRECT = "redirect" + ROOT_URL = "root_url" \ No newline at end of file diff --git a/src/db/models/impl/url/core/pydantic/__init__.py b/src/db/models/impl/url/core/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/core/pydantic/info.py b/src/db/models/impl/url/core/pydantic/info.py new file mode 100644 index 00000000..0985b3fc --- /dev/null +++ b/src/db/models/impl/url/core/pydantic/info.py @@ -0,0 +1,19 @@ +import datetime +from typing import Optional + +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.db.models.impl.url.core.enums import URLSource + + +class URLInfo(BaseModel): + id: int | None = None + batch_id: int | None= None + url: str + collector_metadata: dict | None = None + status: URLStatus = URLStatus.OK + updated_at: datetime.datetime | None = None + created_at: datetime.datetime | None = None + name: str | None = None + source: URLSource | None = None diff --git a/src/db/models/impl/url/core/pydantic/insert.py b/src/db/models/impl/url/core/pydantic/insert.py new file mode 100644 index 00000000..f04dd3df --- /dev/null +++ b/src/db/models/impl/url/core/pydantic/insert.py @@ -0,0 +1,20 @@ +from src.collectors.enums import URLStatus +from src.core.enums import RecordType +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLInsertModel(BulkInsertableModel): + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URL + + url: str + collector_metadata: dict | None = None + name: str | None = None + status: URLStatus = URLStatus.OK + source: URLSource \ No newline at end of file diff --git a/src/db/models/impl/url/core/pydantic/upsert.py b/src/db/models/impl/url/core/pydantic/upsert.py new file mode 100644 index 00000000..8a101c70 --- /dev/null +++ b/src/db/models/impl/url/core/pydantic/upsert.py @@ -0,0 +1,18 @@ +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + + +class URLUpsertModel(BulkUpsertableModel): + + @classmethod + def id_field(cls) -> str: + return "id" + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URL + + id: int + name: str | None diff --git a/src/db/models/impl/url/core/sqlalchemy.py b/src/db/models/impl/url/core/sqlalchemy.py new file mode 100644 index 00000000..3582dd56 --- /dev/null +++ b/src/db/models/impl/url/core/sqlalchemy.py @@ -0,0 +1,109 @@ +from sqlalchemy import Column, Text, String, JSON +from sqlalchemy.orm import relationship + +from src.collectors.enums import URLStatus +from src.db.models.helpers import enum_column +from src.db.models.impl.url.checked_for_duplicate import URLCheckedForDuplicate +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URL(UpdatedAtMixin, CreatedAtMixin, WithIDBase): + __tablename__ = 'urls' + + # The batch this URL is associated with + url = Column(Text, unique=True) + name = Column(String) + description = Column(Text) + # The metadata from the collector + collector_metadata = Column(JSON) + # The outcome of the URL: submitted, human_labeling, rejected, duplicate, etc. + status = enum_column( + URLStatus, + name='url_status', + nullable=False + ) + + source = enum_column( + URLSource, + name='url_source', + nullable=False + ) + + # Relationships + batch = relationship( + "Batch", + secondary="link_batch_urls", + back_populates="urls", + uselist=False, + ) + record_type = relationship( + URLRecordType, + uselist=False, + ) + duplicates = relationship("Duplicate", back_populates="original_url") + html_content = relationship("URLHTMLContent", back_populates="url", cascade="all, delete-orphan") + task_errors = relationship( + URLTaskError, + cascade="all, delete-orphan" + ) + tasks = relationship( + "Task", + secondary="link_task_urls", + back_populates="urls", + ) + auto_agency_subtasks = relationship( + "URLAutoAgencyIDSubtask" + ) + auto_location_subtasks = relationship( + AutoLocationIDSubtask + ) + name_suggestions = relationship( + URLNameSuggestion + ) + user_agency_suggestions = relationship( + "UserUrlAgencySuggestion", back_populates="url") + auto_record_type_suggestion = relationship( + "AutoRecordTypeSuggestion", uselist=False, back_populates="url") + user_record_type_suggestions = relationship( + "UserRecordTypeSuggestion", back_populates="url") + auto_relevant_suggestion = relationship( + "AutoRelevantSuggestion", uselist=False, back_populates="url") + user_relevant_suggestions = relationship( + "UserURLTypeSuggestion", back_populates="url") + reviewing_user = relationship( + "ReviewingUserURL", uselist=False, back_populates="url") + optional_data_source_metadata = relationship( + "URLOptionalDataSourceMetadata", uselist=False, back_populates="url") + confirmed_agencies = relationship( + "LinkURLAgency", + ) + data_source = relationship( + "URLDataSource", + back_populates="url", + uselist=False + ) + checked_for_duplicate = relationship( + URLCheckedForDuplicate, + uselist=False, + back_populates="url" + ) + compressed_html = relationship( + URLCompressedHTML, + uselist=False, + back_populates="url" + ) + scrape_info = relationship( + "URLScrapeInfo", + uselist=False, + ) + web_metadata = relationship( + "URLWebMetadata", + uselist=False, + ) \ No newline at end of file diff --git a/src/db/models/impl/url/data_source/__init__.py b/src/db/models/impl/url/data_source/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/data_source/pydantic.py b/src/db/models/impl/url/data_source/pydantic.py new file mode 100644 index 00000000..7d02c5df --- /dev/null +++ b/src/db/models/impl/url/data_source/pydantic.py @@ -0,0 +1,11 @@ +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLDataSourcePydantic(BulkInsertableModel): + data_source_id: int + url_id: int + + @classmethod + def sa_model(cls) -> type[URLDataSource]: + return URLDataSource \ No newline at end of file diff --git a/src/db/models/impl/url/data_source/sqlalchemy.py b/src/db/models/impl/url/data_source/sqlalchemy.py new file mode 100644 index 00000000..be7bf047 --- /dev/null +++ b/src/db/models/impl/url/data_source/sqlalchemy.py @@ -0,0 +1,18 @@ +from sqlalchemy import Column, Integer +from sqlalchemy.orm import relationship + +from src.db.models.mixins import CreatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLDataSource(CreatedAtMixin, URLDependentMixin, WithIDBase): + __tablename__ = "url_data_source" + + data_source_id = Column(Integer, nullable=False) + + # Relationships + url = relationship( + "URL", + back_populates="data_source", + uselist=False + ) diff --git a/src/db/models/impl/url/ds_meta_url/__init__.py b/src/db/models/impl/url/ds_meta_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/ds_meta_url/pydantic.py b/src/db/models/impl/url/ds_meta_url/pydantic.py new file mode 100644 index 00000000..8f7674e9 --- /dev/null +++ b/src/db/models/impl/url/ds_meta_url/pydantic.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel + +from src.db.models.impl.url.ds_meta_url.sqlalchemy import URLDSMetaURL + + +class URLDSMetaURLPydantic(BaseModel): + + url_id: int + ds_meta_url_id: int + agency_id: int + + @classmethod + def sa_model(cls) -> type[URLDSMetaURL]: + return URLDSMetaURL \ No newline at end of file diff --git a/src/db/models/impl/url/ds_meta_url/sqlalchemy.py b/src/db/models/impl/url/ds_meta_url/sqlalchemy.py new file mode 100644 index 00000000..e642a694 --- /dev/null +++ b/src/db/models/impl/url/ds_meta_url/sqlalchemy.py @@ -0,0 +1,20 @@ +from sqlalchemy import Column, Integer, PrimaryKeyConstraint, UniqueConstraint + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, AgencyDependentMixin +from src.db.models.templates_.base import Base + + +class URLDSMetaURL( + Base, + URLDependentMixin, + AgencyDependentMixin, + CreatedAtMixin +): + __tablename__ = "url_ds_meta_url" + + ds_meta_url_id = Column(Integer) + + __table_args__ = ( + PrimaryKeyConstraint("url_id", "agency_id"), + UniqueConstraint("ds_meta_url_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/url/error_info/__init__.py b/src/db/models/impl/url/error_info/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/error_info/pydantic.py b/src/db/models/impl/url/error_info/pydantic.py new file mode 100644 index 00000000..3ae4d482 --- /dev/null +++ b/src/db/models/impl/url/error_info/pydantic.py @@ -0,0 +1,10 @@ +import datetime + +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLErrorInfoPydantic(BulkInsertableModel): + task_id: int + url_id: int + error: str + updated_at: datetime.datetime = None diff --git a/src/db/models/impl/url/html/__init__.py b/src/db/models/impl/url/html/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/html/compressed/__init__.py b/src/db/models/impl/url/html/compressed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/html/compressed/pydantic.py b/src/db/models/impl/url/html/compressed/pydantic.py new file mode 100644 index 00000000..1409d924 --- /dev/null +++ b/src/db/models/impl/url/html/compressed/pydantic.py @@ -0,0 +1,13 @@ +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLCompressedHTMLPydantic(BulkInsertableModel): + url_id: int + compressed_html: bytes + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URLCompressedHTML \ No newline at end of file diff --git a/src/db/models/impl/url/html/compressed/sqlalchemy.py b/src/db/models/impl/url/html/compressed/sqlalchemy.py new file mode 100644 index 00000000..995c5b25 --- /dev/null +++ b/src/db/models/impl/url/html/compressed/sqlalchemy.py @@ -0,0 +1,21 @@ +from sqlalchemy import Column, LargeBinary +from sqlalchemy.orm import relationship, Mapped + +from src.db.models.mixins import CreatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLCompressedHTML( + CreatedAtMixin, + URLDependentMixin, + WithIDBase +): + __tablename__ = 'url_compressed_html' + + compressed_html: Mapped[bytes] = Column(LargeBinary, nullable=False) + + url = relationship( + "URL", + uselist=False, + back_populates="compressed_html" + ) \ No newline at end of file diff --git a/src/db/models/impl/url/html/content/__init__.py b/src/db/models/impl/url/html/content/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/html/content/enums.py b/src/db/models/impl/url/html/content/enums.py new file mode 100644 index 00000000..13820352 --- /dev/null +++ b/src/db/models/impl/url/html/content/enums.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class HTMLContentType(Enum): + TITLE = "Title" + DESCRIPTION = "Description" + H1 = "H1" + H2 = "H2" + H3 = "H3" + H4 = "H4" + H5 = "H5" + H6 = "H6" + DIV = "Div" diff --git a/src/db/models/impl/url/html/content/pydantic.py b/src/db/models/impl/url/html/content/pydantic.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/html/content/sqlalchemy.py b/src/db/models/impl/url/html/content/sqlalchemy.py new file mode 100644 index 00000000..63e4da76 --- /dev/null +++ b/src/db/models/impl/url/html/content/sqlalchemy.py @@ -0,0 +1,28 @@ +from sqlalchemy import UniqueConstraint, Column, Text +from sqlalchemy.orm import relationship + +from src.db.enums import PGEnum +from src.db.models.mixins import UpdatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLHTMLContent( + UpdatedAtMixin, + URLDependentMixin, + WithIDBase +): + __tablename__ = 'url_html_content' + __table_args__ = (UniqueConstraint( + "url_id", + "content_type", + name="uq_url_id_content_type"), + ) + + content_type = Column( + PGEnum('Title', 'Description', 'H1', 'H2', 'H3', 'H4', 'H5', 'H6', 'Div', name='url_html_content_type'), + nullable=False) + content = Column(Text, nullable=False) + + + # Relationships + url = relationship("URL", back_populates="html_content") diff --git a/src/db/models/impl/url/internet_archives/__init__.py b/src/db/models/impl/url/internet_archives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/internet_archives/probe/__init__.py b/src/db/models/impl/url/internet_archives/probe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/internet_archives/probe/pydantic.py b/src/db/models/impl/url/internet_archives/probe/pydantic.py new file mode 100644 index 00000000..d62eceeb --- /dev/null +++ b/src/db/models/impl/url/internet_archives/probe/pydantic.py @@ -0,0 +1,14 @@ +from src.db.models.impl.url.internet_archives.probe.sqlalchemy import URLInternetArchivesProbeMetadata +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLInternetArchiveMetadataPydantic(BulkInsertableModel): + + url_id: int + archive_url: str + digest: str + length: int + + @classmethod + def sa_model(cls) -> type[URLInternetArchivesProbeMetadata]: + return URLInternetArchivesProbeMetadata diff --git a/src/db/models/impl/url/internet_archives/probe/sqlalchemy.py b/src/db/models/impl/url/internet_archives/probe/sqlalchemy.py new file mode 100644 index 00000000..122905a7 --- /dev/null +++ b/src/db/models/impl/url/internet_archives/probe/sqlalchemy.py @@ -0,0 +1,15 @@ +from sqlalchemy.orm import Mapped + +from src.db.models.mixins import URLDependentMixin +from src.db.models.templates_.standard import StandardBase + + +class URLInternetArchivesProbeMetadata( + StandardBase, + URLDependentMixin +): + __tablename__ = 'url_internet_archives_probe_metadata' + + archive_url: Mapped[str] + digest: Mapped[str] + length: Mapped[int] \ No newline at end of file diff --git a/src/db/models/impl/url/internet_archives/save/__init__.py b/src/db/models/impl/url/internet_archives/save/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/internet_archives/save/pydantic.py b/src/db/models/impl/url/internet_archives/save/pydantic.py new file mode 100644 index 00000000..16e9f281 --- /dev/null +++ b/src/db/models/impl/url/internet_archives/save/pydantic.py @@ -0,0 +1,10 @@ +from src.db.models.impl.url.internet_archives.save.sqlalchemy import URLInternetArchivesSaveMetadata +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLInternetArchiveSaveMetadataPydantic(BulkInsertableModel): + url_id: int + + @classmethod + def sa_model(cls) -> type[URLInternetArchivesSaveMetadata]: + return URLInternetArchivesSaveMetadata \ No newline at end of file diff --git a/src/db/models/impl/url/internet_archives/save/sqlalchemy.py b/src/db/models/impl/url/internet_archives/save/sqlalchemy.py new file mode 100644 index 00000000..791f4077 --- /dev/null +++ b/src/db/models/impl/url/internet_archives/save/sqlalchemy.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, DateTime, func + +from src.db.models.mixins import URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLInternetArchivesSaveMetadata( + WithIDBase, + URLDependentMixin +): + __tablename__ = 'url_internet_archives_save_metadata' + + created_at = Column(DateTime, nullable=False, server_default=func.now()) + last_uploaded_at = Column(DateTime, nullable=False, server_default=func.now()) diff --git a/src/db/models/instantiations/url/optional_data_source_metadata.py b/src/db/models/impl/url/optional_data_source_metadata.py similarity index 79% rename from src/db/models/instantiations/url/optional_data_source_metadata.py rename to src/db/models/impl/url/optional_data_source_metadata.py index 84871982..bb2a95e5 100644 --- a/src/db/models/instantiations/url/optional_data_source_metadata.py +++ b/src/db/models/impl/url/optional_data_source_metadata.py @@ -2,10 +2,10 @@ from sqlalchemy.orm import relationship from src.db.models.mixins import URLDependentMixin -from src.db.models.templates import StandardModel +from src.db.models.templates_.with_id import WithIDBase -class URLOptionalDataSourceMetadata(URLDependentMixin, StandardModel): +class URLOptionalDataSourceMetadata(URLDependentMixin, WithIDBase): __tablename__ = 'url_optional_data_source_metadata' record_formats = Column(ARRAY(String), nullable=True) diff --git a/src/db/models/impl/url/record_type/__init__.py b/src/db/models/impl/url/record_type/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/record_type/pydantic.py b/src/db/models/impl/url/record_type/pydantic.py new file mode 100644 index 00000000..a45df06c --- /dev/null +++ b/src/db/models/impl/url/record_type/pydantic.py @@ -0,0 +1,20 @@ +from src.core.enums import RecordType +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + + +class URLRecordTypePydantic( + BulkInsertableModel, + BulkUpsertableModel, +): + url_id: int + record_type: RecordType + + @classmethod + def sa_model(cls) -> type[URLRecordType]: + return URLRecordType + + @classmethod + def id_field(cls) -> str: + return "url_id" \ No newline at end of file diff --git a/src/db/models/impl/url/record_type/sqlalchemy.py b/src/db/models/impl/url/record_type/sqlalchemy.py new file mode 100644 index 00000000..7e8f2fac --- /dev/null +++ b/src/db/models/impl/url/record_type/sqlalchemy.py @@ -0,0 +1,17 @@ +from sqlalchemy.orm import Mapped + +from src.core.enums import RecordType +from src.db.models.helpers import url_id_primary_key_constraint, enum_column +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class URLRecordType( + Base, + CreatedAtMixin, + URLDependentMixin +): + __tablename__ = "url_record_type" + __table_args__ = (url_id_primary_key_constraint(),) + + record_type: Mapped[RecordType] = enum_column(RecordType, name="record_type", nullable=False) \ No newline at end of file diff --git a/src/db/models/instantiations/url/reviewing_user.py b/src/db/models/impl/url/reviewing_user.py similarity index 79% rename from src/db/models/instantiations/url/reviewing_user.py rename to src/db/models/impl/url/reviewing_user.py index d28a33e7..9213a157 100644 --- a/src/db/models/instantiations/url/reviewing_user.py +++ b/src/db/models/impl/url/reviewing_user.py @@ -2,10 +2,10 @@ from sqlalchemy.orm import relationship from src.db.models.mixins import CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel +from src.db.models.templates_.with_id import WithIDBase -class ReviewingUserURL(CreatedAtMixin, URLDependentMixin, StandardModel): +class ReviewingUserURL(CreatedAtMixin, URLDependentMixin, WithIDBase): __tablename__ = 'reviewing_user_url' __table_args__ = ( UniqueConstraint( diff --git a/src/db/models/impl/url/scrape_info/__init__.py b/src/db/models/impl/url/scrape_info/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/scrape_info/enums.py b/src/db/models/impl/url/scrape_info/enums.py new file mode 100644 index 00000000..3e16fff3 --- /dev/null +++ b/src/db/models/impl/url/scrape_info/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class ScrapeStatus(Enum): + SUCCESS = "success" + ERROR = "error" \ No newline at end of file diff --git a/src/db/models/impl/url/scrape_info/pydantic.py b/src/db/models/impl/url/scrape_info/pydantic.py new file mode 100644 index 00000000..1aaf2205 --- /dev/null +++ b/src/db/models/impl/url/scrape_info/pydantic.py @@ -0,0 +1,13 @@ +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus +from src.db.models.impl.url.scrape_info.sqlalchemy import URLScrapeInfo +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLScrapeInfoInsertModel(BulkInsertableModel): + url_id: int + status: ScrapeStatus + + @classmethod + def sa_model(cls) -> type[Base]: + return URLScrapeInfo \ No newline at end of file diff --git a/src/db/models/impl/url/scrape_info/sqlalchemy.py b/src/db/models/impl/url/scrape_info/sqlalchemy.py new file mode 100644 index 00000000..b50f2903 --- /dev/null +++ b/src/db/models/impl/url/scrape_info/sqlalchemy.py @@ -0,0 +1,17 @@ +from src.db.models.helpers import enum_column +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus +from src.db.models.mixins import URLDependentMixin +from src.db.models.templates_.standard import StandardBase + + +class URLScrapeInfo( + StandardBase, + URLDependentMixin +): + + __tablename__ = 'url_scrape_info' + + status = enum_column( + enum_type=ScrapeStatus, + name='scrape_status', + ) \ No newline at end of file diff --git a/src/db/models/impl/url/screenshot/__init__.py b/src/db/models/impl/url/screenshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/screenshot/pydantic.py b/src/db/models/impl/url/screenshot/pydantic.py new file mode 100644 index 00000000..027bec19 --- /dev/null +++ b/src/db/models/impl/url/screenshot/pydantic.py @@ -0,0 +1,13 @@ +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLScreenshotPydantic(BulkInsertableModel): + url_id: int + content: bytes + file_size: int + + @classmethod + def sa_model(cls) -> type[Base]: + return URLScreenshot diff --git a/src/db/models/impl/url/screenshot/sqlalchemy.py b/src/db/models/impl/url/screenshot/sqlalchemy.py new file mode 100644 index 00000000..e61a77ea --- /dev/null +++ b/src/db/models/impl/url/screenshot/sqlalchemy.py @@ -0,0 +1,22 @@ +from sqlalchemy import Column, LargeBinary, Integer, UniqueConstraint, PrimaryKeyConstraint + +from src.db.models.helpers import url_id_primary_key_constraint +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, UpdatedAtMixin +from src.db.models.templates_.base import Base + + +class URLScreenshot( + Base, + URLDependentMixin, + CreatedAtMixin, + UpdatedAtMixin, +): + __tablename__ = "url_screenshot" + __table_args__ = ( + url_id_primary_key_constraint(), + ) + + + content = Column(LargeBinary, nullable=False) + file_size = Column(Integer, nullable=False) + diff --git a/src/db/models/instantiations/url/suggestion/README.md b/src/db/models/impl/url/suggestion/README.md similarity index 100% rename from src/db/models/instantiations/url/suggestion/README.md rename to src/db/models/impl/url/suggestion/README.md diff --git a/src/db/models/impl/url/suggestion/__init__.py b/src/db/models/impl/url/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/agency/__init__.py b/src/db/models/impl/url/suggestion/agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/agency/subtask/__init__.py b/src/db/models/impl/url/suggestion/agency/subtask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/agency/subtask/enum.py b/src/db/models/impl/url/suggestion/agency/subtask/enum.py new file mode 100644 index 00000000..ef1ecbc0 --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/subtask/enum.py @@ -0,0 +1,15 @@ +from enum import Enum + + +class AutoAgencyIDSubtaskType(Enum): + HOMEPAGE_MATCH = "homepage_match" + NLP_LOCATION_MATCH = "nlp_location_match" + MUCKROCK = "muckrock_match" + CKAN = "ckan_match" + BATCH_LINK = "batch_link" + +class SubtaskDetailCode(Enum): + NO_DETAILS = "no details" + RETRIEVAL_ERROR = "retrieval error" + HOMEPAGE_SINGLE_AGENCY = "homepage-single agency" + HOMEPAGE_MULTI_AGENCY = "homepage-multi agency" \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/agency/subtask/pydantic.py b/src/db/models/impl/url/suggestion/agency/subtask/pydantic.py new file mode 100644 index 00000000..f2e9be57 --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/subtask/pydantic.py @@ -0,0 +1,17 @@ +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType, SubtaskDetailCode +from src.db.models.impl.url.suggestion.agency.subtask.sqlalchemy import URLAutoAgencyIDSubtask +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + +type_alias = type + +class URLAutoAgencyIDSubtaskPydantic(BulkInsertableModel): + task_id: int + url_id: int + type: AutoAgencyIDSubtaskType + agencies_found: bool + detail: SubtaskDetailCode = SubtaskDetailCode.NO_DETAILS + + @classmethod + def sa_model(cls) -> type_alias[Base]: + return URLAutoAgencyIDSubtask \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/agency/subtask/sqlalchemy.py b/src/db/models/impl/url/suggestion/agency/subtask/sqlalchemy.py new file mode 100644 index 00000000..89371498 --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/subtask/sqlalchemy.py @@ -0,0 +1,35 @@ +from sqlalchemy.orm import relationship + +from src.db.models.helpers import enum_column +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType, SubtaskDetailCode +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, TaskDependentMixin +from src.db.models.templates_.with_id import WithIDBase + +import sqlalchemy as sa + +class URLAutoAgencyIDSubtask( + WithIDBase, + URLDependentMixin, + TaskDependentMixin, + CreatedAtMixin +): + + __tablename__ = "url_auto_agency_id_subtasks" + + type = enum_column( + AutoAgencyIDSubtaskType, + name="agency_auto_suggestion_method" + ) + agencies_found = sa.Column( + sa.Boolean(), + nullable=False + ) + detail = enum_column( + SubtaskDetailCode, + name="agency_id_subtask_detail_code", + ) + + suggestions = relationship( + "AgencyIDSubtaskSuggestion", + cascade="all, delete-orphan" + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/agency/suggestion/__init__.py b/src/db/models/impl/url/suggestion/agency/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/agency/suggestion/pydantic.py b/src/db/models/impl/url/suggestion/agency/suggestion/pydantic.py new file mode 100644 index 00000000..5a0fd2b8 --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/suggestion/pydantic.py @@ -0,0 +1,16 @@ +from src.db.models.impl.url.suggestion.agency.suggestion.sqlalchemy import AgencyIDSubtaskSuggestion +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class AgencyIDSubtaskSuggestionPydantic( + BulkInsertableModel, +): + subtask_id: int + agency_id: int + confidence: int + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return AgencyIDSubtaskSuggestion \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/agency/suggestion/sqlalchemy.py b/src/db/models/impl/url/suggestion/agency/suggestion/sqlalchemy.py new file mode 100644 index 00000000..de6ee029 --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/suggestion/sqlalchemy.py @@ -0,0 +1,28 @@ +import sqlalchemy as sa +from sqlalchemy.orm import relationship + +from src.db.models.mixins import CreatedAtMixin, AgencyDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class AgencyIDSubtaskSuggestion( + WithIDBase, + CreatedAtMixin, + AgencyDependentMixin, +): + __tablename__ = "agency_id_subtask_suggestions" + + subtask_id = sa.Column( + sa.Integer, + sa.ForeignKey("url_auto_agency_id_subtasks.id"), + nullable=False + ) + confidence = sa.Column( + sa.Integer, + sa.CheckConstraint( + "confidence BETWEEN 0 and 100" + ), + nullable=False, + ) + + agency = relationship("Agency", viewonly=True) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/agency/user.py b/src/db/models/impl/url/suggestion/agency/user.py new file mode 100644 index 00000000..f7c43aad --- /dev/null +++ b/src/db/models/impl/url/suggestion/agency/user.py @@ -0,0 +1,21 @@ +from sqlalchemy import Column, Boolean, UniqueConstraint, Integer +from sqlalchemy.orm import relationship, Mapped + +from src.db.models.helpers import get_agency_id_foreign_column +from src.db.models.mixins import URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class UserUrlAgencySuggestion(URLDependentMixin, WithIDBase): + __tablename__ = "user_url_agency_suggestions" + + agency_id: Mapped[int] = get_agency_id_foreign_column(nullable=True) + user_id = Column(Integer, nullable=False) + is_new = Column(Boolean, nullable=True) + + agency = relationship("Agency", back_populates="user_suggestions") + url = relationship("URL", back_populates="user_agency_suggestions") + + __table_args__ = ( + UniqueConstraint("agency_id", "url_id", "user_id", name="uq_user_url_agency_suggestions"), + ) diff --git a/src/db/models/impl/url/suggestion/anonymous/__init__.py b/src/db/models/impl/url/suggestion/anonymous/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/anonymous/agency/__init__.py b/src/db/models/impl/url/suggestion/anonymous/agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/anonymous/agency/sqlalchemy.py b/src/db/models/impl/url/suggestion/anonymous/agency/sqlalchemy.py new file mode 100644 index 00000000..afea2f23 --- /dev/null +++ b/src/db/models/impl/url/suggestion/anonymous/agency/sqlalchemy.py @@ -0,0 +1,16 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import URLDependentMixin, AgencyDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class AnonymousAnnotationAgency( + Base, + URLDependentMixin, + AgencyDependentMixin, + CreatedAtMixin +): + __tablename__ = "anonymous_annotation_agency" + __table_args__ = ( + PrimaryKeyConstraint("url_id", "agency_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/anonymous/location/__init__.py b/src/db/models/impl/url/suggestion/anonymous/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/anonymous/location/sqlalchemy.py b/src/db/models/impl/url/suggestion/anonymous/location/sqlalchemy.py new file mode 100644 index 00000000..f02cb7ba --- /dev/null +++ b/src/db/models/impl/url/suggestion/anonymous/location/sqlalchemy.py @@ -0,0 +1,17 @@ +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import LocationDependentMixin, URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class AnonymousAnnotationLocation( + Base, + URLDependentMixin, + LocationDependentMixin, + CreatedAtMixin +): + + __tablename__ = "anonymous_annotation_location" + __table_args__ = ( + PrimaryKeyConstraint("url_id", "location_id"), + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/anonymous/record_type/__init__.py b/src/db/models/impl/url/suggestion/anonymous/record_type/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/anonymous/record_type/sqlalchemy.py b/src/db/models/impl/url/suggestion/anonymous/record_type/sqlalchemy.py new file mode 100644 index 00000000..25a9ddec --- /dev/null +++ b/src/db/models/impl/url/suggestion/anonymous/record_type/sqlalchemy.py @@ -0,0 +1,23 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.core.enums import RecordType +from src.db.models.helpers import enum_column +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class AnonymousAnnotationRecordType( + Base, + URLDependentMixin, + CreatedAtMixin +): + __tablename__ = "anonymous_annotation_record_type" + __table_args__ = ( + PrimaryKeyConstraint("url_id", "record_type"), + ) + + record_type: Mapped[RecordType] = enum_column( + name="record_type", + enum_type=RecordType, + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/anonymous/url_type/__init__.py b/src/db/models/impl/url/suggestion/anonymous/url_type/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/anonymous/url_type/sqlalchemy.py b/src/db/models/impl/url/suggestion/anonymous/url_type/sqlalchemy.py new file mode 100644 index 00000000..f9033ffa --- /dev/null +++ b/src/db/models/impl/url/suggestion/anonymous/url_type/sqlalchemy.py @@ -0,0 +1,23 @@ +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.helpers import enum_column +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class AnonymousAnnotationURLType( + Base, + URLDependentMixin, + CreatedAtMixin +): + __tablename__ = "anonymous_annotation_url_type" + __table_args__ = ( + PrimaryKeyConstraint("url_id", "url_type"), + ) + + url_type: Mapped[URLType] = enum_column( + name="url_type", + enum_type=URLType, + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/__init__.py b/src/db/models/impl/url/suggestion/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/location/auto/__init__.py b/src/db/models/impl/url/suggestion/location/auto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/location/auto/subtask/__init__.py b/src/db/models/impl/url/suggestion/location/auto/subtask/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/location/auto/subtask/constants.py b/src/db/models/impl/url/suggestion/location/auto/subtask/constants.py new file mode 100644 index 00000000..d6b887c7 --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/subtask/constants.py @@ -0,0 +1,3 @@ + + +MAX_SUGGESTION_LENGTH: int = 100 \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/auto/subtask/enums.py b/src/db/models/impl/url/suggestion/location/auto/subtask/enums.py new file mode 100644 index 00000000..c4937af3 --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/subtask/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class LocationIDSubtaskType(Enum): + NLP_LOCATION_FREQUENCY = 'nlp_location_frequency' + BATCH_LINK = 'batch_link' \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/auto/subtask/pydantic.py b/src/db/models/impl/url/suggestion/location/auto/subtask/pydantic.py new file mode 100644 index 00000000..091a00b9 --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/subtask/pydantic.py @@ -0,0 +1,19 @@ +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class AutoLocationIDSubtaskPydantic( + BulkInsertableModel, +): + + url_id: int + task_id: int + locations_found: bool + type: LocationIDSubtaskType + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return AutoLocationIDSubtask \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/auto/subtask/sqlalchemy.py b/src/db/models/impl/url/suggestion/location/auto/subtask/sqlalchemy.py new file mode 100644 index 00000000..b7412d1e --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/subtask/sqlalchemy.py @@ -0,0 +1,28 @@ +from sqlalchemy import Column, Boolean +from sqlalchemy.orm import relationship, Mapped + +from src.db.models.helpers import enum_column +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.models.mixins import CreatedAtMixin, TaskDependentMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class AutoLocationIDSubtask( + WithIDBase, + CreatedAtMixin, + TaskDependentMixin, + URLDependentMixin, +): + + __tablename__ = 'auto_location_id_subtasks' + + locations_found = Column(Boolean(), nullable=False) + type: Mapped[LocationIDSubtaskType] = enum_column( + LocationIDSubtaskType, + name='auto_location_id_subtask_type' + ) + + suggestions = relationship( + LocationIDSubtaskSuggestion + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/auto/suggestion/__init__.py b/src/db/models/impl/url/suggestion/location/auto/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/location/auto/suggestion/pydantic.py b/src/db/models/impl/url/suggestion/location/auto/suggestion/pydantic.py new file mode 100644 index 00000000..1ddc53d7 --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/suggestion/pydantic.py @@ -0,0 +1,15 @@ +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class LocationIDSubtaskSuggestionPydantic(BulkInsertableModel): + + subtask_id: int + location_id: int + confidence: float + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return LocationIDSubtaskSuggestion \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/auto/suggestion/sqlalchemy.py b/src/db/models/impl/url/suggestion/location/auto/suggestion/sqlalchemy.py new file mode 100644 index 00000000..0d5ea926 --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/auto/suggestion/sqlalchemy.py @@ -0,0 +1,27 @@ +from sqlalchemy import Column, Integer, ForeignKey, Float, PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.models.helpers import location_id_column +from src.db.models.templates_.base import Base + + +class LocationIDSubtaskSuggestion( + Base, +): + + __tablename__ = 'location_id_subtask_suggestions' + __table_args__ = ( + PrimaryKeyConstraint( + 'subtask_id', + 'location_id', + name='location_id_subtask_suggestions_pk' + ), + ) + subtask_id = Column( + Integer, + ForeignKey('auto_location_id_subtasks.id'), + nullable=False, + primary_key=True, + ) + location_id: Mapped[int] = location_id_column() + confidence = Column(Float, nullable=False) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/location/user/__init__.py b/src/db/models/impl/url/suggestion/location/user/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/location/user/pydantic.py b/src/db/models/impl/url/suggestion/location/user/pydantic.py new file mode 100644 index 00000000..11f2218b --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/user/pydantic.py @@ -0,0 +1,16 @@ +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class UserLocationSuggestionPydantic( + BulkInsertableModel, +): + + location_id: int + url_id: int + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return UserLocationSuggestion diff --git a/src/db/models/impl/url/suggestion/location/user/sqlalchemy.py b/src/db/models/impl/url/suggestion/location/user/sqlalchemy.py new file mode 100644 index 00000000..a9d4ae8b --- /dev/null +++ b/src/db/models/impl/url/suggestion/location/user/sqlalchemy.py @@ -0,0 +1,21 @@ +from sqlalchemy import Integer, Column, PrimaryKeyConstraint + +from src.db.models.mixins import CreatedAtMixin, URLDependentMixin, LocationDependentMixin +from src.db.models.templates_.base import Base + + +class UserLocationSuggestion( + Base, + CreatedAtMixin, + LocationDependentMixin, + URLDependentMixin +): + __tablename__ = 'user_location_suggestions' + __table_args__ = ( + PrimaryKeyConstraint('url_id', 'location_id', 'user_id'), + ) + + user_id = Column( + Integer, + nullable=False, + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/name/__init__.py b/src/db/models/impl/url/suggestion/name/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/name/enums.py b/src/db/models/impl/url/suggestion/name/enums.py new file mode 100644 index 00000000..89b570e6 --- /dev/null +++ b/src/db/models/impl/url/suggestion/name/enums.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class NameSuggestionSource(Enum): + HTML_METADATA_TITLE = "HTML Metadata Title" + USER = "User" \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/name/pydantic.py b/src/db/models/impl/url/suggestion/name/pydantic.py new file mode 100644 index 00000000..244e02c2 --- /dev/null +++ b/src/db/models/impl/url/suggestion/name/pydantic.py @@ -0,0 +1,17 @@ +from pydantic import Field + +from src.db.models.impl.url.suggestion.location.auto.subtask.constants import MAX_SUGGESTION_LENGTH +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +class URLNameSuggestionPydantic(BulkInsertableModel): + + url_id: int + suggestion: str = Field(..., max_length=MAX_SUGGESTION_LENGTH) + source: NameSuggestionSource + + @classmethod + def sa_model(cls) -> type[URLNameSuggestion]: + return URLNameSuggestion \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/name/sqlalchemy.py b/src/db/models/impl/url/suggestion/name/sqlalchemy.py new file mode 100644 index 00000000..2f11542d --- /dev/null +++ b/src/db/models/impl/url/suggestion/name/sqlalchemy.py @@ -0,0 +1,23 @@ +from sqlalchemy import Column, String +from sqlalchemy.orm import Mapped + +from src.db.models.helpers import enum_column +from src.db.models.impl.url.suggestion.location.auto.subtask.constants import MAX_SUGGESTION_LENGTH +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLNameSuggestion( + WithIDBase, + CreatedAtMixin, + URLDependentMixin +): + + __tablename__ = "url_name_suggestions" + + suggestion = Column(String(MAX_SUGGESTION_LENGTH), nullable=False) + source: Mapped[NameSuggestionSource] = enum_column( + NameSuggestionSource, + name="suggestion_source_enum" + ) \ No newline at end of file diff --git a/src/db/models/impl/url/suggestion/record_type/__init__.py b/src/db/models/impl/url/suggestion/record_type/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/record_type/auto.py b/src/db/models/impl/url/suggestion/record_type/auto.py new file mode 100644 index 00000000..2aaed526 --- /dev/null +++ b/src/db/models/impl/url/suggestion/record_type/auto.py @@ -0,0 +1,27 @@ +from sqlalchemy import Column, UniqueConstraint +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import relationship + +from src.db.models.mixins import URLDependentMixin, UpdatedAtMixin, CreatedAtMixin +from src.db.models.templates_.with_id import WithIDBase +from src.db.models.types import record_type_values + + +class AutoRecordTypeSuggestion( + UpdatedAtMixin, + CreatedAtMixin, + URLDependentMixin, + WithIDBase +): + __tablename__ = "auto_record_type_suggestions" + record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=False) + + __table_args__ = ( + UniqueConstraint("url_id", name="auto_record_type_suggestions_uq_url_id"), + ) + + # Relationships + + url = relationship("URL", back_populates="auto_record_type_suggestion") + + diff --git a/src/db/models/impl/url/suggestion/record_type/user.py b/src/db/models/impl/url/suggestion/record_type/user.py new file mode 100644 index 00000000..5b9dde8c --- /dev/null +++ b/src/db/models/impl/url/suggestion/record_type/user.py @@ -0,0 +1,22 @@ +from sqlalchemy import Column, Integer, UniqueConstraint +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import relationship + +from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase +from src.db.models.types import record_type_values + + +class UserRecordTypeSuggestion(UpdatedAtMixin, CreatedAtMixin, URLDependentMixin, WithIDBase): + __tablename__ = "user_record_type_suggestions" + + user_id = Column(Integer, nullable=False) + record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=False) + + __table_args__ = ( + UniqueConstraint("url_id", "user_id", name="uq_user_record_type_suggestions"), + ) + + # Relationships + + url = relationship("URL", back_populates="user_record_type_suggestions") diff --git a/src/db/models/impl/url/suggestion/relevant/__init__.py b/src/db/models/impl/url/suggestion/relevant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/relevant/auto/__init__.py b/src/db/models/impl/url/suggestion/relevant/auto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/suggestion/relevant/auto/pydantic/__init__.py b/src/db/models/impl/url/suggestion/relevant/auto/pydantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/dtos/url/annotations/auto/relevancy.py b/src/db/models/impl/url/suggestion/relevant/auto/pydantic/input.py similarity index 100% rename from src/db/dtos/url/annotations/auto/relevancy.py rename to src/db/models/impl/url/suggestion/relevant/auto/pydantic/input.py diff --git a/src/db/models/impl/url/suggestion/relevant/auto/sqlalchemy.py b/src/db/models/impl/url/suggestion/relevant/auto/sqlalchemy.py new file mode 100644 index 00000000..49dc7457 --- /dev/null +++ b/src/db/models/impl/url/suggestion/relevant/auto/sqlalchemy.py @@ -0,0 +1,21 @@ +from sqlalchemy import Column, Boolean, UniqueConstraint, String, Float +from sqlalchemy.orm import relationship + +from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class AutoRelevantSuggestion(UpdatedAtMixin, CreatedAtMixin, URLDependentMixin, WithIDBase): + __tablename__ = "auto_relevant_suggestions" + + relevant = Column(Boolean, nullable=True) + confidence = Column(Float, nullable=True) + model_name = Column(String, nullable=True) + + __table_args__ = ( + UniqueConstraint("url_id", name="auto_relevant_suggestions_uq_url_id"), + ) + + # Relationships + + url = relationship("URL", back_populates="auto_relevant_suggestion") diff --git a/src/db/models/impl/url/suggestion/relevant/user.py b/src/db/models/impl/url/suggestion/relevant/user.py new file mode 100644 index 00000000..c7070b5e --- /dev/null +++ b/src/db/models/impl/url/suggestion/relevant/user.py @@ -0,0 +1,32 @@ +from sqlalchemy import Column, UniqueConstraint, Integer +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import relationship, Mapped + +from src.db.models.helpers import enum_column +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class UserURLTypeSuggestion( + UpdatedAtMixin, + CreatedAtMixin, + URLDependentMixin, + WithIDBase +): + __tablename__ = "user_url_type_suggestions" + + user_id = Column(Integer, nullable=False) + type: Mapped[URLType | None] = enum_column( + URLType, + name="url_type", + nullable=True + ) + + __table_args__ = ( + UniqueConstraint("url_id", "user_id", name="uq_user_relevant_suggestions"), + ) + + # Relationships + + url = relationship("URL", back_populates="user_relevant_suggestions") diff --git a/src/db/models/impl/url/task_error/__init__.py b/src/db/models/impl/url/task_error/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/task_error/pydantic_/__init__.py b/src/db/models/impl/url/task_error/pydantic_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/task_error/pydantic_/insert.py b/src/db/models/impl/url/task_error/pydantic_/insert.py new file mode 100644 index 00000000..87172ad7 --- /dev/null +++ b/src/db/models/impl/url/task_error/pydantic_/insert.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from src.db.enums import TaskType +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.db.models.templates_.base import Base + + +class URLTaskErrorPydantic(BaseModel): + + url_id: int + task_id: int + task_type: TaskType + error: str + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URLTaskError diff --git a/src/db/models/impl/url/task_error/pydantic_/small.py b/src/db/models/impl/url/task_error/pydantic_/small.py new file mode 100644 index 00000000..ad14458e --- /dev/null +++ b/src/db/models/impl/url/task_error/pydantic_/small.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class URLTaskErrorSmall(BaseModel): + """Small version of URLTaskErrorPydantic, to be used with the `add_task_errors` method.""" + url_id: int + error: str \ No newline at end of file diff --git a/src/db/models/impl/url/task_error/sqlalchemy.py b/src/db/models/impl/url/task_error/sqlalchemy.py new file mode 100644 index 00000000..3c4ab016 --- /dev/null +++ b/src/db/models/impl/url/task_error/sqlalchemy.py @@ -0,0 +1,23 @@ +from sqlalchemy import String, Column, PrimaryKeyConstraint +from sqlalchemy.orm import Mapped + +from src.db.enums import TaskType +from src.db.models.helpers import enum_column +from src.db.models.mixins import URLDependentMixin, TaskDependentMixin, CreatedAtMixin +from src.db.models.templates_.base import Base + + +class URLTaskError( + Base, + URLDependentMixin, + TaskDependentMixin, + CreatedAtMixin, +): + __tablename__ = "url_task_error" + + task_type: Mapped[TaskType] = enum_column(TaskType, name="task_type") + error: Mapped[str] = Column(String) + + __table_args__ = ( + PrimaryKeyConstraint("url_id", "task_type"), + ) \ No newline at end of file diff --git a/src/db/models/impl/url/web_metadata/__init__.py b/src/db/models/impl/url/web_metadata/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/impl/url/web_metadata/insert.py b/src/db/models/impl/url/web_metadata/insert.py new file mode 100644 index 00000000..4467b9da --- /dev/null +++ b/src/db/models/impl/url/web_metadata/insert.py @@ -0,0 +1,27 @@ +from pydantic import Field + +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata +from src.db.models.templates_.base import Base +from src.db.templates.markers.bulk.insert import BulkInsertableModel +from src.db.templates.markers.bulk.upsert import BulkUpsertableModel + + +class URLWebMetadataPydantic( + BulkInsertableModel, + BulkUpsertableModel +): + + @classmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + return URLWebMetadata + + @classmethod + def id_field(cls) -> str: + return "url_id" + + url_id: int + accessed: bool + status_code: int | None = Field(le=999, ge=100) + content_type: str | None + error_message: str | None \ No newline at end of file diff --git a/src/db/models/impl/url/web_metadata/sqlalchemy.py b/src/db/models/impl/url/web_metadata/sqlalchemy.py new file mode 100644 index 00000000..45f5233c --- /dev/null +++ b/src/db/models/impl/url/web_metadata/sqlalchemy.py @@ -0,0 +1,33 @@ +from sqlalchemy import Column, Text, Boolean, Integer + +from src.db.models.mixins import URLDependentMixin, CreatedAtMixin, UpdatedAtMixin +from src.db.models.templates_.with_id import WithIDBase + + +class URLWebMetadata( + WithIDBase, + URLDependentMixin, + CreatedAtMixin, + UpdatedAtMixin +): + """Contains information about the web page.""" + __tablename__ = "url_web_metadata" + + accessed = Column( + Boolean(), + nullable=False + ) + status_code = Column( + Integer(), + nullable=True + ) + content_type = Column( + Text(), + nullable=True + ) + error_message = Column( + Text(), + nullable=True + ) + + diff --git a/src/db/models/instantiations/agency.py b/src/db/models/instantiations/agency.py deleted file mode 100644 index 37beec3d..00000000 --- a/src/db/models/instantiations/agency.py +++ /dev/null @@ -1,33 +0,0 @@ -""" -References an agency in the data sources database. -""" - -from sqlalchemy import Column, Integer, String, DateTime -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin -from src.db.models.templates import Base - - -class Agency( - CreatedAtMixin, # When agency was added to database - UpdatedAtMixin, # When agency was last updated in database - Base -): - __tablename__ = "agencies" - - agency_id = Column(Integer, primary_key=True) - name = Column(String, nullable=False) - state = Column(String, nullable=True) - county = Column(String, nullable=True) - locality = Column(String, nullable=True) - ds_last_updated_at = Column( - DateTime, - nullable=True, - comment="The last time the agency was updated in the data sources database." - ) - - # Relationships - automated_suggestions = relationship("AutomatedUrlAgencySuggestion", back_populates="agency") - user_suggestions = relationship("UserUrlAgencySuggestion", back_populates="agency") - confirmed_urls = relationship("ConfirmedURLAgency", back_populates="agency") diff --git a/src/db/models/instantiations/backlog_snapshot.py b/src/db/models/instantiations/backlog_snapshot.py deleted file mode 100644 index 240a82fd..00000000 --- a/src/db/models/instantiations/backlog_snapshot.py +++ /dev/null @@ -1,10 +0,0 @@ -from sqlalchemy import Column, Integer - -from src.db.models.mixins import CreatedAtMixin -from src.db.models.templates import StandardModel - - -class BacklogSnapshot(CreatedAtMixin, StandardModel): - __tablename__ = "backlog_snapshot" - - count_pending_total = Column(Integer, nullable=False) diff --git a/src/db/models/instantiations/batch.py b/src/db/models/instantiations/batch.py deleted file mode 100644 index 89645f4a..00000000 --- a/src/db/models/instantiations/batch.py +++ /dev/null @@ -1,56 +0,0 @@ -from sqlalchemy import Column, Integer, TIMESTAMP, Float, JSON -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship - -from src.db.models.helpers import CURRENT_TIME_SERVER_DEFAULT -from src.db.models.templates import StandardModel -from src.db.models.types import batch_status_enum - - -class Batch(StandardModel): - __tablename__ = 'batches' - - strategy = Column( - postgresql.ENUM( - 'example', - 'ckan', - 'muckrock_county_search', - 'auto_googler', - 'muckrock_all_search', - 'muckrock_simple_search', - 'common_crawler', - 'manual', - name='batch_strategy'), - nullable=False) - user_id = Column(Integer, nullable=False) - # Gives the status of the batch - status = Column( - batch_status_enum, - nullable=False - ) - date_generated = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) - # How often URLs ended up approved in the database - strategy_success_rate = Column(Float) - # Percentage of metadata identified by models - metadata_success_rate = Column(Float) - # Rate of matching to agencies - agency_match_rate = Column(Float) - # Rate of matching to record types - record_type_match_rate = Column(Float) - # Rate of matching to record categories - record_category_match_rate = Column(Float) - # Time taken to generate the batch - # TODO: Add means to update after execution - compute_time = Column(Float) - # The parameters used to generate the batch - parameters = Column(JSON) - - # Relationships - urls = relationship( - "URL", - secondary="link_batch_urls", - back_populates="batch" - ) - # missings = relationship("Missing", back_populates="batch") # Not in active use - logs = relationship("Log", back_populates="batch") - duplicates = relationship("Duplicate", back_populates="batch") diff --git a/src/db/models/instantiations/confirmed_url_agency.py b/src/db/models/instantiations/confirmed_url_agency.py deleted file mode 100644 index db63b114..00000000 --- a/src/db/models/instantiations/confirmed_url_agency.py +++ /dev/null @@ -1,19 +0,0 @@ -from sqlalchemy import UniqueConstraint -from sqlalchemy.orm import relationship - -from src.db.models.helpers import get_agency_id_foreign_column -from src.db.models.mixins import URLDependentMixin -from src.db.models.templates import StandardModel - - -class ConfirmedURLAgency(URLDependentMixin, StandardModel): - __tablename__ = "confirmed_url_agency" - - agency_id = get_agency_id_foreign_column() - - url = relationship("URL", back_populates="confirmed_agencies") - agency = relationship("Agency", back_populates="confirmed_urls") - - __table_args__ = ( - UniqueConstraint("url_id", "agency_id", name="uq_confirmed_url_agency"), - ) diff --git a/src/db/models/instantiations/duplicate.py b/src/db/models/instantiations/duplicate.py deleted file mode 100644 index 7a80d918..00000000 --- a/src/db/models/instantiations/duplicate.py +++ /dev/null @@ -1,23 +0,0 @@ -from sqlalchemy import Column, Integer, ForeignKey -from sqlalchemy.orm import relationship - -from src.db.models.mixins import BatchDependentMixin -from src.db.models.templates import StandardModel - - -class Duplicate(BatchDependentMixin, StandardModel): - """ - Identifies duplicates which occur within a batch - """ - __tablename__ = 'duplicates' - - original_url_id = Column( - Integer, - ForeignKey('urls.id'), - nullable=False, - doc="The original URL ID" - ) - - # Relationships - batch = relationship("Batch", back_populates="duplicates") - original_url = relationship("URL", back_populates="duplicates") diff --git a/src/db/models/instantiations/link/link_batch_urls.py b/src/db/models/instantiations/link/link_batch_urls.py deleted file mode 100644 index f357ae6a..00000000 --- a/src/db/models/instantiations/link/link_batch_urls.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy.orm import relationship - -from src.db.models.mixins import CreatedAtMixin, UpdatedAtMixin, BatchDependentMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class LinkBatchURL( - UpdatedAtMixin, - CreatedAtMixin, - URLDependentMixin, - BatchDependentMixin, - StandardModel -): - __tablename__ = "link_batch_urls" - - url = relationship('URL') - batch = relationship('Batch') \ No newline at end of file diff --git a/src/db/models/instantiations/link/link_task_url.py b/src/db/models/instantiations/link/link_task_url.py deleted file mode 100644 index 02ef02c3..00000000 --- a/src/db/models/instantiations/link/link_task_url.py +++ /dev/null @@ -1,15 +0,0 @@ -from sqlalchemy import UniqueConstraint, Column, Integer, ForeignKey - -from src.db.models.templates import Base - - -class LinkTaskURL(Base): - __tablename__ = 'link_task_urls' - __table_args__ = (UniqueConstraint( - "task_id", - "url_id", - name="uq_task_id_url_id"), - ) - - task_id = Column(Integer, ForeignKey('tasks.id', ondelete="CASCADE"), primary_key=True) - url_id = Column(Integer, ForeignKey('urls.id', ondelete="CASCADE"), primary_key=True) diff --git a/src/db/models/instantiations/log.py b/src/db/models/instantiations/log.py deleted file mode 100644 index 756e10c5..00000000 --- a/src/db/models/instantiations/log.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy import Column, Text -from sqlalchemy.orm import relationship - -from src.db.models.mixins import CreatedAtMixin, BatchDependentMixin -from src.db.models.templates import StandardModel - - -class Log(CreatedAtMixin, BatchDependentMixin, StandardModel): - __tablename__ = 'logs' - - log = Column(Text, nullable=False) - - # Relationships - batch = relationship("Batch", back_populates="logs") diff --git a/src/db/models/instantiations/root_url_cache.py b/src/db/models/instantiations/root_url_cache.py deleted file mode 100644 index d121ae28..00000000 --- a/src/db/models/instantiations/root_url_cache.py +++ /dev/null @@ -1,17 +0,0 @@ -from sqlalchemy import UniqueConstraint, Column, String - -from src.db.models.mixins import UpdatedAtMixin -from src.db.models.templates import StandardModel - - -class RootURL(UpdatedAtMixin, StandardModel): - __tablename__ = 'root_url_cache' - __table_args__ = ( - UniqueConstraint( - "url", - name="uq_root_url_url"), - ) - - url = Column(String, nullable=False) - page_title = Column(String, nullable=False) - page_description = Column(String, nullable=True) diff --git a/src/db/models/instantiations/sync_state_agencies.py b/src/db/models/instantiations/sync_state_agencies.py deleted file mode 100644 index 207a2936..00000000 --- a/src/db/models/instantiations/sync_state_agencies.py +++ /dev/null @@ -1,32 +0,0 @@ -""" -Tracks the status of the agencies sync -""" - -from sqlalchemy import DateTime, Date, Integer, Column - -from src.db.models.templates import Base - - -class AgenciesSyncState(Base): - __tablename__ = 'agencies_sync_state' - id = Column(Integer, primary_key=True) - last_full_sync_at = Column( - DateTime(), - nullable=True, - comment="The datetime of the last *full* sync " - "(i.e., the last sync that got all entries " - "available to be synchronized)." - ) - current_cutoff_date = Column( - Date(), - nullable=True, - comment="Tracks the cutoff date passed to the agencies sync endpoint." - "On completion of a full sync, this is set to " - "the day before the present day." - ) - current_page = Column( - Integer(), - nullable=True, - comment="Tracks the current page passed to the agencies sync endpoint." - "On completion of a full sync, this is set to `null`." - ) \ No newline at end of file diff --git a/src/db/models/instantiations/task/core.py b/src/db/models/instantiations/task/core.py deleted file mode 100644 index 89c80405..00000000 --- a/src/db/models/instantiations/task/core.py +++ /dev/null @@ -1,27 +0,0 @@ -from sqlalchemy import Column -from sqlalchemy.orm import relationship - -from src.db.enums import PGEnum, TaskType -from src.db.models.mixins import UpdatedAtMixin -from src.db.models.templates import StandardModel -from src.db.models.types import batch_status_enum - - -class Task(UpdatedAtMixin, StandardModel): - __tablename__ = 'tasks' - - task_type = Column( - PGEnum( - *[task_type.value for task_type in TaskType], - name='task_type' - ), nullable=False) - task_status = Column(batch_status_enum, nullable=False) - - # Relationships - urls = relationship( - "URL", - secondary="link_task_urls", - back_populates="tasks" - ) - error = relationship("TaskError", back_populates="task") - errored_urls = relationship("URLErrorInfo", back_populates="task") diff --git a/src/db/models/instantiations/task/error.py b/src/db/models/instantiations/task/error.py deleted file mode 100644 index cf1ae24f..00000000 --- a/src/db/models/instantiations/task/error.py +++ /dev/null @@ -1,20 +0,0 @@ -from sqlalchemy import Column, Text, UniqueConstraint -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, TaskDependentMixin -from src.db.models.templates import StandardModel - - -class TaskError(UpdatedAtMixin, TaskDependentMixin, StandardModel): - __tablename__ = 'task_errors' - - error = Column(Text, nullable=False) - - # Relationships - task = relationship("Task", back_populates="error") - - __table_args__ = (UniqueConstraint( - "task_id", - "error", - name="uq_task_id_error"), - ) diff --git a/src/db/models/instantiations/url/compressed_html.py b/src/db/models/instantiations/url/compressed_html.py deleted file mode 100644 index 5c2e06c0..00000000 --- a/src/db/models/instantiations/url/compressed_html.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqlalchemy import Column, LargeBinary -from sqlalchemy.orm import relationship - -from src.db.models.mixins import CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class URLCompressedHTML( - CreatedAtMixin, - URLDependentMixin, - StandardModel -): - __tablename__ = 'url_compressed_html' - - compressed_html = Column(LargeBinary, nullable=False) - - url = relationship( - "URL", - uselist=False, - back_populates="compressed_html" - ) \ No newline at end of file diff --git a/src/db/models/instantiations/url/core.py b/src/db/models/instantiations/url/core.py deleted file mode 100644 index 8e9860fc..00000000 --- a/src/db/models/instantiations/url/core.py +++ /dev/null @@ -1,89 +0,0 @@ -from sqlalchemy import Column, Integer, ForeignKey, Text, String, JSON -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin -from src.db.models.templates import StandardModel -from src.db.models.types import record_type_values - - -class URL(UpdatedAtMixin, CreatedAtMixin, StandardModel): - __tablename__ = 'urls' - - # The batch this URL is associated with - url = Column(Text, unique=True) - name = Column(String) - description = Column(Text) - # The metadata from the collector - collector_metadata = Column(JSON) - # The outcome of the URL: submitted, human_labeling, rejected, duplicate, etc. - outcome = Column( - postgresql.ENUM( - 'pending', - 'submitted', - 'validated', - 'not relevant', - 'duplicate', - 'error', - '404 not found', - 'individual record', - name='url_status' - ), - nullable=False - ) - record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=True) - - # Relationships - batch = relationship( - "Batch", - secondary="link_batch_urls", - back_populates="urls", - uselist=False - ) - duplicates = relationship("Duplicate", back_populates="original_url") - html_content = relationship("URLHTMLContent", back_populates="url", cascade="all, delete-orphan") - error_info = relationship("URLErrorInfo", back_populates="url", cascade="all, delete-orphan") - tasks = relationship( - "Task", - secondary="link_task_urls", - back_populates="urls", - ) - automated_agency_suggestions = relationship( - "AutomatedUrlAgencySuggestion", back_populates="url") - user_agency_suggestion = relationship( - "UserUrlAgencySuggestion", uselist=False, back_populates="url") - auto_record_type_suggestion = relationship( - "AutoRecordTypeSuggestion", uselist=False, back_populates="url") - user_record_type_suggestion = relationship( - "UserRecordTypeSuggestion", uselist=False, back_populates="url") - auto_relevant_suggestion = relationship( - "AutoRelevantSuggestion", uselist=False, back_populates="url") - user_relevant_suggestion = relationship( - "UserRelevantSuggestion", uselist=False, back_populates="url") - reviewing_user = relationship( - "ReviewingUserURL", uselist=False, back_populates="url") - optional_data_source_metadata = relationship( - "URLOptionalDataSourceMetadata", uselist=False, back_populates="url") - confirmed_agencies = relationship( - "ConfirmedURLAgency", - ) - data_source = relationship( - "URLDataSource", - back_populates="url", - uselist=False - ) - checked_for_duplicate = relationship( - "URLCheckedForDuplicate", - uselist=False, - back_populates="url" - ) - probed_for_404 = relationship( - "URLProbedFor404", - uselist=False, - back_populates="url" - ) - compressed_html = relationship( - "URLCompressedHTML", - uselist=False, - back_populates="url" - ) \ No newline at end of file diff --git a/src/db/models/instantiations/url/data_source.py b/src/db/models/instantiations/url/data_source.py deleted file mode 100644 index ad6caf46..00000000 --- a/src/db/models/instantiations/url/data_source.py +++ /dev/null @@ -1,18 +0,0 @@ -from sqlalchemy import Column, Integer -from sqlalchemy.orm import relationship - -from src.db.models.mixins import CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class URLDataSource(CreatedAtMixin, URLDependentMixin, StandardModel): - __tablename__ = "url_data_sources" - - data_source_id = Column(Integer, nullable=False) - - # Relationships - url = relationship( - "URL", - back_populates="data_source", - uselist=False - ) diff --git a/src/db/models/instantiations/url/error_info.py b/src/db/models/instantiations/url/error_info.py deleted file mode 100644 index d2a09b6a..00000000 --- a/src/db/models/instantiations/url/error_info.py +++ /dev/null @@ -1,20 +0,0 @@ -from sqlalchemy import UniqueConstraint, Column, Text -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, TaskDependentMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class URLErrorInfo(UpdatedAtMixin, TaskDependentMixin, URLDependentMixin, StandardModel): - __tablename__ = 'url_error_info' - __table_args__ = (UniqueConstraint( - "url_id", - "task_id", - name="uq_url_id_error"), - ) - - error = Column(Text, nullable=False) - - # Relationships - url = relationship("URL", back_populates="error_info") - task = relationship("Task", back_populates="errored_urls") diff --git a/src/db/models/instantiations/url/html_content.py b/src/db/models/instantiations/url/html_content.py deleted file mode 100644 index 39ad3666..00000000 --- a/src/db/models/instantiations/url/html_content.py +++ /dev/null @@ -1,24 +0,0 @@ -from sqlalchemy import UniqueConstraint, Column, Text -from sqlalchemy.orm import relationship - -from src.db.enums import PGEnum -from src.db.models.mixins import UpdatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class URLHTMLContent(UpdatedAtMixin, URLDependentMixin, StandardModel): - __tablename__ = 'url_html_content' - __table_args__ = (UniqueConstraint( - "url_id", - "content_type", - name="uq_url_id_content_type"), - ) - - content_type = Column( - PGEnum('Title', 'Description', 'H1', 'H2', 'H3', 'H4', 'H5', 'H6', 'Div', name='url_html_content_type'), - nullable=False) - content = Column(Text, nullable=False) - - - # Relationships - url = relationship("URL", back_populates="html_content") diff --git a/src/db/models/instantiations/url/probed_for_404.py b/src/db/models/instantiations/url/probed_for_404.py deleted file mode 100644 index 3913e37e..00000000 --- a/src/db/models/instantiations/url/probed_for_404.py +++ /dev/null @@ -1,14 +0,0 @@ -from sqlalchemy.orm import relationship - -from src.db.models.helpers import get_created_at_column -from src.db.models.mixins import URLDependentMixin -from src.db.models.templates import StandardModel - - -class URLProbedFor404(URLDependentMixin, StandardModel): - __tablename__ = 'url_probed_for_404' - - last_probed_at = get_created_at_column() - - # Relationships - url = relationship("URL", uselist=False, back_populates="probed_for_404") diff --git a/src/db/models/instantiations/url/suggestion/agency/auto.py b/src/db/models/instantiations/url/suggestion/agency/auto.py deleted file mode 100644 index 5831882f..00000000 --- a/src/db/models/instantiations/url/suggestion/agency/auto.py +++ /dev/null @@ -1,20 +0,0 @@ -from sqlalchemy import Column, Boolean, UniqueConstraint -from sqlalchemy.orm import relationship - -from src.db.models.helpers import get_agency_id_foreign_column -from src.db.models.mixins import URLDependentMixin -from src.db.models.templates import StandardModel - - -class AutomatedUrlAgencySuggestion(URLDependentMixin, StandardModel): - __tablename__ = "automated_url_agency_suggestions" - - agency_id = get_agency_id_foreign_column(nullable=True) - is_unknown = Column(Boolean, nullable=True) - - agency = relationship("Agency", back_populates="automated_suggestions") - url = relationship("URL", back_populates="automated_agency_suggestions") - - __table_args__ = ( - UniqueConstraint("agency_id", "url_id", name="uq_automated_url_agency_suggestions"), - ) diff --git a/src/db/models/instantiations/url/suggestion/agency/user.py b/src/db/models/instantiations/url/suggestion/agency/user.py deleted file mode 100644 index cb92bfc0..00000000 --- a/src/db/models/instantiations/url/suggestion/agency/user.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqlalchemy import Column, Boolean, UniqueConstraint, Integer -from sqlalchemy.orm import relationship - -from src.db.models.helpers import get_agency_id_foreign_column -from src.db.models.mixins import URLDependentMixin -from src.db.models.templates import StandardModel - - -class UserUrlAgencySuggestion(URLDependentMixin, StandardModel): - __tablename__ = "user_url_agency_suggestions" - - agency_id = get_agency_id_foreign_column(nullable=True) - user_id = Column(Integer, nullable=False) - is_new = Column(Boolean, nullable=True) - - agency = relationship("Agency", back_populates="user_suggestions") - url = relationship("URL", back_populates="user_agency_suggestion") - - __table_args__ = ( - UniqueConstraint("agency_id", "url_id", "user_id", name="uq_user_url_agency_suggestions"), - ) diff --git a/src/db/models/instantiations/url/suggestion/record_type/auto.py b/src/db/models/instantiations/url/suggestion/record_type/auto.py deleted file mode 100644 index 00d738b8..00000000 --- a/src/db/models/instantiations/url/suggestion/record_type/auto.py +++ /dev/null @@ -1,27 +0,0 @@ -from sqlalchemy import Column, UniqueConstraint -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship - -from src.db.models.mixins import URLDependentMixin, UpdatedAtMixin, CreatedAtMixin -from src.db.models.templates import StandardModel -from src.db.models.types import record_type_values - - -class AutoRecordTypeSuggestion( - UpdatedAtMixin, - CreatedAtMixin, - URLDependentMixin, - StandardModel -): - __tablename__ = "auto_record_type_suggestions" - record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=False) - - __table_args__ = ( - UniqueConstraint("url_id", name="auto_record_type_suggestions_uq_url_id"), - ) - - # Relationships - - url = relationship("URL", back_populates="auto_record_type_suggestion") - - diff --git a/src/db/models/instantiations/url/suggestion/record_type/user.py b/src/db/models/instantiations/url/suggestion/record_type/user.py deleted file mode 100644 index cda6fb17..00000000 --- a/src/db/models/instantiations/url/suggestion/record_type/user.py +++ /dev/null @@ -1,22 +0,0 @@ -from sqlalchemy import Column, Integer, UniqueConstraint -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel -from src.db.models.types import record_type_values - - -class UserRecordTypeSuggestion(UpdatedAtMixin, CreatedAtMixin, URLDependentMixin, StandardModel): - __tablename__ = "user_record_type_suggestions" - - user_id = Column(Integer, nullable=False) - record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=False) - - __table_args__ = ( - UniqueConstraint("url_id", "user_id", name="uq_user_record_type_suggestions"), - ) - - # Relationships - - url = relationship("URL", back_populates="user_record_type_suggestion") diff --git a/src/db/models/instantiations/url/suggestion/relevant/auto.py b/src/db/models/instantiations/url/suggestion/relevant/auto.py deleted file mode 100644 index db7f8ea2..00000000 --- a/src/db/models/instantiations/url/suggestion/relevant/auto.py +++ /dev/null @@ -1,21 +0,0 @@ -from sqlalchemy import Column, Boolean, UniqueConstraint, String, Float -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class AutoRelevantSuggestion(UpdatedAtMixin, CreatedAtMixin, URLDependentMixin, StandardModel): - __tablename__ = "auto_relevant_suggestions" - - relevant = Column(Boolean, nullable=True) - confidence = Column(Float, nullable=True) - model_name = Column(String, nullable=True) - - __table_args__ = ( - UniqueConstraint("url_id", name="auto_relevant_suggestions_uq_url_id"), - ) - - # Relationships - - url = relationship("URL", back_populates="auto_relevant_suggestion") diff --git a/src/db/models/instantiations/url/suggestion/relevant/user.py b/src/db/models/instantiations/url/suggestion/relevant/user.py deleted file mode 100644 index 35d30c44..00000000 --- a/src/db/models/instantiations/url/suggestion/relevant/user.py +++ /dev/null @@ -1,35 +0,0 @@ -from sqlalchemy import Column, UniqueConstraint, Integer -from sqlalchemy.dialects import postgresql -from sqlalchemy.orm import relationship - -from src.db.models.mixins import UpdatedAtMixin, CreatedAtMixin, URLDependentMixin -from src.db.models.templates import StandardModel - - -class UserRelevantSuggestion( - UpdatedAtMixin, - CreatedAtMixin, - URLDependentMixin, - StandardModel -): - __tablename__ = "user_relevant_suggestions" - - user_id = Column(Integer, nullable=False) - suggested_status = Column( - postgresql.ENUM( - 'relevant', - 'not relevant', - 'individual record', - 'broken page/404 not found', - name='suggested_status' - ), - nullable=True - ) - - __table_args__ = ( - UniqueConstraint("url_id", "user_id", name="uq_user_relevant_suggestions"), - ) - - # Relationships - - url = relationship("URL", back_populates="user_relevant_suggestion") diff --git a/src/db/models/mixins.py b/src/db/models/mixins.py index 541e5d09..12a0b2a1 100644 --- a/src/db/models/mixins.py +++ b/src/db/models/mixins.py @@ -1,5 +1,8 @@ -from sqlalchemy import Column, Integer, ForeignKey, TIMESTAMP +from typing import ClassVar +from sqlalchemy import Column, Integer, ForeignKey, TIMESTAMP, event + +from src.db.models.exceptions import WriteToViewError from src.db.models.helpers import get_created_at_column, CURRENT_TIME_SERVER_DEFAULT @@ -35,6 +38,15 @@ class BatchDependentMixin: nullable=False ) +class LocationDependentMixin: + location_id = Column( + Integer, + ForeignKey( + 'locations.id', + ondelete="CASCADE", + ), + nullable=False + ) class AgencyDependentMixin: agency_id = Column( @@ -58,3 +70,17 @@ class UpdatedAtMixin: server_default=CURRENT_TIME_SERVER_DEFAULT, onupdate=CURRENT_TIME_SERVER_DEFAULT ) + +class ViewMixin: + """Attach to any mapped class that represents a DB view.""" + __is_view__: ClassVar[bool] = True + + @classmethod + def __declare_last__(cls) -> None: + # Block writes on this mapped class + for evt in ("before_insert", "before_update", "before_delete"): + event.listen(cls, evt, cls._block_write) + + @staticmethod + def _block_write(mapper, connection, target): + raise WriteToViewError(f"{type(target).__name__} is a read-only view.") diff --git a/src/db/models/templates.py b/src/db/models/templates.py deleted file mode 100644 index 3e0a1c95..00000000 --- a/src/db/models/templates.py +++ /dev/null @@ -1,11 +0,0 @@ -from sqlalchemy import Integer, Column -from sqlalchemy.orm import declarative_base - -# Base class for SQLAlchemy ORM models -Base = declarative_base() - -class StandardModel(Base): - __abstract__ = True - - id = Column(Integer, primary_key=True, autoincrement=True) - diff --git a/src/db/models/templates_/__init__.py b/src/db/models/templates_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/templates_/base.py b/src/db/models/templates_/base.py new file mode 100644 index 00000000..0ec5f68e --- /dev/null +++ b/src/db/models/templates_/base.py @@ -0,0 +1,4 @@ +"""Base class for SQLAlchemy ORM models.""" +from sqlalchemy.orm import declarative_base + +Base = declarative_base() diff --git a/src/db/models/templates_/standard.py b/src/db/models/templates_/standard.py new file mode 100644 index 00000000..85a01941 --- /dev/null +++ b/src/db/models/templates_/standard.py @@ -0,0 +1,14 @@ +from sqlalchemy import Column, Integer + +from src.db.models.mixins import CreatedAtMixin, UpdatedAtMixin +from src.db.models.templates_.base import Base + + +class StandardBase( + Base, + CreatedAtMixin, + UpdatedAtMixin, +): + __abstract__ = True + + id = Column(Integer, primary_key=True, autoincrement=True) diff --git a/src/db/models/templates_/with_id.py b/src/db/models/templates_/with_id.py new file mode 100644 index 00000000..e454f215 --- /dev/null +++ b/src/db/models/templates_/with_id.py @@ -0,0 +1,11 @@ +from sqlalchemy import Integer, Column + +from src.db.models.templates_.base import Base + + + +class WithIDBase(Base): + __abstract__ = True + + id = Column(Integer, primary_key=True, autoincrement=True) + diff --git a/src/db/models/views/__init__.py b/src/db/models/views/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/views/batch_url_status/__init__.py b/src/db/models/views/batch_url_status/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/views/batch_url_status/core.py b/src/db/models/views/batch_url_status/core.py new file mode 100644 index 00000000..888ca169 --- /dev/null +++ b/src/db/models/views/batch_url_status/core.py @@ -0,0 +1,81 @@ +""" +CREATE MATERIALIZED VIEW batch_url_status_mat_view as ( + with + batches_with_urls as ( + select + b.id as batch_id + from + batches b + where + exists( + select + 1 + from + link_batch_urls lbu + where + lbu.batch_id = b.id + ) + ) + , batches_with_only_validated_urls as ( + select + b.id + from + batches b + where + exists( + select + 1 + from + link_batch_urls lbu + left join flag_url_validated fuv on fuv.url_id = lbu.url_id + where + lbu.batch_id = b.id + and fuv.id is not null + ) + and not exists( + select + 1 + from + link_batch_urls lbu + left join flag_url_validated fuv on fuv.url_id = lbu.url_id + where + lbu.batch_id = b.id + and fuv.id is null + ) + ) + +select + b.id, + case + when b.status = 'error' THEN 'Error' + when (bwu.id is null) THEN 'No URLs' + when (bwovu.id is not null) THEN 'Labeling Complete' + else 'Has Unlabeled URLs' + end as batch_url_status +from + batches b + left join batches_with_urls bwu + on bwu.id = b.id + left join batches_with_only_validated_urls bwovu + on bwovu.id = b.id +) +""" +from sqlalchemy import PrimaryKeyConstraint, String, Column + +from src.db.models.mixins import ViewMixin, BatchDependentMixin +from src.db.models.templates_.base import Base + + +class BatchURLStatusMatView( + Base, + ViewMixin, + BatchDependentMixin +): + + batch_url_status = Column(String) + + __tablename__ = "batch_url_status_mat_view" + __table_args__ = ( + PrimaryKeyConstraint("batch_id"), + {"info": "view"} + ) \ No newline at end of file diff --git a/src/db/models/views/batch_url_status/enums.py b/src/db/models/views/batch_url_status/enums.py new file mode 100644 index 00000000..2f524de4 --- /dev/null +++ b/src/db/models/views/batch_url_status/enums.py @@ -0,0 +1,8 @@ +from enum import Enum + + +class BatchURLStatusEnum(Enum): + ERROR = "Error" + NO_URLS = "No URLs" + LABELING_COMPLETE = "Labeling Complete" + HAS_UNLABELED_URLS = "Has Unlabeled URLs" \ No newline at end of file diff --git a/src/db/models/views/dependent_locations.py b/src/db/models/views/dependent_locations.py new file mode 100644 index 00000000..95f3db98 --- /dev/null +++ b/src/db/models/views/dependent_locations.py @@ -0,0 +1,54 @@ +""" +create view dependent_locations(parent_location_id, dependent_location_id) as +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'County'::location_type AND lp.type = 'State'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.county_id = lp.county_id AND ld.type = 'Locality'::location_type AND lp.type = 'County'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON ld.state_id = lp.state_id AND ld.type = 'Locality'::location_type AND lp.type = 'State'::location_type +UNION ALL +SELECT + lp.id AS parent_location_id, + ld.id AS dependent_location_id +FROM + locations lp + JOIN locations ld ON lp.type = 'National'::location_type AND (ld.type = ANY + (ARRAY ['State'::location_type, 'County'::location_type, 'Locality'::location_type])); +""" +from sqlalchemy import Column, Integer, ForeignKey + +from src.db.models.mixins import ViewMixin +from src.db.models.templates_.base import Base + + +class DependentLocationView(Base, ViewMixin): + + __tablename__ = "dependent_locations" + __table_args__ = ( + {"info": "view"} + ) + + parent_location_id = Column( + Integer, + ForeignKey("locations.id"), + primary_key=True, + ) + dependent_location_id = Column( + Integer, + ForeignKey("locations.id"), + primary_key=True + ) diff --git a/src/db/models/views/location_expanded.py b/src/db/models/views/location_expanded.py new file mode 100644 index 00000000..1eb973aa --- /dev/null +++ b/src/db/models/views/location_expanded.py @@ -0,0 +1,66 @@ +""" +create or replace view public.locations_expanded + (id, type, state_name, state_iso, county_name, county_fips, locality_name, locality_id, state_id, county_id, + display_name, full_display_name) +as +SELECT + locations.id, + locations.type, + us_states.state_name, + us_states.state_iso, + counties.name AS county_name, + counties.fips AS county_fips, + localities.name AS locality_name, + localities.id AS locality_id, + us_states.id AS state_id, + counties.id AS county_id, + CASE + WHEN locations.type = 'Locality'::location_type THEN localities.name + WHEN locations.type = 'County'::location_type THEN counties.name::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + ELSE NULL::character varying + END AS display_name, + CASE + WHEN locations.type = 'Locality'::location_type THEN concat(localities.name, ', ', counties.name, ', ', + us_states.state_name)::character varying + WHEN locations.type = 'County'::location_type + THEN concat(counties.name, ', ', us_states.state_name)::character varying + WHEN locations.type = 'State'::location_type THEN us_states.state_name::character varying + ELSE NULL::character varying + END AS full_display_name +FROM + locations + LEFT JOIN us_states ON locations.state_id = us_states.id + LEFT JOIN counties ON locations.county_id = counties.id + LEFT JOIN localities ON locations.locality_id = localities.id; +""" +from sqlalchemy import Column, String, Integer + +from src.db.models.helpers import enum_column +from src.db.models.impl.location.location.enums import LocationType +from src.db.models.mixins import ViewMixin, LocationDependentMixin +from src.db.models.templates_.with_id import WithIDBase + + +class LocationExpandedView( + WithIDBase, + ViewMixin, + LocationDependentMixin +): + + __tablename__ = "locations_expanded" + __table_args__ = ( + {"info": "view"} + ) + + type = enum_column(LocationType, name="location_type", nullable=False) + state_name = Column(String) + state_iso = Column(String) + county_name = Column(String) + county_fips = Column(String) + locality_name = Column(String) + locality_id = Column(Integer) + state_id = Column(Integer) + county_id = Column(Integer) + display_name = Column(String) + full_display_name = Column(String) diff --git a/src/db/models/views/meta_url.py b/src/db/models/views/meta_url.py new file mode 100644 index 00000000..20437075 --- /dev/null +++ b/src/db/models/views/meta_url.py @@ -0,0 +1,26 @@ +""" + CREATE OR REPLACE VIEW meta_url_view AS + SELECT + urls.id as url_id + FROM urls + INNER JOIN flag_url_validated fuv on fuv.url_id = urls.id + where fuv.type = 'meta url' +""" + +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import ViewMixin, URLDependentMixin +from src.db.models.templates_.base import Base + + +class MetaURL( + Base, + ViewMixin, + URLDependentMixin, +): + + __tablename__ = "meta_url_view" + __table_args__ = ( + PrimaryKeyConstraint("url_id"), + {"info": "view"} + ) \ No newline at end of file diff --git a/src/db/models/views/unvalidated_url.py b/src/db/models/views/unvalidated_url.py new file mode 100644 index 00000000..bcfa9293 --- /dev/null +++ b/src/db/models/views/unvalidated_url.py @@ -0,0 +1,28 @@ +""" +CREATE OR REPLACE VIEW unvalidated_url_view AS +select + u.id as url_id +from + urls u + left join flag_url_validated fuv + on fuv.url_id = u.id +where + fuv.type is null +""" +from sqlalchemy import PrimaryKeyConstraint + +from src.db.models.mixins import ViewMixin, URLDependentMixin +from src.db.models.templates_.base import Base + + +class UnvalidatedURL( + Base, + ViewMixin, + URLDependentMixin, +): + + __tablename__ = "unvalidated_url_view" + __table_args__ = ( + PrimaryKeyConstraint("url_id"), + {"info": "view"} + ) \ No newline at end of file diff --git a/src/db/models/views/url_anno_count.py b/src/db/models/views/url_anno_count.py new file mode 100644 index 00000000..232f0d21 --- /dev/null +++ b/src/db/models/views/url_anno_count.py @@ -0,0 +1,125 @@ +""" + CREATE OR REPLACE VIEW url_annotation_count AS + with auto_location_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_location_id_subtasks anno on u.id = anno.url_id + group by u.id +) +, auto_agency_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.url_auto_agency_id_subtasks anno on u.id = anno.url_id + group by u.id +) +, auto_url_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_relevant_suggestions anno on u.id = anno.url_id + group by u.id +) +, auto_record_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.auto_record_type_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_location_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_location_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_agency_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_url_agency_suggestions anno on u.id = anno.url_id + group by u.id +) +, user_url_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_url_type_suggestions anno on u.id = anno.url_id + group by u.id + ) +, user_record_type_count as ( + select + u.id, + count(anno.url_id) as cnt + from urls u + inner join public.user_record_type_suggestions anno on u.id = anno.url_id + group by u.id +) +select + u.id as url_id, + coalesce(auto_ag.cnt, 0) as auto_agency_count, + coalesce(auto_loc.cnt, 0) as auto_location_count, + coalesce(auto_rec.cnt, 0) as auto_record_type_count, + coalesce(auto_typ.cnt, 0) as auto_url_type_count, + coalesce(user_ag.cnt, 0) as user_agency_count, + coalesce(user_loc.cnt, 0) as user_location_count, + coalesce(user_rec.cnt, 0) as user_record_type_count, + coalesce(user_typ.cnt, 0) as user_url_type_count, + ( + coalesce(auto_ag.cnt, 0) + + coalesce(auto_loc.cnt, 0) + + coalesce(auto_rec.cnt, 0) + + coalesce(auto_typ.cnt, 0) + + coalesce(user_ag.cnt, 0) + + coalesce(user_loc.cnt, 0) + + coalesce(user_rec.cnt, 0) + + coalesce(user_typ.cnt, 0) + ) as total_anno_count + + from urls u + left join auto_agency_count auto_ag on auto_ag.id = u.id + left join auto_location_count auto_loc on auto_loc.id = u.id + left join auto_record_type_count auto_rec on auto_rec.id = u.id + left join auto_url_type_count auto_typ on auto_typ.id = u.id + left join user_agency_count user_ag on user_ag.id = u.id + left join user_location_count user_loc on user_loc.id = u.id + left join user_record_type_count user_rec on user_rec.id = u.id + left join user_url_type_count user_typ on user_typ.id = u.id +""" +from sqlalchemy import PrimaryKeyConstraint, Column, Integer + +from src.db.models.helpers import url_id_primary_key_constraint +from src.db.models.mixins import ViewMixin, URLDependentMixin +from src.db.models.templates_.base import Base + + +class URLAnnotationCount( + Base, + ViewMixin, + URLDependentMixin +): + + __tablename__ = "url_annotation_count_view" + __table_args__ = ( + url_id_primary_key_constraint(), + {"info": "view"} + ) + + auto_agency_count = Column(Integer, nullable=False) + auto_location_count = Column(Integer, nullable=False) + auto_record_type_count = Column(Integer, nullable=False) + auto_url_type_count = Column(Integer, nullable=False) + user_agency_count = Column(Integer, nullable=False) + user_location_count = Column(Integer, nullable=False) + user_record_type_count = Column(Integer, nullable=False) + user_url_type_count = Column(Integer, nullable=False) + total_anno_count = Column(Integer, nullable=False) \ No newline at end of file diff --git a/src/db/models/views/url_annotations_flags.py b/src/db/models/views/url_annotations_flags.py new file mode 100644 index 00000000..57d8e866 --- /dev/null +++ b/src/db/models/views/url_annotations_flags.py @@ -0,0 +1,51 @@ +""" +CREATE OR REPLACE VIEW url_annotation_flags AS +( +SELECT u.id, + CASE WHEN arts.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_record_type_suggestion, + CASE WHEN ars.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_relevant_suggestion, + CASE WHEN auas.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_auto_agency_suggestion, + CASE WHEN urts.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_record_type_suggestion, + CASE WHEN urs.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_relevant_suggestion, + CASE WHEN uuas.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_user_agency_suggestion, + CASE WHEN cua.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS has_confirmed_agency, + CASE WHEN ruu.url_id IS NOT NULL THEN TRUE ELSE FALSE END AS was_reviewed +FROM urls u + LEFT JOIN public.auto_record_type_suggestions arts ON u.id = arts.url_id + LEFT JOIN public.auto_relevant_suggestions ars ON u.id = ars.url_id + LEFT JOIN public.{URL_AUTO_AGENCY_SUGGESTIONS_TABLE_NAME} auas ON u.id = auas.url_id + LEFT JOIN public.user_record_type_suggestions urts ON u.id = urts.url_id + LEFT JOIN public.user_relevant_suggestions urs ON u.id = urs.url_id + LEFT JOIN public.user_url_agency_suggestions uuas ON u.id = uuas.url_id + LEFT JOIN public.reviewing_user_url ruu ON u.id = ruu.url_id + LEFT JOIN public.link_urls_agency cua on u.id = cua.url_id + ) +""" + +from sqlalchemy import PrimaryKeyConstraint, Column, Boolean + +from src.db.models.mixins import ViewMixin, URLDependentMixin +from src.db.models.templates_.base import Base + + +class URLAnnotationFlagsView( + Base, + ViewMixin, + URLDependentMixin +): + __tablename__ = "url_annotation_flags" + __table_args__ = ( + PrimaryKeyConstraint("url_id"), + {"info": "view"} + ) + + has_auto_record_type_suggestion = Column(Boolean, nullable=False) + has_auto_relevant_suggestion = Column(Boolean, nullable=False) + has_auto_agency_suggestion = Column(Boolean, nullable=False) + has_auto_location_suggestion = Column(Boolean, nullable=False) + has_user_record_type_suggestion = Column(Boolean, nullable=False) + has_user_relevant_suggestion = Column(Boolean, nullable=False) + has_user_agency_suggestion = Column(Boolean, nullable=False) + has_user_location_suggestion = Column(Boolean, nullable=False) + has_confirmed_agency = Column(Boolean, nullable=False) + was_reviewed = Column(Boolean, nullable=False) \ No newline at end of file diff --git a/src/db/models/views/url_status/__init__.py b/src/db/models/views/url_status/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/models/views/url_status/core.py b/src/db/models/views/url_status/core.py new file mode 100644 index 00000000..77a01139 --- /dev/null +++ b/src/db/models/views/url_status/core.py @@ -0,0 +1,77 @@ +""" + CREATE MATERIALIZED VIEW url_status_mat_view AS + with + urls_with_relevant_errors as ( + select + ute.url_id + from + url_task_error ute + where + ute.task_type in ( + 'Screenshot', + 'HTML', + 'URL Probe' + ) + ) + select + u.id as url_id, + case + when ( + -- Validated as not relevant, individual record, or not found + fuv.type in ('not relevant', 'individual record', 'not found') + -- Has Meta URL in data sources app + OR udmu.url_id is not null + -- Has data source in data sources app + OR uds.url_id is not null + ) Then 'Submitted/Pipeline Complete' + when fuv.type is not null THEN 'Accepted' + when ( + -- Has compressed HTML + uch.url_id is not null + AND + -- Has web metadata + uwm.url_id is not null + AND + -- Has screenshot + us.url_id is not null + ) THEN 'Community Labeling' + when uwre.url_id is not null then 'Error' + ELSE 'Intake' + END as status + + from + urls u + left join urls_with_relevant_errors uwre + on u.id = uwre.url_id + left join url_screenshot us + on u.id = us.url_id + left join url_compressed_html uch + on u.id = uch.url_id + left join url_web_metadata uwm + on u.id = uwm.url_id + left join flag_url_validated fuv + on u.id = fuv.url_id + left join url_ds_meta_url udmu + on u.id = udmu.url_id + left join url_data_source uds + on u.id = uds.url_id +""" +from sqlalchemy import String, Column + +from src.db.models.helpers import url_id_primary_key_constraint +from src.db.models.mixins import ViewMixin, URLDependentMixin +from src.db.models.templates_.base import Base + + +class URLStatusMatView( + Base, + ViewMixin, + URLDependentMixin +): + __tablename__ = "url_status_mat_view" + __table_args__ = ( + url_id_primary_key_constraint(), + {"info": "view"} + ) + + status = Column(String) \ No newline at end of file diff --git a/src/db/models/views/url_status/enums.py b/src/db/models/views/url_status/enums.py new file mode 100644 index 00000000..82995812 --- /dev/null +++ b/src/db/models/views/url_status/enums.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class URLStatusViewEnum(Enum): + INTAKE = "Intake" + ACCEPTED = "Accepted" + SUBMITTED_PIPELINE_COMPLETE = "Submitted/Pipeline Complete" + ERROR = "Error" + COMMUNITY_LABELING = "Community Labeling" \ No newline at end of file diff --git a/src/db/queries/base/builder.py b/src/db/queries/base/builder.py index 5806ef47..f0ef345c 100644 --- a/src/db/queries/base/builder.py +++ b/src/db/queries/base/builder.py @@ -1,16 +1,16 @@ from typing import Any, Generic, Optional from sqlalchemy import FromClause, ColumnClause -from sqlalchemy.dialects import postgresql from sqlalchemy.ext.asyncio import AsyncSession +from src.db.helpers.session import session_helper as sh from src.db.types import LabelsType class QueryBuilderBase(Generic[LabelsType]): - def __init__(self, labels: Optional[LabelsType] = None): - self.query: Optional[FromClause] = None + def __init__(self, labels: LabelsType | None = None): + self.query: FromClause | None = None self.labels = labels def get(self, key: str) -> ColumnClause: @@ -33,9 +33,4 @@ async def run(self, session: AsyncSession) -> Any: @staticmethod def compile(query) -> Any: - return query.compile( - dialect=postgresql.dialect(), - compile_kwargs={ - "literal_binds": True - } - ) + return sh.compile_to_sql(query) diff --git a/src/db/queries/implementations/core/common/annotation_exists.py b/src/db/queries/implementations/core/common/annotation_exists.py deleted file mode 100644 index 656b56f3..00000000 --- a/src/db/queries/implementations/core/common/annotation_exists.py +++ /dev/null @@ -1,72 +0,0 @@ -""" -The annotation exists common table expression -Provides a set of boolean flags indicating whether a URL -has each kind of possible annotation -Each row should have the following columns: -- url_id -- UserRelevantSuggestion_exists -- UserRecordTypeSuggestion_exists -- UserUrlAgencySuggestion_exists -- UserAutoRelevantSuggestion_exists -- UserAutoRecordTypeSuggestion_exists -- UserAutoUrlAgencySuggestion_exists -""" - -from typing import Any, Type - -from sqlalchemy import case, func, Select, select - -from src.collectors.enums import URLStatus -from src.db.constants import ALL_ANNOTATION_MODELS -from src.db.models.instantiations.url.core import URL -from src.db.models.mixins import URLDependentMixin -from src.db.queries.base.builder import QueryBuilderBase - - -class AnnotationExistsCTEQueryBuilder(QueryBuilderBase): - - @property - def url_id(self): - return self.query.c.url_id - - def get_exists_label(self, model: Type[URLDependentMixin]): - return f"{model.__name__}_exists" - - def get_all(self) -> list[Any]: - l = [self.url_id] - for model in ALL_ANNOTATION_MODELS: - label = self.get_exists_label(model) - l.append(self.get(label)) - return l - - async def _annotation_exists_case( - self, - ): - cases = [] - for model in ALL_ANNOTATION_MODELS: - cases.append( - case( - ( - func.bool_or(model.url_id.is_not(None)), 1 - ), - else_=0 - ).label(self.get_exists_label(model)) - ) - return cases - - async def _outer_join_models(self, query: Select): - for model in ALL_ANNOTATION_MODELS: - query = query.outerjoin(model) - return query - - - async def build(self) -> Any: - annotation_exists_cases_all = await self._annotation_exists_case() - anno_exists_query = select( - URL.id.label("url_id"), - *annotation_exists_cases_all - ) - anno_exists_query = await self._outer_join_models(anno_exists_query) - anno_exists_query = anno_exists_query.where(URL.outcome == URLStatus.PENDING.value) - anno_exists_query = anno_exists_query.group_by(URL.id).cte("annotations_exist") - self.query = anno_exists_query diff --git a/src/db/queries/implementations/core/common/annotation_exists_/__init__.py b/src/db/queries/implementations/core/common/annotation_exists_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/common/annotation_exists_/constants.py b/src/db/queries/implementations/core/common/annotation_exists_/constants.py new file mode 100644 index 00000000..1237634e --- /dev/null +++ b/src/db/queries/implementations/core/common/annotation_exists_/constants.py @@ -0,0 +1,15 @@ +from src.db.models.impl.url.suggestion.agency.subtask.sqlalchemy import URLAutoAgencyIDSubtask +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.auto.sqlalchemy import AutoRelevantSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion + +ALL_ANNOTATION_MODELS = [ + AutoRecordTypeSuggestion, + AutoRelevantSuggestion, + URLAutoAgencyIDSubtask, + UserURLTypeSuggestion, + UserRecordTypeSuggestion, + UserUrlAgencySuggestion +] diff --git a/src/db/queries/implementations/core/common/annotation_exists_/core.py b/src/db/queries/implementations/core/common/annotation_exists_/core.py new file mode 100644 index 00000000..53e8bcf6 --- /dev/null +++ b/src/db/queries/implementations/core/common/annotation_exists_/core.py @@ -0,0 +1,80 @@ +""" +The annotation exists common table expression +Provides a set of boolean flags indicating whether a URL +has each kind of possible annotation +Each row should have the following columns: +- url_id +- UserRelevantSuggestion_exists +- UserRecordTypeSuggestion_exists +- UserUrlAgencySuggestion_exists +- UserAutoRelevantSuggestion_exists +- UserAutoRecordTypeSuggestion_exists +- UserAutoUrlAgencySuggestion_exists +""" + +from typing import Any, Type + +from sqlalchemy import case, func, Select, select + +from src.collectors.enums import URLStatus +from src.db.queries.implementations.core.common.annotation_exists_.constants import ALL_ANNOTATION_MODELS +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.mixins import URLDependentMixin +from src.db.queries.base.builder import QueryBuilderBase + + +class AnnotationExistsCTEQueryBuilder(QueryBuilderBase): + + @property + def url_id(self): + return self.query.c.url_id + + def get_exists_label(self, model: Type[URLDependentMixin]) -> str: + return f"{model.__name__}_exists" + + def get_all(self) -> list[Any]: + l = [self.url_id] + for model in ALL_ANNOTATION_MODELS: + label = self.get_exists_label(model) + l.append(self.get(label)) + return l + + async def _annotation_exists_case( + self, + ) -> list[Any]: + cases = [] + for model in ALL_ANNOTATION_MODELS: + cases.append( + case( + ( + func.bool_or(model.url_id.is_not(None)), 1 + ), + else_=0 + ).label(self.get_exists_label(model)) + ) + return cases + + async def _outer_join_models(self, query: Select): + for model in ALL_ANNOTATION_MODELS: + query = query.outerjoin(model) + return query + + + async def build(self) -> Any: + annotation_exists_cases_all = await self._annotation_exists_case() + anno_exists_query = select( + URL.id.label("url_id"), + *annotation_exists_cases_all + ) + anno_exists_query = await self._outer_join_models(anno_exists_query) + anno_exists_query = anno_exists_query.outerjoin( + FlagURLValidated, + FlagURLValidated.url_id == URL.id + ) + anno_exists_query = anno_exists_query.where( + URL.status == URLStatus.OK.value, + FlagURLValidated.url_id.is_(None) + ) + anno_exists_query = anno_exists_query.group_by(URL.id).cte("annotations_exist") + self.query = anno_exists_query diff --git a/src/db/queries/implementations/core/get/html_content_info.py b/src/db/queries/implementations/core/get/html_content_info.py index fb26a527..3d2ad559 100644 --- a/src/db/queries/implementations/core/get/html_content_info.py +++ b/src/db/queries/implementations/core/get/html_content_info.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.db.dtos.url.html_content import URLHTMLContentInfo -from src.db.models.instantiations.url.html_content import URLHTMLContent +from src.db.models.impl.url.html.content.sqlalchemy import URLHTMLContent from src.db.queries.base.builder import QueryBuilderBase diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/builder.py b/src/db/queries/implementations/core/get/recent_batch_summaries/builder.py index 8ac1b4af..5de2eb55 100644 --- a/src/db/queries/implementations/core/get/recent_batch_summaries/builder.py +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/builder.py @@ -1,4 +1,3 @@ -from typing import Optional from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession @@ -7,7 +6,9 @@ from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.collectors.enums import CollectorType from src.core.enums import BatchStatus -from src.db.models.instantiations.batch import Batch +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.views.batch_url_status.core import BatchURLStatusMatView +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.db.queries.base.builder import QueryBuilderBase from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.builder import URLCountsCTEQueryBuilder from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.labels import URLCountsLabels @@ -18,15 +19,13 @@ class GetRecentBatchSummariesQueryBuilder(QueryBuilderBase): def __init__( self, page: int = 1, - has_pending_urls: Optional[bool] = None, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - batch_id: Optional[int] = None, + collector_type: CollectorType | None = None, + status: BatchURLStatusEnum | None = None, + batch_id: int | None = None, ): super().__init__() self.url_counts_cte = URLCountsCTEQueryBuilder( page=page, - has_pending_urls=has_pending_urls, collector_type=collector_type, status=status, batch_id=batch_id, @@ -37,18 +36,30 @@ async def run(self, session: AsyncSession) -> list[BatchSummary]: builder = self.url_counts_cte count_labels: URLCountsLabels = builder.labels - query = Select( - *builder.get_all(), - Batch.strategy, - Batch.status, - Batch.parameters, - Batch.user_id, - Batch.compute_time, - Batch.date_generated, - ).join( - builder.query, - builder.get(count_labels.batch_id) == Batch.id, + query = ( + Select( + *builder.get_all(), + Batch.strategy, + Batch.status, + BatchURLStatusMatView.batch_url_status, + Batch.parameters, + Batch.user_id, + Batch.compute_time, + Batch.date_generated, + ).join( + builder.query, + builder.get(count_labels.batch_id) == Batch.id, + ).outerjoin( + BatchURLStatusMatView, + BatchURLStatusMatView.batch_id == Batch.id, + ).order_by( + Batch.id.asc() + ) + ) + + + raw_results = await session.execute(query) summaries: list[BatchSummary] = [] diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/builder.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/builder.py index 571db2a0..4921337f 100644 --- a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/builder.py +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/builder.py @@ -1,15 +1,24 @@ -from typing import Optional - from sqlalchemy import Select, case, Label, and_, exists -from sqlalchemy.sql.functions import count, coalesce +from sqlalchemy.sql.functions import count, coalesce, func from src.collectors.enums import URLStatus, CollectorType from src.core.enums import BatchStatus -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.batch import Batch +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.models.views.batch_url_status.core import BatchURLStatusMatView +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.db.queries.base.builder import QueryBuilderBase from src.db.queries.helpers import add_page_offset +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.all import ALL_CTE +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.duplicate import DUPLICATE_CTE +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.error import ERROR_CTE +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.not_relevant import NOT_RELEVANT_CTE +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.pending import PENDING_CTE +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte.submitted import SUBMITTED_CTE from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.labels import URLCountsLabels @@ -18,14 +27,12 @@ class URLCountsCTEQueryBuilder(QueryBuilderBase): def __init__( self, page: int = 1, - has_pending_urls: Optional[bool] = None, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - batch_id: Optional[int] = None + collector_type: CollectorType | None = None, + status: BatchURLStatusEnum | None = None, + batch_id: int | None = None ): super().__init__(URLCountsLabels()) self.page = page - self.has_pending_urls = has_pending_urls self.collector_type = collector_type self.status = status self.batch_id = batch_id @@ -33,31 +40,35 @@ def __init__( def get_core_query(self): labels: URLCountsLabels = self.labels - return ( + query = ( Select( Batch.id.label(labels.batch_id), - coalesce(count(URL.id), 0).label(labels.total), - self.count_case_url_status(URLStatus.PENDING, labels.pending), - self.count_case_url_status(URLStatus.SUBMITTED, labels.submitted), - self.count_case_url_status(URLStatus.NOT_RELEVANT, labels.not_relevant), - self.count_case_url_status(URLStatus.ERROR, labels.error), - self.count_case_url_status(URLStatus.DUPLICATE, labels.duplicate), + func.coalesce(DUPLICATE_CTE.count, 0).label(labels.duplicate), + func.coalesce(SUBMITTED_CTE.count, 0).label(labels.submitted), + func.coalesce(PENDING_CTE.count, 0).label(labels.pending), + func.coalesce(ALL_CTE.count, 0).label(labels.total), + func.coalesce(NOT_RELEVANT_CTE.count, 0).label(labels.not_relevant), + func.coalesce(ERROR_CTE.count, 0).label(labels.error), ) .select_from(Batch) - .outerjoin(LinkBatchURL) - .outerjoin( - URL + .join( + BatchURLStatusMatView, + BatchURLStatusMatView.batch_id == Batch.id, ) ) + for cte in [DUPLICATE_CTE, SUBMITTED_CTE, PENDING_CTE, ALL_CTE, NOT_RELEVANT_CTE, ERROR_CTE]: + query = query.outerjoin( + cte.cte, + Batch.id == cte.batch_id + ) + return query def build(self): query = self.get_core_query() - query = self.apply_pending_urls_filter(query) query = self.apply_collector_type_filter(query) query = self.apply_status_filter(query) query = self.apply_batch_id_filter(query) - query = query.group_by(Batch.id) query = add_page_offset(query, page=self.page) query = query.order_by(Batch.id) self.query = query.cte("url_counts") @@ -67,23 +78,6 @@ def apply_batch_id_filter(self, query: Select): return query return query.where(Batch.id == self.batch_id) - def apply_pending_urls_filter(self, query: Select): - if self.has_pending_urls is None: - return query - pending_url_subquery = ( - exists( - Select(URL).join(LinkBatchURL).where( - and_( - LinkBatchURL.batch_id == Batch.id, - URL.outcome == URLStatus.PENDING.value - ) - ) - ) - ).correlate(Batch) - if self.has_pending_urls: - return query.where(pending_url_subquery) - return query.where(~pending_url_subquery) - def apply_collector_type_filter(self, query: Select): if self.collector_type is None: return query @@ -92,19 +86,4 @@ def apply_collector_type_filter(self, query: Select): def apply_status_filter(self, query: Select): if self.status is None: return query - return query.where(Batch.status == self.status.value) - - @staticmethod - def count_case_url_status( - url_status: URLStatus, - label: str - ) -> Label: - return ( - coalesce( - count( - case( - (URL.outcome == url_status.value, 1) - ) - ) - , 0).label(label) - ) + return query.where(BatchURLStatusMatView.batch_url_status == self.status.value) diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/__init__.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/all.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/all.py new file mode 100644 index 00000000..5cab51cf --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/all.py @@ -0,0 +1,20 @@ +from sqlalchemy import select, func + +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +ALL_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(LinkBatchURL.url_id).label("total_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .group_by( + Batch.id + ).cte("total_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/duplicate.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/duplicate.py new file mode 100644 index 00000000..906dd49c --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/duplicate.py @@ -0,0 +1,29 @@ +from sqlalchemy import select, func + +from src.collectors.enums import URLStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +DUPLICATE_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(URL.id).label("duplicate_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .join( + URL, + URL.id == LinkBatchURL.url_id, + ) + .where( + URL.status == URLStatus.DUPLICATE + ) + .group_by( + Batch.id + ).cte("duplicate_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/error.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/error.py new file mode 100644 index 00000000..b74020c4 --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/error.py @@ -0,0 +1,29 @@ +from sqlalchemy import select, func + +from src.collectors.enums import URLStatus +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +ERROR_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(URL.id).label("error_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .join( + URL, + URL.id == LinkBatchURL.url_id, + ) + .where( + URL.status == URLStatus.ERROR + ) + .group_by( + Batch.id + ).cte("error_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/not_relevant.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/not_relevant.py new file mode 100644 index 00000000..3fba94ee --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/not_relevant.py @@ -0,0 +1,34 @@ +from sqlalchemy import select, func + +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +NOT_RELEVANT_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(URL.id).label("not_relevant_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .join( + URL, + URL.id == LinkBatchURL.url_id, + ) + .join( + FlagURLValidated, + FlagURLValidated.url_id == URL.id, + ) + .where( + FlagURLValidated.type == URLType.NOT_RELEVANT + ) + .group_by( + Batch.id + ).cte("not_relevant_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/pending.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/pending.py new file mode 100644 index 00000000..b7e4594c --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/pending.py @@ -0,0 +1,33 @@ +from sqlalchemy import select, func + +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +PENDING_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(URL.id).label("pending_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .join( + URL, + URL.id == LinkBatchURL.url_id, + ) + .outerjoin( + FlagURLValidated, + FlagURLValidated.url_id == URL.id, + ) + .where( + FlagURLValidated.type.is_(None) + ) + .group_by( + Batch.id + ).cte("pending_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/submitted.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/submitted.py new file mode 100644 index 00000000..5ab305cc --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte/submitted.py @@ -0,0 +1,32 @@ + + +from sqlalchemy import select, func + +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.queries.implementations.core.get.recent_batch_summaries.url_counts.cte_container import \ + URLCountsCTEContainer + +SUBMITTED_CTE = URLCountsCTEContainer( + select( + Batch.id, + func.count(URL.id).label("submitted_count") + ) + .join( + LinkBatchURL, + LinkBatchURL.batch_id == Batch.id, + ) + .join( + URL, + URL.id == LinkBatchURL.url_id, + ) + .join( + URLDataSource, + URLDataSource.url_id == URL.id, + ) + .group_by( + Batch.id + ).cte("submitted_count") +) \ No newline at end of file diff --git a/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte_container.py b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte_container.py new file mode 100644 index 00000000..7f769c76 --- /dev/null +++ b/src/db/queries/implementations/core/get/recent_batch_summaries/url_counts/cte_container.py @@ -0,0 +1,18 @@ +from sqlalchemy import CTE, Column + + +class URLCountsCTEContainer: + + def __init__( + self, + cte: CTE + ): + self.cte = cte + + @property + def batch_id(self) -> Column: + return self.cte.columns[0] + + @property + def count(self) -> Column: + return self.cte.columns[1] diff --git a/src/db/queries/implementations/core/metrics/urls/aggregated/pending.py b/src/db/queries/implementations/core/metrics/urls/aggregated/pending.py index 503af6c3..17136cce 100644 --- a/src/db/queries/implementations/core/metrics/urls/aggregated/pending.py +++ b/src/db/queries/implementations/core/metrics/urls/aggregated/pending.py @@ -1,23 +1,23 @@ from typing import Any, Type -from sqlalchemy import select, func, case +from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from src.api.endpoints.metrics.dtos.get.urls.aggregated.pending import GetMetricsURLsAggregatedPendingResponseDTO from src.collectors.enums import URLStatus -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion from src.db.models.mixins import URLDependentMixin from src.db.queries.base.builder import QueryBuilderBase -from src.db.queries.implementations.core.common.annotation_exists import AnnotationExistsCTEQueryBuilder +from src.db.queries.implementations.core.common.annotation_exists_.core import AnnotationExistsCTEQueryBuilder class PendingAnnotationExistsCTEQueryBuilder(AnnotationExistsCTEQueryBuilder): @property def has_user_relevant_annotation(self): - return self.get_exists_for_model(UserRelevantSuggestion) + return self.get_exists_for_model(UserURLTypeSuggestion) @property def has_user_record_type_annotation(self): @@ -44,7 +44,7 @@ async def build(self) -> Any: URL.id == self.url_id ) .where( - URL.outcome == URLStatus.PENDING.value + URL.status == URLStatus.OK.value ).cte("pending") ) diff --git a/src/db/queries/implementations/core/tasks/agency_sync/upsert.py b/src/db/queries/implementations/core/tasks/agency_sync/upsert.py deleted file mode 100644 index cff2044b..00000000 --- a/src/db/queries/implementations/core/tasks/agency_sync/upsert.py +++ /dev/null @@ -1,19 +0,0 @@ -from src.external.pdap.dtos.agencies_sync import AgenciesSyncResponseInnerInfo - - -def get_upsert_agencies_mappings( - agencies: list[AgenciesSyncResponseInnerInfo] -) -> list[dict]: - agency_dicts = [] - for agency in agencies: - agency_dict = { - 'agency_id': agency.agency_id, - 'name': agency.display_name, - 'state': agency.state_name, - 'county': agency.county_name, - 'locality': agency.locality_name, - 'ds_last_updated_at': agency.updated_at - } - agency_dicts.append(agency_dict) - - return agency_dicts \ No newline at end of file diff --git a/src/db/queries/implementations/location/__init__.py b/src/db/queries/implementations/location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/location/get.py b/src/db/queries/implementations/location/get.py new file mode 100644 index 00000000..7ab3c381 --- /dev/null +++ b/src/db/queries/implementations/location/get.py @@ -0,0 +1,49 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db import Location +from src.db.helpers.session import session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class GetLocationQueryBuilder(QueryBuilderBase): + + def __init__( + self, + us_state_id: int, + county_id: int | None = None, + locality_id: int | None = None, + ): + super().__init__() + self.us_state_id = us_state_id + self.county_id = county_id + self.locality_id = locality_id + + async def run(self, session: AsyncSession) -> int | None: + query = ( + select( + Location.id + ) + .where( + Location.state_id == self.us_state_id, + ) + ) + if self.county_id is not None: + query = query.where( + Location.county_id == self.county_id + ) + else: + query = query.where( + Location.county_id.is_(None) + ) + + if self.locality_id is not None: + query = query.where( + Location.locality_id == self.locality_id + ) + else: + query = query.where( + Location.locality_id.is_(None) + ) + + return await sh.one_or_none(session, query=query) diff --git a/src/db/queries/protocols.py b/src/db/queries/protocols.py index 0098e953..b1a2ce20 100644 --- a/src/db/queries/protocols.py +++ b/src/db/queries/protocols.py @@ -6,4 +6,4 @@ class HasQuery(Protocol): def __init__(self): - self.query: Optional[Select] = None + self.query: Select | None = None diff --git a/src/db/statement_composer.py b/src/db/statement_composer.py index 9d5faa97..0ae843b3 100644 --- a/src/db/statement_composer.py +++ b/src/db/statement_composer.py @@ -1,21 +1,22 @@ +from http import HTTPStatus from typing import Any from sqlalchemy import Select, select, exists, func, Subquery, and_, not_, ColumnElement -from sqlalchemy.orm import aliased, selectinload +from sqlalchemy.orm import selectinload from src.collectors.enums import URLStatus from src.core.enums import BatchStatus from src.db.constants import STANDARD_ROW_LIMIT from src.db.enums import TaskType -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.link.link_task_url import LinkTaskURL -from src.db.models.instantiations.task.core import Task -from src.db.models.instantiations.url.html_content import URLHTMLContent -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.batch import Batch -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion +from src.db.models.impl.batch.sqlalchemy import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.link.task_url import LinkTaskURL +from src.db.models.impl.task.core import Task +from src.db.models.impl.task.enums import TaskStatus +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.scrape_info.sqlalchemy import URLScrapeInfo +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata from src.db.types import UserSuggestionType @@ -25,21 +26,25 @@ class StatementComposer: """ @staticmethod - def pending_urls_without_html_data() -> Select: + def has_non_errored_urls_without_html_data() -> Select: exclude_subquery = ( select(1). select_from(LinkTaskURL). join(Task, LinkTaskURL.task_id == Task.id). where(LinkTaskURL.url_id == URL.id). where(Task.task_type == TaskType.HTML.value). - where(Task.task_status == BatchStatus.READY_TO_LABEL.value) + where(Task.task_status == TaskStatus.COMPLETE.value) ) query = ( - select(URL). - outerjoin(URLHTMLContent). - where(URLHTMLContent.id == None). - where(~exists(exclude_subquery)). - where(URL.outcome == URLStatus.PENDING.value) + select(URL) + .join(URLWebMetadata) + .outerjoin(URLScrapeInfo) + .where( + URLScrapeInfo.id == None, + ~exists(exclude_subquery), + URLWebMetadata.status_code == HTTPStatus.OK.value, + URLWebMetadata.content_type.like("%html%"), + ) .options( selectinload(URL.batch) ) @@ -68,31 +73,14 @@ def simple_count_subquery(model, attribute: str, label: str) -> Subquery: func.count(attr_value).label(label) ).group_by(attr_value).subquery() - @staticmethod - def exclude_urls_with_agency_suggestions( - statement: Select - ): - # Aliases for clarity - AutomatedSuggestion = aliased(AutomatedUrlAgencySuggestion) - - # Exclude if automated suggestions exist - statement = statement.where( - ~exists().where(AutomatedSuggestion.url_id == URL.id) - ) - # Exclude if confirmed agencies exist - statement = statement.where( - ~exists().where(ConfirmedURLAgency.url_id == URL.id) - ) - return statement - @staticmethod def pending_urls_missing_miscellaneous_metadata_query() -> Select: query = select(URL).where( and_( - URL.outcome == URLStatus.PENDING.value, - URL.name == None, - URL.description == None, - URLOptionalDataSourceMetadata.url_id == None + URL.status == URLStatus.OK.value, + URL.name == None, + URL.description == None, + URLOptionalDataSourceMetadata.url_id == None ) ).outerjoin( URLOptionalDataSourceMetadata @@ -128,17 +116,3 @@ def user_suggestion_not_exists( @staticmethod def count_distinct(field, label): return func.count(func.distinct(field)).label(label) - - @staticmethod - def sum_distinct(field, label): - return func.sum(func.distinct(field)).label(label) - - @staticmethod - def add_limit_and_page_offset(query: Select, page: int): - zero_offset_page = page - 1 - rows_offset = zero_offset_page * STANDARD_ROW_LIMIT - return query.offset( - rows_offset - ).limit( - STANDARD_ROW_LIMIT - ) diff --git a/src/db/templates/__init__.py b/src/db/templates/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/templates/markers/__init__.py b/src/db/templates/markers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/templates/markers/bulk/__init__.py b/src/db/templates/markers/bulk/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/templates/markers/bulk/delete.py b/src/db/templates/markers/bulk/delete.py new file mode 100644 index 00000000..9da0c980 --- /dev/null +++ b/src/db/templates/markers/bulk/delete.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class BulkDeletableModel(BaseModel): + """Identifies a model that can be used for the bulk_delete function in session_helper.""" + diff --git a/src/db/templates/markers/bulk/insert.py b/src/db/templates/markers/bulk/insert.py new file mode 100644 index 00000000..d147e44f --- /dev/null +++ b/src/db/templates/markers/bulk/insert.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class BulkInsertableModel(BaseModel): + """Identifies a model that can be used for the bulk_insert function in session_helper.""" diff --git a/src/db/templates/markers/bulk/update.py b/src/db/templates/markers/bulk/update.py new file mode 100644 index 00000000..d0476135 --- /dev/null +++ b/src/db/templates/markers/bulk/update.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class BulkUpdatableModel(BaseModel): + """Identifies a model that can be used for the bulk_update function in session_helper.""" diff --git a/src/db/templates/markers/bulk/upsert.py b/src/db/templates/markers/bulk/upsert.py new file mode 100644 index 00000000..86d683bb --- /dev/null +++ b/src/db/templates/markers/bulk/upsert.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class BulkUpsertableModel(BaseModel): + """Identifies a model that can be used for the bulk_upsert function in session_helper.""" \ No newline at end of file diff --git a/src/db/templates/protocols/__init__.py b/src/db/templates/protocols/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/templates/protocols/has_id.py b/src/db/templates/protocols/has_id.py new file mode 100644 index 00000000..fc3519a2 --- /dev/null +++ b/src/db/templates/protocols/has_id.py @@ -0,0 +1,6 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class HasIDProtocol(Protocol): + id: int \ No newline at end of file diff --git a/src/db/templates/protocols/sa_correlated/__init__.py b/src/db/templates/protocols/sa_correlated/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/templates/protocols/sa_correlated/core.py b/src/db/templates/protocols/sa_correlated/core.py new file mode 100644 index 00000000..82475e60 --- /dev/null +++ b/src/db/templates/protocols/sa_correlated/core.py @@ -0,0 +1,15 @@ +from abc import abstractmethod +from typing import Protocol, runtime_checkable + +from src.db.models.templates_.base import Base + + +@runtime_checkable +class SQLAlchemyCorrelatedProtocol(Protocol): + + + @classmethod + @abstractmethod + def sa_model(cls) -> type[Base]: + """Defines the SQLAlchemy model.""" + pass diff --git a/src/db/templates/protocols/sa_correlated/with_id.py b/src/db/templates/protocols/sa_correlated/with_id.py new file mode 100644 index 00000000..7e920e76 --- /dev/null +++ b/src/db/templates/protocols/sa_correlated/with_id.py @@ -0,0 +1,20 @@ +from abc import abstractmethod +from typing import Protocol, runtime_checkable + +from src.db.models.templates_.base import Base + + +@runtime_checkable +class SQLAlchemyCorrelatedWithIDProtocol(Protocol): + + @classmethod + @abstractmethod + def id_field(cls) -> str: + """Defines the field to be used as the primary key.""" + return "id" + + @classmethod + @abstractmethod + def sa_model(cls) -> type[Base]: + """Defines the correlated SQLAlchemy model.""" + pass diff --git a/src/db/templates/requester.py b/src/db/templates/requester.py new file mode 100644 index 00000000..b56af87f --- /dev/null +++ b/src/db/templates/requester.py @@ -0,0 +1,20 @@ +""" +A requester is a class that contains a session and provides methods for +performing database operations. +""" +from abc import ABC + +from sqlalchemy.ext.asyncio import AsyncSession + +import src.db.helpers.session.session_helper as sh +from src.db.queries.base.builder import QueryBuilderBase + + +class RequesterBase(ABC): + + def __init__(self, session: AsyncSession): + self.session = session + self.session_helper = sh + + async def run_query_builder(self, query_builder: QueryBuilderBase): + return await query_builder.run(session=self.session) \ No newline at end of file diff --git a/src/db/types.py b/src/db/types.py index dadef2f1..dcee196f 100644 --- a/src/db/types.py +++ b/src/db/types.py @@ -1,10 +1,10 @@ from typing import TypeVar -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion from src.db.queries.base.labels import LabelsBase -UserSuggestionType = UserUrlAgencySuggestion | UserRelevantSuggestion | UserRecordTypeSuggestion +UserSuggestionType = UserUrlAgencySuggestion | UserURLTypeSuggestion | UserRecordTypeSuggestion LabelsType = TypeVar("LabelsType", bound=LabelsBase) \ No newline at end of file diff --git a/src/db/utils/validate.py b/src/db/utils/validate.py new file mode 100644 index 00000000..4837e12c --- /dev/null +++ b/src/db/utils/validate.py @@ -0,0 +1,27 @@ +from typing import Protocol +from urllib.parse import urlparse + +from pydantic import BaseModel + + +def validate_has_protocol(obj: object, protocol: type[Protocol]): + if not isinstance(obj, protocol): + raise TypeError(f"Class must implement {protocol} protocol.") + +def validate_all_models_of_same_type(objects: list[object]): + first_model = objects[0] + if not all(isinstance(model, type(first_model)) for model in objects): + raise TypeError("Models must be of the same type") + +def is_valid_url(url: str) -> bool: + try: + result = urlparse(url) + # If scheme is missing, `netloc` will be empty, so we check path too + if result.scheme in ("http", "https") and result.netloc: + return True + if not result.scheme and result.path: + # no scheme, treat path as potential domain + return "." in result.path + return False + except ValueError: + return False diff --git a/src/external/huggingface/hub/__init__.py b/src/external/huggingface/hub/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/huggingface/hub/client.py b/src/external/huggingface/hub/client.py new file mode 100644 index 00000000..3ca53ceb --- /dev/null +++ b/src/external/huggingface/hub/client.py @@ -0,0 +1,49 @@ + +from datasets import Dataset +from huggingface_hub import HfApi + +from src.external.huggingface.hub.constants import DATA_SOURCES_RAW_REPO_ID +from src.external.huggingface.hub.format import format_as_huggingface_dataset +from src.core.tasks.scheduled.impl.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput + + +class HuggingFaceHubClient: + + def __init__(self, token: str): + self.token = token + self.api = HfApi(token=token) + + def _push_dataset_to_hub( + self, + repo_id: str, + dataset: Dataset, + idx: int + ) -> None: + """ + Modifies: + - repository on Hugging Face, identified by `repo_id` + """ + dataset.to_parquet(f"part_{idx}.parquet") + self.api.upload_file( + path_or_fileobj=f"part_{idx}.parquet", + path_in_repo=f"data/part_{idx}.parquet", + repo_id=repo_id, + repo_type="dataset", + ) + + def push_data_sources_raw_to_hub( + self, + outputs: list[GetForLoadingToHuggingFaceOutput], + idx: int + ) -> None: + """ + Modifies: + - repository on Hugging Face, identified by `DATA_SOURCES_RAW_REPO_ID` + """ + dataset = format_as_huggingface_dataset(outputs) + print(dataset) + self._push_dataset_to_hub( + repo_id=DATA_SOURCES_RAW_REPO_ID, + dataset=dataset, + idx=idx + ) \ No newline at end of file diff --git a/src/external/huggingface/hub/constants.py b/src/external/huggingface/hub/constants.py new file mode 100644 index 00000000..2cffa4f8 --- /dev/null +++ b/src/external/huggingface/hub/constants.py @@ -0,0 +1,3 @@ + + +DATA_SOURCES_RAW_REPO_ID = "PDAP/data_sources_raw" \ No newline at end of file diff --git a/src/external/huggingface/hub/format.py b/src/external/huggingface/hub/format.py new file mode 100644 index 00000000..e1eb32b6 --- /dev/null +++ b/src/external/huggingface/hub/format.py @@ -0,0 +1,23 @@ +from datasets import Dataset + +from src.core.tasks.scheduled.impl.huggingface.queries.get.model import GetForLoadingToHuggingFaceOutput + + +def format_as_huggingface_dataset(outputs: list[GetForLoadingToHuggingFaceOutput]) -> Dataset: + d = { + 'url_id': [], + 'url': [], + 'relevant': [], + 'record_type_fine': [], + 'record_type_coarse': [], + 'html': [] + } + for output in outputs: + d['url_id'].append(output.url_id) + d['url'].append(output.url) + d['relevant'].append(output.relevant) + d['record_type_fine'].append(output.record_type_fine.value) + d['record_type_coarse'].append(output.record_type_coarse.value) + d['html'].append(output.html) + return Dataset.from_dict(d) + diff --git a/src/external/internet_archives/__init__.py b/src/external/internet_archives/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/internet_archives/client.py b/src/external/internet_archives/client.py new file mode 100644 index 00000000..de09eb5b --- /dev/null +++ b/src/external/internet_archives/client.py @@ -0,0 +1,110 @@ +import asyncio +from asyncio import Semaphore + +from aiolimiter import AsyncLimiter +from aiohttp import ClientSession + +from src.external.internet_archives.convert import convert_capture_to_archive_metadata +from src.external.internet_archives.models.capture import IACapture +from src.external.internet_archives.models.ia_url_mapping import InternetArchivesURLMapping +from src.external.internet_archives.models.save_response import InternetArchivesSaveResponseInfo + +from environs import Env + +limiter = AsyncLimiter( + max_rate=50, + time_period=50 +) +sem = Semaphore(10) + + + +class InternetArchivesClient: + + def __init__( + self, + session: ClientSession + ): + self.session = session + + env = Env() + env.read_env() + + self.s3_keys = env.str("INTERNET_ARCHIVE_S3_KEYS") + + async def _get_url_snapshot(self, url: str) -> IACapture | None: + params = { + "url": url, + "output": "json", + "limit": "1", + "gzip": "false", + "filter": "statuscode:200", + "fl": "timestamp,original,length,digest" + } + async with sem: + async with limiter: + async with self.session.get( + f"http://web.archive.org/cdx/search/cdx", + params=params + ) as response: + raw_data = await response.json() + if len(raw_data) == 0: + return None + fields = raw_data[0] + values = raw_data[1] + d = dict(zip(fields, values)) + + return IACapture(**d) + + async def search_for_url_snapshot(self, url: str) -> InternetArchivesURLMapping: + try: + capture: IACapture | None = await self._get_url_snapshot(url) + except Exception as e: + return InternetArchivesURLMapping( + url=url, + ia_metadata=None, + error=f"{e.__class__.__name__}: {e}" + ) + + if capture is None: + return InternetArchivesURLMapping( + url=url, + ia_metadata=None, + error=None + ) + + metadata = convert_capture_to_archive_metadata(capture) + return InternetArchivesURLMapping( + url=url, + ia_metadata=metadata, + error=None + ) + + async def _save_url(self, url: str) -> int: + async with self.session.post( + f"http://web.archive.org/save", + data={ + "url": url, + "skip_first_archive": 1 + }, + headers={ + "Authorization": f"LOW {self.s3_keys}", + "Accept": "application/json" + } + ) as response: + response.raise_for_status() + return response.status + + async def save_to_internet_archives(self, url: str) -> InternetArchivesSaveResponseInfo: + try: + _: int = await self._save_url(url) + except Exception as e: + return InternetArchivesSaveResponseInfo( + url=url, + error=f"{e.__class__.__name__}: {e}" + ) + + return InternetArchivesSaveResponseInfo( + url=url, + error=None + ) diff --git a/src/external/internet_archives/constants.py b/src/external/internet_archives/constants.py new file mode 100644 index 00000000..9ddc48bf --- /dev/null +++ b/src/external/internet_archives/constants.py @@ -0,0 +1,3 @@ + + +MAX_CONCURRENT_REQUESTS = 10 \ No newline at end of file diff --git a/src/external/internet_archives/convert.py b/src/external/internet_archives/convert.py new file mode 100644 index 00000000..df7079ab --- /dev/null +++ b/src/external/internet_archives/convert.py @@ -0,0 +1,11 @@ +from src.external.internet_archives.models.archive_metadata import IAArchiveMetadata +from src.external.internet_archives.models.capture import IACapture + + +def convert_capture_to_archive_metadata(capture: IACapture) -> IAArchiveMetadata: + archive_url = f"https://web.archive.org/web/{capture.timestamp}/{capture.original}" + return IAArchiveMetadata( + archive_url=archive_url, + length=capture.length, + digest=capture.digest + ) \ No newline at end of file diff --git a/src/external/internet_archives/models/__init__.py b/src/external/internet_archives/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/internet_archives/models/archive_metadata.py b/src/external/internet_archives/models/archive_metadata.py new file mode 100644 index 00000000..2093377c --- /dev/null +++ b/src/external/internet_archives/models/archive_metadata.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class IAArchiveMetadata(BaseModel): + archive_url: str + length: int + digest: str \ No newline at end of file diff --git a/src/external/internet_archives/models/capture.py b/src/external/internet_archives/models/capture.py new file mode 100644 index 00000000..839c8ed0 --- /dev/null +++ b/src/external/internet_archives/models/capture.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class IACapture(BaseModel): + timestamp: int + original: str + length: int + digest: str \ No newline at end of file diff --git a/src/external/internet_archives/models/ia_url_mapping.py b/src/external/internet_archives/models/ia_url_mapping.py new file mode 100644 index 00000000..21650b0c --- /dev/null +++ b/src/external/internet_archives/models/ia_url_mapping.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from src.external.internet_archives.models.archive_metadata import IAArchiveMetadata + + +class InternetArchivesURLMapping(BaseModel): + url: str + ia_metadata: IAArchiveMetadata | None + error: str | None + + @property + def has_error(self) -> bool: + return self.error is not None + + @property + def has_metadata(self) -> bool: + return self.ia_metadata is not None diff --git a/src/external/internet_archives/models/save_response.py b/src/external/internet_archives/models/save_response.py new file mode 100644 index 00000000..031c0403 --- /dev/null +++ b/src/external/internet_archives/models/save_response.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class InternetArchivesSaveResponseInfo(BaseModel): + url: str + error: str | None = None + + @property + def has_error(self) -> bool: + return self.error is not None \ No newline at end of file diff --git a/src/external/pdap/client.py b/src/external/pdap/client.py index 126e7970..1c950ad3 100644 --- a/src/external/pdap/client.py +++ b/src/external/pdap/client.py @@ -1,14 +1,14 @@ -from typing import Optional +from typing import Any -from pdap_access_manager import AccessManager, DataSourcesNamespaces, RequestInfo, RequestType +from pdap_access_manager import AccessManager, DataSourcesNamespaces, RequestInfo, RequestType, ResponseInfo -from src.core.tasks.scheduled.operators.agency_sync.dtos.parameters import AgencySyncParameters -from src.core.tasks.url.operators.submit_approved_url.tdo import SubmitApprovedURLTDO, SubmittedURLInfo -from src.external.pdap.dtos.agencies_sync import AgenciesSyncResponseInnerInfo, AgenciesSyncResponseInfo +from src.core.tasks.url.operators.submit_approved.tdo import SubmitApprovedURLTDO, SubmittedURLInfo from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse from src.external.pdap.dtos.unique_url_duplicate import UniqueURLDuplicateInfo from src.external.pdap.enums import MatchAgencyResponseStatus +from src.external.pdap.impl.meta_urls.core import submit_meta_urls +from src.external.pdap.impl.meta_urls.request import SubmitMetaURLsRequest class PDAPClient: @@ -22,20 +22,20 @@ def __init__( async def match_agency( self, name: str, - state: Optional[str] = None, - county: Optional[str] = None, - locality: Optional[str] = None + state: str | None = None, + county: str | None = None, + locality: str | None = None ) -> MatchAgencyResponse: """ Returns agencies, if any, that match or partially match the search criteria """ - url = self.access_manager.build_url( + url: str = self.access_manager.build_url( namespace=DataSourcesNamespaces.MATCH, subdomains=["agency"] ) - headers = await self.access_manager.jwt_header() - headers['Content-Type'] = "application/json" + headers: dict[str, str] = await self.access_manager.jwt_header() + headers['Content-Type']: str = "application/json" request_info = RequestInfo( type_=RequestType.POST, url=url, @@ -47,15 +47,15 @@ async def match_agency( "locality": locality } ) - response_info = await self.access_manager.make_request(request_info) - matches = [] + response_info: ResponseInfo = await self.access_manager.make_request(request_info) + matches: list[MatchAgencyInfo] = [] for agency in response_info.data["agencies"]: mai = MatchAgencyInfo( id=agency['id'], submitted_name=agency['name'] ) if len(agency['locations']) > 0: - first_location = agency['locations'][0] + first_location: dict[str, Any] = agency['locations'][0] mai.state = first_location['state'] mai.county = first_location['county'] mai.locality = first_location['locality'] @@ -73,7 +73,7 @@ async def is_url_duplicate( """ Check if a URL is unique. Returns duplicate info otherwise """ - url = self.access_manager.build_url( + url: str = self.access_manager.build_url( namespace=DataSourcesNamespaces.CHECK, subdomains=["unique-url"] ) @@ -84,12 +84,14 @@ async def is_url_duplicate( "url": url_to_check } ) - response_info = await self.access_manager.make_request(request_info) - duplicates = [UniqueURLDuplicateInfo(**entry) for entry in response_info.data["duplicates"]] - is_duplicate = (len(duplicates) != 0) + response_info: ResponseInfo = await self.access_manager.make_request(request_info) + duplicates: list[UniqueURLDuplicateInfo] = [ + UniqueURLDuplicateInfo(**entry) for entry in response_info.data["duplicates"] + ] + is_duplicate: bool = (len(duplicates) != 0) return is_duplicate - async def submit_urls( + async def submit_data_source_urls( self, tdos: list[SubmitApprovedURLTDO] ) -> list[SubmittedURLInfo]: @@ -103,11 +105,11 @@ async def submit_urls( ) # Build url-id dictionary - url_id_dict = {} + url_id_dict: dict[str, int] = {} for tdo in tdos: url_id_dict[tdo.url] = tdo.url_id - data_sources_json = [] + data_sources_json: list[dict[str, Any]] = [] for tdo in tdos: data_sources_json.append( { @@ -123,7 +125,7 @@ async def submit_urls( } ) - headers = await self.access_manager.jwt_header() + headers: dict[str, str] = await self.access_manager.jwt_header() request_info = RequestInfo( type_=RequestType.POST, url=request_url, @@ -132,12 +134,12 @@ async def submit_urls( "data_sources": data_sources_json } ) - response_info = await self.access_manager.make_request(request_info) - data_sources_response_json = response_info.data["data_sources"] + response_info: ResponseInfo = await self.access_manager.make_request(request_info) + data_sources_response_json: list[dict[str, Any]] = response_info.data["data_sources"] - results = [] + results: list[SubmittedURLInfo] = [] for data_source in data_sources_response_json: - url = data_source["url"] + url: str = data_source["url"] response_object = SubmittedURLInfo( url_id=url_id_dict[url], data_source_id=data_source["data_source_id"], @@ -147,32 +149,11 @@ async def submit_urls( return results - async def sync_agencies( + async def submit_meta_urls( self, - params: AgencySyncParameters - ) -> AgenciesSyncResponseInfo: - url =self.access_manager.build_url( - namespace=DataSourcesNamespaces.SOURCE_COLLECTOR, - subdomains=[ - "agencies", - "sync" - ] - ) - headers = await self.access_manager.jwt_header() - headers['Content-Type'] = "application/json" - request_info = RequestInfo( - type_=RequestType.GET, - url=url, - headers=headers, - params={ - "page": params.page, - "update_at": params.cutoff_date - } - ) - response_info = await self.access_manager.make_request(request_info) - return AgenciesSyncResponseInfo( - agencies=[ - AgenciesSyncResponseInnerInfo(**entry) - for entry in response_info.data["agencies"] - ] + requests: list[SubmitMetaURLsRequest] + ): + return await submit_meta_urls( + self.access_manager, + requests=requests ) \ No newline at end of file diff --git a/src/external/pdap/dtos/agencies_sync.py b/src/external/pdap/dtos/agencies_sync.py deleted file mode 100644 index 7f2b5ad0..00000000 --- a/src/external/pdap/dtos/agencies_sync.py +++ /dev/null @@ -1,15 +0,0 @@ -import datetime -from typing import Optional - -from pydantic import BaseModel - -class AgenciesSyncResponseInnerInfo(BaseModel): - display_name: str - agency_id: int - state_name: Optional[str] - county_name: Optional[str] - locality_name: Optional[str] - updated_at: datetime.datetime - -class AgenciesSyncResponseInfo(BaseModel): - agencies: list[AgenciesSyncResponseInnerInfo] diff --git a/src/external/pdap/dtos/match_agency/post.py b/src/external/pdap/dtos/match_agency/post.py index 14870796..2be0b90e 100644 --- a/src/external/pdap/dtos/match_agency/post.py +++ b/src/external/pdap/dtos/match_agency/post.py @@ -6,6 +6,6 @@ class MatchAgencyInfo(BaseModel): id: int submitted_name: str - state: Optional[str] = None - county: Optional[str] = None - locality: Optional[str] = None + state: str | None = None + county: str | None = None + locality: str | None = None diff --git a/src/external/pdap/dtos/search_agency_by_location/__init__.py b/src/external/pdap/dtos/search_agency_by_location/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/pdap/dtos/search_agency_by_location/params.py b/src/external/pdap/dtos/search_agency_by_location/params.py new file mode 100644 index 00000000..96ebd2fa --- /dev/null +++ b/src/external/pdap/dtos/search_agency_by_location/params.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + + +class SearchAgencyByLocationParams(BaseModel): + request_id: int + query: str + iso: str = Field( + description="US State ISO Code", + max_length=2, + ) \ No newline at end of file diff --git a/src/external/pdap/dtos/search_agency_by_location/response.py b/src/external/pdap/dtos/search_agency_by_location/response.py new file mode 100644 index 00000000..92242b5a --- /dev/null +++ b/src/external/pdap/dtos/search_agency_by_location/response.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field + +class SearchAgencyByLocationAgencyInfo(BaseModel): + agency_id: int + similarity: float = Field(ge=0, le=1) + +class SearchAgencyByLocationResponse(BaseModel): + request_id: int + results: list[SearchAgencyByLocationAgencyInfo] = Field(min_length=1) + +class SearchAgencyByLocationOuterResponse(BaseModel): + responses: list[SearchAgencyByLocationResponse] \ No newline at end of file diff --git a/src/external/pdap/dtos/unique_url_duplicate.py b/src/external/pdap/dtos/unique_url_duplicate.py index 096622fe..51e327f1 100644 --- a/src/external/pdap/dtos/unique_url_duplicate.py +++ b/src/external/pdap/dtos/unique_url_duplicate.py @@ -8,4 +8,4 @@ class UniqueURLDuplicateInfo(BaseModel): original_url: str approval_status: ApprovalStatus - rejection_note: Optional[str] = None + rejection_note: str | None = None diff --git a/src/external/pdap/enums.py b/src/external/pdap/enums.py index 36111acd..c532f820 100644 --- a/src/external/pdap/enums.py +++ b/src/external/pdap/enums.py @@ -12,3 +12,9 @@ class ApprovalStatus(Enum): REJECTED = "rejected" PENDING = "pending" NEEDS_IDENTIFICATION = "needs identification" + +class DataSourcesURLStatus(Enum): + AVAILABLE = "available" + BROKEN = "broken" + OK = "ok" + NONE_FOUND = "none found" \ No newline at end of file diff --git a/src/external/pdap/impl/__init__.py b/src/external/pdap/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/pdap/impl/meta_urls/__init__.py b/src/external/pdap/impl/meta_urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/pdap/impl/meta_urls/core.py b/src/external/pdap/impl/meta_urls/core.py new file mode 100644 index 00000000..4a34fbeb --- /dev/null +++ b/src/external/pdap/impl/meta_urls/core.py @@ -0,0 +1,58 @@ +from typing import Any + +from pdap_access_manager import AccessManager, DataSourcesNamespaces, RequestInfo, RequestType, ResponseInfo + +from src.external.pdap.impl.meta_urls.enums import SubmitMetaURLsStatus +from src.external.pdap.impl.meta_urls.request import SubmitMetaURLsRequest +from src.external.pdap.impl.meta_urls.response import SubmitMetaURLsResponse + + +async def submit_meta_urls( + access_manager: AccessManager, + requests: list[SubmitMetaURLsRequest] +) -> list[SubmitMetaURLsResponse]: + + + # Build url-id dictionary + url_id_dict: dict[str, int] = {} + for request in requests: + url_id_dict[request.url] = request.url_id + + meta_urls_json: list[dict[str, Any]] = [] + for request in requests: + meta_urls_json.append( + { + "url": request.url, + "agency_id": request.agency_id + } + ) + + headers: dict[str, str] = await access_manager.jwt_header() + url: str = access_manager.build_url( + namespace=DataSourcesNamespaces.SOURCE_COLLECTOR, + subdomains=["meta-urls"] + ) + request_info = RequestInfo( + type_=RequestType.POST, + url=url, + headers=headers, + json_={ + "meta_urls": meta_urls_json + } + ) + + response_info: ResponseInfo = await access_manager.make_request(request_info) + meta_urls_response_json: list[dict[str, Any]] = response_info.data["meta_urls"] + + responses: list[SubmitMetaURLsResponse] = [] + for meta_url in meta_urls_response_json: + responses.append( + SubmitMetaURLsResponse( + url=meta_url["url"], + status=SubmitMetaURLsStatus(meta_url["status"]), + agency_id=meta_url["agency_id"], + meta_url_id=meta_url["meta_url_id"], + error=meta_url["error"] + ) + ) + return responses \ No newline at end of file diff --git a/src/external/pdap/impl/meta_urls/enums.py b/src/external/pdap/impl/meta_urls/enums.py new file mode 100644 index 00000000..e49e71aa --- /dev/null +++ b/src/external/pdap/impl/meta_urls/enums.py @@ -0,0 +1,7 @@ +from enum import Enum + + +class SubmitMetaURLsStatus(Enum): + SUCCESS = "success" + FAILURE = "failure" + ALREADY_EXISTS = "already_exists" \ No newline at end of file diff --git a/src/external/pdap/impl/meta_urls/request.py b/src/external/pdap/impl/meta_urls/request.py new file mode 100644 index 00000000..ac222aca --- /dev/null +++ b/src/external/pdap/impl/meta_urls/request.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class SubmitMetaURLsRequest(BaseModel): + url_id: int + url: str + agency_id: int diff --git a/src/external/pdap/impl/meta_urls/response.py b/src/external/pdap/impl/meta_urls/response.py new file mode 100644 index 00000000..96d5ece7 --- /dev/null +++ b/src/external/pdap/impl/meta_urls/response.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + +from src.external.pdap.impl.meta_urls.enums import SubmitMetaURLsStatus + + +class SubmitMetaURLsResponse(BaseModel): + url: str + status: SubmitMetaURLsStatus + meta_url_id: int | None = None + agency_id: int | None = None + error: str | None = None \ No newline at end of file diff --git a/src/core/tasks/url/operators/url_html/scraper/request_interface/README.md b/src/external/url_request/README.md similarity index 100% rename from src/core/tasks/url/operators/url_html/scraper/request_interface/README.md rename to src/external/url_request/README.md diff --git a/src/external/url_request/__init__.py b/src/external/url_request/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/url_request/constants.py b/src/external/url_request/constants.py new file mode 100644 index 00000000..178b0fad --- /dev/null +++ b/src/external/url_request/constants.py @@ -0,0 +1,6 @@ +from typing import Literal + +HTML_CONTENT_TYPE = "text/html" +MAX_CONCURRENCY = 5 + +NETWORK_IDLE: Literal["networkidle"] = "networkidle" \ No newline at end of file diff --git a/src/external/url_request/core.py b/src/external/url_request/core.py new file mode 100644 index 00000000..7a6920fe --- /dev/null +++ b/src/external/url_request/core.py @@ -0,0 +1,22 @@ +from aiohttp import ClientSession, ClientTimeout + +from src.external.url_request.dtos.url_response import URLResponseInfo +from src.external.url_request.probe.core import URLProbeManager +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper +from src.external.url_request.request import fetch_urls + + +class URLRequestInterface: + + @staticmethod + async def make_requests_with_html( + urls: list[str], + ) -> list[URLResponseInfo]: + return await fetch_urls(urls) + + @staticmethod + async def probe_urls(urls: list[str]) -> list[URLProbeResponseOuterWrapper]: + async with ClientSession(timeout=ClientTimeout(total=30)) as session: + manager = URLProbeManager(session=session) + return await manager.probe_urls(urls=urls) + diff --git a/src/external/url_request/dtos/__init__.py b/src/external/url_request/dtos/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/url_request/dtos/request_resources.py b/src/external/url_request/dtos/request_resources.py new file mode 100644 index 00000000..01a5365f --- /dev/null +++ b/src/external/url_request/dtos/request_resources.py @@ -0,0 +1,14 @@ +import asyncio +from dataclasses import dataclass + +from aiohttp import ClientSession +from playwright.async_api import async_playwright + +from src.external.url_request.constants import MAX_CONCURRENCY + + +@dataclass +class RequestResources: + session: ClientSession + browser: async_playwright + semaphore: asyncio.Semaphore = asyncio.Semaphore(MAX_CONCURRENCY) diff --git a/src/external/url_request/dtos/screenshot_response.py b/src/external/url_request/dtos/screenshot_response.py new file mode 100644 index 00000000..bb36b258 --- /dev/null +++ b/src/external/url_request/dtos/screenshot_response.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel + + +class URLScreenshotResponse(BaseModel): + url: str + screenshot: bytes | None + error: str | None = None + + @property + def is_success(self) -> bool: + return self.error is None \ No newline at end of file diff --git a/src/external/url_request/dtos/url_response.py b/src/external/url_request/dtos/url_response.py new file mode 100644 index 00000000..57303a7c --- /dev/null +++ b/src/external/url_request/dtos/url_response.py @@ -0,0 +1,12 @@ +from http import HTTPStatus +from typing import Optional + +from pydantic import BaseModel + + +class URLResponseInfo(BaseModel): + success: bool + status: HTTPStatus | None = None + html: str | None = None + content_type: str | None = None + exception: str | None = None diff --git a/src/external/url_request/probe/__init__.py b/src/external/url_request/probe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/url_request/probe/convert.py b/src/external/url_request/probe/convert.py new file mode 100644 index 00000000..3b15268a --- /dev/null +++ b/src/external/url_request/probe/convert.py @@ -0,0 +1,112 @@ +from http import HTTPStatus +from typing import Sequence + +from aiohttp import ClientResponse, ClientResponseError + +from src.external.url_request.probe.models.response import URLProbeResponse +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper + + +def _process_client_response_history(history: Sequence[ClientResponse]) -> list[str]: + return [str(cr.url) for cr in history] + + +def _extract_content_type(cr: ClientResponse, error: str | None) -> str | None: + if error is None: + return cr.content_type + return None + + +def _extract_redirect_probe_response(cr: ClientResponse) -> URLProbeResponse | None: + """Returns the probe response for the first redirect. + + This is the original URL that was probed.""" + if len(cr.history) == 0: + return None + + all_urls = [str(cr.url) for cr in cr.history] + first_url = all_urls[0] + + return URLProbeResponse( + url=first_url, + status_code=HTTPStatus.FOUND.value, + content_type=None, + error=None, + ) + + +def _extract_error(cr: ClientResponse) -> str | None: + try: + cr.raise_for_status() + return None + except ClientResponseError as e: + return str(e) + +def _has_redirect(cr: ClientResponse) -> bool: + return len(cr.history) > 0 + +def _extract_source_url(cr: ClientResponse) -> str: + return str(cr.history[0].url) + +def _extract_destination_url(cr: ClientResponse) -> str: + return str(cr.url) + +def convert_client_response_to_probe_response( + url: str, + cr: ClientResponse +) -> URLProbeResponse | URLProbeRedirectResponsePair: + error = _extract_error(cr) + content_type = _extract_content_type(cr, error=error) + if not _has_redirect(cr): + return URLProbeResponse( + url=str(cr.url), + status_code=cr.status, + content_type=content_type, + error=error, + ) + + # Extract into separate probe responses + source_cr = cr.history[0] # Source CR is the first in the history + destination_cr = cr + + destination_url = str(destination_cr.url) + + source_error = _extract_error(source_cr) + source_content_type = _extract_content_type(source_cr, error=source_error) + source_probe_response = URLProbeResponse( + url=url, + status_code=source_cr.status, + content_type=source_content_type, + error=source_error, + ) + + + destination_error = _extract_error(destination_cr) + destination_content_type = _extract_content_type(destination_cr, error=destination_error) + destination_probe_response = URLProbeResponse( + url=destination_url, + status_code=destination_cr.status, + content_type=destination_content_type, + error=destination_error, + ) + + return URLProbeRedirectResponsePair( + source=source_probe_response, + destination=destination_probe_response + ) + +def convert_to_error_response( + url: str, + error: str, + status_code: int | None = None +) -> URLProbeResponseOuterWrapper: + return URLProbeResponseOuterWrapper( + original_url=url, + response=URLProbeResponse( + url=url, + status_code=status_code, + content_type=None, + error=error + ) + ) diff --git a/src/external/url_request/probe/core.py b/src/external/url_request/probe/core.py new file mode 100644 index 00000000..48009381 --- /dev/null +++ b/src/external/url_request/probe/core.py @@ -0,0 +1,97 @@ +import asyncio.exceptions +from http import HTTPStatus + +from aiohttp import ClientSession, InvalidUrlClientError, ClientConnectorSSLError, ClientConnectorDNSError, \ + ClientConnectorCertificateError, ClientResponseError, ClientConnectorError, TooManyRedirects, ClientOSError, \ + ServerDisconnectedError, ClientConnectionResetError +from pydantic import ValidationError +from tqdm.asyncio import tqdm_asyncio + +from src.external.url_request.probe.convert import convert_client_response_to_probe_response, convert_to_error_response +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper +from src.util.progress_bar import get_progress_bar_disabled + + +class URLProbeManager: + + def __init__( + self, + session: ClientSession + ): + self.session = session + + async def probe_urls(self, urls: list[str]) -> list[URLProbeResponseOuterWrapper]: + return await tqdm_asyncio.gather( + *[self._probe(url) for url in urls], + timeout=60 * 10, # 10 minutes, + disable=get_progress_bar_disabled() + ) + + async def _probe(self, url: str) -> URLProbeResponseOuterWrapper: + try: + response = await self._head(url) + if not response.is_redirect and response.response.status_code == HTTPStatus.OK: + return response + # Fallback to GET if HEAD fails + return await self._get(url) + except InvalidUrlClientError: + return convert_to_error_response(url, error="Invalid URL") + except ( + ClientConnectorError, + ClientConnectorSSLError, + ClientConnectorDNSError, + ClientConnectorCertificateError, + ServerDisconnectedError, + ClientConnectionResetError + ) as e: + return convert_to_error_response(url, error=str(e)) + except asyncio.exceptions.TimeoutError: + return convert_to_error_response(url, error="Timeout Error") + except ValidationError as e: + raise ValueError(f"Validation Error for {url}.") from e + except ClientOSError as e: + return convert_to_error_response(url, error=f"Client OS Error: {e.errno}. {str(e)}") + + async def _head(self, url: str) -> URLProbeResponseOuterWrapper: + try: + async with self.session.head(url, allow_redirects=True) as response: + return URLProbeResponseOuterWrapper( + original_url=url, + response=convert_client_response_to_probe_response( + url, + response + ) + ) + except TooManyRedirects: + return convert_to_error_response( + url, + error="Too many redirects (> 10)", + ) + except ClientResponseError as e: + return convert_to_error_response( + url, + error=str(e), + status_code=e.status + ) + + async def _get(self, url: str) -> URLProbeResponseOuterWrapper: + try: + async with self.session.get(url, allow_redirects=True) as response: + return URLProbeResponseOuterWrapper( + original_url=url, + response=convert_client_response_to_probe_response( + url, + response + ) + ) + except TooManyRedirects: + return convert_to_error_response( + url, + error="Too many redirects (> 10)", + ) + except ClientResponseError as e: + return convert_to_error_response( + url, + error=str(e), + status_code=e.status + ) diff --git a/src/external/url_request/probe/format.py b/src/external/url_request/probe/format.py new file mode 100644 index 00000000..b528de4d --- /dev/null +++ b/src/external/url_request/probe/format.py @@ -0,0 +1,7 @@ +from aiohttp import ClientResponse, ClientResponseError + +from src.external.url_request.probe.models.response import URLProbeResponse + + +def format_content_type(content_type: str) -> str: + return content_type.split(";")[0].strip() diff --git a/src/external/url_request/probe/models/__init__.py b/src/external/url_request/probe/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/url_request/probe/models/redirect.py b/src/external/url_request/probe/models/redirect.py new file mode 100644 index 00000000..56c9f227 --- /dev/null +++ b/src/external/url_request/probe/models/redirect.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from src.external.url_request.probe.models.response import URLProbeResponse + + +class URLProbeRedirectResponsePair(BaseModel): + source: URLProbeResponse + destination: URLProbeResponse \ No newline at end of file diff --git a/src/external/url_request/probe/models/response.py b/src/external/url_request/probe/models/response.py new file mode 100644 index 00000000..967f1c4f --- /dev/null +++ b/src/external/url_request/probe/models/response.py @@ -0,0 +1,22 @@ +from pydantic import BaseModel, Field, model_validator + + + +class URLProbeResponse(BaseModel): + url: str + status_code: int | None = Field(le=999, ge=100) + content_type: str | None + error: str | None = None + + @model_validator(mode='after') + def check_error_mutually_exclusive_with_content(self): + if self.error is None: + if self.status_code is None: + raise ValueError('Status code required if no error') + return self + + if self.content_type is not None: + raise ValueError('Content type mutually exclusive with error') + + return self + diff --git a/src/external/url_request/probe/models/wrapper.py b/src/external/url_request/probe/models/wrapper.py new file mode 100644 index 00000000..04dbc9c4 --- /dev/null +++ b/src/external/url_request/probe/models/wrapper.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel + +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair +from src.external.url_request.probe.models.response import URLProbeResponse + + +class URLProbeResponseOuterWrapper(BaseModel): + original_url: str + response: URLProbeResponse | URLProbeRedirectResponsePair + + @property + def is_redirect(self) -> bool: + return isinstance(self.response, URLProbeRedirectResponsePair) diff --git a/src/external/url_request/request.py b/src/external/url_request/request.py new file mode 100644 index 00000000..40fc2dd6 --- /dev/null +++ b/src/external/url_request/request.py @@ -0,0 +1,91 @@ +"""Functions for making HTTP requests.""" +from http import HTTPStatus + +from aiohttp import ClientSession, ClientResponseError +from playwright.async_api import async_playwright +from tqdm.asyncio import tqdm + +from src.external.url_request.constants import HTML_CONTENT_TYPE +from src.external.url_request.dtos.request_resources import RequestResources + +from src.external.url_request.dtos.url_response import URLResponseInfo + + +async def execute_get( + session: ClientSession, + url: str +) -> URLResponseInfo: + try: + async with session.get(url, timeout=20) as response: + response.raise_for_status() + text = await response.text() + return URLResponseInfo( + success=True, + html=text, + content_type=response.headers.get("content-type"), + status=HTTPStatus(response.status) + ) + except ClientResponseError as e: + return URLResponseInfo(success=False, status=HTTPStatus(e.status), exception=str(e)) + + +async def get_response(session: ClientSession, url: str) -> URLResponseInfo: + try: + return await execute_get(session, url) + except Exception as e: + print(f"An error occurred while fetching {url}: {e}") + return URLResponseInfo(success=False, exception=str(e)) + + +async def make_simple_requests(urls: list[str]) -> list[URLResponseInfo]: + async with ClientSession() as session: + tasks = [get_response(session, url) for url in urls] + results = await tqdm.gather(*tasks) + return results + + +async def get_dynamic_html_content( + rr: RequestResources, + url: str +) -> URLResponseInfo | None: + # For HTML responses, attempt to load the page to check for dynamic html content + async with rr.semaphore: + page = await rr.browser.new_page() + try: + await page.goto(url) + await page.wait_for_load_state("networkidle") + html_content = await page.content() + return URLResponseInfo( + success=True, + html=html_content, + content_type=HTML_CONTENT_TYPE, + status=HTTPStatus.OK + ) + except Exception as e: + return URLResponseInfo(success=False, exception=str(e)) + finally: + await page.close() + + +async def fetch_and_render( + rr: RequestResources, + url: str +) -> URLResponseInfo | None: + simple_response = await get_response(rr.session, url) + if not simple_response.success: + return simple_response + + if simple_response.content_type != HTML_CONTENT_TYPE: + return simple_response + + return await get_dynamic_html_content(rr, url) + + +async def fetch_urls(urls: list[str]) -> list[URLResponseInfo]: + async with ClientSession() as session: + async with async_playwright() as playwright: + browser = await playwright.chromium.launch(headless=True) + request_resources = RequestResources(session=session, browser=browser) + tasks = [fetch_and_render(request_resources, url) for url in urls] + results = await tqdm.gather(*tasks) + return results diff --git a/src/external/url_request/screenshot_/__init__.py b/src/external/url_request/screenshot_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/external/url_request/screenshot_/constants.py b/src/external/url_request/screenshot_/constants.py new file mode 100644 index 00000000..fc5c11ea --- /dev/null +++ b/src/external/url_request/screenshot_/constants.py @@ -0,0 +1,7 @@ + + + +SCREENSHOT_HEIGHT: int = 800 +SCREENSHOT_WIDTH: int = 1200 + +COMPRESSION_QUALITY: int = 80 \ No newline at end of file diff --git a/src/external/url_request/screenshot_/convert.py b/src/external/url_request/screenshot_/convert.py new file mode 100644 index 00000000..75b62c92 --- /dev/null +++ b/src/external/url_request/screenshot_/convert.py @@ -0,0 +1,13 @@ +from PIL import Image +from io import BytesIO + +from PIL.ImageFile import ImageFile + +from src.external.url_request.screenshot_.constants import COMPRESSION_QUALITY + + +def convert_png_to_webp(png: bytes) -> bytes: + image: ImageFile = Image.open(BytesIO(png)) + output = BytesIO() + image.save(output, format="WEBP", quality=COMPRESSION_QUALITY) + return output.getvalue() diff --git a/src/external/url_request/screenshot_/core.py b/src/external/url_request/screenshot_/core.py new file mode 100644 index 00000000..c7e3c3d4 --- /dev/null +++ b/src/external/url_request/screenshot_/core.py @@ -0,0 +1,54 @@ +from playwright.async_api import async_playwright, Browser, ViewportSize, Page +from tqdm.asyncio import tqdm_asyncio + +from src.external.url_request.constants import NETWORK_IDLE +from src.external.url_request.dtos.screenshot_response import URLScreenshotResponse +from src.external.url_request.screenshot_.constants import SCREENSHOT_HEIGHT, SCREENSHOT_WIDTH +from src.external.url_request.screenshot_.convert import convert_png_to_webp +from src.util.progress_bar import get_progress_bar_disabled + + +async def get_screenshots( + urls: list[str] +) -> list[URLScreenshotResponse]: + responses: list[URLScreenshotResponse] = [] + async with async_playwright() as playwright: + browser: Browser = await playwright.chromium.launch(headless=True) + page: Page = await browser.new_page( + viewport=ViewportSize( + { + "width": SCREENSHOT_WIDTH, + "height": SCREENSHOT_HEIGHT, + } + ) + ) + for url in tqdm_asyncio(urls, disable=get_progress_bar_disabled()): + try: + response: URLScreenshotResponse = await get_screenshot( + page=page, url=url + ) + responses.append(response) + except Exception as e: + responses.append( + URLScreenshotResponse( + url=url, + screenshot=None, + error=str(e) + ) + ) + await page.close() + await browser.close() + return responses + +async def get_screenshot( + page: Page, + url: str, +) -> URLScreenshotResponse: + await page.goto(url) + await page.wait_for_load_state(NETWORK_IDLE) + screenshot_png: bytes = await page.screenshot(type="png") + screenshot_webp: bytes = convert_png_to_webp(screenshot_png) + return URLScreenshotResponse( + url=url, + screenshot=screenshot_webp, + ) diff --git a/src/security/manager.py b/src/security/manager.py index 97bc0da8..16f0519e 100644 --- a/src/security/manager.py +++ b/src/security/manager.py @@ -16,9 +16,7 @@ class SecurityManager: - def __init__( - self - ): + def __init__(self): dotenv.load_dotenv() self.secret_key = os.getenv("DS_APP_SECRET_KEY") diff --git a/src/util/alembic_helpers.py b/src/util/alembic_helpers.py index 3eb18773..cb9d8d67 100644 --- a/src/util/alembic_helpers.py +++ b/src/util/alembic_helpers.py @@ -1,5 +1,10 @@ +import uuid + from alembic import op import sqlalchemy as sa +from sqlalchemy import text +from sqlalchemy.dialects.postgresql import ENUM + def switch_enum_type( table_name, @@ -8,6 +13,7 @@ def switch_enum_type( new_enum_values, drop_old_enum=True, check_constraints_to_drop: list[str] = None, + conversion_mappings: dict[str, str] = None ): """ Switches an ENUM type in a PostgreSQL column by: @@ -21,6 +27,8 @@ def switch_enum_type( :param enum_name: Name of the ENUM type in PostgreSQL. :param new_enum_values: List of new ENUM values. :param drop_old_enum: Whether to drop the old ENUM type. + :param check_constraints_to_drop: List of check constraints to drop before switching the ENUM type. + :param conversion_mappings: Dictionary of old values to new values for the ENUM type. """ # 1. Drop check constraints that reference the enum @@ -38,7 +46,21 @@ def switch_enum_type( new_enum_type.create(op.get_bind()) # Alter the column type to use the new enum type - op.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{column_name}" TYPE "{enum_name}" USING "{column_name}"::text::{enum_name}') + if conversion_mappings is None: + op.execute(f'ALTER TABLE "{table_name}" ALTER COLUMN "{column_name}" TYPE "{enum_name}" USING "{column_name}"::text::{enum_name}') + if conversion_mappings is not None: + case_when: str = "" + for old_value, new_value in conversion_mappings.items(): + case_when += f"WHEN '{old_value}' THEN '{new_value}'\n" + + op.execute(f""" + ALTER TABLE "{table_name}" + ALTER COLUMN "{column_name}" TYPE "{enum_name}" + USING CASE {column_name}::text + {case_when} + ELSE "{column_name}"::text + END::{enum_name}; + """) # Drop the old enum type if drop_old_enum: @@ -61,7 +83,8 @@ def id_column() -> sa.Column: sa.Integer(), primary_key=True, autoincrement=True, - nullable=False + nullable=False, + comment='The primary identifier for the row.' ) def created_at_column() -> sa.Column: @@ -70,7 +93,19 @@ def created_at_column() -> sa.Column: 'created_at', sa.DateTime(), server_default=sa.text('now()'), - nullable=False + nullable=False, + comment='The time the row was created.' + ) + +def enum_column( + column_name, + enum_name +) -> sa.Column: + return sa.Column( + column_name, + ENUM(name=enum_name, create_type=False), + nullable=False, + comment=f'The {column_name} of the row.' ) def updated_at_column() -> sa.Column: @@ -80,18 +115,53 @@ def updated_at_column() -> sa.Column: sa.DateTime(), server_default=sa.text('now()'), server_onupdate=sa.text('now()'), - nullable=False + nullable=False, + comment='The last time the row was updated.' + ) + +def task_id_column() -> sa.Column: + return sa.Column( + 'task_id', + sa.Integer(), + sa.ForeignKey( + 'tasks.id', + ondelete='CASCADE' + ), + nullable=False, + comment='A foreign key to the `tasks` table.' ) -def url_id_column() -> sa.Column: +def url_id_column(name: str = 'url_id', primary_key: bool = False) -> sa.Column: return sa.Column( - 'url_id', + name, sa.Integer(), sa.ForeignKey( 'urls.id', ondelete='CASCADE' ), - nullable=False + primary_key=primary_key, + nullable=False, + comment='A foreign key to the `urls` table.' + ) + +def user_id_column(name: str = 'user_id') -> sa.Column: + return sa.Column( + name, + sa.Integer(), + nullable=False, + ) + + +def location_id_column(name: str = 'location_id') -> sa.Column: + return sa.Column( + name, + sa.Integer(), + sa.ForeignKey( + 'locations.id', + ondelete='CASCADE' + ), + nullable=False, + comment='A foreign key to the `locations` table.' ) def batch_id_column(nullable=False) -> sa.Column: @@ -102,5 +172,127 @@ def batch_id_column(nullable=False) -> sa.Column: 'batches.id', ondelete='CASCADE' ), - nullable=nullable + nullable=nullable, + comment='A foreign key to the `batches` table.' + ) + +def agency_id_column(nullable=False) -> sa.Column: + return sa.Column( + 'agency_id', + sa.Integer(), + sa.ForeignKey( + 'agencies.agency_id', + ondelete='CASCADE' + ), + nullable=nullable, + comment='A foreign key to the `agencies` table.' + ) + +def add_enum_value( + enum_name: str, + enum_value: str +) -> None: + op.execute(f"ALTER TYPE {enum_name} ADD VALUE '{enum_value}'") + + + +def _q_ident(s: str) -> str: + return '"' + s.replace('"', '""') + '"' + + +def _q_label(s: str) -> str: + return "'" + s.replace("'", "''") + "'" + + +def remove_enum_value( + *, + enum_name: str, + value_to_remove: str, + targets: list[tuple[str, str]], # (table, column) + schema: str = "public", +) -> None: + """ + Remove `value_to_remove` from ENUM `schema.enum_name` across the given (table, column) pairs. + Assumes target columns have **no defaults**. + """ + conn = op.get_bind() + + # 1) Load current labels (ordered) + labels = [ + r[0] + for r in conn.execute( + text( + """ + SELECT e.enumlabel + FROM pg_enum e + JOIN pg_type t ON t.oid = e.enumtypid + JOIN pg_namespace n ON n.oid = t.typnamespace + WHERE t.typname = :enum_name + AND n.nspname = :schema + ORDER BY e.enumsortorder + """ + ), + {"enum_name": enum_name, "schema": schema}, + ).fetchall() + ] + if not labels: + raise RuntimeError(f"Enum {schema}.{enum_name!r} not found.") + if value_to_remove not in labels: + return # nothing to do + new_labels = [l for l in labels if l != value_to_remove] + if not new_labels: + raise RuntimeError("Refusing to remove the last remaining enum label.") + + # Deduplicate targets while preserving order + seen = set() + targets = [(t, c) for (t, c) in targets if not ((t, c) in seen or seen.add((t, c)))] + + # 2) Ensure no rows still hold the label + for table, col in targets: + count = conn.execute( + text( + f"SELECT COUNT(*) FROM {_q_ident(schema)}.{_q_ident(table)} " + f"WHERE {_q_ident(col)} = :v" + ), + {"v": value_to_remove}, + ).scalar() + if count and count > 0: + raise RuntimeError( + f"Cannot remove {value_to_remove!r}: {schema}.{table}.{col} " + f"has {count} row(s) with that value. UPDATE or DELETE them first." + ) + + # 3) Create a tmp enum without the value + tmp_name = f"{enum_name}__tmp__{uuid.uuid4().hex[:8]}" + op.execute( + text( + f"CREATE TYPE {_q_ident(schema)}.{_q_ident(tmp_name)} AS ENUM (" + + ", ".join(_q_label(l) for l in new_labels) + + ")" + ) + ) + + # 4) For each column: enum -> text -> tmp_enum + for table, col in targets: + op.execute( + text( + f"ALTER TABLE {_q_ident(schema)}.{_q_ident(table)} " + f"ALTER COLUMN {_q_ident(col)} TYPE TEXT USING {_q_ident(col)}::TEXT" + ) + ) + op.execute( + text( + f"ALTER TABLE {_q_ident(schema)}.{_q_ident(table)} " + f"ALTER COLUMN {_q_ident(col)} TYPE {_q_ident(schema)}.{_q_ident(tmp_name)} " + f"USING {_q_ident(col)}::{_q_ident(schema)}.{_q_ident(tmp_name)}" + ) + ) + + # 5) Swap: drop old enum, rename tmp -> original name + op.execute(text(f"DROP TYPE {_q_ident(schema)}.{_q_ident(enum_name)}")) + op.execute( + text( + f"ALTER TYPE {_q_ident(schema)}.{_q_ident(tmp_name)} " + f"RENAME TO {_q_ident(enum_name)}" + ) ) \ No newline at end of file diff --git a/src/util/clean.py b/src/util/clean.py new file mode 100644 index 00000000..3c0a0f92 --- /dev/null +++ b/src/util/clean.py @@ -0,0 +1,10 @@ + + +def clean_url(url: str) -> str: + # Remove Non-breaking spaces + url = url.strip(" ") + + # Remove any fragments and everything after them + url = url.split("#")[0] + return url + diff --git a/src/util/db_manager.py b/src/util/db_manager.py deleted file mode 100644 index b03708a0..00000000 --- a/src/util/db_manager.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -import psycopg2 -from dotenv import load_dotenv - - -class DBManager: - - def __init__(self, db_name, user, password, host, port): - self.conn = psycopg2.connect( - dbname=db_name, - user=user, - password=password, - host=host, - port=port - ) - self.cursor = self.conn.cursor() - - def __del__(self): - self.conn.close() - - def execute(self, query, params=None): - self.cursor.execute(query, params) - self.conn.commit() - return self.cursor.fetchall() - - def fetchall(self): - return self.cursor.fetchall() - - def fetchone(self): - return self.cursor.fetchone() - - def fetchmany(self, size): - return self.cursor.fetchmany(size) - - def close(self): - self.conn.close() - - -if __name__ == "__main__": - # Note: This is test code to evaluate whether the connection url works. Will be removed in final version. - load_dotenv() - conn_url = os.getenv("DIGITAL_OCEAN_DB_CONNECTION_URL") - conn = psycopg2.connect(conn_url) - - pass \ No newline at end of file diff --git a/src/util/helper_functions.py b/src/util/helper_functions.py index deb6830b..4e33985f 100644 --- a/src/util/helper_functions.py +++ b/src/util/helper_functions.py @@ -16,7 +16,7 @@ def get_project_root(marker_files=(".project-root",)) -> Path: def project_path(*parts: str) -> Path: return get_project_root().joinpath(*parts) -def get_enum_values(enum: Type[Enum]): +def get_enum_values(enum: Type[Enum]) -> list[str]: return [item.value for item in enum] def get_from_env(key: str, allow_none: bool = False): @@ -42,7 +42,11 @@ def load_from_environment(keys: list[str]) -> dict[str, str]: def base_model_list_dump(model_list: list[BaseModel]) -> list[dict]: return [model.model_dump() for model in model_list] -def update_if_not_none(target: dict, source: dict): +def update_if_not_none(target: dict, source: dict) -> None: + """ + Modifies: + target + """ for key, value in source.items(): if value is not None: target[key] = value \ No newline at end of file diff --git a/src/util/miscellaneous_functions.py b/src/util/miscellaneous_functions.py index 4b0bc88b..88e7a6a7 100644 --- a/src/util/miscellaneous_functions.py +++ b/src/util/miscellaneous_functions.py @@ -16,8 +16,8 @@ def create_directories_if_not_exist(file_path: str): Create directories if they don't exist Args: file_path: - - Returns: + Modifies: + file_path """ directory = os.path.dirname(file_path) diff --git a/src/util/progress_bar.py b/src/util/progress_bar.py new file mode 100644 index 00000000..615120ba --- /dev/null +++ b/src/util/progress_bar.py @@ -0,0 +1,8 @@ + +from environs import Env + +def get_progress_bar_disabled() -> bool: + env = Env() + env.read_env() + enabled: bool = env.bool("PROGRESS_BAR_FLAG", True) + return not enabled diff --git a/src/util/url_mapper.py b/src/util/url_mapper.py new file mode 100644 index 00000000..3a399d77 --- /dev/null +++ b/src/util/url_mapper.py @@ -0,0 +1,48 @@ +from src.db.dtos.url.mapping import URLMapping + + +class URLMapper: + + def __init__(self, mappings: list[URLMapping]): + self._url_to_id = { + mapping.url: mapping.url_id + for mapping in mappings + } + self._id_to_url = { + mapping.url_id: mapping.url + for mapping in mappings + } + + def get_id(self, url: str) -> int: + return self._url_to_id[url] + + def get_ids(self, urls: list[str]) -> list[int]: + return [ + self._url_to_id[url] + for url in urls + ] + + def get_all_ids(self) -> list[int]: + return list(self._url_to_id.values()) + + def get_all_urls(self) -> list[str]: + return list(self._url_to_id.keys()) + + def get_url(self, url_id: int) -> str: + return self._id_to_url[url_id] + + def get_mappings_by_url(self, urls: list[str]) -> list[URLMapping]: + return [ + URLMapping( + url_id=self._url_to_id[url], + url=url + ) for url in urls + ] + + def add_mapping(self, mapping: URLMapping) -> None: + self._url_to_id[mapping.url] = mapping.url_id + self._id_to_url[mapping.url_id] = mapping.url + + def add_mappings(self, mappings: list[URLMapping]) -> None: + for mapping in mappings: + self.add_mapping(mapping) \ No newline at end of file diff --git a/start_mirrored_local_app.py b/start_mirrored_local_app.py index 5199fba2..9190fece 100644 --- a/start_mirrored_local_app.py +++ b/start_mirrored_local_app.py @@ -27,15 +27,8 @@ def main(): # Check cache if exists and checker = TimestampChecker() data_dump_container = docker_manager.run_container(data_dumper_docker_info) - if checker.last_run_within_24_hours(): - print("Last run within 24 hours, skipping dump...") - else: - data_dump_container.run_command( - DUMP_SH_DOCKER_PATH, - ) - data_dump_container.run_command( - RESTORE_SH_DOCKER_PATH, - ) + _run_dump_if_longer_than_24_hours(checker, data_dump_container) + _run_database_restore(data_dump_container) print("Stopping datadumper container") data_dump_container.stop() checker.set_last_run_time() @@ -44,6 +37,10 @@ def main(): apply_migrations() # Run `fastapi dev main.py` + _run_fast_api(docker_manager) + + +def _run_fast_api(docker_manager: DockerManager) -> None: try: uvicorn.run( "src.api.main:app", @@ -59,8 +56,22 @@ def main(): print("Containers stopped.") +def _run_database_restore(data_dump_container) -> None: + data_dump_container.run_command( + RESTORE_SH_DOCKER_PATH, + ) +def _run_dump_if_longer_than_24_hours( + checker: TimestampChecker, + data_dump_container +) -> None: + if checker.last_run_within_24_hours(): + print("Last run within 24 hours, skipping dump...") + return + data_dump_container.run_command( + DUMP_SH_DOCKER_PATH, + ) if __name__ == "__main__": diff --git a/tests/alembic/conftest.py b/tests/alembic/conftest.py index 405f5677..f041e94a 100644 --- a/tests/alembic/conftest.py +++ b/tests/alembic/conftest.py @@ -1,34 +1,36 @@ +from typing import Any, Generator + import pytest from alembic.config import Config -from sqlalchemy import create_engine, inspect, MetaData +from sqlalchemy import create_engine, inspect, MetaData, Engine, Connection from sqlalchemy.orm import scoped_session, sessionmaker -from src.db.helpers import get_postgres_connection_string +from src.db.helpers.connect import get_postgres_connection_string from tests.helpers.alembic_runner import AlembicRunner @pytest.fixture() -def alembic_config(): +def alembic_config() -> Generator[Config, Any, None]: alembic_cfg = Config("alembic.ini") yield alembic_cfg @pytest.fixture() -def db_engine(): +def db_engine() -> Generator[Engine, Any, None]: engine = create_engine(get_postgres_connection_string()) yield engine engine.dispose() @pytest.fixture() -def connection(db_engine): +def connection(db_engine) -> Generator[Connection, Any, None]: connection = db_engine.connect() yield connection connection.close() @pytest.fixture() -def alembic_runner(connection, alembic_config) -> AlembicRunner: +def alembic_runner(connection, alembic_config) -> Generator[AlembicRunner, Any, None]: alembic_config.attributes["connection"] = connection alembic_config.set_main_option( "sqlalchemy.url", @@ -41,17 +43,11 @@ def alembic_runner(connection, alembic_config) -> AlembicRunner: connection=connection, session=scoped_session(sessionmaker(bind=connection)), ) - try: - runner.downgrade("base") - except Exception as e: - runner.reset_schema() - runner.stamp("base") + runner.reset_schema() + runner.stamp("base") print("Running test") yield runner print("Test complete") runner.session.close() - try: - runner.downgrade("base") - except Exception as e: - runner.reset_schema() - runner.stamp("base") + runner.reset_schema() + runner.stamp("base") diff --git a/tests/alembic/helpers.py b/tests/alembic/helpers.py index 96e7f62a..a284e0fc 100644 --- a/tests/alembic/helpers.py +++ b/tests/alembic/helpers.py @@ -13,9 +13,8 @@ def table_creation_check( alembic_runner: AlembicRunner, tables: list[str], end_revision: str, - start_revision: Optional[str] = None, - -): + start_revision: str | None = None, +) -> None: if start_revision is not None: alembic_runner.upgrade(start_revision) for table_name in tables: diff --git a/tests/alembic/test_revisions.py b/tests/alembic/test_revisions.py index 19b5d046..94fa6c5e 100644 --- a/tests/alembic/test_revisions.py +++ b/tests/alembic/test_revisions.py @@ -6,4 +6,3 @@ def test_full_upgrade_downgrade(alembic_runner): # Both should run without error alembic_runner.upgrade("head") - alembic_runner.downgrade("base") \ No newline at end of file diff --git a/tests/automated/integration/api/_helpers/RequestValidator.py b/tests/automated/integration/api/_helpers/RequestValidator.py index 33c3120d..73293522 100644 --- a/tests/automated/integration/api/_helpers/RequestValidator.py +++ b/tests/automated/integration/api/_helpers/RequestValidator.py @@ -1,18 +1,12 @@ from http import HTTPStatus from typing import Optional, Annotated -from fastapi import HTTPException +from fastapi import HTTPException, Response from pydantic import BaseModel from starlette.testclient import TestClient -from src.api.endpoints.annotate.agency.get.dto import GetNextURLForAgencyAnnotationResponse -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.api.endpoints.annotate.all.get.dto import GetNextURLForAllAnnotationResponse -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.post import RecordTypeAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.post.dto import RelevanceAnnotationPostInfo +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary @@ -32,14 +26,17 @@ from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse from src.api.endpoints.review.reject.dto import FinalReviewRejectionInfo from src.api.endpoints.search.dtos.response import SearchURLResponse +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse from src.api.endpoints.task.by_id.dto import TaskInfo -from src.api.endpoints.task.dtos.get.tasks import GetTasksResponse from src.api.endpoints.task.dtos.get.task_status import GetTaskStatusResponseInfo +from src.api.endpoints.task.dtos.get.tasks import GetTasksResponse from src.api.endpoints.url.get.dto import GetURLsResponseInfo -from src.db.enums import TaskType -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO from src.collectors.enums import CollectorType +from src.collectors.impl.example.dtos.input import ExampleInputDTO from src.core.enums import BatchStatus +from src.db.enums import TaskType +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum from src.util.helper_functions import update_if_not_none @@ -192,9 +189,8 @@ def delete( def get_batch_statuses( self, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - has_pending_urls: Optional[bool] = None + collector_type: CollectorType | None = None, + status: BatchURLStatusEnum | None = None, ) -> GetBatchSummariesResponse: params = {} update_if_not_none( @@ -202,7 +198,6 @@ def get_batch_statuses( source={ "collector_type": collector_type.value if collector_type else None, "status": status.value if status else None, - "has_pending_urls": has_pending_urls } ) data = self.get( @@ -250,57 +245,6 @@ def abort_batch(self, batch_id: int) -> MessageResponse: ) return MessageResponse(**data) - def get_next_relevance_annotation(self) -> GetNextRelevanceAnnotationResponseOuterInfo: - data = self.get( - url=f"/annotate/relevance" - ) - return GetNextRelevanceAnnotationResponseOuterInfo(**data) - - def get_next_record_type_annotation(self) -> GetNextRecordTypeAnnotationResponseOuterInfo: - data = self.get( - url=f"/annotate/record-type" - ) - return GetNextRecordTypeAnnotationResponseOuterInfo(**data) - - def post_record_type_annotation_and_get_next( - self, - url_id: int, - record_type_annotation_post_info: RecordTypeAnnotationPostInfo - ) -> GetNextRecordTypeAnnotationResponseOuterInfo: - data = self.post_v2( - url=f"/annotate/record-type/{url_id}", - json=record_type_annotation_post_info.model_dump(mode='json') - ) - return GetNextRecordTypeAnnotationResponseOuterInfo(**data) - - def post_relevance_annotation_and_get_next( - self, - url_id: int, - relevance_annotation_post_info: RelevanceAnnotationPostInfo - ) -> GetNextRelevanceAnnotationResponseOuterInfo: - data = self.post_v2( - url=f"/annotate/relevance/{url_id}", - json=relevance_annotation_post_info.model_dump(mode='json') - ) - return GetNextRelevanceAnnotationResponseOuterInfo(**data) - - async def get_next_agency_annotation(self) -> GetNextURLForAgencyAnnotationResponse: - data = self.get( - url=f"/annotate/agency" - ) - return GetNextURLForAgencyAnnotationResponse(**data) - - async def post_agency_annotation_and_get_next( - self, - url_id: int, - agency_annotation_post_info: URLAgencyAnnotationPostInfo - ) -> GetNextURLForAgencyAnnotationResponse: - data = self.post( - url=f"/annotate/agency/{url_id}", - json=agency_annotation_post_info.model_dump(mode='json') - ) - return GetNextURLForAgencyAnnotationResponse(**data) - def get_urls(self, page: int = 1, errors: bool = False) -> GetURLsResponseInfo: data = self.get( url=f"/url", @@ -373,12 +317,16 @@ async def get_current_task_status(self) -> GetTaskStatusResponseInfo: async def get_next_url_for_all_annotations( self, - batch_id: Optional[int] = None + batch_id: int | None = None, + anno_url_id: int | None = None ) -> GetNextURLForAllAnnotationResponse: params = {} update_if_not_none( target=params, - source={"batch_id": batch_id} + source={ + "batch_id": batch_id, + "anno_url_id": anno_url_id + } ) data = self.get( url=f"/annotate/all", @@ -390,12 +338,16 @@ async def post_all_annotations_and_get_next( self, url_id: int, all_annotations_post_info: AllAnnotationPostInfo, - batch_id: Optional[int] = None, + batch_id: int | None = None, + anno_url_id: int | None = None ) -> GetNextURLForAllAnnotationResponse: params = {} update_if_not_none( target=params, - source={"batch_id": batch_id} + source={ + "batch_id": batch_id, + "anno_url_id": anno_url_id + } ) data = self.post( url=f"/annotate/all/{url_id}", @@ -462,4 +414,20 @@ async def get_urls_aggregated_pending_metrics(self) -> GetMetricsURLsAggregatedP data = self.get_v2( url="/metrics/urls/aggregate/pending", ) - return GetMetricsURLsAggregatedPendingResponseDTO(**data) \ No newline at end of file + return GetMetricsURLsAggregatedPendingResponseDTO(**data) + + async def get_url_screenshot(self, url_id: int) -> Response: + return self.client.get( + url=f"/url/{url_id}/screenshot", + headers={"Authorization": f"Bearer token"} + ) + + async def submit_url( + self, + request: URLSubmissionRequest + ) -> URLSubmissionResponse: + response: dict = self.post_v2( + url="/submit/url", + json=request.model_dump(mode='json') + ) + return URLSubmissionResponse(**response) \ No newline at end of file diff --git a/tests/automated/integration/api/annotate/__init__.py b/tests/automated/integration/api/annotate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/annotate/all/__init__.py b/tests/automated/integration/api/annotate/all/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/annotate/all/test_happy_path.py b/tests/automated/integration/api/annotate/all/test_happy_path.py new file mode 100644 index 00000000..48b60b8b --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_happy_path.py @@ -0,0 +1,168 @@ +import pytest + +from src.api.endpoints.annotate.all.get.models.location import LocationAnnotationUserSuggestion +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.get.queries.core import GetNextURLForAllAnnotationQueryBuilder +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.name import AnnotationPostNameInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from src.db.models.impl.url.suggestion.relevant.user import UserURLTypeSuggestion +from tests.helpers.data_creator.models.creation_info.us_state import USStateCreationInfo +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_annotate_all( + api_test_helper, + pennsylvania: USStateCreationInfo, + california: USStateCreationInfo, +): + """ + Test the happy path workflow for the all-annotations endpoint + The user should be able to get a valid URL (filtering on batch id if needed), + submit a full annotation, and receive another URL + """ + ath = api_test_helper + adb_client = ath.adb_client() + + # Set up URLs + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + url_mapping_1 = setup_info_1.url_mapping + setup_info_2 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + url_mapping_2 = setup_info_2.url_mapping + + # Get a valid URL to annotate + get_response_1 = await ath.request_validator.get_next_url_for_all_annotations() + assert get_response_1.next_annotation is not None + assert len(get_response_1.next_annotation.name_suggestions) == 1 + name_suggestion = get_response_1.next_annotation.name_suggestions[0] + assert name_suggestion.name is not None + assert name_suggestion.endorsement_count == 0 + + # Apply the second batch id as a filter and see that a different URL is returned + get_response_2 = await ath.request_validator.get_next_url_for_all_annotations( + batch_id=setup_info_2.batch_id + ) + + assert get_response_2.next_annotation is not None + assert get_response_1.next_annotation.url_info.url_id != get_response_2.next_annotation.url_info.url_id + + # Annotate the first and submit + agency_id = await ath.db_data_creator.agency() + post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.DATA_SOURCE, + record_type=RecordType.ACCIDENT_REPORTS, + agency_info=AnnotationPostAgencyInfo(agency_ids=[agency_id]), + location_info=AnnotationPostLocationInfo( + location_ids=[ + california.location_id, + pennsylvania.location_id, + ] + ), + name_info=AnnotationPostNameInfo( + new_name="New Name" + ) + ) + ) + assert post_response_1.next_annotation is not None + + # Confirm the second is received + assert post_response_1.next_annotation.url_info.url_id == url_mapping_2.url_id + + # Upon submitting the second, confirm that no more URLs are returned through either POST or GET + post_response_2 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_2.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.NOT_RELEVANT, + location_info=AnnotationPostLocationInfo(), + agency_info=AnnotationPostAgencyInfo(), + name_info=AnnotationPostNameInfo( + existing_name_id=setup_info_2.name_suggestion_id + ) + ) + ) + assert post_response_2.next_annotation is None + + get_response_3 = await ath.request_validator.get_next_url_for_all_annotations() + assert get_response_3.next_annotation is None + + + # Check that all annotations are present in the database + + # Check URL Type Suggestions + all_relevance_suggestions: list[UserURLTypeSuggestion] = await adb_client.get_all(UserURLTypeSuggestion) + assert len(all_relevance_suggestions) == 4 + suggested_types: set[URLType] = {sugg.type for sugg in all_relevance_suggestions} + assert suggested_types == {URLType.DATA_SOURCE, URLType.NOT_RELEVANT} + + # Should be one agency + all_agency_suggestions = await adb_client.get_all(UserUrlAgencySuggestion) + assert len(all_agency_suggestions) == 3 + suggested_agency_ids: set[int] = {sugg.agency_id for sugg in all_agency_suggestions} + assert agency_id in suggested_agency_ids + + # Should be one record type + all_record_type_suggestions = await adb_client.get_all(UserRecordTypeSuggestion) + assert len(all_record_type_suggestions) == 3 + suggested_record_types: set[RecordType] = { + sugg.record_type for sugg in all_record_type_suggestions + } + assert RecordType.ACCIDENT_REPORTS.value in suggested_record_types + + # Confirm 3 Location Suggestions, with two belonging to California and one to Pennsylvania + all_location_suggestions = await adb_client.get_all(UserLocationSuggestion) + assert len(all_location_suggestions) == 2 + location_ids: list[int] = [location_suggestion.location_id for location_suggestion in all_location_suggestions] + assert set(location_ids) == {california.location_id, pennsylvania.location_id} + # Confirm that all location suggestions are for the correct URL + for location_suggestion in all_location_suggestions: + assert location_suggestion.url_id == url_mapping_1.url_id + + # Retrieve the same URL (directly from the database, leveraging a different User) + # And confirm the presence of the user annotations + response: GetNextURLForAllAnnotationResponse = await adb_client.run_query_builder( + GetNextURLForAllAnnotationQueryBuilder( + batch_id=None, + user_id=99, + ) + ) + user_suggestions: list[LocationAnnotationUserSuggestion] = \ + response.next_annotation.location_suggestions.user.suggestions + assert len(user_suggestions) == 2 + + response_location_ids: list[int] = [location_suggestion.location_id for location_suggestion in user_suggestions] + assert set(response_location_ids) == {california.location_id, pennsylvania.location_id} + + response_location_names: list[str] = [location_suggestion.location_name for location_suggestion in user_suggestions] + assert set(response_location_names) == { + "California", + "Pennsylvania" + } + + for user_suggestion in user_suggestions: + assert user_suggestion.user_count == 1 + + # Confirm 3 name suggestions + name_suggestions: list[URLNameSuggestion] = await adb_client.get_all(URLNameSuggestion) + assert len(name_suggestions) == 3 + suggested_names: set[str] = {name_suggestion.suggestion for name_suggestion in name_suggestions} + assert "New Name" in suggested_names + + # Confirm 2 link user name suggestions + link_user_name_suggestions: list[LinkUserNameSuggestion] = await adb_client.get_all(LinkUserNameSuggestion) + assert len(link_user_name_suggestions) == 2 + diff --git a/tests/automated/integration/api/annotate/all/test_not_found.py b/tests/automated/integration/api/annotate/all/test_not_found.py new file mode 100644 index 00000000..251b4c0e --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_not_found.py @@ -0,0 +1,48 @@ +import pytest + +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.name import AnnotationPostNameInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.core.enums import RecordType +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.link.user_suggestion_not_found.agency.sqlalchemy import LinkUserSuggestionAgencyNotFound +from src.db.models.impl.link.user_suggestion_not_found.location.sqlalchemy import LinkUserSuggestionLocationNotFound +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_not_found( + api_test_helper, +): + """ + Test that marking a URL as agency or location not found works. + """ + ath = api_test_helper + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + + post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=setup_info_1.url_mapping.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.DATA_SOURCE, + record_type=RecordType.ACCIDENT_REPORTS, + agency_info=AnnotationPostAgencyInfo(not_found=True), + location_info=AnnotationPostLocationInfo( + not_found=True, + ), + name_info=AnnotationPostNameInfo( + new_name="New Name" + ) + ) + ) + + adb_client: AsyncDatabaseClient = ath.adb_client() + + not_found_agencies: list[LinkUserSuggestionAgencyNotFound] = await adb_client.get_all(LinkUserSuggestionAgencyNotFound) + assert len(not_found_agencies) == 1 + + not_found_locations: list[LinkUserSuggestionLocationNotFound] = await adb_client.get_all(LinkUserSuggestionLocationNotFound) + assert len(not_found_locations) == 1 \ No newline at end of file diff --git a/tests/automated/integration/api/annotate/all/test_post_batch_filtering.py b/tests/automated/integration/api/annotate/all/test_post_batch_filtering.py new file mode 100644 index 00000000..a770329d --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_post_batch_filtering.py @@ -0,0 +1,40 @@ +import pytest + +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_annotate_all_post_batch_filtering(api_test_helper): + """ + Batch filtering should also work when posting annotations + """ + ath = api_test_helper + adb_client = ath.adb_client() + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + setup_info_2 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + setup_info_3 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_3 = setup_info_3.url_mapping + + # Submit the first annotation, using the third batch id, and receive the third URL + post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + batch_id=setup_info_3.batch_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.NOT_RELEVANT, + location_info=AnnotationPostLocationInfo(), + agency_info=AnnotationPostAgencyInfo() + ) + ) + + assert post_response_1.next_annotation.url_info.url_id == url_mapping_3.url_id diff --git a/tests/automated/integration/api/annotate/all/test_suspended_url.py b/tests/automated/integration/api/annotate/all/test_suspended_url.py new file mode 100644 index 00000000..3eed8699 --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_suspended_url.py @@ -0,0 +1,29 @@ +import pytest + +from src.db.models.impl.flag.url_suspended.sqlalchemy import FlagURLSuspended +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_annotate_all( + api_test_helper, +): + """ + Test that a suspended URL is not returned for annotation. + """ + ath = api_test_helper + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + + get_response_1 = await ath.request_validator.get_next_url_for_all_annotations() + assert get_response_1.next_annotation is not None + + adb_client = ath.adb_client() + await adb_client.add( + FlagURLSuspended( + url_id=setup_info_1.url_mapping.url_id, + ) + ) + get_response_2 = await ath.request_validator.get_next_url_for_all_annotations() + assert get_response_2.next_annotation is None \ No newline at end of file diff --git a/tests/automated/integration/api/annotate/all/test_url_filtering.py b/tests/automated/integration/api/annotate/all/test_url_filtering.py new file mode 100644 index 00000000..6ca36cb5 --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_url_filtering.py @@ -0,0 +1,44 @@ +import pytest + +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.api_test_helper import APITestHelper +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_annotate_all_post_batch_filtering(api_test_helper: APITestHelper): + """ + Test that URL filtering works when getting and posting annotations + """ + ath = api_test_helper + adb_client: AsyncDatabaseClient = ath.adb_client() + + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + setup_info_2 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + setup_info_3 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_3 = setup_info_3.url_mapping + + get_response_2 = await ath.request_validator.get_next_url_for_all_annotations( + batch_id=setup_info_3.batch_id, + anno_url_id=url_mapping_3.url_id + ) + assert get_response_2.next_annotation.url_info.url_id == url_mapping_3.url_id + + post_response_3 = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + anno_url_id=url_mapping_3.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.NOT_RELEVANT, + ) + ) + + assert post_response_3.next_annotation.url_info.url_id == url_mapping_3.url_id \ No newline at end of file diff --git a/tests/automated/integration/api/annotate/all/test_validation_error.py b/tests/automated/integration/api/annotate/all/test_validation_error.py new file mode 100644 index 00000000..db9e336a --- /dev/null +++ b/tests/automated/integration/api/annotate/all/test_validation_error.py @@ -0,0 +1,32 @@ +import pytest + +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.core.enums import RecordType +from src.core.exceptions import FailedValidationException +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review + + +@pytest.mark.asyncio +async def test_annotate_all_validation_error(api_test_helper): + """ + Validation errors in the PostInfo DTO should result in a 400 BAD REQUEST response + """ + ath = api_test_helper + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=False + ) + url_mapping_1 = setup_info_1.url_mapping + + with pytest.raises(FailedValidationException) as e: + response = await ath.request_validator.post_all_annotations_and_get_next( + url_id=url_mapping_1.url_id, + all_annotations_post_info=AllAnnotationPostInfo( + suggested_status=URLType.NOT_RELEVANT, + record_type=RecordType.ACCIDENT_REPORTS, + location_info=AnnotationPostLocationInfo(), + agency_info=AnnotationPostAgencyInfo() + ) + ) diff --git a/tests/automated/integration/api/annotate/anonymous/__init__.py b/tests/automated/integration/api/annotate/anonymous/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/annotate/anonymous/helper.py b/tests/automated/integration/api/annotate/anonymous/helper.py new file mode 100644 index 00000000..ccfe518f --- /dev/null +++ b/tests/automated/integration/api/annotate/anonymous/helper.py @@ -0,0 +1,23 @@ +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from tests.automated.integration.api._helpers.RequestValidator import RequestValidator + + +async def get_next_url_for_anonymous_annotation( + request_validator: RequestValidator, +): + data = request_validator.get( + url=f"/annotate/anonymous" + ) + return GetNextURLForAllAnnotationResponse(**data) + +async def post_and_get_next_url_for_anonymous_annotation( + request_validator: RequestValidator, + url_id: int, + all_annotation_post_info: AllAnnotationPostInfo, +): + data = request_validator.post( + url=f"/annotate/anonymous/{url_id}", + json=all_annotation_post_info.model_dump(mode='json') + ) + return GetNextURLForAllAnnotationResponse(**data) \ No newline at end of file diff --git a/tests/automated/integration/api/annotate/anonymous/test_core.py b/tests/automated/integration/api/annotate/anonymous/test_core.py new file mode 100644 index 00000000..4b747363 --- /dev/null +++ b/tests/automated/integration/api/annotate/anonymous/test_core.py @@ -0,0 +1,83 @@ +import pytest + +from src.api.endpoints.annotate.all.get.models.name import NameAnnotationSuggestion +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.name import AnnotationPostNameInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.core.enums import RecordType +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.url.suggestion.anonymous.agency.sqlalchemy import AnonymousAnnotationAgency +from src.db.models.impl.url.suggestion.anonymous.location.sqlalchemy import AnonymousAnnotationLocation +from src.db.models.impl.url.suggestion.anonymous.record_type.sqlalchemy import AnonymousAnnotationRecordType +from src.db.models.impl.url.suggestion.anonymous.url_type.sqlalchemy import AnonymousAnnotationURLType +from src.db.models.mixins import URLDependentMixin +from tests.automated.integration.api.annotate.anonymous.helper import get_next_url_for_anonymous_annotation, \ + post_and_get_next_url_for_anonymous_annotation +from tests.helpers.data_creator.models.creation_info.us_state import USStateCreationInfo +from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review +from tests.helpers.setup.final_review.model import FinalReviewSetupInfo + + +@pytest.mark.asyncio +async def test_annotate_anonymous( + api_test_helper, + pennsylvania: USStateCreationInfo, +): + ath = api_test_helper + ddc = ath.db_data_creator + rv = ath.request_validator + + # Set up URLs + setup_info_1 = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + url_mapping_1: URLMapping = setup_info_1.url_mapping + setup_info_2: FinalReviewSetupInfo = await setup_for_get_next_url_for_final_review( + db_data_creator=ath.db_data_creator, include_user_annotations=True + ) + url_mapping_2: URLMapping = setup_info_2.url_mapping + + get_response_1: GetNextURLForAllAnnotationResponse = await get_next_url_for_anonymous_annotation(rv) + assert get_response_1.next_annotation is not None + assert len(get_response_1.next_annotation.name_suggestions) == 1 + name_suggestion: NameAnnotationSuggestion = get_response_1.next_annotation.name_suggestions[0] + assert name_suggestion.name is not None + assert name_suggestion.endorsement_count == 0 + + agency_id: int = await ddc.agency() + + post_response_1: GetNextURLForAllAnnotationResponse = await post_and_get_next_url_for_anonymous_annotation( + rv, + get_response_1.next_annotation.url_info.url_id, + AllAnnotationPostInfo( + suggested_status=URLType.DATA_SOURCE, + record_type=RecordType.ACCIDENT_REPORTS, + agency_info=AnnotationPostAgencyInfo(agency_ids=[agency_id]), + location_info=AnnotationPostLocationInfo( + location_ids=[ + pennsylvania.location_id, + ] + ), + name_info=AnnotationPostNameInfo( + new_name="New Name" + ) + ) + ) + + assert post_response_1.next_annotation is not None + assert post_response_1.next_annotation.url_info.url_id != get_response_1.next_annotation.url_info.url_id + + for model in [ + AnonymousAnnotationAgency, + AnonymousAnnotationLocation, + AnonymousAnnotationRecordType, + AnonymousAnnotationURLType + ]: + instances: list[URLDependentMixin] = await ddc.adb_client.get_all(model) + assert len(instances) == 1 + instance: model = instances[0] + assert instance.url_id == get_response_1.next_annotation.url_info.url_id + diff --git a/tests/automated/integration/api/annotate/helpers.py b/tests/automated/integration/api/annotate/helpers.py new file mode 100644 index 00000000..39cfedab --- /dev/null +++ b/tests/automated/integration/api/annotate/helpers.py @@ -0,0 +1,22 @@ +from src.core.tasks.url.operators.html.scraper.parser.dtos.response_html import ResponseHTMLInfo +from src.db.dtos.url.mapping import URLMapping + + +def check_url_mappings_match( + map_1: URLMapping, + map_2: URLMapping +): + assert map_1.url_id == map_2.url_id + assert map_2.url == map_2.url + + +def check_html_info_not_empty( + html_info: ResponseHTMLInfo +): + assert not html_info_empty(html_info) + + +def html_info_empty( + html_info: ResponseHTMLInfo +) -> bool: + return html_info.description == "" and html_info.title == "" diff --git a/tests/automated/integration/api/annotate/test_.py b/tests/automated/integration/api/annotate/test_.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/batch/__init__.py b/tests/automated/integration/api/batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/batch/summaries/__init__.py b/tests/automated/integration/api/batch/summaries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/batch/summaries/test_happy_path.py b/tests/automated/integration/api/batch/summaries/test_happy_path.py new file mode 100644 index 00000000..f6e28238 --- /dev/null +++ b/tests/automated/integration/api/batch/summaries/test_happy_path.py @@ -0,0 +1,96 @@ +import pytest + +from src.core.enums import BatchStatus +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters + + +@pytest.mark.asyncio +async def test_get_batch_summaries(api_test_helper): + ath = api_test_helper + + batch_params = [ + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=1, + status=URLCreationEnum.OK + ), + TestURLCreationParameters( + count=2, + status=URLCreationEnum.SUBMITTED + ) + ] + ), + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=4, + status=URLCreationEnum.NOT_RELEVANT + ), + TestURLCreationParameters( + count=3, + status=URLCreationEnum.ERROR + ) + ] + ), + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=7, + status=URLCreationEnum.DUPLICATE + ), + TestURLCreationParameters( + count=1, + status=URLCreationEnum.SUBMITTED + ) + ] + ) + ] + + batch_1_creation_info = await ath.db_data_creator.batch_v2(batch_params[0]) + batch_2_creation_info = await ath.db_data_creator.batch_v2(batch_params[1]) + batch_3_creation_info = await ath.db_data_creator.batch_v2(batch_params[2]) + + batch_1_id = batch_1_creation_info.batch_id + batch_2_id = batch_2_creation_info.batch_id + batch_3_id = batch_3_creation_info.batch_id + + await ath.adb_client().refresh_materialized_views() + + response = ath.request_validator.get_batch_statuses() + results = response.results + + assert len(results) == 3 + + result_1 = results[0] + assert result_1.id == batch_1_id + assert result_1.status == BatchStatus.READY_TO_LABEL + counts_1 = result_1.url_counts + assert counts_1.total == 3 + assert counts_1.pending == 1 + assert counts_1.submitted == 2 + assert counts_1.not_relevant == 0 + assert counts_1.duplicate == 0 + assert counts_1.errored == 0 + + result_2 = results[1] + assert result_2.id == batch_2_id + counts_2 = result_2.url_counts + assert counts_2.total == 7 + assert counts_2.not_relevant == 4 + assert counts_2.errored == 3 + assert counts_2.pending == 3 + assert counts_2.submitted == 0 + assert counts_2.duplicate == 0 + + result_3 = results[2] + assert result_3.id == batch_3_id + counts_3 = result_3.url_counts + assert counts_3.total == 8 + assert counts_3.not_relevant == 0 + assert counts_3.errored == 0 + assert counts_3.pending == 7 + assert counts_3.submitted == 1 + assert counts_3.duplicate == 7 diff --git a/tests/automated/integration/api/batch/summaries/test_pending_url_filter.py b/tests/automated/integration/api/batch/summaries/test_pending_url_filter.py new file mode 100644 index 00000000..c471b6fa --- /dev/null +++ b/tests/automated/integration/api/batch/summaries/test_pending_url_filter.py @@ -0,0 +1,59 @@ +import pytest + +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.dtos.url.mapping import URLMapping +from src.db.models.views.batch_url_status.enums import BatchURLStatusEnum +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_get_batch_summaries_pending_url_filter(api_test_helper): + ath = api_test_helper + dbdc: DBDataCreator = ath.db_data_creator + + # Add an errored out batch + batch_error: int = await dbdc.create_batch(status=BatchStatus.ERROR) + + # Add a batch with pending urls + batch_pending = await ath.db_data_creator.batch_and_urls( + strategy=CollectorType.EXAMPLE, + url_count=2, + batch_status=BatchStatus.READY_TO_LABEL, + with_html_content=True, + url_status=URLCreationEnum.OK + ) + + # Add a batch with submitted URLs + batch_submitted: int = await dbdc.create_batch(status=BatchStatus.READY_TO_LABEL) + submitted_url_mappings: list[URLMapping] = await dbdc.create_submitted_urls(count=2) + submitted_url_ids: list[int] = [url_mapping.url_id for url_mapping in submitted_url_mappings] + await dbdc.create_batch_url_links( + batch_id=batch_submitted, + url_ids=submitted_url_ids + ) + + # Add an aborted batch + batch_aborted: int = await dbdc.create_batch(status=BatchStatus.ABORTED) + + # Add a batch with validated URLs + batch_validated: int = await dbdc.create_batch(status=BatchStatus.READY_TO_LABEL) + validated_url_mappings: list[URLMapping] = await dbdc.create_validated_urls( + count=2 + ) + validated_url_ids: list[int] = [url_mapping.url_id for url_mapping in validated_url_mappings] + await dbdc.create_batch_url_links( + batch_id=batch_validated, + url_ids=validated_url_ids + ) + + await dbdc.adb_client.refresh_materialized_views() + + # Test filter for pending URLs and only retrieve the second batch + pending_urls_results = ath.request_validator.get_batch_statuses( + status=BatchURLStatusEnum.HAS_UNLABELED_URLS + ) + + assert len(pending_urls_results.results) == 1 + assert pending_urls_results.results[0].id == batch_pending.batch_id diff --git a/tests/automated/integration/api/batch/test_batch.py b/tests/automated/integration/api/batch/test_batch.py new file mode 100644 index 00000000..f1e3d4f2 --- /dev/null +++ b/tests/automated/integration/api/batch/test_batch.py @@ -0,0 +1,47 @@ +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary +from src.db.models.impl.batch.pydantic.info import BatchInfo +from src.db.dtos.url.insert import InsertURLsInfo +from src.collectors.impl.example.dtos.input import ExampleInputDTO +from src.core.enums import BatchStatus + +def test_get_batch_urls(api_test_helper): + + # Insert batch and urls into database + ath = api_test_helper + batch_id = ath.db_data_creator.batch() + iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=101) + + response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=1) + assert len(response.urls) == 100 + # Check that the first url corresponds to the first url inserted + assert response.urls[0].url == iui.url_mappings[0].url + # Check that the last url corresponds to the 100th url inserted + assert response.urls[-1].url == iui.url_mappings[99].url + + + # Check that a more limited set of urls exist + response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=2) + assert len(response.urls) == 1 + # Check that this url corresponds to the last url inserted + assert response.urls[0].url == iui.url_mappings[-1].url + +def test_get_duplicate_urls(api_test_helper): + + # Insert batch and url into database + ath = api_test_helper + batch_id = ath.db_data_creator.batch() + iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=101) + # Get a list of all url ids + url_ids = [url.url_id for url in iui.url_mappings] + + # Create a second batch which will be associated with the duplicates + dup_batch_id = ath.db_data_creator.batch() + + # Insert duplicate urls into database + ath.db_data_creator.duplicate_urls(duplicate_batch_id=dup_batch_id, url_ids=url_ids) + + response = ath.request_validator.get_batch_url_duplicates(batch_id=dup_batch_id, page=1) + assert len(response.duplicates) == 100 + + response = ath.request_validator.get_batch_url_duplicates(batch_id=dup_batch_id, page=2) + assert len(response.duplicates) == 1 \ No newline at end of file diff --git a/tests/automated/integration/api/conftest.py b/tests/automated/integration/api/conftest.py index d07e92d5..fa019469 100644 --- a/tests/automated/integration/api/conftest.py +++ b/tests/automated/integration/api/conftest.py @@ -5,12 +5,11 @@ import pytest_asyncio from starlette.testclient import TestClient -from src.api.endpoints.review.routes import requires_final_review_permission from src.api.main import app from src.core.core import AsyncCore -from src.security.manager import get_access_info from src.security.dtos.access_info import AccessInfo from src.security.enums import Permissions +from src.security.manager import get_access_info from tests.automated.integration.api._helpers.RequestValidator import RequestValidator from tests.helpers.api_test_helper import APITestHelper @@ -36,12 +35,11 @@ def override_access_info() -> AccessInfo: ] ) + @pytest.fixture(scope="session") -def client() -> Generator[TestClient, None, None]: - # Mock environment +def client(disable_task_flags) -> Generator[TestClient, None, None]: with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info - app.dependency_overrides[requires_final_review_permission] = override_access_info async_core: AsyncCore = c.app.state.async_core # Interfaces to the web should be mocked diff --git a/tests/automated/integration/api/metrics/batches/test_aggregated.py b/tests/automated/integration/api/metrics/batches/test_aggregated.py index 084762b9..090896e8 100644 --- a/tests/automated/integration/api/metrics/batches/test_aggregated.py +++ b/tests/automated/integration/api/metrics/batches/test_aggregated.py @@ -2,44 +2,65 @@ from src.collectors.enums import CollectorType, URLStatus from src.core.enums import BatchStatus +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.helpers.connect import get_postgres_connection_string +from src.db.models.impl.flag.url_validated.enums import URLType from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.create import create_batch, create_url_data_sources, create_urls, \ + create_batch_url_links, create_validated_flags +from tests.helpers.setup.wipe import wipe_database @pytest.mark.asyncio -async def test_get_batches_aggregated_metrics(api_test_helper): +async def test_get_batches_aggregated_metrics( + api_test_helper, + wiped_database +): ath = api_test_helper + adb_client: AsyncDatabaseClient = ath.adb_client() # Create successful batches with URLs of different statuses - all_params = [] for i in range(3): - params = TestBatchCreationParameters( + batch_id = await create_batch( + adb_client=adb_client, strategy=CollectorType.MANUAL, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING - ), - TestURLCreationParameters( - count=2, - status=URLStatus.SUBMITTED - ), - TestURLCreationParameters( - count=3, - status=URLStatus.NOT_RELEVANT - ), - TestURLCreationParameters( - count=4, - status=URLStatus.ERROR - ), - TestURLCreationParameters( - count=5, - status=URLStatus.VALIDATED - ) - ] ) - all_params.append(params) - + url_mappings_error: list[URLMapping] = await create_urls( + adb_client=adb_client, + status=URLStatus.ERROR, + count=4, + ) + url_mappings_ok: list[URLMapping] = await create_urls( + adb_client=adb_client, + status=URLStatus.OK, + count=11, + ) + url_mappings_all: list[URLMapping] = url_mappings_error + url_mappings_ok + url_ids_all: list[int] = [url_mapping.url_id for url_mapping in url_mappings_all] + await create_batch_url_links( + adb_client=adb_client, + batch_id=batch_id, + url_ids=url_ids_all, + ) + urls_submitted: list[int] = url_ids_all[:2] + urls_not_relevant: list[int] = url_ids_all[2:5] + urls_validated: list[int] = url_ids_all[5:10] + await create_validated_flags( + adb_client=adb_client, + url_ids=urls_validated + urls_submitted, + validation_type=URLType.DATA_SOURCE, + ) + await create_validated_flags( + adb_client=adb_client, + url_ids=urls_not_relevant, + validation_type=URLType.NOT_RELEVANT, + ) + await create_url_data_sources( + adb_client=adb_client, + url_ids=urls_submitted, + ) + all_params = [] # Create failed batches for i in range(2): params = TestBatchCreationParameters( @@ -66,8 +87,8 @@ async def test_get_batches_aggregated_metrics(api_test_helper): assert inner_dto_manual.count_urls == 45 assert inner_dto_manual.count_successful_batches == 3 assert inner_dto_manual.count_failed_batches == 0 - assert inner_dto_manual.count_urls_pending == 3 + assert inner_dto_manual.count_urls_pending == 15 assert inner_dto_manual.count_urls_submitted == 6 assert inner_dto_manual.count_urls_rejected == 9 assert inner_dto_manual.count_urls_errors == 12 - assert inner_dto_manual.count_urls_validated == 15 + assert inner_dto_manual.count_urls_validated == 30 diff --git a/tests/automated/integration/api/metrics/batches/test_breakdown.py b/tests/automated/integration/api/metrics/batches/test_breakdown.py index 0cce8740..c6ef6e0b 100644 --- a/tests/automated/integration/api/metrics/batches/test_breakdown.py +++ b/tests/automated/integration/api/metrics/batches/test_breakdown.py @@ -1,79 +1,102 @@ +from datetime import datetime, timedelta + import pendulum import pytest from src.collectors.enums import CollectorType, URLStatus from src.core.enums import BatchStatus -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.data_creator.create import create_batch, create_urls, create_batch_url_links, create_validated_flags, \ + create_url_data_sources @pytest.mark.asyncio async def test_get_batches_breakdown_metrics(api_test_helper): # Create a different batch for each month, with different URLs - today = pendulum.parse('2021-01-01') + today = datetime.now() ath = api_test_helper + adb_client: AsyncDatabaseClient = ath.adb_client() - batch_1_params = TestBatchCreationParameters( + batch_id_1 = await create_batch( + adb_client=adb_client, strategy=CollectorType.MANUAL, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING - ), - TestURLCreationParameters( - count=2, - status=URLStatus.SUBMITTED - ), - ] ) - batch_1 = await ath.db_data_creator.batch_v2(batch_1_params) - batch_2_params = TestBatchCreationParameters( - strategy=CollectorType.EXAMPLE, - outcome=BatchStatus.ERROR, - created_at=today.subtract(weeks=1), + url_mappings_1: list[URLMapping] = await create_urls( + adb_client=adb_client, + count=3, + ) + url_ids_1: list[int] = [url_mapping.url_id for url_mapping in url_mappings_1] + await create_batch_url_links(adb_client=adb_client, batch_id=batch_id_1, url_ids=url_ids_1) + await create_validated_flags( + adb_client=adb_client, + url_ids=url_ids_1[:2], + validation_type=URLType.DATA_SOURCE + ) + await create_url_data_sources( + adb_client=adb_client, + url_ids=url_ids_1[:2], ) - batch_2 = await ath.db_data_creator.batch_v2(batch_2_params) - batch_3_params = TestBatchCreationParameters( + + batch_id_2 = await create_batch( + adb_client=adb_client, + status=BatchStatus.ERROR, + date_generated=today - timedelta(days=7), + ) + + batch_id_3 = await create_batch( + adb_client=adb_client, strategy=CollectorType.AUTO_GOOGLER, - created_at=today.subtract(weeks=2), - urls=[ - TestURLCreationParameters( - count=3, - status=URLStatus.NOT_RELEVANT - ), - TestURLCreationParameters( - count=4, - status=URLStatus.ERROR - ), - TestURLCreationParameters( - count=5, - status=URLStatus.VALIDATED - ), - ] + date_generated=today - timedelta(days=14) ) - batch_3 = await ath.db_data_creator.batch_v2(batch_3_params) + error_url_mappings: list[URLMapping] = await create_urls( + adb_client=adb_client, + status=URLStatus.ERROR, + count=4, + ) + error_url_ids: list[int] = [url_mapping.url_id for url_mapping in error_url_mappings] + validated_url_mappings: list[URLMapping] = await create_urls( + adb_client=adb_client, + count=8, + ) + validated_url_ids: list[int] = [url_mapping.url_id for url_mapping in validated_url_mappings] + await create_validated_flags( + adb_client=adb_client, + url_ids=validated_url_ids[:3], + validation_type=URLType.NOT_RELEVANT, + ) + await create_validated_flags( + adb_client=adb_client, + url_ids=validated_url_ids[4:9], + validation_type=URLType.DATA_SOURCE, + ) + await create_batch_url_links( + adb_client=adb_client, + batch_id=batch_id_3, + url_ids=error_url_ids + validated_url_ids, + ) + dto_1 = await ath.request_validator.get_batches_breakdown_metrics( page=1 ) assert len(dto_1.batches) == 3 dto_batch_1 = dto_1.batches[2] - assert dto_batch_1.batch_id == batch_1.batch_id + assert dto_batch_1.batch_id == batch_id_1 assert dto_batch_1.strategy == CollectorType.MANUAL assert dto_batch_1.status == BatchStatus.READY_TO_LABEL - assert pendulum.instance(dto_batch_1.created_at) > today assert dto_batch_1.count_url_total == 3 assert dto_batch_1.count_url_pending == 1 assert dto_batch_1.count_url_submitted == 2 assert dto_batch_1.count_url_rejected == 0 assert dto_batch_1.count_url_error == 0 - assert dto_batch_1.count_url_validated == 0 + assert dto_batch_1.count_url_validated == 2 dto_batch_2 = dto_1.batches[1] - assert dto_batch_2.batch_id == batch_2.batch_id + assert dto_batch_2.batch_id == batch_id_2 assert dto_batch_2.status == BatchStatus.ERROR assert dto_batch_2.strategy == CollectorType.EXAMPLE - assert pendulum.instance(dto_batch_2.created_at) == today.subtract(weeks=1) assert dto_batch_2.count_url_total == 0 assert dto_batch_2.count_url_submitted == 0 assert dto_batch_2.count_url_pending == 0 @@ -82,16 +105,15 @@ async def test_get_batches_breakdown_metrics(api_test_helper): assert dto_batch_2.count_url_validated == 0 dto_batch_3 = dto_1.batches[0] - assert dto_batch_3.batch_id == batch_3.batch_id + assert dto_batch_3.batch_id == batch_id_3 assert dto_batch_3.status == BatchStatus.READY_TO_LABEL assert dto_batch_3.strategy == CollectorType.AUTO_GOOGLER - assert pendulum.instance(dto_batch_3.created_at) == today.subtract(weeks=2) assert dto_batch_3.count_url_total == 12 - assert dto_batch_3.count_url_pending == 0 + assert dto_batch_3.count_url_pending == 5 assert dto_batch_3.count_url_submitted == 0 assert dto_batch_3.count_url_rejected == 3 assert dto_batch_3.count_url_error == 4 - assert dto_batch_3.count_url_validated == 5 + assert dto_batch_3.count_url_validated == 7 dto_2 = await ath.request_validator.get_batches_breakdown_metrics( page=2 diff --git a/tests/automated/integration/api/metrics/test_backlog.py b/tests/automated/integration/api/metrics/test_backlog.py index a6807a23..da8dccd6 100644 --- a/tests/automated/integration/api/metrics/test_backlog.py +++ b/tests/automated/integration/api/metrics/test_backlog.py @@ -1,11 +1,10 @@ import pendulum import pytest -from src.collectors.enums import CollectorType, URLStatus -from src.core.enums import SuggestedStatus -from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from src.collectors.enums import URLStatus +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -14,29 +13,22 @@ async def test_get_backlog_metrics(api_test_helper): ath = api_test_helper adb_client = ath.adb_client() + ddc: DBDataCreator = ath.db_data_creator # Populate the backlog table and test that backlog metrics returned on a monthly basis # Ensure that multiple days in each month are added to the backlog table, with different values - - batch_1_params = TestBatchCreationParameters( - strategy=CollectorType.MANUAL, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING, - annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT - ) - ), - TestURLCreationParameters( - count=2, - status=URLStatus.SUBMITTED - ), - ] + batch_1_id: int = await ddc.create_batch() + url_mappings_1: list[URLMapping] = await ddc.create_urls(count=3) + url_ids_1: list[int] = [url_mapping.url_id for url_mapping in url_mappings_1] + await ddc.create_batch_url_links(url_ids=url_ids_1, batch_id=batch_1_id) + submitted_url_ids_1: list[int] = url_ids_1[:2] + await ddc.create_validated_flags( + url_ids=submitted_url_ids_1, + validation_type=URLType.DATA_SOURCE ) - batch_1 = await ath.db_data_creator.batch_v2(batch_1_params) + await ddc.create_url_data_sources(url_ids=submitted_url_ids_1) await adb_client.populate_backlog_snapshot( dt=today.subtract(months=3).naive() @@ -46,23 +38,20 @@ async def test_get_backlog_metrics(api_test_helper): dt=today.subtract(months=2, days=3).naive() ) - batch_2_params = TestBatchCreationParameters( - strategy=CollectorType.AUTO_GOOGLER, - urls=[ - TestURLCreationParameters( - count=4, - status=URLStatus.PENDING, - annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT - ) - ), - TestURLCreationParameters( - count=2, - status=URLStatus.ERROR - ), - ] + batch_2_id: int = await ddc.create_batch() + not_relevant_url_mappings_2: list[URLMapping] = await ddc.create_urls(count=6) + not_relevant_url_ids_2: list[int] = [url_mapping.url_id for url_mapping in not_relevant_url_mappings_2] + await ddc.create_batch_url_links(url_ids=not_relevant_url_ids_2, batch_id=batch_2_id) + await ddc.create_validated_flags( + url_ids=not_relevant_url_ids_2[:4], + validation_type=URLType.NOT_RELEVANT + ) + error_url_mappings_2: list[URLMapping] = await ddc.create_urls( + status=URLStatus.ERROR, + count=2 ) - batch_2 = await ath.db_data_creator.batch_v2(batch_2_params) + error_url_ids_2: list[int] = [url_mapping.url_id for url_mapping in error_url_mappings_2] + await ddc.create_batch_url_links(url_ids=error_url_ids_2, batch_id=batch_2_id) await adb_client.populate_backlog_snapshot( dt=today.subtract(months=2).naive() @@ -72,23 +61,15 @@ async def test_get_backlog_metrics(api_test_helper): dt=today.subtract(months=1, days=4).naive() ) - batch_3_params = TestBatchCreationParameters( - strategy=CollectorType.AUTO_GOOGLER, - urls=[ - TestURLCreationParameters( - count=7, - status=URLStatus.PENDING, - annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT - ) - ), - TestURLCreationParameters( - count=5, - status=URLStatus.VALIDATED - ), - ] + batch_3_id: int = await ddc.create_batch() + url_mappings_3: list[URLMapping] = await ddc.create_urls(count=12) + url_ids_3: list[int] = [url_mapping.url_id for url_mapping in url_mappings_3] + await ddc.create_batch_url_links(url_ids=url_ids_3, batch_id=batch_3_id) + await ddc.create_validated_flags( + url_ids=url_ids_3[:5], + validation_type=URLType.DATA_SOURCE ) - batch_3 = await ath.db_data_creator.batch_v2(batch_3_params) + await adb_client.populate_backlog_snapshot( dt=today.subtract(months=1).naive() @@ -100,5 +81,5 @@ async def test_get_backlog_metrics(api_test_helper): # Test that the count closest to the beginning of the month is returned for each month assert dto.entries[0].count_pending_total == 1 - assert dto.entries[1].count_pending_total == 5 - assert dto.entries[2].count_pending_total == 12 + assert dto.entries[1].count_pending_total == 3 + assert dto.entries[2].count_pending_total == 10 diff --git a/tests/automated/integration/api/metrics/urls/aggregated/test_core.py b/tests/automated/integration/api/metrics/urls/aggregated/test_core.py index 15b48f1e..64ae5ae4 100644 --- a/tests/automated/integration/api/metrics/urls/aggregated/test_core.py +++ b/tests/automated/integration/api/metrics/urls/aggregated/test_core.py @@ -1,75 +1,70 @@ +from datetime import datetime, timedelta, timezone + import pendulum import pytest from src.collectors.enums import CollectorType, URLStatus +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.url_validated.enums import URLType from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio async def test_get_urls_aggregated_metrics(api_test_helper): ath = api_test_helper - today = pendulum.parse('2021-01-01') + today = datetime.now() + + ddc: DBDataCreator = ath.db_data_creator batch_0_params = TestBatchCreationParameters( strategy=CollectorType.MANUAL, - created_at=today.subtract(days=1), + created_at=today - timedelta(days=1), urls=[ TestURLCreationParameters( count=1, - status=URLStatus.PENDING, + status=URLCreationEnum.OK, ), ] ) - batch_0 = await ath.db_data_creator.batch_v2(batch_0_params) - oldest_url_id = batch_0.url_creation_infos[URLStatus.PENDING].url_mappings[0].url_id - + batch_0: int = await ddc.create_batch( + strategy=CollectorType.MANUAL, + date_generated=today - timedelta(days=1) + ) + url_mappings_0: list[URLMapping] = await ddc.create_urls(batch_id=batch_0) + oldest_url_id: int = url_mappings_0[0].url_id - batch_1_params = TestBatchCreationParameters( + batch_1: int = await ddc.create_batch( strategy=CollectorType.MANUAL, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING, - ), - TestURLCreationParameters( - count=2, - status=URLStatus.SUBMITTED - ), - ] ) - batch_1 = await ath.db_data_creator.batch_v2(batch_1_params) + url_mappings_1_ok: list[URLMapping] = await ddc.create_urls(batch_id=batch_1, count=1) + url_mappings_1_submitted: list[URLMapping] = await ddc.create_submitted_urls(count=2) + url_ids_1_submitted: list[int] = [url_mapping.url_id for url_mapping in url_mappings_1_submitted] + await ddc.create_batch_url_links(url_ids=url_ids_1_submitted, batch_id=batch_1) - batch_2_params = TestBatchCreationParameters( + batch_2: int = await ddc.create_batch( strategy=CollectorType.AUTO_GOOGLER, - urls=[ - TestURLCreationParameters( - count=4, - status=URLStatus.PENDING, - ), - TestURLCreationParameters( - count=2, - status=URLStatus.ERROR - ), - TestURLCreationParameters( - count=1, - status=URLStatus.VALIDATED - ), - TestURLCreationParameters( - count=5, - status=URLStatus.NOT_RELEVANT - ), - ] ) - batch_2 = await ath.db_data_creator.batch_v2(batch_2_params) + url_mappings_2_ok: list[URLMapping] = await ddc.create_urls(batch_id=batch_2, count=4, status=URLStatus.OK) + url_mappings_2_error: list[URLMapping] = await ddc.create_urls(batch_id=batch_2, count=2, status=URLStatus.ERROR) + url_mappings_2_validated: list[URLMapping] = await ddc.create_validated_urls(count=1, validation_type=URLType.DATA_SOURCE) + url_mappings_2_not_relevant: list[URLMapping] = await ddc.create_validated_urls(count=5, validation_type=URLType.NOT_RELEVANT) + url_ids_2_validated: list[int] = [url_mapping.url_id for url_mapping in url_mappings_2_validated] + url_ids_2_not_relevant: list[int] = [url_mapping.url_id for url_mapping in url_mappings_2_not_relevant] + await ddc.create_batch_url_links( + url_ids=url_ids_2_validated + url_ids_2_not_relevant, + batch_id=batch_2 + ) + + await ddc.adb_client.refresh_materialized_views() dto = await ath.request_validator.get_urls_aggregated_metrics() - assert dto.oldest_pending_url_id == oldest_url_id - assert dto.oldest_pending_url_created_at == today.subtract(days=1).in_timezone('UTC').naive() - assert dto.count_urls_pending == 6 - assert dto.count_urls_rejected == 5 - assert dto.count_urls_errors == 2 - assert dto.count_urls_validated == 1 - assert dto.count_urls_submitted == 2 - assert dto.count_urls_total == 16 + assert dto.oldest_pending_url.url_id == oldest_url_id + # assert dto.count_urls_rejected == 5 + # assert dto.count_urls_errors == 2 + # assert dto.count_urls_validated == 8 + # assert dto.count_urls_submitted == 2 + # assert dto.count_urls_total == 16 diff --git a/tests/automated/integration/api/metrics/urls/aggregated/test_pending.py b/tests/automated/integration/api/metrics/urls/aggregated/test_pending.py index 1b55f04d..fee6ef46 100644 --- a/tests/automated/integration/api/metrics/urls/aggregated/test_pending.py +++ b/tests/automated/integration/api/metrics/urls/aggregated/test_pending.py @@ -1,7 +1,8 @@ import pytest from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.core.enums import SuggestedStatus, RecordType +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters @@ -26,19 +27,19 @@ async def setup_test_batches(db_data_creator): batches = [ create_batch( annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT + user_relevant=URLType.DATA_SOURCE ) ), create_batch( annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_record_type=RecordType.ARREST_RECORDS ), count=2 ), create_batch( annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_record_type=RecordType.CALLS_FOR_SERVICE, user_agency=URLAgencyAnnotationPostInfo( suggested_agency=await db_data_creator.agency() @@ -59,7 +60,7 @@ async def setup_test_batches(db_data_creator): ), create_batch( annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_record_type=RecordType.PERSONNEL_RECORDS, user_agency=URLAgencyAnnotationPostInfo( suggested_agency=await db_data_creator.agency() @@ -69,7 +70,7 @@ async def setup_test_batches(db_data_creator): ), create_batch( annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_agency=URLAgencyAnnotationPostInfo( is_new=True ) diff --git a/tests/automated/integration/api/metrics/urls/breakdown/test_pending.py b/tests/automated/integration/api/metrics/urls/breakdown/test_pending.py index e81d6ec7..3e906a8c 100644 --- a/tests/automated/integration/api/metrics/urls/breakdown/test_pending.py +++ b/tests/automated/integration/api/metrics/urls/breakdown/test_pending.py @@ -2,10 +2,12 @@ import pytest from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.collectors.enums import CollectorType, URLStatus -from src.core.enums import SuggestedStatus, RecordType +from src.collectors.enums import CollectorType +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters @@ -27,14 +29,14 @@ async def test_get_urls_breakdown_pending_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=1, - status=URLStatus.PENDING, + status=URLCreationEnum.OK, annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.NOT_RELEVANT + user_relevant=URLType.NOT_RELEVANT ) ), TestURLCreationParameters( count=2, - status=URLStatus.SUBMITTED + status=URLCreationEnum.SUBMITTED ), ] ) @@ -44,9 +46,9 @@ async def test_get_urls_breakdown_pending_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=3, - status=URLStatus.PENDING, + status=URLCreationEnum.OK, annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_record_type=RecordType.CALLS_FOR_SERVICE ) ) @@ -60,17 +62,17 @@ async def test_get_urls_breakdown_pending_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=3, - status=URLStatus.SUBMITTED + status=URLCreationEnum.SUBMITTED ), TestURLCreationParameters( count=4, - status=URLStatus.ERROR + status=URLCreationEnum.ERROR ), TestURLCreationParameters( count=5, - status=URLStatus.PENDING, + status=URLCreationEnum.OK, annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, + user_relevant=URLType.DATA_SOURCE, user_record_type=RecordType.INCARCERATION_RECORDS, user_agency=URLAgencyAnnotationPostInfo( suggested_agency=agency_id diff --git a/tests/automated/integration/api/metrics/urls/breakdown/test_submitted.py b/tests/automated/integration/api/metrics/urls/breakdown/test_submitted.py index 71e00e51..cbd30f8b 100644 --- a/tests/automated/integration/api/metrics/urls/breakdown/test_submitted.py +++ b/tests/automated/integration/api/metrics/urls/breakdown/test_submitted.py @@ -3,6 +3,7 @@ from src.collectors.enums import CollectorType, URLStatus from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters @@ -18,11 +19,11 @@ async def test_get_urls_breakdown_submitted_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=1, - status=URLStatus.PENDING + status=URLCreationEnum.OK ), TestURLCreationParameters( count=2, - status=URLStatus.SUBMITTED + status=URLCreationEnum.SUBMITTED ), ] ) @@ -32,7 +33,7 @@ async def test_get_urls_breakdown_submitted_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=3, - status=URLStatus.SUBMITTED + status=URLCreationEnum.SUBMITTED ) ], created_at=today.subtract(weeks=1), @@ -44,15 +45,15 @@ async def test_get_urls_breakdown_submitted_metrics(api_test_helper): urls=[ TestURLCreationParameters( count=3, - status=URLStatus.SUBMITTED + status=URLCreationEnum.SUBMITTED ), TestURLCreationParameters( count=4, - status=URLStatus.ERROR + status=URLCreationEnum.ERROR ), TestURLCreationParameters( count=5, - status=URLStatus.VALIDATED + status=URLCreationEnum.VALIDATED ), ] ) diff --git a/tests/automated/integration/api/review/conftest.py b/tests/automated/integration/api/review/conftest.py deleted file mode 100644 index e4345821..00000000 --- a/tests/automated/integration/api/review/conftest.py +++ /dev/null @@ -1,45 +0,0 @@ -import pytest_asyncio - -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.collectors.enums import URLStatus -from src.core.enums import SuggestedStatus, RecordType -from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters - - -@pytest_asyncio.fixture -async def batch_url_creation_info(db_data_creator): - simple_parameter_statuses = [ - URLStatus.VALIDATED, - URLStatus.SUBMITTED, - URLStatus.INDIVIDUAL_RECORD, - URLStatus.NOT_RELEVANT, - URLStatus.ERROR, - URLStatus.DUPLICATE, - URLStatus.NOT_FOUND - ] - simple_parameters = [ - TestURLCreationParameters( - status=status - ) for status in simple_parameter_statuses - ] - - parameters = TestBatchCreationParameters( - urls=[ - *simple_parameters, - TestURLCreationParameters( - count=2, - status=URLStatus.PENDING, - annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, - user_record_type=RecordType.ARREST_RECORDS, - user_agency=URLAgencyAnnotationPostInfo( - suggested_agency=await db_data_creator.agency() - ) - ) - ) - ] - ) - - return await db_data_creator.batch_v2(parameters=parameters) diff --git a/tests/automated/integration/api/review/rejection/helpers.py b/tests/automated/integration/api/review/rejection/helpers.py deleted file mode 100644 index 8fb26603..00000000 --- a/tests/automated/integration/api/review/rejection/helpers.py +++ /dev/null @@ -1,39 +0,0 @@ -from src.api.endpoints.review.enums import RejectionReason -from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse -from src.api.endpoints.review.reject.dto import FinalReviewRejectionInfo -from src.collectors.enums import URLStatus -from src.db.models.instantiations.url.core import URL -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review - - -async def run_rejection_test( - api_test_helper, - rejection_reason: RejectionReason, - url_status: URLStatus -): - ath = api_test_helper - db_data_creator = ath.db_data_creator - - setup_info = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=3, - include_user_annotations=True - ) - url_mapping = setup_info.url_mapping - - result: GetNextURLForFinalReviewOuterResponse = await ath.request_validator.reject_and_get_next_source_for_review( - review_info=FinalReviewRejectionInfo( - url_id=url_mapping.url_id, - rejection_reason=rejection_reason - ) - ) - - assert result.next_source is None - - adb_client = db_data_creator.adb_client - # Confirm same agency id is listed as rejected - urls: list[URL] = await adb_client.get_all(URL) - assert len(urls) == 1 - url = urls[0] - assert url.id == url_mapping.url_id - assert url.outcome == url_status.value diff --git a/tests/automated/integration/api/review/rejection/test_broken_page.py b/tests/automated/integration/api/review/rejection/test_broken_page.py deleted file mode 100644 index 813e523a..00000000 --- a/tests/automated/integration/api/review/rejection/test_broken_page.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from src.api.endpoints.review.enums import RejectionReason -from src.collectors.enums import URLStatus -from tests.automated.integration.api.review.rejection.helpers import run_rejection_test - - -@pytest.mark.asyncio -async def test_rejection_broken_page(api_test_helper): - await run_rejection_test( - api_test_helper, - rejection_reason=RejectionReason.BROKEN_PAGE_404, - url_status=URLStatus.NOT_FOUND - ) diff --git a/tests/automated/integration/api/review/rejection/test_individual_record.py b/tests/automated/integration/api/review/rejection/test_individual_record.py deleted file mode 100644 index 6e81d378..00000000 --- a/tests/automated/integration/api/review/rejection/test_individual_record.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest - -from src.api.endpoints.review.enums import RejectionReason -from src.collectors.enums import URLStatus -from tests.automated.integration.api.review.rejection.helpers import run_rejection_test - - -@pytest.mark.asyncio -async def test_rejection_individual_record(api_test_helper): - await run_rejection_test( - api_test_helper, - rejection_reason=RejectionReason.INDIVIDUAL_RECORD, - url_status=URLStatus.INDIVIDUAL_RECORD - ) - diff --git a/tests/automated/integration/api/review/rejection/test_not_relevant.py b/tests/automated/integration/api/review/rejection/test_not_relevant.py deleted file mode 100644 index 1ad2847f..00000000 --- a/tests/automated/integration/api/review/rejection/test_not_relevant.py +++ /dev/null @@ -1,14 +0,0 @@ -import pytest - -from src.api.endpoints.review.enums import RejectionReason -from src.collectors.enums import URLStatus -from tests.automated.integration.api.review.rejection.helpers import run_rejection_test - - -@pytest.mark.asyncio -async def test_rejection_not_relevant(api_test_helper): - await run_rejection_test( - api_test_helper, - rejection_reason=RejectionReason.NOT_RELEVANT, - url_status=URLStatus.NOT_RELEVANT - ) diff --git a/tests/automated/integration/api/review/test_approve_and_get_next_source.py b/tests/automated/integration/api/review/test_approve_and_get_next_source.py deleted file mode 100644 index 9afc16d8..00000000 --- a/tests/automated/integration/api/review/test_approve_and_get_next_source.py +++ /dev/null @@ -1,78 +0,0 @@ -import pytest - -from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.api.endpoints.review.next.dto import GetNextURLForFinalReviewOuterResponse -from src.collectors.enums import URLStatus -from src.core.enums import RecordType -from src.db.constants import PLACEHOLDER_AGENCY_NAME -from src.db.models.instantiations.agency import Agency -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review - - -@pytest.mark.asyncio -async def test_approve_and_get_next_source_for_review(api_test_helper): - ath = api_test_helper - db_data_creator = ath.db_data_creator - - setup_info = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - include_user_annotations=True - ) - url_mapping = setup_info.url_mapping - - # Add confirmed agency - await db_data_creator.confirmed_suggestions([url_mapping.url_id]) - - # Additionally, include an agency not yet included in the database - additional_agency = 999999 - - agency_ids = [await db_data_creator.agency() for _ in range(3)] - agency_ids.append(additional_agency) - - result: GetNextURLForFinalReviewOuterResponse = await ath.request_validator.approve_and_get_next_source_for_review( - approval_info=FinalReviewApprovalInfo( - url_id=url_mapping.url_id, - record_type=RecordType.ARREST_RECORDS, - agency_ids=agency_ids, - name="New Test Name", - description="New Test Description", - record_formats=["New Test Record Format", "New Test Record Format 2"], - data_portal_type="New Test Data Portal Type", - supplying_entity="New Test Supplying Entity" - ) - ) - - assert result.remaining == 0 - assert result.next_source is None - - adb_client = db_data_creator.adb_client - # Confirm same agency id is listed as confirmed - urls = await adb_client.get_all(URL) - assert len(urls) == 1 - url = urls[0] - assert url.id == url_mapping.url_id - assert url.record_type == RecordType.ARREST_RECORDS.value - assert url.outcome == URLStatus.VALIDATED.value - assert url.name == "New Test Name" - assert url.description == "New Test Description" - - optional_metadata = await adb_client.get_all(URLOptionalDataSourceMetadata) - assert len(optional_metadata) == 1 - assert optional_metadata[0].data_portal_type == "New Test Data Portal Type" - assert optional_metadata[0].supplying_entity == "New Test Supplying Entity" - assert optional_metadata[0].record_formats == ["New Test Record Format", "New Test Record Format 2"] - - # Get agencies - confirmed_agencies = await adb_client.get_all(ConfirmedURLAgency) - assert len(confirmed_agencies) == 4 - for agency in confirmed_agencies: - assert agency.agency_id in agency_ids - - # Check that created agency has placeholder - agencies = await adb_client.get_all(Agency) - for agency in agencies: - if agency.agency_id == additional_agency: - assert agency.name == PLACEHOLDER_AGENCY_NAME diff --git a/tests/automated/integration/api/review/test_batch_filtering.py b/tests/automated/integration/api/review/test_batch_filtering.py deleted file mode 100644 index 2e8aa63c..00000000 --- a/tests/automated/integration/api/review/test_batch_filtering.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - - -@pytest.mark.asyncio -async def test_batch_filtering( - batch_url_creation_info, - api_test_helper -): - ath = api_test_helper - rv = ath.request_validator - - # Receive null batch info if batch id not provided - outer_result_no_batch_info = await rv.review_next_source() - assert outer_result_no_batch_info.next_source.batch_info is None - - # Get batch info if batch id is provided - outer_result = await ath.request_validator.review_next_source( - batch_id=batch_url_creation_info.batch_id - ) - assert outer_result.remaining == 2 - batch_info = outer_result.next_source.batch_info - assert batch_info.count_reviewed == 4 - assert batch_info.count_ready_for_review == 2 - diff --git a/tests/automated/integration/api/review/test_next_source.py b/tests/automated/integration/api/review/test_next_source.py deleted file mode 100644 index 790914ee..00000000 --- a/tests/automated/integration/api/review/test_next_source.py +++ /dev/null @@ -1,65 +0,0 @@ -import pytest - -from src.core.enums import SuggestedStatus, RecordType -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review - - -@pytest.mark.asyncio -async def test_review_next_source(api_test_helper): - ath = api_test_helper - - setup_info = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, - include_user_annotations=True - ) - url_mapping = setup_info.url_mapping - - await ath.db_data_creator.agency_auto_suggestions( - url_id=url_mapping.url_id, - count=3 - ) - confirmed_agency_id = await ath.db_data_creator.agency_confirmed_suggestion(url_id=url_mapping.url_id) - - outer_result = await ath.request_validator.review_next_source() - assert outer_result.remaining == 1 - - result = outer_result.next_source - - assert result.name == "Test Name" - assert result.description == "Test Description" - - optional_metadata = result.optional_metadata - - assert optional_metadata.data_portal_type == "Test Data Portal Type" - assert optional_metadata.supplying_entity == "Test Supplying Entity" - assert optional_metadata.record_formats == ["Test Record Format", "Test Record Format 2"] - - assert result.url == url_mapping.url - html_info = result.html_info - assert html_info.description == "test description" - assert html_info.title == "test html content" - - annotation_info = result.annotations - relevant_info = annotation_info.relevant - assert relevant_info.auto.is_relevant == True - assert relevant_info.user == SuggestedStatus.NOT_RELEVANT - - record_type_info = annotation_info.record_type - assert record_type_info.auto == RecordType.ARREST_RECORDS - assert record_type_info.user == RecordType.ACCIDENT_REPORTS - - agency_info = annotation_info.agency - auto_agency_suggestions = agency_info.auto - assert auto_agency_suggestions.unknown == False - assert len(auto_agency_suggestions.suggestions) == 3 - - # Check user agency suggestions exist and in descending order of count - user_agency_suggestion = agency_info.user - assert user_agency_suggestion.pdap_agency_id == setup_info.user_agency_id - - - # Check confirmed agencies exist - confirmed_agencies = agency_info.confirmed - assert len(confirmed_agencies) == 1 - confirmed_agency = confirmed_agencies[0] - assert confirmed_agency.pdap_agency_id == confirmed_agency_id diff --git a/tests/automated/integration/api/search/__init__.py b/tests/automated/integration/api/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/search/agency/__init__.py b/tests/automated/integration/api/search/agency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/search/agency/test_search.py b/tests/automated/integration/api/search/agency/test_search.py new file mode 100644 index 00000000..cc3fee19 --- /dev/null +++ b/tests/automated/integration/api/search/agency/test_search.py @@ -0,0 +1,63 @@ +import pytest + +from tests.helpers.api_test_helper import APITestHelper +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo + + +@pytest.mark.asyncio +async def test_search_agency( + api_test_helper: APITestHelper, + db_data_creator: DBDataCreator, + pittsburgh_locality: LocalityCreationInfo, + allegheny_county: CountyCreationInfo +): + + agency_a_id: int = await db_data_creator.agency("A Agency") + agency_b_id: int = await db_data_creator.agency("AB Agency") + agency_c_id: int = await db_data_creator.agency("ABC Agency") + + await db_data_creator.link_agencies_to_location( + agency_ids=[agency_a_id, agency_c_id], + location_id=pittsburgh_locality.location_id + ) + await db_data_creator.link_agencies_to_location( + agency_ids=[agency_b_id], + location_id=allegheny_county.location_id + ) + + responses: list[dict] = api_test_helper.request_validator.get_v2( + url="/search/agency", + params={ + "query": "A Agency", + } + ) + assert len(responses) == 3 + assert responses[0]["agency_id"] == agency_a_id + assert responses[1]["agency_id"] == agency_b_id + assert responses[2]["agency_id"] == agency_c_id + + # Filter based on location ID + responses = api_test_helper.request_validator.get_v2( + url="/search/agency", + params={ + "query": "A Agency", + "location_id": pittsburgh_locality.location_id + } + ) + + assert len(responses) == 2 + assert responses[0]["agency_id"] == agency_a_id + assert responses[1]["agency_id"] == agency_c_id + + # Filter again based on location ID but with Allegheny County + # Confirm pittsburgh agencies are picked up + responses = api_test_helper.request_validator.get_v2( + url="/search/agency", + params={ + "query": "A Agency", + "location_id": allegheny_county.location_id + } + ) + assert len(responses) == 3 diff --git a/tests/automated/integration/api/search/url/__init__.py b/tests/automated/integration/api/search/url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/test_search.py b/tests/automated/integration/api/search/url/test_search.py similarity index 100% rename from tests/automated/integration/api/test_search.py rename to tests/automated/integration/api/search/url/test_search.py diff --git a/tests/automated/integration/api/submit/__init__.py b/tests/automated/integration/api/submit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/submit/test_duplicate.py b/tests/automated/integration/api/submit/test_duplicate.py new file mode 100644 index 00000000..c1ccfd29 --- /dev/null +++ b/tests/automated/integration/api/submit/test_duplicate.py @@ -0,0 +1,24 @@ +import pytest + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.db.dtos.url.mapping import URLMapping +from tests.helpers.api_test_helper import APITestHelper +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_duplicate( + api_test_helper: APITestHelper, + db_data_creator: DBDataCreator +): + url_mapping: URLMapping = (await db_data_creator.create_urls(count=1))[0] + + response: URLSubmissionResponse = await api_test_helper.request_validator.submit_url( + request=URLSubmissionRequest( + url=url_mapping.url + ) + ) + assert response.status == URLSubmissionStatus.DATABASE_DUPLICATE + assert response.url_id is None \ No newline at end of file diff --git a/tests/automated/integration/api/submit/test_invalid.py b/tests/automated/integration/api/submit/test_invalid.py new file mode 100644 index 00000000..a5ae27e7 --- /dev/null +++ b/tests/automated/integration/api/submit/test_invalid.py @@ -0,0 +1,16 @@ +import pytest + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from tests.helpers.api_test_helper import APITestHelper + + +@pytest.mark.asyncio +async def test_invalid(api_test_helper: APITestHelper): + response: URLSubmissionResponse = await api_test_helper.request_validator.submit_url( + request=URLSubmissionRequest( + url="invalid_url" + ) + ) + assert response.status == URLSubmissionStatus.INVALID \ No newline at end of file diff --git a/tests/automated/integration/api/submit/test_needs_cleaning.py b/tests/automated/integration/api/submit/test_needs_cleaning.py new file mode 100644 index 00000000..c6512502 --- /dev/null +++ b/tests/automated/integration/api/submit/test_needs_cleaning.py @@ -0,0 +1,37 @@ +import pytest + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.user_suggestion_not_found.users_submitted_url.sqlalchemy import LinkUserSubmittedURL +from src.db.models.impl.url.core.sqlalchemy import URL +from tests.helpers.api_test_helper import APITestHelper + + +@pytest.mark.asyncio +async def test_needs_cleaning( + api_test_helper: APITestHelper, + adb_client_test: AsyncDatabaseClient +): + response: URLSubmissionResponse = await api_test_helper.request_validator.submit_url( + request=URLSubmissionRequest( + url="www.example.com#fdragment" + ) + ) + + assert response.status == URLSubmissionStatus.ACCEPTED_WITH_CLEANING + assert response.url_id is not None + url_id: int = response.url_id + + adb_client: AsyncDatabaseClient = adb_client_test + urls: list[URL] = await adb_client.get_all(URL) + assert len(urls) == 1 + url: URL = urls[0] + assert url.id == url_id + assert url.url == "www.example.com" + + links: list[LinkUserSubmittedURL] = await adb_client.get_all(LinkUserSubmittedURL) + assert len(links) == 1 + link: LinkUserSubmittedURL = links[0] + assert link.url_id == url_id \ No newline at end of file diff --git a/tests/automated/integration/api/submit/test_url_maximal.py b/tests/automated/integration/api/submit/test_url_maximal.py new file mode 100644 index 00000000..8d1930f5 --- /dev/null +++ b/tests/automated/integration/api/submit/test_url_maximal.py @@ -0,0 +1,85 @@ +import pytest + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.core.enums import RecordType +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.link.user_suggestion_not_found.users_submitted_url.sqlalchemy import LinkUserSubmittedURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.suggestion.agency.user import UserUrlAgencySuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.suggestion.record_type.user import UserRecordTypeSuggestion +from tests.helpers.api_test_helper import APITestHelper +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo + + +@pytest.mark.asyncio +async def test_maximal( + api_test_helper: APITestHelper, + adb_client_test: AsyncDatabaseClient, + db_data_creator: DBDataCreator, + pittsburgh_locality: LocalityCreationInfo +): + + agency_id: int = await db_data_creator.agency() + + response: URLSubmissionResponse = await api_test_helper.request_validator.submit_url( + request=URLSubmissionRequest( + url="www.example.com", + record_type=RecordType.INCARCERATION_RECORDS, + name="Example URL", + location_id=pittsburgh_locality.location_id, + agency_id=agency_id, + ) + ) + + assert response.status == URLSubmissionStatus.ACCEPTED_AS_IS + assert response.url_id is not None + url_id: int = response.url_id + + adb_client: AsyncDatabaseClient = adb_client_test + urls: list[URL] = await adb_client.get_all(URL) + assert len(urls) == 1 + url: URL = urls[0] + assert url.id == url_id + assert url.url == "www.example.com" + + links: list[LinkUserSubmittedURL] = await adb_client.get_all(LinkUserSubmittedURL) + assert len(links) == 1 + link: LinkUserSubmittedURL = links[0] + assert link.url_id == url_id + + agen_suggs: list[UserUrlAgencySuggestion] = await adb_client.get_all(UserUrlAgencySuggestion) + assert len(agen_suggs) == 1 + agen_sugg: UserUrlAgencySuggestion = agen_suggs[0] + assert agen_sugg.url_id == url_id + assert agen_sugg.agency_id == agency_id + + loc_suggs: list[UserLocationSuggestion] = await adb_client.get_all(UserLocationSuggestion) + assert len(loc_suggs) == 1 + loc_sugg: UserLocationSuggestion = loc_suggs[0] + assert loc_sugg.url_id == url_id + assert loc_sugg.location_id == pittsburgh_locality.location_id + + name_sugg: list[URLNameSuggestion] = await adb_client.get_all(URLNameSuggestion) + assert len(name_sugg) == 1 + name_sugg: URLNameSuggestion = name_sugg[0] + assert name_sugg.url_id == url_id + assert name_sugg.suggestion == "Example URL" + assert name_sugg.source == NameSuggestionSource.USER + + name_link_suggs: list[LinkUserNameSuggestion] = await adb_client.get_all(LinkUserNameSuggestion) + assert len(name_link_suggs) == 1 + name_link_sugg: LinkUserNameSuggestion = name_link_suggs[0] + assert name_link_sugg.suggestion_id == name_sugg.id + + rec_suggs: list[UserRecordTypeSuggestion] = await adb_client.get_all(UserRecordTypeSuggestion) + assert len(rec_suggs) == 1 + rec_sugg: UserRecordTypeSuggestion = rec_suggs[0] + assert rec_sugg.url_id == url_id + assert rec_sugg.record_type == RecordType.INCARCERATION_RECORDS.value diff --git a/tests/automated/integration/api/submit/test_url_minimal.py b/tests/automated/integration/api/submit/test_url_minimal.py new file mode 100644 index 00000000..f1f078f6 --- /dev/null +++ b/tests/automated/integration/api/submit/test_url_minimal.py @@ -0,0 +1,37 @@ +import pytest + +from src.api.endpoints.submit.url.enums import URLSubmissionStatus +from src.api.endpoints.submit.url.models.request import URLSubmissionRequest +from src.api.endpoints.submit.url.models.response import URLSubmissionResponse +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.user_suggestion_not_found.users_submitted_url.sqlalchemy import LinkUserSubmittedURL +from src.db.models.impl.url.core.sqlalchemy import URL +from tests.helpers.api_test_helper import APITestHelper + + +@pytest.mark.asyncio +async def test_minimal( + api_test_helper: APITestHelper, + adb_client_test: AsyncDatabaseClient +): + response: URLSubmissionResponse = await api_test_helper.request_validator.submit_url( + request=URLSubmissionRequest( + url="www.example.com" + ) + ) + + assert response.status == URLSubmissionStatus.ACCEPTED_AS_IS + assert response.url_id is not None + url_id: int = response.url_id + + adb_client: AsyncDatabaseClient = adb_client_test + urls: list[URL] = await adb_client.get_all(URL) + assert len(urls) == 1 + url: URL = urls[0] + assert url.id == url_id + assert url.url == "www.example.com" + + links: list[LinkUserSubmittedURL] = await adb_client.get_all(LinkUserSubmittedURL) + assert len(links) == 1 + link: LinkUserSubmittedURL = links[0] + assert link.url_id == url_id \ No newline at end of file diff --git a/tests/automated/integration/api/test_annotate.py b/tests/automated/integration/api/test_annotate.py deleted file mode 100644 index b0039212..00000000 --- a/tests/automated/integration/api/test_annotate.py +++ /dev/null @@ -1,756 +0,0 @@ -from http import HTTPStatus - -import pytest -from fastapi import HTTPException - -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.post import RecordTypeAnnotationPostInfo -from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.get.dto import GetNextRelevanceAnnotationResponseOuterInfo -from src.api.endpoints.annotate.relevance.post.dto import RelevanceAnnotationPostInfo -from src.core.tasks.url.operators.url_html.scraper.parser.dtos.response_html import ResponseHTMLInfo -from src.db.dtos.url.insert import InsertURLsInfo -from src.db.dtos.url.mapping import URLMapping -from src.db.models.instantiations.url.suggestion.agency.user import UserUrlAgencySuggestion -from src.core.error_manager.enums import ErrorTypes -from src.core.enums import RecordType, SuggestionType, SuggestedStatus -from src.core.exceptions import FailedValidationException -from src.db.models.instantiations.url.suggestion.record_type.user import UserRecordTypeSuggestion -from src.db.models.instantiations.url.suggestion.relevant.user import UserRelevantSuggestion -from tests.helpers.setup.annotate_agency.model import AnnotateAgencySetupInfo -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import BatchURLCreationInfo -from tests.automated.integration.api.conftest import MOCK_USER_ID - -def check_url_mappings_match( - map_1: URLMapping, - map_2: URLMapping -): - assert map_1.url_id == map_2.url_id - assert map_2.url == map_2.url - -def check_html_info_not_empty( - html_info: ResponseHTMLInfo -): - assert not html_info_empty(html_info) - -def html_info_empty( - html_info: ResponseHTMLInfo -) -> bool: - return html_info.description == "" and html_info.title == "" - -@pytest.mark.asyncio -async def test_annotate_relevancy(api_test_helper): - ath = api_test_helper - - batch_id = ath.db_data_creator.batch() - - # Create 2 URLs with outcome `pending` - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=2) - - url_1 = iui.url_mappings[0] - url_2 = iui.url_mappings[1] - - # Add `Relevancy` attribute with value `True` to 1st URL - await ath.db_data_creator.auto_relevant_suggestions( - url_id=url_1.url_id, - relevant=True - ) - - # Add 'Relevancy' attribute with value `False` to 2nd URL - await ath.db_data_creator.auto_relevant_suggestions( - url_id=url_2.url_id, - relevant=False - ) - - # Add HTML data to both - await ath.db_data_creator.html_data([url_1.url_id, url_2.url_id]) - # Call `GET` `/annotate/relevance` and receive next URL - request_info_1: GetNextRelevanceAnnotationResponseOuterInfo = api_test_helper.request_validator.get_next_relevance_annotation() - inner_info_1 = request_info_1.next_annotation - - check_url_mappings_match(inner_info_1.url_info, url_1) - check_html_info_not_empty(inner_info_1.html_info) - - # Validate that the correct relevant value is returned - assert inner_info_1.annotation.is_relevant is True - - # A second user should see the same URL - - - # Annotate with value 'False' and get next URL - request_info_2: GetNextRelevanceAnnotationResponseOuterInfo = api_test_helper.request_validator.post_relevance_annotation_and_get_next( - url_id=inner_info_1.url_info.url_id, - relevance_annotation_post_info=RelevanceAnnotationPostInfo( - suggested_status=SuggestedStatus.NOT_RELEVANT - ) - ) - - inner_info_2 = request_info_2.next_annotation - - check_url_mappings_match( - inner_info_2.url_info, - url_2 - ) - check_html_info_not_empty(inner_info_2.html_info) - - request_info_3: GetNextRelevanceAnnotationResponseOuterInfo = api_test_helper.request_validator.post_relevance_annotation_and_get_next( - url_id=inner_info_2.url_info.url_id, - relevance_annotation_post_info=RelevanceAnnotationPostInfo( - suggested_status=SuggestedStatus.RELEVANT - ) - ) - - assert request_info_3.next_annotation is None - - # Get all URL annotations. Confirm they exist for user - adb_client = ath.adb_client() - results: list[UserRelevantSuggestion] = await adb_client.get_all(UserRelevantSuggestion) - result_1 = results[0] - result_2 = results[1] - - assert result_1.url_id == inner_info_1.url_info.url_id - assert result_1.suggested_status == SuggestedStatus.NOT_RELEVANT.value - - assert result_2.url_id == inner_info_2.url_info.url_id - assert result_2.suggested_status == SuggestedStatus.RELEVANT.value - - # If user submits annotation for same URL, the URL should be overwritten - request_info_4: GetNextRelevanceAnnotationResponseOuterInfo = api_test_helper.request_validator.post_relevance_annotation_and_get_next( - url_id=inner_info_1.url_info.url_id, - relevance_annotation_post_info=RelevanceAnnotationPostInfo( - suggested_status=SuggestedStatus.RELEVANT - ) - ) - - assert request_info_4.next_annotation is None - - results: list[UserRelevantSuggestion] = await adb_client.get_all(UserRelevantSuggestion) - assert len(results) == 2 - - for result in results: - if result.url_id == inner_info_1.url_info.url_id: - assert results[0].suggested_status == SuggestedStatus.RELEVANT.value - -async def post_and_validate_relevancy_annotation(ath, url_id, annotation: SuggestedStatus): - response = ath.request_validator.post_relevance_annotation_and_get_next( - url_id=url_id, - relevance_annotation_post_info=RelevanceAnnotationPostInfo( - suggested_status=annotation - ) - ) - - assert response.next_annotation is None - - results: list[UserRelevantSuggestion] = await ath.adb_client().get_all(UserRelevantSuggestion) - assert len(results) == 1 - assert results[0].suggested_status == annotation.value - -@pytest.mark.asyncio -async def test_annotate_relevancy_broken_page(api_test_helper): - ath = api_test_helper - - creation_info = await ath.db_data_creator.batch_and_urls(url_count=1, with_html_content=False) - - await post_and_validate_relevancy_annotation( - ath, - url_id=creation_info.url_ids[0], - annotation=SuggestedStatus.BROKEN_PAGE_404 - ) - -@pytest.mark.asyncio -async def test_annotate_relevancy_individual_record(api_test_helper): - ath = api_test_helper - - creation_info: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1 - ) - - await post_and_validate_relevancy_annotation( - ath, - url_id=creation_info.url_ids[0], - annotation=SuggestedStatus.INDIVIDUAL_RECORD - ) - -@pytest.mark.asyncio -async def test_annotate_relevancy_already_annotated_by_different_user( - api_test_helper -): - ath = api_test_helper - - creation_info: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1 - ) - - await ath.db_data_creator.user_relevant_suggestion( - url_id=creation_info.url_ids[0], - user_id=2, - relevant=True - ) - - # Annotate with different user (default is 1) and get conflict error - try: - response = await ath.request_validator.post_relevance_annotation_and_get_next( - url_id=creation_info.url_ids[0], - relevance_annotation_post_info=RelevanceAnnotationPostInfo( - suggested_status=SuggestedStatus.NOT_RELEVANT - ) - ) - except HTTPException as e: - assert e.status_code == HTTPStatus.CONFLICT - assert e.detail["detail"]["code"] == ErrorTypes.ANNOTATION_EXISTS.value - assert e.detail["detail"]["message"] == f"Annotation of type RELEVANCE already exists for url {creation_info.url_ids[0]}" - - -@pytest.mark.asyncio -async def test_annotate_relevancy_no_html(api_test_helper): - ath = api_test_helper - - batch_id = ath.db_data_creator.batch() - - # Create 2 URLs with outcome `pending` - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=2) - - url_1 = iui.url_mappings[0] - url_2 = iui.url_mappings[1] - - # Add `Relevancy` attribute with value `True` to 1st URL - await ath.db_data_creator.auto_relevant_suggestions( - url_id=url_1.url_id, - relevant=True - ) - - # Add 'Relevancy' attribute with value `False` to 2nd URL - await ath.db_data_creator.auto_relevant_suggestions( - url_id=url_2.url_id, - relevant=False - ) - - # Call `GET` `/annotate/relevance` and receive next URL - request_info_1: GetNextRelevanceAnnotationResponseOuterInfo = api_test_helper.request_validator.get_next_relevance_annotation() - inner_info_1 = request_info_1.next_annotation - - check_url_mappings_match(inner_info_1.url_info, url_1) - assert html_info_empty(inner_info_1.html_info) - -@pytest.mark.asyncio -async def test_annotate_record_type(api_test_helper): - ath = api_test_helper - - batch_id = ath.db_data_creator.batch() - - # Create 2 URLs with outcome `pending` - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=2) - - url_1 = iui.url_mappings[0] - url_2 = iui.url_mappings[1] - - # Add record type attribute with value `Accident Reports` to 1st URL - await ath.db_data_creator.auto_record_type_suggestions( - url_id=url_1.url_id, - record_type=RecordType.ACCIDENT_REPORTS - ) - - # Add 'Record Type' attribute with value `Dispatch Recordings` to 2nd URL - await ath.db_data_creator.auto_record_type_suggestions( - url_id=url_2.url_id, - record_type=RecordType.DISPATCH_RECORDINGS - ) - - # Add HTML data to both - await ath.db_data_creator.html_data([url_1.url_id, url_2.url_id]) - - # Call `GET` `/annotate/record-type` and receive next URL - request_info_1: GetNextRecordTypeAnnotationResponseOuterInfo = api_test_helper.request_validator.get_next_record_type_annotation() - inner_info_1 = request_info_1.next_annotation - - check_url_mappings_match(inner_info_1.url_info, url_1) - check_html_info_not_empty(inner_info_1.html_info) - - # Validate that the correct record type is returned - assert inner_info_1.suggested_record_type == RecordType.ACCIDENT_REPORTS - - # Annotate with value 'Personnel Records' and get next URL - request_info_2: GetNextRecordTypeAnnotationResponseOuterInfo = api_test_helper.request_validator.post_record_type_annotation_and_get_next( - url_id=inner_info_1.url_info.url_id, - record_type_annotation_post_info=RecordTypeAnnotationPostInfo( - record_type=RecordType.PERSONNEL_RECORDS - ) - ) - - inner_info_2 = request_info_2.next_annotation - - check_url_mappings_match(inner_info_2.url_info, url_2) - check_html_info_not_empty(inner_info_2.html_info) - - request_info_3: GetNextRecordTypeAnnotationResponseOuterInfo = api_test_helper.request_validator.post_record_type_annotation_and_get_next( - url_id=inner_info_2.url_info.url_id, - record_type_annotation_post_info=RecordTypeAnnotationPostInfo( - record_type=RecordType.ANNUAL_AND_MONTHLY_REPORTS - ) - ) - - assert request_info_3.next_annotation is None - - # Get all URL annotations. Confirm they exist for user - adb_client = ath.adb_client() - results: list[UserRecordTypeSuggestion] = await adb_client.get_all(UserRecordTypeSuggestion) - result_1 = results[0] - result_2 = results[1] - - assert result_1.url_id == inner_info_1.url_info.url_id - assert result_1.record_type == RecordType.PERSONNEL_RECORDS.value - - assert result_2.url_id == inner_info_2.url_info.url_id - assert result_2.record_type == RecordType.ANNUAL_AND_MONTHLY_REPORTS.value - - # If user submits annotation for same URL, the URL should be overwritten - - request_info_4: GetNextRecordTypeAnnotationResponseOuterInfo = api_test_helper.request_validator.post_record_type_annotation_and_get_next( - url_id=inner_info_1.url_info.url_id, - record_type_annotation_post_info=RecordTypeAnnotationPostInfo( - record_type=RecordType.BOOKING_REPORTS - ) - ) - - assert request_info_4.next_annotation is None - - results: list[UserRecordTypeSuggestion] = await adb_client.get_all(UserRecordTypeSuggestion) - assert len(results) == 2 - - for result in results: - if result.url_id == inner_info_1.url_info.url_id: - assert result.record_type == RecordType.BOOKING_REPORTS.value - -@pytest.mark.asyncio -async def test_annotate_record_type_already_annotated_by_different_user( - api_test_helper -): - ath = api_test_helper - - creation_info: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1 - ) - - await ath.db_data_creator.user_record_type_suggestion( - url_id=creation_info.url_ids[0], - user_id=2, - record_type=RecordType.ACCIDENT_REPORTS - ) - - # Annotate with different user (default is 1) and get conflict error - try: - response = await ath.request_validator.post_record_type_annotation_and_get_next( - url_id=creation_info.url_ids[0], - record_type_annotation_post_info=RecordTypeAnnotationPostInfo( - record_type=RecordType.ANNUAL_AND_MONTHLY_REPORTS - ) - ) - except HTTPException as e: - assert e.status_code == HTTPStatus.CONFLICT - assert e.detail["detail"]["code"] == ErrorTypes.ANNOTATION_EXISTS.value - assert e.detail["detail"]["message"] == f"Annotation of type RECORD_TYPE already exists for url {creation_info.url_ids[0]}" - - -@pytest.mark.asyncio -async def test_annotate_record_type_no_html_info(api_test_helper): - ath = api_test_helper - - batch_id = ath.db_data_creator.batch() - - # Create 2 URLs with outcome `pending` - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=2) - - url_1 = iui.url_mappings[0] - url_2 = iui.url_mappings[1] - - # Add record type attribute with value `Accident Reports` to 1st URL - await ath.db_data_creator.auto_record_type_suggestions( - url_id=url_1.url_id, - record_type=RecordType.ACCIDENT_REPORTS - ) - - # Add 'Record Type' attribute with value `Dispatch Recordings` to 2nd URL - await ath.db_data_creator.auto_record_type_suggestions( - url_id=url_2.url_id, - record_type=RecordType.DISPATCH_RECORDINGS - ) - - # Call `GET` `/annotate/record-type` and receive next URL - request_info_1: GetNextRecordTypeAnnotationResponseOuterInfo = api_test_helper.request_validator.get_next_record_type_annotation() - inner_info_1 = request_info_1.next_annotation - - check_url_mappings_match(inner_info_1.url_info, url_1) - assert html_info_empty(inner_info_1.html_info) - -@pytest.mark.asyncio -async def test_annotate_agency_multiple_auto_suggestions(api_test_helper): - """ - Test Scenario: Multiple Auto Suggestions - A URL has multiple Agency Auto Suggestion and has not been annotated by the User - The user should receive all of the auto suggestions with full detail - """ - ath = api_test_helper - buci: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1, - with_html_content=True - ) - await ath.db_data_creator.auto_suggestions( - url_ids=buci.url_ids, - num_suggestions=2, - suggestion_type=SuggestionType.AUTO_SUGGESTION - ) - - # User requests next annotation - response = await ath.request_validator.get_next_agency_annotation() - - assert response.next_annotation - next_annotation = response.next_annotation - # Check that url_id matches the one we inserted - assert next_annotation.url_info.url_id == buci.url_ids[0] - - # Check that html data is present - assert next_annotation.html_info.description != "" - assert next_annotation.html_info.title != "" - - # Check that two agency_suggestions exist - assert len(next_annotation.agency_suggestions) == 2 - - for agency_suggestion in next_annotation.agency_suggestions: - assert agency_suggestion.suggestion_type == SuggestionType.AUTO_SUGGESTION - assert agency_suggestion.pdap_agency_id is not None - assert agency_suggestion.agency_name is not None - assert agency_suggestion.state is not None - assert agency_suggestion.county is not None - assert agency_suggestion.locality is not None - - -@pytest.mark.asyncio -async def test_annotate_agency_multiple_auto_suggestions_no_html(api_test_helper): - """ - Test Scenario: Multiple Auto Suggestions - A URL has multiple Agency Auto Suggestion and has not been annotated by the User - The user should receive all of the auto suggestions with full detail - """ - ath = api_test_helper - buci: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1, - with_html_content=False - ) - await ath.db_data_creator.auto_suggestions( - url_ids=buci.url_ids, - num_suggestions=2, - suggestion_type=SuggestionType.AUTO_SUGGESTION - ) - - # User requests next annotation - response = await ath.request_validator.get_next_agency_annotation() - - assert response.next_annotation - next_annotation = response.next_annotation - # Check that url_id matches the one we inserted - assert next_annotation.url_info.url_id == buci.url_ids[0] - - # Check that html data is not present - assert next_annotation.html_info.description == "" - assert next_annotation.html_info.title == "" - -@pytest.mark.asyncio -async def test_annotate_agency_single_unknown_auto_suggestion(api_test_helper): - """ - Test Scenario: Single Unknown Auto Suggestion - A URL has a single Unknown Agency Auto Suggestion and has not been annotated by the User - The user should receive a single Unknown Auto Suggestion lacking other detail - """ - ath = api_test_helper - buci: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1, - with_html_content=True - ) - await ath.db_data_creator.auto_suggestions( - url_ids=buci.url_ids, - num_suggestions=1, - suggestion_type=SuggestionType.UNKNOWN - ) - response = await ath.request_validator.get_next_agency_annotation() - - assert response.next_annotation - next_annotation = response.next_annotation - # Check that url_id matches the one we inserted - assert next_annotation.url_info.url_id == buci.url_ids[0] - - # Check that html data is present - assert next_annotation.html_info.description != "" - assert next_annotation.html_info.title != "" - - # Check that one agency_suggestion exists - assert len(next_annotation.agency_suggestions) == 1 - - agency_suggestion = next_annotation.agency_suggestions[0] - - assert agency_suggestion.suggestion_type == SuggestionType.UNKNOWN - assert agency_suggestion.pdap_agency_id is None - assert agency_suggestion.agency_name is None - assert agency_suggestion.state is None - assert agency_suggestion.county is None - assert agency_suggestion.locality is None - - -@pytest.mark.asyncio -async def test_annotate_agency_single_confirmed_agency(api_test_helper): - """ - Test Scenario: Single Confirmed Agency - A URL has a single Confirmed Agency and has not been annotated by the User - The user should not receive this URL to annotate - """ - ath = api_test_helper - buci: BatchURLCreationInfo = await ath.db_data_creator.batch_and_urls( - url_count=1, - with_html_content=True - ) - await ath.db_data_creator.confirmed_suggestions( - url_ids=buci.url_ids, - ) - response = await ath.request_validator.get_next_agency_annotation() - assert response.next_annotation is None - -@pytest.mark.asyncio -async def test_annotate_agency_other_user_annotation(api_test_helper): - """ - Test Scenario: Other User Annotation - A URL has been annotated by another User - Our user should still receive this URL to annotate - """ - ath = api_test_helper - setup_info: AnnotateAgencySetupInfo = await setup_for_annotate_agency( - db_data_creator=ath.db_data_creator, - url_count=1 - ) - url_ids = setup_info.url_ids - - response = await ath.request_validator.get_next_agency_annotation() - - assert response.next_annotation - next_annotation = response.next_annotation - # Check that url_id matches the one we inserted - assert next_annotation.url_info.url_id == url_ids[0] - - # Check that html data is present - assert next_annotation.html_info.description != "" - assert next_annotation.html_info.title != "" - - # Check that one agency_suggestion exists - assert len(next_annotation.agency_suggestions) == 1 - - # Test that another user can insert a suggestion - await ath.db_data_creator.manual_suggestion( - user_id=MOCK_USER_ID + 1, - url_id=url_ids[0], - ) - - # After this, text that our user does not receive this URL - response = await ath.request_validator.get_next_agency_annotation() - assert response.next_annotation is None - -@pytest.mark.asyncio -async def test_annotate_agency_submit_and_get_next(api_test_helper): - """ - Test Scenario: Submit and Get Next (no other URL available) - A URL has been annotated by our User, and no other valid URLs have not been annotated - Our user should not receive another URL to annotate - Until another relevant URL is added - """ - ath = api_test_helper - setup_info: AnnotateAgencySetupInfo = await setup_for_annotate_agency( - db_data_creator=ath.db_data_creator, - url_count=2 - ) - url_ids = setup_info.url_ids - - # User should submit an annotation and receive the next - response = await ath.request_validator.post_agency_annotation_and_get_next( - url_id=url_ids[0], - agency_annotation_post_info=URLAgencyAnnotationPostInfo( - suggested_agency=await ath.db_data_creator.agency(), - is_new=False - ) - - ) - assert response.next_annotation is not None - - # User should submit this annotation and receive none for the next - response = await ath.request_validator.post_agency_annotation_and_get_next( - url_id=url_ids[1], - agency_annotation_post_info=URLAgencyAnnotationPostInfo( - suggested_agency=await ath.db_data_creator.agency(), - is_new=False - ) - ) - assert response.next_annotation is None - - -@pytest.mark.asyncio -async def test_annotate_agency_submit_new(api_test_helper): - """ - Test Scenario: Submit New - Our user receives an annotation and marks it as `NEW` - This should complete successfully - And within the database the annotation should be marked as `NEW` - """ - ath = api_test_helper - adb_client = ath.adb_client() - setup_info: AnnotateAgencySetupInfo = await setup_for_annotate_agency( - db_data_creator=ath.db_data_creator, - url_count=1 - ) - url_ids = setup_info.url_ids - - # User should submit an annotation and mark it as New - response = await ath.request_validator.post_agency_annotation_and_get_next( - url_id=url_ids[0], - agency_annotation_post_info=URLAgencyAnnotationPostInfo( - suggested_agency=await ath.db_data_creator.agency(), - is_new=True - ) - ) - assert response.next_annotation is None - - # Within database, the annotation should be marked as `NEW` - all_manual_suggestions = await adb_client.get_all(UserUrlAgencySuggestion) - assert len(all_manual_suggestions) == 1 - assert all_manual_suggestions[0].is_new - -@pytest.mark.asyncio -async def test_annotate_all(api_test_helper): - """ - Test the happy path workflow for the all-annotations endpoint - The user should be able to get a valid URL (filtering on batch id if needed), - submit a full annotation, and receive another URL - """ - ath = api_test_helper - adb_client = ath.adb_client() - setup_info_1 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - url_mapping_1 = setup_info_1.url_mapping - setup_info_2 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - url_mapping_2 = setup_info_2.url_mapping - - # First, get a valid URL to annotate - get_response_1 = await ath.request_validator.get_next_url_for_all_annotations() - - # Apply the second batch id as a filter and see that a different URL is returned - get_response_2 = await ath.request_validator.get_next_url_for_all_annotations( - batch_id=setup_info_2.batch_id - ) - - assert get_response_1.next_annotation.url_info.url_id != get_response_2.next_annotation.url_info.url_id - - # Annotate the first and submit - agency_id = await ath.db_data_creator.agency() - post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( - url_id=url_mapping_1.url_id, - all_annotations_post_info=AllAnnotationPostInfo( - suggested_status=SuggestedStatus.RELEVANT, - record_type=RecordType.ACCIDENT_REPORTS, - agency=URLAgencyAnnotationPostInfo( - is_new=False, - suggested_agency=agency_id - ) - ) - ) - assert post_response_1.next_annotation is not None - - # Confirm the second is received - assert post_response_1.next_annotation.url_info.url_id == url_mapping_2.url_id - - # Upon submitting the second, confirm that no more URLs are returned through either POST or GET - post_response_2 = await ath.request_validator.post_all_annotations_and_get_next( - url_id=url_mapping_2.url_id, - all_annotations_post_info=AllAnnotationPostInfo( - suggested_status=SuggestedStatus.NOT_RELEVANT, - ) - ) - assert post_response_2.next_annotation is None - - get_response_3 = await ath.request_validator.get_next_url_for_all_annotations() - assert get_response_3.next_annotation is None - - - # Check that all annotations are present in the database - - # Should be two relevance annotations, one True and one False - all_relevance_suggestions: list[UserRelevantSuggestion] = await adb_client.get_all(UserRelevantSuggestion) - assert len(all_relevance_suggestions) == 2 - assert all_relevance_suggestions[0].suggested_status == SuggestedStatus.RELEVANT.value - assert all_relevance_suggestions[1].suggested_status == SuggestedStatus.NOT_RELEVANT.value - - # Should be one agency - all_agency_suggestions = await adb_client.get_all(UserUrlAgencySuggestion) - assert len(all_agency_suggestions) == 1 - assert all_agency_suggestions[0].is_new == False - assert all_agency_suggestions[0].agency_id == agency_id - - # Should be one record type - all_record_type_suggestions = await adb_client.get_all(UserRecordTypeSuggestion) - assert len(all_record_type_suggestions) == 1 - assert all_record_type_suggestions[0].record_type == RecordType.ACCIDENT_REPORTS.value - -@pytest.mark.asyncio -async def test_annotate_all_post_batch_filtering(api_test_helper): - """ - Batch filtering should also work when posting annotations - """ - ath = api_test_helper - adb_client = ath.adb_client() - setup_info_1 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - url_mapping_1 = setup_info_1.url_mapping - setup_info_2 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - setup_info_3 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - url_mapping_3 = setup_info_3.url_mapping - - # Submit the first annotation, using the third batch id, and receive the third URL - post_response_1 = await ath.request_validator.post_all_annotations_and_get_next( - url_id=url_mapping_1.url_id, - batch_id=setup_info_3.batch_id, - all_annotations_post_info=AllAnnotationPostInfo( - suggested_status=SuggestedStatus.RELEVANT, - record_type=RecordType.ACCIDENT_REPORTS, - agency=URLAgencyAnnotationPostInfo( - is_new=True - ) - ) - ) - - assert post_response_1.next_annotation.url_info.url_id == url_mapping_3.url_id - - -@pytest.mark.asyncio -async def test_annotate_all_validation_error(api_test_helper): - """ - Validation errors in the PostInfo DTO should result in a 400 BAD REQUEST response - """ - ath = api_test_helper - setup_info_1 = await setup_for_get_next_url_for_final_review( - db_data_creator=ath.db_data_creator, include_user_annotations=False - ) - url_mapping_1 = setup_info_1.url_mapping - - with pytest.raises(FailedValidationException) as e: - response = await ath.request_validator.post_all_annotations_and_get_next( - url_id=url_mapping_1.url_id, - all_annotations_post_info=AllAnnotationPostInfo( - suggested_status=SuggestedStatus.NOT_RELEVANT, - record_type=RecordType.ACCIDENT_REPORTS - ) - ) diff --git a/tests/automated/integration/api/test_batch.py b/tests/automated/integration/api/test_batch.py deleted file mode 100644 index eea90bf2..00000000 --- a/tests/automated/integration/api/test_batch.py +++ /dev/null @@ -1,237 +0,0 @@ -import pytest - -from src.db.dtos.batch import BatchInfo -from src.db.dtos.url.insert import InsertURLsInfo -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO -from src.collectors.enums import CollectorType, URLStatus -from src.core.enums import BatchStatus -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters - - -@pytest.mark.asyncio -async def test_get_batch_summaries(api_test_helper): - ath = api_test_helper - - batch_params = [ - TestBatchCreationParameters( - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING - ), - TestURLCreationParameters( - count=2, - status=URLStatus.SUBMITTED - ) - ] - ), - TestBatchCreationParameters( - urls=[ - TestURLCreationParameters( - count=4, - status=URLStatus.NOT_RELEVANT - ), - TestURLCreationParameters( - count=3, - status=URLStatus.ERROR - ) - ] - ), - TestBatchCreationParameters( - urls=[ - TestURLCreationParameters( - count=7, - status=URLStatus.DUPLICATE - ), - TestURLCreationParameters( - count=1, - status=URLStatus.SUBMITTED - ) - ] - ) - ] - - batch_1_creation_info = await ath.db_data_creator.batch_v2(batch_params[0]) - batch_2_creation_info = await ath.db_data_creator.batch_v2(batch_params[1]) - batch_3_creation_info = await ath.db_data_creator.batch_v2(batch_params[2]) - - batch_1_id = batch_1_creation_info.batch_id - batch_2_id = batch_2_creation_info.batch_id - batch_3_id = batch_3_creation_info.batch_id - - - response = ath.request_validator.get_batch_statuses() - results = response.results - - assert len(results) == 3 - - result_1 = results[0] - assert result_1.id == batch_1_id - assert result_1.status == BatchStatus.READY_TO_LABEL - counts_1 = result_1.url_counts - assert counts_1.total == 3 - assert counts_1.pending == 1 - assert counts_1.submitted == 2 - assert counts_1.not_relevant == 0 - assert counts_1.duplicate == 0 - assert counts_1.errored == 0 - - result_2 = results[1] - assert result_2.id == batch_2_id - counts_2 = result_2.url_counts - assert counts_2.total == 7 - assert counts_2.not_relevant == 4 - assert counts_2.errored == 3 - assert counts_2.pending == 0 - assert counts_2.submitted == 0 - assert counts_2.duplicate == 0 - - result_3 = results[2] - assert result_3.id == batch_3_id - counts_3 = result_3.url_counts - assert counts_3.total == 8 - assert counts_3.not_relevant == 0 - assert counts_3.errored == 0 - assert counts_3.pending == 0 - assert counts_3.submitted == 1 - assert counts_3.duplicate == 7 - - - - - - -@pytest.mark.asyncio -async def test_get_batch_summaries_pending_url_filter(api_test_helper): - ath = api_test_helper - - # Add an errored out batch - batch_error = await ath.db_data_creator.batch_and_urls( - strategy=CollectorType.EXAMPLE, - url_count=2, - batch_status=BatchStatus.ERROR - ) - - # Add a batch with pending urls - batch_pending = await ath.db_data_creator.batch_and_urls( - strategy=CollectorType.EXAMPLE, - url_count=2, - batch_status=BatchStatus.READY_TO_LABEL, - with_html_content=True, - url_status=URLStatus.PENDING - ) - - # Add a batch with submitted URLs - batch_submitted = await ath.db_data_creator.batch_and_urls( - strategy=CollectorType.EXAMPLE, - url_count=2, - batch_status=BatchStatus.READY_TO_LABEL, - with_html_content=True, - url_status=URLStatus.SUBMITTED - ) - - # Add an aborted batch - batch_aborted = await ath.db_data_creator.batch_and_urls( - strategy=CollectorType.EXAMPLE, - url_count=2, - batch_status=BatchStatus.ABORTED - ) - - # Add a batch with validated URLs - batch_validated = await ath.db_data_creator.batch_and_urls( - strategy=CollectorType.EXAMPLE, - url_count=2, - batch_status=BatchStatus.READY_TO_LABEL, - with_html_content=True, - url_status=URLStatus.VALIDATED - ) - - # Test filter for pending URLs and only retrieve the second batch - pending_urls_results = ath.request_validator.get_batch_statuses( - has_pending_urls=True - ) - - assert len(pending_urls_results.results) == 1 - assert pending_urls_results.results[0].id == batch_pending.batch_id - - # Test filter without pending URLs and retrieve the other four batches - no_pending_urls_results = ath.request_validator.get_batch_statuses( - has_pending_urls=False - ) - - assert len(no_pending_urls_results.results) == 4 - for result in no_pending_urls_results.results: - assert result.id in [ - batch_error.batch_id, - batch_submitted.batch_id, - batch_validated.batch_id, - batch_aborted.batch_id - ] - - # Test no filter for pending URLs and retrieve all batches - no_filter_results = ath.request_validator.get_batch_statuses() - - assert len(no_filter_results.results) == 5 - - - - -def test_abort_batch(api_test_helper): - ath = api_test_helper - - dto = ExampleInputDTO( - sleep_time=1 - ) - - batch_id = ath.request_validator.example_collector(dto=dto)["batch_id"] - - response = ath.request_validator.abort_batch(batch_id=batch_id) - - assert response.message == "Batch aborted." - - bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) - - assert bi.status == BatchStatus.ABORTED - -def test_get_batch_urls(api_test_helper): - - # Insert batch and urls into database - ath = api_test_helper - batch_id = ath.db_data_creator.batch() - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=101) - - response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=1) - assert len(response.urls) == 100 - # Check that the first url corresponds to the first url inserted - assert response.urls[0].url == iui.url_mappings[0].url - # Check that the last url corresponds to the 100th url inserted - assert response.urls[-1].url == iui.url_mappings[99].url - - - # Check that a more limited set of urls exist - response = ath.request_validator.get_batch_urls(batch_id=batch_id, page=2) - assert len(response.urls) == 1 - # Check that this url corresponds to the last url inserted - assert response.urls[0].url == iui.url_mappings[-1].url - -def test_get_duplicate_urls(api_test_helper): - - # Insert batch and url into database - ath = api_test_helper - batch_id = ath.db_data_creator.batch() - iui: InsertURLsInfo = ath.db_data_creator.urls(batch_id=batch_id, url_count=101) - # Get a list of all url ids - url_ids = [url.url_id for url in iui.url_mappings] - - # Create a second batch which will be associated with the duplicates - dup_batch_id = ath.db_data_creator.batch() - - # Insert duplicate urls into database - ath.db_data_creator.duplicate_urls(duplicate_batch_id=dup_batch_id, url_ids=url_ids) - - response = ath.request_validator.get_batch_url_duplicates(batch_id=dup_batch_id, page=1) - assert len(response.duplicates) == 100 - - response = ath.request_validator.get_batch_url_duplicates(batch_id=dup_batch_id, page=2) - assert len(response.duplicates) == 1 \ No newline at end of file diff --git a/tests/automated/integration/api/test_example_collector.py b/tests/automated/integration/api/test_example_collector.py deleted file mode 100644 index 1e20362d..00000000 --- a/tests/automated/integration/api/test_example_collector.py +++ /dev/null @@ -1,142 +0,0 @@ -import asyncio -from unittest.mock import AsyncMock - -import pytest - -from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse -from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse -from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.batch import BatchInfo -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO -from src.collectors.source_collectors.example.core import ExampleCollector -from src.collectors.enums import CollectorType -from src.core.logger import AsyncCoreLogger -from src.core.enums import BatchStatus -from tests.helpers.patch_functions import block_sleep -from tests.automated.integration.api.conftest import disable_task_trigger - - -@pytest.mark.asyncio -async def test_example_collector(api_test_helper, monkeypatch): - ath = api_test_helper - - barrier = await block_sleep(monkeypatch) - - # Temporarily disable task trigger - disable_task_trigger(ath) - - - logger = AsyncCoreLogger(adb_client=AsyncDatabaseClient(), flush_interval=1) - await logger.__aenter__() - ath.async_core.collector_manager.logger = logger - - dto = ExampleInputDTO( - sleep_time=1 - ) - - # Request Example Collector - data = ath.request_validator.example_collector( - dto=dto - ) - batch_id = data["batch_id"] - assert batch_id is not None - assert data["message"] == "Started example collector." - - # Yield control so coroutine runs up to the barrier - await asyncio.sleep(0) - - - # Check that batch currently shows as In Process - bsr: GetBatchSummariesResponse = ath.request_validator.get_batch_statuses( - status=BatchStatus.IN_PROCESS - ) - assert len(bsr.results) == 1 - bsi: BatchInfo = bsr.results[0] - - assert bsi.id == batch_id - assert bsi.strategy == CollectorType.EXAMPLE.value - assert bsi.status == BatchStatus.IN_PROCESS - - # Release the barrier to resume execution - barrier.release() - - await ath.wait_for_all_batches_to_complete() - - csr: GetBatchSummariesResponse = ath.request_validator.get_batch_statuses( - collector_type=CollectorType.EXAMPLE, - status=BatchStatus.READY_TO_LABEL - ) - - assert len(csr.results) == 1 - bsi: BatchSummary = csr.results[0] - - assert bsi.id == batch_id - assert bsi.strategy == CollectorType.EXAMPLE.value - assert bsi.status == BatchStatus.READY_TO_LABEL - - bi: BatchSummary = ath.request_validator.get_batch_info(batch_id=batch_id) - assert bi.status == BatchStatus.READY_TO_LABEL - assert bi.parameters == dto.model_dump() - assert bi.strategy == CollectorType.EXAMPLE.value - assert bi.user_id is not None - - # Flush early to ensure logs are written - await logger.flush_all() - - lr: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) - - assert len(lr.logs) > 0 - - # Check that task was triggered - ath.async_core.collector_manager.\ - post_collection_function_trigger.\ - trigger_or_rerun.assert_called_once() - - await logger.__aexit__(None, None, None) - -@pytest.mark.asyncio -async def test_example_collector_error(api_test_helper, monkeypatch): - """ - Test that when an error occurs in a collector, the batch is properly update - """ - ath = api_test_helper - - logger = AsyncCoreLogger(adb_client=AsyncDatabaseClient(), flush_interval=1) - await logger.__aenter__() - ath.async_core.collector_manager.logger = logger - - # Patch the collector to raise an exception during run_implementation - mock = AsyncMock() - mock.side_effect = Exception("Collector failed!") - monkeypatch.setattr(ExampleCollector, 'run_implementation', mock) - - dto = ExampleInputDTO( - sleep_time=1 - ) - - data = ath.request_validator.example_collector( - dto=dto - ) - batch_id = data["batch_id"] - assert batch_id is not None - assert data["message"] == "Started example collector." - - await ath.wait_for_all_batches_to_complete() - - bi: BatchSummary = ath.request_validator.get_batch_info(batch_id=batch_id) - - assert bi.status == BatchStatus.ERROR - - # Check there are logs - assert not logger.log_queue.empty() - await logger.flush_all() - assert logger.log_queue.empty() - - gbl: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) - assert gbl.logs[-1].log == "Error: Collector failed!" - await logger.__aexit__(None, None, None) - - - - diff --git a/tests/automated/integration/api/test_manual_batch.py b/tests/automated/integration/api/test_manual_batch.py index a7be37e4..dae5ee4f 100644 --- a/tests/automated/integration/api/test_manual_batch.py +++ b/tests/automated/integration/api/test_manual_batch.py @@ -2,10 +2,10 @@ import pytest from src.api.endpoints.collector.dtos.manual_batch.post import ManualBatchInnerInputDTO, ManualBatchInputDTO -from src.db.models.instantiations.link.link_batch_urls import LinkBatchURL -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.batch import Batch +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.batch.sqlalchemy import Batch from src.collectors.enums import CollectorType from src.core.enums import RecordType @@ -94,7 +94,7 @@ def check_link(link: LinkBatchURL): def check_url(url: URL, url_only: bool): assert url.url is not None - other_attributes = ["name", "description", "collector_metadata", "record_type"] + other_attributes = ["name", "description", "collector_metadata"] return check_attributes(url, other_attributes, url_only) diff --git a/tests/automated/integration/api/test_task.py b/tests/automated/integration/api/test_task.py index 95ebe003..bda246dc 100644 --- a/tests/automated/integration/api/test_task.py +++ b/tests/automated/integration/api/test_task.py @@ -9,7 +9,7 @@ async def task_setup(ath: APITestHelper) -> int: url_ids = [url.url_id for url in iui.url_mappings] task_id = await ath.db_data_creator.task(url_ids=url_ids) - await ath.db_data_creator.error_info(url_ids=[url_ids[0]], task_id=task_id) + await ath.db_data_creator.task_errors(url_ids=[url_ids[0]], task_id=task_id) return task_id diff --git a/tests/automated/integration/api/test_url.py b/tests/automated/integration/api/test_url.py deleted file mode 100644 index e59c8299..00000000 --- a/tests/automated/integration/api/test_url.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest - -from src.api.endpoints.url.get.dto import GetURLsResponseInfo -from src.db.dtos.url.insert import InsertURLsInfo - - -@pytest.mark.asyncio -async def test_get_urls(api_test_helper): - # Basic test, no results - data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls() - - assert data.urls == [] - assert data.count == 0 - - db_data_creator = api_test_helper.db_data_creator - - # Create batch with status `in-process` and strategy `example` - batch_id = db_data_creator.batch() - # Create 2 URLs with outcome `pending` - iui: InsertURLsInfo = db_data_creator.urls(batch_id=batch_id, url_count=3) - - url_id_1st = iui.url_mappings[0].url_id - - # Get the latter 2 urls - url_ids = [iui.url_mappings[1].url_id, iui.url_mappings[2].url_id] - - # Add errors - await db_data_creator.error_info(url_ids=url_ids) - - - data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls() - assert data.count == 3 - assert len(data.urls) == 3 - assert data.urls[0].url == iui.url_mappings[0].url - - for i in range(1, 3): - assert data.urls[i].url == iui.url_mappings[i].url - assert len(data.urls[i].errors) == 1 - - # Retrieve data again with errors only - data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls(errors=True) - assert data.count == 2 - assert len(data.urls) == 2 - for url in data.urls: - assert url.id != url_id_1st - diff --git a/tests/automated/integration/api/url/__init__.py b/tests/automated/integration/api/url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/url/by_id/__init__.py b/tests/automated/integration/api/url/by_id/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/url/by_id/snapshot/__init__.py b/tests/automated/integration/api/url/by_id/snapshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/api/url/by_id/snapshot/test_not_found.py b/tests/automated/integration/api/url/by_id/snapshot/test_not_found.py new file mode 100644 index 00000000..cce84649 --- /dev/null +++ b/tests/automated/integration/api/url/by_id/snapshot/test_not_found.py @@ -0,0 +1,10 @@ +import pytest + +from tests.helpers.api_test_helper import APITestHelper +from fastapi import Response + +@pytest.mark.asyncio +async def test_get_url_screenshot_not_found(api_test_helper: APITestHelper): + + response: Response = await api_test_helper.request_validator.get_url_screenshot(url_id=1) + assert response.status_code == 404 \ No newline at end of file diff --git a/tests/automated/integration/api/url/by_id/snapshot/test_success.py b/tests/automated/integration/api/url/by_id/snapshot/test_success.py new file mode 100644 index 00000000..e3ea9d73 --- /dev/null +++ b/tests/automated/integration/api/url/by_id/snapshot/test_success.py @@ -0,0 +1,32 @@ +import pytest + +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from tests.automated.integration.api._helpers.RequestValidator import RequestValidator +from tests.helpers.api_test_helper import APITestHelper +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_get_url_screenshot_success( + api_test_helper: APITestHelper +): + ath: APITestHelper = api_test_helper + ddc: DBDataCreator = api_test_helper.db_data_creator + rv: RequestValidator = ath.request_validator + + url_mapping: URLMapping = (await ddc.create_urls())[0] + url_id: int = url_mapping.url_id + + url_screenshot = URLScreenshot( + url_id=url_id, + content=b"test", + file_size=4 + ) + await ddc.adb_client.add(url_screenshot) + + response = await rv.get_url_screenshot(url_id=url_id) + assert response.status_code == 200 + assert response.headers["Content-Type"] == "image/webp" + assert response.content == b"test" + assert response.headers["Content-Length"] == "4" diff --git a/tests/automated/integration/api/url/test_get.py b/tests/automated/integration/api/url/test_get.py new file mode 100644 index 00000000..8c95c670 --- /dev/null +++ b/tests/automated/integration/api/url/test_get.py @@ -0,0 +1,47 @@ +import pytest + +from src.api.endpoints.url.get.dto import GetURLsResponseInfo +from src.db.dtos.url.insert import InsertURLsInfo +from tests.helpers.api_test_helper import APITestHelper + + +@pytest.mark.asyncio +async def test_get_urls(api_test_helper: APITestHelper): + # Basic test, no results + data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls() + + assert data.urls == [] + assert data.count == 0 + + db_data_creator = api_test_helper.db_data_creator + + # Create batch with status `in-process` and strategy `example` + batch_id = db_data_creator.batch() + # Create 2 URLs with outcome `pending` + iui: InsertURLsInfo = db_data_creator.urls(batch_id=batch_id, url_count=3) + + url_id_1st = iui.url_mappings[0].url_id + + # Get the latter 2 urls + url_ids = [iui.url_mappings[1].url_id, iui.url_mappings[2].url_id] + + # Add errors + await db_data_creator.task_errors(url_ids=url_ids) + + + data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls() + assert data.count == 3 + assert len(data.urls) == 3 + assert data.urls[0].url == iui.url_mappings[0].url + + for i in range(1, 3): + assert data.urls[i].url == iui.url_mappings[i].url + assert len(data.urls[i].errors) == 1 + + # Retrieve data again with errors only + data: GetURLsResponseInfo = api_test_helper.request_validator.get_urls(errors=True) + assert data.count == 2 + assert len(data.urls) == 2 + for url in data.urls: + assert url.id != url_id_1st + diff --git a/tests/automated/integration/conftest.py b/tests/automated/integration/conftest.py index 7e4fc535..574f35f4 100644 --- a/tests/automated/integration/conftest.py +++ b/tests/automated/integration/conftest.py @@ -1,11 +1,16 @@ from unittest.mock import MagicMock import pytest +import pytest_asyncio from src.collectors.manager import AsyncCollectorManager from src.core.core import AsyncCore from src.core.logger import AsyncCoreLogger from src.db.client.async_ import AsyncDatabaseClient +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo +from tests.helpers.data_creator.models.creation_info.us_state import USStateCreationInfo @pytest.fixture @@ -25,4 +30,67 @@ def test_async_core(adb_client_test): ) yield core core.shutdown() - logger.shutdown() \ No newline at end of file + logger.shutdown() + +@pytest_asyncio.fixture +async def pennsylvania( + db_data_creator: DBDataCreator +) -> USStateCreationInfo: + """Creates Pennsylvania state and returns its state and location ID""" + return await db_data_creator.create_us_state( + name="Pennsylvania", + iso="PA" + ) + +@pytest_asyncio.fixture +async def allegheny_county( + db_data_creator: DBDataCreator, + pennsylvania: USStateCreationInfo +) -> CountyCreationInfo: + return await db_data_creator.create_county( + state_id=pennsylvania.us_state_id, + name="Allegheny" + ) + +@pytest_asyncio.fixture +async def pittsburgh_locality( + db_data_creator: DBDataCreator, + pennsylvania: USStateCreationInfo, + allegheny_county: CountyCreationInfo +) -> LocalityCreationInfo: + return await db_data_creator.create_locality( + state_id=pennsylvania.us_state_id, + county_id=allegheny_county.county_id, + name="Pittsburgh" + ) + +@pytest_asyncio.fixture +async def california( + db_data_creator: DBDataCreator, +) -> USStateCreationInfo: + return await db_data_creator.create_us_state( + name="California", + iso="CA" + ) + +@pytest_asyncio.fixture +async def los_angeles_county( + db_data_creator: DBDataCreator, + california: USStateCreationInfo +) -> CountyCreationInfo: + return await db_data_creator.create_county( + state_id=california.us_state_id, + name="Los Angeles" + ) + +@pytest_asyncio.fixture +async def los_angeles_locality( + db_data_creator: DBDataCreator, + california: USStateCreationInfo, + los_angeles_county: CountyCreationInfo +) -> LocalityCreationInfo: + return await db_data_creator.create_locality( + state_id=california.us_state_id, + county_id=los_angeles_county.county_id, + name="Los Angeles" + ) \ No newline at end of file diff --git a/tests/automated/integration/core/async_/conclude_task/helpers.py b/tests/automated/integration/core/async_/conclude_task/helpers.py index 35e106c8..923b3cc9 100644 --- a/tests/automated/integration/core/async_/conclude_task/helpers.py +++ b/tests/automated/integration/core/async_/conclude_task/helpers.py @@ -1,4 +1,4 @@ -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo +from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from src.db.enums import TaskType from tests.automated.integration.core.async_.conclude_task.setup_info import TestAsyncCoreSetupInfo @@ -9,10 +9,9 @@ def setup_run_info( outcome: TaskOperatorOutcome, message: str = "" ): - run_info = URLTaskOperatorRunInfo( + run_info = TaskOperatorRunInfo( task_id=setup_info.task_id, task_type=TaskType.HTML, - linked_url_ids=setup_info.url_ids, outcome=outcome, message=message, ) diff --git a/tests/automated/integration/core/async_/conclude_task/test_error.py b/tests/automated/integration/core/async_/conclude_task/test_error.py index 0f92fd26..1a31b87e 100644 --- a/tests/automated/integration/core/async_/conclude_task/test_error.py +++ b/tests/automated/integration/core/async_/conclude_task/test_error.py @@ -1,13 +1,12 @@ import pytest from src.core.enums import BatchStatus -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome -from src.db.enums import TaskType +from src.db.models.impl.task.enums import TaskStatus from tests.automated.integration.core.async_.conclude_task.helpers import setup_run_info from tests.automated.integration.core.async_.conclude_task.setup_info import TestAsyncCoreSetupInfo from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -27,6 +26,5 @@ async def test_conclude_task_error( task_info = await ddc.adb_client.get_task_info(task_id=setup.task_id) - assert task_info.task_status == BatchStatus.ERROR + assert task_info.task_status == TaskStatus.ERROR assert task_info.error_info == "test error" - assert len(task_info.urls) == 3 diff --git a/tests/automated/integration/core/async_/conclude_task/test_success.py b/tests/automated/integration/core/async_/conclude_task/test_success.py index 19bd0f4f..03cc5b52 100644 --- a/tests/automated/integration/core/async_/conclude_task/test_success.py +++ b/tests/automated/integration/core/async_/conclude_task/test_success.py @@ -1,13 +1,12 @@ import pytest from src.core.enums import BatchStatus -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome -from src.db.enums import TaskType +from src.db.models.impl.task.enums import TaskStatus from tests.automated.integration.core.async_.conclude_task.helpers import setup_run_info from tests.automated.integration.core.async_.conclude_task.setup_info import TestAsyncCoreSetupInfo from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -27,5 +26,4 @@ async def test_conclude_task_success( task_info = await ddc.adb_client.get_task_info(task_id=setup.task_id) - assert task_info.task_status == BatchStatus.READY_TO_LABEL - assert len(task_info.urls) == 3 + assert task_info.task_status == TaskStatus.COMPLETE diff --git a/tests/automated/integration/core/async_/run_task/test_break_loop.py b/tests/automated/integration/core/async_/run_task/test_break_loop.py index e438c26d..71b5704f 100644 --- a/tests/automated/integration/core/async_/run_task/test_break_loop.py +++ b/tests/automated/integration/core/async_/run_task/test_break_loop.py @@ -1,13 +1,15 @@ import types -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, create_autospec import pytest +from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.url.models.entry import URLTaskEntry +from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.db.enums import TaskType -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -19,23 +21,26 @@ async def test_run_task_break_loop(db_data_creator: DBDataCreator): and an alert should be sent to discord """ - async def run_task(self, task_id: int) -> URLTaskOperatorRunInfo: - return URLTaskOperatorRunInfo( - task_id=task_id, + async def run_task(self) -> TaskOperatorRunInfo: + return TaskOperatorRunInfo( + task_id=1, outcome=TaskOperatorOutcome.SUCCESS, - linked_url_ids=[1, 2, 3], task_type=TaskType.HTML ) core = setup_async_core(db_data_creator.adb_client) core.task_manager.conclude_task = AsyncMock() - mock_operator = AsyncMock() + mock_operator = create_autospec(URLTaskOperatorBase, instance=True) mock_operator.meets_task_prerequisites = AsyncMock(return_value=True) mock_operator.task_type = TaskType.HTML mock_operator.run_task = types.MethodType(run_task, mock_operator) + entry = URLTaskEntry( + operator=mock_operator, + enabled=True + ) - core.task_manager.loader.get_task_operators = AsyncMock(return_value=[mock_operator]) + core.task_manager.loader.load_entries = AsyncMock(return_value=[entry]) await core.task_manager.trigger_task_run() core.task_manager.handler.discord_poster.post_to_discord.assert_called_once_with( diff --git a/tests/automated/integration/core/async_/run_task/test_prereq_met.py b/tests/automated/integration/core/async_/run_task/test_prereq_met.py index b171402d..e5425fd9 100644 --- a/tests/automated/integration/core/async_/run_task/test_prereq_met.py +++ b/tests/automated/integration/core/async_/run_task/test_prereq_met.py @@ -1,51 +1,50 @@ import types -from unittest.mock import AsyncMock, call +from unittest.mock import AsyncMock, call, create_autospec import pytest from src.core.enums import BatchStatus -from src.core.tasks.dtos.run_info import URLTaskOperatorRunInfo +from src.core.tasks.base.run_info import TaskOperatorRunInfo from src.core.tasks.url.enums import TaskOperatorOutcome +from src.core.tasks.url.models.entry import URLTaskEntry +from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.db.enums import TaskType -from src.db.models.instantiations.task.core import Task +from src.db.models.impl.task.core import Task from tests.automated.integration.core.async_.helpers import setup_async_core -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio async def test_run_task_prereq_met(db_data_creator: DBDataCreator): """ When a task pre-requisite is met, the task should be run - And a task entry should be created in the database """ - async def run_task(self, task_id: int) -> URLTaskOperatorRunInfo: - return URLTaskOperatorRunInfo( - task_id=task_id, + async def run_task(self) -> TaskOperatorRunInfo: + return TaskOperatorRunInfo( + task_id=1, task_type=TaskType.HTML, outcome=TaskOperatorOutcome.SUCCESS, - linked_url_ids=[1, 2, 3] ) core = setup_async_core(db_data_creator.adb_client) core.task_manager.conclude_task = AsyncMock() - mock_operator = AsyncMock() + mock_operator = create_autospec(URLTaskOperatorBase, instance=True) mock_operator.meets_task_prerequisites = AsyncMock( side_effect=[True, False] ) mock_operator.task_type = TaskType.HTML mock_operator.run_task = types.MethodType(run_task, mock_operator) + entry = URLTaskEntry( + operator=mock_operator, + enabled=True + ) - core.task_manager.loader.get_task_operators = AsyncMock(return_value=[mock_operator]) + core.task_manager.loader.load_entries = AsyncMock(return_value=[entry]) await core.run_tasks() # There should be two calls to meets_task_prerequisites mock_operator.meets_task_prerequisites.assert_has_calls([call(), call()]) - results = await db_data_creator.adb_client.get_all(Task) - - assert len(results) == 1 - assert results[0].task_status == BatchStatus.IN_PROCESS.value - core.task_manager.conclude_task.assert_called_once() diff --git a/tests/automated/integration/core/async_/run_task/test_prereq_not_met.py b/tests/automated/integration/core/async_/run_task/test_prereq_not_met.py index ef068cd5..286c14dd 100644 --- a/tests/automated/integration/core/async_/run_task/test_prereq_not_met.py +++ b/tests/automated/integration/core/async_/run_task/test_prereq_not_met.py @@ -1,7 +1,9 @@ -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, create_autospec import pytest +from src.core.tasks.url.models.entry import URLTaskEntry +from src.core.tasks.url.operators.base import URLTaskOperatorBase from tests.automated.integration.core.async_.helpers import setup_async_core @@ -12,9 +14,10 @@ async def test_run_task_prereq_not_met(): """ core = setup_async_core(AsyncMock()) - mock_operator = AsyncMock() + mock_operator = create_autospec(URLTaskOperatorBase, instance=True) mock_operator.meets_task_prerequisites = AsyncMock(return_value=False) - core.task_manager.loader.get_task_operators = AsyncMock(return_value=[mock_operator]) + entry = URLTaskEntry(operator=mock_operator, enabled=True) + core.task_manager.loader.load_entries = AsyncMock(return_value=[entry]) await core.run_tasks() mock_operator.meets_task_prerequisites.assert_called_once() diff --git a/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py b/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py index 33a93998..c419fb70 100644 --- a/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py +++ b/tests/automated/integration/db/client/annotate_url/test_agency_not_in_db.py @@ -1,9 +1,9 @@ import pytest from src.db.constants import PLACEHOLDER_AGENCY_NAME -from src.db.models.instantiations.agency import Agency +from src.db.models.impl.agency.sqlalchemy import Agency from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py b/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py deleted file mode 100644 index ccf76dc8..00000000 --- a/tests/automated/integration/db/client/annotate_url/test_marked_not_relevant.py +++ /dev/null @@ -1,66 +0,0 @@ -import pytest - -from src.core.enums import SuggestedStatus -from src.db.dtos.url.mapping import URLMapping -from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_annotate_url_marked_not_relevant(db_data_creator: DBDataCreator): - """ - If a URL is marked not relevant by the user, they should not receive that URL - in calls to get an annotation for record type or agency - Other users should still receive the URL - """ - setup_info = await setup_for_get_next_url_for_annotation( - db_data_creator=db_data_creator, - url_count=2 - ) - adb_client = db_data_creator.adb_client - url_to_mark_not_relevant: URLMapping = setup_info.insert_urls_info.url_mappings[0] - url_to_mark_relevant: URLMapping = setup_info.insert_urls_info.url_mappings[1] - for url_mapping in setup_info.insert_urls_info.url_mappings: - await db_data_creator.agency_auto_suggestions( - url_id=url_mapping.url_id, - count=3 - ) - await adb_client.add_user_relevant_suggestion( - user_id=1, - url_id=url_to_mark_not_relevant.url_id, - suggested_status=SuggestedStatus.NOT_RELEVANT - ) - await adb_client.add_user_relevant_suggestion( - user_id=1, - url_id=url_to_mark_relevant.url_id, - suggested_status=SuggestedStatus.RELEVANT - ) - - # User should not receive the URL for record type annotation - record_type_annotation_info = await adb_client.get_next_url_for_record_type_annotation( - user_id=1, - batch_id=None - ) - assert record_type_annotation_info.url_info.url_id != url_to_mark_not_relevant.url_id - - # Other users also should not receive the URL for record type annotation - record_type_annotation_info = await adb_client.get_next_url_for_record_type_annotation( - user_id=2, - batch_id=None - ) - assert record_type_annotation_info.url_info.url_id != \ - url_to_mark_not_relevant.url_id, "Other users should not receive the URL for record type annotation" - - # User should not receive the URL for agency annotation - agency_annotation_info_user_1 = await adb_client.get_next_url_agency_for_annotation( - user_id=1, - batch_id=None - ) - assert agency_annotation_info_user_1.next_annotation.url_info.url_id != url_to_mark_not_relevant.url_id - - # Other users also should not receive the URL for agency annotation - agency_annotation_info_user_2 = await adb_client.get_next_url_agency_for_annotation( - user_id=2, - batch_id=None - ) - assert agency_annotation_info_user_1.next_annotation.url_info.url_id != url_to_mark_not_relevant.url_id diff --git a/tests/automated/integration/db/client/approve_url/test_basic.py b/tests/automated/integration/db/client/approve_url/test_basic.py index 590f9cd1..c9eb62b1 100644 --- a/tests/automated/integration/db/client/approve_url/test_basic.py +++ b/tests/automated/integration/db/client/approve_url/test_basic.py @@ -3,12 +3,14 @@ from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo from src.collectors.enums import URLStatus from src.core.enums import RecordType -from src.db.models.instantiations.confirmed_url_agency import ConfirmedURLAgency -from src.db.models.instantiations.url.core import URL -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.reviewing_user import ReviewingUserURL +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.models.impl.url.reviewing_user import ReviewingUserURL from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -41,12 +43,21 @@ async def test_approve_url_basic(db_data_creator: DBDataCreator): assert len(urls) == 1 url = urls[0] assert url.id == url_mapping.url_id - assert url.record_type == RecordType.ARREST_RECORDS.value - assert url.outcome == URLStatus.VALIDATED.value + assert url.status == URLStatus.OK assert url.name == "Test Name" assert url.description == "Test Description" - confirmed_agency: list[ConfirmedURLAgency] = await adb_client.get_all(ConfirmedURLAgency) + record_types: list[URLRecordType] = await adb_client.get_all(URLRecordType) + assert len(record_types) == 1 + assert record_types[0].record_type == RecordType.ARREST_RECORDS + + # Confirm presence of validated flag + validated_flags: list[FlagURLValidated] = await adb_client.get_all(FlagURLValidated) + assert len(validated_flags) == 1 + assert validated_flags[0].url_id == url_mapping.url_id + + + confirmed_agency: list[LinkURLAgency] = await adb_client.get_all(LinkURLAgency) assert len(confirmed_agency) == 1 assert confirmed_agency[0].url_id == url_mapping.url_id assert confirmed_agency[0].agency_id == agency_id diff --git a/tests/automated/integration/db/client/approve_url/test_error.py b/tests/automated/integration/db/client/approve_url/test_error.py index 52871e76..352e737a 100644 --- a/tests/automated/integration/db/client/approve_url/test_error.py +++ b/tests/automated/integration/db/client/approve_url/test_error.py @@ -4,7 +4,7 @@ from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo from src.core.enums import RecordType from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -30,10 +30,8 @@ async def test_approval_url_error(db_data_creator: DBDataCreator): # Create kwarg dictionary with all required approval info fields kwarg_dict = { - "record_type": RecordType.ARREST_RECORDS, "agency_ids": [await db_data_creator.agency()], "name": "Test Name", - "description": "Test Description", } # For each keyword, create a copy of the kwargs and set that one to none # Confirm it produces the correct error diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py deleted file mode 100644 index adb48844..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_basic.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest - -from src.core.enums import SuggestedStatus, RecordType -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_basic(db_data_creator: DBDataCreator): - """ - Test that an annotated URL is returned - """ - - setup_info = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=1, - include_user_annotations=True - ) - - url_mapping = setup_info.url_mapping - # Add agency auto suggestions - await db_data_creator.agency_auto_suggestions( - url_id=url_mapping.url_id, - count=3 - ) - - - outer_result = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - result = outer_result.next_source - - assert result.url == url_mapping.url - html_info = result.html_info - assert html_info.description == "test description" - assert html_info.title == "test html content" - - annotation_info = result.annotations - relevant_info = annotation_info.relevant - assert relevant_info.auto.is_relevant == True - assert relevant_info.user == SuggestedStatus.NOT_RELEVANT - - record_type_info = annotation_info.record_type - assert record_type_info.auto == RecordType.ARREST_RECORDS - assert record_type_info.user == RecordType.ACCIDENT_REPORTS - - agency_info = annotation_info.agency - auto_agency_suggestions = agency_info.auto - assert auto_agency_suggestions.unknown == False - assert len(auto_agency_suggestions.suggestions) == 3 - - # Check user agency suggestion exists and is correct - assert agency_info.user.pdap_agency_id == setup_info.user_agency_id diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py deleted file mode 100644 index bce7d8e2..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_batch_id_filtering.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_batch_id_filtering(db_data_creator: DBDataCreator): - setup_info_1 = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=3, - include_user_annotations=True - ) - - setup_info_2 = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=3, - include_user_annotations=True - ) - - url_mapping_1 = setup_info_1.url_mapping - url_mapping_2 = setup_info_2.url_mapping - - # If a batch id is provided, return first valid URL with that batch id - result_with_batch_id = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=setup_info_2.batch_id - ) - - assert result_with_batch_id.next_source.url == url_mapping_2.url - - # If no batch id is provided, return first valid URL - result_no_batch_id =await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - - assert result_no_batch_id.next_source.url == url_mapping_1.url diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py deleted file mode 100644 index 874dba18..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_favor_more_components.py +++ /dev/null @@ -1,42 +0,0 @@ -import pytest - -from src.core.enums import SuggestionType -from tests.helpers.setup.final_review.core import setup_for_get_next_url_for_final_review -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_favor_more_components(db_data_creator: DBDataCreator): - """ - Test in the case of two URLs, favoring the one with more annotations for more components - i.e., if one has annotations for record type and agency id, that should be favored over one with just record type - """ - - setup_info_without_user_anno = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=3, - include_user_annotations=False - ) - url_mapping_without_user_anno = setup_info_without_user_anno.url_mapping - - setup_info_with_user_anno = await setup_for_get_next_url_for_final_review( - db_data_creator=db_data_creator, - annotation_count=3, - include_user_annotations=True - ) - url_mapping_with_user_anno = setup_info_with_user_anno.url_mapping - - # Have both be listed as unknown - - for url_mapping in [url_mapping_with_user_anno, url_mapping_without_user_anno]: - await db_data_creator.agency_auto_suggestions( - url_id=url_mapping.url_id, - count=3, - suggestion_type=SuggestionType.UNKNOWN - ) - - result = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - - assert result.next_source.id == url_mapping_with_user_anno.url_id diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py deleted file mode 100644 index 4b04d4d1..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_new_agency.py +++ /dev/null @@ -1,41 +0,0 @@ -import pytest - -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.core.enums import SuggestedStatus, RecordType, SuggestionType -from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_new_agency(db_data_creator: DBDataCreator): - """ - Test that a URL with a new agency is properly returned - """ - - # Apply batch v2 - parameters = TestBatchCreationParameters( - urls=[ - TestURLCreationParameters( - annotation_info=AnnotationInfo( - user_relevant=SuggestedStatus.RELEVANT, - user_agency=URLAgencyAnnotationPostInfo( - is_new=True - ), - user_record_type=RecordType.ARREST_RECORDS - ) - ) - ] - ) - creation_info = await db_data_creator.batch_v2(parameters) - outer_result = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - result = outer_result.next_source - - assert result is not None - user_suggestion = result.annotations.agency.user - assert user_suggestion.suggestion_type == SuggestionType.NEW_AGENCY - assert user_suggestion.pdap_agency_id is None - assert user_suggestion.agency_name is None diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py deleted file mode 100644 index b82ebee2..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_not_annotations.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_no_annotations(db_data_creator: DBDataCreator): - """ - Test in the case of one URL with no annotations. - No annotations should be returned - """ - batch_id = db_data_creator.batch() - url_mapping = db_data_creator.urls(batch_id=batch_id, url_count=1).url_mappings[0] - - result = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - - assert result.next_source is None diff --git a/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py b/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py deleted file mode 100644 index 6c9a29c8..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_final_review/test_only_confirmed_urls.py +++ /dev/null @@ -1,24 +0,0 @@ -import pytest - -from src.collectors.enums import URLStatus -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_final_review_only_confirmed_urls(db_data_creator: DBDataCreator): - """ - Test in the case of one URL that is submitted - Should not be returned. - """ - batch_id = db_data_creator.batch() - url_mapping = db_data_creator.urls( - batch_id=batch_id, - url_count=1, - outcome=URLStatus.SUBMITTED - ).url_mappings[0] - - result = await db_data_creator.adb_client.get_next_url_for_final_review( - batch_id=None - ) - - assert result.next_source is None diff --git a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py deleted file mode 100644 index 57c6ae35..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_pending.py +++ /dev/null @@ -1,68 +0,0 @@ -import pytest - -from src.core.enums import SuggestedStatus -from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_user_relevance_annotation_pending( - db_data_creator: DBDataCreator -): - """ - Users should receive a valid URL to annotate - All users should receive the same next URL - Once any user annotates that URL, none of the users should receive it again - """ - setup_info = await setup_for_get_next_url_for_annotation( - db_data_creator=db_data_creator, - url_count=2 - ) - - url_1 = setup_info.insert_urls_info.url_mappings[0] - - # Add `Relevancy` attribute with value `True` - await db_data_creator.auto_relevant_suggestions( - url_id=url_1.url_id, - relevant=True - ) - - adb_client = db_data_creator.adb_client - url_1 = await adb_client.get_next_url_for_relevance_annotation( - user_id=1, - batch_id=None - ) - assert url_1 is not None - - url_2 = await adb_client.get_next_url_for_relevance_annotation( - user_id=2, - batch_id=None - ) - assert url_2 is not None - - assert url_1.url_info.url == url_2.url_info.url - - # Annotate this URL, then check that the second URL is returned - await adb_client.add_user_relevant_suggestion( - url_id=url_1.url_info.url_id, - user_id=1, - suggested_status=SuggestedStatus.RELEVANT - ) - - url_3 = await adb_client.get_next_url_for_relevance_annotation( - user_id=1, - batch_id=None - ) - assert url_3 is not None - - assert url_1 != url_3 - - # Check that the second URL is also returned for another user - url_4 = await adb_client.get_next_url_for_relevance_annotation( - user_id=2, - batch_id=None - ) - assert url_4 is not None - - - assert url_4 == url_3 diff --git a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py b/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py deleted file mode 100644 index 3736c2b8..00000000 --- a/tests/automated/integration/db/client/get_next_url_for_user_relevance_annotation/test_validated.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from src.collectors.enums import URLStatus -from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_user_relevance_annotation_validated( - db_data_creator: DBDataCreator -): - """ - A validated URL should not turn up in get_next_url_for_user_annotation - """ - - setup_info = await setup_for_get_next_url_for_annotation( - db_data_creator=db_data_creator, - url_count=1, - outcome=URLStatus.VALIDATED - ) - - - url_1 = setup_info.insert_urls_info.url_mappings[0] - - # Add `Relevancy` attribute with value `True` - await db_data_creator.auto_relevant_suggestions( - url_id=url_1.url_id, - relevant=True - ) - - adb_client = db_data_creator.adb_client - url = await adb_client.get_next_url_for_relevance_annotation( - user_id=1, - batch_id=None - ) - assert url is None diff --git a/tests/automated/integration/db/client/test_add_url_error_info.py b/tests/automated/integration/db/client/test_add_url_error_info.py deleted file mode 100644 index 34d103ce..00000000 --- a/tests/automated/integration/db/client/test_add_url_error_info.py +++ /dev/null @@ -1,37 +0,0 @@ -import pytest - -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.error import URLErrorPydanticInfo -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_add_url_error_info(db_data_creator: DBDataCreator): - batch_id = db_data_creator.batch() - url_mappings = db_data_creator.urls(batch_id=batch_id, url_count=3).url_mappings - url_ids = [url_mapping.url_id for url_mapping in url_mappings] - - adb_client = AsyncDatabaseClient() - task_id = await db_data_creator.task() - - error_infos = [] - for url_mapping in url_mappings: - uei = URLErrorPydanticInfo( - url_id=url_mapping.url_id, - error="test error", - task_id=task_id - ) - - error_infos.append(uei) - - await adb_client.add_url_error_infos( - url_error_infos=error_infos - ) - - results = await adb_client.get_urls_with_errors() - - assert len(results) == 3 - - for result in results: - assert result.url_id in url_ids - assert result.error == "test error" diff --git a/tests/automated/integration/db/client/test_delete_old_logs.py b/tests/automated/integration/db/client/test_delete_old_logs.py index d451af8f..7c2c2b62 100644 --- a/tests/automated/integration/db/client/test_delete_old_logs.py +++ b/tests/automated/integration/db/client/test_delete_old_logs.py @@ -2,8 +2,9 @@ import pytest -from src.db.dtos.log import LogInfo -from tests.helpers.db_data_creator import DBDataCreator +from src.core.tasks.scheduled.impl.delete_logs.operator import DeleteOldLogsTaskOperator +from src.db.models.impl.log.pydantic.info import LogInfo +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -13,13 +14,16 @@ async def test_delete_old_logs(db_data_creator: DBDataCreator): old_datetime = datetime.now() - timedelta(days=7) db_client = db_data_creator.db_client adb_client = db_data_creator.adb_client + operator = DeleteOldLogsTaskOperator( + adb_client=adb_client, + ) log_infos = [] for i in range(3): log_infos.append(LogInfo(log="test log", batch_id=batch_id, created_at=old_datetime)) db_client.insert_logs(log_infos=log_infos) logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 3 - await adb_client.delete_old_logs() + await operator.inner_task_logic() logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 0 diff --git a/tests/automated/integration/db/client/test_delete_url_updated_at.py b/tests/automated/integration/db/client/test_delete_url_updated_at.py index a6ca731b..3c50c505 100644 --- a/tests/automated/integration/db/client/test_delete_url_updated_at.py +++ b/tests/automated/integration/db/client/test_delete_url_updated_at.py @@ -1,5 +1,5 @@ -from src.db.dtos.url.core import URLInfo -from tests.helpers.db_data_creator import DBDataCreator +from src.db.models.impl.url.core.pydantic.info import URLInfo +from tests.helpers.data_creator.core import DBDataCreator def test_delete_url_updated_at(db_data_creator: DBDataCreator): diff --git a/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py b/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py index 5a402727..86d4a3ee 100644 --- a/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py +++ b/tests/automated/integration/db/client/test_get_next_url_for_annotation_batch_filtering.py @@ -1,8 +1,9 @@ import pytest +from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse from src.core.enums import SuggestionType from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio @@ -31,38 +32,38 @@ def assert_batch_info(batch_info): # Test for relevance # If a batch id is provided, return first valid URL with that batch id - result_with_batch_id = await db_data_creator.adb_client.get_next_url_for_relevance_annotation( + result_with_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=setup_info_2.batch_id ) - assert result_with_batch_id.url_info.url == url_2.url - assert_batch_info(result_with_batch_id.batch_info) + assert result_with_batch_id.next_annotation.url_info.url == url_2.url + assert_batch_info(result_with_batch_id.next_annotation.batch_info) # If no batch id is provided, return first valid URL - result_no_batch_id = await db_data_creator.adb_client.get_next_url_for_relevance_annotation( + result_no_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=None ) - assert result_no_batch_id.url_info.url == url_1.url + assert result_no_batch_id.next_annotation.url_info.url == url_1.url # Test for record type # If a batch id is provided, return first valid URL with that batch id - result_with_batch_id = await db_data_creator.adb_client.get_next_url_for_record_type_annotation( + result_with_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=setup_info_2.batch_id ) - assert result_with_batch_id.url_info.url == url_2.url - assert_batch_info(result_with_batch_id.batch_info) + assert result_with_batch_id.next_annotation.url_info.url == url_2.url + assert_batch_info(result_with_batch_id.next_annotation.batch_info) # If no batch id is provided, return first valid URL - result_no_batch_id = await db_data_creator.adb_client.get_next_url_for_record_type_annotation( + result_no_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=None ) - assert result_no_batch_id.url_info.url == url_1.url + assert result_no_batch_id.next_annotation.url_info.url == url_1.url # Test for agency for url in [url_1, url_2]: @@ -73,7 +74,7 @@ def assert_batch_info(batch_info): ) # If a batch id is provided, return first valid URL with that batch id - result_with_batch_id = await db_data_creator.adb_client.get_next_url_agency_for_annotation( + result_with_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=setup_info_2.batch_id ) @@ -82,7 +83,7 @@ def assert_batch_info(batch_info): assert_batch_info(result_with_batch_id.next_annotation.batch_info) # If no batch id is provided, return first valid URL - result_no_batch_id = await db_data_creator.adb_client.get_next_url_agency_for_annotation( + result_no_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( user_id=1, batch_id=None ) @@ -91,16 +92,18 @@ def assert_batch_info(batch_info): # All annotations - result_with_batch_id = await db_data_creator.adb_client.get_next_url_for_all_annotations( - batch_id=setup_info_2.batch_id + result_with_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( + batch_id=setup_info_2.batch_id, + user_id=1 ) assert result_with_batch_id.next_annotation.url_info.url == url_2.url assert_batch_info(result_with_batch_id.next_annotation.batch_info) # If no batch id is provided, return first valid URL - result_no_batch_id = await db_data_creator.adb_client.get_next_url_for_all_annotations( - batch_id=None + result_no_batch_id: GetNextURLForAllAnnotationResponse = await db_data_creator.adb_client.get_next_url_for_all_annotations( + batch_id=None, + user_id=1 ) assert result_no_batch_id.next_annotation.url_info.url == url_1.url diff --git a/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py b/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py deleted file mode 100644 index 8f03286c..00000000 --- a/tests/automated/integration/db/client/test_get_next_url_for_user_agency_annotation.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from tests.helpers.setup.annotate_agency.core import setup_for_annotate_agency -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_user_agency_annotation(db_data_creator: DBDataCreator): - """ - All users should receive the same next valid URL for agency annotation - Once any user annotates that URL, none of the users should receive it - """ - setup_info = await setup_for_annotate_agency( - db_data_creator, - url_count=2 - ) - - # All users should receive the same URL - url_1 = setup_info.url_ids[0] - url_2 = setup_info.url_ids[1] - - adb_client = db_data_creator.adb_client - url_user_1 = await adb_client.get_next_url_agency_for_annotation( - user_id=1, - batch_id=None - ) - assert url_user_1 is not None - - url_user_2 = await adb_client.get_next_url_agency_for_annotation( - user_id=2, - batch_id=None - ) - - assert url_user_2 is not None - - # Check that the URLs are the same - assert url_user_1 == url_user_2 - - # Annotate the URL - await adb_client.add_agency_manual_suggestion( - url_id=url_1, - user_id=1, - is_new=True, - agency_id=None - ) - - # Both users should receive the next URL - next_url_user_1 = await adb_client.get_next_url_agency_for_annotation( - user_id=1, - batch_id=None - ) - assert next_url_user_1 is not None - - next_url_user_2 = await adb_client.get_next_url_agency_for_annotation( - user_id=2, - batch_id=None - ) - assert next_url_user_2 is not None - - assert url_user_1 != next_url_user_1 - assert next_url_user_1 == next_url_user_2 diff --git a/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py b/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py deleted file mode 100644 index 292ab33f..00000000 --- a/tests/automated/integration/db/client/test_get_next_url_for_user_record_type_annotation.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest - -from src.core.enums import RecordType -from tests.helpers.setup.annotation.core import setup_for_get_next_url_for_annotation -from tests.helpers.db_data_creator import DBDataCreator - - -@pytest.mark.asyncio -async def test_get_next_url_for_user_record_type_annotation(db_data_creator: DBDataCreator): - """ - All users should receive the same next valid URL for record type annotation - Once any user annotates that URL, none of the users should receive it - """ - setup_info = await setup_for_get_next_url_for_annotation( - db_data_creator, - url_count=2 - ) - - # All users should receive the same URL - url_1 = setup_info.insert_urls_info.url_mappings[0] - url_2 = setup_info.insert_urls_info.url_mappings[1] - - adb_client = db_data_creator.adb_client - - url_user_1 = await adb_client.get_next_url_for_record_type_annotation( - user_id=1, - batch_id=None - ) - assert url_user_1 is not None - - url_user_2 = await adb_client.get_next_url_for_record_type_annotation( - user_id=2, - batch_id=None - ) - - assert url_user_2 is not None - - # Check that the URLs are the same - assert url_user_1 == url_user_2 - - # After annotating, both users should receive a different URL - await adb_client.add_user_record_type_suggestion( - user_id=1, - url_id=url_1.url_id, - record_type=RecordType.ARREST_RECORDS - ) - - next_url_user_1 = await adb_client.get_next_url_for_record_type_annotation( - user_id=1, - batch_id=None - ) - - next_url_user_2 = await adb_client.get_next_url_for_record_type_annotation( - user_id=2, - batch_id=None - ) - - assert next_url_user_1 != url_user_1 - assert next_url_user_1 == next_url_user_2 diff --git a/tests/automated/integration/db/client/test_insert_logs.py b/tests/automated/integration/db/client/test_insert_logs.py index d752c894..5ac9b9be 100644 --- a/tests/automated/integration/db/client/test_insert_logs.py +++ b/tests/automated/integration/db/client/test_insert_logs.py @@ -1,7 +1,7 @@ import pytest -from src.db.dtos.log import LogInfo -from tests.helpers.db_data_creator import DBDataCreator +from src.db.models.impl.log.pydantic.info import LogInfo +from tests.helpers.data_creator.core import DBDataCreator @pytest.mark.asyncio diff --git a/tests/automated/integration/db/client/test_insert_urls.py b/tests/automated/integration/db/client/test_insert_urls.py index 73a88d02..f2d73f00 100644 --- a/tests/automated/integration/db/client/test_insert_urls.py +++ b/tests/automated/integration/db/client/test_insert_urls.py @@ -1,8 +1,11 @@ import pytest from src.core.enums import BatchStatus -from src.db.dtos.batch import BatchInfo -from src.db.dtos.url.core import URLInfo +from src.db.models.impl.batch.pydantic.info import BatchInfo +from src.db.models.impl.link.batch_url.sqlalchemy import LinkBatchURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo +from src.db.models.impl.url.core.sqlalchemy import URL @pytest.mark.asyncio @@ -23,14 +26,17 @@ async def test_insert_urls( URLInfo( url="https://example.com/1", collector_metadata={"name": "example_1"}, + source=URLSource.COLLECTOR ), URLInfo( url="https://example.com/2", + source=URLSource.COLLECTOR ), # Duplicate URLInfo( url="https://example.com/1", collector_metadata={"name": "example_duplicate"}, + source=URLSource.COLLECTOR ) ] insert_urls_info = await adb_client_test.insert_urls( @@ -46,3 +52,11 @@ async def test_insert_urls( assert insert_urls_info.original_count == 2 assert insert_urls_info.duplicate_count == 1 + + urls = await adb_client_test.get_all(URL) + assert len(urls) == 2 + + links: list[LinkBatchURL] = await adb_client_test.get_all(LinkBatchURL) + assert len(links) == 2 + for link in links: + assert link.batch_id == batch_id diff --git a/tests/automated/integration/db/structure/README.md b/tests/automated/integration/db/structure/README.md new file mode 100644 index 00000000..2e22a324 --- /dev/null +++ b/tests/automated/integration/db/structure/README.md @@ -0,0 +1,6 @@ +Database Structure tests, in this instance +Test the integrity of the database schema and that it behaves as expected. + +This includes testing that: +* Enum columns allow only allowed values (and throw errors on others) +* Column types are correct diff --git a/tests/automated/integration/db/structure/__init__.py b/tests/automated/integration/db/structure/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/db/structure/test_batch.py b/tests/automated/integration/db/structure/test_batch.py new file mode 100644 index 00000000..f905b178 --- /dev/null +++ b/tests/automated/integration/db/structure/test_batch.py @@ -0,0 +1,88 @@ +import sqlalchemy as sa +from sqlalchemy import create_engine +from sqlalchemy.dialects import postgresql + +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.helpers.connect import get_postgres_connection_string +from src.util.helper_functions import get_enum_values +from tests.automated.integration.db.structure.testers.models.column import ColumnTester +from tests.automated.integration.db.structure.testers.table import TableTester + + +def test_batch(wiped_database): + engine = create_engine(get_postgres_connection_string()) + table_tester = TableTester( + table_name="batches", + columns=[ + ColumnTester( + column_name="strategy", + type_=postgresql.ENUM, + allowed_values=get_enum_values(CollectorType), + ), + ColumnTester( + column_name="user_id", + type_=sa.Integer, + allowed_values=[1], + ), + ColumnTester( + column_name="status", + type_=postgresql.ENUM, + allowed_values=get_enum_values(BatchStatus), + ), + ColumnTester( + column_name="total_url_count", + type_=sa.Integer, + allowed_values=[1], + ), + ColumnTester( + column_name="original_url_count", + type_=sa.Integer, + allowed_values=[1], + ), + ColumnTester( + column_name="duplicate_url_count", + type_=sa.Integer, + allowed_values=[1], + ), + ColumnTester( + column_name="strategy_success_rate", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="metadata_success_rate", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="agency_match_rate", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="record_type_match_rate", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="record_category_match_rate", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="compute_time", + type_=sa.Float, + allowed_values=[1.0], + ), + ColumnTester( + column_name="parameters", + type_=sa.JSON, + allowed_values=[{}] + ) + + ], + engine=engine + ) + + table_tester.run_column_tests() diff --git a/tests/automated/integration/db/structure/test_html_content.py b/tests/automated/integration/db/structure/test_html_content.py new file mode 100644 index 00000000..936a8a25 --- /dev/null +++ b/tests/automated/integration/db/structure/test_html_content.py @@ -0,0 +1,38 @@ +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.enums import URLHTMLContentType +from src.util.helper_functions import get_enum_values +from tests.automated.integration.db.structure.testers.models.column import ColumnTester +from tests.automated.integration.db.structure.testers.table import TableTester +from tests.helpers.data_creator.core import DBDataCreator + + +def test_html_content(db_data_creator: DBDataCreator): + batch_id = db_data_creator.batch() + iui: InsertURLsInfo = db_data_creator.urls(batch_id=batch_id, url_count=1) + + table_tester = TableTester( + table_name="url_html_content", + columns=[ + ColumnTester( + column_name="url_id", + type_=sa.Integer, + allowed_values=[iui.url_mappings[0].url_id] + ), + ColumnTester( + column_name="content_type", + type_=postgresql.ENUM, + allowed_values=get_enum_values(URLHTMLContentType) + ), + ColumnTester( + column_name="content", + type_=sa.Text, + allowed_values=["Text"] + ) + ], + engine=db_data_creator.db_client.engine + ) + + table_tester.run_column_tests() diff --git a/tests/automated/integration/db/structure/test_root_url.py b/tests/automated/integration/db/structure/test_root_url.py new file mode 100644 index 00000000..8f8be80b --- /dev/null +++ b/tests/automated/integration/db/structure/test_root_url.py @@ -0,0 +1,32 @@ +import sqlalchemy as sa + +from tests.automated.integration.db.structure.testers.models.column import ColumnTester +from tests.automated.integration.db.structure.testers.table import TableTester +from tests.helpers.data_creator.core import DBDataCreator + + +def test_root_url(db_data_creator: DBDataCreator): + + table_tester = TableTester( + table_name="root_urls", + columns=[ + ColumnTester( + column_name="url", + type_=sa.String, + allowed_values=["https://example.com"] + ), + ColumnTester( + column_name="page_title", + type_=sa.String, + allowed_values=["Text"] + ), + ColumnTester( + column_name="page_description", + type_=sa.String, + allowed_values=["Text"] + ) + ], + engine=db_data_creator.db_client.engine + ) + + table_tester.run_column_tests() diff --git a/tests/automated/integration/db/structure/test_task_enums.py b/tests/automated/integration/db/structure/test_task_enums.py new file mode 100644 index 00000000..709808a3 --- /dev/null +++ b/tests/automated/integration/db/structure/test_task_enums.py @@ -0,0 +1,13 @@ +import pytest + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType + + +@pytest.mark.asyncio +async def test_task_enums(adb_client_test: AsyncDatabaseClient) -> None: + + for task_type in TaskType: + if task_type == TaskType.IDLE: + continue + await adb_client_test.initiate_task(task_type=task_type) \ No newline at end of file diff --git a/tests/automated/integration/db/structure/test_upsert_new_agencies.py b/tests/automated/integration/db/structure/test_upsert_new_agencies.py new file mode 100644 index 00000000..6b377974 --- /dev/null +++ b/tests/automated/integration/db/structure/test_upsert_new_agencies.py @@ -0,0 +1,59 @@ +import pytest + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.db.models.impl.agency.sqlalchemy import Agency +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_upsert_new_agencies( + wiped_database, + db_data_creator: DBDataCreator +): + """ + Check that if the agency doesn't exist, it is added + But if the agency does exist, it is updated with new information + """ + + suggestions = [] + for i in range(3): + suggestion = URLAgencySuggestionInfo( + url_id=1, + suggestion_type=SuggestionType.AUTO_SUGGESTION, + pdap_agency_id=i, + agency_name=f"Test Agency {i}", + state=f"Test State {i}", + county=f"Test County {i}", + locality=f"Test Locality {i}", + user_id=1 + ) + suggestions.append(suggestion) + + adb_client = db_data_creator.adb_client + await adb_client.upsert_new_agencies(suggestions) + + update_suggestion = URLAgencySuggestionInfo( + url_id=1, + suggestion_type=SuggestionType.AUTO_SUGGESTION, + pdap_agency_id=0, + agency_name="Updated Test Agency", + state="Updated Test State", + county="Updated Test County", + locality="Updated Test Locality", + user_id=1 + ) + + await adb_client.upsert_new_agencies([update_suggestion]) + + rows = await adb_client.get_all(Agency, order_by_attribute="agency_id") + + assert len(rows) == 3 + + d = {} + for row in rows: + d[row.agency_id] = row.name + + assert d[0] == "Updated Test Agency" + assert d[1] == "Test Agency 1" + assert d[2] == "Test Agency 2" diff --git a/tests/automated/integration/db/structure/testers/__init__.py b/tests/automated/integration/db/structure/testers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/db/structure/testers/models/__init__.py b/tests/automated/integration/db/structure/testers/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/db/structure/testers/models/column.py b/tests/automated/integration/db/structure/testers/models/column.py new file mode 100644 index 00000000..1b4c5a50 --- /dev/null +++ b/tests/automated/integration/db/structure/testers/models/column.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + +from tests.automated.integration.db.structure.types import SATypes + + +@dataclass +class ColumnTester: + column_name: str + type_: SATypes + allowed_values: list diff --git a/tests/automated/integration/db/structure/testers/models/foreign_key.py b/tests/automated/integration/db/structure/testers/models/foreign_key.py new file mode 100644 index 00000000..517a82a8 --- /dev/null +++ b/tests/automated/integration/db/structure/testers/models/foreign_key.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + + +@dataclass +class ForeignKeyTester: + column_name: str + valid_id: int + invalid_id: int diff --git a/tests/automated/integration/db/structure/testers/models/unique_constraint.py b/tests/automated/integration/db/structure/testers/models/unique_constraint.py new file mode 100644 index 00000000..baa85cbb --- /dev/null +++ b/tests/automated/integration/db/structure/testers/models/unique_constraint.py @@ -0,0 +1,6 @@ +from dataclasses import dataclass + + +@dataclass +class UniqueConstraintTester: + columns: list[str] diff --git a/tests/automated/integration/db/structure/testers/table.py b/tests/automated/integration/db/structure/testers/table.py new file mode 100644 index 00000000..a91c0837 --- /dev/null +++ b/tests/automated/integration/db/structure/testers/table.py @@ -0,0 +1,95 @@ +from typing import Optional, Any + +import pytest +import sqlalchemy as sa +from sqlalchemy import create_engine +from sqlalchemy.dialects import postgresql +from sqlalchemy.exc import DataError + +from src.db.helpers.connect import get_postgres_connection_string +from src.db.models.templates_.base import Base +from tests.automated.integration.db.structure.testers.models.column import ColumnTester +from tests.automated.integration.db.structure.types import ConstraintTester, SATypes + + +class TableTester: + + def __init__( + self, + columns: list[ColumnTester], + table_name: str, + engine: Optional[sa.Engine] = None, + constraints: Optional[list[ConstraintTester]] = None, + ): + if engine is None: + engine = create_engine(get_postgres_connection_string(is_async=True)) + self.columns = columns + self.table_name = table_name + self.constraints = constraints + self.engine = engine + + def run_tests(self): + pass + + def setup_row_dict(self, override: Optional[dict[str, Any]] = None): + d = {} + for column in self.columns: + # For row dicts, the first value is the default + d[column.column_name] = column.allowed_values[0] + if override is not None: + d.update(override) + return d + + def run_column_test(self, column: ColumnTester): + if len(column.allowed_values) == 1: + return # It will be tested elsewhere + for value in column.allowed_values: + print(f"Testing column {column.column_name} with value {value}") + row_dict = self.setup_row_dict(override={column.column_name: value}) + table = self.get_table_model() + with self.engine.begin() as conn: + # Delete existing rows + conn.execute(table.delete()) + conn.commit() + with self.engine.begin() as conn: + conn.execute(table.insert(), row_dict) + conn.commit() + conn.close() + self.test_invalid_values(column) + + def generate_invalid_value(self, type_: SATypes): + match type_: + case sa.Integer: + return "not an integer" + case sa.String: + return -1 + case postgresql.ENUM: + return "not an enum value" + case sa.TIMESTAMP: + return "not a timestamp" + + def test_invalid_values(self, column: ColumnTester): + invalid_value = self.generate_invalid_value(type_=column.type_) + row_dict = self.setup_row_dict(override={column.column_name: invalid_value}) + table = self.get_table_model() + print(f"Testing column '{column.column_name}' with invalid value {invalid_value}") + with pytest.raises(DataError): + with self.engine.begin() as conn: + conn.execute(table.delete()) + conn.commit() + with self.engine.begin() as conn: + conn.execute(table.insert(), row_dict) + conn.commit() + conn.close() + + + def get_table_model(self) -> sa.Table: + """ + Retrieve table model from metadata + """ + return sa.Table(self.table_name, Base.metadata, autoload_with=self.engine) + + + def run_column_tests(self): + for column in self.columns: + self.run_column_test(column) diff --git a/tests/automated/integration/db/structure/types.py b/tests/automated/integration/db/structure/types.py new file mode 100644 index 00000000..3124538f --- /dev/null +++ b/tests/automated/integration/db/structure/types.py @@ -0,0 +1,10 @@ +from typing import TypeAlias + +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +from tests.automated.integration.db.structure.testers.models.foreign_key import ForeignKeyTester +from tests.automated.integration.db.structure.testers.models.unique_constraint import UniqueConstraintTester + +SATypes: TypeAlias = sa.Integer or sa.String or postgresql.ENUM or sa.TIMESTAMP or sa.Text +ConstraintTester: TypeAlias = UniqueConstraintTester or ForeignKeyTester diff --git a/tests/automated/integration/db/test_database_structure.py b/tests/automated/integration/db/test_database_structure.py deleted file mode 100644 index 7b34cebb..00000000 --- a/tests/automated/integration/db/test_database_structure.py +++ /dev/null @@ -1,348 +0,0 @@ -""" -Database Structure tests, in this instance -Test the integrity of the database schema and that it behaves as expected. - -This includes testing that: -* Enum columns allow only allowed values (and throw errors on others) -* Column types are correct -""" - -from dataclasses import dataclass -from typing import TypeAlias, Optional, Any - -import pytest -import sqlalchemy as sa -from sqlalchemy import create_engine -from sqlalchemy.dialects import postgresql -from sqlalchemy.exc import DataError - -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.db.dtos.url.insert import InsertURLsInfo -from src.db.enums import URLHTMLContentType -from src.db.helpers import get_postgres_connection_string -from src.db.models.instantiations.agency import Agency -from src.collectors.enums import CollectorType, URLStatus -from src.core.enums import BatchStatus, SuggestionType -from src.db.models.templates import Base -from src.util.helper_functions import get_enum_values -from tests.helpers.db_data_creator import DBDataCreator - -SATypes: TypeAlias = sa.Integer or sa.String or postgresql.ENUM or sa.TIMESTAMP or sa.Text - -@dataclass -class ColumnTester: - column_name: str - type_: SATypes - allowed_values: list - -@dataclass -class UniqueConstraintTester: - columns: list[str] - -@dataclass -class ForeignKeyTester: - column_name: str - valid_id: int - invalid_id: int - -ConstraintTester: TypeAlias = UniqueConstraintTester or ForeignKeyTester - -class TableTester: - - def __init__( - self, - columns: list[ColumnTester], - table_name: str, - engine: Optional[sa.Engine] = None, - constraints: Optional[list[ConstraintTester]] = None, - ): - if engine is None: - engine = create_engine(get_postgres_connection_string(is_async=True)) - self.columns = columns - self.table_name = table_name - self.constraints = constraints - self.engine = engine - - def run_tests(self): - pass - - def setup_row_dict(self, override: Optional[dict[str, Any]] = None): - d = {} - for column in self.columns: - # For row dicts, the first value is the default - d[column.column_name] = column.allowed_values[0] - if override is not None: - d.update(override) - return d - - def run_column_test(self, column: ColumnTester): - if len(column.allowed_values) == 1: - return # It will be tested elsewhere - for value in column.allowed_values: - print(f"Testing column {column.column_name} with value {value}") - row_dict = self.setup_row_dict(override={column.column_name: value}) - table = self.get_table_model() - with self.engine.begin() as conn: - # Delete existing rows - conn.execute(table.delete()) - conn.commit() - with self.engine.begin() as conn: - conn.execute(table.insert(), row_dict) - conn.commit() - conn.close() - self.test_invalid_values(column) - - def generate_invalid_value(self, type_: SATypes): - match type_: - case sa.Integer: - return "not an integer" - case sa.String: - return -1 - case postgresql.ENUM: - return "not an enum value" - case sa.TIMESTAMP: - return "not a timestamp" - - def test_invalid_values(self, column: ColumnTester): - invalid_value = self.generate_invalid_value(type_=column.type_) - row_dict = self.setup_row_dict(override={column.column_name: invalid_value}) - table = self.get_table_model() - print(f"Testing column '{column.column_name}' with invalid value {invalid_value}") - with pytest.raises(DataError): - with self.engine.begin() as conn: - conn.execute(table.delete()) - conn.commit() - with self.engine.begin() as conn: - conn.execute(table.insert(), row_dict) - conn.commit() - conn.close() - - - def get_table_model(self) -> sa.Table: - """ - Retrieve table model from metadata - """ - return sa.Table(self.table_name, Base.metadata, autoload_with=self.engine) - - - def run_column_tests(self): - for column in self.columns: - self.run_column_test(column) - - -def test_batch(wiped_database): - engine = create_engine(get_postgres_connection_string()) - table_tester = TableTester( - table_name="batches", - columns=[ - ColumnTester( - column_name="strategy", - type_=postgresql.ENUM, - allowed_values=get_enum_values(CollectorType), - ), - ColumnTester( - column_name="user_id", - type_=sa.Integer, - allowed_values=[1], - ), - ColumnTester( - column_name="status", - type_=postgresql.ENUM, - allowed_values=get_enum_values(BatchStatus), - ), - ColumnTester( - column_name="total_url_count", - type_=sa.Integer, - allowed_values=[1], - ), - ColumnTester( - column_name="original_url_count", - type_=sa.Integer, - allowed_values=[1], - ), - ColumnTester( - column_name="duplicate_url_count", - type_=sa.Integer, - allowed_values=[1], - ), - ColumnTester( - column_name="strategy_success_rate", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="metadata_success_rate", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="agency_match_rate", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="record_type_match_rate", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="record_category_match_rate", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="compute_time", - type_=sa.Float, - allowed_values=[1.0], - ), - ColumnTester( - column_name="parameters", - type_=sa.JSON, - allowed_values=[{}] - ) - - ], - engine=engine - ) - - table_tester.run_column_tests() - -def test_url(db_data_creator: DBDataCreator): - batch_id = db_data_creator.batch() - table_tester = TableTester( - table_name="urls", - columns=[ - ColumnTester( - column_name="batch_id", - type_=sa.Integer, - allowed_values=[batch_id], - ), - ColumnTester( - column_name="url", - type_=sa.String, - allowed_values=["https://example.com"], - ), - ColumnTester( - column_name="collector_metadata", - type_=sa.JSON, - allowed_values=[{}] - ), - ColumnTester( - column_name="outcome", - type_=postgresql.ENUM, - allowed_values=get_enum_values(URLStatus) - ), - ColumnTester( - column_name="name", - type_=sa.String, - allowed_values=['test'], - ) - ], - engine=db_data_creator.db_client.engine - ) - - table_tester.run_column_tests() - -def test_html_content(db_data_creator: DBDataCreator): - batch_id = db_data_creator.batch() - iui: InsertURLsInfo = db_data_creator.urls(batch_id=batch_id, url_count=1) - - table_tester = TableTester( - table_name="url_html_content", - columns=[ - ColumnTester( - column_name="url_id", - type_=sa.Integer, - allowed_values=[iui.url_mappings[0].url_id] - ), - ColumnTester( - column_name="content_type", - type_=postgresql.ENUM, - allowed_values=get_enum_values(URLHTMLContentType) - ), - ColumnTester( - column_name="content", - type_=sa.Text, - allowed_values=["Text"] - ) - ], - engine=db_data_creator.db_client.engine - ) - - table_tester.run_column_tests() - -def test_root_url(db_data_creator: DBDataCreator): - - table_tester = TableTester( - table_name="root_urls", - columns=[ - ColumnTester( - column_name="url", - type_=sa.String, - allowed_values=["https://example.com"] - ), - ColumnTester( - column_name="page_title", - type_=sa.String, - allowed_values=["Text"] - ), - ColumnTester( - column_name="page_description", - type_=sa.String, - allowed_values=["Text"] - ) - ], - engine=db_data_creator.db_client.engine - ) - - table_tester.run_column_tests() - - -@pytest.mark.asyncio -async def test_upsert_new_agencies(db_data_creator: DBDataCreator): - """ - Check that if the agency doesn't exist, it is added - But if the agency does exist, it is updated with new information - """ - - suggestions = [] - for i in range(3): - suggestion = URLAgencySuggestionInfo( - url_id=1, - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=i, - agency_name=f"Test Agency {i}", - state=f"Test State {i}", - county=f"Test County {i}", - locality=f"Test Locality {i}", - user_id=1 - ) - suggestions.append(suggestion) - - adb_client = db_data_creator.adb_client - await adb_client.upsert_new_agencies(suggestions) - - update_suggestion = URLAgencySuggestionInfo( - url_id=1, - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=0, - agency_name="Updated Test Agency", - state="Updated Test State", - county="Updated Test County", - locality="Updated Test Locality", - user_id=1 - ) - - await adb_client.upsert_new_agencies([update_suggestion]) - - rows = await adb_client.get_all(Agency) - - assert len(rows) == 3 - - d = {} - for row in rows: - d[row.agency_id] = row.name - - assert d[0] == "Updated Test Agency" - assert d[1] == "Test Agency 1" - assert d[2] == "Test Agency 2" diff --git a/tests/automated/integration/html_tag_collector/test_root_url_cache.py b/tests/automated/integration/html_tag_collector/test_root_url_cache.py deleted file mode 100644 index 151985cf..00000000 --- a/tests/automated/integration/html_tag_collector/test_root_url_cache.py +++ /dev/null @@ -1,19 +0,0 @@ -import pytest - -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.core import RootURLCache -from src.core.tasks.url.operators.url_html.scraper.root_url_cache.dtos.response import RootURLCacheResponseInfo - - -async def mock_get_request(url: str) -> RootURLCacheResponseInfo: - return RootURLCacheResponseInfo(text="
This is an example of HTML content.
+ + + """ + +def setup_url_to_response_info( +) -> dict[str, URLResponseInfo]: + d = {} + for entry in TEST_ENTRIES: + response_info = URLResponseInfo( + success=_get_success(entry), + status=get_http_status(entry), + html=_generate_test_html() if _get_success(entry) else None, + content_type=_get_content_type(entry), + exception=None if _get_success(entry) else "Error" + ) + d[entry.url_info.url] = response_info + return d diff --git a/tests/automated/integration/tasks/url/impl/html/setup/__init__.py b/tests/automated/integration/tasks/url/impl/html/setup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/html/setup/data.py b/tests/automated/integration/tasks/url/impl/html/setup/data.py new file mode 100644 index 00000000..5615392c --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/html/setup/data.py @@ -0,0 +1,94 @@ +from http import HTTPStatus + +from src.collectors.enums import URLStatus +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus +from tests.automated.integration.tasks.url.impl.html.setup.models.entry import TestURLHTMLTaskSetupEntry, TestURLInfo, \ + TestWebMetadataInfo, ExpectedResult, TestErrorType + +TEST_ENTRIES = [ + # URLs that give 200s should be updated with the appropriate scrape status + # and their html should be stored + TestURLHTMLTaskSetupEntry( + url_info=TestURLInfo( + url="https://happy-path.com/pending", + status=URLStatus.OK + ), + web_metadata_info=TestWebMetadataInfo( + accessed=True, + content_type="text/html", + response_code=HTTPStatus.OK, + error_message=None + ), + expected_result=ExpectedResult( + has_html=True, # Test for both compressed HTML and content metadata + scrape_status=ScrapeStatus.SUCCESS + ) + ), + # URLs that give 404s should be updated with the appropriate scrape status + # and their web metadata status should be updated to 404 + TestURLHTMLTaskSetupEntry( + url_info=TestURLInfo( + url="https://not-found-path.com/submitted", + status=URLStatus.ERROR + ), + web_metadata_info=TestWebMetadataInfo( + accessed=True, + content_type="text/html", + response_code=HTTPStatus.OK, + error_message=None + ), + give_error=TestErrorType.HTTP_404, + expected_result=ExpectedResult( + has_html=False, + scrape_status=ScrapeStatus.ERROR, + web_metadata_status_marked_404=True + ) + ), + # URLs that give errors should be updated with the appropriate scrape status + TestURLHTMLTaskSetupEntry( + url_info=TestURLInfo( + url="https://error-path.com/submitted", + status=URLStatus.ERROR + ), + web_metadata_info=TestWebMetadataInfo( + accessed=True, + content_type="text/html", + response_code=HTTPStatus.OK, + error_message=None + ), + give_error=TestErrorType.SCRAPER, + expected_result=ExpectedResult( + has_html=False, + scrape_status=ScrapeStatus.ERROR + ) + ), + # URLs with non-200 web metadata should not be processed + TestURLHTMLTaskSetupEntry( + url_info=TestURLInfo( + url="https://not-200-path.com/submitted", + status=URLStatus.OK + ), + web_metadata_info=TestWebMetadataInfo( + accessed=True, + content_type="text/html", + response_code=HTTPStatus.PERMANENT_REDIRECT, + error_message=None + ), + expected_result=ExpectedResult( + has_html=False, + scrape_status=None + ) + ), + # URLs with no web metadata should not be processed + TestURLHTMLTaskSetupEntry( + url_info=TestURLInfo( + url="https://no-web-metadata.com/submitted", + status=URLStatus.OK + ), + web_metadata_info=None, + expected_result=ExpectedResult( + has_html=False, + scrape_status=None + ) + ) +] \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/html/setup/manager.py b/tests/automated/integration/tasks/url/impl/html/setup/manager.py new file mode 100644 index 00000000..986a9f7e --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/html/setup/manager.py @@ -0,0 +1,78 @@ +import types + +from src.core.enums import RecordType +from src.core.tasks.url.operators.html.core import URLHTMLTaskOperator +from src.core.tasks.url.operators.html.scraper.parser.core import HTMLResponseParser +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic +from tests.automated.integration.tasks.url.impl.html.mocks.methods import mock_parse +from tests.automated.integration.tasks.url.impl.html.mocks.url_request_interface.core import MockURLRequestInterface +from tests.automated.integration.tasks.url.impl.html.setup.data import TEST_ENTRIES +from tests.automated.integration.tasks.url.impl.html.setup.models.record import TestURLHTMLTaskSetupRecord + + +class TestURLHTMLTaskSetupManager: + + def __init__(self, adb_client: AsyncDatabaseClient): + self.adb_client = adb_client + + + async def setup(self) -> list[TestURLHTMLTaskSetupRecord]: + + records = await self._setup_urls() + await self.setup_web_metadata(records) + return records + + async def _setup_urls(self) -> list[TestURLHTMLTaskSetupRecord]: + url_insert_models: list[URLInsertModel] = [] + for entry in TEST_ENTRIES: + url_insert_model = URLInsertModel( + status=entry.url_info.status, + url=entry.url_info.url, + name=f"Test for {entry.url_info.url}", + record_type=RecordType.RESOURCES, + source=URLSource.COLLECTOR + ) + url_insert_models.append(url_insert_model) + url_ids = await self.adb_client.bulk_insert(url_insert_models, return_ids=True) + + records = [] + for url_id, entry in zip(url_ids, TEST_ENTRIES): + record = TestURLHTMLTaskSetupRecord( + url_id=url_id, + entry=entry + ) + records.append(record) + return records + + async def setup_web_metadata( + self, + records: list[TestURLHTMLTaskSetupRecord] + ) -> None: + models = [] + for record in records: + entry = record.entry + web_metadata_info = entry.web_metadata_info + if web_metadata_info is None: + continue + model = URLWebMetadataPydantic( + url_id=record.url_id, + accessed=web_metadata_info.accessed, + status_code=web_metadata_info.response_code.value, + content_type=web_metadata_info.content_type, + error_message=web_metadata_info.error_message + ) + models.append(model) + await self.adb_client.bulk_insert(models) + +async def setup_operator() -> URLHTMLTaskOperator: + html_parser = HTMLResponseParser() + html_parser.parse = types.MethodType(mock_parse, html_parser) + operator = URLHTMLTaskOperator( + adb_client=AsyncDatabaseClient(), + url_request_interface=MockURLRequestInterface(), + html_parser=html_parser + ) + return operator diff --git a/tests/automated/integration/tasks/url/impl/html/setup/models/__init__.py b/tests/automated/integration/tasks/url/impl/html/setup/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/html/setup/models/entry.py b/tests/automated/integration/tasks/url/impl/html/setup/models/entry.py new file mode 100644 index 00000000..287bb52c --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/html/setup/models/entry.py @@ -0,0 +1,34 @@ +from enum import Enum +from http import HTTPStatus + +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus + + +class TestErrorType(Enum): + SCRAPER = "scraper" + HTTP_404 = "http-404" + + +class TestWebMetadataInfo(BaseModel): + accessed: bool + content_type: str | None + response_code: HTTPStatus + error_message: str | None + +class TestURLInfo(BaseModel): + url: str + status: URLStatus + +class ExpectedResult(BaseModel): + has_html: bool + scrape_status: ScrapeStatus | None # Does not have scrape info if none + web_metadata_status_marked_404: bool = False + +class TestURLHTMLTaskSetupEntry(BaseModel): + url_info: TestURLInfo + web_metadata_info: TestWebMetadataInfo | None + give_error: TestErrorType | None = None + expected_result: ExpectedResult \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/html/setup/models/record.py b/tests/automated/integration/tasks/url/impl/html/setup/models/record.py new file mode 100644 index 00000000..022c9639 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/html/setup/models/record.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + +from tests.automated.integration.tasks.url.impl.html.setup.models.entry import TestURLHTMLTaskSetupEntry + + +class TestURLHTMLTaskSetupRecord(BaseModel): + url_id: int + entry: TestURLHTMLTaskSetupEntry \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/html/test_task.py b/tests/automated/integration/tasks/url/impl/html/test_task.py new file mode 100644 index 00000000..e7462e65 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/html/test_task.py @@ -0,0 +1,33 @@ +import pytest + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.enums import TaskType +from tests.automated.integration.tasks.url.impl.asserts import assert_prereqs_not_met, assert_prereqs_met, \ + assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.html.check.manager import TestURLHTMLTaskCheckManager +from tests.automated.integration.tasks.url.impl.html.setup.manager import setup_operator, \ + TestURLHTMLTaskSetupManager + + +@pytest.mark.asyncio +async def test_url_html_task(adb_client_test: AsyncDatabaseClient): + setup = TestURLHTMLTaskSetupManager(adb_client_test) + + operator = await setup_operator() + + # No URLs were created, the prereqs should not be met + await assert_prereqs_not_met(operator) + + records = await setup.setup() + await assert_prereqs_met(operator) + + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + checker = TestURLHTMLTaskCheckManager( + adb_client=adb_client_test, + records=records + ) + await checker.check() + + await assert_prereqs_not_met(operator) diff --git a/tests/automated/integration/tasks/url/impl/location_identification/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/conftest.py b/tests/automated/integration/tasks/url/impl/location_identification/conftest.py new file mode 100644 index 00000000..cbfa1c57 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/conftest.py @@ -0,0 +1,23 @@ +from unittest.mock import create_autospec + +import pytest + +from src.core.tasks.url.operators.location_id.core import LocationIdentificationTaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.core.tasks.url.operators.location_id.subtasks.loader import LocationIdentificationSubtaskLoader +from src.db.client.async_ import AsyncDatabaseClient + + +@pytest.fixture +def operator( + adb_client_test: AsyncDatabaseClient +) -> LocationIdentificationTaskOperator: + + operator = LocationIdentificationTaskOperator( + adb_client=adb_client_test, + loader=LocationIdentificationSubtaskLoader( + adb_client=adb_client_test, + nlp_processor=create_autospec(NLPProcessor) + ) + ) + return operator \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/batch_link/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/batch_link/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/batch_link/test_core.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/batch_link/test_core.py new file mode 100644 index 00000000..ab505627 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/batch_link/test_core.py @@ -0,0 +1,64 @@ +import pytest + +from src.core.tasks.url.operators.location_id.core import LocationIdentificationTaskOperator +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.location_batch.sqlalchemy import LinkLocationBatch +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_batch_link_subtask( + operator: LocationIdentificationTaskOperator, + db_data_creator: DBDataCreator, + pittsburgh_locality: LocalityCreationInfo +): + + adb_client: AsyncDatabaseClient = operator.adb_client + + creation_info: BatchURLCreationInfoV2 = await db_data_creator.batch_v2( + parameters=TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=2 + ) + ] + ) + ) + batch_id: int = creation_info.batch_id + url_ids: list[int] = creation_info.url_ids + + location_id: int = pittsburgh_locality.location_id + + link = LinkLocationBatch( + location_id=location_id, + batch_id=batch_id + ) + await adb_client.add(link) + + assert await operator.meets_task_prerequisites() + assert operator._subtask == LocationIDSubtaskType.BATCH_LINK + + await run_task_and_confirm_success(operator) + + assert not await operator.meets_task_prerequisites() + assert operator._subtask is None + + subtasks: list[AutoLocationIDSubtask] = await adb_client.get_all(AutoLocationIDSubtask) + assert len(subtasks) == 2 + subtask: AutoLocationIDSubtask = subtasks[0] + assert subtask.type == LocationIDSubtaskType.BATCH_LINK + assert subtask.locations_found + + suggestions: list[LocationIDSubtaskSuggestion] = await adb_client.get_all(LocationIDSubtaskSuggestion) + assert len(suggestions) == 2 + + assert all(sugg.confidence == 80 for sugg in suggestions) + assert all(sugg.location_id == location_id for sugg in suggestions) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/conftest.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/conftest.py new file mode 100644 index 00000000..766a7ca5 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/conftest.py @@ -0,0 +1,15 @@ +import pytest_asyncio + +from src.db.dtos.url.mapping import URLMapping +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest_asyncio.fixture +async def url_ids( + db_data_creator: DBDataCreator, +) -> list[int]: + # Create 2 URLs with compressed HTML + url_mappings: list[URLMapping] = await db_data_creator.create_urls(count=2) + url_ids: list[int] = [url.url_id for url in url_mappings] + await db_data_creator.html_data(url_ids=url_ids) + return url_ids diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/test_core.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/test_core.py new file mode 100644 index 00000000..f8f0c821 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/end_to_end/test_core.py @@ -0,0 +1,120 @@ +import pytest + +from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.url.operators.location_id.core import LocationIdentificationTaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.core import \ + NLPLocationFrequencySubtaskOperator +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.models.input_ import \ + NLPLocationFrequencySubtaskInput +from src.core.tasks.url.operators.location_id.subtasks.models.subtask import AutoLocationIDSubtaskData +from src.core.tasks.url.operators.location_id.subtasks.models.suggestion import LocationSuggestion +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.task_url import LinkTaskURL +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.pydantic import AutoLocationIDSubtaskPydantic +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from tests.helpers.asserts import assert_task_run_success +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo + + +@pytest.mark.asyncio +async def test_nlp_location_match( + operator: LocationIdentificationTaskOperator, + db_data_creator: DBDataCreator, + url_ids: list[int], + pittsburgh_locality: LocalityCreationInfo, + allegheny_county: CountyCreationInfo, + monkeypatch +): + # Confirm operator meets prerequisites + assert await operator.meets_task_prerequisites() + assert operator._subtask == LocationIDSubtaskType.NLP_LOCATION_FREQUENCY + + happy_path_url_id: int = url_ids[0] + error_url_id: int = url_ids[1] + + async def mock_process_inputs( + self: NLPLocationFrequencySubtaskOperator, + inputs: list[NLPLocationFrequencySubtaskInput], + ) -> list[AutoLocationIDSubtaskData]: + response = [ + AutoLocationIDSubtaskData( + pydantic_model=AutoLocationIDSubtaskPydantic( + task_id=self.task_id, + url_id=happy_path_url_id, + type=LocationIDSubtaskType.NLP_LOCATION_FREQUENCY, + locations_found=True, + ), + suggestions=[ + LocationSuggestion( + location_id=pittsburgh_locality.location_id, + confidence=25 + ), + LocationSuggestion( + location_id=allegheny_county.location_id, + confidence=75 + ) + ] + ), + AutoLocationIDSubtaskData( + pydantic_model=AutoLocationIDSubtaskPydantic( + task_id=self.task_id, + url_id=error_url_id, + type=LocationIDSubtaskType.NLP_LOCATION_FREQUENCY, + locations_found=False, + ), + suggestions=[], + error="Test error" + ) + ] + return response + + # Remove internal processor reference - mock NLP processor instead + monkeypatch.setattr( + NLPLocationFrequencySubtaskOperator, + "_process_inputs", + mock_process_inputs + ) + run_info: TaskOperatorRunInfo = await operator.run_task() + assert_task_run_success(run_info) + + adb_client: AsyncDatabaseClient = operator.adb_client + # Confirm two URLs linked to the task + task_links: list[LinkTaskURL] = await adb_client.get_all(LinkTaskURL) + assert len(task_links) == 2 + assert {task_link.url_id for task_link in task_links} == set(url_ids) + assert {task_link.task_id for task_link in task_links} == {operator._task_id} + + # Confirm two subtasks were created + subtasks: list[AutoLocationIDSubtask] = await adb_client.get_all(AutoLocationIDSubtask) + assert len(subtasks) == 2 + assert {subtask.url_id for subtask in subtasks} == set(url_ids) + assert {subtask.task_id for subtask in subtasks} == {operator._task_id} + assert {subtask.type for subtask in subtasks} == { + LocationIDSubtaskType.NLP_LOCATION_FREQUENCY + } + assert {subtask.locations_found for subtask in subtasks} == {True, False} + + + # Confirm one URL error info + error_infos: list[URLTaskError] = await adb_client.get_all(URLTaskError) + assert len(error_infos) == 1 + assert error_infos[0].task_id == operator._task_id + assert error_infos[0].url_id == error_url_id + assert error_infos[0].error == "Test error" + + # Confirm two suggestions for happy path URL id + suggestions: list[LocationIDSubtaskSuggestion] = await adb_client.get_all(LocationIDSubtaskSuggestion) + assert len(suggestions) == 2 + # Confirm expected agency ids + assert {suggestion.location_id for suggestion in suggestions} == { + pittsburgh_locality.location_id, + allegheny_county.location_id, + } + # Confirm both have the expected confidence values + assert {suggestion.confidence for suggestion in suggestions} == {25, 75} + diff --git a/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/test_nlp_response_valid.py b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/test_nlp_response_valid.py new file mode 100644 index 00000000..4ad6ec3c --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/subtasks/nlp_location_frequency/test_nlp_response_valid.py @@ -0,0 +1,57 @@ +import pytest + +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.response import \ + NLPLocationMatchResponse +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.models.us_state import \ + USState + +US_STATE = USState( + name="Pennsylvania", + iso="PA", +) + +SINGLE_LOCATION: list[str] = ["Pittsburgh"] +MULTIPLE_LOCATION: list[str] = ["Pittsburgh", "Allegheny"] + +@pytest.mark.parametrize( + argnames="nlp_response, expected_result", + argvalues=[ + ( + NLPLocationMatchResponse( + locations=SINGLE_LOCATION, + us_state=US_STATE + ), + True, + ), + ( + NLPLocationMatchResponse( + locations=MULTIPLE_LOCATION, + us_state=US_STATE, + ), + True + ), + ( + NLPLocationMatchResponse( + locations=MULTIPLE_LOCATION, + us_state=None, + ), + False, + ), + ( + NLPLocationMatchResponse( + locations=[], + us_state=US_STATE, + ), + False, + ), + ( + NLPLocationMatchResponse( + locations=[], + us_state=None, + ), + False + ) + ], +) +def test_nlp_response_valid(nlp_response: NLPLocationMatchResponse, expected_result: bool): + assert nlp_response.valid == expected_result \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/location_identification/survey/__init__.py b/tests/automated/integration/tasks/url/impl/location_identification/survey/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/location_identification/survey/test_survey_flag.py b/tests/automated/integration/tasks/url/impl/location_identification/survey/test_survey_flag.py new file mode 100644 index 00000000..338c604b --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/location_identification/survey/test_survey_flag.py @@ -0,0 +1,44 @@ +import pytest + +from src.core.tasks.url.operators.location_id.core import LocationIdentificationTaskOperator +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_survey_flag( + operator: LocationIdentificationTaskOperator, + db_data_creator: DBDataCreator, + monkeypatch +): + """ + Test that survey correctly disables Subtask flags + when the environment variable is set to disable that subtask + """ + + # Run basic survey and confirm no next subtask + assert not await operator.meets_task_prerequisites() + assert operator._subtask is None + + applicable_url_id: int = ( + await db_data_creator.create_urls( + count=1, + collector_metadata={ + "agency_name": "Test Agency" + } + ) + )[0].url_id + + await db_data_creator.add_compressed_html([applicable_url_id]) + + # Confirm prerequisite met and subtask if Agency Location Frequency + assert await operator.meets_task_prerequisites() + assert operator._subtask == LocationIDSubtaskType.NLP_LOCATION_FREQUENCY + + # Set flag to disable NLP Location Frequency Subtask + monkeypatch.setenv( + "LOCATION_ID_NLP_LOCATION_MATCH_FLAG", "0" + ) + + # Confirm prerequisite no longer met. + assert not await operator.meets_task_prerequisites() diff --git a/tests/automated/integration/tasks/url/impl/probe/__init__.py b/tests/automated/integration/tasks/url/impl/probe/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/check/__init__.py b/tests/automated/integration/tasks/url/impl/probe/check/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/check/manager.py b/tests/automated/integration/tasks/url/impl/probe/check/manager.py new file mode 100644 index 00000000..a8d89ba5 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/check/manager.py @@ -0,0 +1,56 @@ +from sqlalchemy import select + +from src.collectors.enums import URLStatus +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.link.url_redirect_url.sqlalchemy import LinkURLRedirectURL +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata + + +class TestURLProbeCheckManager: + + def __init__( + self, + adb_client: AsyncDatabaseClient + ): + self.adb_client = adb_client + + async def check_url( + self, + url_id: int, + expected_status: URLStatus + ): + url: URL = await self.adb_client.one_or_none(select(URL).where(URL.id == url_id)) + assert url is not None + assert url.status == expected_status + + async def check_web_metadata( + self, + url_id: int, + status_code: int | None, + content_type: str | None, + error: str | None, + accessed: bool + ): + web_metadata: URLWebMetadata = await self.adb_client.one_or_none( + select(URLWebMetadata).where(URLWebMetadata.url_id == url_id) + ) + assert web_metadata is not None + assert web_metadata.url_id == url_id + assert web_metadata.status_code == status_code + assert web_metadata.content_type == content_type + assert web_metadata.error_message == error + assert web_metadata.accessed == accessed + + async def check_redirect( + self, + source_url_id: int, + ) -> int: + """ + Check existence of redirect link using source_url_id and return destination_url_id + """ + redirect: LinkURLRedirectURL = await self.adb_client.one_or_none( + select(LinkURLRedirectURL).where(LinkURLRedirectURL.source_url_id == source_url_id) + ) + assert redirect is not None + return redirect.destination_url_id \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/conftest.py b/tests/automated/integration/tasks/url/impl/probe/conftest.py new file mode 100644 index 00000000..1c390288 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/conftest.py @@ -0,0 +1,23 @@ +import pytest + +from src.db.client.async_ import AsyncDatabaseClient +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.fixture +def setup_manager( + adb_client_test: AsyncDatabaseClient +) -> TestURLProbeSetupManager: + return TestURLProbeSetupManager( + adb_client=adb_client_test + ) + + +@pytest.fixture +def check_manager( + adb_client_test: AsyncDatabaseClient +) -> TestURLProbeCheckManager: + return TestURLProbeCheckManager( + adb_client=adb_client_test + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/constants.py b/tests/automated/integration/tasks/url/impl/probe/constants.py new file mode 100644 index 00000000..6c218e25 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/constants.py @@ -0,0 +1,6 @@ +from src.db.models.impl.url.core.enums import URLSource + +PATCH_ROOT = "src.external.url_request.core.URLProbeManager" +TEST_URL = "https://www.example.com" +TEST_DEST_URL = "https://www.example.com/redirect" +TEST_SOURCE = URLSource.COLLECTOR diff --git a/tests/automated/integration/tasks/url/impl/probe/mocks/__init__.py b/tests/automated/integration/tasks/url/impl/probe/mocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/mocks/url_request_interface.py b/tests/automated/integration/tasks/url/impl/probe/mocks/url_request_interface.py new file mode 100644 index 00000000..cc493274 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/mocks/url_request_interface.py @@ -0,0 +1,22 @@ +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper + + +class MockURLRequestInterface: + + def __init__( + self, + response_or_responses: URLProbeResponseOuterWrapper | list[URLProbeResponseOuterWrapper] + ): + if not isinstance(response_or_responses, list): + responses = [response_or_responses] + else: + responses = response_or_responses + + self._url_to_response = { + response.original_url: response for response in responses + } + + async def probe_urls(self, urls: list[str]) -> list[URLProbeResponseOuterWrapper]: + return [ + self._url_to_response[url] for url in urls + ] diff --git a/tests/automated/integration/tasks/url/impl/probe/models/__init__.py b/tests/automated/integration/tasks/url/impl/probe/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/models/entry.py b/tests/automated/integration/tasks/url/impl/probe/models/entry.py new file mode 100644 index 00000000..810f40ea --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/models/entry.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper + + +class TestURLProbeTaskEntry(BaseModel): + url: str + url_status: URLStatus + planned_response: URLProbeResponseOuterWrapper \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/no_redirect/__init__.py b/tests/automated/integration/tasks/url/impl/probe/no_redirect/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_error.py b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_error.py new file mode 100644 index 00000000..85dd71f5 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_error.py @@ -0,0 +1,55 @@ +import pytest + +from src.collectors.enums import URLStatus +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_url_probe_task_error( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager, + db_data_creator: DBDataCreator +): + """ + If a URL returns a 500 error response (or any other error), + the task should add web metadata response to the database + with + - the correct status + - content_type = None + - accessed = True + - the expected error message + """ + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_no_redirect_probe_response( + status_code=500, + content_type=None, + error="Something went wrong" + ) + ) + assert not await operator.meets_task_prerequisites() + url_id: int = await setup_manager.setup_url(URLStatus.OK) + await db_data_creator.create_validated_flags([url_id], validation_type=URLType.DATA_SOURCE) + await db_data_creator.create_url_data_sources([url_id]) + + assert await operator.meets_task_prerequisites() + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + assert not await operator.meets_task_prerequisites() + await check_manager.check_url( + url_id=url_id, + expected_status=URLStatus.OK + ) + + + await check_manager.check_web_metadata( + url_id=url_id, + status_code=500, + content_type=None, + error="Something went wrong", + accessed=True + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_not_found.py b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_not_found.py new file mode 100644 index 00000000..31216e23 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_not_found.py @@ -0,0 +1,51 @@ +import pytest + +from src.collectors.enums import URLStatus +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager +from tests.helpers.data_creator.core import DBDataCreator + + +@pytest.mark.asyncio +async def test_url_probe_task_not_found( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager, + db_data_creator: DBDataCreator +): + """ + If a URL returns a 404 error response, + the task should add web metadata response to the database + with + - the correct status + - content_type = None + - accessed = False + - error_message = "Not found." + """ + + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_no_redirect_probe_response( + status_code=404, + content_type=None, + error="Not found." + ) + ) + assert not await operator.meets_task_prerequisites() + url_id = await setup_manager.setup_url(URLStatus.OK) + await db_data_creator.create_validated_flags([url_id], validation_type=URLType.NOT_RELEVANT) + assert await operator.meets_task_prerequisites() + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + assert not await operator.meets_task_prerequisites() + await check_manager.check_url( + url_id=url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=url_id, + status_code=404, + content_type=None, + error="Not found.", + accessed=False + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_ok.py b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_ok.py new file mode 100644 index 00000000..ecaec084 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_ok.py @@ -0,0 +1,51 @@ +import pytest + +from src.collectors.enums import URLStatus +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_url_probe_task_no_redirect_ok( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + """ + If a URL returns a 200 OK response, + the task should add web metadata response to the database + with + - the correct status + - the correct content_type + - accessed = True + - error_message = None + """ + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_no_redirect_probe_response( + status_code=200, + content_type="text/html", + error=None + ) + ) + assert not await operator.meets_task_prerequisites() + url_id = await setup_manager.setup_url(URLStatus.OK) + assert await operator.meets_task_prerequisites() + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + assert not await operator.meets_task_prerequisites() + await check_manager.check_url( + url_id=url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=url_id, + status_code=200, + content_type="text/html", + accessed=True, + error=None + ) + + + + + diff --git a/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_two_urls.py b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_two_urls.py new file mode 100644 index 00000000..cfd1f68f --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/no_redirect/test_two_urls.py @@ -0,0 +1,42 @@ +import pytest + +from src.collectors.enums import URLStatus +from src.db.models.impl.url.core.sqlalchemy import URL +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_two_urls( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + url_1 = "https://example.com/1" + url_2 = "https://example.com/2" + operator = setup_manager.setup_operator( + response_or_responses=[ + setup_manager.setup_no_redirect_probe_response( + status_code=200, + content_type="text/html", + error=None, + url=url_1 + ), + setup_manager.setup_no_redirect_probe_response( + status_code=200, + content_type="text/html", + error=None, + url=url_2 + ) + ] + ) + assert not await operator.meets_task_prerequisites() + url_id_1 = await setup_manager.setup_url(URLStatus.OK, url=url_1) + url_id_2 = await setup_manager.setup_url(URLStatus.OK, url=url_2) + assert await operator.meets_task_prerequisites() + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + assert not await operator.meets_task_prerequisites() + + urls = await check_manager.adb_client.get_all(URL) + assert len(urls) == 2 diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/__init__.py b/tests/automated/integration/tasks/url/impl/probe/redirect/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/README.md b/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/README.md new file mode 100644 index 00000000..bb03c102 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/README.md @@ -0,0 +1 @@ +Tests for when the destination is a new URL not in the database. \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/__init__.py b/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/test_dest_ok.py b/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/test_dest_ok.py new file mode 100644 index 00000000..df695021 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/redirect/dest_new/test_dest_ok.py @@ -0,0 +1,56 @@ +import pytest + +from src.collectors.enums import URLStatus +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_url_probe_task_redirect_dest_new_ok( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + """ + If a URL + - returns a redirect response to a new URL, + - and the new URL returns a 200 OK response and does not exist in the database, + the task + - should add the new URL to the database + - along with web metadata response to the database + - and the link between the original URL and the new URL. + """ + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_redirect_probe_response( + redirect_status_code=301, + dest_status_code=200, + dest_content_type="text/html", + dest_error=None + ) + ) + source_url_id = await setup_manager.setup_url(URLStatus.OK) + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + await check_manager.check_url( + url_id=source_url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=source_url_id, + status_code=301, + content_type=None, + error=None, + accessed=True + ) + dest_url_id = await check_manager.check_redirect(source_url_id) + await check_manager.check_url( + url_id=dest_url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=dest_url_id, + status_code=200, + content_type="text/html", + error=None, + accessed=True + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/test_dest_exists_in_db.py b/tests/automated/integration/tasks/url/impl/probe/redirect/test_dest_exists_in_db.py new file mode 100644 index 00000000..b52dce6b --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/redirect/test_dest_exists_in_db.py @@ -0,0 +1,70 @@ +import pytest + +from src.collectors.enums import URLStatus +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.constants import TEST_DEST_URL +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_url_probe_task_redirect_dest_exists_in_db( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + """ + If a URL: + - returns a redirect response to a new URL, + - and the new URL already exists in the database, + the task should add web metadata response to the database URL + and a link between the original URL and the new URL. + + """ + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_redirect_probe_response( + redirect_status_code=302, + dest_status_code=200, + dest_content_type="text/html", + dest_error=None + ) + ) + source_url_id = await setup_manager.setup_url(URLStatus.OK) + dest_url_id = await setup_manager.setup_url(URLStatus.OK, url=TEST_DEST_URL) + # Add web metadata for destination URL, to prevent it from being pulled + web_metadata = URLWebMetadataPydantic( + url_id=dest_url_id, + status_code=200, + content_type="text/html", + error_message=None, + accessed=True + ) + await setup_manager.adb_client.bulk_insert([web_metadata]) + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + await check_manager.check_url( + url_id=source_url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_url( + url_id=dest_url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=source_url_id, + status_code=302, + content_type=None, + error=None, + accessed=True + ) + await check_manager.check_web_metadata( + url_id=dest_url_id, + status_code=200, + content_type="text/html", + error=None, + accessed=True + ) + redirect_url_id = await check_manager.check_redirect( + source_url_id=source_url_id + ) + assert redirect_url_id == dest_url_id \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/test_redirect_infinite.py b/tests/automated/integration/tasks/url/impl/probe/redirect/test_redirect_infinite.py new file mode 100644 index 00000000..5a66af3d --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/redirect/test_redirect_infinite.py @@ -0,0 +1,46 @@ +import pytest + +from src.collectors.enums import URLStatus +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.constants import TEST_URL +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_url_probe_task_redirect_infinite( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + """ + If a URL: + - returns a redirect response to itself + The task should add a link that points to itself + as well as web metadata response to the database URL + """ + + operator = setup_manager.setup_operator( + response_or_responses=setup_manager.setup_redirect_probe_response( + redirect_status_code=303, + dest_status_code=303, + dest_content_type=None, + dest_error=None, + redirect_url=TEST_URL + ) + ) + url_id = await setup_manager.setup_url(URLStatus.OK) + run_info = await operator.run_task() + await check_manager.check_url( + url_id=url_id, + expected_status=URLStatus.OK + ) + await check_manager.check_web_metadata( + url_id=url_id, + status_code=303, + content_type=None, + error=None, + accessed=True + ) + redirect_url_id = await check_manager.check_redirect( + source_url_id=url_id, + ) + assert redirect_url_id == url_id diff --git a/tests/automated/integration/tasks/url/impl/probe/redirect/test_two_urls_same_dest.py b/tests/automated/integration/tasks/url/impl/probe/redirect/test_two_urls_same_dest.py new file mode 100644 index 00000000..f0e113ff --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/redirect/test_two_urls_same_dest.py @@ -0,0 +1,56 @@ +import pytest + +from src.collectors.enums import URLStatus +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.probe.check.manager import TestURLProbeCheckManager +from tests.automated.integration.tasks.url.impl.probe.setup.manager import TestURLProbeSetupManager + + +@pytest.mark.asyncio +async def test_url_probe_task_redirect_two_urls_same_dest( + setup_manager: TestURLProbeSetupManager, + check_manager: TestURLProbeCheckManager +): + """ + If two URLs: + - return a redirect response to the same URL + Two links to that URL should be added to the database, one for each URL + """ + + operator = setup_manager.setup_operator( + response_or_responses=[ + setup_manager.setup_redirect_probe_response( + redirect_status_code=307, + dest_status_code=200, + dest_content_type=None, + dest_error=None, + ), + setup_manager.setup_redirect_probe_response( + redirect_status_code=308, + dest_status_code=200, + dest_content_type=None, + dest_error=None, + source_url="https://example.com/2", + ), + ] + ) + source_url_id_1 = await setup_manager.setup_url(URLStatus.OK) + source_url_id_2 = await setup_manager.setup_url(URLStatus.OK, url="https://example.com/2") + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + await check_manager.check_url( + url_id=source_url_id_1, + expected_status=URLStatus.OK + ) + await check_manager.check_url( + url_id=source_url_id_2, + expected_status=URLStatus.OK + ) + redirect_url_id_1 = await check_manager.check_redirect( + source_url_id=source_url_id_1 + ) + redirect_url_id_2 = await check_manager.check_redirect( + source_url_id=source_url_id_2 + ) + assert redirect_url_id_1 == redirect_url_id_2 + diff --git a/tests/automated/integration/tasks/url/impl/probe/setup/__init__.py b/tests/automated/integration/tasks/url/impl/probe/setup/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/probe/setup/manager.py b/tests/automated/integration/tasks/url/impl/probe/setup/manager.py new file mode 100644 index 00000000..50405970 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/probe/setup/manager.py @@ -0,0 +1,100 @@ +from typing import cast, Literal + +from src.collectors.enums import URLStatus +from src.core.tasks.url.operators.probe.core import URLProbeTaskOperator +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.external.url_request.core import URLRequestInterface +from src.external.url_request.probe.models.redirect import URLProbeRedirectResponsePair +from src.external.url_request.probe.models.response import URLProbeResponse +from src.external.url_request.probe.models.wrapper import URLProbeResponseOuterWrapper +from tests.automated.integration.tasks.url.impl.probe.constants import TEST_URL, TEST_DEST_URL, TEST_SOURCE +from tests.automated.integration.tasks.url.impl.probe.mocks.url_request_interface import MockURLRequestInterface + + +class TestURLProbeSetupManager: + + def __init__( + self, + adb_client: AsyncDatabaseClient + ): + self.adb_client = adb_client + + async def setup_url( + self, + url_status: URLStatus, + url: str = TEST_URL + ) -> int: + url_insert_model = URLInsertModel( + url=url, + status=url_status, + source=TEST_SOURCE + ) + return ( + await self.adb_client.bulk_insert( + models=[url_insert_model], + return_ids=True + ) + )[0] + + def setup_operator( + self, + response_or_responses: URLProbeResponseOuterWrapper | list[URLProbeResponseOuterWrapper] + ) -> URLProbeTaskOperator: + operator = URLProbeTaskOperator( + adb_client=self.adb_client, + url_request_interface=cast( + URLRequestInterface, + MockURLRequestInterface( + response_or_responses=response_or_responses + ) + ) + ) + return operator + + @staticmethod + def setup_no_redirect_probe_response( + status_code: int | None, + content_type: str | None, + error: str | None, + url: str = TEST_URL + ) -> URLProbeResponseOuterWrapper: + return URLProbeResponseOuterWrapper( + original_url=url, + response=URLProbeResponse( + url=url, + status_code=status_code, + content_type=content_type, + error=error + ) + ) + + @staticmethod + def setup_redirect_probe_response( + redirect_status_code: Literal[301, 302, 303, 307, 308], + dest_status_code: int, + dest_content_type: str | None, + dest_error: str | None, + source_url: str = TEST_URL, + redirect_url: str = TEST_DEST_URL + ) -> URLProbeResponseOuterWrapper: + if redirect_status_code not in (301, 302, 303, 307, 308): + raise ValueError('Redirect response must be one of 301, 302, 303, 307, 308') + return URLProbeResponseOuterWrapper( + original_url=source_url, + response=URLProbeRedirectResponsePair( + source=URLProbeResponse( + url=source_url, + status_code=redirect_status_code, + content_type=None, + error=None + ), + destination=URLProbeResponse( + url=redirect_url, + status_code=dest_status_code, + content_type=dest_content_type, + error=dest_error + ) + ) + ) + diff --git a/tests/automated/integration/tasks/url/impl/root_url/__init__.py b/tests/automated/integration/tasks/url/impl/root_url/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/root_url/conftest.py b/tests/automated/integration/tasks/url/impl/root_url/conftest.py new file mode 100644 index 00000000..16b7012e --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/conftest.py @@ -0,0 +1,9 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.client.async_ import AsyncDatabaseClient + + +@pytest.fixture +def operator(adb_client_test: AsyncDatabaseClient) -> URLRootURLTaskOperator: + return URLRootURLTaskOperator(adb_client=adb_client_test) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/root_url/constants.py b/tests/automated/integration/tasks/url/impl/root_url/constants.py new file mode 100644 index 00000000..dc688797 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/constants.py @@ -0,0 +1,5 @@ + + +ROOT_URL = "https://root.com" +BRANCH_URL = "https://root.com/branch" +SECOND_BRANCH_URL = "https://root.com/second-branch" \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_in_db.py b/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_in_db.py new file mode 100644 index 00000000..7e8af066 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_in_db.py @@ -0,0 +1,60 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.flag.root_url.pydantic import FlagRootURLPydantic +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import ROOT_URL, BRANCH_URL + + +@pytest.mark.asyncio +async def test_branch_root_url_in_db( + operator: URLRootURLTaskOperator +): + """ + If a URL is a branch URL, + with the root URL in the database, + it should be marked as such and not pulled again + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add URL that is a root URL, and mark as such + url_insert_model_root = URLInsertModel( + url=ROOT_URL, + source=URLSource.DATA_SOURCES + ) + root_url_id = (await operator.adb_client.bulk_insert([url_insert_model_root], return_ids=True))[0] + root_model_flag_insert = FlagRootURLPydantic( + url_id=root_url_id + ) + await operator.adb_client.bulk_insert([root_model_flag_insert]) + + # Add URL that is a branch of the root URL + url_insert_model = URLInsertModel( + url=BRANCH_URL, + source=URLSource.COLLECTOR + ) + branch_url_id = (await operator.adb_client.bulk_insert([url_insert_model], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + + links: list[LinkURLRootURL] = await operator.adb_client.get_all(LinkURLRootURL) + assert len(links) == 1 + assert links[0].url_id == branch_url_id + + # Check for only one flag, for the root URL + flags: list[FlagRootURL] = await operator.adb_client.get_all(FlagRootURL) + assert len(flags) == 1 + assert flags[0].url_id == root_url_id \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_not_in_db.py b/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_not_in_db.py new file mode 100644 index 00000000..6c00f8f9 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_branch_root_url_not_in_db.py @@ -0,0 +1,58 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.db.models.impl.url.core.sqlalchemy import URL +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import BRANCH_URL, ROOT_URL + + +@pytest.mark.asyncio +async def test_branch_root_url_not_in_db( + operator: URLRootURLTaskOperator +): + """ + If a URL is a branch URL, + with the root URL not in the database, + Add the root URL and mark it as such + and add the link to the root URL for the branch + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add URL that is a branch of a root URL + url_insert_model = URLInsertModel( + url=BRANCH_URL, + source=URLSource.COLLECTOR + ) + branch_url_id = (await operator.adb_client.bulk_insert([url_insert_model], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + + # Check for presence of root URL with proper source and flag + urls: list[URL] = await operator.adb_client.get_all(URL) + root_url = next(url for url in urls if url.url == ROOT_URL) + assert root_url.source == URLSource.ROOT_URL + + # Check for presence of link for branch URL + links: list[LinkURLRootURL] = await operator.adb_client.get_all(LinkURLRootURL) + assert len(links) == 1 + link = next(link for link in links if link.url_id == branch_url_id) + assert link.root_url_id == root_url.id + + # Check for absence of flag for branch URL + flags: list[FlagRootURL] = await operator.adb_client.get_all(FlagRootURL) + assert len(flags) == 1 + flag = next(flag for flag in flags if flag.url_id == root_url.id) + assert flag \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_is_root_url.py b/tests/automated/integration/tasks/url/impl/root_url/test_is_root_url.py new file mode 100644 index 00000000..a6a56c7c --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_is_root_url.py @@ -0,0 +1,47 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import ROOT_URL + + +@pytest.mark.asyncio +async def test_is_root_url( + operator: URLRootURLTaskOperator +): + """ + If a URL is a root URL, + it should be marked as such and not pulled again + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add URL that is a root URL + url_insert_model = URLInsertModel( + url=ROOT_URL, + source=URLSource.DATA_SOURCES + ) + url_id = (await operator.adb_client.bulk_insert([url_insert_model], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + + # Check for absence of Link + links: list[LinkURLRootURL] = await operator.adb_client.get_all(LinkURLRootURL) + assert len(links) == 0 + + # Check for presence of Flag + flags: list[FlagRootURL] = await operator.adb_client.get_all(FlagRootURL) + assert len(flags) == 1 + assert flags[0].url_id == url_id \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db.py b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db.py new file mode 100644 index 00000000..be67d23e --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db.py @@ -0,0 +1,61 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.flag.root_url.pydantic import FlagRootURLPydantic +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import ROOT_URL, BRANCH_URL, SECOND_BRANCH_URL + + +@pytest.mark.asyncio +async def test_two_branches_one_root_in_db( + operator: URLRootURLTaskOperator +): + """ + If two URLs are branches of a ROOT URL that is already in the database, + Both URLs should be linked to the ROOT URL + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add root URL and mark as such + url_insert_model_root = URLInsertModel( + url=ROOT_URL, + source=URLSource.DATA_SOURCES + ) + url_id_root = (await operator.adb_client.bulk_insert([url_insert_model_root], return_ids=True))[0] + root_model_flag_insert = FlagRootURLPydantic( + url_id=url_id_root + ) + await operator.adb_client.bulk_insert([root_model_flag_insert]) + + # Add two URLs that are branches of that root URL + url_insert_model_branch_1 = URLInsertModel( + url=BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_1 = (await operator.adb_client.bulk_insert([url_insert_model_branch_1], return_ids=True))[0] + + url_insert_model_branch_2 = URLInsertModel( + url=SECOND_BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_2 = (await operator.adb_client.bulk_insert([url_insert_model_branch_2], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + + # Check for presence of separate links for both branch URLs + links: list[LinkURLRootURL] = await operator.adb_client.get_all(LinkURLRootURL) + assert len(links) == 2 + link_url_ids = {link.url_id for link in links} + assert link_url_ids == {url_id_branch_1, url_id_branch_2} diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db_not_flagged.py b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db_not_flagged.py new file mode 100644 index 00000000..614796e9 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_in_db_not_flagged.py @@ -0,0 +1,68 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.flag.root_url.pydantic import FlagRootURLPydantic +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import ROOT_URL, BRANCH_URL, SECOND_BRANCH_URL + + +@pytest.mark.asyncio +async def test_two_branches_one_root_in_db_not_flagged( + operator: URLRootURLTaskOperator +): + """ + If two URLs are branches of a ROOT URL that is already in the database + but not flagged as such, + Both URLs should be linked to the ROOT URL + and the Root URL should be flagged + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add root URL but do not mark as such + url_insert_model_root = URLInsertModel( + url=ROOT_URL, + source=URLSource.DATA_SOURCES + ) + url_id_root = (await operator.adb_client.bulk_insert([url_insert_model_root], return_ids=True))[0] + + # Add two URLs that are branches of that root URL + url_insert_model_branch_1 = URLInsertModel( + url=BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_1 = (await operator.adb_client.bulk_insert([url_insert_model_branch_1], return_ids=True))[0] + + url_insert_model_branch_2 = URLInsertModel( + url=SECOND_BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_2 = (await operator.adb_client.bulk_insert([url_insert_model_branch_2], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + + # Check for presence of separate links for both branch URLs + links: list[LinkURLRootURL] = await operator.adb_client.get_all(LinkURLRootURL) + assert len(links) == 2 + url_ids = [link.url_id for link in links] + # Check both URLs are present + assert set(url_ids) == {url_id_branch_1, url_id_branch_2} + # Check both URLs are linked to the root URL + assert url_id_root in [link.root_url_id for link in links] + + flags: list[FlagRootURL] = await operator.adb_client.get_all(FlagRootURL) + assert len(flags) == 1 + assert flags[0].url_id == url_id_root + diff --git a/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_not_in_db.py b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_not_in_db.py new file mode 100644 index 00000000..f68786b9 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/root_url/test_two_branches_one_root_not_in_db.py @@ -0,0 +1,45 @@ +import pytest + +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from tests.automated.integration.tasks.url.impl.asserts import assert_task_ran_without_error +from tests.automated.integration.tasks.url.impl.root_url.constants import BRANCH_URL, SECOND_BRANCH_URL + + +@pytest.mark.asyncio +@pytest.mark.asyncio +async def test_two_branches_one_root_in_db_not_flagged( + operator: URLRootURLTaskOperator +): + """ + If two URLs are branches of a ROOT URL that is not already in the database, + Both URLs, along with the Root URL, should be added to the database + and the Root URL should flagged as such + """ + # Check prerequisites not yet met + assert not await operator.meets_task_prerequisites() + + # Add two URLs that are branches of a root URL + url_insert_model_branch_1 = URLInsertModel( + url=BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_1 = (await operator.adb_client.bulk_insert([url_insert_model_branch_1], return_ids=True))[0] + + url_insert_model_branch_2 = URLInsertModel( + url=SECOND_BRANCH_URL, + source=URLSource.COLLECTOR + ) + url_id_branch_2 = (await operator.adb_client.bulk_insert([url_insert_model_branch_2], return_ids=True))[0] + + # Check prerequisites are now met + assert await operator.meets_task_prerequisites() + + # Run task + run_info = await operator.run_task() + assert_task_ran_without_error(run_info) + + # Check task prerequisites no longer met + assert not await operator.meets_task_prerequisites() + diff --git a/tests/automated/integration/tasks/url/impl/screenshot/__init__.py b/tests/automated/integration/tasks/url/impl/screenshot/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/screenshot/conftest.py b/tests/automated/integration/tasks/url/impl/screenshot/conftest.py new file mode 100644 index 00000000..41c38366 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/screenshot/conftest.py @@ -0,0 +1,14 @@ +import pytest_asyncio + +from src.core.tasks.url.operators.screenshot.core import URLScreenshotTaskOperator +from src.db.client.async_ import AsyncDatabaseClient + + +@pytest_asyncio.fixture +async def operator( + adb_client_test: AsyncDatabaseClient, +) -> URLScreenshotTaskOperator: + operator = URLScreenshotTaskOperator( + adb_client=adb_client_test, + ) + return operator \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/screenshot/test_core.py b/tests/automated/integration/tasks/url/impl/screenshot/test_core.py new file mode 100644 index 00000000..6f54fbf9 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/screenshot/test_core.py @@ -0,0 +1,74 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.core.tasks.url.operators.screenshot.core import URLScreenshotTaskOperator +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.url.screenshot.sqlalchemy import URLScreenshot +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.external.url_request.dtos.screenshot_response import URLScreenshotResponse +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.run import run_task_and_confirm_success + +# src/core/tasks/url/operators/screenshot/get.py +MOCK_ROOT_PATH = "src.core.tasks.url.operators.screenshot.get.get_screenshots" + +@pytest.mark.asyncio +async def test_core( + operator: URLScreenshotTaskOperator, + db_data_creator: DBDataCreator, + monkeypatch +) -> None: + + # Should not yet meet task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add two URLs to database + url_mappings: list[URLMapping] = await db_data_creator.create_urls(count=2) + screenshot_mapping: URLMapping = url_mappings[0] + error_mapping: URLMapping = url_mappings[1] + url_ids: list[int] = [url_mapping.url_id for url_mapping in url_mappings] + + # Add web metadata for 200 responses + await db_data_creator.create_web_metadata( + url_ids=url_ids, + status_code=200, + ) + + # Should now meet task prerequisites + assert await operator.meets_task_prerequisites() + + mock_get_screenshots = AsyncMock(return_value=[ + URLScreenshotResponse( + url=screenshot_mapping.url, + screenshot=bytes(124536), + ), + URLScreenshotResponse( + url=error_mapping.url, + screenshot=None, + error="error", + ) + ]) + + # Mock get_url_screenshots to return one success and one failure + monkeypatch.setattr( + MOCK_ROOT_PATH, + mock_get_screenshots + ) + + await run_task_and_confirm_success(operator) + + # Get screenshots from database, confirm only one + screenshots: list[URLScreenshot] = await db_data_creator.adb_client.get_all(URLScreenshot) + assert len(screenshots) == 1 + assert screenshots[0].url_id == screenshot_mapping.url_id + + # Get errors from database, confirm only one + errors: list[URLTaskError] = await db_data_creator.adb_client.get_all(URLTaskError) + assert len(errors) == 1 + assert errors[0].url_id == error_mapping.url_id + + + + + diff --git a/tests/automated/integration/tasks/url/impl/submit_approved/__init__.py b/tests/automated/integration/tasks/url/impl/submit_approved/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/submit_approved/mock.py b/tests/automated/integration/tasks/url/impl/submit_approved/mock.py new file mode 100644 index 00000000..0e631d5b --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/submit_approved/mock.py @@ -0,0 +1,38 @@ +from http import HTTPStatus +from unittest.mock import AsyncMock + +from pdap_access_manager import ResponseInfo + +from src.core.enums import SubmitResponseStatus +from src.external.pdap.client import PDAPClient + + +def mock_make_request(pdap_client: PDAPClient, urls: list[str]): + assert len(urls) == 3, "Expected 3 urls" + pdap_client.access_manager.make_request = AsyncMock( + return_value=ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "data_sources": [ + { + "url": urls[0], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 21, + }, + { + "url": urls[1], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 34, + }, + { + "url": urls[2], + "status": SubmitResponseStatus.FAILURE, + "error": "Test Error", + "data_source_id": None + } + ] + } + ) + ) diff --git a/tests/automated/integration/tasks/url/impl/submit_approved/setup.py b/tests/automated/integration/tasks/url/impl/submit_approved/setup.py new file mode 100644 index 00000000..1f9d8915 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/submit_approved/setup.py @@ -0,0 +1,49 @@ +from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo +from src.core.enums import RecordType +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo + + +async def setup_validated_urls(db_data_creator: DBDataCreator, agency_id: int) -> list[str]: + creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( + url_count=3, + with_html_content=True + ) + + url_1 = creation_info.url_ids[0] + url_2 = creation_info.url_ids[1] + url_3 = creation_info.url_ids[2] + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_1, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[agency_id], + name="URL 1 Name", + description=None, + record_formats=["Record Format 1", "Record Format 2"], + data_portal_type="Data Portal Type 1", + supplying_entity="Supplying Entity 1" + ), + user_id=1 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_2, + record_type=RecordType.INCARCERATION_RECORDS, + agency_ids=[agency_id], + name="URL 2 Name", + description="URL 2 Description", + ), + user_id=2 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_3, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[agency_id], + name="URL 3 Name", + description="URL 3 Description", + ), + user_id=3 + ) + return creation_info.urls diff --git a/tests/automated/integration/tasks/url/impl/submit_approved/test_submit_approved_url_task.py b/tests/automated/integration/tasks/url/impl/submit_approved/test_submit_approved_url_task.py new file mode 100644 index 00000000..3d1aec23 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/submit_approved/test_submit_approved_url_task.py @@ -0,0 +1,135 @@ +import pytest +from deepdiff import DeepDiff +from pdap_access_manager import RequestInfo, RequestType, DataSourcesNamespaces + +from src.collectors.enums import URLStatus +from src.core.tasks.url.enums import TaskOperatorOutcome +from src.core.tasks.url.operators.submit_approved.core import SubmitApprovedURLTaskOperator +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError +from src.external.pdap.client import PDAPClient +from tests.automated.integration.tasks.url.impl.submit_approved.mock import mock_make_request +from tests.automated.integration.tasks.url.impl.submit_approved.setup import setup_validated_urls + + +@pytest.mark.asyncio +async def test_submit_approved_url_task( + db_data_creator, + mock_pdap_client: PDAPClient, + monkeypatch +): + """ + The submit_approved_url_task should submit + all validated URLs to the PDAP Data Sources App + """ + + + # Get Task Operator + operator = SubmitApprovedURLTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + # Check Task Operator does not yet meet pre-requisites + assert not await operator.meets_task_prerequisites() + + # Create URLs with status 'validated' in database and all requisite URL values + # Ensure they have optional metadata as well + agency_id = await db_data_creator.agency() + urls: list[str] = await setup_validated_urls(db_data_creator, agency_id=agency_id) + mock_make_request(mock_pdap_client, urls) + + # Check Task Operator does meet pre-requisites + assert await operator.meets_task_prerequisites() + + # Run Task + run_info = await operator.run_task() + + # Check Task has been marked as completed + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message + + # Check Task Operator no longer meets pre-requisites + assert not await operator.meets_task_prerequisites() + + # Get URLs + urls: list[URL] = await db_data_creator.adb_client.get_all(URL, order_by_attribute="id") + url_1: URL = urls[0] + url_2: URL = urls[1] + url_3: URL = urls[2] + + # Get URL Data Source Links + url_data_sources = await db_data_creator.adb_client.get_all(URLDataSource) + assert len(url_data_sources) == 2 + + url_data_source_1 = url_data_sources[0] + url_data_source_2 = url_data_sources[1] + + assert url_data_source_1.url_id == url_1.id + assert url_data_source_1.data_source_id == 21 + + assert url_data_source_2.url_id == url_2.id + assert url_data_source_2.data_source_id == 34 + + # Check that errored URL has entry in url_error_info + url_errors = await db_data_creator.adb_client.get_all(URLTaskError) + assert len(url_errors) == 1 + url_error = url_errors[0] + assert url_error.url_id == url_3.id + assert url_error.error == "Test Error" + + # Check mock method was called expected parameters + access_manager = mock_pdap_client.access_manager + access_manager.make_request.assert_called_once() + access_manager.build_url.assert_called_with( + namespace=DataSourcesNamespaces.SOURCE_COLLECTOR, + subdomains=['data-sources'] + ) + + call_1 = access_manager.make_request.call_args_list[0][0][0] + expected_call_1 = RequestInfo( + type_=RequestType.POST, + url="http://example.com", + headers=access_manager.jwt_header.return_value, + json_={ + "data_sources": [ + { + "name": "URL 1 Name", + "source_url": url_1.url, + "record_type": "Accident Reports", + "description": None, + "record_formats": ["Record Format 1", "Record Format 2"], + "data_portal_type": "Data Portal Type 1", + "last_approval_editor": 1, + "supplying_entity": "Supplying Entity 1", + "agency_ids": [agency_id] + }, + { + "name": "URL 2 Name", + "source_url": url_2.url, + "record_type": "Incarceration Records", + "description": "URL 2 Description", + "last_approval_editor": 2, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [agency_id] + }, + { + "name": "URL 3 Name", + "source_url": url_3.url, + "record_type": "Accident Reports", + "description": "URL 3 Description", + "last_approval_editor": 3, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [agency_id] + } + ] + } + ) + assert call_1.type_ == expected_call_1.type_ + assert call_1.headers == expected_call_1.headers + diff = DeepDiff(call_1.json_, expected_call_1.json_, ignore_order=True) + assert diff == {}, f"Differences found: {diff}" diff --git a/tests/automated/integration/tasks/url/impl/submit_approved/test_validated_meta_url.py b/tests/automated/integration/tasks/url/impl/submit_approved/test_validated_meta_url.py new file mode 100644 index 00000000..76754b29 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/submit_approved/test_validated_meta_url.py @@ -0,0 +1,41 @@ +import pytest + +from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.url.operators.submit_approved.core import SubmitApprovedURLTaskOperator +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.url.data_source.sqlalchemy import URLDataSource +from src.external.pdap.client import PDAPClient +from tests.helpers.asserts import assert_task_run_success + + +@pytest.mark.asyncio +async def test_validated_meta_url_not_included( + db_data_creator, + mock_pdap_client: PDAPClient, +): + """ + If a validated Meta URL is included in the database + This should not be included in the submit approved task + """ + + # Get Task Operator + operator = SubmitApprovedURLTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + dbdc = db_data_creator + url_1: int = (await dbdc.create_validated_urls( + validation_type=URLType.META_URL + ))[0].url_id + + # Test task operator does not meet prerequisites + assert not await operator.meets_task_prerequisites() + + # Run task and confirm runs without error + run_info: TaskOperatorRunInfo = await operator.run_task() + assert_task_run_success(run_info) + + # Confirm entry not included in database + ds_urls: list[URLDataSource] = await dbdc.adb_client.get_all(URLDataSource) + assert len(ds_urls) == 0 diff --git a/tests/automated/integration/tasks/url/impl/submit_meta_urls/__init__.py b/tests/automated/integration/tasks/url/impl/submit_meta_urls/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/submit_meta_urls/test_core.py b/tests/automated/integration/tasks/url/impl/submit_meta_urls/test_core.py new file mode 100644 index 00000000..37d6e00f --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/submit_meta_urls/test_core.py @@ -0,0 +1,80 @@ +from http import HTTPStatus +from unittest.mock import AsyncMock + +import pytest +from pdap_access_manager import ResponseInfo + +from src.collectors.enums import URLStatus +from src.core.enums import SubmitResponseStatus +from src.core.tasks.url.operators.submit_meta_urls.core import SubmitMetaURLsTaskOperator +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.ds_meta_url.sqlalchemy import URLDSMetaURL +from src.external.pdap.client import PDAPClient +from src.external.pdap.impl.meta_urls.enums import SubmitMetaURLsStatus +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_submit_meta_urls( + db_data_creator: DBDataCreator, + mock_pdap_client: PDAPClient, +): + """ + Test Submit Meta URLs Task Operator + """ + + + operator = SubmitMetaURLsTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + assert not await operator.meets_task_prerequisites() + + # Create validated meta url + agency_id: int = (await db_data_creator.create_agencies(count=1))[0] + + mapping: URLMapping = (await db_data_creator.create_validated_urls( + validation_type=URLType.META_URL + ))[0] + await db_data_creator.link_urls_to_agencies( + url_ids=[mapping.url_id], + agency_ids=[agency_id] + ) + + mock_pdap_client.access_manager.make_request = AsyncMock( + return_value=ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "meta_urls": [ + { + "url": mapping.url, + "agency_id": agency_id, + "status": SubmitMetaURLsStatus.SUCCESS.value, + "meta_url_id": 2, + "error": None, + }, + ] + } + ) + ) + + + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + urls: list[URL] = await db_data_creator.adb_client.get_all(URL) + assert len(urls) == 1 + url: URL = urls[0] + assert url.status == URLStatus.OK + + url_ds_meta_urls: list[URLDSMetaURL] = await db_data_creator.adb_client.get_all(URLDSMetaURL) + assert len(url_ds_meta_urls) == 1 + url_ds_meta_url: URLDSMetaURL = url_ds_meta_urls[0] + assert url_ds_meta_url.url_id == url.id + assert url_ds_meta_url.ds_meta_url_id == 2 + assert url_ds_meta_url.agency_id == agency_id \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/suspend/__init__.py b/tests/automated/integration/tasks/url/impl/suspend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/suspend/test_core.py b/tests/automated/integration/tasks/url/impl/suspend/test_core.py new file mode 100644 index 00000000..9e1f57d8 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/suspend/test_core.py @@ -0,0 +1,50 @@ +import pytest + +from src.core.tasks.url.operators.suspend.core import SuspendURLTaskOperator +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.flag.url_suspended.sqlalchemy import FlagURLSuspended +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_suspend_task( + adb_client_test: AsyncDatabaseClient, + db_data_creator: DBDataCreator, +): + operator = SuspendURLTaskOperator( + adb_client=adb_client_test + ) + + assert not await operator.meets_task_prerequisites() + + url_id_1: int = (await db_data_creator.create_urls(count=1))[0].url_id + + assert not await operator.meets_task_prerequisites() + + await db_data_creator.not_found_location_suggestion(url_id=url_id_1) + + assert not await operator.meets_task_prerequisites() + + await db_data_creator.not_found_location_suggestion(url_id=url_id_1) + + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + url_id_2: int = (await db_data_creator.create_urls(count=1))[0].url_id + + await db_data_creator.not_found_agency_suggestion(url_id=url_id_2) + + assert not await operator.meets_task_prerequisites() + + await db_data_creator.not_found_agency_suggestion(url_id=url_id_2) + + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + flags: list[FlagURLSuspended] = await adb_client_test.get_all(FlagURLSuspended) + assert len(flags) == 2 + + assert {flag.url_id for flag in flags} == {url_id_1, url_id_2} \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/test_example_task.py b/tests/automated/integration/tasks/url/impl/test_example_task.py similarity index 75% rename from tests/automated/integration/tasks/url/test_example_task.py rename to tests/automated/integration/tasks/url/impl/test_example_task.py index 9a2a2fc9..00ec7c34 100644 --- a/tests/automated/integration/tasks/url/test_example_task.py +++ b/tests/automated/integration/tasks/url/impl/test_example_task.py @@ -5,9 +5,12 @@ from src.db.enums import TaskType from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.base import URLTaskOperatorBase -from tests.helpers.db_data_creator import DBDataCreator +from src.db.models.impl.link.task_url import LinkTaskURL +from tests.helpers.data_creator.core import DBDataCreator -class ExampleTaskOperator(URLTaskOperatorBase): +class ExampleTaskOperator( + URLTaskOperatorBase, +): @property def task_type(self) -> TaskType: @@ -31,14 +34,16 @@ async def test_example_task_success(db_data_creator: DBDataCreator): async def mock_inner_task_logic(self): # Add link to 3 urls - self.linked_url_ids = url_ids + await self.link_urls_to_task(url_ids) operator = ExampleTaskOperator(adb_client=db_data_creator.adb_client) operator.inner_task_logic = types.MethodType(mock_inner_task_logic, operator) - run_info = await operator.run_task(1) + run_info = await operator.run_task() assert run_info.outcome == TaskOperatorOutcome.SUCCESS - assert run_info.linked_url_ids == url_ids + links: list[LinkTaskURL] = await db_data_creator.adb_client.get_all(LinkTaskURL) + assert len(links) == 3 + assert all(link.url_id in url_ids for link in links) @pytest.mark.asyncio @@ -49,7 +54,7 @@ def mock_inner_task_logic(self): raise ValueError("test error") operator.inner_task_logic = types.MethodType(mock_inner_task_logic, operator) - run_info = await operator.run_task(1) + run_info = await operator.run_task() assert run_info.outcome == TaskOperatorOutcome.ERROR diff --git a/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py b/tests/automated/integration/tasks/url/impl/test_url_miscellaneous_metadata_task.py similarity index 93% rename from tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py rename to tests/automated/integration/tasks/url/impl/test_url_miscellaneous_metadata_task.py index e3d7c529..0af83bff 100644 --- a/tests/automated/integration/tasks/url/test_url_miscellaneous_metadata_task.py +++ b/tests/automated/integration/tasks/url/impl/test_url_miscellaneous_metadata_task.py @@ -2,12 +2,12 @@ import pytest -from src.core.tasks.url.operators.url_miscellaneous_metadata.core import URLMiscellaneousMetadataTaskOperator -from src.db.models.instantiations.url.optional_data_source_metadata import URLOptionalDataSourceMetadata -from src.db.models.instantiations.url.core import URL +from src.core.tasks.url.operators.misc_metadata.core import URLMiscellaneousMetadataTaskOperator +from src.db.models.impl.url.optional_data_source_metadata import URLOptionalDataSourceMetadata +from src.db.models.impl.url.core.sqlalchemy import URL from src.collectors.enums import CollectorType from src.core.tasks.url.enums import TaskOperatorOutcome -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator def batch_and_url( @@ -94,7 +94,7 @@ async def test_url_miscellaneous_metadata_task(db_data_creator: DBDataCreator): assert meets_prereqs # Run task - run_info = await operator.run_task(1) + run_info = await operator.run_task() assert run_info.outcome == TaskOperatorOutcome.SUCCESS # Check that each URL has the expected name/description and optional metadata diff --git a/tests/automated/integration/tasks/url/test_url_record_type_task.py b/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py similarity index 84% rename from tests/automated/integration/tasks/url/test_url_record_type_task.py rename to tests/automated/integration/tasks/url/impl/test_url_record_type_task.py index 514aa716..1373f3fa 100644 --- a/tests/automated/integration/tasks/url/test_url_record_type_task.py +++ b/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py @@ -3,11 +3,11 @@ import pytest from src.db.enums import TaskType -from src.db.models.instantiations.url.suggestion.record_type.auto import AutoRecordTypeSuggestion +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator from src.core.enums import RecordType -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from src.core.tasks.url.operators.record_type.llm_api.record_classifier.deepseek import DeepSeekRecordClassifier @pytest.mark.asyncio @@ -32,9 +32,8 @@ async def test_url_record_type_task(db_data_creator: DBDataCreator): await db_data_creator.html_data(url_ids) assert await operator.meets_task_prerequisites() - task_id = await db_data_creator.adb_client.initiate_task(task_type=TaskType.RECORD_TYPE) - run_info = await operator.run_task(task_id) + run_info = await operator.run_task() assert run_info.outcome == TaskOperatorOutcome.SUCCESS # Task should have been created @@ -46,7 +45,6 @@ async def test_url_record_type_task(db_data_creator: DBDataCreator): assert len(tasks) == 1 task = tasks[0] assert task.type == TaskType.RECORD_TYPE - assert run_info.linked_url_ids == url_ids assert task.url_error_count == 1 # Get metadata diff --git a/tests/automated/integration/tasks/url/impl/validate/__init__.py b/tests/automated/integration/tasks/url/impl/validate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/impl/validate/conftest.py b/tests/automated/integration/tasks/url/impl/validate/conftest.py new file mode 100644 index 00000000..0bcc5712 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/conftest.py @@ -0,0 +1,32 @@ +import pytest +import pytest_asyncio + +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator +from src.db.client.async_ import AsyncDatabaseClient +from tests.automated.integration.tasks.url.impl.validate.helper import TestValidateTaskHelper +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo + + +@pytest.fixture +def operator( + adb_client_test: AsyncDatabaseClient +) -> AutoValidateURLTaskOperator: + return AutoValidateURLTaskOperator( + adb_client=adb_client_test, + ) + +@pytest_asyncio.fixture +async def helper( + db_data_creator: DBDataCreator, + pittsburgh_locality: LocalityCreationInfo +) -> TestValidateTaskHelper: + url_id: int = (await db_data_creator.create_urls(count=1, record_type=None))[0].url_id + agency_id: int = await db_data_creator.agency() + return TestValidateTaskHelper( + db_data_creator, + url_id=url_id, + agency_id=agency_id, + location_id=pittsburgh_locality.location_id + ) + diff --git a/tests/automated/integration/tasks/url/impl/validate/helper.py b/tests/automated/integration/tasks/url/impl/validate/helper.py new file mode 100644 index 00000000..6ab44984 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/helper.py @@ -0,0 +1,145 @@ +from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo +from src.core.enums import RecordType +from src.db.client.async_ import AsyncDatabaseClient +from src.db.models.impl.flag.auto_validated.sqlalchemy import FlagURLAutoValidated +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.record_type.sqlalchemy import URLRecordType +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from tests.conftest import db_data_creator +from tests.helpers.counter import next_int +from tests.helpers.data_creator.core import DBDataCreator + +DEFAULT_RECORD_TYPE: RecordType = RecordType.INCARCERATION_RECORDS + +class TestValidateTaskHelper: + + def __init__( + self, + db_data_creator: DBDataCreator, + url_id: int, + agency_id: int, + location_id: int + ): + self.db_data_creator = db_data_creator + self.adb_client: AsyncDatabaseClient = db_data_creator.adb_client + self.url_id = url_id + self.agency_id = agency_id + self.location_id = location_id + + + async def check_url_validated( + self, + url_type: URLType, + ) -> None: + validated_flags: list[FlagURLValidated] = await self.adb_client.get_all(FlagURLValidated) + assert len(validated_flags) == 1 + validated_flag: FlagURLValidated = validated_flags[0] + assert validated_flag.url_id == self.url_id + assert validated_flag.type == url_type + + async def check_auto_validated( + self, + ) -> None: + auto_validated_flags: list[FlagURLAutoValidated] = await self.adb_client.get_all(FlagURLAutoValidated) + assert len(auto_validated_flags) == 1 + auto_validated_flag: FlagURLAutoValidated = auto_validated_flags[0] + assert auto_validated_flag.url_id == self.url_id + + async def check_agency_linked( + self + ) -> None: + links: list[LinkURLAgency] = await self.adb_client.get_all(LinkURLAgency) + assert len(links) == 1 + link: LinkURLAgency = links[0] + assert link.url_id == self.url_id + assert link.agency_id == self.agency_id + + async def check_record_type( + self, + record_type: RecordType = DEFAULT_RECORD_TYPE + ): + record_types: list[URLRecordType] = await self.adb_client.get_all(URLRecordType) + assert len(record_types) == 1 + rt: URLRecordType = record_types[0] + assert rt.url_id == self.url_id + assert rt.record_type == record_type + + async def add_url_type_suggestions( + self, + url_type: URLType, + count: int = 1 + ): + for _ in range(count): + await self.db_data_creator.user_relevant_suggestion( + suggested_status=url_type, + url_id=self.url_id, + user_id=next_int() + ) + + async def add_agency_suggestions( + self, + count: int = 1, + agency_id: int | None = None + ): + if agency_id is None: + agency_id = self.agency_id + for i in range(count): + await self.db_data_creator.agency_user_suggestions( + url_id=self.url_id, + user_id=next_int(), + agency_annotation_info=URLAgencyAnnotationPostInfo( + suggested_agency=agency_id + ) + ) + + async def add_location_suggestions( + self, + count: int = 1, + location_id: int | None = None + ): + if location_id is None: + location_id = self.location_id + for i in range(count): + await self.db_data_creator.add_user_location_suggestion( + url_id=self.url_id, + user_id=next_int(), + location_id=location_id, + ) + + async def add_record_type_suggestions( + self, + count: int = 1, + record_type: RecordType = DEFAULT_RECORD_TYPE + ): + for i in range(count): + await self.db_data_creator.user_record_type_suggestion( + url_id=self.url_id, + record_type=record_type, + user_id=next_int() + ) + + async def add_name_suggestion( + self, + count: int = 1, + ) -> str: + name = f"Test Validate Task Name" + suggestion_id: int = await self.db_data_creator.name_suggestion( + url_id=self.url_id, + source=NameSuggestionSource.USER, + name=name, + ) + for i in range(count): + await self.db_data_creator.user_name_endorsement( + suggestion_id=suggestion_id, + user_id=next_int(), + ) + return name + + async def check_name(self) -> None: + urls: list[URL] = await self.adb_client.get_all(URL) + assert len(urls) == 1 + url: URL = urls[0] + assert url.name == "Test Validate Task Name" \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/impl/validate/test_data_source.py b/tests/automated/integration/tasks/url/impl/validate/test_data_source.py new file mode 100644 index 00000000..82bed288 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/test_data_source.py @@ -0,0 +1,67 @@ +""" +Add a URL with two of the same suggestions for each of the following: +- Agency +- Location +- Record Type +- URL Type (DATA SOURCE) +And confirm it is validated as DATA SOURCE +""" +import pytest + +from src.core.enums import RecordType +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.automated.integration.tasks.url.impl.validate.helper import TestValidateTaskHelper +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_data_source( + operator: AutoValidateURLTaskOperator, + helper: TestValidateTaskHelper +): + await helper.add_url_type_suggestions( + url_type=URLType.DATA_SOURCE, + count=2 + ) + + assert not await operator.meets_task_prerequisites() + + await helper.add_agency_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_location_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_record_type_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_name_suggestion(count=2) + + assert await operator.meets_task_prerequisites() + + # Add different record type suggestion + await helper.add_record_type_suggestions( + count=2, + record_type=RecordType.STOPS + ) + + # Assert no longer meets task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add tiebreaker + await helper.add_record_type_suggestions() + + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + await helper.check_url_validated(URLType.DATA_SOURCE) + await helper.check_auto_validated() + await helper.check_agency_linked() + await helper.check_record_type() + await helper.check_name() + diff --git a/tests/automated/integration/tasks/url/impl/validate/test_individual_record.py b/tests/automated/integration/tasks/url/impl/validate/test_individual_record.py new file mode 100644 index 00000000..19d025df --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/test_individual_record.py @@ -0,0 +1,58 @@ +import pytest + +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.automated.integration.tasks.url.impl.validate.helper import TestValidateTaskHelper +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_individual_record( + operator: AutoValidateURLTaskOperator, + helper: TestValidateTaskHelper +): + """ + Add URL with 2 INDIVIDUAL RECORD suggestions. Check validated as INDIVIDUAL RECORD + """ + # Add two INDIVIDUAL record suggestions + await helper.add_url_type_suggestions( + url_type=URLType.INDIVIDUAL_RECORD, + count=2 + ) + + assert not await operator.meets_task_prerequisites() + + await helper.add_agency_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_location_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_name_suggestion(count=2) + + assert await operator.meets_task_prerequisites() + + # Add additional agency suggestions to create tie + additional_agency_id: int = await helper.db_data_creator.agency() + await helper.add_agency_suggestions( + count=2, + agency_id=additional_agency_id + ) + + # Confirm no longer meets task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add tiebreaker suggestion + await helper.add_agency_suggestions() + + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + await helper.check_url_validated(URLType.INDIVIDUAL_RECORD) + await helper.check_auto_validated() + await helper.check_agency_linked() + await helper.check_name() + diff --git a/tests/automated/integration/tasks/url/impl/validate/test_meta_url.py b/tests/automated/integration/tasks/url/impl/validate/test_meta_url.py new file mode 100644 index 00000000..962a2b63 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/test_meta_url.py @@ -0,0 +1,65 @@ +""" +Add a URL with two of the same suggestions for each of the following: +- Agency +- Location +- URL Type (META URL) +And confirm it is validated as META URL +""" +import pytest + +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.automated.integration.tasks.url.impl.validate.helper import TestValidateTaskHelper +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_meta_url( + operator: AutoValidateURLTaskOperator, + helper: TestValidateTaskHelper, + allegheny_county: CountyCreationInfo +): + # Add two META URL suggestions + await helper.add_url_type_suggestions(URLType.META_URL, count=2) + + # Assert operator does not yet meet task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add two Agency suggestions + await helper.add_agency_suggestions(count=2) + + # Assert operator does not yet meet task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add two location suggestions + await helper.add_location_suggestions(count=2) + + assert not await operator.meets_task_prerequisites() + + await helper.add_name_suggestion(count=2) + + # Assert operator now meets task prerequisites + assert await operator.meets_task_prerequisites() + + # Add additional two location suggestions for different location + await helper.add_location_suggestions( + count=2, + location_id=allegheny_county.location_id + ) + + # Assert operator no longer meets task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add additional location suggestion as tiebreaker + await helper.add_location_suggestions() + + # Assert operator again meets task prerequisites + assert await operator.meets_task_prerequisites() + + await run_task_and_confirm_success(operator) + + await helper.check_url_validated(URLType.META_URL) + await helper.check_auto_validated() + await helper.check_agency_linked() + await helper.check_name() diff --git a/tests/automated/integration/tasks/url/impl/validate/test_not_relevant.py b/tests/automated/integration/tasks/url/impl/validate/test_not_relevant.py new file mode 100644 index 00000000..288f61e9 --- /dev/null +++ b/tests/automated/integration/tasks/url/impl/validate/test_not_relevant.py @@ -0,0 +1,56 @@ +import pytest + +from src.core.tasks.url.operators.validate.core import AutoValidateURLTaskOperator +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.automated.integration.tasks.url.impl.validate.helper import TestValidateTaskHelper +from tests.helpers.run import run_task_and_confirm_success + + +@pytest.mark.asyncio +async def test_not_relevant( + operator: AutoValidateURLTaskOperator, + helper: TestValidateTaskHelper +): + """ + Add URL with 2 NOT RELEVANT suggestions. Check validated as NOT RELEVANT + """ + + # Assert operator does not yet meet task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add one NOT RELEVANT suggestion + await helper.add_url_type_suggestions( + url_type=URLType.NOT_RELEVANT, + ) + + # Assert operator does not yet meet task prerequisites + assert not await operator.meets_task_prerequisites() + + # Add second NOT RELEVANT suggestion + await helper.add_url_type_suggestions( + url_type=URLType.NOT_RELEVANT, + ) + + # Assert operator now meets task prerequisites + assert await operator.meets_task_prerequisites() + + # Add different suggestion to create tie + await helper.add_url_type_suggestions( + url_type=URLType.META_URL, + count=2 + ) + assert not await operator.meets_task_prerequisites() + + # Add tiebreaker + await helper.add_url_type_suggestions( + url_type=URLType.NOT_RELEVANT + ) + + await run_task_and_confirm_success(operator) + + # Assert URL validated as NOT RELEVANT + await helper.check_url_validated( + url_type=URLType.NOT_RELEVANT, + ) + + await helper.check_auto_validated() diff --git a/tests/automated/integration/tasks/url/loader/__init__.py b/tests/automated/integration/tasks/url/loader/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/integration/tasks/url/loader/conftest.py b/tests/automated/integration/tasks/url/loader/conftest.py new file mode 100644 index 00000000..a5d39643 --- /dev/null +++ b/tests/automated/integration/tasks/url/loader/conftest.py @@ -0,0 +1,26 @@ +from unittest.mock import AsyncMock + +import pytest + +from src.collectors.impl.muckrock.api_interface.core import MuckrockAPIInterface +from src.core.tasks.url.loader import URLTaskOperatorLoader +from src.core.tasks.url.operators.html.scraper.parser.core import HTMLResponseParser +from src.core.tasks.url.operators.location_id.subtasks.impl.nlp_location_freq.processor.nlp.core import NLPProcessor +from src.db.client.async_ import AsyncDatabaseClient +from src.external.huggingface.inference.client import HuggingFaceInferenceClient +from src.external.pdap.client import PDAPClient +from src.external.url_request.core import URLRequestInterface + + +@pytest.fixture(scope="session") +def loader() -> URLTaskOperatorLoader: + """Setup loader with mock dependencies""" + return URLTaskOperatorLoader( + adb_client=AsyncMock(spec=AsyncDatabaseClient), + url_request_interface=AsyncMock(spec=URLRequestInterface), + html_parser=AsyncMock(spec=HTMLResponseParser), + pdap_client=AsyncMock(spec=PDAPClient), + muckrock_api_interface=AsyncMock(spec=MuckrockAPIInterface), + hf_inference_client=AsyncMock(spec=HuggingFaceInferenceClient), + nlp_processor=AsyncMock(spec=NLPProcessor) + ) \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/loader/test_flags.py b/tests/automated/integration/tasks/url/loader/test_flags.py new file mode 100644 index 00000000..f812c947 --- /dev/null +++ b/tests/automated/integration/tasks/url/loader/test_flags.py @@ -0,0 +1,76 @@ +import pytest +from pydantic import BaseModel + +from src.core.tasks.url.loader import URLTaskOperatorLoader +from src.core.tasks.url.models.entry import URLTaskEntry +from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator +from src.core.tasks.url.operators.auto_name.core import AutoNameURLTaskOperator +from src.core.tasks.url.operators.auto_relevant.core import URLAutoRelevantTaskOperator +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from src.core.tasks.url.operators.html.core import URLHTMLTaskOperator +from src.core.tasks.url.operators.misc_metadata.core import URLMiscellaneousMetadataTaskOperator +from src.core.tasks.url.operators.probe.core import URLProbeTaskOperator +from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator +from src.core.tasks.url.operators.root_url.core import URLRootURLTaskOperator +from src.core.tasks.url.operators.submit_approved.core import SubmitApprovedURLTaskOperator + + +class FlagTestParams(BaseModel): + + class Config: + arbitrary_types_allowed = True + + env_var: str + operator: type[URLTaskOperatorBase] + +params = [ + FlagTestParams( + env_var="URL_HTML_TASK_FLAG", + operator=URLHTMLTaskOperator + ), + FlagTestParams( + env_var="URL_RECORD_TYPE_TASK_FLAG", + operator=URLRecordTypeTaskOperator + ), + FlagTestParams( + env_var="URL_AGENCY_IDENTIFICATION_TASK_FLAG", + operator=AgencyIdentificationTaskOperator + ), + FlagTestParams( + env_var="URL_SUBMIT_APPROVED_TASK_FLAG", + operator=SubmitApprovedURLTaskOperator + ), + FlagTestParams( + env_var="URL_MISC_METADATA_TASK_FLAG", + operator=URLMiscellaneousMetadataTaskOperator + ), + FlagTestParams( + env_var="URL_AUTO_RELEVANCE_TASK_FLAG", + operator=URLAutoRelevantTaskOperator + ), + FlagTestParams( + env_var="URL_PROBE_TASK_FLAG", + operator=URLProbeTaskOperator + ), + FlagTestParams( + env_var="URL_ROOT_URL_TASK_FLAG", + operator=URLRootURLTaskOperator + ), + FlagTestParams( + env_var="URL_AUTO_NAME_TASK_FLAG", + operator=AutoNameURLTaskOperator + ) +] + +@pytest.mark.asyncio +@pytest.mark.parametrize("flag_test_params", params) +async def test_flag_enabled( + flag_test_params: FlagTestParams, + monkeypatch, + loader: URLTaskOperatorLoader +): + monkeypatch.setenv(flag_test_params.env_var, "0") + entries: list[URLTaskEntry] = await loader.load_entries() + for entry in entries: + if isinstance(entry.operator, flag_test_params.operator): + assert not entry.enabled, f"Flag associated with env_var {flag_test_params.env_var} should be disabled" diff --git a/tests/automated/integration/tasks/url/loader/test_happy_path.py b/tests/automated/integration/tasks/url/loader/test_happy_path.py new file mode 100644 index 00000000..a7b02e89 --- /dev/null +++ b/tests/automated/integration/tasks/url/loader/test_happy_path.py @@ -0,0 +1,15 @@ +import pytest + +from src.core.tasks.url.loader import URLTaskOperatorLoader + +NUMBER_OF_TASK_OPERATORS: int = 14 + +@pytest.mark.asyncio +async def test_happy_path( + loader: URLTaskOperatorLoader +): + """ + Under normal circumstances, all task operators should be returned + """ + task_operators = await loader.load_entries() + assert len(task_operators) == NUMBER_OF_TASK_OPERATORS \ No newline at end of file diff --git a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py b/tests/automated/integration/tasks/url/test_agency_preannotation_task.py deleted file mode 100644 index 03961fe0..00000000 --- a/tests/automated/integration/tasks/url/test_agency_preannotation_task.py +++ /dev/null @@ -1,326 +0,0 @@ -from copy import deepcopy -from typing import Optional -from unittest.mock import MagicMock, AsyncMock, patch - -import pytest -from aiohttp import ClientSession - -from src.collectors.source_collectors.muckrock.api_interface.core import MuckrockAPIInterface -from src.collectors.source_collectors.muckrock.api_interface.lookup_response import AgencyLookupResponse -from src.collectors.source_collectors.muckrock.enums import AgencyLookupResponseType -from src.core.tasks.url.operators.agency_identification.core import AgencyIdentificationTaskOperator -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.db.models.instantiations.url.suggestion.agency.auto import AutomatedUrlAgencySuggestion -from src.external.pdap.enums import MatchAgencyResponseStatus -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from src.db.models.instantiations.agency import Agency -from src.collectors.enums import CollectorType, URLStatus -from src.core.tasks.url.enums import TaskOperatorOutcome -from src.core.tasks.url.subtasks.agency_identification.auto_googler import AutoGooglerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.ckan import CKANAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.common_crawler import CommonCrawlerAgencyIdentificationSubtask -from src.core.tasks.url.subtasks.agency_identification.muckrock import MuckrockAgencyIdentificationSubtask -from src.core.enums import SuggestionType -from pdap_access_manager import AccessManager -from src.external.pdap.dtos.match_agency.response import MatchAgencyResponse -from src.external.pdap.dtos.match_agency.post import MatchAgencyInfo -from src.external.pdap.client import PDAPClient -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfoV2 - -sample_agency_suggestions = [ - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=None, - agency_name=None, - state=None, - county=None, - locality=None - ), - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=-1, - agency_name="Test Agency", - state="Test State", - county="Test County", - locality="Test Locality" - ), - URLAgencySuggestionInfo( - url_id=-1, # This will be overwritten - suggestion_type=SuggestionType.AUTO_SUGGESTION, - pdap_agency_id=-1, - agency_name="Test Agency 2", - state="Test State 2", - county="Test County 2", - locality="Test Locality 2" - ) -] - -@pytest.mark.asyncio -async def test_agency_preannotation_task(db_data_creator: DBDataCreator): - async def mock_run_subtask( - subtask, - url_id: int, - collector_metadata: Optional[dict] - ): - # Deepcopy to prevent using the same instance in memory - suggestion = deepcopy(sample_agency_suggestions[url_id % 3]) - suggestion.url_id = url_id - suggestion.pdap_agency_id = (url_id % 3) if suggestion.suggestion_type != SuggestionType.UNKNOWN else None - return [suggestion] - - async with ClientSession() as session: - mock = MagicMock() - access_manager = AccessManager( - email=mock.email, - password=mock.password, - api_key=mock.api_key, - session=session - ) - pdap_client = PDAPClient( - access_manager=access_manager - ) - muckrock_api_interface = MuckrockAPIInterface(session=session) - with patch.object( - AgencyIdentificationTaskOperator, - "run_subtask", - side_effect=mock_run_subtask, - ) as mock: - operator = AgencyIdentificationTaskOperator( - adb_client=db_data_creator.adb_client, - pdap_client=pdap_client, - muckrock_api_interface=muckrock_api_interface - ) - - # Confirm does not yet meet prerequisites - assert not await operator.meets_task_prerequisites() - - - d = {} - - # Create six urls, one from each strategy - for strategy in [ - CollectorType.COMMON_CRAWLER, - CollectorType.AUTO_GOOGLER, - CollectorType.MUCKROCK_COUNTY_SEARCH, - CollectorType.MUCKROCK_SIMPLE_SEARCH, - CollectorType.MUCKROCK_ALL_SEARCH, - CollectorType.CKAN - ]: - # Create two URLs for each, one pending and one errored - creation_info: BatchURLCreationInfoV2 = await db_data_creator.batch_v2( - parameters=TestBatchCreationParameters( - strategy=strategy, - urls=[ - TestURLCreationParameters( - count=1, - status=URLStatus.PENDING, - with_html_content=True - ), - TestURLCreationParameters( - count=1, - status=URLStatus.ERROR, - with_html_content=True - ) - ] - ) - ) - d[strategy] = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings[0].url_id - - - # Confirm meets prerequisites - assert await operator.meets_task_prerequisites() - # Run task - run_info = await operator.run_task(1) - assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - - # Confirm tasks are piped into the correct subtasks - # * common_crawler into common_crawler_subtask - # * auto_googler into auto_googler_subtask - # * muckrock_county_search into muckrock_subtask - # * muckrock_simple_search into muckrock_subtask - # * muckrock_all_search into muckrock_subtask - # * ckan into ckan_subtask - - assert mock.call_count == 6 - - - # Confirm subtask classes are correct for the given urls - d2 = {} - for call_arg in mock.call_args_list: - subtask_class = call_arg[0][0].__class__ - url_id = call_arg[0][1] - d2[url_id] = subtask_class - - - subtask_class_collector_type = [ - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_ALL_SEARCH), - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_COUNTY_SEARCH), - (MuckrockAgencyIdentificationSubtask, CollectorType.MUCKROCK_SIMPLE_SEARCH), - (CKANAgencyIdentificationSubtask, CollectorType.CKAN), - (CommonCrawlerAgencyIdentificationSubtask, CollectorType.COMMON_CRAWLER), - (AutoGooglerAgencyIdentificationSubtask, CollectorType.AUTO_GOOGLER) - ] - - for subtask_class, collector_type in subtask_class_collector_type: - url_id = d[collector_type] - assert d2[url_id] == subtask_class - - - # Confirm task again does not meet prerequisites - assert not await operator.meets_task_prerequisites() - - - - - # Check confirmed and auto suggestions - adb_client = db_data_creator.adb_client - confirmed_suggestions = await adb_client.get_urls_with_confirmed_agencies() - assert len(confirmed_suggestions) == 2 - - agencies = await adb_client.get_all(Agency) - assert len(agencies) == 2 - - auto_suggestions = await adb_client.get_all(AutomatedUrlAgencySuggestion) - assert len(auto_suggestions) == 4 - - # Of the auto suggestions, 2 should be unknown - assert len([s for s in auto_suggestions if s.is_unknown]) == 2 - - # Of the auto suggestions, 2 should not be unknown - assert len([s for s in auto_suggestions if not s.is_unknown]) == 2 - -@pytest.mark.asyncio -async def test_common_crawler_subtask(db_data_creator: DBDataCreator): - # Test that common_crawler subtask correctly adds URL to - # url_agency_suggestions with label 'Unknown' - subtask = CommonCrawlerAgencyIdentificationSubtask() - results: list[URLAgencySuggestionInfo] = await subtask.run(url_id=1, collector_metadata={}) - assert len(results) == 1 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.UNKNOWN - - -@pytest.mark.asyncio -async def test_auto_googler_subtask(db_data_creator: DBDataCreator): - # Test that auto_googler subtask correctly adds URL to - # url_agency_suggestions with label 'Unknown' - subtask = AutoGooglerAgencyIdentificationSubtask() - results: list[URLAgencySuggestionInfo] = await subtask.run(url_id=1, collector_metadata={}) - assert len(results) == 1 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.UNKNOWN - -@pytest.mark.asyncio -async def test_muckrock_subtask(db_data_creator: DBDataCreator): - # Test that muckrock subtask correctly sends agency name to - # MatchAgenciesInterface and adds received suggestions to - # url_agency_suggestions - - # Create mock instances for dependency injections - muckrock_api_interface_mock = MagicMock(spec=MuckrockAPIInterface) - pdap_client_mock = MagicMock(spec=PDAPClient) - - # Set up mock return values for method calls - muckrock_api_interface_mock.lookup_agency.return_value = AgencyLookupResponse( - type=AgencyLookupResponseType.FOUND, - name="Mock Agency Name", - error=None - ) - - pdap_client_mock.match_agency.return_value = MatchAgencyResponse( - status=MatchAgencyResponseStatus.PARTIAL_MATCH, - matches=[ - MatchAgencyInfo( - id=1, - submitted_name="Mock Agency Name", - ), - MatchAgencyInfo( - id=2, - submitted_name="Another Mock Agency Name", - ) - ] - ) - - # Create an instance of MuckrockAgencyIdentificationSubtask with mock dependencies - muckrock_agency_identification_subtask = MuckrockAgencyIdentificationSubtask( - muckrock_api_interface=muckrock_api_interface_mock, - pdap_client=pdap_client_mock - ) - - # Run the subtask - results: list[URLAgencySuggestionInfo] = await muckrock_agency_identification_subtask.run( - url_id=1, - collector_metadata={ - "agency": 123 - } - ) - - # Verify the results - assert len(results) == 2 - assert results[0].url_id == 1 - assert results[0].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert results[0].pdap_agency_id == 1 - assert results[0].agency_name == "Mock Agency Name" - assert results[1].url_id == 1 - assert results[1].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert results[1].pdap_agency_id == 2 - assert results[1].agency_name == "Another Mock Agency Name" - - # Assert methods called as expected - muckrock_api_interface_mock.lookup_agency.assert_called_once_with( - muckrock_agency_id=123 - ) - pdap_client_mock.match_agency.assert_called_once_with( - name="Mock Agency Name" - ) - - -@pytest.mark.asyncio -async def test_ckan_subtask(db_data_creator: DBDataCreator): - # Test that ckan subtask correctly sends agency id to - # CKANAPIInterface, sends resultant agency name to - # PDAPClient and adds received suggestions to - # url_agency_suggestions - - pdap_client = AsyncMock() - pdap_client.match_agency.return_value = MatchAgencyResponse( - status=MatchAgencyResponseStatus.PARTIAL_MATCH, - matches=[ - MatchAgencyInfo( - id=1, - submitted_name="Mock Agency Name", - ), - MatchAgencyInfo( - id=2, - submitted_name="Another Mock Agency Name", - ) - ] - ) # Assuming MatchAgencyResponse is a class - - # Create an instance of CKANAgencyIdentificationSubtask - task = CKANAgencyIdentificationSubtask(pdap_client) - - # Call the run method with static values - collector_metadata = {"agency_name": "Test Agency"} - url_id = 1 - - # Call the run method - result = await task.run(url_id, collector_metadata) - - # Check the result - assert len(result) == 2 - assert result[0].url_id == 1 - assert result[0].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert result[0].pdap_agency_id == 1 - assert result[0].agency_name == "Mock Agency Name" - assert result[1].url_id == 1 - assert result[1].suggestion_type == SuggestionType.AUTO_SUGGESTION - assert result[1].pdap_agency_id == 2 - assert result[1].agency_name == "Another Mock Agency Name" - - # Assert methods called as expected - pdap_client.match_agency.assert_called_once_with(name="Test Agency") - diff --git a/tests/automated/integration/tasks/url/test_submit_approved_url_task.py b/tests/automated/integration/tasks/url/test_submit_approved_url_task.py deleted file mode 100644 index 0bdc3718..00000000 --- a/tests/automated/integration/tasks/url/test_submit_approved_url_task.py +++ /dev/null @@ -1,220 +0,0 @@ -from http import HTTPStatus -from unittest.mock import AsyncMock - -import pytest -from deepdiff import DeepDiff - -from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.core.tasks.url.operators.submit_approved_url.core import SubmitApprovedURLTaskOperator -from src.db.enums import TaskType -from src.db.models.instantiations.url.error_info import URLErrorInfo -from src.db.models.instantiations.url.data_source import URLDataSource -from src.db.models.instantiations.url.core import URL -from src.collectors.enums import URLStatus -from src.core.tasks.url.enums import TaskOperatorOutcome -from src.core.enums import RecordType, SubmitResponseStatus -from tests.helpers.db_data_creator import BatchURLCreationInfo, DBDataCreator -from pdap_access_manager import RequestInfo, RequestType, ResponseInfo, DataSourcesNamespaces -from src.external.pdap.client import PDAPClient - - -def mock_make_request(pdap_client: PDAPClient, urls: list[str]): - assert len(urls) == 3, "Expected 3 urls" - pdap_client.access_manager.make_request = AsyncMock( - return_value=ResponseInfo( - status_code=HTTPStatus.OK, - data={ - "data_sources": [ - { - "url": urls[0], - "status": SubmitResponseStatus.SUCCESS, - "error": None, - "data_source_id": 21, - }, - { - "url": urls[1], - "status": SubmitResponseStatus.SUCCESS, - "error": None, - "data_source_id": 34, - }, - { - "url": urls[2], - "status": SubmitResponseStatus.FAILURE, - "error": "Test Error", - "data_source_id": None - } - ] - } - ) - ) - - - -async def setup_validated_urls(db_data_creator: DBDataCreator) -> list[str]: - creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( - url_count=3, - with_html_content=True - ) - - url_1 = creation_info.url_ids[0] - url_2 = creation_info.url_ids[1] - url_3 = creation_info.url_ids[2] - await db_data_creator.adb_client.approve_url( - approval_info=FinalReviewApprovalInfo( - url_id=url_1, - record_type=RecordType.ACCIDENT_REPORTS, - agency_ids=[1, 2], - name="URL 1 Name", - description="URL 1 Description", - record_formats=["Record Format 1", "Record Format 2"], - data_portal_type="Data Portal Type 1", - supplying_entity="Supplying Entity 1" - ), - user_id=1 - ) - await db_data_creator.adb_client.approve_url( - approval_info=FinalReviewApprovalInfo( - url_id=url_2, - record_type=RecordType.INCARCERATION_RECORDS, - agency_ids=[3, 4], - name="URL 2 Name", - description="URL 2 Description", - ), - user_id=2 - ) - await db_data_creator.adb_client.approve_url( - approval_info=FinalReviewApprovalInfo( - url_id=url_3, - record_type=RecordType.ACCIDENT_REPORTS, - agency_ids=[5, 6], - name="URL 3 Name", - description="URL 3 Description", - ), - user_id=3 - ) - return creation_info.urls - -@pytest.mark.asyncio -async def test_submit_approved_url_task( - db_data_creator, - mock_pdap_client: PDAPClient, - monkeypatch -): - """ - The submit_approved_url_task should submit - all validated URLs to the PDAP Data Sources App - """ - - - # Get Task Operator - operator = SubmitApprovedURLTaskOperator( - adb_client=db_data_creator.adb_client, - pdap_client=mock_pdap_client - ) - - # Check Task Operator does not yet meet pre-requisites - assert not await operator.meets_task_prerequisites() - - # Create URLs with status 'validated' in database and all requisite URL values - # Ensure they have optional metadata as well - urls = await setup_validated_urls(db_data_creator) - mock_make_request(mock_pdap_client, urls) - - # Check Task Operator does meet pre-requisites - assert await operator.meets_task_prerequisites() - - # Run Task - task_id = await db_data_creator.adb_client.initiate_task( - task_type=TaskType.SUBMIT_APPROVED - ) - run_info = await operator.run_task(task_id=task_id) - - # Check Task has been marked as completed - assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - - # Get URLs - urls = await db_data_creator.adb_client.get_all(URL, order_by_attribute="id") - url_1 = urls[0] - url_2 = urls[1] - url_3 = urls[2] - - # Check URLs have been marked as 'submitted' - assert url_1.outcome == URLStatus.SUBMITTED.value - assert url_2.outcome == URLStatus.SUBMITTED.value - assert url_3.outcome == URLStatus.ERROR.value - - # Get URL Data Source Links - url_data_sources = await db_data_creator.adb_client.get_all(URLDataSource) - assert len(url_data_sources) == 2 - - url_data_source_1 = url_data_sources[0] - url_data_source_2 = url_data_sources[1] - - assert url_data_source_1.url_id == url_1.id - assert url_data_source_1.data_source_id == 21 - - assert url_data_source_2.url_id == url_2.id - assert url_data_source_2.data_source_id == 34 - - # Check that errored URL has entry in url_error_info - url_errors = await db_data_creator.adb_client.get_all(URLErrorInfo) - assert len(url_errors) == 1 - url_error = url_errors[0] - assert url_error.url_id == url_3.id - assert url_error.error == "Test Error" - - # Check mock method was called expected parameters - access_manager = mock_pdap_client.access_manager - access_manager.make_request.assert_called_once() - access_manager.build_url.assert_called_with( - namespace=DataSourcesNamespaces.SOURCE_COLLECTOR, - subdomains=['data-sources'] - ) - - call_1 = access_manager.make_request.call_args_list[0][0][0] - expected_call_1 = RequestInfo( - type_=RequestType.POST, - url="http://example.com", - headers=access_manager.jwt_header.return_value, - json_={ - "data_sources": [ - { - "name": "URL 1 Name", - "source_url": url_1.url, - "record_type": "Accident Reports", - "description": "URL 1 Description", - "record_formats": ["Record Format 1", "Record Format 2"], - "data_portal_type": "Data Portal Type 1", - "last_approval_editor": 1, - "supplying_entity": "Supplying Entity 1", - "agency_ids": [1, 2] - }, - { - "name": "URL 2 Name", - "source_url": url_2.url, - "record_type": "Incarceration Records", - "description": "URL 2 Description", - "last_approval_editor": 2, - "supplying_entity": None, - "record_formats": None, - "data_portal_type": None, - "agency_ids": [3, 4] - }, - { - "name": "URL 3 Name", - "source_url": url_3.url, - "record_type": "Accident Reports", - "description": "URL 3 Description", - "last_approval_editor": 3, - "supplying_entity": None, - "record_formats": None, - "data_portal_type": None, - "agency_ids": [5, 6] - } - ] - } - ) - assert call_1.type_ == expected_call_1.type_ - assert call_1.headers == expected_call_1.headers - diff = DeepDiff(call_1.json_, expected_call_1.json_, ignore_order=True) - assert diff == {}, f"Differences found: {diff}" diff --git a/tests/automated/integration/tasks/url/test_url_404_probe.py b/tests/automated/integration/tasks/url/test_url_404_probe.py deleted file mode 100644 index 7a88f759..00000000 --- a/tests/automated/integration/tasks/url/test_url_404_probe.py +++ /dev/null @@ -1,164 +0,0 @@ -import types -from http import HTTPStatus - -import pendulum -import pytest -from aiohttp import ClientResponseError, RequestInfo - -from src.core.tasks.url.operators.url_404_probe.core import URL404ProbeTaskOperator -from src.core.tasks.url.operators.url_html.scraper.request_interface.core import URLRequestInterface -from src.db.models.instantiations.url.probed_for_404 import URLProbedFor404 -from src.db.models.instantiations.url.core import URL -from src.collectors.enums import URLStatus -from src.core.tasks.url.enums import TaskOperatorOutcome -from src.core.tasks.url.operators.url_html.scraper.request_interface.dtos.url_response import URLResponseInfo -from tests.helpers.db_data_creator import DBDataCreator -from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters - - -@pytest.mark.asyncio -async def test_url_404_probe_task(db_data_creator: DBDataCreator): - - mock_html_content = "" - mock_content_type = "text/html" - adb_client = db_data_creator.adb_client - - async def mock_make_simple_requests(self, urls: list[str]) -> list[URLResponseInfo]: - """ - Mock make_simple_requests so that - - the first url returns a 200 - - the second url returns a 404 - - the third url returns a general error - - """ - results = [] - for idx, url in enumerate(urls): - if idx == 1: - results.append( - URLResponseInfo( - success=False, - content_type=mock_content_type, - exception=str(ClientResponseError( - request_info=RequestInfo( - url=url, - method="GET", - real_url=url, - headers={}, - ), - code=HTTPStatus.NOT_FOUND.value, - history=(None,), - )), - status=HTTPStatus.NOT_FOUND - ) - ) - elif idx == 2: - results.append( - URLResponseInfo( - success=False, - exception=str(ValueError("test error")), - content_type=mock_content_type - ) - ) - else: - results.append(URLResponseInfo( - html=mock_html_content, success=True, content_type=mock_content_type)) - return results - - url_request_interface = URLRequestInterface() - url_request_interface.make_simple_requests = types.MethodType(mock_make_simple_requests, url_request_interface) - - operator = URL404ProbeTaskOperator( - url_request_interface=url_request_interface, - adb_client=adb_client - ) - # Check that initially prerequisites aren't met - meets_prereqs = await operator.meets_task_prerequisites() - assert not meets_prereqs - - # Add 4 URLs, 3 pending, 1 error - creation_info = await db_data_creator.batch_v2( - parameters=TestBatchCreationParameters( - urls=[ - TestURLCreationParameters( - count=3, - status=URLStatus.PENDING, - with_html_content=True - ), - TestURLCreationParameters( - count=1, - status=URLStatus.ERROR, - with_html_content=False - ), - ] - ) - ) - - meets_prereqs = await operator.meets_task_prerequisites() - assert meets_prereqs - - # Run task and validate results - run_info = await operator.run_task(task_id=1) - assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - - - pending_url_mappings = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings - url_id_success = pending_url_mappings[0].url_id - url_id_404 = pending_url_mappings[1].url_id - url_id_error = pending_url_mappings[2].url_id - - url_id_initial_error = creation_info.url_creation_infos[URLStatus.ERROR].url_mappings[0].url_id - - # Check that URLProbedFor404 has been appropriately populated - probed_for_404_objects: list[URLProbedFor404] = await db_data_creator.adb_client.get_all(URLProbedFor404) - - assert len(probed_for_404_objects) == 3 - assert probed_for_404_objects[0].url_id == url_id_success - assert probed_for_404_objects[1].url_id == url_id_404 - assert probed_for_404_objects[2].url_id == url_id_error - - # Check that the URLs have been updated appropriated - urls: list[URL] = await adb_client.get_all(URL) - - def find_url(url_id: int) -> URL: - for url in urls: - if url.id == url_id: - return url - raise Exception(f"URL with id {url_id} not found") - - assert find_url(url_id_success).outcome == URLStatus.PENDING.value - assert find_url(url_id_404).outcome == URLStatus.NOT_FOUND.value - assert find_url(url_id_error).outcome == URLStatus.PENDING.value - assert find_url(url_id_initial_error).outcome == URLStatus.ERROR.value - - # Check that meets_task_prerequisites now returns False - meets_prereqs = await operator.meets_task_prerequisites() - assert not meets_prereqs - - # Check that meets_task_prerequisites returns True - # After setting the last probed for 404 date to 2 months ago - two_months_ago = pendulum.now().subtract(months=2).naive() - await adb_client.mark_all_as_recently_probed_for_404( - [url_id_404, url_id_error], - dt=two_months_ago - ) - - meets_prereqs = await operator.meets_task_prerequisites() - assert meets_prereqs - - # Run the task and Ensure all but the URL previously marked as 404 have been checked again - run_info = await operator.run_task(task_id=2) - assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message - - probed_for_404_objects: list[URLProbedFor404] = await db_data_creator.adb_client.get_all(URLProbedFor404) - - assert len(probed_for_404_objects) == 3 - assert probed_for_404_objects[0].last_probed_at != two_months_ago - assert probed_for_404_objects[1].last_probed_at == two_months_ago - assert probed_for_404_objects[2].last_probed_at != two_months_ago - - - - - - diff --git a/tests/automated/unit/api/__init__.py b/tests/automated/unit/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/unit/api/test_all_annotation_post_info.py b/tests/automated/unit/api/test_all_annotation_post_info.py new file mode 100644 index 00000000..cb7bdb41 --- /dev/null +++ b/tests/automated/unit/api/test_all_annotation_post_info.py @@ -0,0 +1,108 @@ +import pytest +from pydantic import BaseModel + +from src.api.endpoints.annotate.all.post.models.agency import AnnotationPostAgencyInfo +from src.api.endpoints.annotate.all.post.models.location import AnnotationPostLocationInfo +from src.api.endpoints.annotate.all.post.models.request import AllAnnotationPostInfo +from src.core.enums import RecordType +from src.core.exceptions import FailedValidationException +from src.db.models.impl.flag.url_validated.enums import URLType + + +class TestAllAnnotationPostInfoParams(BaseModel): + suggested_status: URLType + record_type: RecordType | None + agency_ids: list[int] + location_ids: list[int] + raise_exception: bool + +@pytest.mark.parametrize( + "params", + [ + # Happy Paths + TestAllAnnotationPostInfoParams( + suggested_status=URLType.META_URL, + record_type=None, + agency_ids=[1, 2], + location_ids=[3,4], + raise_exception=False + ), + TestAllAnnotationPostInfoParams( + suggested_status=URLType.DATA_SOURCE, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[1, 2], + location_ids=[3,4], + raise_exception=False + ), + TestAllAnnotationPostInfoParams( + suggested_status=URLType.NOT_RELEVANT, + record_type=None, + agency_ids=[], + location_ids=[], + raise_exception=False + ), + TestAllAnnotationPostInfoParams( + suggested_status=URLType.INDIVIDUAL_RECORD, + record_type=None, + agency_ids=[1, 2], + location_ids=[3, 4], + raise_exception=False + ), + # Error Paths - Meta URL + TestAllAnnotationPostInfoParams( + suggested_status=URLType.META_URL, + record_type=RecordType.ACCIDENT_REPORTS, # Record Type Included + agency_ids=[1, 2], + location_ids=[3, 4], + raise_exception=True + ), + # Error Paths - Not Relevant + TestAllAnnotationPostInfoParams( + suggested_status=URLType.NOT_RELEVANT, + record_type=RecordType.ACCIDENT_REPORTS, # Record Type Included + agency_ids=[], + location_ids=[], + raise_exception=True + ), + TestAllAnnotationPostInfoParams( + suggested_status=URLType.NOT_RELEVANT, + record_type=None, + agency_ids=[1, 2], # Agency IDs Included + location_ids=[], + raise_exception=True + ), + TestAllAnnotationPostInfoParams( + suggested_status=URLType.NOT_RELEVANT, + record_type=None, + agency_ids=[], + location_ids=[1, 2], # Location IDs included + raise_exception=True + ), + # Error Paths - Individual Record + TestAllAnnotationPostInfoParams( + suggested_status=URLType.INDIVIDUAL_RECORD, + record_type=RecordType.ACCIDENT_REPORTS, # Record Type Included + agency_ids=[], + location_ids=[], + raise_exception=True + ), + ] +) +def test_all_annotation_post_info( + params: TestAllAnnotationPostInfoParams +): + if params.raise_exception: + with pytest.raises(FailedValidationException): + AllAnnotationPostInfo( + suggested_status=params.suggested_status, + record_type=params.record_type, + agency_info=AnnotationPostAgencyInfo(agency_ids=params.agency_ids), + location_info=AnnotationPostLocationInfo(location_ids=params.location_ids) + ) + else: + AllAnnotationPostInfo( + suggested_status=params.suggested_status, + record_type=params.record_type, + agency_info=AnnotationPostAgencyInfo(agency_ids=params.agency_ids), + location_info=AnnotationPostLocationInfo(location_ids=params.location_ids) + ) \ No newline at end of file diff --git a/tests/automated/unit/core/test_core_logger.py b/tests/automated/unit/core/test_core_logger.py index f6738011..6c4f0375 100644 --- a/tests/automated/unit/core/test_core_logger.py +++ b/tests/automated/unit/core/test_core_logger.py @@ -3,7 +3,7 @@ import pytest -from src.db.dtos.log import LogInfo +from src.db.models.impl.log.pydantic.info import LogInfo from src.core.logger import AsyncCoreLogger diff --git a/tests/automated/unit/db/__init__.py b/tests/automated/unit/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/unit/db/utils/__init__.py b/tests/automated/unit/db/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/unit/db/utils/validate/__init__.py b/tests/automated/unit/db/utils/validate/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/unit/db/utils/validate/mock/__init__.py b/tests/automated/unit/db/utils/validate/mock/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/automated/unit/db/utils/validate/mock/class_.py b/tests/automated/unit/db/utils/validate/mock/class_.py new file mode 100644 index 00000000..87b0d213 --- /dev/null +++ b/tests/automated/unit/db/utils/validate/mock/class_.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from tests.automated.unit.db.utils.validate.mock.protocol import MockProtocol + + +class MockClassNoProtocol(BaseModel): + mock_attribute: str | None = None + +class MockClassWithProtocol(BaseModel, MockProtocol): + mock_attribute: str | None = None \ No newline at end of file diff --git a/tests/automated/unit/db/utils/validate/mock/protocol.py b/tests/automated/unit/db/utils/validate/mock/protocol.py new file mode 100644 index 00000000..5a55d0fe --- /dev/null +++ b/tests/automated/unit/db/utils/validate/mock/protocol.py @@ -0,0 +1,7 @@ +from asyncio import Protocol + + +class MockProtocol(Protocol): + + def mock_method(self) -> None: + pass \ No newline at end of file diff --git a/tests/automated/unit/db/utils/validate/test_all_models_of_same_type.py b/tests/automated/unit/db/utils/validate/test_all_models_of_same_type.py new file mode 100644 index 00000000..8e325879 --- /dev/null +++ b/tests/automated/unit/db/utils/validate/test_all_models_of_same_type.py @@ -0,0 +1,17 @@ +import pytest + +from src.db.utils.validate import validate_all_models_of_same_type +from tests.automated.unit.db.utils.validate.mock.class_ import MockClassNoProtocol, MockClassWithProtocol + + +def test_validate_all_models_of_same_type_happy_path(): + + models = [MockClassNoProtocol() for _ in range(3)] + validate_all_models_of_same_type(models) + +def test_validate_all_models_of_same_type_error_path(): + + models = [MockClassNoProtocol() for _ in range(2)] + models.append(MockClassWithProtocol()) + with pytest.raises(TypeError): + validate_all_models_of_same_type(models) \ No newline at end of file diff --git a/tests/automated/unit/db/utils/validate/test_has_protocol.py b/tests/automated/unit/db/utils/validate/test_has_protocol.py new file mode 100644 index 00000000..cfb820a3 --- /dev/null +++ b/tests/automated/unit/db/utils/validate/test_has_protocol.py @@ -0,0 +1,17 @@ +import pytest + +from src.db.utils.validate import validate_has_protocol +from tests.automated.unit.db.utils.validate.mock.class_ import MockClassWithProtocol, MockClassNoProtocol +from tests.automated.unit.db.utils.validate.mock.protocol import MockProtocol + + +def test_validate_has_protocol_happy_path(): + + model = MockClassWithProtocol() + validate_has_protocol(model, MockProtocol) + +def test_validate_has_protocol_error_path(): + + model = MockClassNoProtocol() + with pytest.raises(TypeError): + validate_has_protocol(model, MockProtocol) \ No newline at end of file diff --git a/tests/automated/unit/dto/test_all_annotation_post_info.py b/tests/automated/unit/dto/test_all_annotation_post_info.py deleted file mode 100644 index 0778c089..00000000 --- a/tests/automated/unit/dto/test_all_annotation_post_info.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest - -from src.api.endpoints.annotate.all.post.dto import AllAnnotationPostInfo -from src.core.enums import RecordType, SuggestedStatus -from src.core.exceptions import FailedValidationException - -# Mock values to pass -mock_record_type = RecordType.ARREST_RECORDS.value # replace with valid RecordType if Enum -mock_agency = {"is_new": False, "suggested_agency": 1} # replace with a valid dict for the URLAgencyAnnotationPostInfo model - -@pytest.mark.parametrize( - "suggested_status, record_type, agency, should_raise", - [ - (SuggestedStatus.RELEVANT, mock_record_type, mock_agency, False), # valid - (SuggestedStatus.RELEVANT, None, mock_agency, True), # missing record_type - (SuggestedStatus.RELEVANT, mock_record_type, None, True), # missing agency - (SuggestedStatus.RELEVANT, None, None, True), # missing both - (SuggestedStatus.NOT_RELEVANT, None, None, False), # valid - (SuggestedStatus.NOT_RELEVANT, mock_record_type, None, True), # record_type present - (SuggestedStatus.NOT_RELEVANT, None, mock_agency, True), # agency present - (SuggestedStatus.NOT_RELEVANT, mock_record_type, mock_agency, True), # both present - ] -) -def test_all_annotation_post_info_validation(suggested_status, record_type, agency, should_raise): - data = { - "suggested_status": suggested_status.value, - "record_type": record_type, - "agency": agency - } - - if should_raise: - with pytest.raises(FailedValidationException): - AllAnnotationPostInfo(**data) - else: - model = AllAnnotationPostInfo(**data) - assert model.suggested_status == suggested_status diff --git a/tests/automated/unit/source_collectors/test_autogoogler_collector.py b/tests/automated/unit/source_collectors/test_autogoogler_collector.py index 96fbf8c4..cc191dc3 100644 --- a/tests/automated/unit/source_collectors/test_autogoogler_collector.py +++ b/tests/automated/unit/source_collectors/test_autogoogler_collector.py @@ -2,17 +2,18 @@ import pytest -from src.collectors.source_collectors.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO -from src.collectors.source_collectors.auto_googler.dtos.input import AutoGooglerInputDTO +from src.collectors.impl.auto_googler.dtos.query_results import GoogleSearchQueryResultsInnerDTO +from src.collectors.impl.auto_googler.dtos.input import AutoGooglerInputDTO from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.core import URLInfo from src.core.logger import AsyncCoreLogger -from src.collectors.source_collectors.auto_googler.collector import AutoGooglerCollector +from src.collectors.impl.auto_googler.collector import AutoGooglerCollector +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo @pytest.fixture def patch_get_query_results(monkeypatch): - patch_path = "src.collectors.source_collectors.auto_googler.searcher.GoogleSearcher.get_query_results" + patch_path = "src.collectors.impl.auto_googler.searcher.GoogleSearcher.get_query_results" mock = AsyncMock() mock.side_effect = [ [GoogleSearchQueryResultsInnerDTO(url="https://include.com/1", title="keyword", snippet="snippet 1"),], @@ -37,6 +38,12 @@ async def test_auto_googler_collector(patch_get_query_results): mock.assert_called_once_with("keyword") collector.adb_client.insert_urls.assert_called_once_with( - url_infos=[URLInfo(url="https://include.com/1", collector_metadata={"query": "keyword", "title": "keyword", "snippet": "snippet 1"})], + url_infos=[ + URLInfo( + url="https://include.com/1", + collector_metadata={"query": "keyword", "title": "keyword", "snippet": "snippet 1"}, + source=URLSource.COLLECTOR + ) + ], batch_id=1 ) \ No newline at end of file diff --git a/tests/automated/unit/source_collectors/test_common_crawl_collector.py b/tests/automated/unit/source_collectors/test_common_crawl_collector.py index 070f9533..0a10680f 100644 --- a/tests/automated/unit/source_collectors/test_common_crawl_collector.py +++ b/tests/automated/unit/source_collectors/test_common_crawl_collector.py @@ -2,16 +2,17 @@ import pytest -from src.collectors.source_collectors.common_crawler.input import CommonCrawlerInputDTO +from src.collectors.impl.common_crawler.input import CommonCrawlerInputDTO from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.core import URLInfo from src.core.logger import AsyncCoreLogger -from src.collectors.source_collectors.common_crawler.collector import CommonCrawlerCollector +from src.collectors.impl.common_crawler.collector import CommonCrawlerCollector +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo @pytest.fixture def mock_get_common_crawl_search_results(): - mock_path = "src.collectors.source_collectors.common_crawler.crawler.get_common_crawl_search_results" + mock_path = "src.collectors.impl.common_crawler.crawler.get_common_crawl_search_results" # Results contain other keys, but those are not relevant and thus # can be ignored mock_results = [ @@ -39,8 +40,8 @@ async def test_common_crawl_collector(mock_get_common_crawl_search_results): collector.adb_client.insert_urls.assert_called_once_with( url_infos=[ - URLInfo(url="http://keyword.com"), - URLInfo(url="http://keyword.com/page3") + URLInfo(url="http://keyword.com", source=URLSource.COLLECTOR), + URLInfo(url="http://keyword.com/page3", source=URLSource.COLLECTOR), ], batch_id=1 ) diff --git a/tests/automated/unit/source_collectors/test_example_collector.py b/tests/automated/unit/source_collectors/test_example_collector.py index d9d5b17a..632a6293 100644 --- a/tests/automated/unit/source_collectors/test_example_collector.py +++ b/tests/automated/unit/source_collectors/test_example_collector.py @@ -1,8 +1,8 @@ from unittest.mock import AsyncMock from src.db.client.sync import DatabaseClient -from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO -from src.collectors.source_collectors.example.core import ExampleCollector +from src.collectors.impl.example.dtos.input import ExampleInputDTO +from src.collectors.impl.example.core import ExampleCollector from src.core.logger import AsyncCoreLogger diff --git a/tests/automated/unit/source_collectors/test_muckrock_collectors.py b/tests/automated/unit/source_collectors/test_muckrock_collectors.py index b3e9fec1..6c845b8e 100644 --- a/tests/automated/unit/source_collectors/test_muckrock_collectors.py +++ b/tests/automated/unit/source_collectors/test_muckrock_collectors.py @@ -3,16 +3,17 @@ import pytest -from src.collectors.source_collectors.muckrock.collectors.county.core import MuckrockCountyLevelSearchCollector -from src.collectors.source_collectors.muckrock.collectors.simple.core import MuckrockSimpleSearchCollector +from src.collectors.impl.muckrock.collectors.county.core import MuckrockCountyLevelSearchCollector +from src.collectors.impl.muckrock.collectors.simple.core import MuckrockSimpleSearchCollector from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.url.core import URLInfo from src.core.logger import AsyncCoreLogger -from src.collectors.source_collectors.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO -from src.collectors.source_collectors.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO -from src.collectors.source_collectors.muckrock.fetch_requests.foia import FOIAFetchRequest +from src.collectors.impl.muckrock.collectors.county.dto import MuckrockCountySearchCollectorInputDTO +from src.collectors.impl.muckrock.collectors.simple.dto import MuckrockSimpleSearchCollectorInputDTO +from src.collectors.impl.muckrock.fetch_requests.foia import FOIAFetchRequest +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo -PATCH_ROOT = "src.collectors.source_collectors.muckrock" +PATCH_ROOT = "src.collectors.impl.muckrock" @pytest.fixture def patch_muckrock_fetcher(monkeypatch): @@ -55,10 +56,12 @@ async def test_muckrock_simple_collector(patch_muckrock_fetcher): URLInfo( url='https://include.com/1', collector_metadata={'absolute_url': 'https://include.com/1', 'title': 'keyword'}, + source=URLSource.COLLECTOR ), URLInfo( url='https://include.com/2', collector_metadata={'absolute_url': 'https://include.com/2', 'title': 'keyword'}, + source=URLSource.COLLECTOR ) ], batch_id=1 @@ -111,14 +114,17 @@ async def test_muckrock_county_search_collector(patch_muckrock_county_level_sear URLInfo( url='https://include.com/1', collector_metadata={'absolute_url': 'https://include.com/1', 'title': 'keyword'}, + source=URLSource.COLLECTOR ), URLInfo( url='https://include.com/2', collector_metadata={'absolute_url': 'https://include.com/2', 'title': 'keyword'}, + source=URLSource.COLLECTOR ), URLInfo( url='https://include.com/3', collector_metadata={'absolute_url': 'https://include.com/3', 'title': 'lemon'}, + source=URLSource.COLLECTOR ), ], batch_id=1 diff --git a/tests/conftest.py b/tests/conftest.py index ee9a6774..8ba93200 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,27 @@ import logging -from typing import Any, Generator, AsyncGenerator, Coroutine +import os +from contextlib import contextmanager +from typing import Any, Generator, AsyncGenerator import pytest import pytest_asyncio +from aiohttp import ClientSession from alembic.config import Config from sqlalchemy import create_engine, inspect, MetaData from sqlalchemy.orm import scoped_session, sessionmaker +from src.core.env_var_manager import EnvVarManager +# Below are to prevent import errors +from src.db.models.impl.missing import Missing # noqa: F401 +from src.db.models.impl.log.sqlalchemy import Log # noqa: F401 +from src.db.models.impl.task.error import TaskError # noqa: F401 +from src.db.models.impl.url.checked_for_duplicate import URLCheckedForDuplicate # noqa: F401 from src.db.client.async_ import AsyncDatabaseClient from src.db.client.sync import DatabaseClient -from src.db.helpers import get_postgres_connection_string -from src.db.models.templates import Base -from src.core.env_var_manager import EnvVarManager +from src.db.helpers.connect import get_postgres_connection_string from src.util.helper_functions import load_from_environment from tests.helpers.alembic_runner import AlembicRunner -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.populate import populate_database from tests.helpers.setup.wipe import wipe_database @@ -43,7 +50,9 @@ def setup_and_teardown(): "PDAP_API_URL", "DISCORD_WEBHOOK_URL", "OPENAI_API_KEY", - "HUGGINGFACE_INFERENCE_API_KEY" + "HUGGINGFACE_INFERENCE_API_KEY", + "HUGGINGFACE_HUB_TOKEN", + "INTERNET_ARCHIVE_S3_KEYS", ] all_env_vars = required_env_vars.copy() for env_var in test_env_vars: @@ -51,41 +60,42 @@ def setup_and_teardown(): EnvVarManager.override(all_env_vars) - conn = get_postgres_connection_string() - engine = create_engine(conn) - alembic_cfg = Config("alembic.ini") - alembic_cfg.attributes["connection"] = engine.connect() - alembic_cfg.set_main_option( - "sqlalchemy.url", - get_postgres_connection_string() - ) - live_connection = engine.connect() - runner = AlembicRunner( - alembic_config=alembic_cfg, - inspector=inspect(live_connection), - metadata=MetaData(), - connection=live_connection, - session=scoped_session(sessionmaker(bind=live_connection)), - ) - try: - runner.upgrade("head") - except Exception as e: - print("Exception while upgrading: ", e) - print("Resetting schema") - runner.reset_schema() - runner.stamp("base") - runner.upgrade("head") + with set_env_vars( + { + "INTERNET_ARCHIVE_S3_KEYS": "TEST", + } + ): + conn = get_postgres_connection_string() + engine = create_engine(conn) + alembic_cfg = Config("alembic.ini") + alembic_cfg.attributes["connection"] = engine.connect() + alembic_cfg.set_main_option( + "sqlalchemy.url", + get_postgres_connection_string() + ) + live_connection = engine.connect() + runner = AlembicRunner( + alembic_config=alembic_cfg, + inspector=inspect(live_connection), + metadata=MetaData(), + connection=live_connection, + session=scoped_session(sessionmaker(bind=live_connection)), + ) + try: + runner.upgrade("head") + except Exception as e: + print("Exception while upgrading: ", e) + print("Resetting schema") + runner.reset_schema() + runner.stamp("base") + runner.upgrade("head") + + + yield - yield - try: - runner.downgrade("base") - except Exception as e: - print("Exception while downgrading: ", e) - print("Resetting schema") runner.reset_schema() runner.stamp("base") - finally: live_connection.close() engine.dispose() @@ -123,3 +133,36 @@ def db_data_creator( ): db_data_creator = DBDataCreator(db_client=db_client_test) yield db_data_creator + +@pytest_asyncio.fixture +async def test_client_session() -> AsyncGenerator[ClientSession, Any]: + async with ClientSession() as session: + yield session + + + +@contextmanager +def set_env_vars(env_vars: dict[str, str]): + """Temporarily set multiple environment variables, restoring afterwards.""" + originals = {} + try: + # Save originals and set new values + for key, value in env_vars.items(): + originals[key] = os.environ.get(key) + os.environ[key] = value + yield + finally: + # Restore originals + for key, original in originals.items(): + if original is None: + os.environ.pop(key, None) + else: + os.environ[key] = original + +@pytest.fixture(scope="session") +def disable_task_flags(): + with set_env_vars({ + "SCHEDULED_TASKS_FLAG": "0", + "RUN_URL_TASKS_TASK_FLAG": "0", + }): + yield \ No newline at end of file diff --git a/tests/helpers/alembic_runner.py b/tests/helpers/alembic_runner.py index 53458109..dd1807ba 100644 --- a/tests/helpers/alembic_runner.py +++ b/tests/helpers/alembic_runner.py @@ -23,9 +23,6 @@ def upgrade(self, revision: str): command.upgrade(self.alembic_config, revision) self.reflect() - def downgrade(self, revision: str): - command.downgrade(self.alembic_config, revision) - def stamp(self, revision: str): command.stamp(self.alembic_config, revision) diff --git a/tests/helpers/api_test_helper.py b/tests/helpers/api_test_helper.py index 55a85345..2ff51f98 100644 --- a/tests/helpers/api_test_helper.py +++ b/tests/helpers/api_test_helper.py @@ -5,7 +5,7 @@ from src.core.core import AsyncCore from src.core.enums import BatchStatus from tests.automated.integration.api._helpers.RequestValidator import RequestValidator -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.data_creator.core import DBDataCreator @dataclass diff --git a/tests/helpers/batch_creation_parameters/annotation_info.py b/tests/helpers/batch_creation_parameters/annotation_info.py index f9c9ef2d..cef99f43 100644 --- a/tests/helpers/batch_creation_parameters/annotation_info.py +++ b/tests/helpers/batch_creation_parameters/annotation_info.py @@ -3,11 +3,12 @@ from pydantic import BaseModel from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.core.enums import SuggestedStatus, RecordType +from src.core.enums import RecordType +from src.db.models.impl.flag.url_validated.enums import URLType class AnnotationInfo(BaseModel): - user_relevant: Optional[SuggestedStatus] = None + user_relevant: Optional[URLType] = None auto_relevant: Optional[bool] = None user_record_type: Optional[RecordType] = None auto_record_type: Optional[RecordType] = None diff --git a/tests/helpers/batch_creation_parameters/core.py b/tests/helpers/batch_creation_parameters/core.py index dfc33644..4562cbdf 100644 --- a/tests/helpers/batch_creation_parameters/core.py +++ b/tests/helpers/batch_creation_parameters/core.py @@ -9,10 +9,10 @@ class TestBatchCreationParameters(BaseModel): - created_at: Optional[datetime.datetime] = None + created_at: datetime.datetime | None = None outcome: BatchStatus = BatchStatus.READY_TO_LABEL strategy: CollectorType = CollectorType.EXAMPLE - urls: Optional[list[TestURLCreationParameters]] = None + urls: list[TestURLCreationParameters] | None = None @model_validator(mode='after') def validate_urls(self): diff --git a/tests/helpers/batch_creation_parameters/enums.py b/tests/helpers/batch_creation_parameters/enums.py new file mode 100644 index 00000000..d61a2793 --- /dev/null +++ b/tests/helpers/batch_creation_parameters/enums.py @@ -0,0 +1,11 @@ +from enum import Enum + + +class URLCreationEnum(Enum): + OK = "ok" + SUBMITTED = "submitted" + VALIDATED = "validated" + ERROR = "error" + NOT_RELEVANT = "not_relevant" + DUPLICATE = "duplicate" + NOT_FOUND = "not_found" \ No newline at end of file diff --git a/tests/helpers/batch_creation_parameters/url_creation_parameters.py b/tests/helpers/batch_creation_parameters/url_creation_parameters.py index 2e30cca0..701a239b 100644 --- a/tests/helpers/batch_creation_parameters/url_creation_parameters.py +++ b/tests/helpers/batch_creation_parameters/url_creation_parameters.py @@ -1,23 +1,26 @@ from pydantic import BaseModel, model_validator from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.collectors.enums import URLStatus from src.core.enums import RecordType from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum class TestURLCreationParameters(BaseModel): count: int = 1 - status: URLStatus = URLStatus.PENDING + status: URLCreationEnum = URLCreationEnum.OK with_html_content: bool = False annotation_info: AnnotationInfo = AnnotationInfo() @model_validator(mode='after') def validate_annotation_info(self): - if self.status == URLStatus.NOT_RELEVANT: + if self.status == URLCreationEnum.NOT_RELEVANT: self.annotation_info.final_review_approved = False return self - if self.status != URLStatus.VALIDATED: + if self.status not in ( + URLCreationEnum.SUBMITTED, + URLCreationEnum.VALIDATED + ): return self # Assume is validated diff --git a/tests/helpers/counter.py b/tests/helpers/counter.py new file mode 100644 index 00000000..8d9de1a0 --- /dev/null +++ b/tests/helpers/counter.py @@ -0,0 +1,7 @@ + +from itertools import count + +COUNTER = count(1) + +def next_int() -> int: + return next(COUNTER) \ No newline at end of file diff --git a/tests/helpers/data_creator/__init__.py b/tests/helpers/data_creator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/__init__.py b/tests/helpers/data_creator/commands/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/base.py b/tests/helpers/data_creator/commands/base.py new file mode 100644 index 00000000..84e77621 --- /dev/null +++ b/tests/helpers/data_creator/commands/base.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.client.sync import DatabaseClient +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer + + +class DBDataCreatorCommandBase(ABC): + + def __init__(self,): + self._clients: DBDataCreatorClientContainer | None = None + + def load_clients(self, clients: DBDataCreatorClientContainer): + self._clients = clients + + @property + def clients(self) -> DBDataCreatorClientContainer: + if self._clients is None: + raise Exception("Clients not loaded") + return self._clients + + @property + def db_client(self) -> DatabaseClient: + return self.clients.db + + @property + def adb_client(self) -> AsyncDatabaseClient: + return self.clients.adb + + def run_command_sync(self, command: "DBDataCreatorCommandBase"): + command.load_clients(self._clients) + return command.run_sync() + + async def run_command(self, command: "DBDataCreatorCommandBase"): + command.load_clients(self._clients) + return await command.run() + + @abstractmethod + async def run(self): + raise NotImplementedError + + async def run_sync(self): + raise NotImplementedError \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/__init__.py b/tests/helpers/data_creator/commands/impl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/agency.py b/tests/helpers/data_creator/commands/impl/agency.py new file mode 100644 index 00000000..0bf04ce6 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/agency.py @@ -0,0 +1,40 @@ +from random import randint +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.simple_test_data_functions import generate_test_name + + +@final +class AgencyCommand(DBDataCreatorCommandBase): + + def __init__( + self, + name: str | None = None + ): + super().__init__() + if name is None: + name = generate_test_name() + self.name = name + + @override + async def run(self) -> int: + agency_id = randint(1, 99999999) + await self.adb_client.upsert_new_agencies( + suggestions=[ + URLAgencySuggestionInfo( + url_id=-1, + suggestion_type=SuggestionType.UNKNOWN, + pdap_agency_id=agency_id, + agency_name=self.name, + state=f"Test State {agency_id}", + county=f"Test County {agency_id}", + locality=f"Test Locality {agency_id}" + ) + ] + ) + return agency_id diff --git a/tests/helpers/data_creator/commands/impl/annotate.py b/tests/helpers/data_creator/commands/impl/annotate.py new file mode 100644 index 00000000..1f549615 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/annotate.py @@ -0,0 +1,102 @@ +from typing import final + +from typing_extensions import override + +from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo +from src.api.endpoints.review.enums import RejectionReason +from src.core.enums import SuggestionType +from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.suggestion.auto.agency_.core import AgencyAutoSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.record_type import AutoRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.relevant import AutoRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.agency import AgencyUserSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.record_type import UserRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.relevant import UserRelevantSuggestionCommand + + +@final +class AnnotateCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + annotation_info: AnnotationInfo + ): + super().__init__() + self.url_id = url_id + self.annotation_info = annotation_info + + @override + async def run(self) -> None: + info = self.annotation_info + if info.user_relevant is not None: + await self.run_command( + UserRelevantSuggestionCommand( + url_id=self.url_id, + suggested_status=info.user_relevant + ) + ) + if info.auto_relevant is not None: + await self.run_command( + AutoRelevantSuggestionCommand( + url_id=self.url_id, + relevant=info.auto_relevant + ) + ) + if info.user_record_type is not None: + await self.run_command( + UserRecordTypeSuggestionCommand( + url_id=self.url_id, + record_type=info.user_record_type, + ) + ) + if info.auto_record_type is not None: + await self.run_command( + AutoRecordTypeSuggestionCommand( + url_id=self.url_id, + record_type=info.auto_record_type + ) + ) + if info.user_agency is not None: + await self.run_command( + AgencyUserSuggestionsCommand( + url_id=self.url_id, + agency_annotation_info=info.user_agency + ) + ) + if info.auto_agency is not None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=self.url_id, + count=1, + suggestion_type=SuggestionType.AUTO_SUGGESTION + ) + ) + if info.confirmed_agency is not None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=self.url_id, + count=1, + suggestion_type=SuggestionType.CONFIRMED + ) + ) + if info.final_review_approved is not None: + if info.final_review_approved: + final_review_approval_info = FinalReviewApprovalInfo( + url_id=self.url_id, + record_type=self.annotation_info.user_record_type, + agency_ids=[self.annotation_info.user_agency.suggested_agency] + if self.annotation_info.user_agency is not None else None, + description="Test Description", + ) + await self.adb_client.approve_url( + approval_info=final_review_approval_info, + user_id=1 + ) + else: + await self.adb_client.reject_url( + url_id=self.url_id, + user_id=1, + rejection_reason=RejectionReason.NOT_RELEVANT + ) diff --git a/tests/helpers/data_creator/commands/impl/batch.py b/tests/helpers/data_creator/commands/impl/batch.py new file mode 100644 index 00000000..6871661d --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/batch.py @@ -0,0 +1,35 @@ +from datetime import datetime +from typing import Optional + +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.models.impl.batch.pydantic.info import BatchInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class DBDataCreatorBatchCommand(DBDataCreatorCommandBase): + + def __init__( + self, + strategy: CollectorType = CollectorType.EXAMPLE, + batch_status: BatchStatus = BatchStatus.IN_PROCESS, + created_at: Optional[datetime] = None + ): + super().__init__() + self.strategy = strategy + self.batch_status = batch_status + self.created_at = created_at + + async def run(self) -> int: + raise NotImplementedError + + def run_sync(self) -> int: + return self.db_client.insert_batch( + BatchInfo( + strategy=self.strategy.value, + status=self.batch_status, + parameters={"test_key": "test_value"}, + user_id=1, + date_generated=self.created_at + ) + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/batch_v2.py b/tests/helpers/data_creator/commands/impl/batch_v2.py new file mode 100644 index 00000000..524416da --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/batch_v2.py @@ -0,0 +1,43 @@ +from src.core.enums import BatchStatus +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.batch import DBDataCreatorBatchCommand +from tests.helpers.data_creator.commands.impl.urls_v2.core import URLsV2Command +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 + + +class BatchV2Command(DBDataCreatorCommandBase): + + def __init__( + self, + parameters: TestBatchCreationParameters + ): + super().__init__() + self.parameters = parameters + + async def run(self) -> BatchURLCreationInfoV2: + # Create batch + command = DBDataCreatorBatchCommand( + strategy=self.parameters.strategy, + batch_status=self.parameters.outcome, + created_at=self.parameters.created_at + ) + batch_id = self.run_command_sync(command) + # Return early if batch would not involve URL creation + if self.parameters.outcome in (BatchStatus.ERROR, BatchStatus.ABORTED): + return BatchURLCreationInfoV2( + batch_id=batch_id, + ) + + response = await self.run_command( + URLsV2Command( + parameters=self.parameters.urls, + batch_id=batch_id, + created_at=self.parameters.created_at + ) + ) + + return BatchURLCreationInfoV2( + batch_id=batch_id, + urls_by_status=response.urls_by_status, + ) diff --git a/tests/helpers/data_creator/commands/impl/html_data.py b/tests/helpers/data_creator/commands/impl/html_data.py new file mode 100644 index 00000000..c548eb5a --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/html_data.py @@ -0,0 +1,51 @@ +from src.db.dtos.url.html_content import URLHTMLContentInfo +from src.db.models.impl.url.html.content.enums import HTMLContentType +from src.db.dtos.url.raw_html import RawHTMLInfo +from src.db.models.impl.url.scrape_info.enums import ScrapeStatus +from src.db.models.impl.url.scrape_info.pydantic import URLScrapeInfoInsertModel +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer + + +class HTMLDataCreatorCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_ids: list[int] + ): + super().__init__() + self.url_ids = url_ids + + async def run(self) -> None: + html_content_infos = [] + raw_html_info_list = [] + scraper_info_list = [] + for url_id in self.url_ids: + html_content_infos.append( + URLHTMLContentInfo( + url_id=url_id, + content_type=HTMLContentType.TITLE, + content="test html content" + ) + ) + html_content_infos.append( + URLHTMLContentInfo( + url_id=url_id, + content_type=HTMLContentType.DESCRIPTION, + content="test description" + ) + ) + raw_html_info = RawHTMLInfo( + url_id=url_id, + html="" + ) + raw_html_info_list.append(raw_html_info) + scraper_info = URLScrapeInfoInsertModel( + url_id=url_id, + status=ScrapeStatus.SUCCESS, + ) + scraper_info_list.append(scraper_info) + + await self.adb_client.add_raw_html(raw_html_info_list) + await self.adb_client.add_html_content_infos(html_content_infos) + diff --git a/tests/helpers/data_creator/commands/impl/suggestion/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py b/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py new file mode 100644 index 00000000..e096d15e --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/agency_confirmed.py @@ -0,0 +1,29 @@ +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + +@final +class AgencyConfirmedSuggestionCommand(DBDataCreatorCommandBase): + + def __init__(self, url_id: int): + super().__init__() + self.url_id = url_id + + @override + async def run(self) -> int: + agency_id = await self.run_command(AgencyCommand()) + await self.adb_client.add_confirmed_agency_url_links( + suggestions=[ + URLAgencySuggestionInfo( + url_id=self.url_id, + suggestion_type=SuggestionType.CONFIRMED, + pdap_agency_id=agency_id + ) + ] + ) + return agency_id \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/agency_/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/agency_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/agency_/core.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/agency_/core.py new file mode 100644 index 00000000..fe54c6f9 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/agency_/core.py @@ -0,0 +1,78 @@ +from typing import final + +from typing_extensions import override + +from src.core.enums import SuggestionType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.agency.subtask.enum import AutoAgencyIDSubtaskType +from src.db.models.impl.url.suggestion.agency.subtask.pydantic import URLAutoAgencyIDSubtaskPydantic +from src.db.models.impl.url.suggestion.agency.suggestion.pydantic import AgencyIDSubtaskSuggestionPydantic +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + +@final +class AgencyAutoSuggestionsCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + count: int, + suggestion_type: SuggestionType = SuggestionType.AUTO_SUGGESTION, + subtask_type: AutoAgencyIDSubtaskType = AutoAgencyIDSubtaskType.HOMEPAGE_MATCH, + confidence: int = 50 + ): + super().__init__() + if suggestion_type == SuggestionType.UNKNOWN: + count = 1 # Can only be one auto suggestion if unknown + agencies_found = False + else: + agencies_found = True + self.url_id = url_id + self.count = count + self.suggestion_type = suggestion_type + self.subtask_type = subtask_type + self.confidence = confidence + self.agencies_found = agencies_found + + @override + async def run(self) -> None: + task_id: int = await self.add_task() + subtask_id: int = await self.create_subtask(task_id) + if not self.agencies_found: + return + + suggestions: list[AgencyIDSubtaskSuggestionPydantic] = [] + for _ in range(self.count): + pdap_agency_id: int = await self.run_command(AgencyCommand()) + + suggestion = AgencyIDSubtaskSuggestionPydantic( + subtask_id=subtask_id, + agency_id=pdap_agency_id, + confidence=self.confidence, + ) + suggestions.append(suggestion) + + await self.adb_client.bulk_insert( + models=suggestions, + ) + + async def add_task(self) -> int: + task_id: int = await self.adb_client.initiate_task( + task_type=TaskType.AGENCY_IDENTIFICATION, + ) + return task_id + + async def create_subtask(self, task_id: int) -> int: + obj: URLAutoAgencyIDSubtaskPydantic = URLAutoAgencyIDSubtaskPydantic( + task_id=task_id, + type=self.subtask_type, + url_id=self.url_id, + agencies_found=self.agencies_found, + ) + subtask_id: int = (await self.adb_client.bulk_insert( + models=[obj], + return_ids=True + ))[0] + return subtask_id + diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py new file mode 100644 index 00000000..25ad6e53 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/record_type.py @@ -0,0 +1,20 @@ +from src.core.enums import RecordType +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class AutoRecordTypeSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + record_type: RecordType + ): + super().__init__() + self.url_id = url_id + self.record_type = record_type + + async def run(self) -> None: + await self.adb_client.add_auto_record_type_suggestion( + url_id=self.url_id, + record_type=self.record_type + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py b/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py new file mode 100644 index 00000000..2e31491d --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/auto/relevant.py @@ -0,0 +1,24 @@ +from src.db.models.impl.url.suggestion.relevant.auto.pydantic.input import AutoRelevancyAnnotationInput +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class AutoRelevantSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + relevant: bool = True + ): + super().__init__() + self.url_id = url_id + self.relevant = relevant + + async def run(self) -> None: + await self.adb_client.add_auto_relevant_suggestion( + input_=AutoRelevancyAnnotationInput( + url_id=self.url_id, + is_relevant=self.relevant, + confidence=0.5, + model_name="test_model" + ) + ) diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/__init__.py b/tests/helpers/data_creator/commands/impl/suggestion/user/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py b/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py new file mode 100644 index 00000000..35418679 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/agency.py @@ -0,0 +1,37 @@ +from random import randint +from typing import final + +from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand + + +@final +class AgencyUserSuggestionsCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + user_id: int | None = None, + agency_annotation_info: URLAgencyAnnotationPostInfo | None = None + ): + super().__init__() + if user_id is None: + user_id = randint(1, 99999999) + self.url_id = url_id + self.user_id = user_id + self.agency_annotation_info = agency_annotation_info + + async def run(self) -> None: + if self.agency_annotation_info is None: + agency_annotation_info = URLAgencyAnnotationPostInfo( + suggested_agency=await self.run_command(AgencyCommand()) + ) + else: + agency_annotation_info = self.agency_annotation_info + await self.adb_client.add_agency_manual_suggestion( + agency_id=agency_annotation_info.suggested_agency, + url_id=self.url_id, + user_id=self.user_id, + is_new=agency_annotation_info.is_new + ) diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py b/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py new file mode 100644 index 00000000..03c7ab0b --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/record_type.py @@ -0,0 +1,25 @@ +from random import randint + +from src.core.enums import RecordType +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class UserRecordTypeSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + record_type: RecordType, + user_id: int | None = None, + ): + super().__init__() + self.url_id = url_id + self.user_id = user_id if user_id is not None else randint(1, 99999999) + self.record_type = record_type + + async def run(self) -> None: + await self.adb_client.add_user_record_type_suggestion( + url_id=self.url_id, + user_id=self.user_id, + record_type=self.record_type + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py b/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py new file mode 100644 index 00000000..0dfd5a3f --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/suggestion/user/relevant.py @@ -0,0 +1,30 @@ +from random import randint +from typing import final + +from typing_extensions import override + +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +@final +class UserRelevantSuggestionCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_id: int, + user_id: int | None = None, + suggested_status: URLType = URLType.DATA_SOURCE + ): + super().__init__() + self.url_id = url_id + self.user_id = user_id if user_id is not None else randint(1, 99999999) + self.suggested_status = suggested_status + + @override + async def run(self) -> None: + await self.adb_client.add_user_relevant_suggestion( + url_id=self.url_id, + user_id=self.user_id, + suggested_status=self.suggested_status + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/url_metadata.py b/tests/helpers/data_creator/commands/impl/url_metadata.py new file mode 100644 index 00000000..161d5631 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/url_metadata.py @@ -0,0 +1,31 @@ +from http import HTTPStatus + +from src.db.models.impl.url.web_metadata.insert import URLWebMetadataPydantic +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase + + +class URLMetadataCommand(DBDataCreatorCommandBase): + + def __init__( + self, + url_ids: list[int], + content_type: str = "text/html", + status_code: int = HTTPStatus.OK.value + ): + super().__init__() + self.url_ids = url_ids + self.content_type = content_type + self.status_code = status_code + + async def run(self) -> None: + url_metadata_infos = [] + for url_id in self.url_ids: + url_metadata = URLWebMetadataPydantic( + url_id=url_id, + accessed=True, + status_code=self.status_code, + content_type=self.content_type, + error_message=None + ) + url_metadata_infos.append(url_metadata) + await self.adb_client.bulk_insert(url_metadata_infos) \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/urls_/__init__.py b/tests/helpers/data_creator/commands/impl/urls_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/urls_/convert.py b/tests/helpers/data_creator/commands/impl/urls_/convert.py new file mode 100644 index 00000000..66747e6c --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_/convert.py @@ -0,0 +1,34 @@ +from src.collectors.enums import URLStatus +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum + + +def convert_url_creation_enum_to_url_status(url_creation_enum: URLCreationEnum) -> URLStatus: + match url_creation_enum: + case URLCreationEnum.OK: + return URLStatus.OK + case URLCreationEnum.SUBMITTED: + return URLStatus.OK + case URLCreationEnum.VALIDATED: + return URLStatus.OK + case URLCreationEnum.NOT_RELEVANT: + return URLStatus.OK + case URLCreationEnum.ERROR: + return URLStatus.ERROR + case URLCreationEnum.DUPLICATE: + return URLStatus.DUPLICATE + case _: + raise ValueError(f"Unknown URLCreationEnum: {url_creation_enum}") + +def convert_url_creation_enum_to_validated_type( + url_creation_enum: URLCreationEnum +) -> URLType: + match url_creation_enum: + case URLCreationEnum.SUBMITTED: + return URLType.DATA_SOURCE + case URLCreationEnum.VALIDATED: + return URLType.DATA_SOURCE + case URLCreationEnum.NOT_RELEVANT: + return URLType.NOT_RELEVANT + case _: + raise ValueError(f"Unknown URLCreationEnum: {url_creation_enum}") \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/urls_/query.py b/tests/helpers/data_creator/commands/impl/urls_/query.py new file mode 100644 index 00000000..7587abfb --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_/query.py @@ -0,0 +1,70 @@ +from datetime import datetime + +from src.core.tasks.url.operators.submit_approved.tdo import SubmittedURLInfo +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.info import URLInfo +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.urls_.convert import convert_url_creation_enum_to_url_status +from tests.helpers.simple_test_data_functions import generate_test_urls + + +class URLsDBDataCreatorCommand(DBDataCreatorCommandBase): + + def __init__( + self, + batch_id: int | None, + url_count: int, + collector_metadata: dict | None = None, + status: URLCreationEnum = URLCreationEnum.OK, + created_at: datetime | None = None + ): + super().__init__() + self.batch_id = batch_id + self.url_count = url_count + self.collector_metadata = collector_metadata + self.status = status + self.created_at = created_at + + async def run(self) -> InsertURLsInfo: + raise NotImplementedError + + def run_sync(self) -> InsertURLsInfo: + raw_urls = generate_test_urls(self.url_count) + url_infos: list[URLInfo] = [] + for url in raw_urls: + url_infos.append( + URLInfo( + url=url, + status=convert_url_creation_enum_to_url_status(self.status), + name="Test Name" if self.status in ( + URLCreationEnum.VALIDATED, + URLCreationEnum.SUBMITTED, + ) else None, + collector_metadata=self.collector_metadata, + created_at=self.created_at, + source=URLSource.COLLECTOR + ) + ) + + url_insert_info = self.db_client.insert_urls( + url_infos=url_infos, + batch_id=self.batch_id, + ) + + # If outcome is submitted, also add entry to DataSourceURL + if self.status == URLCreationEnum.SUBMITTED: + submitted_url_infos = [] + for url_id in url_insert_info.url_ids: + submitted_url_info = SubmittedURLInfo( + url_id=url_id, + data_source_id=url_id, # Use same ID for convenience, + request_error=None, + submitted_at=self.created_at + ) + submitted_url_infos.append(submitted_url_info) + self.db_client.mark_urls_as_submitted(submitted_url_infos) + + + return url_insert_info \ No newline at end of file diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/__init__.py b/tests/helpers/data_creator/commands/impl/urls_v2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/core.py b/tests/helpers/data_creator/commands/impl/urls_v2/core.py new file mode 100644 index 00000000..f7042720 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_v2/core.py @@ -0,0 +1,68 @@ +from datetime import datetime + +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.annotate import AnnotateCommand +from tests.helpers.data_creator.commands.impl.html_data import HTMLDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls_.convert import convert_url_creation_enum_to_validated_type +from tests.helpers.data_creator.commands.impl.urls_.query import URLsDBDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls_v2.response import URLsV2Response +from tests.helpers.data_creator.generate import generate_validated_flags +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class URLsV2Command(DBDataCreatorCommandBase): + + def __init__( + self, + parameters: list[TestURLCreationParameters], + batch_id: int | None = None, + created_at: datetime | None = None + ): + super().__init__() + self.parameters = parameters + self.batch_id = batch_id + self.created_at = created_at + + async def run(self) -> URLsV2Response: + urls_by_status: dict[URLCreationEnum, URLCreationInfo] = {} + urls_by_order: list[URLCreationInfo] = [] + # Create urls + for url_parameters in self.parameters: + command = URLsDBDataCreatorCommand( + batch_id=self.batch_id, + url_count=url_parameters.count, + status=url_parameters.status, + created_at=self.created_at + ) + iui: InsertURLsInfo = self.run_command_sync(command) + url_ids = [iui.url_id for iui in iui.url_mappings] + if url_parameters.with_html_content: + command = HTMLDataCreatorCommand( + url_ids=url_ids + ) + await self.run_command(command) + if url_parameters.annotation_info.has_annotations(): + for url_id in url_ids: + await self.run_command( + AnnotateCommand( + url_id=url_id, + annotation_info=url_parameters.annotation_info + ) + ) + + creation_info = URLCreationInfo( + url_mappings=iui.url_mappings, + outcome=url_parameters.status, + annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None + ) + urls_by_order.append(creation_info) + urls_by_status[url_parameters.status] = creation_info + + return URLsV2Response( + urls_by_status=urls_by_status, + urls_by_order=urls_by_order + ) diff --git a/tests/helpers/data_creator/commands/impl/urls_v2/response.py b/tests/helpers/data_creator/commands/impl/urls_v2/response.py new file mode 100644 index 00000000..74aa8e20 --- /dev/null +++ b/tests/helpers/data_creator/commands/impl/urls_v2/response.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class URLsV2Response(BaseModel): + urls_by_status: dict[URLCreationEnum, URLCreationInfo] = {} + urls_by_order: list[URLCreationInfo] = [] \ No newline at end of file diff --git a/tests/helpers/data_creator/core.py b/tests/helpers/data_creator/core.py new file mode 100644 index 00000000..cbeb207f --- /dev/null +++ b/tests/helpers/data_creator/core.py @@ -0,0 +1,725 @@ +from datetime import datetime +from http import HTTPStatus +from typing import Optional, Any + +from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo +from src.collectors.enums import CollectorType, URLStatus +from src.core.enums import BatchStatus, SuggestionType, RecordType +from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo +from src.core.tasks.url.operators.misc_metadata.tdo import URLMiscellaneousMetadataTDO +from src.db.client.async_ import AsyncDatabaseClient +from src.db.client.sync import DatabaseClient +from src.db.dtos.url.insert import InsertURLsInfo +from src.db.dtos.url.mapping import URLMapping +from src.db.enums import TaskType +from src.db.models.impl.agency.enums import AgencyType +from src.db.models.impl.agency.sqlalchemy import Agency +from src.db.models.impl.duplicate.pydantic.insert import DuplicateInsertInfo +from src.db.models.impl.flag.root_url.sqlalchemy import FlagRootURL +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.link.agency_location.sqlalchemy import LinkAgencyLocation +from src.db.models.impl.link.url_agency.sqlalchemy import LinkURLAgency +from src.db.models.impl.link.urls_root_url.sqlalchemy import LinkURLRootURL +from src.db.models.impl.link.user_name_suggestion.sqlalchemy import LinkUserNameSuggestion +from src.db.models.impl.link.user_suggestion_not_found.agency.sqlalchemy import LinkUserSuggestionAgencyNotFound +from src.db.models.impl.link.user_suggestion_not_found.location.sqlalchemy import LinkUserSuggestionLocationNotFound +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.suggestion.location.auto.subtask.enums import LocationIDSubtaskType +from src.db.models.impl.url.suggestion.location.auto.subtask.sqlalchemy import AutoLocationIDSubtask +from src.db.models.impl.url.suggestion.location.auto.suggestion.sqlalchemy import LocationIDSubtaskSuggestion +from src.db.models.impl.url.suggestion.location.user.sqlalchemy import UserLocationSuggestion +from src.db.models.impl.url.suggestion.name.enums import NameSuggestionSource +from src.db.models.impl.url.suggestion.name.sqlalchemy import URLNameSuggestion +from src.db.models.impl.url.task_error.pydantic_.insert import URLTaskErrorPydantic +from src.db.models.impl.url.web_metadata.sqlalchemy import URLWebMetadata +from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.batch_creation_parameters.url_creation_parameters import TestURLCreationParameters +from tests.helpers.counter import next_int +from tests.helpers.data_creator.commands.base import DBDataCreatorCommandBase +from tests.helpers.data_creator.commands.impl.agency import AgencyCommand +from tests.helpers.data_creator.commands.impl.batch import DBDataCreatorBatchCommand +from tests.helpers.data_creator.commands.impl.batch_v2 import BatchV2Command +from tests.helpers.data_creator.commands.impl.html_data import HTMLDataCreatorCommand +from tests.helpers.data_creator.commands.impl.suggestion.agency_confirmed import AgencyConfirmedSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.agency_.core import AgencyAutoSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.record_type import AutoRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.auto.relevant import AutoRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.agency import AgencyUserSuggestionsCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.record_type import UserRecordTypeSuggestionCommand +from tests.helpers.data_creator.commands.impl.suggestion.user.relevant import UserRelevantSuggestionCommand +from tests.helpers.data_creator.commands.impl.url_metadata import URLMetadataCommand +from tests.helpers.data_creator.commands.impl.urls_.query import URLsDBDataCreatorCommand +from tests.helpers.data_creator.commands.impl.urls_v2.core import URLsV2Command +from tests.helpers.data_creator.commands.impl.urls_v2.response import URLsV2Response +from tests.helpers.data_creator.create import create_urls, create_batch, create_batch_url_links, create_validated_flags, \ + create_url_data_sources, create_state, create_county, create_locality +from tests.helpers.data_creator.models.clients import DBDataCreatorClientContainer +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo +from tests.helpers.data_creator.models.creation_info.batch.v2 import BatchURLCreationInfoV2 +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo +from tests.helpers.data_creator.models.creation_info.us_state import USStateCreationInfo +from tests.helpers.simple_test_data_functions import generate_test_name + + +class DBDataCreator: + """ + Assists in the creation of test data + """ + def __init__(self, db_client: Optional[DatabaseClient] = None): + if db_client is not None: + self.db_client = db_client + else: + self.db_client = DatabaseClient() + self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() + self.clients = DBDataCreatorClientContainer( + adb=self.adb_client, + db=self.db_client + ) + + def run_command_sync(self, command: DBDataCreatorCommandBase) -> Any: + command.load_clients(self.clients) + return command.run_sync() + + async def run_command(self, command: DBDataCreatorCommandBase) -> Any: + command.load_clients(self.clients) + return await command.run() + + def batch( + self, + strategy: CollectorType = CollectorType.EXAMPLE, + batch_status: BatchStatus = BatchStatus.IN_PROCESS, + created_at: Optional[datetime] = None + ) -> int: + command = DBDataCreatorBatchCommand( + strategy=strategy, + batch_status=batch_status, + created_at=created_at + ) + return self.run_command_sync(command) + + async def task(self, url_ids: Optional[list[int]] = None) -> int: + task_id = await self.adb_client.initiate_task(task_type=TaskType.HTML) + if url_ids is not None: + await self.adb_client.link_urls_to_task(task_id=task_id, url_ids=url_ids) + return task_id + + async def batch_v2( + self, + parameters: TestBatchCreationParameters + ) -> BatchURLCreationInfoV2: + return await self.run_command(BatchV2Command(parameters)) + + async def url_v2( + self, + parameters: list[TestURLCreationParameters], + batch_id: int | None = None, + created_at: datetime | None = None + ) -> URLsV2Response: + return await self.run_command( + URLsV2Command( + parameters=parameters, + batch_id=batch_id, + created_at=created_at + ) + ) + + + async def batch_and_urls( + self, + strategy: CollectorType = CollectorType.EXAMPLE, + url_count: int = 3, + with_html_content: bool = False, + batch_status: BatchStatus = BatchStatus.READY_TO_LABEL, + url_status: URLCreationEnum = URLCreationEnum.OK + ) -> BatchURLCreationInfo: + batch_id = self.batch( + strategy=strategy, + batch_status=batch_status + ) + if batch_status in (BatchStatus.ERROR, BatchStatus.ABORTED): + return BatchURLCreationInfo( + batch_id=batch_id, + url_ids=[], + urls=[] + ) + iuis: InsertURLsInfo = self.urls( + batch_id=batch_id, + url_count=url_count, + outcome=url_status + ) + url_ids = [iui.url_id for iui in iuis.url_mappings] + if with_html_content: + await self.html_data(url_ids) + + return BatchURLCreationInfo( + batch_id=batch_id, + url_ids=url_ids, + urls=[iui.url for iui in iuis.url_mappings] + ) + + async def agency(self, name: str | None = None) -> int: + return await self.run_command(AgencyCommand(name)) + + async def auto_relevant_suggestions(self, url_id: int, relevant: bool = True): + await self.run_command( + AutoRelevantSuggestionCommand( + url_id=url_id, + relevant=relevant + ) + ) + + async def user_relevant_suggestion( + self, + url_id: int, + user_id: int | None = None, + suggested_status: URLType = URLType.DATA_SOURCE + ) -> None: + await self.run_command( + UserRelevantSuggestionCommand( + url_id=url_id, + user_id=user_id, + suggested_status=suggested_status + ) + ) + + async def user_record_type_suggestion( + self, + url_id: int, + record_type: RecordType, + user_id: Optional[int] = None, + ) -> None: + await self.run_command( + UserRecordTypeSuggestionCommand( + url_id=url_id, + record_type=record_type, + user_id=user_id + ) + ) + + async def auto_record_type_suggestions( + self, + url_id: int, + record_type: RecordType + ): + await self.run_command( + AutoRecordTypeSuggestionCommand( + url_id=url_id, + record_type=record_type + ) + ) + + async def auto_suggestions( + self, + url_ids: list[int], + num_suggestions: int, + suggestion_type: SuggestionType.AUTO_SUGGESTION or SuggestionType.UNKNOWN + ): + allowed_suggestion_types = [SuggestionType.AUTO_SUGGESTION, SuggestionType.UNKNOWN] + if suggestion_type not in allowed_suggestion_types: + raise ValueError(f"suggestion_type must be one of {allowed_suggestion_types}") + if suggestion_type == SuggestionType.UNKNOWN and num_suggestions > 1: + raise ValueError("num_suggestions must be 1 when suggestion_type is unknown") + + for url_id in url_ids: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=url_id, + count=num_suggestions, + suggestion_type=suggestion_type + ) + ) + + async def confirmed_suggestions(self, url_ids: list[int]): + for url_id in url_ids: + await self.adb_client.add_confirmed_agency_url_links( + suggestions=[ + URLAgencySuggestionInfo( + url_id=url_id, + suggestion_type=SuggestionType.CONFIRMED, + pdap_agency_id=await self.agency() + ) + ] + ) + + async def manual_suggestion(self, user_id: int, url_id: int, is_new: bool = False): + await self.adb_client.add_agency_manual_suggestion( + agency_id=await self.agency(), + url_id=url_id, + user_id=user_id, + is_new=is_new + ) + + + def urls( + self, + batch_id: int, + url_count: int, + collector_metadata: dict | None = None, + outcome: URLCreationEnum = URLCreationEnum.OK, + created_at: datetime | None = None + ) -> InsertURLsInfo: + command = URLsDBDataCreatorCommand( + batch_id=batch_id, + url_count=url_count, + collector_metadata=collector_metadata, + status=outcome, + created_at=created_at + ) + return self.run_command_sync(command) + + async def url_miscellaneous_metadata( + self, + url_id: int, + name: str = "Test Name", + description: str = "Test Description", + record_formats: Optional[list[str]] = None, + data_portal_type: Optional[str] = "Test Data Portal Type", + supplying_entity: Optional[str] = "Test Supplying Entity" + ) -> None: + if record_formats is None: + record_formats = ["Test Record Format", "Test Record Format 2"] + + tdo = URLMiscellaneousMetadataTDO( + url_id=url_id, + collector_metadata={}, + collector_type=CollectorType.EXAMPLE, + record_formats=record_formats, + name=name, + description=description, + data_portal_type=data_portal_type, + supplying_entity=supplying_entity + ) + + await self.adb_client.add_miscellaneous_metadata([tdo]) + + + def duplicate_urls( + self, + duplicate_batch_id: int, + url_ids: list[int] + ) -> None: + """ + Create duplicates for all given url ids, and associate them + with the given batch + """ + duplicate_infos = [] + for url_id in url_ids: + dup_info = DuplicateInsertInfo( + batch_id=duplicate_batch_id, + original_url_id=url_id + ) + duplicate_infos.append(dup_info) + + self.db_client.insert_duplicates(duplicate_infos) + + async def html_data(self, url_ids: list[int]) -> None: + command = HTMLDataCreatorCommand( + url_ids=url_ids + ) + await self.run_command(command) + + async def task_errors( + self, + url_ids: list[int], + task_id: Optional[int] = None + ) -> None: + if task_id is None: + task_id = await self.task() + task_errors = [] + for url_id in url_ids: + task_error = URLTaskErrorPydantic( + url_id=url_id, + error="test error", + task_id=task_id, + task_type=TaskType.HTML + ) + task_errors.append(task_error) + await self.adb_client.bulk_insert(task_errors) + + + async def agency_auto_suggestions( + self, + url_id: int, + count: int, + suggestion_type: SuggestionType = SuggestionType.AUTO_SUGGESTION + ) -> None: + await self.run_command( + AgencyAutoSuggestionsCommand( + url_id=url_id, + count=count, + suggestion_type=suggestion_type + ) + ) + + async def agency_confirmed_suggestion( + self, + url_id: int + ) -> int: + """ + Create a confirmed agency suggestion and return the auto-generated pdap_agency_id. + """ + return await self.run_command( + AgencyConfirmedSuggestionCommand(url_id) + ) + + async def agency_user_suggestions( + self, + url_id: int, + user_id: int | None = None, + agency_annotation_info: URLAgencyAnnotationPostInfo | None = None + ) -> None: + await self.run_command( + AgencyUserSuggestionsCommand( + url_id=url_id, + user_id=user_id, + agency_annotation_info=agency_annotation_info + ) + ) + + async def url_metadata( + self, + url_ids: list[int], + content_type: str = "text/html", + status_code: int = HTTPStatus.OK.value + ) -> None: + await self.run_command( + URLMetadataCommand( + url_ids=url_ids, + content_type=content_type, + status_code=status_code + ) + ) + + async def create_validated_urls( + self, + record_type: RecordType = RecordType.RESOURCES, + validation_type: URLType = URLType.DATA_SOURCE, + count: int = 1 + ) -> list[URLMapping]: + url_mappings: list[URLMapping] = await self.create_urls( + record_type=record_type, + count=count + ) + url_ids: list[int] = [url_mapping.url_id for url_mapping in url_mappings] + await self.create_validated_flags( + url_ids=url_ids, + validation_type=validation_type + ) + return url_mappings + + async def create_submitted_urls( + self, + record_type: RecordType = RecordType.RESOURCES, + count: int = 1 + ) -> list[URLMapping]: + url_mappings: list[URLMapping] = await self.create_urls( + record_type=record_type, + count=count + ) + url_ids: list[int] = [url_mapping.url_id for url_mapping in url_mappings] + await self.create_validated_flags( + url_ids=url_ids, + validation_type=URLType.DATA_SOURCE + ) + await self.create_url_data_sources(url_ids=url_ids) + return url_mappings + + + async def create_urls( + self, + status: URLStatus = URLStatus.OK, + source: URLSource = URLSource.COLLECTOR, + record_type: RecordType | None = RecordType.RESOURCES, + collector_metadata: dict | None = None, + count: int = 1, + batch_id: int | None = None + ) -> list[URLMapping]: + + url_mappings: list[URLMapping] = await create_urls( + adb_client=self.adb_client, + status=status, + source=source, + record_type=record_type, + collector_metadata=collector_metadata, + count=count + ) + url_ids: list[int] = [url_mapping.url_id for url_mapping in url_mappings] + if batch_id is not None: + await self.create_batch_url_links( + url_ids=url_ids, + batch_id=batch_id + ) + return url_mappings + + async def create_batch( + self, + status: BatchStatus = BatchStatus.READY_TO_LABEL, + strategy: CollectorType = CollectorType.EXAMPLE, + date_generated: datetime = datetime.now(), + ) -> int: + return await create_batch( + adb_client=self.adb_client, + status=status, + strategy=strategy, + date_generated=date_generated + ) + + async def create_batch_url_links( + self, + url_ids: list[int], + batch_id: int, + ) -> None: + await create_batch_url_links( + adb_client=self.adb_client, + url_ids=url_ids, + batch_id=batch_id + ) + + async def create_validated_flags( + self, + url_ids: list[int], + validation_type: URLType, + ) -> None: + await create_validated_flags( + adb_client=self.adb_client, + url_ids=url_ids, + validation_type=validation_type + ) + + async def create_url_data_sources( + self, + url_ids: list[int], + ) -> None: + await create_url_data_sources( + adb_client=self.adb_client, + url_ids=url_ids + ) + + async def create_url_agency_links( + self, + url_ids: list[int], + agency_ids: list[int], + ) -> None: + links: list[LinkURLAgency] = [] + for url_id in url_ids: + for agency_id in agency_ids: + link = LinkURLAgency( + url_id=url_id, + agency_id=agency_id, + ) + links.append(link) + await self.adb_client.add_all(links) + + async def create_agency(self, agency_id: int = 1) -> None: + agency = Agency( + agency_id=agency_id, + name=generate_test_name(agency_id), + agency_type=AgencyType.UNKNOWN + ) + await self.adb_client.add_all([agency]) + + async def create_agencies(self, count: int = 3) -> list[int]: + agencies: list[Agency] = [] + agency_ids: list[int] = [] + for _ in range(count): + agency_id = next_int() + agency = Agency( + agency_id=agency_id, + name=generate_test_name(agency_id), + agency_type=AgencyType.UNKNOWN + ) + agencies.append(agency) + agency_ids.append(agency_id) + await self.adb_client.add_all(agencies) + return agency_ids + + async def flag_as_root(self, url_ids: list[int]) -> None: + flag_root_urls: list[FlagRootURL] = [ + FlagRootURL(url_id=url_id) for url_id in url_ids + ] + await self.adb_client.add_all(flag_root_urls) + + async def link_urls_to_root(self, url_ids: list[int], root_url_id: int) -> None: + links: list[LinkURLRootURL] = [ + LinkURLRootURL(url_id=url_id, root_url_id=root_url_id) for url_id in url_ids + ] + await self.adb_client.add_all(links) + + async def link_urls_to_agencies(self, url_ids: list[int], agency_ids: list[int]) -> None: + assert len(url_ids) == len(agency_ids) + links: list[LinkURLAgency] = [] + for url_id, agency_id in zip(url_ids, agency_ids): + link = LinkURLAgency( + url_id=url_id, + agency_id=agency_id + ) + links.append(link) + await self.adb_client.add_all(links) + + async def create_web_metadata( + self, + url_ids: list[int], + status_code: int = 200, + ): + web_metadata: list[URLWebMetadata] = [ + URLWebMetadata( + url_id=url_id, + status_code=status_code, + accessed=True, + content_type="text/html", + ) + for url_id in url_ids + ] + await self.adb_client.add_all(web_metadata) + + async def create_us_state( + self, + name: str, + iso:str + ) -> USStateCreationInfo: + return await create_state( + adb_client=self.adb_client, + name=name, + iso=iso, + ) + + async def create_county( + self, + state_id: int, + name: str, + ) -> CountyCreationInfo: + return await create_county( + adb_client=self.adb_client, + state_id=state_id, + name=name, + ) + + async def create_locality( + self, + state_id: int, + county_id: int, + name: str, + ) -> LocalityCreationInfo: + return await create_locality( + adb_client=self.adb_client, + state_id=state_id, + county_id=county_id, + name=name, + ) + + async def add_compressed_html( + self, + url_ids: list[int], + ) -> None: + compressed_html_inserts: list[URLCompressedHTML] = [ + URLCompressedHTML( + url_id=url_id, + compressed_html=b"Test HTML" + ) + for url_id in url_ids + ] + await self.adb_client.add_all(compressed_html_inserts) + + async def add_user_location_suggestion( + self, + url_id: int, + user_id: int, + location_id: int, + ): + suggestion = UserLocationSuggestion( + url_id=url_id, + user_id=user_id, + location_id=location_id, + ) + await self.adb_client.add(suggestion) + + async def add_location_suggestion( + self, + url_id: int, + location_ids: list[int], + confidence: float, + type_: LocationIDSubtaskType = LocationIDSubtaskType.NLP_LOCATION_FREQUENCY + ) -> None: + locations_found: bool = len(location_ids) > 0 + task_id: int = await self.task(url_ids=[url_id]) + subtask = AutoLocationIDSubtask( + url_id=url_id, + type=type_, + task_id=task_id, + locations_found=len(location_ids) > 0 + ) + subtask_id: int = await self.adb_client.add(subtask, return_id=True) + if not locations_found: + return + suggestions: list[LocationIDSubtaskSuggestion] = [] + for location_id in location_ids: + suggestion = LocationIDSubtaskSuggestion( + subtask_id=subtask_id, + location_id=location_id, + confidence=confidence + ) + suggestions.append(suggestion) + await self.adb_client.add_all(suggestions) + + async def link_agencies_to_location( + self, + agency_ids: list[int], + location_id: int + ) -> None: + links: list[LinkAgencyLocation] = [ + LinkAgencyLocation( + agency_id=agency_id, + location_id=location_id + ) + for agency_id in agency_ids + ] + await self.adb_client.add_all(links) + + async def name_suggestion( + self, + url_id: int, + source: NameSuggestionSource = NameSuggestionSource.HTML_METADATA_TITLE, + name: str | None = None, + ) -> int: + if name is None: + name = f"Test Name {next_int()}" + suggestion = URLNameSuggestion( + url_id=url_id, + source=source, + suggestion=name, + ) + return await self.adb_client.add(suggestion, return_id=True) + + async def user_name_endorsement( + self, + suggestion_id: int, + user_id: int, + ): + link = LinkUserNameSuggestion( + suggestion_id=suggestion_id, + user_id=user_id, + ) + await self.adb_client.add(link) + + async def not_found_location_suggestion( + self, + url_id: int, + ) -> None: + suggestion = LinkUserSuggestionLocationNotFound( + url_id=url_id, + user_id=next_int(), + ) + await self.adb_client.add(suggestion) + + async def not_found_agency_suggestion( + self, + url_id: int, + ) -> None: + suggestion = LinkUserSuggestionAgencyNotFound( + url_id=url_id, + user_id=next_int(), + ) + await self.adb_client.add(suggestion) \ No newline at end of file diff --git a/tests/helpers/data_creator/create.py b/tests/helpers/data_creator/create.py new file mode 100644 index 00000000..200a34cd --- /dev/null +++ b/tests/helpers/data_creator/create.py @@ -0,0 +1,158 @@ +from datetime import datetime + +from src.collectors.enums import CollectorType, URLStatus +from src.core.enums import BatchStatus, RecordType +from src.db import County, Locality, USState +from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.mapping import URLMapping +from src.db.models.impl.batch.pydantic.insert import BatchInsertModel +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.pydantic import FlagURLValidatedPydantic +from src.db.models.impl.link.batch_url.pydantic import LinkBatchURLPydantic +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.db.models.impl.url.data_source.pydantic import URLDataSourcePydantic +from src.db.models.impl.url.record_type.pydantic import URLRecordTypePydantic +from tests.helpers.counter import COUNTER, next_int +from tests.helpers.data_creator.generate import generate_batch, generate_urls, generate_validated_flags, \ + generate_url_data_sources, generate_batch_url_links +from tests.helpers.data_creator.models.creation_info.county import CountyCreationInfo +from tests.helpers.data_creator.models.creation_info.locality import LocalityCreationInfo +from tests.helpers.data_creator.models.creation_info.us_state import USStateCreationInfo + + +async def create_batch( + adb_client: AsyncDatabaseClient, + status: BatchStatus = BatchStatus.READY_TO_LABEL, + strategy: CollectorType = CollectorType.EXAMPLE, + date_generated: datetime = datetime.now(), +) -> int: + batch: BatchInsertModel = generate_batch(status=status, strategy=strategy, date_generated=date_generated) + return (await adb_client.bulk_insert([batch], return_ids=True))[0] + +async def create_urls( + adb_client: AsyncDatabaseClient, + status: URLStatus = URLStatus.OK, + source: URLSource = URLSource.COLLECTOR, + record_type: RecordType | None = RecordType.RESOURCES, + collector_metadata: dict | None = None, + count: int = 1 +) -> list[URLMapping]: + urls: list[URLInsertModel] = generate_urls( + status=status, + source=source, + collector_metadata=collector_metadata, + count=count, + ) + url_ids = await adb_client.bulk_insert(urls, return_ids=True) + if record_type is not None: + record_types: list[URLRecordTypePydantic] = [ + URLRecordTypePydantic( + url_id=url_id, + record_type=record_type, + ) + for url_id in url_ids + ] + await adb_client.bulk_insert(record_types) + + return [URLMapping(url_id=url_id, url=url.url) for url_id, url in zip(url_ids, urls)] + +async def create_validated_flags( + adb_client: AsyncDatabaseClient, + url_ids: list[int], + validation_type: URLType, +) -> None: + validated_flags: list[FlagURLValidatedPydantic] = generate_validated_flags( + url_ids=url_ids, + validation_type=validation_type, + ) + await adb_client.bulk_insert(validated_flags) + +async def create_url_data_sources( + adb_client: AsyncDatabaseClient, + url_ids: list[int], +) -> None: + url_data_sources: list[URLDataSourcePydantic] = generate_url_data_sources( + url_ids=url_ids, + ) + await adb_client.bulk_insert(url_data_sources) + +async def create_batch_url_links( + adb_client: AsyncDatabaseClient, + url_ids: list[int], + batch_id: int, +) -> None: + batch_url_links: list[LinkBatchURLPydantic] = generate_batch_url_links( + url_ids=url_ids, + batch_id=batch_id, + ) + await adb_client.bulk_insert(batch_url_links) + +async def create_state( + adb_client: AsyncDatabaseClient, + name: str, + iso: str +) -> USStateCreationInfo: + + us_state_insert_model = USState( + state_name=name, + state_iso=iso, + ) + us_state_id: int = await adb_client.add( + us_state_insert_model, + return_id=True + ) + location_id: int = await adb_client.get_location_id( + us_state_id=us_state_id, + ) + return USStateCreationInfo( + us_state_id=us_state_id, + location_id=location_id, + ) + +async def create_county( + adb_client: AsyncDatabaseClient, + state_id: int, + name: str +) -> CountyCreationInfo: + county_insert_model = County( + name=name, + state_id=state_id, + fips=str(next_int()), + ) + county_id: int = await adb_client.add( + county_insert_model, + return_id=True + ) + location_id: int = await adb_client.get_location_id( + us_state_id=state_id, + county_id=county_id + ) + return CountyCreationInfo( + county_id=county_id, + location_id=location_id, + ) + +async def create_locality( + adb_client: AsyncDatabaseClient, + state_id: int, + county_id: int, + name: str +) -> LocalityCreationInfo: + locality_insert_model = Locality( + name=name, + county_id=county_id, + ) + locality_id: int = await adb_client.add( + locality_insert_model, + return_id=True + ) + location_id: int = await adb_client.get_location_id( + us_state_id=state_id, + county_id=county_id, + locality_id=locality_id + ) + return LocalityCreationInfo( + locality_id=locality_id, + location_id=location_id, + ) \ No newline at end of file diff --git a/tests/helpers/data_creator/generate.py b/tests/helpers/data_creator/generate.py new file mode 100644 index 00000000..1cf0a806 --- /dev/null +++ b/tests/helpers/data_creator/generate.py @@ -0,0 +1,80 @@ +from datetime import datetime + +from src.collectors.enums import URLStatus, CollectorType +from src.core.enums import BatchStatus, RecordType +from src.db.models.impl.batch.pydantic.insert import BatchInsertModel +from src.db.models.impl.flag.url_validated.enums import URLType +from src.db.models.impl.flag.url_validated.pydantic import FlagURLValidatedPydantic +from src.db.models.impl.flag.url_validated.sqlalchemy import FlagURLValidated +from src.db.models.impl.link.batch_url.pydantic import LinkBatchURLPydantic +from src.db.models.impl.url.core.enums import URLSource +from src.db.models.impl.url.core.pydantic.insert import URLInsertModel +from src.db.models.impl.url.data_source.pydantic import URLDataSourcePydantic +from tests.helpers.counter import next_int + + +def generate_batch( + status: BatchStatus, + strategy: CollectorType = CollectorType.EXAMPLE, + date_generated: datetime = datetime.now(), +) -> BatchInsertModel: + return BatchInsertModel( + strategy=strategy.value, + status=status, + parameters={}, + user_id=1, + date_generated=date_generated, + ) + +def generate_batch_url_links( + url_ids: list[int], + batch_id: int +) -> list[LinkBatchURLPydantic]: + return [ + LinkBatchURLPydantic( + url_id=url_id, + batch_id=batch_id, + ) + for url_id in url_ids + ] + +def generate_urls( + status: URLStatus = URLStatus.OK, + source: URLSource = URLSource.COLLECTOR, + collector_metadata: dict | None = None, + count: int = 1 +) -> list[URLInsertModel]: + results: list[URLInsertModel] = [] + for i in range(count): + val: int = next_int() + results.append(URLInsertModel( + url=f"http://example.com/{val}", + status=status, + source=source, + name=f"Example {val}", + collector_metadata=collector_metadata, + )) + return results + +def generate_validated_flags( + url_ids: list[int], + validation_type: URLType, +) -> list[FlagURLValidatedPydantic]: + return [ + FlagURLValidatedPydantic( + url_id=url_id, + type=validation_type, + ) + for url_id in url_ids + ] + +def generate_url_data_sources( + url_ids: list[int], +) -> list[URLDataSourcePydantic]: + return [ + URLDataSourcePydantic( + url_id=url_id, + data_source_id=url_id, + ) + for url_id in url_ids + ] \ No newline at end of file diff --git a/tests/helpers/data_creator/insert.py b/tests/helpers/data_creator/insert.py new file mode 100644 index 00000000..06b207e3 --- /dev/null +++ b/tests/helpers/data_creator/insert.py @@ -0,0 +1,10 @@ +from src.db.client.async_ import AsyncDatabaseClient +from src.db.templates.markers.bulk.insert import BulkInsertableModel + + +async def bulk_insert_all( + adb_client: AsyncDatabaseClient, + lists_of_models: list[list[BulkInsertableModel]], +): + for list_of_models in lists_of_models: + await adb_client.bulk_insert(list_of_models) \ No newline at end of file diff --git a/tests/helpers/data_creator/models/__init__.py b/tests/helpers/data_creator/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/clients.py b/tests/helpers/data_creator/models/clients.py new file mode 100644 index 00000000..a8256dfc --- /dev/null +++ b/tests/helpers/data_creator/models/clients.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel + +from src.db.client.async_ import AsyncDatabaseClient +from src.db.client.sync import DatabaseClient + + +class DBDataCreatorClientContainer(BaseModel): + db: DatabaseClient + adb: AsyncDatabaseClient + + class Config: + arbitrary_types_allowed = True diff --git a/tests/helpers/data_creator/models/creation_info/__init__.py b/tests/helpers/data_creator/models/creation_info/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/creation_info/batch/__init__.py b/tests/helpers/data_creator/models/creation_info/batch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/helpers/data_creator/models/creation_info/batch/v1.py b/tests/helpers/data_creator/models/creation_info/batch/v1.py new file mode 100644 index 00000000..d5451eca --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/batch/v1.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class BatchURLCreationInfo(BaseModel): + batch_id: int + url_ids: list[int] + urls: list[str] diff --git a/tests/helpers/data_creator/models/creation_info/batch/v2.py b/tests/helpers/data_creator/models/creation_info/batch/v2.py new file mode 100644 index 00000000..52d7e37d --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/batch/v2.py @@ -0,0 +1,17 @@ +from pydantic import BaseModel + +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.data_creator.models.creation_info.url import URLCreationInfo + + +class BatchURLCreationInfoV2(BaseModel): + batch_id: int + urls_by_status: dict[URLCreationEnum, URLCreationInfo] = {} + + @property + def url_ids(self) -> list[int]: + url_creation_infos = self.urls_by_status.values() + url_ids = [] + for url_creation_info in url_creation_infos: + url_ids.extend(url_creation_info.url_ids) + return url_ids diff --git a/tests/helpers/data_creator/models/creation_info/county.py b/tests/helpers/data_creator/models/creation_info/county.py new file mode 100644 index 00000000..4a9511ec --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/county.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class CountyCreationInfo(BaseModel): + county_id: int + location_id: int \ No newline at end of file diff --git a/tests/helpers/data_creator/models/creation_info/locality.py b/tests/helpers/data_creator/models/creation_info/locality.py new file mode 100644 index 00000000..6e98899d --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/locality.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class LocalityCreationInfo(BaseModel): + locality_id: int + location_id: int \ No newline at end of file diff --git a/tests/helpers/data_creator/models/creation_info/url.py b/tests/helpers/data_creator/models/creation_info/url.py new file mode 100644 index 00000000..16c45a0a --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/url.py @@ -0,0 +1,18 @@ +from typing import Optional + +from pydantic import BaseModel + +from src.collectors.enums import URLStatus +from src.db.dtos.url.mapping import URLMapping +from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum + + +class URLCreationInfo(BaseModel): + url_mappings: list[URLMapping] + outcome: URLCreationEnum + annotation_info: Optional[AnnotationInfo] = None + + @property + def url_ids(self) -> list[int]: + return [url_mapping.url_id for url_mapping in self.url_mappings] diff --git a/tests/helpers/data_creator/models/creation_info/us_state.py b/tests/helpers/data_creator/models/creation_info/us_state.py new file mode 100644 index 00000000..2c8914d6 --- /dev/null +++ b/tests/helpers/data_creator/models/creation_info/us_state.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class USStateCreationInfo(BaseModel): + us_state_id: int + location_id: int \ No newline at end of file diff --git a/tests/helpers/db_data_creator.py b/tests/helpers/db_data_creator.py deleted file mode 100644 index 1a1d0a70..00000000 --- a/tests/helpers/db_data_creator.py +++ /dev/null @@ -1,529 +0,0 @@ -from datetime import datetime -from random import randint -from typing import List, Optional - -from pydantic import BaseModel - -from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo -from src.api.endpoints.review.approve.dto import FinalReviewApprovalInfo -from src.api.endpoints.review.enums import RejectionReason -from src.core.tasks.url.operators.agency_identification.dtos.suggestion import URLAgencySuggestionInfo -from src.db.client.async_ import AsyncDatabaseClient -from src.db.dtos.batch import BatchInfo -from src.db.dtos.duplicate import DuplicateInsertInfo -from src.db.dtos.url.annotations.auto.relevancy import AutoRelevancyAnnotationInput -from src.db.dtos.url.insert import InsertURLsInfo -from src.db.dtos.url.error import URLErrorPydanticInfo -from src.db.dtos.url.html_content import URLHTMLContentInfo, HTMLContentType -from src.db.dtos.url.core import URLInfo -from src.db.dtos.url.mapping import URLMapping -from src.db.client.sync import DatabaseClient -from src.db.dtos.url.raw_html import RawHTMLInfo -from src.db.enums import TaskType -from src.collectors.enums import CollectorType, URLStatus -from src.core.tasks.url.operators.submit_approved_url.tdo import SubmittedURLInfo -from src.core.tasks.url.operators.url_miscellaneous_metadata.tdo import URLMiscellaneousMetadataTDO -from src.core.enums import BatchStatus, SuggestionType, RecordType, SuggestedStatus -from tests.helpers.batch_creation_parameters.annotation_info import AnnotationInfo -from tests.helpers.batch_creation_parameters.core import TestBatchCreationParameters -from tests.helpers.simple_test_data_functions import generate_test_urls - - -class URLCreationInfo(BaseModel): - url_mappings: list[URLMapping] - outcome: URLStatus - annotation_info: Optional[AnnotationInfo] = None - - @property - def url_ids(self) -> list[int]: - return [url_mapping.url_id for url_mapping in self.url_mappings] - -class BatchURLCreationInfoV2(BaseModel): - batch_id: int - url_creation_infos: dict[URLStatus, URLCreationInfo] - - @property - def url_ids(self) -> list[int]: - url_creation_infos = self.url_creation_infos.values() - url_ids = [] - for url_creation_info in url_creation_infos: - url_ids.extend(url_creation_info.url_ids) - return url_ids - -class BatchURLCreationInfo(BaseModel): - batch_id: int - url_ids: list[int] - urls: list[str] - -class DBDataCreator: - """ - Assists in the creation of test data - """ - def __init__(self, db_client: Optional[DatabaseClient] = None): - if db_client is not None: - self.db_client = db_client - else: - self.db_client = DatabaseClient() - self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() - - def batch( - self, - strategy: CollectorType = CollectorType.EXAMPLE, - batch_status: BatchStatus = BatchStatus.IN_PROCESS, - created_at: Optional[datetime] = None - ) -> int: - return self.db_client.insert_batch( - BatchInfo( - strategy=strategy.value, - status=batch_status, - parameters={"test_key": "test_value"}, - user_id=1, - date_generated=created_at - ) - ) - - async def task(self, url_ids: Optional[list[int]] = None) -> int: - task_id = await self.adb_client.initiate_task(task_type=TaskType.HTML) - if url_ids is not None: - await self.adb_client.link_urls_to_task(task_id=task_id, url_ids=url_ids) - return task_id - - async def batch_v2( - self, - parameters: TestBatchCreationParameters - ) -> BatchURLCreationInfoV2: - batch_id = self.batch( - strategy=parameters.strategy, - batch_status=parameters.outcome, - created_at=parameters.created_at - ) - if parameters.outcome in (BatchStatus.ERROR, BatchStatus.ABORTED): - return BatchURLCreationInfoV2( - batch_id=batch_id, - url_creation_infos={} - ) - - d: dict[URLStatus, URLCreationInfo] = {} - # Create urls - for url_parameters in parameters.urls: - iui: InsertURLsInfo = self.urls( - batch_id=batch_id, - url_count=url_parameters.count, - outcome=url_parameters.status, - created_at=parameters.created_at - ) - url_ids = [iui.url_id for iui in iui.url_mappings] - if url_parameters.with_html_content: - await self.html_data(url_ids) - if url_parameters.annotation_info.has_annotations(): - for url_id in url_ids: - await self.annotate( - url_id=url_id, - annotation_info=url_parameters.annotation_info - ) - - d[url_parameters.status] = URLCreationInfo( - url_mappings=iui.url_mappings, - outcome=url_parameters.status, - annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None - ) - return BatchURLCreationInfoV2( - batch_id=batch_id, - url_creation_infos=d - ) - - async def batch_and_urls( - self, - strategy: CollectorType = CollectorType.EXAMPLE, - url_count: int = 3, - with_html_content: bool = False, - batch_status: BatchStatus = BatchStatus.READY_TO_LABEL, - url_status: URLStatus = URLStatus.PENDING - ) -> BatchURLCreationInfo: - batch_id = self.batch( - strategy=strategy, - batch_status=batch_status - ) - if batch_status in (BatchStatus.ERROR, BatchStatus.ABORTED): - return BatchURLCreationInfo( - batch_id=batch_id, - url_ids=[], - urls=[] - ) - iuis: InsertURLsInfo = self.urls( - batch_id=batch_id, - url_count=url_count, - outcome=url_status - ) - url_ids = [iui.url_id for iui in iuis.url_mappings] - if with_html_content: - await self.html_data(url_ids) - - return BatchURLCreationInfo( - batch_id=batch_id, - url_ids=url_ids, - urls=[iui.url for iui in iuis.url_mappings] - ) - - async def agency(self) -> int: - agency_id = randint(1, 99999999) - await self.adb_client.upsert_new_agencies( - suggestions=[ - URLAgencySuggestionInfo( - url_id=-1, - suggestion_type=SuggestionType.UNKNOWN, - pdap_agency_id=agency_id, - agency_name=f"Test Agency {agency_id}", - state=f"Test State {agency_id}", - county=f"Test County {agency_id}", - locality=f"Test Locality {agency_id}" - ) - ] - ) - return agency_id - - async def auto_relevant_suggestions(self, url_id: int, relevant: bool = True): - await self.adb_client.add_auto_relevant_suggestion( - input_=AutoRelevancyAnnotationInput( - url_id=url_id, - is_relevant=relevant, - confidence=0.5, - model_name="test_model" - ) - ) - - async def annotate( - self, - url_id: int, - annotation_info: AnnotationInfo - ): - info = annotation_info - if info.user_relevant is not None: - await self.user_relevant_suggestion_v2(url_id=url_id, suggested_status=info.user_relevant) - if info.auto_relevant is not None: - await self.auto_relevant_suggestions(url_id=url_id, relevant=info.auto_relevant) - if info.user_record_type is not None: - await self.user_record_type_suggestion(url_id=url_id, record_type=info.user_record_type) - if info.auto_record_type is not None: - await self.auto_record_type_suggestions(url_id=url_id, record_type=info.auto_record_type) - if info.user_agency is not None: - await self.agency_user_suggestions(url_id=url_id, agency_annotation_info=info.user_agency) - if info.auto_agency is not None: - await self.agency_auto_suggestions(url_id=url_id, count=1, suggestion_type=SuggestionType.AUTO_SUGGESTION) - if info.confirmed_agency is not None: - await self.agency_auto_suggestions(url_id=url_id, count=1, suggestion_type=SuggestionType.CONFIRMED) - if info.final_review_approved is not None: - if info.final_review_approved: - final_review_approval_info = FinalReviewApprovalInfo( - url_id=url_id, - record_type=annotation_info.user_record_type, - agency_ids=[annotation_info.user_agency.suggested_agency] - if annotation_info.user_agency is not None else None, - description="Test Description", - ) - await self.adb_client.approve_url( - approval_info=final_review_approval_info, - user_id=1 - ) - else: - await self.adb_client.reject_url( - url_id=url_id, - user_id=1, - rejection_reason=RejectionReason.NOT_RELEVANT - ) - - - async def user_relevant_suggestion( - self, - url_id: int, - user_id: Optional[int] = None, - relevant: bool = True - ): - await self.user_relevant_suggestion_v2( - url_id=url_id, - user_id=user_id, - suggested_status=SuggestedStatus.RELEVANT if relevant else SuggestedStatus.NOT_RELEVANT - ) - - async def user_relevant_suggestion_v2( - self, - url_id: int, - user_id: Optional[int] = None, - suggested_status: SuggestedStatus = SuggestedStatus.RELEVANT - ): - if user_id is None: - user_id = randint(1, 99999999) - await self.adb_client.add_user_relevant_suggestion( - url_id=url_id, - user_id=user_id, - suggested_status=suggested_status - ) - - async def user_record_type_suggestion( - self, - url_id: int, - record_type: RecordType, - user_id: Optional[int] = None, - ): - if user_id is None: - user_id = randint(1, 99999999) - await self.adb_client.add_user_record_type_suggestion( - url_id=url_id, - user_id=user_id, - record_type=record_type - ) - - async def auto_record_type_suggestions(self, url_id: int, record_type: RecordType): - await self.adb_client.add_auto_record_type_suggestion( - url_id=url_id, - record_type=record_type - ) - - - async def auto_suggestions( - self, - url_ids: list[int], - num_suggestions: int, - suggestion_type: SuggestionType.AUTO_SUGGESTION or SuggestionType.UNKNOWN - ): - allowed_suggestion_types = [SuggestionType.AUTO_SUGGESTION, SuggestionType.UNKNOWN] - if suggestion_type not in allowed_suggestion_types: - raise ValueError(f"suggestion_type must be one of {allowed_suggestion_types}") - if suggestion_type == SuggestionType.UNKNOWN and num_suggestions > 1: - raise ValueError("num_suggestions must be 1 when suggestion_type is unknown") - - for url_id in url_ids: - suggestions = [] - for i in range(num_suggestions): - if suggestion_type == SuggestionType.UNKNOWN: - agency_id = None - else: - agency_id = await self.agency() - suggestion = URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=suggestion_type, - pdap_agency_id=agency_id - ) - suggestions.append(suggestion) - - await self.adb_client.add_agency_auto_suggestions( - suggestions=suggestions - ) - - async def confirmed_suggestions(self, url_ids: list[int]): - for url_id in url_ids: - await self.adb_client.add_confirmed_agency_url_links( - suggestions=[ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=await self.agency() - ) - ] - ) - - async def manual_suggestion(self, user_id: int, url_id: int, is_new: bool = False): - await self.adb_client.add_agency_manual_suggestion( - agency_id=await self.agency(), - url_id=url_id, - user_id=user_id, - is_new=is_new - ) - - - def urls( - self, - batch_id: int, - url_count: int, - collector_metadata: Optional[dict] = None, - outcome: URLStatus = URLStatus.PENDING, - created_at: Optional[datetime] = None - ) -> InsertURLsInfo: - raw_urls = generate_test_urls(url_count) - url_infos: List[URLInfo] = [] - for url in raw_urls: - url_infos.append( - URLInfo( - url=url, - outcome=outcome, - name="Test Name" if outcome == URLStatus.VALIDATED else None, - collector_metadata=collector_metadata, - created_at=created_at - ) - ) - - url_insert_info = self.db_client.insert_urls( - url_infos=url_infos, - batch_id=batch_id, - ) - - # If outcome is submitted, also add entry to DataSourceURL - if outcome == URLStatus.SUBMITTED: - submitted_url_infos = [] - for url_id in url_insert_info.url_ids: - submitted_url_info = SubmittedURLInfo( - url_id=url_id, - data_source_id=url_id, # Use same ID for convenience, - request_error=None, - submitted_at=created_at - ) - submitted_url_infos.append(submitted_url_info) - self.db_client.mark_urls_as_submitted(submitted_url_infos) - - - return url_insert_info - - async def url_miscellaneous_metadata( - self, - url_id: int, - name: str = "Test Name", - description: str = "Test Description", - record_formats: Optional[list[str]] = None, - data_portal_type: Optional[str] = "Test Data Portal Type", - supplying_entity: Optional[str] = "Test Supplying Entity" - ): - if record_formats is None: - record_formats = ["Test Record Format", "Test Record Format 2"] - - tdo = URLMiscellaneousMetadataTDO( - url_id=url_id, - collector_metadata={}, - collector_type=CollectorType.EXAMPLE, - record_formats=record_formats, - name=name, - description=description, - data_portal_type=data_portal_type, - supplying_entity=supplying_entity - ) - - await self.adb_client.add_miscellaneous_metadata([tdo]) - - - def duplicate_urls(self, duplicate_batch_id: int, url_ids: list[int]): - """ - Create duplicates for all given url ids, and associate them - with the given batch - """ - duplicate_infos = [] - for url_id in url_ids: - dup_info = DuplicateInsertInfo( - duplicate_batch_id=duplicate_batch_id, - original_url_id=url_id - ) - duplicate_infos.append(dup_info) - - self.db_client.insert_duplicates(duplicate_infos) - - async def html_data(self, url_ids: list[int]): - html_content_infos = [] - raw_html_info_list = [] - for url_id in url_ids: - html_content_infos.append( - URLHTMLContentInfo( - url_id=url_id, - content_type=HTMLContentType.TITLE, - content="test html content" - ) - ) - html_content_infos.append( - URLHTMLContentInfo( - url_id=url_id, - content_type=HTMLContentType.DESCRIPTION, - content="test description" - ) - ) - raw_html_info = RawHTMLInfo( - url_id=url_id, - html="" - ) - raw_html_info_list.append(raw_html_info) - - await self.adb_client.add_raw_html(raw_html_info_list) - await self.adb_client.add_html_content_infos(html_content_infos) - - async def error_info( - self, - url_ids: list[int], - task_id: Optional[int] = None - ): - if task_id is None: - task_id = await self.task() - error_infos = [] - for url_id in url_ids: - url_error_info = URLErrorPydanticInfo( - url_id=url_id, - error="test error", - task_id=task_id - ) - error_infos.append(url_error_info) - await self.adb_client.add_url_error_infos(error_infos) - - - async def agency_auto_suggestions( - self, - url_id: int, - count: int, - suggestion_type: SuggestionType = SuggestionType.AUTO_SUGGESTION - ): - if suggestion_type == SuggestionType.UNKNOWN: - count = 1 # Can only be one auto suggestion if unknown - - suggestions = [] - for _ in range(count): - if suggestion_type == SuggestionType.UNKNOWN: - pdap_agency_id = None - else: - pdap_agency_id = await self.agency() - suggestion = URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=suggestion_type, - pdap_agency_id=pdap_agency_id, - state="Test State", - county="Test County", - locality="Test Locality" - ) - suggestions.append(suggestion) - - await self.adb_client.add_agency_auto_suggestions( - suggestions=suggestions - ) - - async def agency_confirmed_suggestion( - self, - url_id: int - ) -> int: - """ - Creates a confirmed agency suggestion - and returns the auto-generated pdap_agency_id - """ - agency_id = await self.agency() - await self.adb_client.add_confirmed_agency_url_links( - suggestions=[ - URLAgencySuggestionInfo( - url_id=url_id, - suggestion_type=SuggestionType.CONFIRMED, - pdap_agency_id=agency_id - ) - ] - ) - return agency_id - - async def agency_user_suggestions( - self, - url_id: int, - user_id: Optional[int] = None, - agency_annotation_info: Optional[URLAgencyAnnotationPostInfo] = None - ): - if user_id is None: - user_id = randint(1, 99999999) - - if agency_annotation_info is None: - agency_annotation_info = URLAgencyAnnotationPostInfo( - suggested_agency=await self.agency() - ) - await self.adb_client.add_agency_manual_suggestion( - agency_id=agency_annotation_info.suggested_agency, - url_id=url_id, - user_id=user_id, - is_new=agency_annotation_info.is_new - ) diff --git a/tests/helpers/patch_functions.py b/tests/helpers/patch_functions.py index 8a42c9dc..170a2062 100644 --- a/tests/helpers/patch_functions.py +++ b/tests/helpers/patch_functions.py @@ -4,7 +4,7 @@ async def block_sleep(monkeypatch) -> AwaitableBarrier: barrier = AwaitableBarrier() monkeypatch.setattr( - "src.collectors.source_collectors.example.core.ExampleCollector.sleep", + "src.collectors.impl.example.core.ExampleCollector.sleep", barrier ) return barrier diff --git a/tests/helpers/run.py b/tests/helpers/run.py new file mode 100644 index 00000000..aa889f7f --- /dev/null +++ b/tests/helpers/run.py @@ -0,0 +1,15 @@ +from src.core.tasks.base.run_info import TaskOperatorRunInfo +from src.core.tasks.url.operators.base import URLTaskOperatorBase +from tests.helpers.asserts import assert_task_run_success + + +async def run_task_and_confirm_success( + operator: URLTaskOperatorBase, +) -> None: + """ + Run task, confirm success, and assert task no longer meets prerequisites. + """ + + run_info: TaskOperatorRunInfo = await operator.run_task() + assert_task_run_success(run_info) + assert not await operator.meets_task_prerequisites() \ No newline at end of file diff --git a/tests/helpers/setup/annotate_agency/core.py b/tests/helpers/setup/annotate_agency/core.py index fbd7bc53..6827194d 100644 --- a/tests/helpers/setup/annotate_agency/core.py +++ b/tests/helpers/setup/annotate_agency/core.py @@ -1,5 +1,6 @@ from src.core.enums import SuggestionType -from tests.helpers.db_data_creator import DBDataCreator, BatchURLCreationInfo +from tests.helpers.data_creator.core import DBDataCreator +from tests.helpers.data_creator.models.creation_info.batch.v1 import BatchURLCreationInfo from tests.helpers.setup.annotate_agency.model import AnnotateAgencySetupInfo diff --git a/tests/helpers/setup/annotation/core.py b/tests/helpers/setup/annotation/core.py index d8d3bb0c..70123cb9 100644 --- a/tests/helpers/setup/annotation/core.py +++ b/tests/helpers/setup/annotation/core.py @@ -1,12 +1,13 @@ from src.collectors.enums import URLStatus -from tests.helpers.db_data_creator import DBDataCreator +from tests.helpers.batch_creation_parameters.enums import URLCreationEnum +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.annotation.model import AnnotationSetupInfo async def setup_for_get_next_url_for_annotation( db_data_creator: DBDataCreator, url_count: int, - outcome: URLStatus = URLStatus.PENDING + outcome: URLCreationEnum = URLCreationEnum.OK ) -> AnnotationSetupInfo: batch_id = db_data_creator.batch() insert_urls_info = db_data_creator.urls( diff --git a/tests/helpers/setup/final_review/core.py b/tests/helpers/setup/final_review/core.py index 87c4da59..ababae82 100644 --- a/tests/helpers/setup/final_review/core.py +++ b/tests/helpers/setup/final_review/core.py @@ -2,7 +2,8 @@ from src.api.endpoints.annotate.agency.post.dto import URLAgencyAnnotationPostInfo from src.core.enums import RecordType -from tests.helpers.db_data_creator import DBDataCreator +from src.db.models.impl.flag.url_validated.enums import URLType +from tests.helpers.data_creator.core import DBDataCreator from tests.helpers.setup.final_review.model import FinalReviewSetupInfo @@ -37,7 +38,7 @@ async def add_agency_suggestion() -> int: ) return agency_id - async def add_record_type_suggestion(record_type: RecordType): + async def add_record_type_suggestion(record_type: RecordType) -> None: await db_data_creator.user_record_type_suggestion( url_id=url_mapping.url_id, record_type=record_type @@ -46,7 +47,7 @@ async def add_record_type_suggestion(record_type: RecordType): async def add_relevant_suggestion(relevant: bool): await db_data_creator.user_relevant_suggestion( url_id=url_mapping.url_id, - relevant=relevant + suggested_status=URLType.DATA_SOURCE if relevant else URLType.NOT_RELEVANT ) await db_data_creator.auto_relevant_suggestions( @@ -59,6 +60,10 @@ async def add_relevant_suggestion(relevant: bool): record_type=RecordType.ARREST_RECORDS ) + name_suggestion_id: int = await db_data_creator.name_suggestion( + url_id=url_mapping.url_id, + ) + if include_user_annotations: await add_relevant_suggestion(False) await add_record_type_suggestion(RecordType.ACCIDENT_REPORTS) @@ -69,5 +74,6 @@ async def add_relevant_suggestion(relevant: bool): return FinalReviewSetupInfo( batch_id=batch_id, url_mapping=url_mapping, - user_agency_id=user_agency_id + user_agency_id=user_agency_id, + name_suggestion_id=name_suggestion_id ) diff --git a/tests/helpers/setup/final_review/model.py b/tests/helpers/setup/final_review/model.py index c75fb847..a3e57a3c 100644 --- a/tests/helpers/setup/final_review/model.py +++ b/tests/helpers/setup/final_review/model.py @@ -8,4 +8,5 @@ class FinalReviewSetupInfo(BaseModel): batch_id: int url_mapping: URLMapping - user_agency_id: Optional[int] + user_agency_id: int | None + name_suggestion_id: int | None diff --git a/tests/helpers/setup/populate.py b/tests/helpers/setup/populate.py index 1741253b..02c364d6 100644 --- a/tests/helpers/setup/populate.py +++ b/tests/helpers/setup/populate.py @@ -1,5 +1,5 @@ from src.db.client.async_ import AsyncDatabaseClient -from src.db.models.instantiations.url.core import URL +from src.db.models.impl.url.core.sqlalchemy import URL async def populate_database(adb_client: AsyncDatabaseClient) -> None: @@ -12,7 +12,7 @@ async def populate_database(adb_client: AsyncDatabaseClient) -> None: collector_metadata={ "source_collector": "test-data", }, - outcome='validated', + status='validated', record_type="Other" ) await adb_client.add(url) \ No newline at end of file diff --git a/tests/helpers/setup/wipe.py b/tests/helpers/setup/wipe.py index 2145bcf1..e81c266d 100644 --- a/tests/helpers/setup/wipe.py +++ b/tests/helpers/setup/wipe.py @@ -1,6 +1,6 @@ from sqlalchemy import create_engine -from src.db.models.templates import Base +from src.db.models.templates_.base import Base def wipe_database(connection_string: str) -> None: @@ -8,5 +8,7 @@ def wipe_database(connection_string: str) -> None: engine = create_engine(connection_string) with engine.connect() as connection: for table in reversed(Base.metadata.sorted_tables): + if table.info == "view": + continue connection.execute(table.delete()) connection.commit() diff --git a/tests/helpers/simple_test_data_functions.py b/tests/helpers/simple_test_data_functions.py index d5f2c313..4d321dc5 100644 --- a/tests/helpers/simple_test_data_functions.py +++ b/tests/helpers/simple_test_data_functions.py @@ -4,6 +4,8 @@ """ import uuid +from tests.helpers.counter import next_int + def generate_test_urls(count: int) -> list[str]: results = [] @@ -12,3 +14,18 @@ def generate_test_urls(count: int) -> list[str]: results.append(url) return results + + +def generate_test_url(i: int) -> str: + return f"https://test.com/{i}" + +def generate_test_name(i: int | None = None) -> str: + if i is None: + return f"Test Name {next_int()}" + return f"Test Name {i}" + +def generate_test_description(i: int) -> str: + return f"Test description {i}" + +def generate_test_html(i: int) -> str: + return f"