diff --git a/alembic/versions/2025_06_03_0814-c1380f90f5de_remove_unused_batch_columns.py b/alembic/versions/2025_06_03_0814-c1380f90f5de_remove_unused_batch_columns.py new file mode 100644 index 00000000..f88d4b4c --- /dev/null +++ b/alembic/versions/2025_06_03_0814-c1380f90f5de_remove_unused_batch_columns.py @@ -0,0 +1,43 @@ +"""Remove unused batch columns + +Revision ID: c1380f90f5de +Revises: 00cc949e0347 +Create Date: 2025-06-03 08:14:15.583777 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'c1380f90f5de' +down_revision: Union[str, None] = '00cc949e0347' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + +TABLE_NAME = "batches" +TOTAL_URL_COUNT_COLUMN = "total_url_count" +ORIGINAL_URL_COUNT_COLUMN = "original_url_count" +DUPLICATE_URL_COUNT_COLUMN = "duplicate_url_count" + +def upgrade() -> None: + for column in [ + TOTAL_URL_COUNT_COLUMN, + ORIGINAL_URL_COUNT_COLUMN, + DUPLICATE_URL_COUNT_COLUMN, + ]: + op.drop_column(TABLE_NAME, column) + + +def downgrade() -> None: + for column in [ + TOTAL_URL_COUNT_COLUMN, + ORIGINAL_URL_COUNT_COLUMN, + DUPLICATE_URL_COUNT_COLUMN, + ]: + op.add_column( + TABLE_NAME, + sa.Column(column, sa.Integer(), nullable=False, default=0), + ) diff --git a/src/api/endpoints/batch/dtos/get/status.py b/src/api/endpoints/batch/dtos/get/status.py deleted file mode 100644 index a591b88e..00000000 --- a/src/api/endpoints/batch/dtos/get/status.py +++ /dev/null @@ -1,7 +0,0 @@ -from pydantic import BaseModel - -from src.db.dtos.batch_info import BatchInfo - - -class GetBatchStatusResponse(BaseModel): - results: list[BatchInfo] \ No newline at end of file diff --git a/src/api/endpoints/batch/dtos/get/summaries/__init__.py b/src/api/endpoints/batch/dtos/get/summaries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/api/endpoints/batch/dtos/get/summaries/counts.py b/src/api/endpoints/batch/dtos/get/summaries/counts.py new file mode 100644 index 00000000..0ce4e468 --- /dev/null +++ b/src/api/endpoints/batch/dtos/get/summaries/counts.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class BatchSummaryURLCounts(BaseModel): + total: int + pending: int + duplicate: int + not_relevant: int + submitted: int + errored: int diff --git a/src/api/endpoints/batch/dtos/get/summaries/response.py b/src/api/endpoints/batch/dtos/get/summaries/response.py new file mode 100644 index 00000000..9dead212 --- /dev/null +++ b/src/api/endpoints/batch/dtos/get/summaries/response.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary + + +class GetBatchSummariesResponse(BaseModel): + results: list[BatchSummary] diff --git a/src/api/endpoints/batch/dtos/get/summaries/summary.py b/src/api/endpoints/batch/dtos/get/summaries/summary.py new file mode 100644 index 00000000..f00a42a5 --- /dev/null +++ b/src/api/endpoints/batch/dtos/get/summaries/summary.py @@ -0,0 +1,18 @@ +import datetime +from typing import Optional + +from pydantic import BaseModel + +from src.api.endpoints.batch.dtos.get.summaries.counts import BatchSummaryURLCounts +from src.core.enums import BatchStatus + + +class BatchSummary(BaseModel): + id: int + strategy: str + status: BatchStatus + parameters: dict + user_id: int + compute_time: Optional[float] + date_generated: datetime.datetime + url_counts: BatchSummaryURLCounts diff --git a/src/api/endpoints/batch/routes.py b/src/api/endpoints/batch/routes.py index e79f7f14..bd3282fc 100644 --- a/src/api/endpoints/batch/routes.py +++ b/src/api/endpoints/batch/routes.py @@ -6,15 +6,15 @@ from src.api.dependencies import get_async_core from src.api.endpoints.batch.dtos.get.duplicates import GetDuplicatesByBatchResponse from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse -from src.api.endpoints.batch.dtos.get.status import GetBatchStatusResponse +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.api.endpoints.batch.dtos.get.urls import GetURLsByBatchResponse from src.api.endpoints.batch.dtos.post.abort import MessageResponse -from src.db.dtos.batch_info import BatchInfo from src.collectors.enums import CollectorType from src.core.core import AsyncCore from src.core.enums import BatchStatus -from src.security.manager import get_access_info from src.security.dtos.access_info import AccessInfo +from src.security.manager import get_access_info batch_router = APIRouter( prefix="/batch", @@ -43,7 +43,7 @@ async def get_batch_status( ), core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), -) -> GetBatchStatusResponse: +) -> GetBatchSummariesResponse: """ Get the status of recent batches """ @@ -60,9 +60,8 @@ async def get_batch_info( batch_id: int = Path(description="The batch id"), core: AsyncCore = Depends(get_async_core), access_info: AccessInfo = Depends(get_access_info), -) -> BatchInfo: - result = await core.get_batch_info(batch_id) - return result +) -> BatchSummary: + return await core.get_batch_info(batch_id) @batch_router.get("/{batch_id}/urls") async def get_urls_by_batch( diff --git a/src/core/core.py b/src/core/core.py index f6151a85..a37af7e3 100644 --- a/src/core/core.py +++ b/src/core/core.py @@ -1,5 +1,7 @@ +from http import HTTPStatus from typing import Optional +from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.exc import IntegrityError @@ -11,7 +13,8 @@ from src.api.endpoints.annotate.dtos.relevance.response import GetNextRelevanceAnnotationResponseOuterInfo from src.api.endpoints.batch.dtos.get.duplicates import GetDuplicatesByBatchResponse from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse -from src.api.endpoints.batch.dtos.get.status import GetBatchStatusResponse +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.api.endpoints.batch.dtos.get.urls import GetURLsByBatchResponse from src.api.endpoints.batch.dtos.post.abort import MessageResponse from src.api.endpoints.collector.dtos.collector_start import CollectorStartInfo @@ -62,8 +65,14 @@ async def shutdown(self): await self.collector_manager.shutdown_all_collectors() #region Batch - async def get_batch_info(self, batch_id: int) -> BatchInfo: - return await self.adb_client.get_batch_by_id(batch_id) + async def get_batch_info(self, batch_id: int) -> BatchSummary: + result = await self.adb_client.get_batch_by_id(batch_id) + if result is None: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Batch {batch_id} does not exist" + ) + return result async def get_urls_by_batch(self, batch_id: int, page: int = 1) -> GetURLsByBatchResponse: url_infos = await self.adb_client.get_urls_by_batch(batch_id, page) @@ -77,23 +86,20 @@ async def get_duplicate_urls_by_batch(self, batch_id: int, page: int = 1) -> Get dup_infos = await self.adb_client.get_duplicates_by_batch_id(batch_id, page=page) return GetDuplicatesByBatchResponse(duplicates=dup_infos) - async def get_batch_status(self, batch_id: int) -> BatchInfo: - return await self.adb_client.get_batch_by_id(batch_id) - async def get_batch_statuses( self, collector_type: Optional[CollectorType], status: Optional[BatchStatus], has_pending_urls: Optional[bool], page: int - ) -> GetBatchStatusResponse: - results = await self.adb_client.get_recent_batch_status_info( + ) -> GetBatchSummariesResponse: + results = await self.adb_client.get_batch_summaries( collector_type=collector_type, status=status, page=page, has_pending_urls=has_pending_urls ) - return GetBatchStatusResponse(results=results) + return results async def get_batch_logs(self, batch_id: int) -> GetBatchLogsResponse: logs = await self.adb_client.get_logs_by_batch_id(batch_id) diff --git a/src/db/client/async_.py b/src/db/client/async_.py index 05b394e7..1ab930c9 100644 --- a/src/db/client/async_.py +++ b/src/db/client/async_.py @@ -20,6 +20,8 @@ from src.api.endpoints.annotate.dtos.record_type.response import GetNextRecordTypeAnnotationResponseInfo from src.api.endpoints.annotate.dtos.relevance.response import GetNextRelevanceAnnotationResponseInfo from src.api.endpoints.annotate.dtos.shared.batch import AnnotationBatchInfo +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.api.endpoints.collector.dtos.manual_batch.post import ManualBatchInputDTO from src.api.endpoints.collector.dtos.manual_batch.response import ManualBatchResponseDTO from src.api.endpoints.metrics.dtos.get.backlog import GetMetricsBacklogResponseDTO, GetMetricsBacklogResponseInnerDTO @@ -52,8 +54,9 @@ from src.db.dtos.url_html_content_info import URLHTMLContentInfo, HTMLContentType from src.db.dtos.url_info import URLInfo from src.db.dtos.url_mapping import URLMapping +from src.db.queries.implementations.core.get_recent_batch_summaries.builder import GetRecentBatchSummariesQueryBuilder from src.db.statement_composer import StatementComposer -from src.db.constants import PLACEHOLDER_AGENCY_NAME +from src.db.constants import PLACEHOLDER_AGENCY_NAME, STANDARD_ROW_LIMIT from src.db.enums import TaskType from src.db.models.templates import Base from src.db.models.core import URL, URLErrorInfo, URLHTMLContent, \ @@ -1363,12 +1366,17 @@ async def reject_url( session.add(rejecting_user_url) @session_manager - async def get_batch_by_id(self, session, batch_id: int) -> Optional[BatchInfo]: + async def get_batch_by_id(self, session, batch_id: int) -> Optional[BatchSummary]: """Retrieve a batch by ID.""" - query = Select(Batch).where(Batch.id == batch_id) - result = await session.execute(query) - batch = result.scalars().first() - return BatchInfo(**batch.__dict__) + builder = GetRecentBatchSummariesQueryBuilder( + batch_id=batch_id + ) + summaries = await builder.run(session) + if len(summaries) == 0: + return None + batch_summary = summaries[0] + return batch_summary + @session_manager async def get_urls_by_batch(self, session, batch_id: int, page: int = 1) -> List[URLInfo]: @@ -1432,15 +1440,12 @@ async def insert_batch(self, session: AsyncSession, batch_info: BatchInfo) -> in user_id=batch_info.user_id, status=batch_info.status.value, parameters=batch_info.parameters, - total_url_count=batch_info.total_url_count, - original_url_count=batch_info.original_url_count, - duplicate_url_count=batch_info.duplicate_url_count, compute_time=batch_info.compute_time, - strategy_success_rate=batch_info.strategy_success_rate, - metadata_success_rate=batch_info.metadata_success_rate, - agency_match_rate=batch_info.agency_match_rate, - record_type_match_rate=batch_info.record_type_match_rate, - record_category_match_rate=batch_info.record_category_match_rate, + strategy_success_rate=0, + metadata_success_rate=0, + agency_match_rate=0, + record_type_match_rate=0, + record_category_match_rate=0, ) if batch_info.date_generated is not None: batch.date_generated = batch_info.date_generated @@ -1618,52 +1623,25 @@ async def get_duplicates_by_batch_id(self, session, batch_id: int, page: int) -> return final_results @session_manager - async def get_recent_batch_status_info( + async def get_batch_summaries( self, session, page: int, collector_type: Optional[CollectorType] = None, status: Optional[BatchStatus] = None, has_pending_urls: Optional[bool] = None - ) -> List[BatchInfo]: + ) -> GetBatchSummariesResponse: # Get only the batch_id, collector_type, status, and created_at - limit = 100 - query = Select(Batch) - if has_pending_urls is not None: - pending_url_subquery = Select(URL).where( - and_( - URL.batch_id == Batch.id, - URL.outcome == URLStatus.PENDING.value - ) - ) - - if has_pending_urls: - # Query for all that have pending URLs - query = query.where(exists( - pending_url_subquery - )) - else: - # Query for all that DO NOT have pending URLs - # (or that have no URLs at all) - query = query.where( - not_( - exists( - pending_url_subquery - ) - ) - ) - if collector_type: - query = query.filter(Batch.strategy == collector_type.value) - if status: - query = query.filter(Batch.status == status.value) - - query = (query. - order_by(Batch.date_generated.desc()). - limit(limit). - offset((page - 1) * limit)) - raw_results = await session.execute(query) - batches = raw_results.scalars().all() - return [BatchInfo(**batch.__dict__) for batch in batches] + builder = GetRecentBatchSummariesQueryBuilder( + page=page, + collector_type=collector_type, + status=status, + has_pending_urls=has_pending_urls + ) + summaries = await builder.run(session) + return GetBatchSummariesResponse( + results=summaries + ) @session_manager async def get_logs_by_batch_id(self, session, batch_id: int) -> List[LogOutputInfo]: diff --git a/src/db/client/sync.py b/src/db/client/sync.py index 67d432fc..8957ac92 100644 --- a/src/db/client/sync.py +++ b/src/db/client/sync.py @@ -61,15 +61,12 @@ def insert_batch(self, session, batch_info: BatchInfo) -> int: user_id=batch_info.user_id, status=batch_info.status.value, parameters=batch_info.parameters, - total_url_count=batch_info.total_url_count, - original_url_count=batch_info.original_url_count, - duplicate_url_count=batch_info.duplicate_url_count, compute_time=batch_info.compute_time, - strategy_success_rate=batch_info.strategy_success_rate, - metadata_success_rate=batch_info.metadata_success_rate, - agency_match_rate=batch_info.agency_match_rate, - record_type_match_rate=batch_info.record_type_match_rate, - record_category_match_rate=batch_info.record_category_match_rate, + strategy_success_rate=0, + metadata_success_rate=0, + agency_match_rate=0, + record_type_match_rate=0, + record_category_match_rate=0, ) if batch_info.date_generated is not None: batch.date_generated = batch_info.date_generated diff --git a/src/db/constants.py b/src/db/constants.py index 294c8fd9..1f39857d 100644 --- a/src/db/constants.py +++ b/src/db/constants.py @@ -1,3 +1,5 @@ -PLACEHOLDER_AGENCY_NAME = "PLACEHOLDER_AGENCY_NAME" \ No newline at end of file +PLACEHOLDER_AGENCY_NAME = "PLACEHOLDER_AGENCY_NAME" + +STANDARD_ROW_LIMIT = 100 \ No newline at end of file diff --git a/src/db/dtos/batch_info.py b/src/db/dtos/batch_info.py index db5505bc..3e1d265b 100644 --- a/src/db/dtos/batch_info.py +++ b/src/db/dtos/batch_info.py @@ -12,13 +12,6 @@ class BatchInfo(BaseModel): status: BatchStatus parameters: dict user_id: int - total_url_count: int = 0 - original_url_count: int = 0 - duplicate_url_count: int = 0 - strategy_success_rate: Optional[float] = None - metadata_success_rate: Optional[float] = None - agency_match_rate: Optional[float] = None - record_type_match_rate: Optional[float] = None - record_category_match_rate: Optional[float] = None + total_url_count: Optional[int] = None compute_time: Optional[float] = None date_generated: Optional[datetime] = None diff --git a/src/db/models/core.py b/src/db/models/core.py index 54c9b091..fe64b8df 100644 --- a/src/db/models/core.py +++ b/src/db/models/core.py @@ -43,11 +43,6 @@ class Batch(StandardModel): batch_status_enum, nullable=False ) - # The number of URLs in the batch - # TODO: Add means to update after execution - total_url_count = Column(Integer, nullable=False, default=0) - original_url_count = Column(Integer, nullable=False, default=0) - duplicate_url_count = Column(Integer, nullable=False, default=0) date_generated = Column(TIMESTAMP, nullable=False, server_default=CURRENT_TIME_SERVER_DEFAULT) # How often URLs ended up approved in the database strategy_success_rate = Column(Float) diff --git a/src/db/queries/README.md b/src/db/queries/README.md new file mode 100644 index 00000000..7918b0ba --- /dev/null +++ b/src/db/queries/README.md @@ -0,0 +1 @@ +This directory contains classes for building more complex queries. \ No newline at end of file diff --git a/src/db/queries/__init__.py b/src/db/queries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/base/__init__.py b/src/db/queries/base/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/base/builder.py b/src/db/queries/base/builder.py new file mode 100644 index 00000000..5806ef47 --- /dev/null +++ b/src/db/queries/base/builder.py @@ -0,0 +1,41 @@ +from typing import Any, Generic, Optional + +from sqlalchemy import FromClause, ColumnClause +from sqlalchemy.dialects import postgresql +from sqlalchemy.ext.asyncio import AsyncSession + +from src.db.types import LabelsType + + +class QueryBuilderBase(Generic[LabelsType]): + + def __init__(self, labels: Optional[LabelsType] = None): + self.query: Optional[FromClause] = None + self.labels = labels + + def get(self, key: str) -> ColumnClause: + return getattr(self.query.c, key) + + def get_all(self) -> list[Any]: + results = [] + for label in self.labels.get_all_labels(): + results.append(self.get(label)) + return results + + def __getitem__(self, key: str) -> ColumnClause: + return self.get(key) + + async def build(self) -> Any: + raise NotImplementedError + + async def run(self, session: AsyncSession) -> Any: + raise NotImplementedError + + @staticmethod + def compile(query) -> Any: + return query.compile( + dialect=postgresql.dialect(), + compile_kwargs={ + "literal_binds": True + } + ) diff --git a/src/db/queries/base/labels.py b/src/db/queries/base/labels.py new file mode 100644 index 00000000..50f812ce --- /dev/null +++ b/src/db/queries/base/labels.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass, fields + + +@dataclass(frozen=True) +class LabelsBase: + + def get_all_labels(self) -> list[str]: + return [getattr(self, f.name) for f in fields(self)] \ No newline at end of file diff --git a/src/db/queries/helpers.py b/src/db/queries/helpers.py new file mode 100644 index 00000000..c37984cb --- /dev/null +++ b/src/db/queries/helpers.py @@ -0,0 +1,8 @@ +from sqlalchemy.dialects import postgresql + +from src.db.constants import STANDARD_ROW_LIMIT + + +def add_page_offset(statement, page, limit=STANDARD_ROW_LIMIT): + offset = (page - 1) * limit + return statement.limit(limit).offset(offset) diff --git a/src/db/queries/implementations/__init__.py b/src/db/queries/implementations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/__init__.py b/src/db/queries/implementations/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/get_recent_batch_summaries/__init__.py b/src/db/queries/implementations/core/get_recent_batch_summaries/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/get_recent_batch_summaries/builder.py b/src/db/queries/implementations/core/get_recent_batch_summaries/builder.py new file mode 100644 index 00000000..c9958d6f --- /dev/null +++ b/src/db/queries/implementations/core/get_recent_batch_summaries/builder.py @@ -0,0 +1,76 @@ +from typing import Optional + +from sqlalchemy import Select +from sqlalchemy.ext.asyncio import AsyncSession + +from src.api.endpoints.batch.dtos.get.summaries.counts import BatchSummaryURLCounts +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary +from src.collectors.enums import CollectorType +from src.core.enums import BatchStatus +from src.db.models.core import Batch +from src.db.queries.base.builder import QueryBuilderBase +from src.db.queries.implementations.core.get_recent_batch_summaries.url_counts.builder import URLCountsCTEQueryBuilder +from src.db.queries.implementations.core.get_recent_batch_summaries.url_counts.labels import URLCountsLabels + + +class GetRecentBatchSummariesQueryBuilder(QueryBuilderBase): + + def __init__( + self, + page: int = 1, + has_pending_urls: Optional[bool] = None, + collector_type: Optional[CollectorType] = None, + status: Optional[BatchStatus] = None, + batch_id: Optional[int] = None, + ): + super().__init__() + self.url_counts_cte = URLCountsCTEQueryBuilder( + page=page, + has_pending_urls=has_pending_urls, + collector_type=collector_type, + status=status, + batch_id=batch_id, + ) + + async def run(self, session: AsyncSession) -> list[BatchSummary]: + self.url_counts_cte.build() + builder = self.url_counts_cte + count_labels: URLCountsLabels = builder.labels + + query = Select( + *builder.get_all(), + Batch.strategy, + Batch.status, + Batch.parameters, + Batch.user_id, + Batch.compute_time, + Batch.date_generated, + ).join( + builder.query, + builder.get(count_labels.batch_id) == Batch.id, + ) + raw_results = await session.execute(query) + + summaries: list[BatchSummary] = [] + for row in raw_results.mappings().all(): + summaries.append( + BatchSummary( + id=row.id, + strategy=row.strategy, + status=row.status, + parameters=row.parameters, + user_id=row.user_id, + compute_time=row.compute_time, + date_generated=row.date_generated, + url_counts=BatchSummaryURLCounts( + total=row[count_labels.total], + duplicate=row[count_labels.duplicate], + not_relevant=row[count_labels.not_relevant], + submitted=row[count_labels.submitted], + errored=row[count_labels.error], + pending=row[count_labels.pending], + ), + ) + ) + + return summaries diff --git a/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/__init__.py b/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/builder.py b/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/builder.py new file mode 100644 index 00000000..bd6c382a --- /dev/null +++ b/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/builder.py @@ -0,0 +1,105 @@ +from typing import Optional + +from sqlalchemy import Select, case, Label, and_, exists +from sqlalchemy.sql.functions import count, coalesce + +from src.collectors.enums import URLStatus, CollectorType +from src.core.enums import BatchStatus +from src.db.models.core import Batch, URL +from src.db.queries.base.builder import QueryBuilderBase +from src.db.queries.helpers import add_page_offset +from src.db.queries.implementations.core.get_recent_batch_summaries.url_counts.labels import URLCountsLabels + + +class URLCountsCTEQueryBuilder(QueryBuilderBase): + + def __init__( + self, + page: int = 1, + has_pending_urls: Optional[bool] = None, + collector_type: Optional[CollectorType] = None, + status: Optional[BatchStatus] = None, + batch_id: Optional[int] = None + ): + super().__init__(URLCountsLabels()) + self.page = page + self.has_pending_urls = has_pending_urls + self.collector_type = collector_type + self.status = status + self.batch_id = batch_id + + + def get_core_query(self): + labels: URLCountsLabels = self.labels + return ( + Select( + Batch.id.label(labels.batch_id), + coalesce(count(URL.id), 0).label(labels.total), + self.count_case_url_status(URLStatus.PENDING, labels.pending), + self.count_case_url_status(URLStatus.SUBMITTED, labels.submitted), + self.count_case_url_status(URLStatus.NOT_RELEVANT, labels.not_relevant), + self.count_case_url_status(URLStatus.ERROR, labels.error), + self.count_case_url_status(URLStatus.DUPLICATE, labels.duplicate), + ).outerjoin( + URL + ) + ) + + + def build(self): + query = self.get_core_query() + query = self.apply_pending_urls_filter(query) + query = self.apply_collector_type_filter(query) + query = self.apply_status_filter(query) + query = self.apply_batch_id_filter(query) + query = query.group_by(Batch.id) + query = add_page_offset(query, page=self.page) + query = query.order_by(Batch.id) + self.query = query.cte("url_counts") + + def apply_batch_id_filter(self, query: Select): + if self.batch_id is None: + return query + return query.where(Batch.id == self.batch_id) + + def apply_pending_urls_filter(self, query: Select): + if self.has_pending_urls is None: + return query + pending_url_subquery = ( + exists( + Select(URL).where( + and_( + URL.batch_id == Batch.id, + URL.outcome == URLStatus.PENDING.value + ) + ) + ) + ).correlate(Batch) + if self.has_pending_urls: + return query.where(pending_url_subquery) + return query.where(~pending_url_subquery) + + def apply_collector_type_filter(self, query: Select): + if self.collector_type is None: + return query + return query.where(Batch.strategy == self.collector_type.value) + + def apply_status_filter(self, query: Select): + if self.status is None: + return query + return query.where(Batch.status == self.status.value) + + @staticmethod + def count_case_url_status( + url_status: URLStatus, + label: str + ) -> Label: + return ( + coalesce( + count( + case( + (URL.outcome == url_status.value, 1) + ) + ) + , 0).label(label) + ) diff --git a/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/labels.py b/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/labels.py new file mode 100644 index 00000000..c55d8f45 --- /dev/null +++ b/src/db/queries/implementations/core/get_recent_batch_summaries/url_counts/labels.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from src.db.queries.base.labels import LabelsBase + + +@dataclass(frozen=True) +class URLCountsLabels(LabelsBase): + batch_id: str = "id" + total: str = "count_total" + pending: str = "count_pending" + submitted: str = "count_submitted" + not_relevant: str = "count_not_relevant" + error: str = "count_error" + duplicate: str = "count_duplicate" + + diff --git a/src/db/queries/protocols.py b/src/db/queries/protocols.py new file mode 100644 index 00000000..0098e953 --- /dev/null +++ b/src/db/queries/protocols.py @@ -0,0 +1,9 @@ +from typing import Protocol, Optional + +from sqlalchemy import Select + + +class HasQuery(Protocol): + + def __init__(self): + self.query: Optional[Select] = None diff --git a/src/db/statement_composer.py b/src/db/statement_composer.py index cd313b59..670bc7a6 100644 --- a/src/db/statement_composer.py +++ b/src/db/statement_composer.py @@ -4,10 +4,11 @@ from sqlalchemy.orm import aliased from src.collectors.enums import URLStatus +from src.core.enums import BatchStatus +from src.db.constants import STANDARD_ROW_LIMIT from src.db.enums import TaskType from src.db.models.core import URL, URLHTMLContent, AutomatedUrlAgencySuggestion, URLOptionalDataSourceMetadata, Batch, \ - ConfirmedURLAgency, LinkTaskURL, Task, UserUrlAgencySuggestion, UserRecordTypeSuggestion, UserRelevantSuggestion -from src.core.enums import BatchStatus + ConfirmedURLAgency, LinkTaskURL, Task from src.db.types import UserSuggestionType @@ -18,13 +19,14 @@ class StatementComposer: @staticmethod def pending_urls_without_html_data() -> Select: - exclude_subquery = (select(1). - select_from(LinkTaskURL). - join(Task, LinkTaskURL.task_id == Task.id). - where(LinkTaskURL.url_id == URL.id). - where(Task.task_type == TaskType.HTML.value). - where(Task.task_status == BatchStatus.READY_TO_LABEL.value) - ) + exclude_subquery = ( + select(1). + select_from(LinkTaskURL). + join(Task, LinkTaskURL.task_id == Task.id). + where(LinkTaskURL.url_id == URL.id). + where(Task.task_type == TaskType.HTML.value). + where(Task.task_status == BatchStatus.READY_TO_LABEL.value) + ) query = ( select(URL). outerjoin(URLHTMLContent). @@ -32,15 +34,12 @@ def pending_urls_without_html_data() -> Select: where(~exists(exclude_subquery)). where(URL.outcome == URLStatus.PENDING.value) ) - - return query - @staticmethod def exclude_urls_with_extant_model( - statement: Select, - model: Any + statement: Select, + model: Any ): return (statement.where( ~exists( @@ -51,9 +50,6 @@ def exclude_urls_with_extant_model( ) )) - - - @staticmethod def simple_count_subquery(model, attribute: str, label: str) -> Subquery: attr_value = getattr(model, attribute) @@ -79,7 +75,6 @@ def exclude_urls_with_agency_suggestions( ) return statement - @staticmethod def pending_urls_missing_miscellaneous_metadata_query() -> Select: query = select(URL).where( @@ -109,19 +104,13 @@ def user_suggestion_exists( ) return subquery - @staticmethod def user_suggestion_not_exists( - model_to_exclude: UserUrlAgencySuggestion or - UserRecordTypeSuggestion or - UserRelevantSuggestion + model_to_exclude: UserSuggestionType ) -> ColumnElement[bool]: - # - subquery = not_( - StatementComposer.user_suggestion_exists(model_to_exclude) - ) - + StatementComposer.user_suggestion_exists(model_to_exclude) + ) return subquery @staticmethod @@ -131,3 +120,13 @@ def count_distinct(field, label): @staticmethod def sum_distinct(field, label): return func.sum(func.distinct(field)).label(label) + + @staticmethod + def add_limit_and_page_offset(query: Select, page: int): + zero_offset_page = page - 1 + rows_offset = zero_offset_page * STANDARD_ROW_LIMIT + return query.offset( + rows_offset + ).limit( + STANDARD_ROW_LIMIT + ) diff --git a/src/db/types.py b/src/db/types.py index 40dc9ef3..44350e6d 100644 --- a/src/db/types.py +++ b/src/db/types.py @@ -1,3 +1,8 @@ +from typing import TypeVar + from src.db.models.core import UserUrlAgencySuggestion, UserRecordTypeSuggestion, UserRelevantSuggestion +from src.db.queries.base.labels import LabelsBase + +UserSuggestionType = UserUrlAgencySuggestion | UserRelevantSuggestion | UserRecordTypeSuggestion -UserSuggestionType = UserUrlAgencySuggestion | UserRelevantSuggestion | UserRecordTypeSuggestion \ No newline at end of file +LabelsType = TypeVar("LabelsType", bound=LabelsBase) \ No newline at end of file diff --git a/tests/automated/integration/api/helpers/RequestValidator.py b/tests/automated/integration/api/helpers/RequestValidator.py index 1e94f144..ac9d43f6 100644 --- a/tests/automated/integration/api/helpers/RequestValidator.py +++ b/tests/automated/integration/api/helpers/RequestValidator.py @@ -15,7 +15,8 @@ from src.api.endpoints.annotate.dtos.relevance.response import GetNextRelevanceAnnotationResponseOuterInfo from src.api.endpoints.batch.dtos.get.duplicates import GetDuplicatesByBatchResponse from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse -from src.api.endpoints.batch.dtos.get.status import GetBatchStatusResponse +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.api.endpoints.batch.dtos.get.urls import GetURLsByBatchResponse from src.api.endpoints.batch.dtos.post.abort import MessageResponse from src.api.endpoints.collector.dtos.manual_batch.post import ManualBatchInputDTO @@ -194,7 +195,7 @@ def get_batch_statuses( collector_type: Optional[CollectorType] = None, status: Optional[BatchStatus] = None, has_pending_urls: Optional[bool] = None - ) -> GetBatchStatusResponse: + ) -> GetBatchSummariesResponse: params = {} update_if_not_none( target=params, @@ -208,7 +209,7 @@ def get_batch_statuses( url=f"/batch", params=params ) - return GetBatchStatusResponse(**data) + return GetBatchSummariesResponse(**data) def example_collector(self, dto: ExampleInputDTO) -> dict: data = self.post( @@ -217,11 +218,11 @@ def example_collector(self, dto: ExampleInputDTO) -> dict: ) return data - def get_batch_info(self, batch_id: int) -> BatchInfo: + def get_batch_info(self, batch_id: int) -> BatchSummary: data = self.get( url=f"/batch/{batch_id}" ) - return BatchInfo(**data) + return BatchSummary(**data) def get_batch_urls(self, batch_id: int, page: int = 1) -> GetURLsByBatchResponse: data = self.get( diff --git a/tests/automated/integration/api/test_batch.py b/tests/automated/integration/api/test_batch.py index d4900736..2f654b55 100644 --- a/tests/automated/integration/api/test_batch.py +++ b/tests/automated/integration/api/test_batch.py @@ -5,9 +5,104 @@ from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO from src.collectors.enums import CollectorType, URLStatus from src.core.enums import BatchStatus +from tests.helpers.test_batch_creation_parameters import TestBatchCreationParameters, TestURLCreationParameters + + +@pytest.mark.asyncio +async def test_get_batch_summaries(api_test_helper): + ath = api_test_helper + + batch_params = [ + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=1, + status=URLStatus.PENDING + ), + TestURLCreationParameters( + count=2, + status=URLStatus.SUBMITTED + ) + ] + ), + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=4, + status=URLStatus.NOT_RELEVANT + ), + TestURLCreationParameters( + count=3, + status=URLStatus.ERROR + ) + ] + ), + TestBatchCreationParameters( + urls=[ + TestURLCreationParameters( + count=7, + status=URLStatus.DUPLICATE + ), + TestURLCreationParameters( + count=1, + status=URLStatus.SUBMITTED + ) + ] + ) + ] + + batch_1_creation_info = await ath.db_data_creator.batch_v2(batch_params[0]) + batch_2_creation_info = await ath.db_data_creator.batch_v2(batch_params[1]) + batch_3_creation_info = await ath.db_data_creator.batch_v2(batch_params[2]) + + batch_1_id = batch_1_creation_info.batch_id + batch_2_id = batch_2_creation_info.batch_id + batch_3_id = batch_3_creation_info.batch_id + + + response = ath.request_validator.get_batch_statuses() + results = response.results + + assert len(results) == 3 + + result_1 = results[0] + assert result_1.id == batch_1_id + assert result_1.status == BatchStatus.READY_TO_LABEL + counts_1 = result_1.url_counts + assert counts_1.total == 3 + assert counts_1.pending == 1 + assert counts_1.submitted == 2 + assert counts_1.not_relevant == 0 + assert counts_1.duplicate == 0 + assert counts_1.errored == 0 + + result_2 = results[1] + assert result_2.id == batch_2_id + counts_2 = result_2.url_counts + assert counts_2.total == 7 + assert counts_2.not_relevant == 4 + assert counts_2.errored == 3 + assert counts_2.pending == 0 + assert counts_2.submitted == 0 + assert counts_2.duplicate == 0 + + result_3 = results[2] + assert result_3.id == batch_3_id + counts_3 = result_3.url_counts + assert counts_3.total == 8 + assert counts_3.not_relevant == 0 + assert counts_3.errored == 0 + assert counts_3.pending == 0 + assert counts_3.submitted == 1 + assert counts_3.duplicate == 7 + + + + + @pytest.mark.asyncio -async def test_get_batch_status_pending_url_filter(api_test_helper): +async def test_get_batch_summaries_pending_url_filter(api_test_helper): ath = api_test_helper # Add an errored out batch diff --git a/tests/automated/integration/api/test_duplicates.py b/tests/automated/integration/api/test_duplicates.py index e1b45be9..49c1e15d 100644 --- a/tests/automated/integration/api/test_duplicates.py +++ b/tests/automated/integration/api/test_duplicates.py @@ -1,5 +1,6 @@ import pytest +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.db.dtos.batch_info import BatchInfo from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO from tests.automated.integration.api.conftest import disable_task_trigger @@ -31,13 +32,13 @@ async def test_duplicates(api_test_helper): await ath.wait_for_all_batches_to_complete() - bi_1: BatchInfo = ath.request_validator.get_batch_info(batch_id_1) - bi_2: BatchInfo = ath.request_validator.get_batch_info(batch_id_2) + bi_1: BatchSummary = ath.request_validator.get_batch_info(batch_id_1) + bi_2: BatchSummary = ath.request_validator.get_batch_info(batch_id_2) - bi_1.total_url_count = 2 - bi_2.total_url_count = 0 - bi_1.duplicate_url_count = 0 - bi_2.duplicate_url_count = 2 + bi_1.url_counts.total = 2 + bi_2.url_counts.total = 0 + bi_1.url_counts.duplicate = 0 + bi_2.url_counts.duplicate = 2 url_info_1 = ath.request_validator.get_batch_urls(batch_id_1) url_info_2 = ath.request_validator.get_batch_urls(batch_id_2) diff --git a/tests/automated/integration/api/test_example_collector.py b/tests/automated/integration/api/test_example_collector.py index 3f7f40fa..a1c5694f 100644 --- a/tests/automated/integration/api/test_example_collector.py +++ b/tests/automated/integration/api/test_example_collector.py @@ -4,7 +4,8 @@ import pytest from src.api.endpoints.batch.dtos.get.logs import GetBatchLogsResponse -from src.api.endpoints.batch.dtos.get.status import GetBatchStatusResponse +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse +from src.api.endpoints.batch.dtos.get.summaries.summary import BatchSummary from src.db.client.async_ import AsyncDatabaseClient from src.db.dtos.batch_info import BatchInfo from src.collectors.source_collectors.example.dtos.input import ExampleInputDTO @@ -47,7 +48,7 @@ async def test_example_collector(api_test_helper, monkeypatch): # Check that batch currently shows as In Process - bsr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses( + bsr: GetBatchSummariesResponse = ath.request_validator.get_batch_statuses( status=BatchStatus.IN_PROCESS ) assert len(bsr.results) == 1 @@ -62,21 +63,21 @@ async def test_example_collector(api_test_helper, monkeypatch): await ath.wait_for_all_batches_to_complete() - csr: GetBatchStatusResponse = ath.request_validator.get_batch_statuses( + csr: GetBatchSummariesResponse = ath.request_validator.get_batch_statuses( collector_type=CollectorType.EXAMPLE, status=BatchStatus.READY_TO_LABEL ) assert len(csr.results) == 1 - bsi: BatchInfo = csr.results[0] + bsi: BatchSummary = csr.results[0] assert bsi.id == batch_id assert bsi.strategy == CollectorType.EXAMPLE.value assert bsi.status == BatchStatus.READY_TO_LABEL - bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) + bi: BatchSummary = ath.request_validator.get_batch_info(batch_id=batch_id) assert bi.status == BatchStatus.READY_TO_LABEL - assert bi.total_url_count == 2 + assert bi.url_counts.total == 2 assert bi.parameters == dto.model_dump() assert bi.strategy == CollectorType.EXAMPLE.value assert bi.user_id is not None @@ -124,7 +125,7 @@ async def test_example_collector_error(api_test_helper, monkeypatch): await ath.wait_for_all_batches_to_complete() - bi: BatchInfo = ath.request_validator.get_batch_info(batch_id=batch_id) + bi: BatchSummary = ath.request_validator.get_batch_info(batch_id=batch_id) assert bi.status == BatchStatus.ERROR diff --git a/tests/automated/integration/collector_db/test_db_client.py b/tests/automated/integration/collector_db/test_db_client.py index 47fa5598..2269b98d 100644 --- a/tests/automated/integration/collector_db/test_db_client.py +++ b/tests/automated/integration/collector_db/test_db_client.py @@ -30,7 +30,7 @@ async def test_insert_urls( parameters={}, user_id=1 ) - batch_id = db_client_test.insert_batch(batch_info) + batch_id = await adb_client_test.insert_batch(batch_info) urls = [ URLInfo( diff --git a/tests/helpers/api_test_helper.py b/tests/helpers/api_test_helper.py index fa577b34..d5de78cd 100644 --- a/tests/helpers/api_test_helper.py +++ b/tests/helpers/api_test_helper.py @@ -1,7 +1,7 @@ import asyncio from dataclasses import dataclass -from src.api.endpoints.batch.dtos.get.status import GetBatchStatusResponse +from src.api.endpoints.batch.dtos.get.summaries.response import GetBatchSummariesResponse from src.core.core import AsyncCore from src.core.enums import BatchStatus from tests.automated.integration.api.helpers.RequestValidator import RequestValidator @@ -19,7 +19,7 @@ def adb_client(self): async def wait_for_all_batches_to_complete(self): for i in range(20): - data: GetBatchStatusResponse = self.request_validator.get_batch_statuses( + data: GetBatchSummariesResponse = self.request_validator.get_batch_statuses( status=BatchStatus.IN_PROCESS ) if len(data.results) == 0: diff --git a/tests/helpers/db_data_creator.py b/tests/helpers/db_data_creator.py index c96946ee..08f1220b 100644 --- a/tests/helpers/db_data_creator.py +++ b/tests/helpers/db_data_creator.py @@ -60,7 +60,6 @@ def batch( BatchInfo( strategy=strategy.value, status=batch_status, - total_url_count=1, parameters={"test_key": "test_value"}, user_id=1, date_generated=created_at