From 0aef8d3f896439069a4a6ea5fd3906cafce48bc3 Mon Sep 17 00:00:00 2001 From: Max Chis Date: Sun, 23 Nov 2025 17:21:13 -0500 Subject: [PATCH] Update Record Task type not to repeat on error. --- .../tasks/url/operators/record_type/core.py | 57 +++++++++++---- .../operators/record_type/queries/__init__.py | 0 .../url/operators/record_type/queries/cte.py | 31 ++++++++ .../url/operators/record_type/queries/get.py | 36 ++++++++++ .../operators/record_type/queries/prereq.py | 18 +++++ src/db/client/async_.py | 71 +------------------ .../url/impl/test_url_record_type_task.py | 7 ++ 7 files changed, 139 insertions(+), 81 deletions(-) create mode 100644 src/core/tasks/url/operators/record_type/queries/__init__.py create mode 100644 src/core/tasks/url/operators/record_type/queries/cte.py create mode 100644 src/core/tasks/url/operators/record_type/queries/get.py create mode 100644 src/core/tasks/url/operators/record_type/queries/prereq.py diff --git a/src/core/tasks/url/operators/record_type/core.py b/src/core/tasks/url/operators/record_type/core.py index 8e31fa8d..9f63a6a5 100644 --- a/src/core/tasks/url/operators/record_type/core.py +++ b/src/core/tasks/url/operators/record_type/core.py @@ -1,9 +1,13 @@ from src.core.enums import RecordType from src.core.tasks.url.operators.base import URLTaskOperatorBase from src.core.tasks.url.operators.record_type.llm_api.record_classifier.openai import OpenAIRecordClassifier +from src.core.tasks.url.operators.record_type.queries.get import GetRecordTypeTaskURLsQueryBuilder +from src.core.tasks.url.operators.record_type.queries.prereq import RecordTypeTaskPrerequisiteQueryBuilder from src.core.tasks.url.operators.record_type.tdo import URLRecordTypeTDO from src.db.client.async_ import AsyncDatabaseClient +from src.db.dtos.url.with_html import URLWithHTML from src.db.enums import TaskType +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall @@ -18,18 +22,22 @@ def __init__( self.classifier = classifier @property - def task_type(self): + def task_type(self) -> TaskType: return TaskType.RECORD_TYPE - async def meets_task_prerequisites(self): - return await self.adb_client.has_urls_with_html_data_and_without_auto_record_type_suggestion() + async def meets_task_prerequisites(self) -> bool: + return await self.run_query_builder( + RecordTypeTaskPrerequisiteQueryBuilder() + ) async def get_tdos(self) -> list[URLRecordTypeTDO]: - urls_with_html = await self.adb_client.get_urls_with_html_data_and_without_auto_record_type_suggestion() + urls_with_html: list[URLWithHTML] = await self.run_query_builder( + GetRecordTypeTaskURLsQueryBuilder() + ) tdos = [URLRecordTypeTDO(url_with_html=url_with_html) for url_with_html in urls_with_html] return tdos - async def inner_task_logic(self): + async def inner_task_logic(self) -> None: # Get pending urls from Source Collector # with HTML data and without Record Type Metadata tdos = await self.get_tdos() @@ -41,7 +49,10 @@ async def inner_task_logic(self): await self.put_results_into_database(success_subset) await self.update_errors_in_database(error_subset) - async def update_errors_in_database(self, tdos: list[URLRecordTypeTDO]): + async def update_errors_in_database( + self, + tdos: list[URLRecordTypeTDO] + ) -> None: task_errors: list[URLTaskErrorSmall] = [] for tdo in tdos: error_info = URLTaskErrorSmall( @@ -51,20 +62,42 @@ async def update_errors_in_database(self, tdos: list[URLRecordTypeTDO]): task_errors.append(error_info) await self.add_task_errors(task_errors) - async def put_results_into_database(self, tdos: list[URLRecordTypeTDO]): - suggestions = [] + async def put_results_into_database( + self, + tdos: list[URLRecordTypeTDO] + ) -> None: + url_and_record_type_list = [] for tdo in tdos: url_id = tdo.url_with_html.url_id record_type = tdo.record_type - suggestions.append((url_id, record_type)) - await self.adb_client.add_auto_record_type_suggestions(suggestions) + url_and_record_type_list.append((url_id, record_type)) + # Add to database + suggestions: list[AutoRecordTypeSuggestion] = [] + for url_id, record_type in url_and_record_type_list: + suggestion = AutoRecordTypeSuggestion( + url_id=url_id, + record_type=record_type.value + ) + suggestions.append(suggestion) + await self.adb_client.add_all(suggestions) - async def separate_success_and_error_subsets(self, tdos: list[URLRecordTypeTDO]): + @staticmethod + async def separate_success_and_error_subsets( + tdos: list[URLRecordTypeTDO] + ) -> tuple[list[URLRecordTypeTDO], list[URLRecordTypeTDO]]: success_subset = [tdo for tdo in tdos if not tdo.is_errored()] error_subset = [tdo for tdo in tdos if tdo.is_errored()] return success_subset, error_subset - async def get_ml_classifications(self, tdos: list[URLRecordTypeTDO]): + async def get_ml_classifications( + self, + tdos: list[URLRecordTypeTDO] + ) -> None: + """ + Modifies: + - tdo.record_type + - tdo.error + """ for tdo in tdos: try: record_type_str = await self.classifier.classify_url(tdo.url_with_html.html_infos) diff --git a/src/core/tasks/url/operators/record_type/queries/__init__.py b/src/core/tasks/url/operators/record_type/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/core/tasks/url/operators/record_type/queries/cte.py b/src/core/tasks/url/operators/record_type/queries/cte.py new file mode 100644 index 00000000..22d3db10 --- /dev/null +++ b/src/core/tasks/url/operators/record_type/queries/cte.py @@ -0,0 +1,31 @@ +from sqlalchemy import select, CTE, Column + +from src.db.enums import TaskType +from src.db.helpers.query import not_exists_url, no_url_task_error +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML +from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion + + +class RecordTypeTaskPrerequisiteCTEContainer: + + def __init__(self): + self.cte: CTE = ( + select( + URL.id + ) + .join( + URLCompressedHTML + ) + .where( + not_exists_url(AutoRecordTypeSuggestion), + no_url_task_error( + TaskType.RECORD_TYPE + ) + ) + .cte("record_type_task_prerequisite") + ) + + @property + def url_id(self) -> Column[int]: + return self.cte.columns.id \ No newline at end of file diff --git a/src/core/tasks/url/operators/record_type/queries/get.py b/src/core/tasks/url/operators/record_type/queries/get.py new file mode 100644 index 00000000..c5b7e7e0 --- /dev/null +++ b/src/core/tasks/url/operators/record_type/queries/get.py @@ -0,0 +1,36 @@ +from typing import Sequence + +from sqlalchemy import select, Row +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from src.core.tasks.url.operators.record_type.queries.cte import RecordTypeTaskPrerequisiteCTEContainer +from src.db.dto_converter import DTOConverter +from src.db.dtos.url.with_html import URLWithHTML +from src.db.models.impl.url.core.sqlalchemy import URL +from src.db.queries.base.builder import QueryBuilderBase + + +class GetRecordTypeTaskURLsQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> list[URLWithHTML]: + cte = RecordTypeTaskPrerequisiteCTEContainer() + query = ( + select( + URL + ) + .join( + cte.cte, + cte.url_id == URL.id + ) + .options( + selectinload(URL.html_content) + ) + .limit(100) + .order_by(URL.id) + ) + urls: Sequence[Row[URL]] = await self.sh.scalars( + session=session, + query=query + ) + return DTOConverter.url_list_to_url_with_html_list(urls) diff --git a/src/core/tasks/url/operators/record_type/queries/prereq.py b/src/core/tasks/url/operators/record_type/queries/prereq.py new file mode 100644 index 00000000..32b70adb --- /dev/null +++ b/src/core/tasks/url/operators/record_type/queries/prereq.py @@ -0,0 +1,18 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.core.tasks.url.operators.record_type.queries.cte import RecordTypeTaskPrerequisiteCTEContainer +from src.db.queries.base.builder import QueryBuilderBase + + +class RecordTypeTaskPrerequisiteQueryBuilder(QueryBuilderBase): + + async def run(self, session: AsyncSession) -> bool: + container = RecordTypeTaskPrerequisiteCTEContainer() + query = ( + select( + container.url_id + ) + ) + return await self.sh.results_exist(session=session, query=query) + diff --git a/src/db/client/async_.py b/src/db/client/async_.py index 10ee5b6c..913a0a35 100644 --- a/src/db/client/async_.py +++ b/src/db/client/async_.py @@ -1,10 +1,9 @@ from datetime import datetime from functools import wraps -from typing import Optional, Type, Any, List, Sequence +from typing import Optional, Any, List -from sqlalchemy import select, func, Select, and_, update, Row, text, Engine +from sqlalchemy import select, func, Select, and_, update, Row, text from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker, AsyncEngine -from sqlalchemy.orm import selectinload from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse from src.api.endpoints.annotate.all.get.queries.core import GetNextURLForAllAnnotationQueryBuilder @@ -50,7 +49,6 @@ from src.db.client.types import UserSuggestionModel from src.db.config_manager import ConfigManager from src.db.constants import PLACEHOLDER_AGENCY_NAME -from src.db.dto_converter import DTOConverter from src.db.dtos.url.html_content import URLHTMLContentInfo from src.db.dtos.url.insert import InsertURLsInfo from src.db.dtos.url.raw_html import RawHTMLInfo @@ -286,18 +284,6 @@ async def add_user_relevant_suggestion( # region record_type - @session_manager - async def add_auto_record_type_suggestions( - self, - session: AsyncSession, - url_and_record_type_list: list[tuple[int, RecordType]] - ): - for url_id, record_type in url_and_record_type_list: - suggestion = AutoRecordTypeSuggestion( - url_id=url_id, - record_type=record_type.value - ) - session.add(suggestion) async def add_auto_record_type_suggestion( self, @@ -381,59 +367,6 @@ async def add_miscellaneous_metadata(self, session: AsyncSession, tdos: list[URL async def get_non_errored_urls_without_html_data(self) -> list[URLInfo]: return await self.run_query_builder(GetPendingURLsWithoutHTMLDataQueryBuilder()) - async def get_urls_with_html_data_and_without_models( - self, - session: AsyncSession, - model: Type[Base] - ): - statement = (select(URL) - .options(selectinload(URL.html_content)) - .where(URL.status == URLStatus.OK.value)) - statement = self.statement_composer.exclude_urls_with_extant_model( - statement=statement, - model=model - ) - statement = statement.limit(100).order_by(URL.id) - raw_result = await session.execute(statement) - urls: Sequence[Row[URL]] = raw_result.unique().scalars().all() - final_results = DTOConverter.url_list_to_url_with_html_list(urls) - - return final_results - - @session_manager - async def get_urls_with_html_data_and_without_auto_record_type_suggestion( - self, - session: AsyncSession - ): - return await self.get_urls_with_html_data_and_without_models( - session=session, - model=AutoRecordTypeSuggestion - ) - - async def has_urls_with_html_data_and_without_models( - self, - session: AsyncSession, - model: Type[Base] - ) -> bool: - statement = (select(URL) - .join(URLCompressedHTML) - .where(URL.status == URLStatus.OK.value)) - # Exclude URLs with auto suggested record types - statement = self.statement_composer.exclude_urls_with_extant_model( - statement=statement, - model=model - ) - statement = statement.limit(1) - scalar_result = await session.scalars(statement) - return bool(scalar_result.first()) - - @session_manager - async def has_urls_with_html_data_and_without_auto_record_type_suggestion(self, session: AsyncSession) -> bool: - return await self.has_urls_with_html_data_and_without_models( - session=session, - model=AutoRecordTypeSuggestion - ) - @session_manager async def one_or_none_model( self, diff --git a/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py b/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py index 1373f3fa..57f41ded 100644 --- a/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py +++ b/tests/automated/integration/tasks/url/impl/test_url_record_type_task.py @@ -7,6 +7,7 @@ from src.core.tasks.url.enums import TaskOperatorOutcome from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator from src.core.enums import RecordType +from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError from tests.helpers.data_creator.core import DBDataCreator from src.core.tasks.url.operators.record_type.llm_api.record_classifier.deepseek import DeepSeekRecordClassifier @@ -52,3 +53,9 @@ async def test_url_record_type_task(db_data_creator: DBDataCreator): for suggestion in suggestions: assert suggestion.record_type == RecordType.ACCIDENT_REPORTS.value + # Get URL Error Tasks + url_error_tasks: list[URLTaskError] = await db_data_creator.adb_client.get_all(URLTaskError) + assert len(url_error_tasks) == 1 + url_error_task = url_error_tasks[0] + assert url_error_task.url_id == url_ids[1] + assert url_error_task.task_type == TaskType.RECORD_TYPE \ No newline at end of file