diff --git a/problemtools/config.py b/problemtools/config.py index 8cd470c9..a85493be 100644 --- a/problemtools/config.py +++ b/problemtools/config.py @@ -1,13 +1,15 @@ import collections import os import yaml +from pathlib import Path +from typing import Mapping class ConfigError(Exception): pass -def load_config(configuration_file): +def load_config(configuration_file: str, priority_dirs: list[Path] = []) -> dict: """Load a problemtools configuration file. Args: @@ -15,41 +17,42 @@ def load_config(configuration_file): relative to config directory so typically just a file name without paths, e.g. "languages.yaml". """ - res = None + res: dict | None = None - for dirname in __config_file_paths(): - path = os.path.join(dirname, configuration_file) + for dirname in __config_file_paths() + priority_dirs: + path = dirname / configuration_file new_config = None - if os.path.isfile(path): + if path.is_file(): try: with open(path, 'r') as config: new_config = yaml.safe_load(config.read()) except (yaml.parser.ParserError, yaml.scanner.ScannerError) as err: - raise ConfigError('Config file %s: failed to parse: %s' % (path, err)) + raise ConfigError(f'Config file {path}: failed to parse: {err}') if res is None: if new_config is None: - raise ConfigError('Base configuration file %s not found in %s' % (configuration_file, path)) + raise ConfigError(f'Base configuration file {configuration_file} not found in {path}') res = new_config elif new_config is not None: __update_dict(res, new_config) + assert res is not None, 'Failed to load config (should never happen, we should have hit an error in loop above)' return res -def __config_file_paths(): +def __config_file_paths() -> list[Path]: """ Paths in which to look for config files, by increasing order of priority (i.e., any config in the last path should take precedence over the others). """ return [ - os.path.join(os.path.dirname(__file__), 'config'), - os.path.join('/etc', 'kattis', 'problemtools'), - os.path.join(os.environ.get('XDG_CONFIG_HOME', os.path.join(os.path.expanduser('~'), '.config')), 'problemtools'), + Path(__file__).parent / 'config', + Path('/etc/kattis/problemtools'), + Path(os.environ.get('XDG_CONFIG_HOME', Path.home() / '.config')) / 'problemtools', ] -def __update_dict(orig, update): +def __update_dict(orig: dict, update: Mapping) -> None: """Deep update of a dictionary For each entry (k, v) in update such that both orig[k] and v are diff --git a/problemtools/languages.py b/problemtools/languages.py index 1bcb61f6..93941205 100644 --- a/problemtools/languages.py +++ b/problemtools/languages.py @@ -6,6 +6,7 @@ import fnmatch import re import string +from pathlib import Path from . import config @@ -218,10 +219,10 @@ def update(self, data): priorities[lang.priority] = lang_id -def load_language_config(): +def load_language_config(probdir_parent: Path) -> Languages: """Load language configuration. Returns: Languages object for the set of languages. """ - return Languages(config.load_config('languages.yaml')) + return Languages(config.load_config('languages.yaml', [probdir_parent])) diff --git a/problemtools/verifyproblem.py b/problemtools/verifyproblem.py index 9005e34e..ef424af0 100644 --- a/problemtools/verifyproblem.py +++ b/problemtools/verifyproblem.py @@ -1902,7 +1902,7 @@ def __init__(self, probdir: str, args: argparse.Namespace): self.probdir = os.path.realpath(probdir) self.shortname: str = os.path.basename(self.probdir) super().__init__(self.shortname, self) - self.language_config = languages.load_language_config() + self.language_config = languages.load_language_config(Path(self.probdir).parent) self.testcase_by_infile: dict[str, TestCase] = {} self.loaded = False self._metadata: metadata.Metadata | None = None diff --git a/tests/test_config.py b/tests/test_config.py index 3307122d..35628a1b 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -5,9 +5,9 @@ def config_paths_mock(): - import os + from pathlib import Path - return [os.path.join(os.path.dirname(__file__), 'config1'), os.path.join(os.path.dirname(__file__), 'config2')] + return [Path(__file__).parent / 'config1', Path(__file__).parent / 'config2'] def test_load_basic_config(monkeypatch):