Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ENV.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@ Please ensure these are properly defined in a `.env` file in the root directory.
|`PDAP_EMAIL`| An email address for accessing the PDAP API. | `abc123@test.com` |
|`PDAP_PASSWORD`| A password for accessing the PDAP API. | `abc123` |
|`PDAP_API_KEY`| An API key for accessing the PDAP API. | `abc123` |
|`DISCORD_WEBHOOK_URL`| The URL for the Discord webhook used for notifications| `abc123` |

4 changes: 4 additions & 0 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from html_tag_collector.RootURLCache import RootURLCache
from html_tag_collector.URLRequestInterface import URLRequestInterface
from hugging_face.HuggingFaceInterface import HuggingFaceInterface
from util.DiscordNotifier import DiscordPoster
from util.helper_functions import get_from_env


Expand All @@ -40,6 +41,9 @@ async def lifespan(app: FastAPI):
url_request_interface=URLRequestInterface(),
html_parser=HTMLResponseParser(
root_url_cache=RootURLCache()
),
discord_poster=DiscordPoster(
webhook_url=get_from_env("DISCORD_WEBHOOK_URL")
)
)
async_scheduled_task_manager = AsyncScheduledTaskManager(async_core=async_core)
Expand Down
34 changes: 21 additions & 13 deletions core/AsyncCore.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,37 @@
import logging
from typing import Optional

from aiohttp import ClientSession

from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface
from collector_db.AsyncDatabaseClient import AsyncDatabaseClient
from collector_db.DTOs.TaskInfo import TaskInfo
from collector_db.DTOs.URLAnnotationInfo import URLAnnotationInfo
from collector_db.enums import TaskType, URLMetadataAttributeType
from collector_db.enums import TaskType
from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo
from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseOuterInfo
from core.DTOs.GetNextRelevanceAnnotationResponseInfo import GetNextRelevanceAnnotationResponseOuterInfo
from core.DTOs.GetNextURLForAgencyAnnotationResponse import GetNextURLForAgencyAnnotationResponse, \
URLAgencyAnnotationPostInfo
from core.DTOs.GetNextURLForAnnotationResponse import GetNextURLForAnnotationResponse
from core.DTOs.GetTasksResponse import GetTasksResponse
from core.DTOs.GetURLsResponseInfo import GetURLsResponseInfo
from core.DTOs.AnnotationRequestInfo import AnnotationRequestInfo
from core.DTOs.TaskOperatorRunInfo import TaskOperatorRunInfo, TaskOperatorOutcome
from core.classes.AgencyIdentificationTaskOperator import AgencyIdentificationTaskOperator
from core.classes.TaskOperatorBase import TaskOperatorBase
from core.classes.URLHTMLTaskOperator import URLHTMLTaskOperator
from core.classes.URLMiscellaneousMetadataTaskOperator import URLMiscellaneousMetadataTaskOperator
from core.classes.URLRecordTypeTaskOperator import URLRecordTypeTaskOperator
from core.classes.URLRelevanceHuggingfaceTaskOperator import URLRelevanceHuggingfaceTaskOperator
from core.enums import BatchStatus, SuggestionType, RecordType
from html_tag_collector.DataClassTags import convert_to_response_html_info
from core.enums import BatchStatus, RecordType
from html_tag_collector.ResponseParser import HTMLResponseParser
from html_tag_collector.URLRequestInterface import URLRequestInterface
from hugging_face.HuggingFaceInterface import HuggingFaceInterface
from llm_api_logic.OpenAIRecordClassifier import OpenAIRecordClassifier
from pdap_api_client.AccessManager import AccessManager
from pdap_api_client.PDAPClient import PDAPClient
from security_manager.SecurityManager import AccessInfo
from util.DiscordNotifier import DiscordPoster
from util.helper_functions import get_from_env

TASK_REPEAT_THRESHOLD = 20

class AsyncCore:

