Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/routers/mldcat_ap/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Specific queries could be written to fetch e.g., a single feature or quality.
"""

import asyncio
from typing import Annotated

from fastapi import APIRouter, Depends, HTTPException
Expand Down Expand Up @@ -46,13 +47,16 @@ async def get_mldcat_ap_distribution(
) -> JsonLDGraph:
assert user_db is not None # noqa: S101
assert expdb is not None # noqa: S101
oml_dataset = await get_dataset(
dataset_id=distribution_id,
user=user,
user_db=user_db,
expdb_db=expdb,
oml_dataset, openml_features, oml_qualities = await asyncio.gather(
get_dataset(
dataset_id=distribution_id,
user=user,
user_db=user_db,
expdb_db=expdb,
),
get_dataset_features(distribution_id, user, expdb),
get_qualities(distribution_id, user, expdb),
)
openml_features = await get_dataset_features(distribution_id, user, expdb)
features = [
Feature(
id_=f"{_server_url}/feature/{distribution_id}/{feature.index}",
Expand All @@ -61,7 +65,6 @@ async def get_mldcat_ap_distribution(
)
for feature in openml_features
]
oml_qualities = await get_qualities(distribution_id, user, expdb)
qualities = [
Quality(
id_=f"{_server_url}/quality/{quality.name}/{distribution_id}",
Expand Down
17 changes: 11 additions & 6 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import re
from datetime import datetime
from enum import StrEnum
Expand Down Expand Up @@ -298,8 +299,10 @@ async def get_dataset_features(
) -> list[Feature]:
assert expdb is not None # noqa: S101
await _get_dataset_raise_otherwise(dataset_id, user, expdb)
features = await database.datasets.get_features(dataset_id, expdb)
ontologies = await database.datasets.get_feature_ontologies(dataset_id, expdb)
features, ontologies = await asyncio.gather(
database.datasets.get_features(dataset_id, expdb),
database.datasets.get_feature_ontologies(dataset_id, expdb),
)
for feature in features:
feature.ontology = ontologies.get(feature.index)

Expand Down Expand Up @@ -402,10 +405,12 @@ async def get_dataset(
msg = f"No data file found for dataset {dataset_id}."
raise DatasetNoDataFileError(msg)

tags = await database.datasets.get_tags_for(dataset_id, expdb_db)
description = await database.datasets.get_description(dataset_id, expdb_db)
processing_result = await _get_processing_information(dataset_id, expdb_db)
status = await database.datasets.get_status(dataset_id, expdb_db)
tags, description, processing_result, status = await asyncio.gather(
database.datasets.get_tags_for(dataset_id, expdb_db),
database.datasets.get_description(dataset_id, expdb_db),
_get_processing_information(dataset_id, expdb_db),
database.datasets.get_status(dataset_id, expdb_db),
)

status_ = DatasetStatus(status.status) if status else DatasetStatus.IN_PREPARATION

Expand Down
10 changes: 6 additions & 4 deletions src/routers/openml/flows.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Annotated, Literal

from fastapi import APIRouter, Depends
Expand Down Expand Up @@ -40,7 +41,11 @@ async def get_flow(
msg = f"Flow with id {flow_id} not found."
raise FlowNotFoundError(msg)

parameter_rows = await database.flows.get_parameters(flow_id, expdb)
parameter_rows, tags, subflow_rows = await asyncio.gather(
database.flows.get_parameters(flow_id, expdb),
database.flows.get_tags(flow_id, expdb),
database.flows.get_subflows(flow_id, expdb),
)
parameters = [
Parameter(
name=parameter.name,
Expand All @@ -53,9 +58,6 @@ async def get_flow(
)
for parameter in parameter_rows
]

tags = await database.flows.get_tags(flow_id, expdb)
subflow_rows = await database.flows.get_subflows(flow_id, expdb)
subflows = []
for subflow in subflow_rows:
subflows.append( # noqa: PERF401
Expand Down
17 changes: 11 additions & 6 deletions src/routers/openml/setups.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""All endpoints that relate to setups."""

import asyncio
from typing import Annotated

from fastapi import APIRouter, Body, Depends
Expand Down Expand Up @@ -27,11 +28,13 @@ async def tag_setup(
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
"""Add tag `tag` to setup with id `setup_id`."""
if not await database.setups.get(setup_id, expdb_db):
setup, setup_tags = await asyncio.gather(
database.setups.get(setup_id, expdb_db),
database.setups.get_tags(setup_id, expdb_db),
)
if not setup:
msg = f"Setup {setup_id} not found."
raise SetupNotFoundError(msg)

setup_tags = await database.setups.get_tags(setup_id, expdb_db)
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)

