Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions src/core/tasks/url/operators/record_type/core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from src.core.enums import RecordType
from src.core.tasks.url.operators.base import URLTaskOperatorBase
from src.core.tasks.url.operators.record_type.llm_api.record_classifier.openai import OpenAIRecordClassifier
from src.core.tasks.url.operators.record_type.queries.get import GetRecordTypeTaskURLsQueryBuilder
from src.core.tasks.url.operators.record_type.queries.prereq import RecordTypeTaskPrerequisiteQueryBuilder
from src.core.tasks.url.operators.record_type.tdo import URLRecordTypeTDO
from src.db.client.async_ import AsyncDatabaseClient
from src.db.dtos.url.with_html import URLWithHTML
from src.db.enums import TaskType
from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion
from src.db.models.impl.url.task_error.pydantic_.small import URLTaskErrorSmall


Expand All @@ -18,18 +22,22 @@
self.classifier = classifier

@property
def task_type(self):
def task_type(self) -> TaskType:

Check warning on line 25 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L25 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:25:1: D102 Missing docstring in public method
return TaskType.RECORD_TYPE

async def meets_task_prerequisites(self):
return await self.adb_client.has_urls_with_html_data_and_without_auto_record_type_suggestion()
async def meets_task_prerequisites(self) -> bool:

Check warning on line 28 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L28 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:28:1: D102 Missing docstring in public method
return await self.run_query_builder(
RecordTypeTaskPrerequisiteQueryBuilder()
)

async def get_tdos(self) -> list[URLRecordTypeTDO]:
urls_with_html = await self.adb_client.get_urls_with_html_data_and_without_auto_record_type_suggestion()
urls_with_html: list[URLWithHTML] = await self.run_query_builder(
GetRecordTypeTaskURLsQueryBuilder()
)
tdos = [URLRecordTypeTDO(url_with_html=url_with_html) for url_with_html in urls_with_html]
return tdos

async def inner_task_logic(self):
async def inner_task_logic(self) -> None:

Check warning on line 40 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L40 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:40:1: D102 Missing docstring in public method
# Get pending urls from Source Collector
# with HTML data and without Record Type Metadata
tdos = await self.get_tdos()
Expand All @@ -41,7 +49,10 @@
await self.put_results_into_database(success_subset)
await self.update_errors_in_database(error_subset)

