diff --git a/ENV.md b/ENV.md index cdedd288..7c09fb64 100644 --- a/ENV.md +++ b/ENV.md @@ -14,12 +14,14 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`POSTGRES_DB` | The database name for the test database | `source_collector_test_db` | |`POSTGRES_HOST` | The host for the test database | `127.0.0.1` | |`POSTGRES_PORT` | The port for the test database | `5432` | -|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | +|`DS_APP_SECRET_KEY`| The secret key used for decoding JWT tokens produced by the Data Sources App. Must match the secret token `JWT_SECRET_KEY` that is used in the Data Sources App for encoding. | `abc123` | |`DEV`| Set to any value to run the application in development mode. | `true` | |`DEEPSEEK_API_KEY`| The API key required for accessing the DeepSeek API. | `abc123` | |`OPENAI_API_KEY`| The API key required for accessing the OpenAI API. | `abc123` | -|`PDAP_EMAIL`| An email address for accessing the PDAP API. | `abc123@test.com` | -|`PDAP_PASSWORD`| A password for accessing the PDAP API. | `abc123` | +|`PDAP_EMAIL`| An email address for accessing the PDAP API.[^1] | `abc123@test.com` | +|`PDAP_PASSWORD`| A password for accessing the PDAP API.[^1] | `abc123` | |`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | +|`PDAP_API_URL`| The URL for the PDAP API| `https://data-sources-v2.pdap.dev/api`| |`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` | +[^1:] The user account in question will require elevated permissions to access certain endpoints. At a minimum, the user will require the `source_collector` and `db_write` permissions. \ No newline at end of file diff --git a/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py new file mode 100644 index 00000000..b92fe1ef --- /dev/null +++ b/alembic/versions/2025_03_29_1716-33a546c93441_add_data_source_id_column_to_url_table.py @@ -0,0 +1,31 @@ +"""Add data source ID column to URL table + +Revision ID: 33a546c93441 +Revises: 45271f8fe75d +Create Date: 2025-03-29 17:16:11.863064 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '33a546c93441' +down_revision: Union[str, None] = '45271f8fe75d' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + 'urls', + sa.Column('data_source_id', sa.Integer(), nullable=True) + ) + # Add unique constraint to data_source_id column + op.create_unique_constraint('uq_data_source_id', 'urls', ['data_source_id']) + + +def downgrade() -> None: + op.drop_column('urls', 'data_source_id') diff --git a/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py b/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py new file mode 100644 index 00000000..e1d5b725 --- /dev/null +++ b/alembic/versions/2025_04_15_1338-b363794fa4e9_add_submit_url_task_type_enum.py @@ -0,0 +1,48 @@ +"""Add Submit URL Task Type Enum + +Revision ID: b363794fa4e9 +Revises: 33a546c93441 +Create Date: 2025-04-15 13:38:58.293627 + +""" +from typing import Sequence, Union + + +from util.alembic_helpers import switch_enum_type + +# revision identifiers, used by Alembic. +revision: str = 'b363794fa4e9' +down_revision: Union[str, None] = '33a546c93441' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + "Submit Approved URLs" + ] + ) + + +def downgrade() -> None: + switch_enum_type( + table_name='tasks', + column_name='task_type', + enum_name='task_type', + new_enum_values=[ + "HTML", + "Relevancy", + "Record Type", + "Agency Identification", + "Misc Metadata", + ] + ) \ No newline at end of file diff --git a/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py b/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py new file mode 100644 index 00000000..82ce97a4 --- /dev/null +++ b/alembic/versions/2025_04_15_1532-ed06a5633d2e_revert_to_pending_validated_urls_.py @@ -0,0 +1,42 @@ +"""Revert to pending validated URLs without name and add constraint + +Revision ID: ed06a5633d2e +Revises: b363794fa4e9 +Create Date: 2025-04-15 15:32:26.465488 + +""" +from typing import Sequence, Union + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = 'ed06a5633d2e' +down_revision: Union[str, None] = 'b363794fa4e9' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + + op.execute( + """ + UPDATE public.urls + SET OUTCOME = 'pending' + WHERE OUTCOME = 'validated' AND NAME IS NULL + """ + ) + + op.create_check_constraint( + 'url_name_not_null_when_validated', + 'urls', + "NAME IS NOT NULL OR OUTCOME != 'validated'" + ) + + +def downgrade() -> None: + op.drop_constraint( + 'url_name_not_null_when_validated', + 'urls', + type_='check' + ) diff --git a/api/main.py b/api/main.py index 19f8de8d..40970e4f 100644 --- a/api/main.py +++ b/api/main.py @@ -1,5 +1,6 @@ from contextlib import asynccontextmanager +import aiohttp import uvicorn from fastapi import FastAPI @@ -15,6 +16,7 @@ from collector_manager.AsyncCollectorManager import AsyncCollectorManager from core.AsyncCore import AsyncCore from core.AsyncCoreLogger import AsyncCoreLogger +from core.EnvVarManager import EnvVarManager from core.ScheduledTaskManager import AsyncScheduledTaskManager from core.SourceCollectorCore import SourceCollectorCore from core.TaskManager import TaskManager @@ -22,18 +24,27 @@ from html_tag_collector.RootURLCache import RootURLCache from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface +from pdap_api_client.AccessManager import AccessManager +from pdap_api_client.PDAPClient import PDAPClient from util.DiscordNotifier import DiscordPoster -from util.helper_functions import get_from_env + @asynccontextmanager async def lifespan(app: FastAPI): + env_var_manager = EnvVarManager.get() + # Initialize shared dependencies - db_client = DatabaseClient() - adb_client = AsyncDatabaseClient() + db_client = DatabaseClient( + db_url=env_var_manager.get_postgres_connection_string() + ) + adb_client = AsyncDatabaseClient( + db_url=env_var_manager.get_postgres_connection_string(is_async=True) + ) await setup_database(db_client) core_logger = AsyncCoreLogger(adb_client=adb_client) + session = aiohttp.ClientSession() source_collector_core = SourceCollectorCore( db_client=DatabaseClient(), @@ -46,7 +57,15 @@ async def lifespan(app: FastAPI): root_url_cache=RootURLCache() ), discord_poster=DiscordPoster( - webhook_url=get_from_env("DISCORD_WEBHOOK_URL") + webhook_url=env_var_manager.discord_webhook_url + ), + pdap_client=PDAPClient( + access_manager=AccessManager( + email=env_var_manager.pdap_email, + password=env_var_manager.pdap_password, + api_key=env_var_manager.pdap_api_key, + session=session + ) ) ) async_collector_manager = AsyncCollectorManager( @@ -72,17 +91,17 @@ async def lifespan(app: FastAPI): yield # Code here runs before shutdown # Shutdown logic (if needed) + # Clean up resources, close connections, etc. await core_logger.shutdown() await async_core.shutdown() source_collector_core.shutdown() - # Clean up resources, close connections, etc. + await session.close() pass async def setup_database(db_client): # Initialize database if dev environment, otherwise apply migrations try: - get_from_env("DEV") db_client.init_db() except Exception as e: return @@ -95,6 +114,7 @@ async def setup_database(db_client): lifespan=lifespan ) + routers = [ root_router, collector_router, diff --git a/api/routes/batch.py b/api/routes/batch.py index 23df2394..9d4b62cc 100644 --- a/api/routes/batch.py +++ b/api/routes/batch.py @@ -25,7 +25,7 @@ @batch_router.get("") -def get_batch_status( +async def get_batch_status( collector_type: Optional[CollectorType] = Query( description="Filter by collector type", default=None @@ -38,13 +38,13 @@ def get_batch_status( description="The page number", default=1 ), - core: SourceCollectorCore = Depends(get_core), + core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetBatchStatusResponse: """ Get the status of recent batches """ - return core.get_batch_statuses(collector_type=collector_type, status=status, page=page) + return await core.get_batch_statuses(collector_type=collector_type, status=status, page=page) @batch_router.get("/{batch_id}") @@ -69,28 +69,28 @@ async def get_urls_by_batch( return await core.get_urls_by_batch(batch_id, page=page) @batch_router.get("/{batch_id}/duplicates") -def get_duplicates_by_batch( +async def get_duplicates_by_batch( batch_id: int = Path(description="The batch id"), page: int = Query( description="The page number", default=1 ), - core: SourceCollectorCore = Depends(get_core), + core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetDuplicatesByBatchResponse: - return core.get_duplicate_urls_by_batch(batch_id, page=page) + return await core.get_duplicate_urls_by_batch(batch_id, page=page) @batch_router.get("/{batch_id}/logs") -def get_batch_logs( +async def get_batch_logs( batch_id: int = Path(description="The batch id"), - core: SourceCollectorCore = Depends(get_core), + async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), ) -> GetBatchLogsResponse: """ Retrieve the logs for a recent batch. Note that for later batches, the logs may not be available. """ - return core.get_batch_logs(batch_id) + return await async_core.get_batch_logs(batch_id) @batch_router.post("/{batch_id}/abort") async def abort_batch( diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index c8315fbe..98410b6f 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -5,22 +5,21 @@ from sqlalchemy import select, exists, func, case, desc, Select, not_, and_, or_, update, Delete, Insert, asc from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker -from sqlalchemy.orm import selectinload, joinedload, QueryableAttribute +from sqlalchemy.orm import selectinload, joinedload, QueryableAttribute, aliased from sqlalchemy.sql.functions import coalesce from starlette import status from collector_db.ConfigManager import ConfigManager from collector_db.DTOConverter import DTOConverter from collector_db.DTOs.BatchInfo import BatchInfo -from collector_db.DTOs.DuplicateInfo import DuplicateInsertInfo +from collector_db.DTOs.DuplicateInfo import DuplicateInsertInfo, DuplicateInfo from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo -from collector_db.DTOs.LogInfo import LogInfo +from collector_db.DTOs.LogInfo import LogInfo, LogOutputInfo from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo, HTMLContentType from collector_db.DTOs.URLInfo import URLInfo from collector_db.DTOs.URLMapping import URLMapping -from collector_db.DTOs.URLWithHTML import URLWithHTML from collector_db.StatementComposer import StatementComposer from collector_db.constants import PLACEHOLDER_AGENCY_NAME from collector_db.enums import URLMetadataAttributeType, TaskType @@ -42,7 +41,9 @@ GetURLsResponseInnerInfo from core.DTOs.URLAgencySuggestionInfo import URLAgencySuggestionInfo from core.DTOs.task_data_objects.AgencyIdentificationTDO import AgencyIdentificationTDO +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO, SubmittedURLInfo from core.DTOs.task_data_objects.URLMiscellaneousMetadataTDO import URLMiscellaneousMetadataTDO, URLHTMLMetadataInfo +from core.EnvVarManager import EnvVarManager from core.enums import BatchStatus, SuggestionType, RecordType from html_tag_collector.DataClassTags import convert_to_response_html_info @@ -58,7 +59,9 @@ def add_standard_limit_and_offset(statement, page, limit=100): class AsyncDatabaseClient: - def __init__(self, db_url: str = get_postgres_connection_string(is_async=True)): + def __init__(self, db_url: Optional[str] = None): + if db_url is None: + db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) self.engine = create_async_engine( url=db_url, echo=ConfigManager.get_sqlalchemy_echo(), @@ -1465,3 +1468,144 @@ async def update_batch_post_collection( batch.duplicate_url_count = duplicate_url_count batch.status = batch_status.value batch.compute_time = compute_time + + + @session_manager + async def has_validated_urls(self, session: AsyncSession) -> bool: + query = ( + select(URL) + .where(URL.outcome == URLStatus.VALIDATED.value) + ) + urls = await session.execute(query) + urls = urls.scalars().all() + return len(urls) > 0 + + @session_manager + async def get_validated_urls( + self, + session: AsyncSession + ) -> list[SubmitApprovedURLTDO]: + query = ( + select(URL) + .where(URL.outcome == URLStatus.VALIDATED.value) + .options( + selectinload(URL.optional_data_source_metadata), + selectinload(URL.confirmed_agencies), + selectinload(URL.reviewing_user) + ).limit(100) + ) + urls = await session.execute(query) + urls = urls.scalars().all() + results: list[SubmitApprovedURLTDO] = [] + for url in urls: + agency_ids = [] + for agency in url.confirmed_agencies: + agency_ids.append(agency.agency_id) + optional_metadata = url.optional_data_source_metadata + + if optional_metadata is None: + record_formats = None + data_portal_type = None + supplying_entity = None + else: + record_formats = optional_metadata.record_formats + data_portal_type = optional_metadata.data_portal_type + supplying_entity = optional_metadata.supplying_entity + + tdo = SubmitApprovedURLTDO( + url_id=url.id, + url=url.url, + name=url.name, + agency_ids=agency_ids, + description=url.description, + record_type=url.record_type, + record_formats=record_formats, + data_portal_type=data_portal_type, + supplying_entity=supplying_entity, + approving_user_id=url.reviewing_user.user_id + ) + results.append(tdo) + return results + + @session_manager + async def mark_urls_as_submitted(self, session: AsyncSession, infos: list[SubmittedURLInfo]): + for info in infos: + url_id = info.url_id + data_source_id = info.data_source_id + query = ( + update(URL) + .where(URL.id == url_id) + .values( + data_source_id=data_source_id, + outcome=URLStatus.SUBMITTED.value + ) + ) + await session.execute(query) + + @session_manager + async def get_duplicates_by_batch_id(self, session, batch_id: int, page: int) -> List[DuplicateInfo]: + original_batch = aliased(Batch) + duplicate_batch = aliased(Batch) + + query = ( + Select( + URL.url.label("source_url"), + URL.id.label("original_url_id"), + duplicate_batch.id.label("duplicate_batch_id"), + duplicate_batch.parameters.label("duplicate_batch_parameters"), + original_batch.id.label("original_batch_id"), + original_batch.parameters.label("original_batch_parameters"), + ) + .select_from(Duplicate) + .join(URL, Duplicate.original_url_id == URL.id) + .join(duplicate_batch, Duplicate.batch_id == duplicate_batch.id) + .join(original_batch, URL.batch_id == original_batch.id) + .filter(duplicate_batch.id == batch_id) + .limit(100) + .offset((page - 1) * 100) + ) + raw_results = await session.execute(query) + results = raw_results.all() + final_results = [] + for result in results: + final_results.append( + DuplicateInfo( + source_url=result.source_url, + duplicate_batch_id=result.duplicate_batch_id, + duplicate_metadata=result.duplicate_batch_parameters, + original_batch_id=result.original_batch_id, + original_metadata=result.original_batch_parameters, + original_url_id=result.original_url_id + ) + ) + return final_results + + @session_manager + async def get_recent_batch_status_info( + self, + session, + page: int, + collector_type: Optional[CollectorType] = None, + status: Optional[BatchStatus] = None, + ) -> List[BatchInfo]: + # Get only the batch_id, collector_type, status, and created_at + limit = 100 + query = (Select(Batch) + .order_by(Batch.date_generated.desc())) + if collector_type: + query = query.filter(Batch.strategy == collector_type.value) + if status: + query = query.filter(Batch.status == status.value) + query = (query.limit(limit) + .offset((page - 1) * limit)) + raw_results = await session.execute(query) + batches = raw_results.scalars().all() + return [BatchInfo(**batch.__dict__) for batch in batches] + + @session_manager + async def get_logs_by_batch_id(self, session, batch_id: int) -> List[LogOutputInfo]: + query = Select(Log).filter_by(batch_id=batch_id).order_by(Log.created_at.asc()) + raw_results = await session.execute(query) + logs = raw_results.scalars().all() + return ([LogOutputInfo(**log.__dict__) for log in logs]) + diff --git a/collector_db/DTOs/URLInfo.py b/collector_db/DTOs/URLInfo.py index afe6c2f2..c47d2830 100644 --- a/collector_db/DTOs/URLInfo.py +++ b/collector_db/DTOs/URLInfo.py @@ -13,3 +13,4 @@ class URLInfo(BaseModel): collector_metadata: Optional[dict] = None outcome: URLStatus = URLStatus.PENDING updated_at: Optional[datetime.datetime] = None + name: Optional[str] = None diff --git a/collector_db/DatabaseClient.py b/collector_db/DatabaseClient.py index 8d72ef0d..b8547f1d 100644 --- a/collector_db/DatabaseClient.py +++ b/collector_db/DatabaseClient.py @@ -16,13 +16,17 @@ from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base, Batch, URL, Log, Duplicate from collector_manager.enums import CollectorType +from core.EnvVarManager import EnvVarManager from core.enums import BatchStatus # Database Client class DatabaseClient: - def __init__(self, db_url: str = get_postgres_connection_string()): + def __init__(self, db_url: Optional[str] = None): """Initialize the DatabaseClient.""" + if db_url is None: + db_url = EnvVarManager.get().get_postgres_connection_string(is_async=True) + self.engine = create_engine( url=db_url, echo=ConfigManager.get_sqlalchemy_echo(), @@ -49,10 +53,6 @@ def wrapper(self, *args, **kwargs): return wrapper - def row_to_dict(self, row: Row) -> dict: - return dict(row._mapping) - - @session_manager def insert_batch(self, session, batch_info: BatchInfo) -> int: """Insert a new batch into the database and return its ID.""" @@ -105,24 +105,14 @@ def insert_url(self, session, url_info: URLInfo) -> int: batch_id=url_info.batch_id, url=url_info.url, collector_metadata=url_info.collector_metadata, - outcome=url_info.outcome.value + outcome=url_info.outcome.value, + name=url_info.name ) session.add(url_entry) session.commit() session.refresh(url_entry) return url_entry.id - @session_manager - def add_duplicate_info(self, session, duplicate_infos: list[DuplicateInfo]): - # TODO: Add test for this method when testing CollectorDatabaseProcessor - for duplicate_info in duplicate_infos: - duplicate = Duplicate( - batch_id=duplicate_info.original_batch_id, - original_url_id=duplicate_info.original_url_id, - ) - session.add(duplicate) - - def insert_urls(self, url_infos: List[URLInfo], batch_id: int) -> InsertURLsInfo: url_mappings = [] duplicates = [] @@ -163,83 +153,11 @@ def insert_logs(self, session, log_infos: List[LogInfo]): log.created_at = log_info.created_at session.add(log) - @session_manager - def get_logs_by_batch_id(self, session, batch_id: int) -> List[LogOutputInfo]: - logs = session.query(Log).filter_by(batch_id=batch_id).order_by(Log.created_at.asc()).all() - return ([LogOutputInfo(**log.__dict__) for log in logs]) - - @session_manager - def get_all_logs(self, session) -> List[LogInfo]: - logs = session.query(Log).all() - return ([LogInfo(**log.__dict__) for log in logs]) - @session_manager def get_batch_status(self, session, batch_id: int) -> BatchStatus: batch = session.query(Batch).filter_by(id=batch_id).first() return BatchStatus(batch.status) - @session_manager - def get_recent_batch_status_info( - self, - session, - page: int, - collector_type: Optional[CollectorType] = None, - status: Optional[BatchStatus] = None, - ) -> List[BatchInfo]: - # Get only the batch_id, collector_type, status, and created_at - limit = 100 - query = (session.query(Batch) - .order_by(Batch.date_generated.desc())) - if collector_type: - query = query.filter(Batch.strategy == collector_type.value) - if status: - query = query.filter(Batch.status == status.value) - query = (query.limit(limit) - .offset((page - 1) * limit)) - batches = query.all() - return [BatchInfo(**batch.__dict__) for batch in batches] - - @session_manager - def get_duplicates_by_batch_id(self, session, batch_id: int, page: int) -> List[DuplicateInfo]: - original_batch = aliased(Batch) - duplicate_batch = aliased(Batch) - - query = ( - session.query( - URL.url.label("source_url"), - URL.id.label("original_url_id"), - duplicate_batch.id.label("duplicate_batch_id"), - duplicate_batch.parameters.label("duplicate_batch_parameters"), - original_batch.id.label("original_batch_id"), - original_batch.parameters.label("original_batch_parameters"), - ) - .select_from(Duplicate) - .join(URL, Duplicate.original_url_id == URL.id) - .join(duplicate_batch, Duplicate.batch_id == duplicate_batch.id) - .join(original_batch, URL.batch_id == original_batch.id) - .filter(duplicate_batch.id == batch_id) - .limit(100) - .offset((page - 1) * 100) - ) - results = query.all() - final_results = [] - for result in results: - final_results.append( - DuplicateInfo( - source_url=result.source_url, - duplicate_batch_id=result.duplicate_batch_id, - duplicate_metadata=result.duplicate_batch_parameters, - original_batch_id=result.original_batch_id, - original_metadata=result.original_batch_parameters, - original_url_id=result.original_url_id - ) - ) - return final_results - - @session_manager - def delete_all_logs(self, session): - session.query(Log).delete() - @session_manager def delete_old_logs(self, session): """ diff --git a/collector_db/enums.py b/collector_db/enums.py index c12cfde0..a701a847 100644 --- a/collector_db/enums.py +++ b/collector_db/enums.py @@ -38,6 +38,7 @@ class TaskType(PyEnum): RECORD_TYPE = "Record Type" AGENCY_IDENTIFICATION = "Agency Identification" MISC_METADATA = "Misc Metadata" + SUBMIT_APPROVED = "Submit Approved URLs" IDLE = "Idle" class PGEnum(TypeDecorator): diff --git a/collector_db/helper_functions.py b/collector_db/helper_functions.py index dcb161b9..4f99556a 100644 --- a/collector_db/helper_functions.py +++ b/collector_db/helper_functions.py @@ -2,15 +2,8 @@ import dotenv +from core.EnvVarManager import EnvVarManager + def get_postgres_connection_string(is_async = False): - dotenv.load_dotenv() - username = os.getenv("POSTGRES_USER") - password = os.getenv("POSTGRES_PASSWORD") - host = os.getenv("POSTGRES_HOST") - port = os.getenv("POSTGRES_PORT") - database = os.getenv("POSTGRES_DB") - driver = "postgresql" - if is_async: - driver += "+asyncpg" - return f"{driver}://{username}:{password}@{host}:{port}/{database}" \ No newline at end of file + return EnvVarManager.get().get_postgres_connection_string(is_async) diff --git a/collector_db/models.py b/collector_db/models.py index e420961f..e98ef437 100644 --- a/collector_db/models.py +++ b/collector_db/models.py @@ -105,6 +105,7 @@ class URL(Base): record_type = Column(postgresql.ENUM(*record_type_values, name='record_type'), nullable=True) created_at = get_created_at_column() updated_at = get_updated_at_column() + data_source_id = Column(Integer, nullable=True) # Relationships batch = relationship("Batch", back_populates="urls") @@ -128,8 +129,8 @@ class URL(Base): "AutoRelevantSuggestion", uselist=False, back_populates="url") user_relevant_suggestions = relationship( "UserRelevantSuggestion", back_populates="url") - reviewing_users = relationship( - "ReviewingUserURL", back_populates="url") + reviewing_user = relationship( + "ReviewingUserURL", uselist=False, back_populates="url") optional_data_source_metadata = relationship( "URLOptionalDataSourceMetadata", uselist=False, back_populates="url") confirmed_agencies = relationship( @@ -163,7 +164,7 @@ class ReviewingUserURL(Base): created_at = get_created_at_column() # Relationships - url = relationship("URL", back_populates="reviewing_users") + url = relationship("URL", uselist=False, back_populates="reviewing_user") class RootURL(Base): __tablename__ = 'root_url_cache' diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 299a865e..cb9a80bc 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -10,6 +10,9 @@ from collector_manager.enums import CollectorType from core.DTOs.CollectorStartInfo import CollectorStartInfo from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo +from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse +from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse +from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ @@ -20,6 +23,7 @@ from core.DTOs.MessageResponse import MessageResponse from core.TaskManager import TaskManager from core.enums import BatchStatus, RecordType + from security_manager.SecurityManager import AccessInfo @@ -55,6 +59,27 @@ async def abort_batch(self, batch_id: int) -> MessageResponse: await self.collector_manager.abort_collector_async(cid=batch_id) return MessageResponse(message=f"Batch aborted.") + async def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> GetDuplicatesByBatchResponse: + dup_infos = await self.adb_client.get_duplicates_by_batch_id(batch_id, page=page) + return GetDuplicatesByBatchResponse(duplicates=dup_infos) + + async def get_batch_statuses( + self, + collector_type: Optional[CollectorType], + status: Optional[BatchStatus], + page: int + ) -> GetBatchStatusResponse: + results = await self.adb_client.get_recent_batch_status_info( + collector_type=collector_type, + status=status, + page=page + ) + return GetBatchStatusResponse(results=results) + + async def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: + logs = await self.adb_client.get_logs_by_batch_id(batch_id) + return GetBatchLogsResponse(logs=logs) + #endregion # region Collector diff --git a/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py new file mode 100644 index 00000000..c5b002d0 --- /dev/null +++ b/core/DTOs/task_data_objects/SubmitApprovedURLTDO.py @@ -0,0 +1,25 @@ +from typing import Optional + +from pydantic import BaseModel + +from core.enums import RecordType + + +class SubmitApprovedURLTDO(BaseModel): + url_id: int + url: str + record_type: RecordType + agency_ids: list[int] + name: str + description: str + approving_user_id: int + record_formats: Optional[list[str]] = None + data_portal_type: Optional[str] = None + supplying_entity: Optional[str] = None + data_source_id: Optional[int] = None + request_error: Optional[str] = None + +class SubmittedURLInfo(BaseModel): + url_id: int + data_source_id: Optional[int] + request_error: Optional[str] \ No newline at end of file diff --git a/core/EnvVarManager.py b/core/EnvVarManager.py new file mode 100644 index 00000000..39e4ce83 --- /dev/null +++ b/core/EnvVarManager.py @@ -0,0 +1,76 @@ +import os + +class EnvVarManager: + _instance = None + _allow_direct_init = False # internal flag + + """ + A class for unified management of environment variables + """ + def __new__(cls, *args, **kwargs): + if not cls._allow_direct_init: + raise RuntimeError("Use `EnvVarManager.get()` or `EnvVarManager.override()` instead.") + return super().__new__(cls) + + def __init__(self, env: dict = os.environ): + self.env = env + self._load() + + def _load(self): + + self.google_api_key = self.require_env("GOOGLE_API_KEY") + self.google_cse_id = self.require_env("GOOGLE_CSE_ID") + + self.pdap_email = self.require_env("PDAP_EMAIL") + self.pdap_password = self.require_env("PDAP_PASSWORD") + self.pdap_api_key = self.require_env("PDAP_API_KEY") + self.pdap_api_url = self.require_env("PDAP_API_URL") + + self.discord_webhook_url = self.require_env("DISCORD_WEBHOOK_URL") + + self.openai_api_key = self.require_env("OPENAI_API_KEY") + + self.postgres_user = self.require_env("POSTGRES_USER") + self.postgres_password = self.require_env("POSTGRES_PASSWORD") + self.postgres_host = self.require_env("POSTGRES_HOST") + self.postgres_port = self.require_env("POSTGRES_PORT") + self.postgres_db = self.require_env("POSTGRES_DB") + + @classmethod + def get(cls): + """ + Get the singleton instance, loading from environment if not yet + instantiated + """ + if cls._instance is None: + cls._allow_direct_init = True + cls._instance = cls(os.environ) + cls._allow_direct_init = False + return cls._instance + + @classmethod + def override(cls, env: dict): + """ + Create singleton instance that + overrides the environment variables with injected values + """ + cls._allow_direct_init = True + cls._instance = cls(env) + cls._allow_direct_init = False + + @classmethod + def reset(cls): + cls._instance = None + + def get_postgres_connection_string(self, is_async = False): + driver = "postgresql" + if is_async: + driver += "+asyncpg" + return (f"{driver}://{self.postgres_user}:{self.postgres_password}" + f"@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}") + + def require_env(self, key: str, allow_none: bool = False): + val = self.env.get(key) + if val is None and not allow_none: + raise ValueError(f"Environment variable {key} is not set") + return val \ No newline at end of file diff --git a/core/README.md b/core/README.md index 25b1cde3..9546f613 100644 --- a/core/README.md +++ b/core/README.md @@ -11,4 +11,5 @@ The Source Collector Core is a directory which integrates: - **Cycle**: Refers to the overall lifecycle for Each URL -- from initial retrieval in a Batch to either disposal or incorporation into the Data Sources App Database - **Task**: A semi-independent operation performed on a set of URLs. These include: Collection, retrieving HTML data, getting metadata via Machine Learning, and so on. - **Task Set**: Refers to a group of URLs that are operated on together as part of a single task. These URLs in a set are not necessarily all from the same batch. URLs in a task set should only be operated on in that task once. -- **Task Operator**: A class which performs a single task on a set of URLs. \ No newline at end of file +- **Task Operator**: A class which performs a single task on a set of URLs. +- **Subtask**: A subcomponent of a Task Operator which performs a single operation on a single URL. Often distinguished by the Collector Strategy used for that URL. \ No newline at end of file diff --git a/core/SourceCollectorCore.py b/core/SourceCollectorCore.py index 8002717c..4516ceb5 100644 --- a/core/SourceCollectorCore.py +++ b/core/SourceCollectorCore.py @@ -2,10 +2,6 @@ from collector_db.DatabaseClient import DatabaseClient -from collector_manager.enums import CollectorType -from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse -from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse -from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse from core.ScheduledTaskManager import ScheduledTaskManager from core.enums import BatchStatus @@ -15,38 +11,21 @@ def __init__( self, core_logger: Optional[Any] = None, # Deprecated collector_manager: Optional[Any] = None, # Deprecated - db_client: DatabaseClient = DatabaseClient(), + db_client: Optional[DatabaseClient] = None, dev_mode: bool = False ): + if db_client is None: + db_client = DatabaseClient() self.db_client = db_client if not dev_mode: self.scheduled_task_manager = ScheduledTaskManager(db_client=db_client) else: self.scheduled_task_manager = None - def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> GetDuplicatesByBatchResponse: - dup_infos = self.db_client.get_duplicates_by_batch_id(batch_id, page=page) - return GetDuplicatesByBatchResponse(duplicates=dup_infos) - - def get_batch_statuses( - self, - collector_type: Optional[CollectorType], - status: Optional[BatchStatus], - page: int - ) -> GetBatchStatusResponse: - results = self.db_client.get_recent_batch_status_info( - collector_type=collector_type, - status=status, - page=page - ) - return GetBatchStatusResponse(results=results) def get_status(self, batch_id: int) -> BatchStatus: return self.db_client.get_batch_status(batch_id) - def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: - logs = self.db_client.get_logs_by_batch_id(batch_id) - return GetBatchLogsResponse(logs=logs) def shutdown(self): if self.scheduled_task_manager is not None: diff --git a/core/TaskManager.py b/core/TaskManager.py index 624cb906..7796e80e 100644 --- a/core/TaskManager.py +++ b/core/TaskManager.py @@ -1,5 +1,4 @@ import logging -from typing import Optional from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient @@ -9,6 +8,7 @@ from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome from core.FunctionTrigger import FunctionTrigger from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator +from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator from core.classes.TaskOperatorBase import TaskOperatorBase from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator from core.classes.URLMiscellaneousMetadataTaskOperator import URLMiscellaneousMetadataTaskOperator @@ -19,10 +19,8 @@ from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier -from pdap_api_client.AccessManager import AccessManager from pdap_api_client.PDAPClient import PDAPClient from util.DiscordNotifier import DiscordPoster -from util.helper_functions import get_from_env TASK_REPEAT_THRESHOLD = 20 @@ -35,12 +33,16 @@ def __init__( url_request_interface: URLRequestInterface, html_parser: HTMLResponseParser, discord_poster: DiscordPoster, + pdap_client: PDAPClient ): + # Dependencies self.adb_client = adb_client + self.pdap_client = pdap_client self.huggingface_interface = huggingface_interface self.url_request_interface = url_request_interface self.html_parser = html_parser self.discord_poster = discord_poster + self.logger = logging.getLogger(__name__) self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) @@ -73,21 +75,21 @@ async def get_url_record_type_task_operator(self): return operator async def get_agency_identification_task_operator(self): - pdap_client = PDAPClient( - access_manager=AccessManager( - email=get_from_env("PDAP_EMAIL"), - password=get_from_env("PDAP_PASSWORD"), - api_key=get_from_env("PDAP_API_KEY"), - ), - ) muckrock_api_interface = MuckrockAPIInterface() operator = AgencyIdentificationTaskOperator( adb_client=self.adb_client, - pdap_client=pdap_client, + pdap_client=self.pdap_client, muckrock_api_interface=muckrock_api_interface ) return operator + async def get_submit_approved_url_task_operator(self): + operator = SubmitApprovedURLTaskOperator( + adb_client=self.adb_client, + pdap_client=self.pdap_client + ) + return operator + async def get_url_miscellaneous_metadata_task_operator(self): operator = URLMiscellaneousMetadataTaskOperator( adb_client=self.adb_client @@ -96,11 +98,12 @@ async def get_url_miscellaneous_metadata_task_operator(self): async def get_task_operators(self) -> list[TaskOperatorBase]: return [ - # await self.get_url_html_task_operator(), + await self.get_url_html_task_operator(), await self.get_url_relevance_huggingface_task_operator(), await self.get_url_record_type_task_operator(), await self.get_agency_identification_task_operator(), - await self.get_url_miscellaneous_metadata_task_operator() + await self.get_url_miscellaneous_metadata_task_operator(), + await self.get_submit_approved_url_task_operator() ] #endregion diff --git a/core/classes/SubmitApprovedURLTaskOperator.py b/core/classes/SubmitApprovedURLTaskOperator.py new file mode 100644 index 00000000..81f0b242 --- /dev/null +++ b/core/classes/SubmitApprovedURLTaskOperator.py @@ -0,0 +1,65 @@ +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo +from collector_db.enums import TaskType +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO +from core.classes.TaskOperatorBase import TaskOperatorBase +from pdap_api_client.PDAPClient import PDAPClient + + +class SubmitApprovedURLTaskOperator(TaskOperatorBase): + + def __init__( + self, + adb_client: AsyncDatabaseClient, + pdap_client: PDAPClient + ): + super().__init__(adb_client) + self.pdap_client = pdap_client + + @property + def task_type(self): + return TaskType.SUBMIT_APPROVED + + async def meets_task_prerequisites(self): + return await self.adb_client.has_validated_urls() + + async def inner_task_logic(self): + # Retrieve all URLs that are validated and not submitted + tdos: list[SubmitApprovedURLTDO] = await self.adb_client.get_validated_urls() + + # Link URLs to this task + await self.link_urls_to_task(url_ids=[tdo.url_id for tdo in tdos]) + + # Submit each URL, recording errors if they exist + submitted_url_infos = await self.pdap_client.submit_urls(tdos) + + error_infos = await self.get_error_infos(submitted_url_infos) + success_infos = await self.get_success_infos(submitted_url_infos) + + # Update the database for successful submissions + await self.adb_client.mark_urls_as_submitted(infos=success_infos) + + # Update the database for failed submissions + await self.adb_client.add_url_error_infos(error_infos) + + async def get_success_infos(self, submitted_url_infos): + success_infos = [ + response_object for response_object in submitted_url_infos + if response_object.data_source_id is not None + ] + return success_infos + + async def get_error_infos(self, submitted_url_infos): + error_infos: list[URLErrorPydanticInfo] = [] + error_response_objects = [ + response_object for response_object in submitted_url_infos + if response_object.request_error is not None + ] + for error_response_object in error_response_objects: + error_info = URLErrorPydanticInfo( + task_id=self.task_id, + url_id=error_response_object.url_id, + error=error_response_object.request_error, + ) + error_infos.append(error_info) + return error_infos diff --git a/core/enums.py b/core/enums.py index 213db47c..cfccbb92 100644 --- a/core/enums.py +++ b/core/enums.py @@ -7,11 +7,10 @@ class BatchStatus(Enum): ERROR = "error" ABORTED = "aborted" -class LabelStudioTaskStatus(Enum): - PENDING = "pending" - COMPLETED = "completed" - class RecordType(Enum): + """ + All available URL record types + """ ACCIDENT_REPORTS = "Accident Reports" ARREST_RECORDS = "Arrest Records" CALLS_FOR_SERVICE = "Calls for Service" @@ -51,8 +50,19 @@ class RecordType(Enum): class SuggestionType(Enum): + """ + Identifies the specific kind of suggestion made for a URL + """ AUTO_SUGGESTION = "Auto Suggestion" MANUAL_SUGGESTION = "Manual Suggestion" UNKNOWN = "Unknown" NEW_AGENCY = "New Agency" CONFIRMED = "Confirmed" + +class SubmitResponseStatus(Enum): + """ + Response statuses from the /source-collector/data-sources endpoint + """ + SUCCESS = "success" + FAILURE = "FAILURE" + ALREADY_EXISTS = "already_exists" \ No newline at end of file diff --git a/html_tag_collector/RootURLCache.py b/html_tag_collector/RootURLCache.py index e306b6e1..165be89d 100644 --- a/html_tag_collector/RootURLCache.py +++ b/html_tag_collector/RootURLCache.py @@ -16,7 +16,9 @@ class RootURLCacheResponseInfo: exception: Optional[Exception] = None class RootURLCache: - def __init__(self, adb_client: AsyncDatabaseClient = AsyncDatabaseClient()): + def __init__(self, adb_client: Optional[AsyncDatabaseClient] = None): + if adb_client is None: + adb_client = AsyncDatabaseClient() self.adb_client = adb_client self.cache = None diff --git a/llm_api_logic/OpenAIRecordClassifier.py b/llm_api_logic/OpenAIRecordClassifier.py index fc20a0e2..cc0829b5 100644 --- a/llm_api_logic/OpenAIRecordClassifier.py +++ b/llm_api_logic/OpenAIRecordClassifier.py @@ -1,17 +1,16 @@ -from typing import Any from openai.types.chat import ParsedChatCompletion +from core.EnvVarManager import EnvVarManager from llm_api_logic.LLMRecordClassifierBase import RecordClassifierBase from llm_api_logic.RecordTypeStructuredOutput import RecordTypeStructuredOutput -from util.helper_functions import get_from_env class OpenAIRecordClassifier(RecordClassifierBase): @property def api_key(self): - return get_from_env("OPENAI_API_KEY") + return EnvVarManager.get().openai_api_key @property def model_name(self): diff --git a/pdap_api_client/AccessManager.py b/pdap_api_client/AccessManager.py index c39ba1e8..aadd8451 100644 --- a/pdap_api_client/AccessManager.py +++ b/pdap_api_client/AccessManager.py @@ -4,9 +4,9 @@ import requests from aiohttp import ClientSession +from core.EnvVarManager import EnvVarManager from pdap_api_client.DTOs import RequestType, Namespaces, RequestInfo, ResponseInfo -API_URL = "https://data-sources-v2.pdap.dev/api" request_methods = { RequestType.POST: ClientSession.post, RequestType.PUT: ClientSession.put, @@ -23,7 +23,8 @@ def build_url( namespace: Namespaces, subdomains: Optional[list[str]] = None ): - url = f"{API_URL}/{namespace.value}" + api_url = EnvVarManager.get().pdap_api_url + url = f"{api_url}/{namespace.value}" if subdomains is not None: url = f"{url}/{'/'.join(subdomains)}" return url diff --git a/pdap_api_client/DTOs.py b/pdap_api_client/DTOs.py index 19255a35..93f67839 100644 --- a/pdap_api_client/DTOs.py +++ b/pdap_api_client/DTOs.py @@ -36,6 +36,8 @@ class Namespaces(Enum): AUTH = "auth" MATCH = "match" CHECK = "check" + DATA_SOURCES = "data-sources" + SOURCE_COLLECTOR = "source-collector" class RequestType(Enum): diff --git a/pdap_api_client/PDAPClient.py b/pdap_api_client/PDAPClient.py index b2b89564..24b9d98c 100644 --- a/pdap_api_client/PDAPClient.py +++ b/pdap_api_client/PDAPClient.py @@ -1,5 +1,6 @@ from typing import Optional +from core.DTOs.task_data_objects.SubmitApprovedURLTDO import SubmitApprovedURLTDO, SubmittedURLInfo from pdap_api_client.AccessManager import build_url, AccessManager from pdap_api_client.DTOs import MatchAgencyInfo, UniqueURLDuplicateInfo, UniqueURLResponseInfo, Namespaces, \ RequestType, RequestInfo, MatchAgencyResponse @@ -21,7 +22,6 @@ async def match_agency( county: Optional[str] = None, locality: Optional[str] = None ) -> MatchAgencyResponse: - # TODO: Change to async """ Returns agencies, if any, that match or partially match the search criteria """ @@ -84,3 +84,60 @@ async def is_url_unique( is_unique=is_unique, duplicates=duplicates ) + + async def submit_urls( + self, + tdos: list[SubmitApprovedURLTDO] + ) -> list[SubmittedURLInfo]: + """ + Submits URLs to Data Sources App, + modifying tdos in-place with data source id or error + """ + request_url = build_url( + namespace=Namespaces.SOURCE_COLLECTOR, + subdomains=["data-sources"] + ) + + # Build url-id dictionary + url_id_dict = {} + for tdo in tdos: + url_id_dict[tdo.url] = tdo.url_id + + data_sources_json = [] + for tdo in tdos: + data_sources_json.append({ + "name": tdo.name, + "description": tdo.description, + "source_url": tdo.url, + "record_type": tdo.record_type.value, + "record_formats": tdo.record_formats, + "data_portal_type": tdo.data_portal_type, + "last_approval_editor": tdo.approving_user_id, + "supplying_entity": tdo.supplying_entity, + "agency_ids": tdo.agency_ids + }) + + + headers = await self.access_manager.jwt_header() + request_info = RequestInfo( + type_=RequestType.POST, + url=request_url, + headers=headers, + json={ + "data_sources": data_sources_json + } + ) + response_info = await self.access_manager.make_request(request_info) + data_sources_response_json = response_info.data["data_sources"] + + results = [] + for data_source in data_sources_response_json: + url = data_source["url"] + response_object = SubmittedURLInfo( + url_id=url_id_dict[url], + data_source_id=data_source["data_source_id"], + request_error=data_source["error"] + ) + results.append(response_object) + + return results diff --git a/start_mirrored_local_app.py b/start_mirrored_local_app.py index 48859adc..940c372e 100644 --- a/start_mirrored_local_app.py +++ b/start_mirrored_local_app.py @@ -19,7 +19,7 @@ import docker from docker.errors import APIError, NotFound from docker.models.containers import Container -from pydantic import BaseModel, model_validator, AfterValidator +from pydantic import BaseModel, AfterValidator from apply_migrations import apply_migrations from util.helper_functions import get_from_env @@ -193,7 +193,7 @@ def get_image(self, dockerfile_info: DockerfileInfo): def run_container( self, docker_info: DockerInfo, - ): + ) -> Container: print(f"Running container {docker_info.name}") try: container = self.client.containers.get(docker_info.name) @@ -255,24 +255,11 @@ def set_last_run_time(self): with open("local_state/last_run.txt", "w") as f: f.write(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")) - - -def main(): - docker_manager = DockerManager() - # Ensure docker is running, and start if not - if not is_docker_running(): - start_docker_engine() - - - # Ensure Dockerfile for database is running, and if not, start it - database_docker_info = DockerInfo( +def get_database_docker_info() -> DockerInfo: + return DockerInfo( dockerfile_info=DockerfileInfo( image_tag="postgres:15", ), - # volume_info=VolumeInfo( - # host_path="dbscripts", - # container_path="/var/lib/postgresql/data" - # ), name="data_source_identification_db", ports={ "5432/tcp": 5432 @@ -290,12 +277,9 @@ def main(): start_period=2 ) ) - container = docker_manager.run_container(database_docker_info) - wait_for_pg_to_be_ready(container) - - # Start dockerfile for Datadumper - data_dumper_docker_info = DockerInfo( +def get_data_dumper_docker_info() -> DockerInfo: + return DockerInfo( dockerfile_info=DockerfileInfo( image_tag="datadumper", dockerfile_directory="local_database/DataDumper" @@ -320,6 +304,21 @@ def main(): command="bash" ) +def main(): + docker_manager = DockerManager() + # Ensure docker is running, and start if not + if not is_docker_running(): + start_docker_engine() + + # Ensure Dockerfile for database is running, and if not, start it + database_docker_info = get_database_docker_info() + container = docker_manager.run_container(database_docker_info) + wait_for_pg_to_be_ready(container) + + + # Start dockerfile for Datadumper + data_dumper_docker_info = get_data_dumper_docker_info() + # If not last run within 24 hours, run dump operation in Datadumper # Check cache if exists and checker = TimestampChecker() @@ -343,11 +342,20 @@ def main(): apply_migrations() # Run `fastapi dev main.py` - uvicorn.run( - "api.main:app", - host="0.0.0.0", - port=8000 - ) + try: + uvicorn.run( + "api.main:app", + host="0.0.0.0", + port=8000 + ) + finally: + # Add feature to stop all running containers + print("Stopping containers...") + for container in docker_manager.client.containers.list(): + container.stop() + + print("Containers stopped.") + diff --git a/tests/conftest.py b/tests/conftest.py index 8aeb6dc6..d7b1bce7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import pytest -from alembic import command from alembic.config import Config from sqlalchemy import create_engine, inspect, MetaData from sqlalchemy.orm import scoped_session, sessionmaker @@ -8,12 +7,42 @@ from collector_db.DatabaseClient import DatabaseClient from collector_db.helper_functions import get_postgres_connection_string from collector_db.models import Base +from core.EnvVarManager import EnvVarManager from tests.helpers.AlembicRunner import AlembicRunner from tests.helpers.DBDataCreator import DBDataCreator +from util.helper_functions import load_from_environment @pytest.fixture(autouse=True, scope="session") def setup_and_teardown(): + # Set up environment variables that must be defined + # outside of tests + required_env_vars: dict = load_from_environment( + keys=[ + "POSTGRES_USER", + "POSTGRES_PASSWORD", + "POSTGRES_HOST", + "POSTGRES_PORT", + "POSTGRES_DB", + ] + ) + # Add test environment variables + test_env_vars = [ + "GOOGLE_API_KEY", + "GOOGLE_CSE_ID", + "PDAP_EMAIL", + "PDAP_PASSWORD", + "PDAP_API_KEY", + "PDAP_API_URL", + "DISCORD_WEBHOOK_URL", + "OPENAI_API_KEY", + ] + all_env_vars = required_env_vars.copy() + for env_var in test_env_vars: + all_env_vars[env_var] = "TEST" + + EnvVarManager.override(all_env_vars) + conn = get_postgres_connection_string() engine = create_engine(conn) alembic_cfg = Config("alembic.ini") diff --git a/tests/helpers/DBDataCreator.py b/tests/helpers/DBDataCreator.py index 3cbdb11b..613bfe4d 100644 --- a/tests/helpers/DBDataCreator.py +++ b/tests/helpers/DBDataCreator.py @@ -23,13 +23,17 @@ class BatchURLCreationInfo(BaseModel): batch_id: int url_ids: list[int] + urls: list[str] class DBDataCreator: """ Assists in the creation of test data """ - def __init__(self, db_client: DatabaseClient = DatabaseClient()): - self.db_client = db_client + def __init__(self, db_client: Optional[DatabaseClient] = None): + if db_client is not None: + self.db_client = db_client + else: + self.db_client = DatabaseClient() self.adb_client: AsyncDatabaseClient = AsyncDatabaseClient() def batch(self, strategy: CollectorType = CollectorType.EXAMPLE) -> int: @@ -61,7 +65,11 @@ async def batch_and_urls( if with_html_content: await self.html_data(url_ids) - return BatchURLCreationInfo(batch_id=batch_id, url_ids=url_ids) + return BatchURLCreationInfo( + batch_id=batch_id, + url_ids=url_ids, + urls=[iui.url for iui in iuis.url_mappings] + ) async def agency(self) -> int: agency_id = randint(1, 99999999) @@ -186,6 +194,7 @@ def urls( URLInfo( url=url, outcome=outcome, + name="Test Name" if outcome == URLStatus.VALIDATED else None, collector_metadata=collector_metadata ) ) diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index b466bfbb..ae34b28e 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -1,9 +1,6 @@ -import asyncio -import logging -import os from dataclasses import dataclass from typing import Generator -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, patch import pytest import pytest_asyncio @@ -11,7 +8,6 @@ from api.main import app from core.AsyncCore import AsyncCore -from core.AsyncCoreLogger import AsyncCoreLogger from core.SourceCollectorCore import SourceCollectorCore from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions from tests.helpers.DBDataCreator import DBDataCreator @@ -48,9 +44,7 @@ def override_access_info() -> AccessInfo: @pytest.fixture(scope="session") def client() -> Generator[TestClient, None, None]: - # Mock envioronment - _original_env = dict(os.environ) - os.environ["DISCORD_WEBHOOK_URL"] = "https://discord.com" + # Mock environment with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info async_core: AsyncCore = c.app.state.async_core @@ -67,8 +61,6 @@ def client() -> Generator[TestClient, None, None]: yield c # Reset environment variables back to original state - os.environ.clear() - os.environ.update(_original_env) @pytest_asyncio.fixture diff --git a/tests/test_automated/integration/api/test_annotate.py b/tests/test_automated/integration/api/test_annotate.py index 0e462ba5..0501ac1f 100644 --- a/tests/test_automated/integration/api/test_annotate.py +++ b/tests/test_automated/integration/api/test_annotate.py @@ -1,15 +1,12 @@ -from typing import Any import pytest from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo from collector_db.DTOs.URLMapping import URLMapping -from collector_db.enums import URLMetadataAttributeType, ValidationStatus, ValidationSource from collector_db.models import UserUrlAgencySuggestion, UserRelevantSuggestion, UserRecordTypeSuggestion from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import URLAgencyAnnotationPostInfo -from core.DTOs.GetNextURLForAnnotationResponse import GetNextURLForAnnotationResponse from core.DTOs.RecordTypeAnnotationPostInfo import RecordTypeAnnotationPostInfo from core.DTOs.RelevanceAnnotationPostInfo import RelevanceAnnotationPostInfo from core.enums import RecordType, SuggestionType diff --git a/tests/test_automated/integration/api/test_duplicates.py b/tests/test_automated/integration/api/test_duplicates.py index c42b894d..a5c77b29 100644 --- a/tests/test_automated/integration/api/test_duplicates.py +++ b/tests/test_automated/integration/api/test_duplicates.py @@ -1,5 +1,4 @@ import time -from unittest.mock import AsyncMock from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO @@ -30,7 +29,7 @@ def test_duplicates(api_test_helper): assert batch_id_2 is not None - time.sleep(2) + time.sleep(1.5) bi_1: BatchInfo = ath.request_validator.get_batch_info(batch_id_1) bi_2: BatchInfo = ath.request_validator.get_batch_info(batch_id_2) diff --git a/tests/test_automated/integration/collector_db/test_database_structure.py b/tests/test_automated/integration/collector_db/test_database_structure.py index 2b2fcbca..6d82631c 100644 --- a/tests/test_automated/integration/collector_db/test_database_structure.py +++ b/tests/test_automated/integration/collector_db/test_database_structure.py @@ -52,9 +52,11 @@ def __init__( self, columns: list[ColumnTester], table_name: str, - engine: sa.Engine = create_engine(get_postgres_connection_string()), + engine: Optional[sa.Engine] = None, constraints: Optional[list[ConstraintTester]] = None, ): + if engine is None: + engine = create_engine(get_postgres_connection_string(is_async=True)) self.columns = columns self.table_name = table_name self.constraints = constraints @@ -228,6 +230,11 @@ def test_url(db_data_creator: DBDataCreator): column_name="outcome", type_=postgresql.ENUM, allowed_values=get_enum_values(URLStatus) + ), + ColumnTester( + column_name="name", + type_=sa.String, + allowed_values=['test'], ) ], engine=db_data_creator.db_client.engine diff --git a/tests/test_automated/integration/collector_db/test_db_client.py b/tests/test_automated/integration/collector_db/test_db_client.py index c78bf57e..7b98728f 100644 --- a/tests/test_automated/integration/collector_db/test_db_client.py +++ b/tests/test_automated/integration/collector_db/test_db_client.py @@ -60,11 +60,12 @@ async def test_insert_urls( assert insert_urls_info.original_count == 2 assert insert_urls_info.duplicate_count == 1 - -def test_insert_logs(db_data_creator: DBDataCreator): +@pytest.mark.asyncio +async def test_insert_logs(db_data_creator: DBDataCreator): batch_id_1 = db_data_creator.batch() batch_id_2 = db_data_creator.batch() + adb_client = db_data_creator.adb_client db_client = db_data_creator.db_client db_client.insert_logs( log_infos=[ @@ -74,26 +75,28 @@ def test_insert_logs(db_data_creator: DBDataCreator): ] ) - logs = db_client.get_logs_by_batch_id(batch_id_1) + logs = await adb_client.get_logs_by_batch_id(batch_id_1) assert len(logs) == 2 - logs = db_client.get_logs_by_batch_id(batch_id_2) + logs = await adb_client.get_logs_by_batch_id(batch_id_2) assert len(logs) == 1 -def test_delete_old_logs(db_data_creator: DBDataCreator): +@pytest.mark.asyncio +async def test_delete_old_logs(db_data_creator: DBDataCreator): batch_id = db_data_creator.batch() old_datetime = datetime.now() - timedelta(days=1) db_client = db_data_creator.db_client + adb_client = db_data_creator.adb_client log_infos = [] for i in range(3): log_infos.append(LogInfo(log="test log", batch_id=batch_id, created_at=old_datetime)) db_client.insert_logs(log_infos=log_infos) - logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 3 db_client.delete_old_logs() - logs = db_client.get_logs_by_batch_id(batch_id=batch_id) + logs = await adb_client.get_logs_by_batch_id(batch_id=batch_id) assert len(logs) == 0 def test_delete_url_updated_at(db_data_creator: DBDataCreator): diff --git a/tests/test_automated/integration/core/test_async_core.py b/tests/test_automated/integration/core/test_async_core.py index b4b8e740..ed314dfd 100644 --- a/tests/test_automated/integration/core/test_async_core.py +++ b/tests/test_automated/integration/core/test_async_core.py @@ -21,6 +21,7 @@ def setup_async_core(adb_client: AsyncDatabaseClient): url_request_interface=AsyncMock(), html_parser=AsyncMock(), discord_poster=AsyncMock(), + pdap_client=AsyncMock() ), collector_manager=AsyncMock() ) diff --git a/tests/test_automated/integration/tasks/test_submit_approved_url_task.py b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py new file mode 100644 index 00000000..b15ff9d5 --- /dev/null +++ b/tests/test_automated/integration/tasks/test_submit_approved_url_task.py @@ -0,0 +1,217 @@ +from http import HTTPStatus +from unittest.mock import MagicMock, AsyncMock + +import pytest + +from collector_db.enums import TaskType +from collector_db.models import URL, URLErrorInfo +from collector_manager.enums import URLStatus +from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo +from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome +from core.classes.SubmitApprovedURLTaskOperator import SubmitApprovedURLTaskOperator +from core.enums import RecordType, SubmitResponseStatus +from tests.helpers.DBDataCreator import BatchURLCreationInfo, DBDataCreator +from pdap_api_client.AccessManager import AccessManager +from pdap_api_client.DTOs import RequestInfo, RequestType, ResponseInfo +from pdap_api_client.PDAPClient import PDAPClient + + +def mock_make_request(pdap_client: PDAPClient, urls: list[str]): + assert len(urls) == 3, "Expected 3 urls" + pdap_client.access_manager.make_request = AsyncMock( + return_value=ResponseInfo( + status_code=HTTPStatus.OK, + data={ + "data_sources": [ + { + "url": urls[0], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 21, + }, + { + "url": urls[1], + "status": SubmitResponseStatus.SUCCESS, + "error": None, + "data_source_id": 34, + }, + { + "url": urls[2], + "status": SubmitResponseStatus.FAILURE, + "error": "Test Error", + "data_source_id": None + } + ] + } + ) + ) + +@pytest.fixture +def mock_pdap_client() -> PDAPClient: + mock_access_manager = MagicMock( + spec=AccessManager + ) + mock_access_manager.jwt_header = AsyncMock( + return_value={"Authorization": "Bearer token"} + ) + pdap_client = PDAPClient( + access_manager=mock_access_manager + ) + return pdap_client + +async def setup_validated_urls(db_data_creator: DBDataCreator) -> list[str]: + creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls( + url_count=3, + with_html_content=True + ) + + url_1 = creation_info.url_ids[0] + url_2 = creation_info.url_ids[1] + url_3 = creation_info.url_ids[2] + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_1, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[1, 2], + name="URL 1 Name", + description="URL 1 Description", + record_formats=["Record Format 1", "Record Format 2"], + data_portal_type="Data Portal Type 1", + supplying_entity="Supplying Entity 1" + ), + user_id=1 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_2, + record_type=RecordType.INCARCERATION_RECORDS, + agency_ids=[3, 4], + name="URL 2 Name", + description="URL 2 Description", + ), + user_id=2 + ) + await db_data_creator.adb_client.approve_url( + approval_info=FinalReviewApprovalInfo( + url_id=url_3, + record_type=RecordType.ACCIDENT_REPORTS, + agency_ids=[5, 6], + name="URL 3 Name", + description="URL 3 Description", + ), + user_id=3 + ) + return creation_info.urls + +@pytest.mark.asyncio +async def test_submit_approved_url_task( + db_data_creator, + mock_pdap_client: PDAPClient, + monkeypatch +): + """ + The submit_approved_url_task should submit + all validated URLs to the PDAP Data Sources App + """ + + + # Get Task Operator + operator = SubmitApprovedURLTaskOperator( + adb_client=db_data_creator.adb_client, + pdap_client=mock_pdap_client + ) + + # Check Task Operator does not yet meet pre-requisites + assert not await operator.meets_task_prerequisites() + + # Create URLs with status 'validated' in database and all requisite URL values + # Ensure they have optional metadata as well + urls = await setup_validated_urls(db_data_creator) + mock_make_request(mock_pdap_client, urls) + + # Check Task Operator does meet pre-requisites + assert await operator.meets_task_prerequisites() + + # Run Task + task_id = await db_data_creator.adb_client.initiate_task( + task_type=TaskType.SUBMIT_APPROVED + ) + run_info = await operator.run_task(task_id=task_id) + + # Check Task has been marked as completed + assert run_info.outcome == TaskOperatorOutcome.SUCCESS, run_info.message + + # Get URLs + urls = await db_data_creator.adb_client.get_all(URL, order_by_attribute="id") + url_1 = urls[0] + url_2 = urls[1] + url_3 = urls[2] + + # Check URLs have been marked as 'submitted' + assert url_1.outcome == URLStatus.SUBMITTED.value + assert url_2.outcome == URLStatus.SUBMITTED.value + assert url_3.outcome == URLStatus.ERROR.value + + # Check URLs now have data source ids + assert url_1.data_source_id == 21 + assert url_2.data_source_id == 34 + assert url_3.data_source_id is None + + # Check that errored URL has entry in url_error_info + url_errors = await db_data_creator.adb_client.get_all(URLErrorInfo) + assert len(url_errors) == 1 + url_error = url_errors[0] + assert url_error.url_id == url_3.id + assert url_error.error == "Test Error" + + # Check mock method was called expected parameters + access_manager = mock_pdap_client.access_manager + access_manager.make_request.assert_called_once() + + call_1 = access_manager.make_request.call_args_list[0][0][0] + expected_call_1 = RequestInfo( + type_=RequestType.POST, + url="TEST/source-collector/data-sources", + headers=access_manager.jwt_header.return_value, + json={ + "data_sources": [ + { + "name": "URL 1 Name", + "source_url": url_1.url, + "record_type": "Accident Reports", + "description": "URL 1 Description", + "record_formats": ["Record Format 1", "Record Format 2"], + "data_portal_type": "Data Portal Type 1", + "last_approval_editor": 1, + "supplying_entity": "Supplying Entity 1", + "agency_ids": [1, 2] + }, + { + "name": "URL 2 Name", + "source_url": url_2.url, + "record_type": "Incarceration Records", + "description": "URL 2 Description", + "last_approval_editor": 2, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [3, 4] + }, + { + "name": "URL 3 Name", + "source_url": url_3.url, + "record_type": "Accident Reports", + "description": "URL 3 Description", + "last_approval_editor": 3, + "supplying_entity": None, + "record_formats": None, + "data_portal_type": None, + "agency_ids": [5, 6] + } + ] + } + ) + assert call_1.type_ == expected_call_1.type_ + assert call_1.url == expected_call_1.url + assert call_1.headers == expected_call_1.headers + assert call_1.json == expected_call_1.json diff --git a/util/DiscordNotifier.py b/util/DiscordNotifier.py index 15e74020..6df1aa90 100644 --- a/util/DiscordNotifier.py +++ b/util/DiscordNotifier.py @@ -10,4 +10,10 @@ def __init__(self, webhook_url: str): raise ValueError("WEBHOOK_URL environment variable not set") self.webhook_url = webhook_url def post_to_discord(self, message): - requests.post(self.webhook_url, json={"content": message}) + try: + requests.post(self.webhook_url, json={"content": message}) + except Exception as e: + logging.error( + f"Error posting message to Discord: {e}." + f"\n\nMessage: {message}" + ) diff --git a/util/helper_functions.py b/util/helper_functions.py index bf72d39b..7d6c7f8d 100644 --- a/util/helper_functions.py +++ b/util/helper_functions.py @@ -16,6 +16,19 @@ def get_from_env(key: str, allow_none: bool = False): raise ValueError(f"Environment variable {key} is not set") return val +def load_from_environment(keys: list[str]) -> dict[str, str]: + """ + Load selected keys from environment, returning a dictionary + """ + original_environment = os.environ.copy() + try: + load_dotenv() + return {key: os.getenv(key) for key in keys} + finally: + # Restore the original environment + os.environ.clear() + os.environ.update(original_environment) + def base_model_list_dump(model_list: list[BaseModel]) -> list[dict]: return [model.model_dump() for model in model_list]