diff --git a/api/main.py b/api/main.py index c993b941..6c5e2018 100644 --- a/api/main.py +++ b/api/main.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import asynccontextmanager import aiohttp diff --git a/collector_manager/ExampleCollector.py b/collector_manager/ExampleCollector.py index 9f451732..7bc8a583 100644 --- a/collector_manager/ExampleCollector.py +++ b/collector_manager/ExampleCollector.py @@ -21,9 +21,14 @@ async def run_implementation(self) -> None: sleep_time = dto.sleep_time for i in range(sleep_time): # Simulate a task await self.log(f"Step {i + 1}/{sleep_time}") - await asyncio.sleep(1) # Simulate work + await self.sleep() self.data = ExampleOutputDTO( message=f"Data collected by {self.batch_id}", urls=["https://example.com", "https://example.com/2"], parameters=self.dto.model_dump(), - ) \ No newline at end of file + ) + + @staticmethod + async def sleep(): + # Simulate work + await asyncio.sleep(1) \ No newline at end of file diff --git a/tests/helpers/AwaitableBarrier.py b/tests/helpers/AwaitableBarrier.py new file mode 100644 index 00000000..8bf65a11 --- /dev/null +++ b/tests/helpers/AwaitableBarrier.py @@ -0,0 +1,13 @@ +import asyncio + + +class AwaitableBarrier: + def __init__(self): + self._event = asyncio.Event() + + async def __call__(self, *args, **kwargs): + await self._event.wait() + + def release(self): + self._event.set() + diff --git a/tests/helpers/patch_functions.py b/tests/helpers/patch_functions.py new file mode 100644 index 00000000..bb805d29 --- /dev/null +++ b/tests/helpers/patch_functions.py @@ -0,0 +1,10 @@ +from tests.helpers.AwaitableBarrier import AwaitableBarrier + + +async def block_sleep(monkeypatch) -> AwaitableBarrier: + barrier = AwaitableBarrier() + monkeypatch.setattr( + "collector_manager.ExampleCollector.ExampleCollector.sleep", + barrier + ) + return barrier diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index 00ee7473..73f0c8ab 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -1,6 +1,7 @@ +import asyncio from dataclasses import dataclass from typing import Generator -from unittest.mock import MagicMock, AsyncMock, patch +from unittest.mock import MagicMock, AsyncMock import pytest import pytest_asyncio @@ -9,7 +10,9 @@ from api.main import app from core.AsyncCore import AsyncCore from api.routes.review import requires_final_review_permission +from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse from core.SourceCollectorCore import SourceCollectorCore +from core.enums import BatchStatus from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions, require_permission from tests.helpers.DBDataCreator import DBDataCreator from tests.test_automated.integration.api.helpers.RequestValidator import RequestValidator @@ -26,6 +29,17 @@ class APITestHelper: def adb_client(self): return self.db_data_creator.adb_client + async def wait_for_all_batches_to_complete(self): + for i in range(20): + data: GetBatchStatusResponse = self.request_validator.get_batch_statuses( + status=BatchStatus.IN_PROCESS + ) + if len(data.results) == 0: + return + print("Waiting...") + await asyncio.sleep(0.1) + raise ValueError("Batches did not complete in expected time") + MOCK_USER_ID = 1 def disable_task_trigger(ath: APITestHelper) -> None: diff --git a/tests/test_automated/integration/api/test_duplicates.py b/tests/test_automated/integration/api/test_duplicates.py index a5c77b29..6c6c42ce 100644 --- a/tests/test_automated/integration/api/test_duplicates.py +++ b/tests/test_automated/integration/api/test_duplicates.py @@ -1,18 +1,22 @@ +import asyncio import time +import pytest + from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO from tests.test_automated.integration.api.conftest import disable_task_trigger -def test_duplicates(api_test_helper): +@pytest.mark.asyncio +async def test_duplicates(api_test_helper): ath = api_test_helper # Temporarily disable task trigger disable_task_trigger(ath) dto = ExampleInputDTO( - sleep_time=1 + sleep_time=0 ) batch_id_1 = ath.request_validator.example_collector( @@ -21,15 +25,14 @@ def test_duplicates(api_test_helper): assert batch_id_1 is not None - time.sleep(1) - batch_id_2 = ath.request_validator.example_collector( dto=dto )["batch_id"] assert batch_id_2 is not None - time.sleep(1.5) + await ath.wait_for_all_batches_to_complete() + bi_1: BatchInfo = ath.request_validator.get_batch_info(batch_id_1) bi_2: BatchInfo = ath.request_validator.get_batch_info(batch_id_2) diff --git a/tests/test_automated/integration/api/test_example_collector.py b/tests/test_automated/integration/api/test_example_collector.py index b13f7e31..0b3cf30f 100644 --- a/tests/test_automated/integration/api/test_example_collector.py +++ b/tests/test_automated/integration/api/test_example_collector.py @@ -1,6 +1,5 @@ import asyncio -import time -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import AsyncMock import pytest @@ -14,24 +13,29 @@ from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse from core.enums import BatchStatus +from tests.helpers.patch_functions import block_sleep from tests.test_automated.integration.api.conftest import disable_task_trigger @pytest.mark.asyncio -async def test_example_collector(api_test_helper): +async def test_example_collector(api_test_helper, monkeypatch): ath = api_test_helper + barrier = await block_sleep(monkeypatch) + # Temporarily disable task trigger disable_task_trigger(ath) + logger = AsyncCoreLogger(adb_client=AsyncDatabaseClient(), flush_interval=1) await logger.__aenter__() ath.async_core.collector_manager.logger = logger dto = ExampleInputDTO( - sleep_time=1 - ) + sleep_time=1 + ) + # Request Example Collector data = ath.request_validator.example_collector( dto=dto ) @@ -39,10 +43,14 @@ async def test_example_collector(api_test_helper): assert batch_id is not None assert data["message"] == "Started example collector." + # Yield control so coroutine runs up to the barrier + await asyncio.sleep(0) + + + # Check that batch currently shows as In Process bsr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses( status=BatchStatus.IN_PROCESS ) - assert len(bsr.results) == 1 bsi: BatchStatusInfo = bsr.results[0] @@ -50,7 +58,10 @@ async def test_example_collector(api_test_helper): assert bsi.strategy == CollectorType.EXAMPLE.value assert bsi.status == BatchStatus.IN_PROCESS - await asyncio.sleep(2) + # Release the barrier to resume execution + barrier.release() + + await ath.wait_for_all_batches_to_complete() csr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses( collector_type=CollectorType.EXAMPLE, @@ -74,7 +85,6 @@ async def test_example_collector(api_test_helper): # Flush early to ensure logs are written await logger.flush_all() - lr: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) assert len(lr.logs) > 0 @@ -113,7 +123,7 @@ async def test_example_collector_error(api_test_helper, monkeypatch): assert batch_id is not None assert data["message"] == "Started example collector." - await asyncio.sleep(1) + await ath.wait_for_all_batches_to_complete() bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) diff --git a/tests/test_automated/integration/core/test_example_collector_lifecycle.py b/tests/test_automated/integration/core/test_example_collector_lifecycle.py index a9c4900f..65ffc001 100644 --- a/tests/test_automated/integration/core/test_example_collector_lifecycle.py +++ b/tests/test_automated/integration/core/test_example_collector_lifecycle.py @@ -9,11 +9,14 @@ from core.DTOs.CollectorStartInfo import CollectorStartInfo from core.SourceCollectorCore import SourceCollectorCore from core.enums import BatchStatus +from tests.helpers.patch_functions import block_sleep + @pytest.mark.asyncio async def test_example_collector_lifecycle( test_core: SourceCollectorCore, - test_async_core: AsyncCore + test_async_core: AsyncCore, + monkeypatch ): """ Test the flow of an example collector, which generates fake urls @@ -22,6 +25,9 @@ async def test_example_collector_lifecycle( acore = test_async_core core = test_core db_client = core.db_client + + barrier = await block_sleep(monkeypatch) + dto = ExampleInputDTO( example_field="example_value", sleep_time=1 @@ -36,11 +42,13 @@ async def test_example_collector_lifecycle( batch_id = csi.batch_id + # Yield control so coroutine runs up to the barrier + await asyncio.sleep(0) + assert core.get_status(batch_id) == BatchStatus.IN_PROCESS - print("Sleeping for 1.5 seconds...") - await asyncio.sleep(1.5) + # Release the barrier to resume execution + barrier.release() await acore.collector_manager.logger.flush_all() - print("Done sleeping...") assert core.get_status(batch_id) == BatchStatus.READY_TO_LABEL batch_info: BatchInfo = db_client.get_batch_by_id(batch_id) @@ -48,7 +56,7 @@ async def test_example_collector_lifecycle( assert batch_info.status == BatchStatus.READY_TO_LABEL assert batch_info.total_url_count == 2 assert batch_info.parameters == dto.model_dump() - assert batch_info.compute_time > 1 + assert batch_info.compute_time > 0 url_infos = db_client.get_urls_by_batch(batch_id) assert len(url_infos) == 2 @@ -61,15 +69,19 @@ async def test_example_collector_lifecycle( @pytest.mark.asyncio async def test_example_collector_lifecycle_multiple_batches( test_core: SourceCollectorCore, - test_async_core: AsyncCore + test_async_core: AsyncCore, + monkeypatch ): """ Test the flow of an example collector, which generates fake urls and saves them to the database """ + barrier = await block_sleep(monkeypatch) acore = test_async_core core = test_core csis: list[CollectorStartInfo] = [] + + for i in range(3): dto = ExampleInputDTO( example_field="example_value", @@ -82,12 +94,16 @@ async def test_example_collector_lifecycle_multiple_batches( ) csis.append(csi) + await asyncio.sleep(0) for csi in csis: print("Batch ID:", csi.batch_id) assert core.get_status(csi.batch_id) == BatchStatus.IN_PROCESS - await asyncio.sleep(3) + barrier.release() + + await asyncio.sleep(0.15) for csi in csis: assert core.get_status(csi.batch_id) == BatchStatus.READY_TO_LABEL + diff --git a/tests/test_automated/unit/core/test_core_logger.py b/tests/test_automated/unit/core/test_core_logger.py index d91ce6cd..b0d52055 100644 --- a/tests/test_automated/unit/core/test_core_logger.py +++ b/tests/test_automated/unit/core/test_core_logger.py @@ -10,14 +10,14 @@ @pytest.mark.asyncio async def test_logger_flush(): mock_adb_client = AsyncMock() - async with AsyncCoreLogger(flush_interval=1, adb_client=mock_adb_client) as logger: + async with AsyncCoreLogger(flush_interval=0.01, adb_client=mock_adb_client) as logger: # Add logs await logger.log(LogInfo(log="Log 1", batch_id=1)) await logger.log(LogInfo(log="Log 2", batch_id=1)) # Wait for the flush interval - await asyncio.sleep(1.5) + await asyncio.sleep(0.02) # Verify logs were flushed mock_adb_client.insert_logs.assert_called_once()