From 72caf70625ba664f3ee6982a507a01c0371c72c4 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Mon, 14 Apr 2025 18:35:16 -0400 Subject: [PATCH] feat(app): make logger async --- api/main.py | 14 ++-- collector_db/AsyncDatabaseClient.py | 11 ++- collector_manager/AsyncCollectorBase.py | 12 ++-- collector_manager/AsyncCollectorManager.py | 14 ++-- core/AsyncCoreLogger.py | 71 +++++++++++++++++++ core/SourceCollectorCore.py | 6 +- .../auto_googler/AutoGooglerCollector.py | 2 +- .../muckrock/classes/MuckrockCollector.py | 4 +- .../test_autogoogler_collector.py | 4 +- .../source_collectors/test_ckan_collector.py | 5 +- .../test_common_crawler_collector.py | 3 +- .../test_muckrock_collectors.py | 7 +- .../integration/api/conftest.py | 13 ++-- .../integration/api/test_batch.py | 2 - .../integration/api/test_example_collector.py | 61 ++++++++++------ tests/test_automated/integration/conftest.py | 32 +++++---- .../core/test_example_collector_lifecycle.py | 1 + .../test_autogoogler_collector.py | 3 +- .../source_collectors/test_ckan_collector.py | 3 +- .../test_common_crawl_collector.py | 3 +- .../test_example_collector.py | 7 +- .../test_muckrock_collectors.py | 7 +- 22 files changed, 199 insertions(+), 86 deletions(-) create mode 100644 core/AsyncCoreLogger.py diff --git a/api/main.py b/api/main.py index 93e4521b..19f8de8d 100644 --- a/api/main.py +++ b/api/main.py @@ -14,7 +14,7 @@ from collector_db.DatabaseClient import DatabaseClient from collector_manager.AsyncCollectorManager import AsyncCollectorManager from core.AsyncCore import AsyncCore -from core.CoreLogger import CoreLogger +from core.AsyncCoreLogger import AsyncCoreLogger from core.ScheduledTaskManager import AsyncScheduledTaskManager from core.SourceCollectorCore import SourceCollectorCore from core.TaskManager import TaskManager @@ -32,12 +32,10 @@ async def lifespan(app: FastAPI): db_client = DatabaseClient() adb_client = AsyncDatabaseClient() await setup_database(db_client) - core_logger = CoreLogger(db_client=db_client) + core_logger = AsyncCoreLogger(adb_client=adb_client) + source_collector_core = SourceCollectorCore( - core_logger=CoreLogger( - db_client=db_client - ), db_client=DatabaseClient(), ) task_manager = TaskManager( @@ -68,13 +66,15 @@ async def lifespan(app: FastAPI): app.state.core = source_collector_core app.state.async_core = async_core app.state.async_scheduled_task_manager = async_scheduled_task_manager + app.state.logger = core_logger # Startup logic yield # Code here runs before shutdown # Shutdown logic (if needed) - core_logger.shutdown() - app.state.core.shutdown() + await core_logger.shutdown() + await async_core.shutdown() + source_collector_core.shutdown() # Clean up resources, close connections, etc. pass diff --git a/collector_db/AsyncDatabaseClient.py b/collector_db/AsyncDatabaseClient.py index 60fdcdfe..c8315fbe 100644 --- a/collector_db/AsyncDatabaseClient.py +++ b/collector_db/AsyncDatabaseClient.py @@ -14,6 +14,7 @@ from collector_db.DTOs.BatchInfo import BatchInfo from collector_db.DTOs.DuplicateInfo import DuplicateInsertInfo from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo +from collector_db.DTOs.LogInfo import LogInfo from collector_db.DTOs.TaskInfo import TaskInfo from collector_db.DTOs.URLErrorInfos import URLErrorPydanticInfo from collector_db.DTOs.URLHTMLContentInfo import URLHTMLContentInfo, HTMLContentType @@ -27,7 +28,7 @@ from collector_db.models import URL, URLErrorInfo, URLHTMLContent, Base, \ RootURL, Task, TaskError, LinkTaskURL, Batch, Agency, AutomatedUrlAgencySuggestion, \ UserUrlAgencySuggestion, AutoRelevantSuggestion, AutoRecordTypeSuggestion, UserRelevantSuggestion, \ - UserRecordTypeSuggestion, ReviewingUserURL, URLOptionalDataSourceMetadata, ConfirmedURLAgency, Duplicate + UserRecordTypeSuggestion, ReviewingUserURL, URLOptionalDataSourceMetadata, ConfirmedURLAgency, Duplicate, Log from collector_manager.enums import URLStatus, CollectorType from core.DTOs.FinalReviewApprovalInfo import FinalReviewApprovalInfo from core.DTOs.GetNextRecordTypeAnnotationResponseInfo import GetNextRecordTypeAnnotationResponseInfo @@ -1378,6 +1379,14 @@ async def get_url_info_by_url(self, session: AsyncSession, url: str) -> Optional url = raw_result.scalars().first() return URLInfo(**url.__dict__) + @session_manager + async def insert_logs(self, session, log_infos: List[LogInfo]): + for log_info in log_infos: + log = Log(log=log_info.log, batch_id=log_info.batch_id) + if log_info.created_at is not None: + log.created_at = log_info.created_at + session.add(log) + @session_manager async def insert_duplicates(self, session, duplicate_infos: list[DuplicateInsertInfo]): for duplicate_info in duplicate_infos: diff --git a/collector_manager/AsyncCollectorBase.py b/collector_manager/AsyncCollectorBase.py index ec53f4c6..fe260266 100644 --- a/collector_manager/AsyncCollectorBase.py +++ b/collector_manager/AsyncCollectorBase.py @@ -10,7 +10,7 @@ from collector_db.DTOs.InsertURLsInfo import InsertURLsInfo from collector_db.DTOs.LogInfo import LogInfo from collector_manager.enums import CollectorType -from core.CoreLogger import CoreLogger +from core.AsyncCoreLogger import AsyncCoreLogger from core.FunctionTrigger import FunctionTrigger from core.enums import BatchStatus from core.preprocessors.PreprocessorBase import PreprocessorBase @@ -25,7 +25,7 @@ def __init__( self, batch_id: int, dto: BaseModel, - logger: CoreLogger, + logger: AsyncCoreLogger, adb_client: AsyncDatabaseClient, raise_error: bool = False, post_collection_function_trigger: Optional[FunctionTrigger] = None, @@ -120,8 +120,12 @@ async def run(self) -> None: self.status = BatchStatus.ERROR await self.handle_error(e) - async def log(self, message: str, allow_abort = True) -> None: - self.logger.log(LogInfo( + async def log( + self, + message: str, + allow_abort = True # Deprecated + ) -> None: + await self.logger.log(LogInfo( batch_id=self.batch_id, log=message )) diff --git a/collector_manager/AsyncCollectorManager.py b/collector_manager/AsyncCollectorManager.py index bf338c88..1851bfc9 100644 --- a/collector_manager/AsyncCollectorManager.py +++ b/collector_manager/AsyncCollectorManager.py @@ -10,7 +10,7 @@ from collector_manager.CollectorManager import InvalidCollectorError from collector_manager.collector_mapping import COLLECTOR_MAPPING from collector_manager.enums import CollectorType -from core.CoreLogger import CoreLogger +from core.AsyncCoreLogger import AsyncCoreLogger from core.FunctionTrigger import FunctionTrigger @@ -18,7 +18,7 @@ class AsyncCollectorManager: def __init__( self, - logger: CoreLogger, + logger: AsyncCoreLogger, adb_client: AsyncDatabaseClient, dev_mode: bool = False, post_collection_function_trigger: FunctionTrigger = None @@ -79,10 +79,16 @@ async def abort_collector_async(self, cid: int) -> None: self.async_tasks.pop(cid) async def shutdown_all_collectors(self) -> None: - for cid, task in self.async_tasks.items(): + while self.async_tasks: + cid, task = self.async_tasks.popitem() if task.done(): try: task.result() except Exception as e: raise e - await self.abort_collector_async(cid) \ No newline at end of file + else: + task.cancel() + try: + await task # Await so cancellation propagates + except asyncio.CancelledError: + pass \ No newline at end of file diff --git a/core/AsyncCoreLogger.py b/core/AsyncCoreLogger.py new file mode 100644 index 00000000..70ca06aa --- /dev/null +++ b/core/AsyncCoreLogger.py @@ -0,0 +1,71 @@ +import asyncio + +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from collector_db.DTOs.LogInfo import LogInfo + + +class AsyncCoreLogger: + def __init__( + self, + adb_client: AsyncDatabaseClient, + flush_interval: float = 10, + batch_size: int = 100 + ): + self.adb_client = adb_client + self.flush_interval = flush_interval + self.batch_size = batch_size + + self.log_queue = asyncio.Queue() + self.lock = asyncio.Lock() + self._flush_task: asyncio.Task | None = None + self._stop_event = asyncio.Event() + + async def __aenter__(self): + self._stop_event.clear() + self._flush_task = asyncio.create_task(self._flush_logs()) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.shutdown() + + async def log(self, log_info: LogInfo): + await self.log_queue.put(log_info) + + async def _flush_logs(self): + while not self._stop_event.is_set(): + await asyncio.sleep(self.flush_interval) + await self.flush() + + async def flush(self): + async with self.lock: + logs: list[LogInfo] = [] + + while not self.log_queue.empty() and len(logs) < self.batch_size: + try: + log = self.log_queue.get_nowait() + logs.append(log) + except asyncio.QueueEmpty: + break + + if logs: + await self.adb_client.insert_logs(log_infos=logs) + + async def clear_log_queue(self): + while not self.log_queue.empty(): + self.log_queue.get_nowait() + + async def flush_all(self): + while not self.log_queue.empty(): + await self.flush() + + async def restart(self): + await self.flush_all() + await self.shutdown() + self._stop_event.clear() + self._flush_task = asyncio.create_task(self._flush_logs()) + + async def shutdown(self): + self._stop_event.set() + if self._flush_task: + await self._flush_task + await self.flush_all() diff --git a/core/SourceCollectorCore.py b/core/SourceCollectorCore.py index a0bb34fc..8002717c 100644 --- a/core/SourceCollectorCore.py +++ b/core/SourceCollectorCore.py @@ -3,7 +3,6 @@ from collector_db.DatabaseClient import DatabaseClient from collector_manager.enums import CollectorType -from core.CoreLogger import CoreLogger from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse from core.DTOs.GetDuplicatesByBatchResponse import GetDuplicatesByBatchResponse @@ -14,13 +13,12 @@ class SourceCollectorCore: def __init__( self, - core_logger: CoreLogger, - collector_manager: Optional[Any] = None, + core_logger: Optional[Any] = None, # Deprecated + collector_manager: Optional[Any] = None, # Deprecated db_client: DatabaseClient = DatabaseClient(), dev_mode: bool = False ): self.db_client = db_client - self.core_logger = core_logger if not dev_mode: self.scheduled_task_manager = ScheduledTaskManager(db_client=db_client) else: diff --git a/source_collectors/auto_googler/AutoGooglerCollector.py b/source_collectors/auto_googler/AutoGooglerCollector.py index b678f066..1748d911 100644 --- a/source_collectors/auto_googler/AutoGooglerCollector.py +++ b/source_collectors/auto_googler/AutoGooglerCollector.py @@ -27,7 +27,7 @@ async def run_to_completion(self) -> AutoGoogler: ) ) async for log in auto_googler.run(): - self.log(log) + await self.log(log) return auto_googler async def run_implementation(self) -> None: diff --git a/source_collectors/muckrock/classes/MuckrockCollector.py b/source_collectors/muckrock/classes/MuckrockCollector.py index 885c0369..0511a21d 100644 --- a/source_collectors/muckrock/classes/MuckrockCollector.py +++ b/source_collectors/muckrock/classes/MuckrockCollector.py @@ -47,9 +47,9 @@ async def run_implementation(self) -> None: self.check_for_count_break(results_count, max_count) except SearchCompleteException: break - self.log(f"Search {search_count}: Found {len(results)} results") + await self.log(f"Search {search_count}: Found {len(results)} results") - self.log(f"Search Complete. Total results: {results_count}") + await self.log(f"Search Complete. Total results: {results_count}") self.data = {"urls": self.format_results(all_results)} def format_results(self, results: list[dict]) -> list[dict]: diff --git a/tests/manual/source_collectors/test_autogoogler_collector.py b/tests/manual/source_collectors/test_autogoogler_collector.py index 78fc46d7..a51fc883 100644 --- a/tests/manual/source_collectors/test_autogoogler_collector.py +++ b/tests/manual/source_collectors/test_autogoogler_collector.py @@ -3,7 +3,7 @@ import pytest from collector_db.AsyncDatabaseClient import AsyncDatabaseClient -from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.auto_googler.AutoGooglerCollector import AutoGooglerCollector from source_collectors.auto_googler.DTOs import AutoGooglerInputDTO @@ -16,7 +16,7 @@ async def test_autogoogler_collector(): urls_per_result=5, queries=["police"], ), - logger = MagicMock(spec=CoreLogger), + logger = AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/manual/source_collectors/test_ckan_collector.py b/tests/manual/source_collectors/test_ckan_collector.py index 3bae5d88..f642fd8d 100644 --- a/tests/manual/source_collectors/test_ckan_collector.py +++ b/tests/manual/source_collectors/test_ckan_collector.py @@ -5,6 +5,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.ckan.CKANCollector import CKANCollector from source_collectors.ckan.DTOs import CKANInputDTO @@ -31,7 +32,7 @@ async def test_ckan_collector_default(): "organization_search": organization_search } ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) @@ -80,7 +81,7 @@ async def test_ckan_collector_custom(): ] } ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/manual/source_collectors/test_common_crawler_collector.py b/tests/manual/source_collectors/test_common_crawler_collector.py index 6c9771f3..872b7710 100644 --- a/tests/manual/source_collectors/test_common_crawler_collector.py +++ b/tests/manual/source_collectors/test_common_crawler_collector.py @@ -5,6 +5,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.common_crawler.CommonCrawlerCollector import CommonCrawlerCollector from source_collectors.common_crawler.DTOs import CommonCrawlerInputDTO @@ -18,7 +19,7 @@ async def test_common_crawler_collector(): collector = CommonCrawlerCollector( batch_id=1, dto=CommonCrawlerInputDTO(), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/manual/source_collectors/test_muckrock_collectors.py b/tests/manual/source_collectors/test_muckrock_collectors.py index 8fb80bc4..bfd0ba26 100644 --- a/tests/manual/source_collectors/test_muckrock_collectors.py +++ b/tests/manual/source_collectors/test_muckrock_collectors.py @@ -3,6 +3,7 @@ import pytest from collector_db.AsyncDatabaseClient import AsyncDatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.muckrock.DTOs import MuckrockSimpleSearchCollectorInputDTO, \ MuckrockCountySearchCollectorInputDTO, MuckrockAllFOIARequestsCollectorInputDTO @@ -22,7 +23,7 @@ async def test_muckrock_simple_search_collector(): search_string="police", max_results=10 ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) @@ -41,7 +42,7 @@ async def test_muckrock_county_level_search_collector(): parent_jurisdiction_id=ALLEGHENY_COUNTY_MUCKROCK_ID, town_names=ALLEGHENY_COUNTY_TOWN_NAMES ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) @@ -61,7 +62,7 @@ async def test_muckrock_full_search_collector(): start_page=1, total_pages=2 ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/test_automated/integration/api/conftest.py b/tests/test_automated/integration/api/conftest.py index e51b05dc..b466bfbb 100644 --- a/tests/test_automated/integration/api/conftest.py +++ b/tests/test_automated/integration/api/conftest.py @@ -6,10 +6,12 @@ from unittest.mock import MagicMock, AsyncMock import pytest +import pytest_asyncio from starlette.testclient import TestClient from api.main import app from core.AsyncCore import AsyncCore +from core.AsyncCoreLogger import AsyncCoreLogger from core.SourceCollectorCore import SourceCollectorCore from security_manager.SecurityManager import get_access_info, AccessInfo, Permissions from tests.helpers.DBDataCreator import DBDataCreator @@ -51,7 +53,6 @@ def client() -> Generator[TestClient, None, None]: os.environ["DISCORD_WEBHOOK_URL"] = "https://discord.com" with TestClient(app) as c: app.dependency_overrides[get_access_info] = override_access_info - core: SourceCollectorCore = c.app.state.core async_core: AsyncCore = c.app.state.async_core # Interfaces to the web should be mocked @@ -63,17 +64,16 @@ def client() -> Generator[TestClient, None, None]: task_manager.logger.disabled = True # Set trigger to fail immediately if called, to force it to be manually specified in tests task_manager.task_trigger._func = fail_task_trigger - # core.shutdown() yield c - core.shutdown() + # Reset environment variables back to original state os.environ.clear() os.environ.update(_original_env) -@pytest.fixture -def api_test_helper(client: TestClient, db_data_creator, monkeypatch) -> APITestHelper: - return APITestHelper( +@pytest_asyncio.fixture +async def api_test_helper(client: TestClient, db_data_creator, monkeypatch) -> APITestHelper: + yield APITestHelper( request_validator=RequestValidator(client=client), core=client.app.state.core, async_core=client.app.state.async_core, @@ -81,3 +81,4 @@ def api_test_helper(client: TestClient, db_data_creator, monkeypatch) -> APITest mock_huggingface_interface=MagicMock(), mock_label_studio_interface=MagicMock() ) + await client.app.state.async_core.collector_manager.logger.clear_log_queue() diff --git a/tests/test_automated/integration/api/test_batch.py b/tests/test_automated/integration/api/test_batch.py index 69c2fcab..604e2d67 100644 --- a/tests/test_automated/integration/api/test_batch.py +++ b/tests/test_automated/integration/api/test_batch.py @@ -20,8 +20,6 @@ def test_abort_batch(api_test_helper): assert response.message == "Batch aborted." - time.sleep(3) - bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) assert bi.status == BatchStatus.ABORTED diff --git a/tests/test_automated/integration/api/test_example_collector.py b/tests/test_automated/integration/api/test_example_collector.py index c99119e7..a235d8e8 100644 --- a/tests/test_automated/integration/api/test_example_collector.py +++ b/tests/test_automated/integration/api/test_example_collector.py @@ -1,10 +1,15 @@ +import asyncio import time -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock +import pytest + +from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.BatchInfo import BatchInfo from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO from collector_manager.ExampleCollector import ExampleCollector from collector_manager.enums import CollectorType +from core.AsyncCoreLogger import AsyncCoreLogger from core.DTOs.BatchStatusInfo import BatchStatusInfo from core.DTOs.GetBatchLogsResponse import GetBatchLogsResponse from core.DTOs.GetBatchStatusResponse import GetBatchStatusResponse @@ -12,12 +17,17 @@ from tests.test_automated.integration.api.conftest import disable_task_trigger -def test_example_collector(api_test_helper): +@pytest.mark.asyncio +async def test_example_collector(api_test_helper): ath = api_test_helper # Temporarily disable task trigger disable_task_trigger(ath) + logger = AsyncCoreLogger(adb_client=AsyncDatabaseClient()) + await logger.__aenter__() + ath.async_core.collector_manager.logger = logger + dto = ExampleInputDTO( sleep_time=1 ) @@ -40,7 +50,7 @@ def test_example_collector(api_test_helper): assert bsi.strategy == CollectorType.EXAMPLE.value assert bsi.status == BatchStatus.IN_PROCESS - time.sleep(2) + await asyncio.sleep(2) csr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses( collector_type=CollectorType.EXAMPLE, @@ -62,29 +72,33 @@ def test_example_collector(api_test_helper): assert bi.user_id is not None # Flush early to ensure logs are written - # Commented out due to inconsistency in execution - # ath.core.core_logger.flush_all() - # - # time.sleep(10) - # - # lr: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) - # - # assert len(lr.logs) > 0 + await logger.flush_all() + + + lr: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) + + assert len(lr.logs) > 0 # Check that task was triggered ath.async_core.collector_manager.\ post_collection_function_trigger.\ trigger_or_rerun.assert_called_once() + await logger.__aexit__(None, None, None) -def test_example_collector_error(api_test_helper, monkeypatch): +@pytest.mark.asyncio +async def test_example_collector_error(api_test_helper, monkeypatch): """ Test that when an error occurs in a collector, the batch is properly update """ ath = api_test_helper + logger = AsyncCoreLogger(adb_client=AsyncDatabaseClient()) + await logger.__aenter__() + ath.async_core.collector_manager.logger = logger + # Patch the collector to raise an exception during run_implementation - mock = MagicMock() + mock = AsyncMock() mock.side_effect = Exception("Collector failed!") monkeypatch.setattr(ExampleCollector, 'run_implementation', mock) @@ -99,20 +113,21 @@ def test_example_collector_error(api_test_helper, monkeypatch): assert batch_id is not None assert data["message"] == "Started example collector." - time.sleep(1) + await asyncio.sleep(1) bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) assert bi.status == BatchStatus.ERROR - # - # ath.core.core_logger.flush_all() - # - # time.sleep(10) - # - # gbl: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) - # assert gbl.logs[-1].log == "Error: Collector failed!" - # - # + # Check there are logs + assert not logger.log_queue.empty() + await logger.flush_all() + assert logger.log_queue.empty() + + gbl: GetBatchLogsResponse = ath.request_validator.get_batch_logs(batch_id=batch_id) + assert gbl.logs[-1].log == "Error: Collector failed!" + await logger.__aexit__(None, None, None) + + diff --git a/tests/test_automated/integration/conftest.py b/tests/test_automated/integration/conftest.py index cd05cf6f..6be03e86 100644 --- a/tests/test_automated/integration/conftest.py +++ b/tests/test_automated/integration/conftest.py @@ -5,6 +5,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_manager.AsyncCollectorManager import AsyncCollectorManager from core.AsyncCore import AsyncCore +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from core.SourceCollectorCore import SourceCollectorCore @@ -24,19 +25,20 @@ def test_core(db_client_test): @pytest.fixture -def test_async_core(db_client_test): - with CoreLogger( - db_client=db_client_test - ) as logger: - adb_client = AsyncDatabaseClient() - core = AsyncCore( +def test_async_core(adb_client_test): + logger = AsyncCoreLogger( + adb_client=adb_client_test + ) + adb_client = AsyncDatabaseClient() + core = AsyncCore( + adb_client=adb_client, + task_manager=MagicMock(), + collector_manager=AsyncCollectorManager( adb_client=adb_client, - task_manager=MagicMock(), - collector_manager=AsyncCollectorManager( - adb_client=adb_client, - logger=logger, - dev_mode=True - ), - ) - yield core - core.shutdown() \ No newline at end of file + logger=logger, + dev_mode=True + ), + ) + yield core + core.shutdown() + logger.shutdown() \ No newline at end of file diff --git a/tests/test_automated/integration/core/test_example_collector_lifecycle.py b/tests/test_automated/integration/core/test_example_collector_lifecycle.py index abe8fb7a..d3f3f855 100644 --- a/tests/test_automated/integration/core/test_example_collector_lifecycle.py +++ b/tests/test_automated/integration/core/test_example_collector_lifecycle.py @@ -39,6 +39,7 @@ async def test_example_collector_lifecycle( assert core.get_status(batch_id) == BatchStatus.IN_PROCESS print("Sleeping for 1.5 seconds...") await asyncio.sleep(1.5) + await acore.collector_manager.logger.flush_all() print("Done sleeping...") assert core.get_status(batch_id) == BatchStatus.COMPLETE diff --git a/tests/test_automated/unit/source_collectors/test_autogoogler_collector.py b/tests/test_automated/unit/source_collectors/test_autogoogler_collector.py index 050b1299..2349afe2 100644 --- a/tests/test_automated/unit/source_collectors/test_autogoogler_collector.py +++ b/tests/test_automated/unit/source_collectors/test_autogoogler_collector.py @@ -5,6 +5,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.URLInfo import URLInfo from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.auto_googler.AutoGooglerCollector import AutoGooglerCollector from source_collectors.auto_googler.DTOs import GoogleSearchQueryResultsInnerDTO, AutoGooglerInputDTO @@ -29,7 +30,7 @@ async def test_auto_googler_collector(patch_get_query_results): dto=AutoGooglerInputDTO( queries=["keyword"] ), - logger=AsyncMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/test_automated/unit/source_collectors/test_ckan_collector.py b/tests/test_automated/unit/source_collectors/test_ckan_collector.py index b00ed434..ef7dbee8 100644 --- a/tests/test_automated/unit/source_collectors/test_ckan_collector.py +++ b/tests/test_automated/unit/source_collectors/test_ckan_collector.py @@ -6,6 +6,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.ckan.CKANCollector import CKANCollector from source_collectors.ckan.DTOs import CKANInputDTO @@ -42,7 +43,7 @@ async def test_ckan_collector(mock_ckan_collector_methods): collector = CKANCollector( batch_id=1, dto=CKANInputDTO(), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/test_automated/unit/source_collectors/test_common_crawl_collector.py b/tests/test_automated/unit/source_collectors/test_common_crawl_collector.py index 74fe1052..d1f0ccda 100644 --- a/tests/test_automated/unit/source_collectors/test_common_crawl_collector.py +++ b/tests/test_automated/unit/source_collectors/test_common_crawl_collector.py @@ -5,6 +5,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.URLInfo import URLInfo from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.common_crawler.CommonCrawlerCollector import CommonCrawlerCollector from source_collectors.common_crawler.DTOs import CommonCrawlerInputDTO @@ -31,7 +32,7 @@ async def test_common_crawl_collector(mock_get_common_crawl_search_results): dto=CommonCrawlerInputDTO( search_term="keyword", ), - logger=mock.MagicMock(spec=CoreLogger), + logger=mock.AsyncMock(spec=AsyncCoreLogger), adb_client=mock.AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) diff --git a/tests/test_automated/unit/source_collectors/test_example_collector.py b/tests/test_automated/unit/source_collectors/test_example_collector.py index 17512a6f..26ca601d 100644 --- a/tests/test_automated/unit/source_collectors/test_example_collector.py +++ b/tests/test_automated/unit/source_collectors/test_example_collector.py @@ -1,8 +1,9 @@ -from unittest.mock import MagicMock +from unittest.mock import MagicMock, AsyncMock from collector_db.DatabaseClient import DatabaseClient from collector_manager.DTOs.ExampleInputDTO import ExampleInputDTO from collector_manager.ExampleCollector import ExampleCollector +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger @@ -12,8 +13,8 @@ def test_example_collector(): dto=ExampleInputDTO( sleep_time=1 ), - logger=MagicMock(spec=CoreLogger), - adb_client=MagicMock(spec=DatabaseClient), + logger=AsyncMock(spec=AsyncCoreLogger), + adb_client=AsyncMock(spec=DatabaseClient), raise_error=True ) collector.run() \ No newline at end of file diff --git a/tests/test_automated/unit/source_collectors/test_muckrock_collectors.py b/tests/test_automated/unit/source_collectors/test_muckrock_collectors.py index f74c651e..7e533efa 100644 --- a/tests/test_automated/unit/source_collectors/test_muckrock_collectors.py +++ b/tests/test_automated/unit/source_collectors/test_muckrock_collectors.py @@ -6,6 +6,7 @@ from collector_db.AsyncDatabaseClient import AsyncDatabaseClient from collector_db.DTOs.URLInfo import URLInfo from collector_db.DatabaseClient import DatabaseClient +from core.AsyncCoreLogger import AsyncCoreLogger from core.CoreLogger import CoreLogger from source_collectors.muckrock.DTOs import MuckrockSimpleSearchCollectorInputDTO, \ MuckrockCountySearchCollectorInputDTO, MuckrockAllFOIARequestsCollectorInputDTO @@ -40,7 +41,7 @@ async def test_muckrock_simple_collector(patch_muckrock_fetcher): search_string="keyword", max_results=2 ), - logger=mock.MagicMock(spec=CoreLogger), + logger=mock.AsyncMock(spec=AsyncCoreLogger), adb_client=mock.AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) @@ -97,7 +98,7 @@ async def test_muckrock_county_search_collector(patch_muckrock_county_level_sear parent_jurisdiction_id=1, town_names=["test"] ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True ) @@ -166,7 +167,7 @@ async def test_muckrock_all_foia_requests_collector(patch_muckrock_full_search_c start_page=1, total_pages=2 ), - logger=MagicMock(spec=CoreLogger), + logger=AsyncMock(spec=AsyncCoreLogger), adb_client=AsyncMock(spec=AsyncDatabaseClient), raise_error=True )