diff --git a/ENV.md b/ENV.md index 68359348..cdedd288 100644 --- a/ENV.md +++ b/ENV.md @@ -21,4 +21,5 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`PDAP_EMAIL`| An email address for accessing the PDAP API. | `abc123@test.com` | |`PDAP_PASSWORD`| A password for accessing the PDAP API. | `abc123` | |`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | +|`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` | diff --git a/api/main.py b/api/main.py index 8feaa165..f39cc7f3 100644 --- a/api/main.py +++ b/api/main.py @@ -20,6 +20,7 @@ from html_tag_collector.RootURLCache import RootURLCache from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface +from util.DiscordNotifier import DiscordPoster from util.helper_functions import get_from_env @@ -40,6 +41,9 @@ async def lifespan(app: FastAPI): url_request_interface=URLRequestInterface(), html_parser=HTMLResponseParser( root_url_cache=RootURLCache() + ), + discord_poster=DiscordPoster( + webhook_url=get_from_env("DISCORD_WEBHOOK_URL") ) ) async_scheduled_task_manager = AsyncScheduledTaskManager(async_core=async_core) diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 28a14fa2..d95efbfe 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -1,22 +1,18 @@ import logging from typing import Optional -from aiohttp import ClientSession from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.TaskInfo import TaskInfo -from collector_db.DTOs.URLAnnotationInfo import URLAnnotationInfo -from collector_db.enums import TaskType, URLMetadataAttributeType +from collector_db.enums import TaskType from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ URLAgencyAnnotationPostInfo -from core.DTOs.GetNextURLForAnnotationResponse import GetNextURLForAnnotationResponse from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo -from core.DTOs.AnnotationRequestInfo import AnnotationRequestInfo from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator from core.classes.TaskOperatorBase import TaskOperatorBase @@ -24,8 +20,7 @@ from core.classes.URLMiscellaneousMetadataTaskOperator import URLMiscellaneousMetadataTaskOperator from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator from core.classes.URLRelevanceHuggingfaceTaskOperator import URLRelevanceHuggingfaceTaskOperator -from core.enums import BatchStatus, SuggestionType, RecordType -from html_tag_collector.DataClassTags import convert_to_response_html_info +from core.enums import BatchStatus, RecordType from html_tag_collector.ResponseParser import HTMLResponseParser from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface @@ -33,8 +28,10 @@ from pdap_api_client.AccessManager import AccessManager from pdap_api_client.PDAPClient import PDAPClient from security_manager.SecurityManager import AccessInfo +from util.DiscordNotifier import DiscordPoster from util.helper_functions import get_from_env +TASK_REPEAT_THRESHOLD = 20 class AsyncCore: @@ -44,6 +41,7 @@ def __init__( huggingface_interface: HuggingFaceInterface, url_request_interface: URLRequestInterface, html_parser: HTMLResponseParser, + discord_poster: DiscordPoster ): self.adb_client = adb_client self.huggingface_interface = huggingface_interface @@ -52,6 +50,7 @@ def __init__( self.logger = logging.getLogger(__name__) self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) + self.discord_poster = discord_poster async def get_urls(self, page: int, errors: bool) -> GetURLsResponseInfo: @@ -119,14 +118,23 @@ async def get_task_operators(self) -> list[TaskOperatorBase]: #region Tasks async def run_tasks(self): operators = await self.get_task_operators() + count = 0 for operator in operators: + meets_prereq = await operator.meets_task_prerequisites() - if not meets_prereq: - self.logger.info(f"Skipping {operator.task_type.value} Task") - continue - task_id = await self.initiate_task_in_db(task_type=operator.task_type) - run_info: TaskOperatorRunInfo = await operator.run_task(task_id) - await self.conclude_task(run_info) + while meets_prereq: + if count > TASK_REPEAT_THRESHOLD: + self.discord_poster.post_to_discord( + message=f"Task {operator.task_type.value} has been run" + f" more than {TASK_REPEAT_THRESHOLD} times in a row. " + f"Task loop terminated.") + break + task_id = await self.initiate_task_in_db(task_type=operator.task_type) + run_info: TaskOperatorRunInfo = await operator.run_task(task_id) + await self.conclude_task(run_info) + count += 1 + meets_prereq = await operator.meets_task_prerequisites() + async def conclude_task(self, run_info): await self.adb_client.link_urls_to_task(task_id=run_info.task_id, url_ids=run_info.linked_url_ids) diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index d9a504a7..2065463e 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -30,7 +30,8 @@ def override_access_info() -> AccessInfo: return AccessInfo(user_id=MOCK_USER_ID, permissions=[Permissions.SOURCE_COLLECTOR]) @pytest.fixture -def client(db_client_test) -> Generator[TestClient, None, None]: +def client(db_client_test, monkeypatch) -> Generator[TestClient, None, None]: + monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com") with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info core: SourceCollectorCore = c.app.state.core diff --git a/tests/test_automated/integration/core/test_async_core.py b/tests/test_automated/integration/core/test_async_core.py index 1bb09809..4aa51b77 100644 --- a/tests/test_automated/integration/core/test_async_core.py +++ b/tests/test_automated/integration/core/test_async_core.py @@ -1,5 +1,5 @@ import types -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, call import pytest @@ -27,7 +27,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -53,7 +54,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -80,7 +82,8 @@ async def test_conclude_task_error(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -96,7 +99,8 @@ async def test_run_task_prereq_not_met(): adb_client=AsyncMock(), huggingface_interface=AsyncMock(), url_request_interface=AsyncMock(), - html_parser=AsyncMock() + html_parser=AsyncMock(), + discord_poster=MagicMock() ) mock_operator = AsyncMock() @@ -121,19 +125,22 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo: adb_client=db_data_creator.adb_client, huggingface_interface=AsyncMock(), url_request_interface=AsyncMock(), - html_parser=AsyncMock() + html_parser=AsyncMock(), + discord_poster=MagicMock() ) core.conclude_task = AsyncMock() mock_operator = AsyncMock() - mock_operator.meets_task_prerequisites = AsyncMock(return_value=True) + mock_operator.meets_task_prerequisites = AsyncMock( + side_effect=[True, False] + ) mock_operator.task_type = TaskType.HTML mock_operator.run_task = types.MethodType(run_task, mock_operator) AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator]) await core.run_tasks() - mock_operator.meets_task_prerequisites.assert_called_once() + mock_operator.meets_task_prerequisites.assert_has_calls([call(), call()]) results = await db_data_creator.adb_client.get_all(Task) @@ -142,3 +149,39 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo: core.conclude_task.assert_called_once() +@pytest.mark.asyncio +async def test_run_task_break_loop(db_data_creator: DBDataCreator): + """ + If the task loop for a single task runs more than 20 times in a row, + this is considered suspicious and possibly indicative of a bug. + In this case, the task loop should be terminated + and an alert should be sent to discord + """ + + async def run_task(self, task_id: int) -> TaskOperatorRunInfo: + return TaskOperatorRunInfo( + task_id=task_id, + outcome=TaskOperatorOutcome.SUCCESS, + linked_url_ids=[1, 2, 3] + ) + + core = AsyncCore( + adb_client=db_data_creator.adb_client, + huggingface_interface=AsyncMock(), + url_request_interface=AsyncMock(), + html_parser=AsyncMock(), + discord_poster=MagicMock() + ) + core.conclude_task = AsyncMock() + + mock_operator = AsyncMock() + mock_operator.meets_task_prerequisites = AsyncMock(return_value=True) + mock_operator.task_type = TaskType.HTML + mock_operator.run_task = types.MethodType(run_task, mock_operator) + + AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator]) + await core.run_tasks() + + core.discord_poster.post_to_discord.assert_called_once_with( + message="Task HTML has been run more than 20 times in a row. Task loop terminated." + ) diff --git a/tests/test_automated/integration/security_manager/test_security_manager.py b/tests/test_automated/integration/security_manager/test_security_manager.py index 3dc676ad..eb7e8506 100644 --- a/tests/test_automated/integration/security_manager/test_security_manager.py +++ b/tests/test_automated/integration/security_manager/test_security_manager.py @@ -18,12 +18,16 @@ def mock_get_secret_key(mocker): VALID_TOKEN = "valid_token" INVALID_TOKEN = "invalid_token" FAKE_PAYLOAD = { - "sub": 1, + "sub": "1", "permissions": [Permissions.SOURCE_COLLECTOR.value] } -def test_api_with_valid_token(mock_get_secret_key): +def test_api_with_valid_token( + mock_get_secret_key, + monkeypatch +): + monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com") token = jwt.encode(FAKE_PAYLOAD, SECRET_KEY, algorithm=ALGORITHM) # Create Test Client diff --git a/util/DiscordNotifier.py b/util/DiscordNotifier.py new file mode 100644 index 00000000..15e74020 --- /dev/null +++ b/util/DiscordNotifier.py @@ -0,0 +1,13 @@ +import logging + +import requests + + +class DiscordPoster: + def __init__(self, webhook_url: str): + if not webhook_url: + logging.error("WEBHOOK_URL environment variable not set") + raise ValueError("WEBHOOK_URL environment variable not set") + self.webhook_url = webhook_url + def post_to_discord(self, message): + requests.post(self.webhook_url, json={"content": message})