Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 90 additions & 4 deletions src/routers/openml/tasktype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import json
import logging
from collections.abc import Mapping
from typing import Annotated, Any, Literal, cast

from fastapi import APIRouter, Depends
Expand All @@ -10,6 +12,8 @@
from database.tasks import get_task_type as db_get_task_type
from routers.dependencies import expdb_connection

logger = logging.getLogger(__name__)

router = APIRouter(prefix="/tasktype", tags=["tasks"])


Expand All @@ -26,6 +30,75 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any
return ttype


def parse_api_constraints(
api_constraints: str | Mapping[str, object] | None,
*,
task_type_id: int,
input_name: str,
) -> str | None:
"""Defensively parse api_constraints and extract a valid data_type string.

Malformed api_constraints will not raise errors; instead they are logged
and ignored for response construction. Returns a non-empty data_type string
on success, or None if the value cannot be parsed or does not contain a
valid data_type.
"""
if api_constraints is None:
return None

constraint: Mapping[str, object] | None = None

if isinstance(api_constraints, Mapping):
constraint = api_constraints
elif isinstance(api_constraints, str):
if not api_constraints:
logger.warning(
"api_constraints: empty_string for task_type_id=%d, input=%s",
task_type_id,
input_name,
)
else:
try:
parsed = json.loads(api_constraints)
except json.JSONDecodeError:
logger.warning(
"api_constraints: malformed_json for task_type_id=%d, input=%s",
task_type_id,
input_name,
)
else:
if isinstance(parsed, dict):
constraint = parsed
else:
logger.warning(
"api_constraints: non_dict_json for task_type_id=%d, input=%s (got %s)",
task_type_id,
input_name,
type(parsed).__name__,
)
else:
logger.warning(
"api_constraints: unsupported_type for task_type_id=%d, input=%s (got %s)",
task_type_id,
input_name,
type(api_constraints).__name__,
)

if constraint is None:
return None

data_type = constraint.get("data_type")
if not isinstance(data_type, str) or not data_type:
logger.debug(
"api_constraints: missing_data_type for task_type_id=%d, input=%s",
task_type_id,
input_name,
)
return None

return data_type


@router.get(path="/list")
async def list_task_types(
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
Expand All @@ -44,6 +117,13 @@ async def get_task_type(
task_type_id: int,
expdb: Annotated[AsyncConnection, Depends(expdb_connection)],
) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]:
"""Retrieve a task type by ID.

Response contract:
- Always returns 200 for valid task types.
- input[].data_type is optional and only included when valid constraints exist.
- Invalid api_constraints never break the response.
"""
task_type_record = await db_get_task_type(task_type_id, expdb)
if task_type_record is None:
msg = f"Task type {task_type_id} not found."
Expand All @@ -66,10 +146,16 @@ async def get_task_type(
if task_type_input.requirement == "required":
input_["requirement"] = task_type_input.requirement
input_["name"] = task_type_input.name
# api_constraints is for one input only in the test database (TODO: patch db)
if isinstance(task_type_input.api_constraints, str):
constraint = json.loads(task_type_input.api_constraints)
input_["data_type"] = constraint["data_type"]
# data_type is optional and only included when valid constraints exist.
# Malformed api_constraints will not raise errors; instead they are
# logged and ignored for response construction.
data_type = parse_api_constraints(
task_type_input.api_constraints,
task_type_id=task_type_id,
input_name=task_type_input.name,
)
if data_type is not None:
input_["data_type"] = data_type
input_types.append(input_)
task_type["input"] = input_types
return {"task_type": task_type}
129 changes: 129 additions & 0 deletions tests/routers/openml/test_parse_api_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import logging

import pytest

from routers.openml.tasktype import parse_api_constraints


@pytest.mark.parametrize(
("api_constraints", "expected_data_type"),
[
# 1. Valid JSON string with data_type
('{"data_type": "matrix"}', "matrix"),
# 2. Valid dict with data_type
({"data_type": "matrix"}, "matrix"),
# 3. Malformed JSON string → None
("{bad json", None),
# 4. Empty string → None
("", None),
# 5. Valid JSON/dict without data_type → None
('{"other_key": "val"}', None),
({"other_key": "val"}, None),
# 6. data_type present but empty string → None
('{"data_type": ""}', None),
({"data_type": ""}, None),
# 7. Non-dict JSON (list) → None
('["array"]', None),
# 8. None → None
(None, None),
# 9. Non-dict JSON (int) → None
("42", None),
# 10. data_type is non-string (int) → None
('{"data_type": 123}', None),
({"data_type": 123}, None),
],
ids=[
"valid_json_string",
"valid_dict",
"malformed_json",
"empty_string",
"json_missing_data_type",
"dict_missing_data_type",
"json_empty_data_type",
"dict_empty_data_type",
"non_dict_json_list",
"none_value",
"non_dict_json_int",
"json_non_string_data_type",
"dict_non_string_data_type",
],
)
def test_parse_api_constraints(
api_constraints: object,
expected_data_type: str | None,
) -> None:
result = parse_api_constraints(
api_constraints,
task_type_id=1,
input_name="source_data",
)
assert result == expected_data_type


def test_parse_api_constraints_unsupported_type() -> None:
"""Unsupported types (e.g. int, list passed directly) should return None."""
result = parse_api_constraints(
12345,
task_type_id=1,
input_name="source_data",
)
assert result is None


class TestParseApiConstraintsLogging:
"""Verify correct log levels are emitted for different anomaly types."""

def test_malformed_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING):
parse_api_constraints(
"{bad json",
task_type_id=1,
input_name="source_data",
)
assert any("malformed_json" in r.message for r in caplog.records)

def test_empty_string_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING):
parse_api_constraints(
"",
task_type_id=1,
input_name="source_data",
)
assert any("empty_string" in r.message for r in caplog.records)

def test_non_dict_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING):
parse_api_constraints(
'["array"]',
task_type_id=1,
input_name="source_data",
)
assert any("non_dict_json" in r.message for r in caplog.records)

def test_unsupported_type_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.WARNING):
parse_api_constraints(
12345,
task_type_id=1,
input_name="source_data",
)
assert any("unsupported_type" in r.message for r in caplog.records)

def test_missing_data_type_logs_debug(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
parse_api_constraints(
'{"other_key": "val"}',
task_type_id=1,
input_name="source_data",
)
assert any("missing_data_type" in r.message for r in caplog.records)

def test_valid_constraint_no_warning(self, caplog: pytest.LogCaptureFixture) -> None:
with caplog.at_level(logging.DEBUG):
result = parse_api_constraints(
'{"data_type": "matrix"}',
task_type_id=1,
input_name="source_data",
)
assert result == "matrix"
assert not any(r.name == "routers.openml.tasktype" for r in caplog.records)