diff --git a/.github/workflows/ci_cov.yaml b/.github/workflows/ci_cov.yaml index ea1ec0a2..6eb390ec 100644 --- a/.github/workflows/ci_cov.yaml +++ b/.github/workflows/ci_cov.yaml @@ -37,12 +37,7 @@ jobs: # Install dependencies - name: Install dependencies - run: uv pip install ruff -e ".[docs,test]" - - # Run ruff - - name: Lint with ruff - run: | - ruff check . --output-format=github + run: uv pip install -e ".[docs,test]" # Run unittest with coverage - name: Test with unittest and coverage diff --git a/.github/workflows/ci_windows.yaml b/.github/workflows/ci_windows.yaml index 97edd0cb..05e9f426 100644 --- a/.github/workflows/ci_windows.yaml +++ b/.github/workflows/ci_windows.yaml @@ -25,6 +25,7 @@ jobs: with: python-version: ${{ matrix.python-version }} enable-cache: true + cache-dependency-glob: "**/pyproject.toml" - name: Create virtual environment shell: pwsh diff --git a/docs/api/errors.rst b/docs/api/errors.rst index 6b66ed5f..95f0f7c8 100644 --- a/docs/api/errors.rst +++ b/docs/api/errors.rst @@ -40,7 +40,9 @@ Error functions .. autofunction:: hed.errors.error_reporter.sort_issues -.. autofunction:: hed.errors.error_reporter.replace_tag_references +.. autofunction:: hed.errors.error_reporter.separate_issues + +.. autofunction:: hed.errors.error_reporter.iter_errors Error types ----------- diff --git a/hed/errors/__init__.py b/hed/errors/__init__.py index ed213224..a79bc7f7 100644 --- a/hed/errors/__init__.py +++ b/hed/errors/__init__.py @@ -1,6 +1,12 @@ """Error handling module for HED.""" -from .error_reporter import ErrorHandler, get_printable_issue_string, sort_issues, replace_tag_references, iter_errors +from .error_reporter import ( + ErrorHandler, + separate_issues, + get_printable_issue_string, + sort_issues, + iter_errors, +) from .error_types import ( DefinitionErrors, TemporalErrors, diff --git a/hed/errors/error_reporter.py b/hed/errors/error_reporter.py index e8b289c7..b455ca11 100644 --- a/hed/errors/error_reporter.py +++ b/hed/errors/error_reporter.py @@ -204,6 +204,23 @@ def wrapper(tag, *args, severity=default_severity, **kwargs): schema_error_messages.mark_as_used = True +def separate_issues(issues_list: list[dict]) -> tuple[list[dict], list[dict]]: + """Separate a list of issues into errors and warnings. + + Parameters: + issues_list (list[dict]): A list of issue dictionaries. The 'severity' key is + optional; issues that omit it are treated as errors (ErrorSeverity.ERROR). + + Returns: + tuple[list[dict], list[dict]]: A tuple of (errors, warnings) where errors contains + issues with severity <= ErrorSeverity.ERROR and warnings contains issues with + severity > ErrorSeverity.ERROR. + """ + errors = [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) <= ErrorSeverity.ERROR] + warnings = [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) > ErrorSeverity.ERROR] + return errors, warnings + + class ErrorHandler: """Class to hold error context and having general error functions.""" @@ -274,20 +291,6 @@ def format_error_with_context(self, *args, **kwargs): return error_object - @staticmethod - def filter_issues_by_severity(issues_list: list[dict], severity: int) -> list[dict]: - """Gather all issues matching or below a given severity. - - Parameters: - issues_list (list[dict]): A list of dictionaries containing the full issue list. - severity (int): The level of issues to keep. - - Returns: - list[dict]: A list of dictionaries containing the issue list after filtering by severity. - - """ - return [issue for issue in issues_list if issue["severity"] <= severity] - @staticmethod def format_error(error_type: str, *args, actual_error=None, **kwargs) -> list[dict]: """Format an error based on the parameters, which vary based on what type of error this is. @@ -423,13 +426,46 @@ def val_error_unknown(*args, **kwargs) -> str: return f"Unknown error. Args: {str(args), str(kwargs)}" @staticmethod - def filter_issues_by_count(issues, count, by_file=False) -> tuple[list[dict], dict[str, int]]: + def filter_issues_by_severity(issues_list: list[dict], severity: int) -> list[dict]: + """Gather all issues matching or below a given severity. + + Parameters: + issues_list (list[dict]): A list of dictionaries containing the full issue list. + severity (int): The level of issues to keep. + + Returns: + list[dict]: A list of dictionaries containing the issue list after filtering by severity. + + """ + return [issue for issue in issues_list if issue.get("severity", ErrorSeverity.ERROR) <= severity] + + @staticmethod + def aggregate_code_counts(file_code_dict: dict) -> dict: + """Aggregate the counts of codes across multiple files. + + Parameters: + file_code_dict (dict): A dictionary where keys are filenames and values are + dictionaries of code counts. + + Returns: + dict: A dictionary with the aggregated counts of codes across all files. + """ + total_counts = defaultdict(int) + for file_dict in file_code_dict.values(): + for code, count in file_dict.items(): + total_counts[code] += count + return dict(total_counts) + + @staticmethod + def filter_issues_by_count( + issues: list[dict], count: int, by_file: bool = False + ) -> tuple[list[dict], dict[str, int]]: """Filter the issues list to only include the first count issues of each code. - Parameters: - issues (list): A list of dictionaries containing the full issue list. - count (int): The number of issues to keep for each code. - by_file (bool): If True, group by file name. + Parameters: + issues (list[dict]): A list of dictionaries containing the full issue list. + count (int): The number of issues to keep for each code. + by_file (bool): If True, group by file name. Returns: tuple[list[dict], dict[str, int]]: A tuple containing: @@ -457,22 +493,6 @@ def filter_issues_by_count(issues, count, by_file=False) -> tuple[list[dict], di return filtered_issues, ErrorHandler.aggregate_code_counts(file_dicts) - @staticmethod - def aggregate_code_counts(file_code_dict) -> dict: - """Aggregate the counts of codes across multiple files. - - Parameters: - file_code_dict (dict): A dictionary where keys are filenames and values are dictionaries of code counts. - - Returns: - dict: A dictionary with the aggregated counts of codes across all files. - """ - total_counts = defaultdict(int) - for file_dict in file_code_dict.values(): - for code, count in file_dict.items(): - total_counts[code] += count - return dict(total_counts) - @staticmethod def get_code_counts(issues: list[dict]) -> dict[str, int]: """Count the occurrences of each error code in the issues list. @@ -490,6 +510,34 @@ def get_code_counts(issues: list[dict]) -> dict[str, int]: code_counts[code] += 1 return dict(code_counts) + @staticmethod + def replace_tag_references(list_or_dict): + """Utility function to remove any references to tags, strings, etc. from any type of nested list or dict. + + Use this if you want to save out issues to a file. + + If you'd prefer a copy returned, use ErrorHandler.replace_tag_references(list_or_dict.copy()). + + Parameters: + list_or_dict (list or dict): An arbitrarily nested list/dict structure + """ + if isinstance(list_or_dict, dict): + for key, value in list_or_dict.items(): + if isinstance(value, (dict, list)): + ErrorHandler.replace_tag_references(value) + elif isinstance(value, (bool, float, int)): + list_or_dict[key] = value + else: + list_or_dict[key] = str(value) + elif isinstance(list_or_dict, list): + for key, value in enumerate(list_or_dict): + if isinstance(value, (dict, list)): + ErrorHandler.replace_tag_references(value) + elif isinstance(value, (bool, float, int)): + list_or_dict[key] = value + else: + list_or_dict[key] = str(value) + def sort_issues(issues, reverse=False) -> list[dict]: """Sort a list of issues by the error context values. @@ -822,31 +870,3 @@ def _create_error_tree(error_dict, parent_element=None, add_link=True): _create_error_tree(value, context_ul, add_link) return parent_element - - -def replace_tag_references(list_or_dict): - """Utility function to remove any references to tags, strings, etc. from any type of nested list or dict. - - Use this if you want to save out issues to a file. - - If you'd prefer a copy returned, use replace_tag_references(list_or_dict.copy()). - - Parameters: - list_or_dict (list or dict): An arbitrarily nested list/dict structure - """ - if isinstance(list_or_dict, dict): - for key, value in list_or_dict.items(): - if isinstance(value, (dict, list)): - replace_tag_references(value) - elif isinstance(value, (bool, float, int)): - list_or_dict[key] = value - else: - list_or_dict[key] = str(value) - elif isinstance(list_or_dict, list): - for key, value in enumerate(list_or_dict): - if isinstance(value, (dict, list)): - replace_tag_references(value) - elif isinstance(value, (bool, float, int)): - list_or_dict[key] = value - else: - list_or_dict[key] = str(value) diff --git a/hed/scripts/schema_script_util.py b/hed/scripts/schema_script_util.py index 39c5d637..7af37d04 100644 --- a/hed/scripts/schema_script_util.py +++ b/hed/scripts/schema_script_util.py @@ -2,8 +2,7 @@ from collections import defaultdict from hed.schema import from_string, load_schema, from_dataframes from hed.schema import hed_cache -from hed.errors import get_printable_issue_string, HedFileError -from hed.errors.error_types import ErrorSeverity +from hed.errors import get_printable_issue_string, separate_issues, HedFileError from hed.schema.schema_comparer import SchemaComparer all_extensions = [".tsv", ".mediawiki", ".xml", ".json"] @@ -32,26 +31,32 @@ def _is_prerelease_partner(base_schema) -> bool: return hed_cache.get_hed_version_path(with_standard, check_prerelease=False) is None -def validate_schema_object(base_schema, schema_name): - """Validate a schema object by checking non-warning compliance and roundtrip conversion. +def validate_schema_object(base_schema, schema_name, check_warnings=False): + """Validate a schema object by checking compliance and roundtrip conversion. - Tests the schema for non-warning compliance issues and validates that it can be successfully + Tests the schema for compliance issues and validates that it can be successfully converted to and reloaded from all four formats (MEDIAWIKI, XML, JSON, TSV). Parameters: base_schema (HedSchema): The schema object to validate. schema_name (str): The name/path of the schema for error reporting. + check_warnings (bool): If True, include warnings in the validation. Default is False. Returns: list: A list of validation issue strings. Empty if no issues found. """ validation_issues = [] try: - issues = base_schema.check_compliance() - issues = [issue for issue in issues if issue.get("severity", ErrorSeverity.ERROR) == ErrorSeverity.ERROR] + issues = base_schema.check_compliance(check_for_warnings=check_warnings) + if issues and check_warnings: + errors, warnings = separate_issues(issues) + issues = errors + warnings + else: + errors = issues + if issues: - error_message = get_printable_issue_string(issues, title=schema_name) - validation_issues.append(error_message) + validation_issues.append(get_printable_issue_string(issues, title=schema_name)) + if errors: return validation_issues # If the withStandard partner only exists in the prerelease cache, all unmerged @@ -74,14 +79,18 @@ def validate_schema_object(base_schema, schema_name): return validation_issues -def validate_schema(file_path): +def validate_schema(file_path, check_warnings=False): """Validate a schema file, ensuring it can save/load and passes validation. Loads the schema from file, checks the file extension is lowercase, - and validates the schema object for compliance and roundtrip conversion. + and validates the schema object for compliance errors and roundtrip conversion. Parameters: file_path (str): The path to the schema file to validate. + If loading a TSV file, this should be a single filename where: + Template: basename.tsv, where files are named basename_Struct.tsv, basename_Tag.tsv, etc. + Alternatively, you can point to a directory containing the .tsv files. + check_warnings (bool): If True, include warnings in the validation. Default is False. Returns: list: A list of validation issue strings. Empty if no issues found. @@ -96,7 +105,7 @@ def validate_schema(file_path): validation_issues = [] try: base_schema = load_schema(file_path) - validation_issues = validate_schema_object(base_schema, file_path) + validation_issues = validate_schema_object(base_schema, file_path, check_warnings=check_warnings) except HedFileError as e: print(f"Saving/loading error: {file_path} {e.message}") error_text = e.message diff --git a/spec_tests/hed-examples b/spec_tests/hed-examples index 336a4fcc..a0650c9c 160000 --- a/spec_tests/hed-examples +++ b/spec_tests/hed-examples @@ -1 +1 @@ -Subproject commit 336a4fccaec59b4924c8f37c50967fa480538ccf +Subproject commit a0650c9c290d6d5d200ccdde147b982ebc760317 diff --git a/tests/errors/test_error_reporter.py b/tests/errors/test_error_reporter.py index 0875a89b..ca4bb390 100644 --- a/tests/errors/test_error_reporter.py +++ b/tests/errors/test_error_reporter.py @@ -7,7 +7,7 @@ SchemaWarnings, get_printable_issue_string, sort_issues, - replace_tag_references, + separate_issues, ) from hed.errors.error_reporter import hed_tag_error, get_printable_issue_string_html, iter_errors from hed import HedString, HedTag @@ -230,17 +230,17 @@ def test_replace_tag_references(self): "b": {"c": 2, "d": [3, {"e": HedString("Hed2", self._schema)}]}, "f": [5, 6], } - replace_tag_references(nested_dict) + ErrorHandler.replace_tag_references(nested_dict) self.assertEqual(nested_dict, {"a": "Hed1", "b": {"c": 2, "d": [3, {"e": "Hed2"}]}, "f": [5, 6]}) # Test with mixed data types and HedString in a nested list nested_list = [HedString("Hed1", self._schema), {"a": 2, "b": [3, {"c": HedString("Hed2", self._schema)}]}] - replace_tag_references(nested_list) + ErrorHandler.replace_tag_references(nested_list) self.assertEqual(nested_list, ["Hed1", {"a": 2, "b": [3, {"c": "Hed2"}]}]) # Test with mixed data types and HedString in a list within a dict mixed = {"a": HedString("Hed1", self._schema), "b": [2, 3, {"c": HedString("Hed2", self._schema)}, 4]} - replace_tag_references(mixed) + ErrorHandler.replace_tag_references(mixed) self.assertEqual(mixed, {"a": "Hed1", "b": [2, 3, {"c": "Hed2"}, 4]}) def test_register_error_twice(self): @@ -299,3 +299,50 @@ def test_get_code_counts(self): result_with_missing = ErrorHandler.get_code_counts(issues_with_missing_code) expected_with_missing = {"VALID_CODE": 2, "UNKNOWN": 1} # Default for missing code self.assertEqual(result_with_missing, expected_with_missing) + + +class TestSeparateIssues(unittest.TestCase): + """Tests for separate_issues.""" + + @staticmethod + def _make_issue(severity): + return {"severity": severity, "message": "test"} + + def test_empty_list(self): + errors, warnings = separate_issues([]) + self.assertEqual(errors, []) + self.assertEqual(warnings, []) + + def test_only_errors(self): + issues = [self._make_issue(ErrorSeverity.ERROR), self._make_issue(ErrorSeverity.ERROR)] + errors, warnings = separate_issues(issues) + self.assertEqual(len(errors), 2) + self.assertEqual(len(warnings), 0) + + def test_only_warnings(self): + issues = [self._make_issue(ErrorSeverity.WARNING), self._make_issue(ErrorSeverity.WARNING)] + errors, warnings = separate_issues(issues) + self.assertEqual(len(errors), 0) + self.assertEqual(len(warnings), 2) + + def test_mixed(self): + issues = [ + self._make_issue(ErrorSeverity.ERROR), + self._make_issue(ErrorSeverity.WARNING), + self._make_issue(ErrorSeverity.ERROR), + ] + errors, warnings = separate_issues(issues) + self.assertEqual(len(errors), 2) + self.assertEqual(len(warnings), 1) + + def test_original_list_unchanged(self): + issues = [self._make_issue(ErrorSeverity.ERROR), self._make_issue(ErrorSeverity.WARNING)] + separate_issues(issues) + self.assertEqual(len(issues), 2) + + def test_missing_severity_treated_as_error(self): + """Issues without a 'severity' key should be treated as errors, not raise KeyError.""" + issues = [{"message": "no severity"}, self._make_issue(ErrorSeverity.WARNING)] + errors, warnings = separate_issues(issues) + self.assertEqual(len(errors), 1, "Issue missing severity should default to ERROR") + self.assertEqual(len(warnings), 1) diff --git a/tests/scripts/test_hed_convert_schema.py b/tests/scripts/test_hed_convert_schema.py index 4126e371..11750713 100644 --- a/tests/scripts/test_hed_convert_schema.py +++ b/tests/scripts/test_hed_convert_schema.py @@ -1,12 +1,13 @@ -import unittest -import shutil +import contextlib import copy +import io import os +import shutil +import unittest from hed import load_schema, load_schema_version from hed.schema import HedSectionKey, HedKey from hed.scripts.schema_script_util import add_extension from hed.scripts.hed_convert_schema import convert_and_update -import contextlib class TestConvertAndUpdate(unittest.TestCase): @@ -25,7 +26,7 @@ def test_schema_conversion_and_update(self): # Assume filenames updated includes just the original schema file for simplicity filenames = [original_name] - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = convert_and_update(filenames, set_ids=False) # Verify no error from convert_and_update and the correct schema version was saved @@ -43,7 +44,7 @@ def test_schema_conversion_and_update(self): schema.save_as_dataframes(tsv_filename) filenames = [os.path.join(tsv_filename, "test_schema_Tag.tsv")] - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = convert_and_update(filenames, set_ids=False) # Verify no error from convert_and_update and the correct schema version was saved @@ -72,7 +73,7 @@ def test_schema_adding_tag(self): # Assume filenames updated includes just the original schema file for simplicity filenames = [add_extension(basename, ".mediawiki")] - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = convert_and_update(filenames, set_ids=False) self.assertEqual(result, 0) @@ -81,7 +82,7 @@ def test_schema_adding_tag(self): self.assertTrue(x) self.assertEqual(schema_reloaded, schema_edited) - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = convert_and_update(filenames, set_ids=True) self.assertEqual(result, 0) diff --git a/tests/scripts/test_script_util.py b/tests/scripts/test_script_util.py index 24ec2e2b..5e004d47 100644 --- a/tests/scripts/test_script_util.py +++ b/tests/scripts/test_script_util.py @@ -1,15 +1,19 @@ +import contextlib +import copy +import io import unittest import os import shutil + from hed import load_schema_version from hed.scripts.schema_script_util import ( add_extension, sort_base_schemas, validate_all_schema_formats, validate_schema, + validate_schema_object, validate_all_schemas, ) -import contextlib class TestAddExtension(unittest.TestCase): @@ -107,7 +111,7 @@ def test_mixed_file_types(self): }, "other_schema": {".xml": "other_schema.xml"}, } - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) @@ -118,7 +122,7 @@ def test_tsv_in_correct_subfolder(self): os.path.normpath("hedtsv/wrong_folder/wrong_name_Tag.tsv"), # Should be ignored ] expected = {"test_schema": {".tsv": os.path.normpath("hedtsv/test_schema")}} - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) @@ -131,7 +135,7 @@ def test_tsv_in_correct_subfolder2(self): expected = { os.path.normpath("prerelease/test_schema"): {".tsv": os.path.normpath("prerelease/hedtsv/test_schema")} } - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) @@ -141,14 +145,14 @@ def test_ignored_files(self): os.path.normpath("not_hedtsv/test_schema/test_schema_Tag.tsv"), # Should be ignored ] expected = {"test_schema": {".mediawiki": "test_schema.mediawiki"}} - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) def test_empty_input(self): filenames = [] expected = {} - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) @@ -170,7 +174,7 @@ def test_case_insensitive_extensions(self): "case_test_schema": {".mediawiki": "case_test_schema.MEDIAWIKI"}, "case_other_schema": {".xml": "case_other_schema.XML"}, } - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): result = sort_base_schemas(filenames) self.assertEqual(dict(result), expected) finally: @@ -195,21 +199,21 @@ def test_error_no_error(self): schema = load_schema_version("8.4.0") schema.save_as_xml(os.path.join(self.base_path, self.basename + ".xml")) schema.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename)) - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename)) self.assertTrue(issues) self.assertEqual(issues[0], "Error loading schema: No such file or directory") schema.save_as_mediawiki(os.path.join(self.base_path, self.basename + ".mediawiki")) schema.save_as_json(os.path.join(self.base_path, self.basename + ".json")) - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): self.assertEqual(validate_all_schema_formats(os.path.join(self.base_path, self.basename)), []) schema_incorrect = load_schema_version("8.2.0") schema_incorrect.save_as_dataframes(os.path.join(self.base_path, "hedtsv", self.basename)) # Validate and expect errors - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): issues = validate_all_schema_formats(os.path.join(self.base_path, self.basename)) self.assertTrue(issues) # self.assertIn("Error loading schema: No columns to parse from file", issues[0]) @@ -223,7 +227,7 @@ def tearDownClass(cls): class TestValidateSchema(unittest.TestCase): def test_load_invalid_extension(self): # Verify capital letters fail validation - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.MEDIAWIKI")[0]) self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.Mediawiki")[0]) self.assertIn("Only fully lowercase extensions ", validate_schema("does_not_matter.XML")[0]) @@ -253,7 +257,7 @@ def test_uppercase_extension_policy_enforcement(self): self.assertEqual(schema_files["policy_test"][".xml"], uppercase_file) # Step 2: validate_all_schemas should use actual path and reject per policy - with contextlib.redirect_stdout(None): + with contextlib.redirect_stdout(io.StringIO()): issues = validate_all_schemas(schema_files) # Should get policy violation, not FileNotFoundError @@ -265,3 +269,65 @@ def test_uppercase_extension_policy_enforcement(self): # Clean up if os.path.exists(uppercase_file): os.remove(uppercase_file) + + +class TestCheckWarnings(unittest.TestCase): + """Tests for the check_warnings parameter in validate_schema_object and validate_schema.""" + + @classmethod + def setUpClass(cls): + clean = load_schema_version("8.3.0") + cls.clean_schema = clean + # Deep-copy so the cached shared instance is not mutated. + # Setting version to a future value triggers SCHEMA_PRERELEASE_VERSION_USED (warning only). + cls.warning_schema = copy.deepcopy(clean) + cls.warning_schema.header_attributes["version"] = "999.0.0" + + def test_clean_schema_check_warnings_false(self): + """A fully compliant schema produces no issues with check_warnings=False.""" + with contextlib.redirect_stdout(io.StringIO()): + issues = validate_schema_object(self.clean_schema, "test", check_warnings=False) + self.assertEqual(issues, []) + + def test_clean_schema_check_warnings_true(self): + """A fully compliant schema produces no issues with check_warnings=True.""" + with contextlib.redirect_stdout(io.StringIO()): + issues = validate_schema_object(self.clean_schema, "test", check_warnings=True) + self.assertEqual(issues, []) + + def test_warning_schema_check_warnings_true_reports_issues(self): + """A prerelease version generates a warning that is reported when check_warnings=True.""" + with contextlib.redirect_stdout(io.StringIO()): + issues = validate_schema_object(self.warning_schema, "test", check_warnings=True) + combined = "\n".join(issues) + self.assertTrue(issues, "Expected at least one issue for prerelease version warning") + self.assertIn( + "SCHEMA_PRERELEASE_VERSION_USED", + combined, + "Expected SCHEMA_PRERELEASE_VERSION_USED warning code in output for prerelease version", + ) + # Warnings must not gate the roundtrip checks — the warning message should be the only + # entry and roundtrip errors (if any) would append further entries, but the key contract + # is that the function does NOT return after reporting a warning-only result. + # We verify this indirectly: issue count must equal exactly 1 (the warning summary) + # because a clean schema with only a version warning should roundtrip without further errors. + self.assertEqual(len(issues), 1, "Roundtrip should have run and produced no additional errors") + + def test_warning_schema_check_warnings_false_suppresses_warnings(self): + """Warnings are suppressed and validation passes when check_warnings=False.""" + with contextlib.redirect_stdout(io.StringIO()): + issues = validate_schema_object(self.warning_schema, "test", check_warnings=False) + self.assertEqual(issues, []) + + def test_validate_schema_default_is_warnings_false(self): + """validate_schema default check_warnings=False matches explicit False.""" + schema_path = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "data", + "schema_tests", + "HED8.2.0.mediawiki", + ) + with contextlib.redirect_stdout(io.StringIO()): + default_issues = validate_schema(schema_path) + explicit_issues = validate_schema(schema_path, check_warnings=False) + self.assertEqual(default_issues, explicit_issues) diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py index b55ba416..dc562546 100644 --- a/tests/test_notebooks.py +++ b/tests/test_notebooks.py @@ -5,10 +5,13 @@ 2. Can be executed without errors (with mock data paths) 3. Import statements work correctly -These tests require the optional 'examples' dependencies: +TestNotebooks uses only the standard library (json module) to inspect notebook +structure and does not require optional dependencies. +TestNotebookExecution requires the optional 'examples' dependencies: pip install -e .[examples] """ +import json import os import unittest from pathlib import Path @@ -17,7 +20,10 @@ class TestNotebooks(unittest.TestCase): - """Test suite for validating example Jupyter notebooks.""" + """Test suite for validating example Jupyter notebooks. + + Uses only standard-library JSON parsing — no nbformat/nbconvert required. + """ @classmethod def setUpClass(cls): @@ -25,22 +31,41 @@ def setUpClass(cls): cls.examples_dir = Path(__file__).parent.parent / "examples" cls.test_data_dir = Path(__file__).parent / "data" - # Check if Jupyter dependencies are available - try: - import nbformat - from nbconvert.preprocessors import ExecutePreprocessor + @staticmethod + def _read_notebook(path): + """Read a Jupyter notebook file and return its parsed JSON dict. - cls.nbformat = nbformat - cls.ExecutePreprocessor = ExecutePreprocessor - cls.has_jupyter = True - except ImportError: - cls.has_jupyter = False - cls.skip_message = "Jupyter dependencies not installed. Run 'pip install -e .[examples]' to install them." + Parameters: + path (Path): Path to the .ipynb file. - def setUp(self): - """Set up each individual test.""" - if not self.has_jupyter: - self.skipTest(self.skip_message) + Returns: + dict: The parsed notebook JSON. + """ + with open(path, "r", encoding="utf-8") as f: + return json.load(f) + + @staticmethod + def _get_code_cells(nb): + """Return source strings for all code cells in a notebook. + + In the raw .ipynb JSON format each cell's ``source`` may be either a + plain string or a list of strings (one per line). This helper + normalises both cases into a single string per cell. + + Parameters: + nb (dict): Parsed notebook JSON. + + Returns: + list[str]: Source strings from code cells. + """ + sources = [] + for cell in nb.get("cells", []): + if cell.get("cell_type") == "code": + src = cell.get("source", "") + if isinstance(src, list): + src = "".join(src) + sources.append(src) + return sources def test_notebooks_directory_exists(self): """Verify the examples directory exists and contains notebooks.""" @@ -56,10 +81,10 @@ def test_all_notebooks_valid_format(self): for notebook_path in notebooks: with self.subTest(notebook=notebook_path.name): try: - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) + nb = self._read_notebook(notebook_path) self.assertIsNotNone(nb, f"Failed to read notebook: {notebook_path.name}") - self.assertGreater(len(nb.cells), 0, f"Notebook has no cells: {notebook_path.name}") + self.assertIn("cells", nb, f"Notebook missing 'cells' key: {notebook_path.name}") + self.assertGreater(len(nb["cells"]), 0, f"Notebook has no cells: {notebook_path.name}") except Exception as e: self.fail(f"Failed to load {notebook_path.name}: {str(e)}") @@ -69,16 +94,10 @@ def test_notebook_imports(self): for notebook_path in notebooks: with self.subTest(notebook=notebook_path.name): - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) + nb = self._read_notebook(notebook_path) + code_cells = self._get_code_cells(nb) - # Extract and test import statements - import_cells = [] - for cell in nb.cells: - if cell.cell_type == "code": - source = cell.source - if "import " in source: - import_cells.append(source) + import_cells = [src for src in code_cells if "import " in src] # Try to validate imports (basic check) for cell_source in import_cells: @@ -102,10 +121,8 @@ def test_notebooks_have_markdown_cells(self): continue with self.subTest(notebook=notebook_path.name): - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - markdown_cells = [c for c in nb.cells if c.cell_type == "markdown"] + nb = self._read_notebook(notebook_path) + markdown_cells = [c for c in nb.get("cells", []) if c.get("cell_type") == "markdown"] self.assertGreater(len(markdown_cells), 0, f"Notebook {notebook_path.name} has no markdown cells") def test_specific_notebook_structure(self): @@ -113,12 +130,8 @@ def test_specific_notebook_structure(self): # Test validate_bids_dataset notebook validate_nb = self.examples_dir / "validate_bids_dataset.ipynb" if validate_nb.exists(): - with open(validate_nb, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Should have imports from hed.tools and hed.errors - code_sources = [c.source for c in nb.cells if c.cell_type == "code"] - all_code = "\n".join(code_sources) + nb = self._read_notebook(validate_nb) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("BidsDataset", all_code, "validate_bids_dataset should import BidsDataset") self.assertIn("get_printable_issue_string", all_code, "validate_bids_dataset should import error handling") @@ -126,11 +139,8 @@ def test_specific_notebook_structure(self): # Test summarize_events notebook summarize_nb = self.examples_dir / "summarize_events.ipynb" if summarize_nb.exists(): - with open(summarize_nb, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - code_sources = [c.source for c in nb.cells if c.cell_type == "code"] - all_code = "\n".join(code_sources) + nb = self._read_notebook(summarize_nb) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("TabularSummary", all_code, "summarize_events should import TabularSummary") @@ -140,11 +150,8 @@ def test_notebooks_cell_execution_order(self): for notebook_path in notebooks: with self.subTest(notebook=notebook_path.name): - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Check that code cells exist and have execution count - code_cells = [c for c in nb.cells if c.cell_type == "code"] + nb = self._read_notebook(notebook_path) + code_cells = self._get_code_cells(nb) self.assertGreater(len(code_cells), 0, f"Notebook {notebook_path.name} has no code cells") def test_notebook_metadata(self): @@ -153,10 +160,7 @@ def test_notebook_metadata(self): for notebook_path in notebooks: with self.subTest(notebook=notebook_path.name): - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Check for kernel spec + nb = self._read_notebook(notebook_path) self.assertIn("metadata", nb, f"Notebook {notebook_path.name} missing metadata") def test_validate_bids_dataset_notebook(self): @@ -166,11 +170,8 @@ def test_validate_bids_dataset_notebook(self): if not notebook_path.exists(): self.skipTest("validate_bids_dataset.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("from hed.errors import", all_code) self.assertIn("from hed.tools import BidsDataset", all_code) @@ -184,11 +185,8 @@ def test_summarize_events_notebook(self): if not notebook_path.exists(): self.skipTest("summarize_events.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("TabularSummary", all_code) self.assertIn("get_file_list", all_code) @@ -200,11 +198,8 @@ def test_sidecar_to_spreadsheet_notebook(self): if not notebook_path.exists(): self.skipTest("sidecar_to_spreadsheet.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("hed_to_df", all_code) @@ -215,11 +210,8 @@ def test_merge_spreadsheet_into_sidecar_notebook(self): if not notebook_path.exists(): self.skipTest("merge_spreadsheet_into_sidecar.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("df_to_hed", all_code) self.assertIn("merge_hed_dict", all_code) @@ -231,11 +223,8 @@ def test_find_event_combinations_notebook(self): if not notebook_path.exists(): self.skipTest("find_event_combinations.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) self.assertIn("KeyMap", all_code) @@ -246,11 +235,8 @@ def test_extract_json_template_notebook(self): if not notebook_path.exists(): self.skipTest("extract_json_template.ipynb not found") - with open(notebook_path, "r", encoding="utf-8") as f: - nb = self.nbformat.read(f, as_version=4) - - # Verify key components - all_code = "\n".join([c.source for c in nb.cells if c.cell_type == "code"]) + nb = self._read_notebook(notebook_path) + all_code = "\n".join(self._get_code_cells(nb)) # Check for either get_new_dataframe or extract_sidecar_template self.assertTrue(