From f08c48756267cd3c1def8906b668dd8979863788 Mon Sep 17 00:00:00 2001 From: Gokul-social Date: Sat, 21 Mar 2026 20:32:52 +0530 Subject: [PATCH 1/4] fix(tasktype): handle malformed api_constraints safely without crashing (#273) --- src/routers/openml/tasktype.py | 93 ++++++++++++- .../openml/test_parse_api_constraints.py | 129 ++++++++++++++++++ 2 files changed, 218 insertions(+), 4 deletions(-) create mode 100644 tests/routers/openml/test_parse_api_constraints.py diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 5355e451..8477745d 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,4 +1,5 @@ import json +import logging from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends @@ -10,6 +11,8 @@ from database.tasks import get_task_type as db_get_task_type from routers.dependencies import expdb_connection +logger = logging.getLogger(__name__) + router = APIRouter(prefix="/tasktype", tags=["tasks"]) @@ -26,6 +29,75 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any return ttype +def parse_api_constraints( + api_constraints: Any, + *, + task_type_id: int, + input_name: str, +) -> str | None: + """Defensively parse api_constraints and extract a valid data_type string. + + Malformed api_constraints will not raise errors; instead they are logged + and ignored for response construction. Returns a non-empty data_type string + on success, or None if the value cannot be parsed or does not contain a + valid data_type. + """ + constraint: dict[str, Any] | None = None + + if api_constraints is None: + return None + + if isinstance(api_constraints, dict): + constraint = api_constraints + elif isinstance(api_constraints, str): + if not api_constraints: + logger.warning( + "api_constraints: empty_string for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + return None + try: + parsed = json.loads(api_constraints) + except json.JSONDecodeError: + logger.warning( + "api_constraints: malformed_json for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + return None + if not isinstance(parsed, dict): + logger.warning( + "api_constraints: non_dict_json for task_type_id=%d, input=%s " + "(got %s)", + task_type_id, + input_name, + type(parsed).__name__, + ) + return None + constraint = parsed + else: + logger.warning( + "api_constraints: unsupported_type for task_type_id=%d, input=%s " + "(got %s)", + task_type_id, + input_name, + type(api_constraints).__name__, + ) + return None + + data_type = constraint.get("data_type") + if not isinstance(data_type, str) or not data_type: + logger.debug( + "api_constraints: missing_data_type for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + return None + + return data_type + + @router.get(path="/list") async def list_task_types( expdb: Annotated[AsyncConnection, Depends(expdb_connection)], @@ -44,6 +116,13 @@ async def get_task_type( task_type_id: int, expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[Literal["task_type"], dict[str, str | None | list[str] | list[dict[str, str]]]]: + """Retrieve a task type by ID. + + Response contract: + - Always returns 200 for valid task types. + - input[].data_type is optional and only included when valid constraints exist. + - Invalid api_constraints never break the response. + """ task_type_record = await db_get_task_type(task_type_id, expdb) if task_type_record is None: msg = f"Task type {task_type_id} not found." @@ -66,10 +145,16 @@ async def get_task_type( if task_type_input.requirement == "required": input_["requirement"] = task_type_input.requirement input_["name"] = task_type_input.name - # api_constraints is for one input only in the test database (TODO: patch db) - if isinstance(task_type_input.api_constraints, str): - constraint = json.loads(task_type_input.api_constraints) - input_["data_type"] = constraint["data_type"] + # data_type is optional and only included when valid constraints exist. + # Malformed api_constraints will not raise errors; instead they are + # logged and ignored for response construction. + data_type = parse_api_constraints( + task_type_input.api_constraints, + task_type_id=task_type_id, + input_name=task_type_input.name, + ) + if data_type is not None: + input_["data_type"] = data_type input_types.append(input_) task_type["input"] = input_types return {"task_type": task_type} diff --git a/tests/routers/openml/test_parse_api_constraints.py b/tests/routers/openml/test_parse_api_constraints.py new file mode 100644 index 00000000..fa691892 --- /dev/null +++ b/tests/routers/openml/test_parse_api_constraints.py @@ -0,0 +1,129 @@ +import logging + +import pytest + +from routers.openml.tasktype import parse_api_constraints + + +@pytest.mark.parametrize( + ("api_constraints", "expected_data_type"), + [ + # 1. Valid JSON string with data_type + ('{"data_type": "matrix"}', "matrix"), + # 2. Valid dict with data_type + ({"data_type": "matrix"}, "matrix"), + # 3. Malformed JSON string → None + ("{bad json", None), + # 4. Empty string → None + ("", None), + # 5. Valid JSON/dict without data_type → None + ('{"other_key": "val"}', None), + ({"other_key": "val"}, None), + # 6. data_type present but empty string → None + ('{"data_type": ""}', None), + ({"data_type": ""}, None), + # 7. Non-dict JSON (list) → None + ('["array"]', None), + # 8. None → None + (None, None), + # 9. Non-dict JSON (int) → None + ("42", None), + # 10. data_type is non-string (int) → None + ('{"data_type": 123}', None), + ({"data_type": 123}, None), + ], + ids=[ + "valid_json_string", + "valid_dict", + "malformed_json", + "empty_string", + "json_missing_data_type", + "dict_missing_data_type", + "json_empty_data_type", + "dict_empty_data_type", + "non_dict_json_list", + "none_value", + "non_dict_json_int", + "json_non_string_data_type", + "dict_non_string_data_type", + ], +) +def test_parse_api_constraints( + api_constraints: object, + expected_data_type: str | None, +) -> None: + result = parse_api_constraints( + api_constraints, + task_type_id=1, + input_name="source_data", + ) + assert result == expected_data_type + + +def test_parse_api_constraints_unsupported_type() -> None: + """Unsupported types (e.g. int, list passed directly) should return None.""" + result = parse_api_constraints( + 12345, + task_type_id=1, + input_name="source_data", + ) + assert result is None + + +class TestParseApiConstraintsLogging: + """Verify correct log levels are emitted for different anomaly types.""" + + def test_malformed_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + "{bad json", + task_type_id=1, + input_name="source_data", + ) + assert any("malformed_json" in r.message for r in caplog.records) + + def test_empty_string_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + "", + task_type_id=1, + input_name="source_data", + ) + assert any("empty_string" in r.message for r in caplog.records) + + def test_non_dict_json_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + '["array"]', + task_type_id=1, + input_name="source_data", + ) + assert any("non_dict_json" in r.message for r in caplog.records) + + def test_unsupported_type_logs_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.WARNING): + parse_api_constraints( + 12345, + task_type_id=1, + input_name="source_data", + ) + assert any("unsupported_type" in r.message for r in caplog.records) + + def test_missing_data_type_logs_debug(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.DEBUG): + parse_api_constraints( + '{"other_key": "val"}', + task_type_id=1, + input_name="source_data", + ) + assert any("missing_data_type" in r.message for r in caplog.records) + + def test_valid_constraint_no_warning(self, caplog: pytest.LogCaptureFixture) -> None: + with caplog.at_level(logging.DEBUG): + result = parse_api_constraints( + '{"data_type": "matrix"}', + task_type_id=1, + input_name="source_data", + ) + assert result == "matrix" + assert len(caplog.records) == 0 From ea00485b37c06d8ed68b998140e8d35a708af7de Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Mar 2026 15:13:04 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/routers/openml/tasktype.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 8477745d..b40c739d 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -68,8 +68,7 @@ def parse_api_constraints( return None if not isinstance(parsed, dict): logger.warning( - "api_constraints: non_dict_json for task_type_id=%d, input=%s " - "(got %s)", + "api_constraints: non_dict_json for task_type_id=%d, input=%s (got %s)", task_type_id, input_name, type(parsed).__name__, @@ -78,8 +77,7 @@ def parse_api_constraints( constraint = parsed else: logger.warning( - "api_constraints: unsupported_type for task_type_id=%d, input=%s " - "(got %s)", + "api_constraints: unsupported_type for task_type_id=%d, input=%s (got %s)", task_type_id, input_name, type(api_constraints).__name__, From 41f336ec2c2fef6dff9022517ae452125df7bb4b Mon Sep 17 00:00:00 2001 From: Gokul-social Date: Tue, 24 Mar 2026 19:45:14 +0530 Subject: [PATCH 3/4] fix: restore Any import and resolve lint/type issues --- src/routers/openml/tasktype.py | 50 ++++++++++--------- .../openml/test_parse_api_constraints.py | 2 +- 2 files changed, 28 insertions(+), 24 deletions(-) diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index b40c739d..1fc895aa 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -1,5 +1,6 @@ import json import logging +from collections.abc import Mapping from typing import Annotated, Any, Literal, cast from fastapi import APIRouter, Depends @@ -30,7 +31,7 @@ def _normalize_task_type(task_type: Row[Any]) -> dict[str, str | None | list[Any def parse_api_constraints( - api_constraints: Any, + api_constraints: str | Mapping[str, object] | None, *, task_type_id: int, input_name: str, @@ -42,12 +43,12 @@ def parse_api_constraints( on success, or None if the value cannot be parsed or does not contain a valid data_type. """ - constraint: dict[str, Any] | None = None - if api_constraints is None: return None - if isinstance(api_constraints, dict): + constraint: Mapping[str, object] | None = None + + if isinstance(api_constraints, Mapping): constraint = api_constraints elif isinstance(api_constraints, str): if not api_constraints: @@ -56,25 +57,26 @@ def parse_api_constraints( task_type_id, input_name, ) - return None - try: - parsed = json.loads(api_constraints) - except json.JSONDecodeError: - logger.warning( - "api_constraints: malformed_json for task_type_id=%d, input=%s", - task_type_id, - input_name, - ) - return None - if not isinstance(parsed, dict): - logger.warning( - "api_constraints: non_dict_json for task_type_id=%d, input=%s (got %s)", - task_type_id, - input_name, - type(parsed).__name__, - ) - return None - constraint = parsed + else: + try: + parsed = json.loads(api_constraints) + except json.JSONDecodeError: + logger.warning( + "api_constraints: malformed_json for task_type_id=%d, input=%s", + task_type_id, + input_name, + ) + else: + if isinstance(parsed, dict): + constraint = parsed + else: + logger.warning( + "api_constraints: non_dict_json for task_type_id=%d, input=%s " + "(got %s)", + task_type_id, + input_name, + type(parsed).__name__, + ) else: logger.warning( "api_constraints: unsupported_type for task_type_id=%d, input=%s (got %s)", @@ -82,6 +84,8 @@ def parse_api_constraints( input_name, type(api_constraints).__name__, ) + + if constraint is None: return None data_type = constraint.get("data_type") diff --git a/tests/routers/openml/test_parse_api_constraints.py b/tests/routers/openml/test_parse_api_constraints.py index fa691892..9be9048c 100644 --- a/tests/routers/openml/test_parse_api_constraints.py +++ b/tests/routers/openml/test_parse_api_constraints.py @@ -126,4 +126,4 @@ def test_valid_constraint_no_warning(self, caplog: pytest.LogCaptureFixture) -> input_name="source_data", ) assert result == "matrix" - assert len(caplog.records) == 0 + assert not any(r.name == "routers.openml.tasktype" for r in caplog.records) From 4fca9fb70217c153a0d0b1a75e181f20373a3c5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Mar 2026 14:46:17 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/routers/openml/tasktype.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/routers/openml/tasktype.py b/src/routers/openml/tasktype.py index 1fc895aa..afd8a5e4 100644 --- a/src/routers/openml/tasktype.py +++ b/src/routers/openml/tasktype.py @@ -71,8 +71,7 @@ def parse_api_constraints( constraint = parsed else: logger.warning( - "api_constraints: non_dict_json for task_type_id=%d, input=%s " - "(got %s)", + "api_constraints: non_dict_json for task_type_id=%d, input=%s (got %s)", task_type_id, input_name, type(parsed).__name__,