diff --git a/alembic/versions/2025_05_11_1054-9d4002437ebe_set_default_created_at_for_backlog_.py b/alembic/versions/2025_05_11_1054-9d4002437ebe_set_default_created_at_for_backlog_.py index f45fee4b..fbdb5645 100644 --- a/alembic/versions/2025_05_11_1054-9d4002437ebe_set_default_created_at_for_backlog_.py +++ b/alembic/versions/2025_05_11_1054-9d4002437ebe_set_default_created_at_for_backlog_.py @@ -30,7 +30,7 @@ def upgrade() -> None: def downgrade() -> None: op.alter_column( - table_name='backlog_snapshots', + table_name='backlog_snapshot', column_name='created_at', existing_type=sa.DateTime(), nullable=False, diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index de0bd36a..5d28f70f 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -776,6 +776,8 @@ async def has_urls_without_agency_suggestions( statement = ( select( URL.id + ).where( + URL.outcome == URLStatus.PENDING.value ) ) @@ -797,6 +799,7 @@ async def get_urls_without_agency_suggestions(self, session: AsyncSession) -> li statement = ( select(URL.id, URL.collector_metadata, Batch.strategy) + .where(URL.outcome == URLStatus.PENDING.value) .join(Batch) ) statement = self.statement_composer.exclude_urls_with_agency_suggestions(statement) diff --git a/core/TaskManager.py b/core/TaskManager.py index 4761a62b..052bdbc8 100644 --- a/core/TaskManager.py +++ b/core/TaskManager.py @@ -101,7 +101,7 @@ async def get_task_operators(self) -> list[TaskOperatorBase]: await self.get_url_html_task_operator(), # await self.get_url_relevance_huggingface_task_operator(), await self.get_url_record_type_task_operator(), - # await self.get_agency_identification_task_operator(), + await self.get_agency_identification_task_operator(), await self.get_url_miscellaneous_metadata_task_operator(), await self.get_submit_approved_url_task_operator() ] @@ -122,10 +122,9 @@ async def run_tasks(self): while meets_prereq: print(f"Running {operator.task_type.value} Task") if count > TASK_REPEAT_THRESHOLD: - self.discord_poster.post_to_discord( - message=f"Task {operator.task_type.value} has been run" - f" more than {TASK_REPEAT_THRESHOLD} times in a row. " - f"Task loop terminated.") + message = f"Task {operator.task_type.value} has been run more than {TASK_REPEAT_THRESHOLD} times in a row. Task loop terminated." + print(message) + self.discord_poster.post_to_discord(message=message) break task_id = await self.initiate_task_in_db(task_type=operator.task_type) run_info: TaskOperatorRunInfo = await operator.run_task(task_id) diff --git a/hugging_face/HuggingFaceInterface.py b/hugging_face/HuggingFaceInterface.py index 9ad11d0b..3dff8ccd 100644 --- a/hugging_face/HuggingFaceInterface.py +++ b/hugging_face/HuggingFaceInterface.py @@ -1,5 +1,6 @@ import asyncio import json +import os import sys from typing import List @@ -17,17 +18,22 @@ async def get_url_relevancy_async(urls_with_html: List[URLWithHTML]) -> List[boo stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, + env=os.environ.copy(), # ⬅️ ensure env variables are inherited ) stdout, stderr = await proc.communicate(input=input_data.encode("utf-8")) + print(stderr) raw_output = stdout.decode("utf-8").strip() + if proc.returncode != 0: + raise RuntimeError(f"Error running HuggingFace: {stderr}/{raw_output}") + # Try to extract the actual JSON line for line in raw_output.splitlines(): try: return json.loads(line) - except json.JSONDecodeError: + except json.JSONDecodeError as e: continue raise RuntimeError(f"Could not parse JSON from subprocess: {raw_output}") diff --git a/hugging_face/relevancy_worker.py b/hugging_face/relevancy_worker.py index 5d07d10f..dd158898 100644 --- a/hugging_face/relevancy_worker.py +++ b/hugging_face/relevancy_worker.py @@ -1,3 +1,4 @@ +import os import sys import json from transformers import pipeline @@ -7,6 +8,13 @@ def main(): pipe = pipeline("text-classification", model="PDAP/url-relevance") results = pipe(urls) + + print("Executable:", sys.executable, file=sys.stderr) + print("sys.path:", sys.path, file=sys.stderr) + print("PYTHONPATH:", os.getenv("PYTHONPATH"), file=sys.stderr) + + if len(results) != len(urls): + raise RuntimeError(f"Expected {len(urls)} results, got {len(results)}") bools = [r["score"] >= 0.5 for r in results] print(json.dumps(bools)) diff --git a/local_database/classes/DockerContainer.py b/local_database/classes/DockerContainer.py index ee2ecba9..33b71ce0 100644 --- a/local_database/classes/DockerContainer.py +++ b/local_database/classes/DockerContainer.py @@ -17,6 +17,12 @@ def run_command(self, command: str): def stop(self): self.container.stop() + def log_to_file(self): + logs = self.container.logs(stdout=True, stderr=True) + container_name = self.container.name + with open(f"{container_name}.log", "wb") as f: + f.write(logs) + def wait_for_pg_to_be_ready(self): for i in range(30): exit_code, output = self.container.exec_run("pg_isready") diff --git a/tests/helpers/DBDataCreator.py b/tests/helpers/DBDataCreator.py index 71338d84..38d70cfe 100644 --- a/tests/helpers/DBDataCreator.py +++ b/tests/helpers/DBDataCreator.py @@ -28,7 +28,7 @@ class URLCreationInfo(BaseModel): url_mappings: list[URLMapping] outcome: URLStatus - annotation_info: AnnotationInfo + annotation_info: Optional[AnnotationInfo] = None class BatchURLCreationInfoV2(BaseModel): batch_id: int @@ -109,7 +109,7 @@ async def batch_v2( d[url_parameters.status] = URLCreationInfo( url_mappings=iui.url_mappings, outcome=url_parameters.status, - annotation_info=url_parameters.annotation_info + annotation_info=url_parameters.annotation_info if url_parameters.annotation_info.has_annotations() else None ) return BatchURLCreationInfoV2( batch_id=batch_id, diff --git a/tests/test_automated/integration/tasks/test_agency_preannotation_task.py b/tests/test_automated/integration/tasks/test_agency_preannotation_task.py index cd9556cb..6818c683 100644 --- a/tests/test_automated/integration/tasks/test_agency_preannotation_task.py +++ b/tests/test_automated/integration/tasks/test_agency_preannotation_task.py @@ -5,9 +5,10 @@ import pytest from aiohttp import ClientSession +from tests.helpers.test_batch_creation_parameters import TestBatchCreationParameters, TestURLCreationParameters from source_collectors.muckrock.MuckrockAPIInterface import MuckrockAPIInterface, AgencyLookupResponseType, AgencyLookupResponse from collector_db.models import Agency, AutomatedUrlAgencySuggestion -from collector_manager.enums import CollectorType +from collector_manager.enums import CollectorType, URLStatus from core.DTOs.TaskOperatorRunInfo import TaskOperatorOutcome from core.DTOs.URLAgencySuggestionInfo import URLAgencySuggestionInfo from core.classes.task_operators.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator @@ -20,7 +21,7 @@ from pdap_api_client.DTOs import MatchAgencyResponse, MatchAgencyInfo from pdap_api_client.PDAPClient import PDAPClient from pdap_api_client.enums import MatchAgencyResponseStatus -from tests.helpers.DBDataCreator import DBDataCreator, BatchURLCreationInfo +from tests.helpers.DBDataCreator import DBDataCreator, BatchURLCreationInfo, BatchURLCreationInfoV2 sample_agency_suggestions = [ URLAgencySuggestionInfo( @@ -103,8 +104,25 @@ async def mock_run_subtask( CollectorType.MUCKROCK_ALL_SEARCH, CollectorType.CKAN ]: - creation_info: BatchURLCreationInfo = await db_data_creator.batch_and_urls(strategy=strategy, url_count=1, with_html_content=True) - d[strategy] = creation_info.url_ids[0] + # Create two URLs for each, one pending and one errored + creation_info: BatchURLCreationInfoV2 = await db_data_creator.batch_v2( + parameters=TestBatchCreationParameters( + strategy=strategy, + urls=[ + TestURLCreationParameters( + count=1, + status=URLStatus.PENDING, + with_html_content=True + ), + TestURLCreationParameters( + count=1, + status=URLStatus.ERROR, + with_html_content=True + ) + ] + ) + ) + d[strategy] = creation_info.url_creation_infos[URLStatus.PENDING].url_mappings[0].url_id # Confirm meets prerequisites