Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .github/workflows/ci_cov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ jobs:

# Install dependencies
- name: Install dependencies
run: uv pip install ruff -e ".[docs,test]"

# Run ruff
- name: Lint with ruff
run: |
ruff check . --output-format=github
run: uv pip install -e ".[docs,test]"

# Run unittest with coverage
- name: Test with unittest and coverage
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/ci_windows.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"

- name: Create virtual environment
shell: pwsh
Expand Down
4 changes: 3 additions & 1 deletion docs/api/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ Error functions

.. autofunction:: hed.errors.error_reporter.sort_issues

.. autofunction:: hed.errors.error_reporter.replace_tag_references
.. autofunction:: hed.errors.error_reporter.separate_issues

.. autofunction:: hed.errors.error_reporter.iter_errors

Error types
-----------
Expand Down
8 changes: 7 additions & 1 deletion hed/errors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Error handling module for HED."""

from .error_reporter import ErrorHandler, get_printable_issue_string, sort_issues, replace_tag_references, iter_errors
from .error_reporter import (
ErrorHandler,
separate_issues,
get_printable_issue_string,
sort_issues,
iter_errors,
)
from .error_types import (
DefinitionErrors,
TemporalErrors,
Expand Down
146 changes: 83 additions & 63 deletions hed/errors/error_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,23 @@ def wrapper(tag, *args, severity=default_severity, **kwargs):
schema_error_messages.mark_as_used = True


def separate_issues(issues_list: list[dict]) -> tuple[list[dict], list[dict]]:
"""Separate a list of issues into errors and warnings.

Parameters:
issues_list (list[dict]): A list of issue dictionaries. The 'severity' key is
optional; issues that omit it are treated as errors (ErrorSeverity.ERROR).

Returns:
tuple[list[dict], list[dict]]: A tuple of (errors, warnings) where errors contains
issues with severity <= ErrorSeverity.ERROR and warnings contains issues with
severity > ErrorSeverity.ERROR.
"""
errors = [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) <= ErrorSeverity.ERROR]
warnings = [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) > ErrorSeverity.ERROR]
return errors, warnings


class ErrorHandler:
"""Class to hold error context and having general error functions."""

Expand Down Expand Up @@ -274,20 +291,6 @@ def format_error_with_context(self, *args, **kwargs):

return error_object

@staticmethod
def filter_issues_by_severity(issues_list: list[dict], severity: int) -> list[dict]:
"""Gather all issues matching or below a given severity.

Parameters:
issues_list (list[dict]): A list of dictionaries containing the full issue list.
severity (int): The level of issues to keep.

Returns:
list[dict]: A list of dictionaries containing the issue list after filtering by severity.

"""
return [issue for issue in issues_list if issue["severity"] <= severity]

@staticmethod
def format_error(error_type: str, *args, actual_error=None, **kwargs) -> list[dict]:
"""Format an error based on the parameters, which vary based on what type of error this is.
Expand Down Expand Up @@ -423,13 +426,46 @@ def val_error_unknown(*args, **kwargs) -> str:
return f"Unknown error. Args: {str(args), str(kwargs)}"

@staticmethod
def filter_issues_by_count(issues, count, by_file=False) -> tuple[list[dict], dict[str, int]]:
def filter_issues_by_severity(issues_list: list[dict], severity: int) -> list[dict]:
"""Gather all issues matching or below a given severity.

Parameters:
issues_list (list[dict]): A list of dictionaries containing the full issue list.
severity (int): The level of issues to keep.

Returns:
list[dict]: A list of dictionaries containing the issue list after filtering by severity.

"""
return [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) <= severity]

@staticmethod
def aggregate_code_counts(file_code_dict: dict) -> dict:
"""Aggregate the counts of codes across multiple files.

Parameters:
file_code_dict (dict): A dictionary where keys are filenames and values are
dictionaries of code counts.

Returns:
dict: A dictionary with the aggregated counts of codes across all files.
"""
total_counts = defaultdict(int)
for file_dict in file_code_dict.values():
for code, count in file_dict.items():
total_counts[code] += count
return dict(total_counts)

@staticmethod
def filter_issues_by_count(
issues: list[dict], count: int, by_file: bool = False
) -> tuple[list[dict], dict[str, int]]:
"""Filter the issues list to only include the first count issues of each code.

Parameters:
issues (list): A list of dictionaries containing the full issue list.
count (int): The number of issues to keep for each code.
by_file (bool): If True, group by file name.
Parameters:
issues (list[dict]): A list of dictionaries containing the full issue list.
count (int): The number of issues to keep for each code.
by_file (bool): If True, group by file name.

Returns:
tuple[list[dict], dict[str, int]]: A tuple containing:
Expand Down Expand Up @@ -457,22 +493,6 @@ def filter_issues_by_count(issues, count, by_file=False) -> tuple[list[dict], di

return filtered_issues, ErrorHandler.aggregate_code_counts(file_dicts)

@staticmethod
def aggregate_code_counts(file_code_dict) -> dict:
"""Aggregate the counts of codes across multiple files.