async def update_errors_in_database(self, tdos: list[URLRecordTypeTDO]):
async def update_errors_in_database(

Check warning on line 52 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L52 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:52:1: D102 Missing docstring in public method
self,
tdos: list[URLRecordTypeTDO]
) -> None:
task_errors: list[URLTaskErrorSmall] = []
for tdo in tdos:
error_info = URLTaskErrorSmall(
Expand All @@ -51,20 +62,42 @@
task_errors.append(error_info)
await self.add_task_errors(task_errors)

async def put_results_into_database(self, tdos: list[URLRecordTypeTDO]):
suggestions = []
async def put_results_into_database(

Check warning on line 65 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L65 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:65:1: D102 Missing docstring in public method
self,
tdos: list[URLRecordTypeTDO]
) -> None:
url_and_record_type_list = []
for tdo in tdos:
url_id = tdo.url_with_html.url_id
record_type = tdo.record_type
suggestions.append((url_id, record_type))
await self.adb_client.add_auto_record_type_suggestions(suggestions)
url_and_record_type_list.append((url_id, record_type))
# Add to database
suggestions: list[AutoRecordTypeSuggestion] = []
for url_id, record_type in url_and_record_type_list:
suggestion = AutoRecordTypeSuggestion(
url_id=url_id,
record_type=record_type.value
)
suggestions.append(suggestion)
await self.adb_client.add_all(suggestions)

async def separate_success_and_error_subsets(self, tdos: list[URLRecordTypeTDO]):
@staticmethod
async def separate_success_and_error_subsets(

Check warning on line 85 in src/core/tasks/url/operators/record_type/core.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/core.py#L85 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/core.py:85:1: D102 Missing docstring in public method
tdos: list[URLRecordTypeTDO]
) -> tuple[list[URLRecordTypeTDO], list[URLRecordTypeTDO]]:
success_subset = [tdo for tdo in tdos if not tdo.is_errored()]
error_subset = [tdo for tdo in tdos if tdo.is_errored()]
return success_subset, error_subset

async def get_ml_classifications(self, tdos: list[URLRecordTypeTDO]):
async def get_ml_classifications(
self,
tdos: list[URLRecordTypeTDO]
) -> None:
"""
Modifies:
- tdo.record_type
- tdo.error
"""
for tdo in tdos:
try:
record_type_str = await self.classifier.classify_url(tdo.url_with_html.html_infos)
Expand Down
Empty file.
31 changes: 31 additions & 0 deletions src/core/tasks/url/operators/record_type/queries/cte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from sqlalchemy import select, CTE, Column

Check warning on line 1 in src/core/tasks/url/operators/record_type/queries/cte.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/cte.py#L1 <100>

Missing docstring in public module
Raw output
./src/core/tasks/url/operators/record_type/queries/cte.py:1:1: D100 Missing docstring in public module

from src.db.enums import TaskType
from src.db.helpers.query import not_exists_url, no_url_task_error
from src.db.models.impl.url.core.sqlalchemy import URL
from src.db.models.impl.url.html.compressed.sqlalchemy import URLCompressedHTML
from src.db.models.impl.url.suggestion.record_type.auto import AutoRecordTypeSuggestion


class RecordTypeTaskPrerequisiteCTEContainer:

Check warning on line 10 in src/core/tasks/url/operators/record_type/queries/cte.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/cte.py#L10 <101>

Missing docstring in public class
Raw output
./src/core/tasks/url/operators/record_type/queries/cte.py:10:1: D101 Missing docstring in public class

def __init__(self):

Check warning on line 12 in src/core/tasks/url/operators/record_type/queries/cte.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/cte.py#L12 <107>

Missing docstring in __init__
Raw output
./src/core/tasks/url/operators/record_type/queries/cte.py:12:1: D107 Missing docstring in __init__
self.cte: CTE = (
select(
URL.id
)
.join(
URLCompressedHTML
)
.where(
not_exists_url(AutoRecordTypeSuggestion),
no_url_task_error(
TaskType.RECORD_TYPE
)
)
.cte("record_type_task_prerequisite")
)

@property
def url_id(self) -> Column[int]:

Check warning on line 30 in src/core/tasks/url/operators/record_type/queries/cte.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/cte.py#L30 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/queries/cte.py:30:1: D102 Missing docstring in public method
return self.cte.columns.id

Check warning on line 31 in src/core/tasks/url/operators/record_type/queries/cte.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/cte.py#L31 <292>

no newline at end of file
Raw output
./src/core/tasks/url/operators/record_type/queries/cte.py:31:35: W292 no newline at end of file
36 changes: 36 additions & 0 deletions src/core/tasks/url/operators/record_type/queries/get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Sequence

Check warning on line 1 in src/core/tasks/url/operators/record_type/queries/get.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/get.py#L1 <100>

Missing docstring in public module
Raw output
./src/core/tasks/url/operators/record_type/queries/get.py:1:1: D100 Missing docstring in public module

from sqlalchemy import select, Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from src.core.tasks.url.operators.record_type.queries.cte import RecordTypeTaskPrerequisiteCTEContainer
from src.db.dto_converter import DTOConverter
from src.db.dtos.url.with_html import URLWithHTML
from src.db.models.impl.url.core.sqlalchemy import URL
from src.db.queries.base.builder import QueryBuilderBase


class GetRecordTypeTaskURLsQueryBuilder(QueryBuilderBase):

Check warning on line 14 in src/core/tasks/url/operators/record_type/queries/get.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/get.py#L14 <101>

Missing docstring in public class
Raw output
./src/core/tasks/url/operators/record_type/queries/get.py:14:1: D101 Missing docstring in public class

async def run(self, session: AsyncSession) -> list[URLWithHTML]:

Check warning on line 16 in src/core/tasks/url/operators/record_type/queries/get.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/get.py#L16 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/queries/get.py:16:1: D102 Missing docstring in public method
cte = RecordTypeTaskPrerequisiteCTEContainer()
query = (
select(
URL
)
.join(
cte.cte,
cte.url_id == URL.id
)
.options(
selectinload(URL.html_content)
)
.limit(100)
.order_by(URL.id)
)
urls: Sequence[Row[URL]] = await self.sh.scalars(
session=session,
query=query
)
return DTOConverter.url_list_to_url_with_html_list(urls)
18 changes: 18 additions & 0 deletions src/core/tasks/url/operators/record_type/queries/prereq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from sqlalchemy import select

Check warning on line 1 in src/core/tasks/url/operators/record_type/queries/prereq.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/prereq.py#L1 <100>

Missing docstring in public module
Raw output
./src/core/tasks/url/operators/record_type/queries/prereq.py:1:1: D100 Missing docstring in public module
from sqlalchemy.ext.asyncio import AsyncSession

from src.core.tasks.url.operators.record_type.queries.cte import RecordTypeTaskPrerequisiteCTEContainer
from src.db.queries.base.builder import QueryBuilderBase


class RecordTypeTaskPrerequisiteQueryBuilder(QueryBuilderBase):

Check warning on line 8 in src/core/tasks/url/operators/record_type/queries/prereq.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/prereq.py#L8 <101>

Missing docstring in public class
Raw output
./src/core/tasks/url/operators/record_type/queries/prereq.py:8:1: D101 Missing docstring in public class

async def run(self, session: AsyncSession) -> bool:

Check warning on line 10 in src/core/tasks/url/operators/record_type/queries/prereq.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/prereq.py#L10 <102>

Missing docstring in public method
Raw output
./src/core/tasks/url/operators/record_type/queries/prereq.py:10:1: D102 Missing docstring in public method
container = RecordTypeTaskPrerequisiteCTEContainer()
query = (
select(
container.url_id
)
)
return await self.sh.results_exist(session=session, query=query)

Check warning on line 18 in src/core/tasks/url/operators/record_type/queries/prereq.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] src/core/tasks/url/operators/record_type/queries/prereq.py#L18 <391>

