From d68ab306c7fcd4ddf648e60a5626280fc2df1f7c Mon Sep 17 00:00:00 2001 From: Max Chis Date: Thu, 10 Apr 2025 14:46:31 -0400 Subject: [PATCH] feat(app): enable task loop to repeat if prerequisites met The task loop has been modified such that, if prerequisites continue to be met, the same task loop will run again. In the case of the task looping more than 20 times, the task loop is set to break and discord notified as an indicator of potentially unwelcome activity. --- ENV.md | 1 + api/main.py | 4 ++ core/AsyncCore.py | 34 +++++++---- .../integration/api/conftest.py | 3 +- .../integration/core/test_async_core.py | 59 ++++++++++++++++--- .../security_manager/test_security_manager.py | 8 ++- util/DiscordNotifier.py | 13 ++++ 7 files changed, 98 insertions(+), 24 deletions(-) create mode 100644 util/DiscordNotifier.py diff --git a/ENV.md b/ENV.md index 68359348..cdedd288 100644 --- a/ENV.md +++ b/ENV.md @@ -21,4 +21,5 @@ Please ensure these are properly defined in a `.env` file in the root directory. |`PDAP_EMAIL`| An email address for accessing the PDAP API. | `abc123@test.com` | |`PDAP_PASSWORD`| A password for accessing the PDAP API. | `abc123` | |`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` | +|`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` | diff --git a/api/main.py b/api/main.py index 8feaa165..f39cc7f3 100644 --- a/api/main.py +++ b/api/main.py @@ -20,6 +20,7 @@ from html_tag_collector.RootURLCache import RootURLCache from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface +from util.DiscordNotifier import DiscordPoster from util.helper_functions import get_from_env @@ -40,6 +41,9 @@ async def lifespan(app: FastAPI): url_request_interface=URLRequestInterface(), html_parser=HTMLResponseParser( root_url_cache=RootURLCache() + ), + discord_poster=DiscordPoster( + webhook_url=get_from_env("DISCORD_WEBHOOK_URL") ) ) async_scheduled_task_manager = AsyncScheduledTaskManager(async_core=async_core) diff --git a/core/AsyncCore.py b/core/AsyncCore.py index 28a14fa2..d95efbfe 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -1,22 +1,18 @@ import logging from typing import Optional -from aiohttp import ClientSession from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.TaskInfo import TaskInfo -from collector_db.DTOs.URLAnnotationInfo import URLAnnotationInfo -from collector_db.enums import TaskType, URLMetadataAttributeType +from collector_db.enums import TaskType from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \ URLAgencyAnnotationPostInfo -from core.DTOs.GetNextURLForAnnotationResponse import GetNextURLForAnnotationResponse from core.DTOs.GetTasksResponse import GetTasksResponse from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo -from core.DTOs.AnnotationRequestInfo import AnnotationRequestInfo from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator from core.classes.TaskOperatorBase import TaskOperatorBase @@ -24,8 +20,7 @@ from core.classes.URLMiscellaneousMetadataTaskOperator import URLMiscellaneousMetadataTaskOperator from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator from core.classes.URLRelevanceHuggingfaceTaskOperator import URLRelevanceHuggingfaceTaskOperator -from core.enums import BatchStatus, SuggestionType, RecordType -from html_tag_collector.DataClassTags import convert_to_response_html_info +from core.enums import BatchStatus, RecordType from html_tag_collector.ResponseParser import HTMLResponseParser from html_tag_collector.URLRequestInterface import URLRequestInterface from hugging_face.HuggingFaceInterface import HuggingFaceInterface @@ -33,8 +28,10 @@ from pdap_api_client.AccessManager import AccessManager from pdap_api_client.PDAPClient import PDAPClient from security_manager.SecurityManager import AccessInfo +from util.DiscordNotifier import DiscordPoster from util.helper_functions import get_from_env +TASK_REPEAT_THRESHOLD = 20 class AsyncCore: @@ -44,6 +41,7 @@ def __init__( huggingface_interface: HuggingFaceInterface, url_request_interface: URLRequestInterface, html_parser: HTMLResponseParser, + discord_poster: DiscordPoster ): self.adb_client = adb_client self.huggingface_interface = huggingface_interface @@ -52,6 +50,7 @@ def __init__( self.logger = logging.getLogger(__name__) self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) + self.discord_poster = discord_poster async def get_urls(self, page: int, errors: bool) -> GetURLsResponseInfo: @@ -119,14 +118,23 @@ async def get_task_operators(self) -> list[TaskOperatorBase]: #region Tasks async def run_tasks(self): operators = await self.get_task_operators() + count = 0 for operator in operators: + meets_prereq = await operator.meets_task_prerequisites() - if not meets_prereq: - self.logger.info(f"Skipping {operator.task_type.value} Task") - continue - task_id = await self.initiate_task_in_db(task_type=operator.task_type) - run_info: TaskOperatorRunInfo = await operator.run_task(task_id) - await self.conclude_task(run_info) + while meets_prereq: + if count > TASK_REPEAT_THRESHOLD: + self.discord_poster.post_to_discord( + message=f"Task {operator.task_type.value} has been run" + f" more than {TASK_REPEAT_THRESHOLD} times in a row. " + f"Task loop terminated.") + break + task_id = await self.initiate_task_in_db(task_type=operator.task_type) + run_info: TaskOperatorRunInfo = await operator.run_task(task_id) + await self.conclude_task(run_info) + count += 1 + meets_prereq = await operator.meets_task_prerequisites() + async def conclude_task(self, run_info): await self.adb_client.link_urls_to_task(task_id=run_info.task_id, url_ids=run_info.linked_url_ids) diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index d9a504a7..2065463e 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -30,7 +30,8 @@ def override_access_info() -> AccessInfo: return AccessInfo(user_id=MOCK_USER_ID, permissions=[Permissions.SOURCE_COLLECTOR]) @pytest.fixture -def client(db_client_test) -> Generator[TestClient, None, None]: +def client(db_client_test, monkeypatch) -> Generator[TestClient, None, None]: + monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com") with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info core: SourceCollectorCore = c.app.state.core diff --git a/tests/test_automated/integration/core/test_async_core.py b/tests/test_automated/integration/core/test_async_core.py index 1bb09809..4aa51b77 100644 --- a/tests/test_automated/integration/core/test_async_core.py +++ b/tests/test_automated/integration/core/test_async_core.py @@ -1,5 +1,5 @@ import types -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, call import pytest @@ -27,7 +27,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -53,7 +54,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -80,7 +82,8 @@ async def test_conclude_task_error(db_data_creator: DBDataCreator): adb_client=ddc.adb_client, huggingface_interface=MagicMock(), url_request_interface=MagicMock(), - html_parser=MagicMock() + html_parser=MagicMock(), + discord_poster=MagicMock() ) await core.conclude_task(run_info=run_info) @@ -96,7 +99,8 @@ async def test_run_task_prereq_not_met(): adb_client=AsyncMock(), huggingface_interface=AsyncMock(), url_request_interface=AsyncMock(), - html_parser=AsyncMock() + html_parser=AsyncMock(), + discord_poster=MagicMock() ) mock_operator = AsyncMock() @@ -121,19 +125,22 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo: adb_client=db_data_creator.adb_client, huggingface_interface=AsyncMock(), url_request_interface=AsyncMock(), - html_parser=AsyncMock() + html_parser=AsyncMock(), + discord_poster=MagicMock() ) core.conclude_task = AsyncMock() mock_operator = AsyncMock() - mock_operator.meets_task_prerequisites = AsyncMock(return_value=True) + mock_operator.meets_task_prerequisites = AsyncMock( + side_effect=[True, False] + ) mock_operator.task_type = TaskType.HTML mock_operator.run_task = types.MethodType(run_task, mock_operator) AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator]) await core.run_tasks() - mock_operator.meets_task_prerequisites.assert_called_once() + mock_operator.meets_task_prerequisites.assert_has_calls([call(), call()]) results = await db_data_creator.adb_client.get_all(Task) @@ -142,3 +149,39 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo: core.conclude_task.assert_called_once() +@pytest.mark.asyncio +async def test_run_task_break_loop(db_data_creator: DBDataCreator): + """ + If the task loop for a single task runs more than 20 times in a row, + this is considered suspicious and possibly indicative of a bug. + In this case, the task loop should be terminated + and an alert should be sent to discord + """ + + async def run_task(self, task_id: int) -> TaskOperatorRunInfo: + return TaskOperatorRunInfo( + task_id=task_id, + outcome=TaskOperatorOutcome.SUCCESS, + linked_url_ids=[1, 2, 3] + ) + + core = AsyncCore( + adb_client=db_data_creator.adb_client, + huggingface_interface=AsyncMock(), + url_request_interface=AsyncMock(), + html_parser=AsyncMock(), + discord_poster=MagicMock() + ) + core.conclude_task = AsyncMock() + + mock_operator = AsyncMock() + mock_operator.meets_task_prerequisites = AsyncMock(return_value=True) + mock_operator.task_type = TaskType.HTML + mock_operator.run_task = types.MethodType(run_task, mock_operator) + + AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator]) + await core.run_tasks() + + core.discord_poster.post_to_discord.assert_called_once_with( + message="Task HTML has been run more than 20 times in a row. Task loop terminated." + ) diff --git a/tests/test_automated/integration/security_manager/test_security_manager.py b/tests/test_automated/integration/security_manager/test_security_manager.py index 3dc676ad..eb7e8506 100644 --- a/tests/test_automated/integration/security_manager/test_security_manager.py +++ b/tests/test_automated/integration/security_manager/test_security_manager.py @@ -18,12 +18,16 @@ def mock_get_secret_key(mocker): VALID_TOKEN = "valid_token" INVALID_TOKEN = "invalid_token" FAKE_PAYLOAD = { - "sub": 1, + "sub": "1", "permissions": [Permissions.SOURCE_COLLECTOR.value] } -def test_api_with_valid_token(mock_get_secret_key): +def test_api_with_valid_token( + mock_get_secret_key, + monkeypatch +): + monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com") token = jwt.encode(FAKE_PAYLOAD, SECRET_KEY, algorithm=ALGORITHM) # Create Test Client diff --git a/util/DiscordNotifier.py b/util/DiscordNotifier.py new file mode 100644 index 00000000..15e74020 --- /dev/null +++ b/util/DiscordNotifier.py @@ -0,0 +1,13 @@ +import logging + +import requests + + +class DiscordPoster: + def __init__(self, webhook_url: str): + if not webhook_url: + logging.error("WEBHOOK_URL environment variable not set") + raise ValueError("WEBHOOK_URL environment variable not set") + self.webhook_url = webhook_url + def post_to_discord(self, message): + requests.post(self.webhook_url, json={"content": message})