Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 40 additions & 35 deletions collector_db/AsyncDatabaseClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,27 +1018,44 @@ def annotations_exist_subquery(model: Type[Base]):
).subquery()
)

def count_subquery(model: Type[Base]):
return (
select(
model.url_id,
func.count(model.url_id).label("count")
).group_by(model.url_id).subquery()
)
user_models = [
UserRelevantSuggestion,
UserRecordTypeSuggestion,
UserUrlAgencySuggestion
]

models = [
AutoRelevantSuggestion,
UserRelevantSuggestion,
AutoRecordTypeSuggestion,
UserRecordTypeSuggestion,
AutomatedUrlAgencySuggestion,
UserUrlAgencySuggestion
*user_models
]

# The below relationships are joined directly to the URL
single_join_relationships = [
URL.html_content,
URL.auto_record_type_suggestion,
URL.auto_relevant_suggestion,
URL.user_relevant_suggestion,
URL.user_record_type_suggestion,
URL.optional_data_source_metadata,
]

# The below relationships are joined to entities that are joined to the URL
double_join_relationships = [
(URL.automated_agency_suggestions, AutomatedUrlAgencySuggestion.agency),
(URL.user_agency_suggestion, UserUrlAgencySuggestion.agency),
(URL.confirmed_agencies, ConfirmedURLAgency.agency)
]

exist_subqueries = [
annotations_exist_subquery(model=model)
for model in models
]
user_exist_subqueries = [
annotations_exist_subquery(model=model)
for model in user_models
]

sum_of_exist_subqueries = (
sum(
Expand All @@ -1064,39 +1081,27 @@ def count_subquery(model: Type[Base]):
subquery, URL.id == subquery.c.url_id
)

where_subqueries = [
subquery.c.exists == 1
for subquery in user_exist_subqueries
]

url_query = url_query.where(
URL.outcome == URLStatus.PENDING.value
and_(
URL.outcome == URLStatus.PENDING.value,
*where_subqueries
)
)
if batch_id is not None:
url_query = url_query.where(
URL.batch_id == batch_id
)

# The below relationships are joined directly to the URL
single_join_relationships = [
URL.html_content,
URL.auto_record_type_suggestion,
URL.auto_relevant_suggestion,
URL.user_relevant_suggestion,
URL.user_record_type_suggestion,
URL.optional_data_source_metadata,
]

options = [
joinedload(relationship) for relationship in single_join_relationships
]

# The below relationships are joined to entities that are joined to the URL
double_join_relationships = [
(URL.automated_agency_suggestions, AutomatedUrlAgencySuggestion.agency),
(URL.user_agency_suggestion, UserUrlAgencySuggestion.agency),
(URL.confirmed_agencies, ConfirmedURLAgency.agency)
]
for primary, secondary in double_join_relationships:
options.append(joinedload(primary).joinedload(secondary))

# Apply options
url_query = url_query.options(*options)
url_query = url_query.options(
*[joinedload(relationship) for relationship in single_join_relationships],
*[joinedload(primary).joinedload(secondary) for primary, secondary in double_join_relationships]
)

# Apply order clause
url_query = url_query.order_by(
Expand Down
2 changes: 0 additions & 2 deletions core/TaskManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def __init__(
self.task_trigger = FunctionTrigger(self.run_tasks)
self.task_status: TaskType = TaskType.IDLE



#region Task Operators
async def get_url_html_task_operator(self):
operator = URLHTMLTaskOperator(
Expand Down
19 changes: 2 additions & 17 deletions tests/test_automated/integration/collector_db/test_db_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ async def test_get_next_url_for_final_review_favor_more_components(db_data_creat
async def test_get_next_url_for_final_review_no_annotations(db_data_creator: DBDataCreator):
"""
Test in the case of one URL with no annotations.
Should be returned if it is the only one available.
No annotations should be returned
"""
batch_id = db_data_creator.batch()
url_mapping = db_data_creator.urls(batch_id=batch_id, url_count=1).url_mappings[0]
Expand All @@ -282,22 +282,7 @@ async def test_get_next_url_for_final_review_no_annotations(db_data_creator: DBD
batch_id=None
)

assert result.id == url_mapping.url_id

annotations = result.annotations

agency = annotations.agency
assert agency.confirmed == []
assert agency.auto.unknown is True
assert agency.auto.suggestions == []

record_type = annotations.record_type
assert record_type.auto is None
assert record_type.user is None

relevant = annotations.relevant
assert relevant.auto is None
assert relevant.user is None
assert result is None

@pytest.mark.asyncio
async def test_get_next_url_for_final_review_only_confirmed_urls(db_data_creator: DBDataCreator):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,45 +65,3 @@ async def test_example_collector_lifecycle(

assert url_infos[0].url == "https://example.com"
assert url_infos[1].url == "https://example.com/2"

@pytest.mark.asyncio
async def test_example_collector_lifecycle_multiple_batches(
test_core: SourceCollectorCore,
test_async_core: AsyncCore,
monkeypatch
):
"""
Test the flow of an example collector, which generates fake urls
and saves them to the database
"""
barrier = await block_sleep(monkeypatch)
acore = test_async_core
core = test_core
csis: list[CollectorStartInfo] = []


for i in range(3):
dto = ExampleInputDTO(
example_field="example_value",
sleep_time=1
)
csi: CollectorStartInfo = await acore.initiate_collector(
collector_type=CollectorType.EXAMPLE,
dto=dto,
user_id=1
)
csis.append(csi)

await asyncio.sleep(0)

for csi in csis:
print("Batch ID:", csi.batch_id)
assert core.get_status(csi.batch_id) == BatchStatus.IN_PROCESS

barrier.release()

await asyncio.sleep(0.15)

for csi in csis:
assert core.get_status(csi.batch_id) == BatchStatus.READY_TO_LABEL