diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5355e45..afd8a5e 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,4 +1,6 @@ import json +import logging +from collections.abc import Mapping from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends @@ -10,6 +12,8 @@ from database.tasks import get_task_type as db_get_task_type from routers.dependencies import expdb_connection +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/tasktype", tags=["tasks"]) @@ -26,6 +30,75 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any return ttype +def parse_api_constraints( + api_constraints: str | Mapping[str, object] | None, + *, + task_type_id: int, + input_name: str, +) -> str | None: + """Defensively parse api_constraints and extract a valid data_type string. + + Malformed api_constraints will not raise errors; instead they are logged + and ignored for response construction. Returns a non-empty data_type string + on success, or None if the value cannot be parsed or does not contain a + valid data_type. + """ + if api_constraints is None: + return None + + constraint: Mapping[str, object] | None = None + + if isinstance(api_constraints, Mapping): + constraint = api_constraints + elif isinstance(api_constraints, str): + if not api_constraints: + logger.warning( + "api_constraints: empty_string for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + else: + try: + parsed = json.loads(api_constraints) + except json.JSONDecodeError: + logger.warning( + "api_constraints: malformed_json for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + else: + if isinstance(parsed, dict): + constraint = parsed + else: + logger.warning( + "api_constraints: non_dict_json for task_type_id=%d, input=%s (got %s)", + task_type_id, + input_name, + type(parsed).__name__, + ) + else: + logger.warning( + "api_constraints: unsupported_type for task_type_id=%d, input=%s (got %s)", + task_type_id, + input_name, + type(api_constraints).__name__, + ) + + if constraint is None: + return None + + data_type = constraint.get("data_type") + if not isinstance(data_type, str) or not data_type: + logger.debug( + "api_constraints: missing_data_type for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + return None + + return data_type + + @router.get(path="/list") async def list_task_types( expdb: Annotated[AsyncConnection, Depends(expdb_connection)], @@ -44,6 +117,13 @@ async def get_task_type( task_type_id: int, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]: + """Retrieve a task type by ID. + + Response contract: + - Always returns 200 for valid task types. + - input[].data_type is optional and only included when valid constraints exist. + - Invalid api_constraints never break the response. + """ task_type_record = await db_get_task_type(task_type_id, expdb) if task_type_record is None: msg = f"Task type {task_type_id} not found." @@ -66,10 +146,16 @@ async def get_task_type( if task_type_input.requirement == "required": input_["requirement"] = task_type_input.requirement input_["name"] = task_type_input.name - # api_constraints is for one input only in the test database (TODO: patch db) - if isinstance(task_type_input.api_constraints, str): - constraint = json.loads(task_type_input.api_constraints) - input_["data_type"] = constraint["data_type"] + # data_type is optional and only included when valid constraints exist. + # Malformed api_constraints will not raise errors; instead they are + # logged and ignored for response construction. + data_type = parse_api_constraints( + task_type_input.api_constraints, + task_type_id=task_type_id, + input_name=task_type_input.name, + ) + if data_type is not None: + input_["data_type"] = data_type input_types.append(input_) task_type["input"] = input_types return {"task_type": task_type} diff --git a/tests/routers/openml/test_parse_api_constraints.py b/tests/routers/openml/test_parse_api_constraints.py new file mode 100644 index 0000000..9be9048 --- /dev/null +++ b/tests/routers/openml/test_parse_api_constraints.py @@ -0,0 +1,129 @@ +import logging + +import pytest + +from routers.openml.tasktype import parse_api_constraints + + +@pytest.mark.parametrize( + ("api_constraints", "expected_data_type"), + [ + # 1. Valid JSON string with data_type + ('{"data_type": "matrix"}', "matrix"), + # 2. Valid dict with data_type + ({"data_type": "matrix"}, "matrix"), + # 3. Malformed JSON string → None + ("{bad json", None), + # 4. Empty string → None + ("", None), + # 5. Valid JSON/dict without data_type → None + ('{"other_key": "val"}', None), + ({"other_key": "val"}, None), + # 6. data_type present but empty string → None + ('{"data_type": ""}', None), + ({"data_type": ""}, None), + # 7. Non-dict JSON (list) → None + ('["array"]', None), + # 8. None → None + (None, None), + # 9. Non-dict JSON (int) → None + ("42", None), + # 10. data_type is non-string (int) → None + ('{"data_type": 123}', None), + ({"data_type": 123}, None), + ], + ids=[ + "valid_json_string", + "valid_dict", + "malformed_json", + "empty_string", + "json_missing_data_type", + "dict_missing_data_type", + "json_empty_data_type", + "dict_empty_data_type", + "non_dict_json_list", + "none_value", + "non_dict_json_int", + "json_non_string_data_type", + "dict_non_string_data_type", + ], +) +def test_parse_api_constraints( + api_constraints: object, + expected_data_type: str | None, +) -> None: + result = parse_api_constraints( + api_constraints, + task_type_id=1, + input_name="source_data", + ) + assert result == expected_data_type + + +def test_parse_api_constraints_unsupported_type() -> None: + """Unsupported types (e.g. int, list passed directly) should return None.""" + result = parse_api_constraints( + 12345, + task_type_id=1, + input_name="source_data", + ) + assert result is None + + +class TestParseApiConstraintsLogging: + """Verify correct log levels are emitted for different anomaly types.""" + + def test_malformed_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + "{bad json", + task_type_id=1, + input_name="source_data", + ) + assert any("malformed_json" in r.message for r in caplog.records) + + def test_empty_string_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + "", + task_type_id=1, + input_name="source_data", + ) + assert any("empty_string" in r.message for r in caplog.records) + + def test_non_dict_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + '["array"]', + task_type_id=1, + input_name="source_data", + ) + assert any("non_dict_json" in r.message for r in caplog.records) + + def test_unsupported_type_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + 12345, + task_type_id=1, + input_name="source_data", + ) + assert any("unsupported_type" in r.message for r in caplog.records) + + def test_missing_data_type_logs_debug(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.DEBUG): + parse_api_constraints( + '{"other_key": "val"}', + task_type_id=1, + input_name="source_data", + ) + assert any("missing_data_type" in r.message for r in caplog.records) + + def test_valid_constraint_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.DEBUG): + result = parse_api_constraints( + '{"data_type": "matrix"}', + task_type_id=1, + input_name="source_data", + ) + assert result == "matrix" + assert not any(r.name == "routers.openml.tasktype" for r in caplog.records)