diff --git a/api/routes/task.py b/api/routes/task.py index d9cdbeac..44971959 100644 --- a/api/routes/task.py +++ b/api/routes/task.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, Depends, Query, Path from api.dependencies import get_async_core +from collector_db.DTOs.GetTaskStatusResponseInfo import GetTaskStatusResponseInfo from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.enums import TaskType from core.AsyncCore import AsyncCore @@ -39,6 +40,12 @@ async def get_tasks( task_status=task_status ) +@task_router.get("/status") +async def get_task_status( + async_core: AsyncCore = Depends(get_async_core), + access_info: AccessInfo = Depends(get_access_info) +) -> GetTaskStatusResponseInfo: + return await async_core.get_current_task_status() @task_router.get("/{task_id}") async def get_task_info( @@ -46,4 +53,6 @@ async def get_task_info( async_core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info) ) -> TaskInfo: - return await async_core.get_task_info(task_id) \ No newline at end of file + return await async_core.get_task_info(task_id) + + diff --git a/collector_db/DTOs/GetTaskStatusResponseInfo.py b/collector_db/DTOs/GetTaskStatusResponseInfo.py new file mode 100644 index 00000000..f6a8d5fc --- /dev/null +++ b/collector_db/DTOs/GetTaskStatusResponseInfo.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +from collector_db.enums import TaskType + + +class GetTaskStatusResponseInfo(BaseModel): + status: TaskType \ No newline at end of file diff --git a/collector_db/enums.py b/collector_db/enums.py index 0dd956c5..c12cfde0 100644 --- a/collector_db/enums.py +++ b/collector_db/enums.py @@ -38,6 +38,7 @@ class TaskType(PyEnum): RECORD_TYPE = "Record Type" AGENCY_IDENTIFICATION = "Agency Identification" MISC_METADATA = "Misc Metadata" + IDLE = "Idle" class PGEnum(TypeDecorator): impl = postgresql.ENUM diff --git a/core/AsyncCore.py b/core/AsyncCore.py index b17903db..299a865e 100644 --- a/core/AsyncCore.py +++ b/core/AsyncCore.py @@ -4,6 +4,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.BatchInfo import BatchInfo +from collector_db.DTOs.GetTaskStatusResponseInfo import GetTaskStatusResponseInfo from collector_db.enums import TaskType from collector_manager.AsyncCollectorManager import AsyncCollectorManager from collector_manager.enums import CollectorType @@ -87,6 +88,8 @@ async def initiate_collector( ) # endregion + async def get_current_task_status(self) -> GetTaskStatusResponseInfo: + return GetTaskStatusResponseInfo(status=self.task_manager.task_status) async def run_tasks(self): await self.task_manager.trigger_task_run() diff --git a/core/TaskManager.py b/core/TaskManager.py index 8ec259f5..64aa57e6 100644 --- a/core/TaskManager.py +++ b/core/TaskManager.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from agency_identifier.MuckrockAPIInterface import MuckrockAPIInterface from collector_db.AsyncDatabaseClient import AsyncDatabaseClient @@ -44,12 +45,12 @@ def __init__( self.logger.addHandler(logging.StreamHandler()) self.logger.setLevel(logging.INFO) self.task_trigger = FunctionTrigger(self.run_tasks) + self.task_status: TaskType = TaskType.IDLE #region Task Operators async def get_url_html_task_operator(self): - self.logger.info("Running URL HTML Task") operator = URLHTMLTaskOperator( adb_client=self.adb_client, url_request_interface=self.url_request_interface, @@ -58,7 +59,6 @@ async def get_url_html_task_operator(self): return operator async def get_url_relevance_huggingface_task_operator(self): - self.logger.info("Running URL Relevance Huggingface Task") operator = URLRelevanceHuggingfaceTaskOperator( adb_client=self.adb_client, huggingface_interface=self.huggingface_interface @@ -106,13 +106,18 @@ async def get_task_operators(self) -> list[TaskOperatorBase]: #endregion #region Tasks + async def set_task_status(self, task_type: TaskType): + self.task_status = task_type + async def run_tasks(self): operators = await self.get_task_operators() count = 0 for operator in operators: + await self.set_task_status(task_type=operator.task_type) meets_prereq = await operator.meets_task_prerequisites() while meets_prereq: + print(f"Running {operator.task_type.value} Task") if count > TASK_REPEAT_THRESHOLD: self.discord_poster.post_to_discord( message=f"Task {operator.task_type.value} has been run" @@ -124,6 +129,7 @@ async def run_tasks(self): await self.conclude_task(run_info) count += 1 meets_prereq = await operator.meets_task_prerequisites() + await self.set_task_status(task_type=TaskType.IDLE) async def trigger_task_run(self): await self.task_trigger.trigger_or_rerun() diff --git a/core/classes/URLHTMLTaskOperator.py b/core/classes/URLHTMLTaskOperator.py index 63321635..ad279f9d 100644 --- a/core/classes/URLHTMLTaskOperator.py +++ b/core/classes/URLHTMLTaskOperator.py @@ -29,7 +29,6 @@ async def meets_task_prerequisites(self): return await self.adb_client.has_pending_urls_without_html_data() async def inner_task_logic(self): - print("Running URL HTML Task...") tdos = await self.get_pending_urls_without_html_data() url_ids = [task_info.url_info.id for task_info in tdos] await self.link_urls_to_task(url_ids=url_ids) diff --git a/tests/test_automated/integration/api/helpers/RequestValidator.py b/tests/test_automated/integration/api/helpers/RequestValidator.py index 02a51b29..f8ada6ae 100644 --- a/tests/test_automated/integration/api/helpers/RequestValidator.py +++ b/tests/test_automated/integration/api/helpers/RequestValidator.py @@ -5,6 +5,7 @@ from starlette.testclient import TestClient from collector_db.DTOs.BatchInfo import BatchInfo +from collector_db.DTOs.GetTaskStatusResponseInfo import GetTaskStatusResponseInfo from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.enums import TaskType from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO @@ -281,4 +282,10 @@ async def reject_and_get_next_source_for_review( url=f"/review/reject-source", json=review_info.model_dump(mode='json') ) - return GetNextURLForFinalReviewOuterResponse(**data) \ No newline at end of file + return GetNextURLForFinalReviewOuterResponse(**data) + + async def get_current_task_status(self) -> GetTaskStatusResponseInfo: + data = self.get( + url=f"/task/status" + ) + return GetTaskStatusResponseInfo(**data) \ No newline at end of file diff --git a/tests/test_automated/integration/api/test_task.py b/tests/test_automated/integration/api/test_task.py index d6e13b1f..547b0eb8 100644 --- a/tests/test_automated/integration/api/test_task.py +++ b/tests/test_automated/integration/api/test_task.py @@ -39,3 +39,17 @@ async def test_get_tasks(api_test_helper): assert task.type == TaskType.HTML assert task.url_count == 3 assert task.url_error_count == 1 + +@pytest.mark.asyncio +async def test_get_task_status(api_test_helper): + ath = api_test_helper + + response = await ath.request_validator.get_current_task_status() + + assert response.status == TaskType.IDLE + + for task in [task for task in TaskType]: + await ath.async_core.task_manager.set_task_status(task) + response = await ath.request_validator.get_current_task_status() + + assert response.status == task