Parameters:
file_code_dict (dict): A dictionary where keys are filenames and values are dictionaries of code counts.

Returns:
dict: A dictionary with the aggregated counts of codes across all files.
"""
total_counts = defaultdict(int)
for file_dict in file_code_dict.values():
for code, count in file_dict.items():
total_counts[code] += count
return dict(total_counts)

@staticmethod
def get_code_counts(issues: list[dict]) -> dict[str, int]:
"""Count the occurrences of each error code in the issues list.
Expand All @@ -490,6 +510,34 @@ def get_code_counts(issues: list[dict]) -> dict[str, int]:
code_counts[code] += 1
return dict(code_counts)

@staticmethod
def replace_tag_references(list_or_dict):
"""Utility function to remove any references to tags, strings, etc. from any type of nested list or dict.

Use this if you want to save out issues to a file.

If you'd prefer a copy returned, use ErrorHandler.replace_tag_references(list_or_dict.copy()).

Parameters:
list_or_dict (list or dict): An arbitrarily nested list/dict structure
"""
if isinstance(list_or_dict, dict):
for key, value in list_or_dict.items():
if isinstance(value, (dict, list)):
ErrorHandler.replace_tag_references(value)
elif isinstance(value, (bool, float, int)):
list_or_dict[key] = value
else:
list_or_dict[key] = str(value)
elif isinstance(list_or_dict, list):
for key, value in enumerate(list_or_dict):
if isinstance(value, (dict, list)):
ErrorHandler.replace_tag_references(value)
elif isinstance(value, (bool, float, int)):
list_or_dict[key] = value
else:
list_or_dict[key] = str(value)


def sort_issues(issues, reverse=False) -> list[dict]:
"""Sort a list of issues by the error context values.
Expand Down Expand Up @@ -822,31 +870,3 @@ def _create_error_tree(error_dict, parent_element=None, add_link=True):
_create_error_tree(value, context_ul, add_link)

return parent_element


def replace_tag_references(list_or_dict):
"""Utility function to remove any references to tags, strings, etc. from any type of nested list or dict.

Use this if you want to save out issues to a file.

If you'd prefer a copy returned, use replace_tag_references(list_or_dict.copy()).

Parameters:
list_or_dict (list or dict): An arbitrarily nested list/dict structure
"""
if isinstance(list_or_dict, dict):
for key, value in list_or_dict.items():
if isinstance(value, (dict, list)):
replace_tag_references(value)
elif isinstance(value, (bool, float, int)):
list_or_dict[key] = value
else:
list_or_dict[key] = str(value)
elif isinstance(list_or_dict, list):
for key, value in enumerate(list_or_dict):
if isinstance(value, (dict, list)):
replace_tag_references(value)
elif isinstance(value, (bool, float, int)):
list_or_dict[key] = value
else:
list_or_dict[key] = str(value)
33 changes: 21 additions & 12 deletions hed/scripts/schema_script_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from collections import defaultdict
from hed.schema import from_string, load_schema, from_dataframes
from hed.schema import hed_cache
from hed.errors import get_printable_issue_string, HedFileError
from hed.errors.error_types import ErrorSeverity
from hed.errors import get_printable_issue_string, separate_issues, HedFileError
from hed.schema.schema_comparer import SchemaComparer

all_extensions = [".tsv", ".mediawiki", ".xml", ".json"]
Expand Down Expand Up @@ -32,26 +31,32 @@ def _is_prerelease_partner(base_schema) -> bool:
return hed_cache.get_hed_version_path(with_standard, check_prerelease=False) is None


def validate_schema_object(base_schema, schema_name):
"""Validate a schema object by checking non-warning compliance and roundtrip conversion.
def validate_schema_object(base_schema, schema_name, check_warnings=False):
"""Validate a schema object by checking compliance and roundtrip conversion.

Tests the schema for non-warning compliance issues and validates that it can be successfully
Tests the schema for compliance issues and validates that it can be successfully
converted to and reloaded from all four formats (MEDIAWIKI, XML, JSON, TSV).

Parameters:
base_schema (HedSchema): The schema object to validate.
schema_name (str): The name/path of the schema for error reporting.
check_warnings (bool): If True, include warnings in the validation. Default is False.

Returns:
list: A list of validation issue strings. Empty if no issues found.
"""
validation_issues = []
try:
issues = base_schema.check_compliance()
issues = [issue for issue in issues if issue.get("severity", ErrorSeverity.ERROR) == ErrorSeverity.ERROR]
issues = base_schema.check_compliance(check_for_warnings=check_warnings)
if issues and check_warnings:
errors, warnings = separate_issues(issues)
issues = errors + warnings
else:
errors = issues

if issues:
error_message = get_printable_issue_string(issues, title=schema_name)
validation_issues.append(error_message)
validation_issues.append(get_printable_issue_string(issues, title=schema_name))
if errors:
return validation_issues

# If the withStandard partner only exists in the prerelease cache, all unmerged
Expand All @@ -74,14 +79,18 @@ def validate_schema_object(base_schema, schema_name):
return validation_issues