Expand All @@ -44,6 +41,7 @@ def __init__(
huggingface_interface: HuggingFaceInterface,
url_request_interface: URLRequestInterface,
html_parser: HTMLResponseParser,
discord_poster: DiscordPoster
):
self.adb_client = adb_client
self.huggingface_interface = huggingface_interface
Expand All @@ -52,6 +50,7 @@ def __init__(
self.logger = logging.getLogger(__name__)
self.logger.addHandler(logging.StreamHandler())
self.logger.setLevel(logging.INFO)
self.discord_poster = discord_poster


async def get_urls(self, page: int, errors: bool) -> GetURLsResponseInfo:
Expand Down Expand Up @@ -119,14 +118,23 @@ async def get_task_operators(self) -> list[TaskOperatorBase]:
#region Tasks
async def run_tasks(self):
operators = await self.get_task_operators()
count = 0
for operator in operators:

meets_prereq = await operator.meets_task_prerequisites()
if not meets_prereq:
self.logger.info(f"Skipping {operator.task_type.value} Task")
continue
task_id = await self.initiate_task_in_db(task_type=operator.task_type)
run_info: TaskOperatorRunInfo = await operator.run_task(task_id)
await self.conclude_task(run_info)
while meets_prereq:
if count > TASK_REPEAT_THRESHOLD:
self.discord_poster.post_to_discord(
message=f"Task {operator.task_type.value} has been run"
f" more than {TASK_REPEAT_THRESHOLD} times in a row. "
f"Task loop terminated.")
break
task_id = await self.initiate_task_in_db(task_type=operator.task_type)
run_info: TaskOperatorRunInfo = await operator.run_task(task_id)
await self.conclude_task(run_info)
count += 1
meets_prereq = await operator.meets_task_prerequisites()


async def conclude_task(self, run_info):
await self.adb_client.link_urls_to_task(task_id=run_info.task_id, url_ids=run_info.linked_url_ids)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_automated/integration/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
return AccessInfo(user_id=MOCK_USER_ID, permissions=[Permissions.SOURCE_COLLECTOR])

@pytest.fixture
def client(db_client_test) -> Generator[TestClient, None, None]:
def client(db_client_test, monkeypatch) -> Generator[TestClient, None, None]:

Check warning on line 33 in tests/test_automated/integration/api/conftest.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/test_automated/integration/api/conftest.py#L33 <103>

Missing docstring in public function
Raw output
./tests/test_automated/integration/api/conftest.py:33:1: D103 Missing docstring in public function

Check warning on line 33 in tests/test_automated/integration/api/conftest.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/test_automated/integration/api/conftest.py#L33 <100>

Unused argument 'db_client_test'
Raw output
./tests/test_automated/integration/api/conftest.py:33:12: U100 Unused argument 'db_client_test'
monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com")
with TestClient(app) as c:
app.dependency_overrides[get_access_info] = override_access_info
core: SourceCollectorCore = c.app.state.core
Expand Down
59 changes: 51 additions & 8 deletions tests/test_automated/integration/core/test_async_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import types
from unittest.mock import MagicMock, AsyncMock
from unittest.mock import MagicMock, AsyncMock, call

import pytest

Expand Down Expand Up @@ -27,7 +27,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator):
adb_client=ddc.adb_client,
huggingface_interface=MagicMock(),
url_request_interface=MagicMock(),
html_parser=MagicMock()
html_parser=MagicMock(),
discord_poster=MagicMock()
)
await core.conclude_task(run_info=run_info)

Expand All @@ -53,7 +54,8 @@ async def test_conclude_task_success(db_data_creator: DBDataCreator):
adb_client=ddc.adb_client,
huggingface_interface=MagicMock(),
url_request_interface=MagicMock(),
html_parser=MagicMock()
html_parser=MagicMock(),
discord_poster=MagicMock()
)
await core.conclude_task(run_info=run_info)

Expand All @@ -80,7 +82,8 @@ async def test_conclude_task_error(db_data_creator: DBDataCreator):
adb_client=ddc.adb_client,
huggingface_interface=MagicMock(),
url_request_interface=MagicMock(),
html_parser=MagicMock()
html_parser=MagicMock(),
discord_poster=MagicMock()
)
await core.conclude_task(run_info=run_info)

Expand All @@ -96,7 +99,8 @@ async def test_run_task_prereq_not_met():
adb_client=AsyncMock(),
huggingface_interface=AsyncMock(),
url_request_interface=AsyncMock(),
html_parser=AsyncMock()
html_parser=AsyncMock(),
discord_poster=MagicMock()
)

mock_operator = AsyncMock()
Expand All @@ -121,19 +125,22 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo:
adb_client=db_data_creator.adb_client,
huggingface_interface=AsyncMock(),
url_request_interface=AsyncMock(),
html_parser=AsyncMock()
html_parser=AsyncMock(),
discord_poster=MagicMock()
)
core.conclude_task = AsyncMock()

mock_operator = AsyncMock()
mock_operator.meets_task_prerequisites = AsyncMock(return_value=True)
mock_operator.meets_task_prerequisites = AsyncMock(
side_effect=[True, False]
)
mock_operator.task_type = TaskType.HTML
mock_operator.run_task = types.MethodType(run_task, mock_operator)

AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator])
await core.run_tasks()

mock_operator.meets_task_prerequisites.assert_called_once()
mock_operator.meets_task_prerequisites.assert_has_calls([call(), call()])

results = await db_data_creator.adb_client.get_all(Task)

Expand All @@ -142,3 +149,39 @@ async def run_task(self, task_id: int) -> TaskOperatorRunInfo:

core.conclude_task.assert_called_once()

@pytest.mark.asyncio
async def test_run_task_break_loop(db_data_creator: DBDataCreator):
"""
If the task loop for a single task runs more than 20 times in a row,
this is considered suspicious and possibly indicative of a bug.
In this case, the task loop should be terminated
and an alert should be sent to discord
"""

async def run_task(self, task_id: int) -> TaskOperatorRunInfo:
return TaskOperatorRunInfo(
task_id=task_id,
outcome=TaskOperatorOutcome.SUCCESS,
linked_url_ids=[1, 2, 3]
)

core = AsyncCore(
adb_client=db_data_creator.adb_client,
huggingface_interface=AsyncMock(),
url_request_interface=AsyncMock(),
html_parser=AsyncMock(),
discord_poster=MagicMock()
)
core.conclude_task = AsyncMock()

mock_operator = AsyncMock()
mock_operator.meets_task_prerequisites = AsyncMock(return_value=True)
mock_operator.task_type = TaskType.HTML
mock_operator.run_task = types.MethodType(run_task, mock_operator)

AsyncCore.get_task_operators = AsyncMock(return_value=[mock_operator])
await core.run_tasks()

core.discord_poster.post_to_discord.assert_called_once_with(
message="Task HTML has been run more than 20 times in a row. Task loop terminated."
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
VALID_TOKEN = "valid_token"
INVALID_TOKEN = "invalid_token"
FAKE_PAYLOAD = {
"sub": 1,
"sub": "1",
"permissions": [Permissions.SOURCE_COLLECTOR.value]
}

def test_api_with_valid_token(mock_get_secret_key):
def test_api_with_valid_token(

Check warning on line 25 in tests/test_automated/integration/security_manager/test_security_manager.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/test_automated/integration/security_manager/test_security_manager.py#L25 <103>

Missing docstring in public function
Raw output
./tests/test_automated/integration/security_manager/test_security_manager.py:25:1: D103 Missing docstring in public function
mock_get_secret_key,

Check warning on line 26 in tests/test_automated/integration/security_manager/test_security_manager.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/test_automated/integration/security_manager/test_security_manager.py#L26 <100>

Unused argument 'mock_get_secret_key'
Raw output
./tests/test_automated/integration/security_manager/test_security_manager.py:26:9: U100 Unused argument 'mock_get_secret_key'
monkeypatch
):

monkeypatch.setenv("DISCORD_WEBHOOK_URL", "https://discord.com")
token = jwt.encode(FAKE_PAYLOAD, SECRET_KEY, algorithm=ALGORITHM)

# Create Test Client
Expand Down
13 changes: 13 additions & 0 deletions util/DiscordNotifier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import logging

Check warning on line 1 in util/DiscordNotifier.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] util/DiscordNotifier.py#L1 <100>

Missing docstring in public module
Raw output
./util/DiscordNotifier.py:1:1: D100 Missing docstring in public module

import requests


class DiscordPoster:

Check warning on line 6 in util/DiscordNotifier.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] util/DiscordNotifier.py#L6 <101>

Missing docstring in public class
Raw output
./util/DiscordNotifier.py:6:1: D101 Missing docstring in public class
def __init__(self, webhook_url: str):

Check warning on line 7 in util/DiscordNotifier.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] util/DiscordNotifier.py#L7 <107>

Missing docstring in __init__
Raw output
./util/DiscordNotifier.py:7:1: D107 Missing docstring in __init__
if not webhook_url:
logging.error("WEBHOOK_URL environment variable not set")
raise ValueError("WEBHOOK_URL environment variable not set")
self.webhook_url = webhook_url
def post_to_discord(self, message):

Check warning on line 12 in util/DiscordNotifier.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] util/DiscordNotifier.py#L12 <102>

Missing docstring in public method
Raw output
./util/DiscordNotifier.py:12:1: D102 Missing docstring in public method

Check failure on line 12 in util/DiscordNotifier.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] util/DiscordNotifier.py#L12 <301>

expected 1 blank line, found 0
Raw output
./util/DiscordNotifier.py:12:5: E301 expected 1 blank line, found 0
requests.post(self.webhook_url, json={"content": message})