diff --git a/upath/tests/implementations/test_github.py b/upath/tests/implementations/test_github.py index a5318917..4783b823 100644 --- a/upath/tests/implementations/test_github.py +++ b/upath/tests/implementations/test_github.py @@ -18,45 +18,42 @@ ) -def xfail_on_github_rate_limit(func): - """ - Method decorator to mark test as xfail when GitHub rate limit is exceeded. - """ +def xfail_on_github_connection_error(func): + """Method decorator to xfail tests on GitHub rate limit or connection errors.""" @functools.wraps(func) - def wrapped_method(self, *args, **kwargs): - import requests - + def wrapper(self, *args, **kwargs): try: return func(self, *args, **kwargs) - except AssertionError as e: - if "nodename nor servname provided, or not known" in str(e): - pytest.xfail(reason="No internet connection") - raise - except requests.exceptions.ConnectionError: - pytest.xfail(reason="No internet connection") except Exception as e: - if "rate limit exceeded" in str(e): + str_e = str(e) + if "rate limit exceeded" in str_e or "too many requests for url" in str_e: pytest.xfail("GitHub API rate limit exceeded") + elif ( + "nodename nor servname provided, or not known" in str_e + or "Network is unreachable" in str_e + ): + pytest.xfail("No internet connection") else: raise - return wrapped_method + return wrapper -def wrap_github_rate_limit_check(cls): - """ - Class decorator to wrap all test methods with the - xfail_on_github_rate_limit decorator. - """ - for attr_name in dir(cls): - if attr_name.startswith("test_"): - orig_method = getattr(cls, attr_name) - setattr(cls, attr_name, xfail_on_github_rate_limit(orig_method)) - return cls +def wrap_all_tests(decorator): + """Class decorator factory to wrap all test methods with a given decorator.""" + + def class_decorator(cls): + for attr_name in dir(cls): + if attr_name.startswith("test_"): + orig_method = getattr(cls, attr_name) + setattr(cls, attr_name, decorator(orig_method)) + return cls + + return class_decorator -@wrap_github_rate_limit_check +@wrap_all_tests(xfail_on_github_connection_error) class TestUPathGitHubPath(BaseTests): """ Unit-tests for the GitHubPath implementation of UPath. @@ -70,16 +67,6 @@ def path(self): path = "github://ap--:universal_pathlib@test_data/data" self.path = UPath(path) - @pytest.fixture(autouse=True) - def _xfail_on_rate_limit_errors(self): - try: - yield - except Exception as e: - if "rate limit exceeded" in str(e): - pytest.xfail("GitHub API rate limit exceeded") - else: - raise - def test_is_GitHubPath(self): """ Test that the path is a GitHubPath instance.