if matched_tag_row:
Expand All @@ -51,11 +54,13 @@ async def untag_setup(
expdb_db: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[str, dict[str, str | list[str]]]:
"""Remove tag `tag` from setup with id `setup_id`."""
if not await database.setups.get(setup_id, expdb_db):
setup, setup_tags = await asyncio.gather(
database.setups.get(setup_id, expdb_db),
database.setups.get_tags(setup_id, expdb_db),
)
if not setup:
msg = f"Setup {setup_id} not found."
raise SetupNotFoundError(msg)

setup_tags = await database.setups.get_tags(setup_id, expdb_db)
matched_tag_row = next((t for t in setup_tags if t.tag.casefold() == tag.casefold()), None)

if not matched_tag_row:
Expand Down
22 changes: 15 additions & 7 deletions src/routers/openml/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import json
import re
from typing import Annotated, cast
Expand Down Expand Up @@ -169,23 +170,30 @@ async def get_task(
msg = f"Task {task_id} has task type {task.ttid}, but task type {task.ttid} is not found."
raise InternalError(msg)

task_input_rows, ttios, tags = await asyncio.gather(
database.tasks.get_input_for_task(task_id, expdb),
database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb),
database.tasks.get_tags(task_id, expdb),
)
task_inputs = {
row.input: int(row.value) if row.value.isdigit() else row.value
for row in await database.tasks.get_input_for_task(task_id, expdb)
row.input: int(row.value) if row.value.isdigit() else row.value for row in task_input_rows
}
ttios = await database.tasks.get_task_type_inout_with_template(task_type.ttid, expdb)
templates = [(tt_io.name, tt_io.io, tt_io.requirement, tt_io.template_api) for tt_io in ttios]
input_templates = [
(name, template) for name, io, required, template in templates if io == "input"
]
filled_templates = await asyncio.gather(
*[fill_template(template, task, task_inputs, expdb) for name, template in input_templates],
)
inputs = [
await fill_template(template, task, task_inputs, expdb) | {"name": name}
for name, io, required, template in templates
if io == "input"
filled | {"name": name}
for (name, _), filled in zip(input_templates, filled_templates, strict=True)
]
outputs = [
convert_template_xml_to_json(template) | {"name": name}
for name, io, required, template in templates
if io == "output"
]
tags = await database.tasks.get_tags(task_id, expdb)
name = f"Task {task_id} ({task_type.name})"
dataset_id = task_inputs.get("source_data")
if isinstance(dataset_id, int) and (dataset := await database.datasets.get(dataset_id, expdb)):
Expand Down
71 changes: 38 additions & 33 deletions tests/routers/openml/migration/setups_migration_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import contextlib
import re
from collections.abc import AsyncGenerator, Callable, Iterable
Expand Down Expand Up @@ -114,14 +115,15 @@ async def test_setup_untag_response_is_identical_setup_doesnt_exist(
tag = "totally_new_tag_for_migration_testing"
api_key = ApiKey.SOME_USER

original = await php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

new = await py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
original, new = await asyncio.gather(
php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
),
py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
),
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
Expand All @@ -142,14 +144,15 @@ async def test_setup_untag_response_is_identical_tag_doesnt_exist(
tag = "totally_new_tag_for_migration_testing"
api_key = ApiKey.SOME_USER

original = await php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

new = await py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
original, new = await asyncio.gather(
php_api.post(
"/setup/untag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
),
py_api.post(
f"/setup/untag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
),
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
Expand Down Expand Up @@ -223,14 +226,15 @@ async def test_setup_tag_response_is_identical_setup_doesnt_exist(
tag = "totally_new_tag_for_migration_testing"
api_key = ApiKey.SOME_USER

original = await php_api.post(
"/setup/tag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

new = await py_api.post(
f"/setup/tag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
original, new = await asyncio.gather(
php_api.post(
"/setup/tag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
),
py_api.post(
f"/setup/tag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
),
)

assert original.status_code == HTTPStatus.PRECONDITION_FAILED
Expand All @@ -253,15 +257,16 @@ async def test_setup_tag_response_is_identical_tag_already_exists(
api_key = ApiKey.SOME_USER

async with temporary_tags(tags=[tag], setup_id=setup_id, persist=True):
original = await php_api.post(
"/setup/tag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
)

# In Python, since PHP committed it, it's also there for Python test context
new = await py_api.post(
f"/setup/tag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
# Both APIs can be tested in parallel since the tag is already persisted
original, new = await asyncio.gather(
php_api.post(
"/setup/tag",
data={"api_key": api_key, "tag": tag, "setup_id": setup_id},
),
py_api.post(
f"/setup/tag?api_key={api_key}",
json={"setup_id": setup_id, "tag": tag},
),
)

assert original.status_code == HTTPStatus.INTERNAL_SERVER_ERROR
Expand Down
Loading