def validate_schema(file_path):
def validate_schema(file_path, check_warnings=False):
"""Validate a schema file, ensuring it can save/load and passes validation.

Loads the schema from file, checks the file extension is lowercase,
and validates the schema object for compliance and roundtrip conversion.
and validates the schema object for compliance errors and roundtrip conversion.

Parameters:
file_path (str): The path to the schema file to validate.
If loading a TSV file, this should be a single filename where:
Template: basename.tsv, where files are named basename_Struct.tsv, basename_Tag.tsv, etc.
Alternatively, you can point to a directory containing the .tsv files.
check_warnings (bool): If True, include warnings in the validation. Default is False.

Returns:
list: A list of validation issue strings. Empty if no issues found.
Expand All @@ -96,7 +105,7 @@ def validate_schema(file_path):
validation_issues = []
try:
base_schema = load_schema(file_path)
validation_issues = validate_schema_object(base_schema, file_path)
validation_issues = validate_schema_object(base_schema, file_path, check_warnings=check_warnings)
except HedFileError as e:
print(f"Saving/loading error: {file_path} {e.message}")
error_text = e.message
Expand Down
55 changes: 51 additions & 4 deletions tests/errors/test_error_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
SchemaWarnings,
get_printable_issue_string,
sort_issues,
replace_tag_references,
separate_issues,
)
from hed.errors.error_reporter import hed_tag_error, get_printable_issue_string_html, iter_errors
from hed import HedString, HedTag
Expand Down Expand Up @@ -230,17 +230,17 @@ def test_replace_tag_references(self):
"b": {"c": 2, "d": [3, {"e": HedString("Hed2", self._schema)}]},
"f": [5, 6],
}
replace_tag_references(nested_dict)
ErrorHandler.replace_tag_references(nested_dict)
self.assertEqual(nested_dict, {"a": "Hed1", "b": {"c": 2, "d": [3, {"e": "Hed2"}]}, "f": [5, 6]})

# Test with mixed data types and HedString in a nested list
nested_list = [HedString("Hed1", self._schema), {"a": 2, "b": [3, {"c": HedString("Hed2", self._schema)}]}]
replace_tag_references(nested_list)
ErrorHandler.replace_tag_references(nested_list)
self.assertEqual(nested_list, ["Hed1", {"a": 2, "b": [3, {"c": "Hed2"}]}])

# Test with mixed data types and HedString in a list within a dict
mixed = {"a": HedString("Hed1", self._schema), "b": [2, 3, {"c": HedString("Hed2", self._schema)}, 4]}
replace_tag_references(mixed)
ErrorHandler.replace_tag_references(mixed)
self.assertEqual(mixed, {"a": "Hed1", "b": [2, 3, {"c": "Hed2"}, 4]})

def test_register_error_twice(self):
Expand Down Expand Up @@ -299,3 +299,50 @@ def test_get_code_counts(self):
result_with_missing = ErrorHandler.get_code_counts(issues_with_missing_code)
expected_with_missing = {"VALID_CODE": 2, "UNKNOWN": 1} # Default for missing code
self.assertEqual(result_with_missing, expected_with_missing)


class TestSeparateIssues(unittest.TestCase):
"""Tests for separate_issues."""

@staticmethod
def _make_issue(severity):
return {"severity": severity, "message": "test"}

def test_empty_list(self):
errors, warnings = separate_issues([])
self.assertEqual(errors, [])
self.assertEqual(warnings, [])

def test_only_errors(self):
issues = [self._make_issue(ErrorSeverity.ERROR), self._make_issue(ErrorSeverity.ERROR)]
errors, warnings = separate_issues(issues)
self.assertEqual(len(errors), 2)
self.assertEqual(len(warnings), 0)

def test_only_warnings(self):
issues = [self._make_issue(ErrorSeverity.WARNING), self._make_issue(ErrorSeverity.WARNING)]
errors, warnings = separate_issues(issues)
self.assertEqual(len(errors), 0)
self.assertEqual(len(warnings), 2)

def test_mixed(self):
issues = [
self._make_issue(ErrorSeverity.ERROR),
self._make_issue(ErrorSeverity.WARNING),
self._make_issue(ErrorSeverity.ERROR),
]
errors, warnings = separate_issues(issues)
self.assertEqual(len(errors), 2)
self.assertEqual(len(warnings), 1)

def test_original_list_unchanged(self):
issues = [self._make_issue(ErrorSeverity.ERROR), self._make_issue(ErrorSeverity.WARNING)]
separate_issues(issues)
self.assertEqual(len(issues), 2)

def test_missing_severity_treated_as_error(self):
"""Issues without a 'severity' key should be treated as errors, not raise KeyError."""
issues = [{"message": "no severity"}, self._make_issue(ErrorSeverity.WARNING)]
errors, warnings = separate_issues(issues)
self.assertEqual(len(errors), 1, "Issue missing severity should default to ERROR")
self.assertEqual(len(warnings), 1)
Loading