blank line at end of file
Raw output
./src/core/tasks/url/operators/record_type/queries/prereq.py:18:1: W391 blank line at end of file
71 changes: 2 additions & 69 deletions src/db/client/async_.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from datetime import datetime
from functools import wraps
from typing import Optional, Type, Any, List, Sequence
from typing import Optional, Any, List

from sqlalchemy import select, func, Select, and_, update, Row, text, Engine
from sqlalchemy import select, func, Select, and_, update, Row, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker, AsyncEngine
from sqlalchemy.orm import selectinload

from src.api.endpoints.annotate.all.get.models.response import GetNextURLForAllAnnotationResponse
from src.api.endpoints.annotate.all.get.queries.core import GetNextURLForAllAnnotationQueryBuilder
Expand Down Expand Up @@ -50,7 +49,6 @@
from src.db.client.types import UserSuggestionModel
from src.db.config_manager import ConfigManager
from src.db.constants import PLACEHOLDER_AGENCY_NAME
from src.db.dto_converter import DTOConverter
from src.db.dtos.url.html_content import URLHTMLContentInfo
from src.db.dtos.url.insert import InsertURLsInfo
from src.db.dtos.url.raw_html import RawHTMLInfo
Expand Down Expand Up @@ -286,18 +284,6 @@ async def add_user_relevant_suggestion(

# region record_type

@session_manager
async def add_auto_record_type_suggestions(
self,
session: AsyncSession,
url_and_record_type_list: list[tuple[int, RecordType]]
):
for url_id, record_type in url_and_record_type_list:
suggestion = AutoRecordTypeSuggestion(
url_id=url_id,
record_type=record_type.value
)
session.add(suggestion)

async def add_auto_record_type_suggestion(
self,
Expand Down Expand Up @@ -381,59 +367,6 @@ async def add_miscellaneous_metadata(self, session: AsyncSession, tdos: list[URL
async def get_non_errored_urls_without_html_data(self) -> list[URLInfo]:
return await self.run_query_builder(GetPendingURLsWithoutHTMLDataQueryBuilder())

async def get_urls_with_html_data_and_without_models(
self,
session: AsyncSession,
model: Type[Base]
):
statement = (select(URL)
.options(selectinload(URL.html_content))
.where(URL.status == URLStatus.OK.value))
statement = self.statement_composer.exclude_urls_with_extant_model(
statement=statement,
model=model
)
statement = statement.limit(100).order_by(URL.id)
raw_result = await session.execute(statement)
urls: Sequence[Row[URL]] = raw_result.unique().scalars().all()
final_results = DTOConverter.url_list_to_url_with_html_list(urls)

return final_results

@session_manager
async def get_urls_with_html_data_and_without_auto_record_type_suggestion(
self,
session: AsyncSession
):
return await self.get_urls_with_html_data_and_without_models(
session=session,
model=AutoRecordTypeSuggestion
)

async def has_urls_with_html_data_and_without_models(
self,
session: AsyncSession,
model: Type[Base]
) -> bool:
statement = (select(URL)
.join(URLCompressedHTML)
.where(URL.status == URLStatus.OK.value))
# Exclude URLs with auto suggested record types
statement = self.statement_composer.exclude_urls_with_extant_model(
statement=statement,
model=model
)
statement = statement.limit(1)
scalar_result = await session.scalars(statement)
return bool(scalar_result.first())

@session_manager
async def has_urls_with_html_data_and_without_auto_record_type_suggestion(self, session: AsyncSession) -> bool:
return await self.has_urls_with_html_data_and_without_models(
session=session,
model=AutoRecordTypeSuggestion
)

@session_manager
async def one_or_none_model(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.core.tasks.url.enums import TaskOperatorOutcome
from src.core.tasks.url.operators.record_type.core import URLRecordTypeTaskOperator
from src.core.enums import RecordType
from src.db.models.impl.url.task_error.sqlalchemy import URLTaskError
from tests.helpers.data_creator.core import DBDataCreator
from src.core.tasks.url.operators.record_type.llm_api.record_classifier.deepseek import DeepSeekRecordClassifier

Expand Down Expand Up @@ -52,3 +53,9 @@
for suggestion in suggestions:
assert suggestion.record_type == RecordType.ACCIDENT_REPORTS.value

# Get URL Error Tasks
url_error_tasks: list[URLTaskError] = await db_data_creator.adb_client.get_all(URLTaskError)
assert len(url_error_tasks) == 1
url_error_task = url_error_tasks[0]
assert url_error_task.url_id == url_ids[1]
assert url_error_task.task_type == TaskType.RECORD_TYPE

Check warning on line 61 in tests/automated/integration/tasks/url/impl/test_url_record_type_task.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] tests/automated/integration/tasks/url/impl/test_url_record_type_task.py#L61 <292>

no newline at end of file
Raw output
./tests/automated/integration/tasks/url/impl/test_url_record_type_task.py:61:60: W292 no newline at end of file