From 88878d9033dbddc7034a7f24d20033db9f78df80 Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Mon, 1 Dec 2025 08:24:29 +0100 Subject: [PATCH] feat: add alternative download methods to resolver API Extend the resolver API with alternative download URLs. Resolvers can now return download links to alternative locations or retrieval methods. The `PyPIProvider` now accepts a `override_download_url` parameter. The value overwrites the default PyPI download link. The string can contain a `{version}` format variable. The GitHub and Gitlab tag providers can return git clone URLs for `https` and `ssh` transport. The URLs uses pip's VCS syntax like `git+https://host/repo.git@tag`. The new enum `RetrieveMethod` has a `from_url()` constructor that parses an URL and splits it into method, url, and git ref. Signed-off-by: Christian Heimes --- src/fromager/resolver.py | 76 +++++++++++++++++-- tests/test_resolver.py | 158 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 8 deletions(-) diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 1442d97e..c8149a42 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -6,6 +6,7 @@ from __future__ import annotations import datetime +import enum import functools import logging import os @@ -14,7 +15,7 @@ from collections.abc import Iterable from operator import attrgetter from platform import python_version -from urllib.parse import quote, unquote, urljoin, urlparse +from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit import pypi_simple import resolvelib @@ -180,11 +181,42 @@ def resolve_from_provider( raise ValueError(f"Unable to resolve {req}") +class RetrieveMethod(enum.StrEnum): + tarball = "tarball" + git_https = "git+https" + git_ssh = "git+ssh" + + @classmethod + def from_url(cls, download_url) -> tuple[RetrieveMethod, str, str | None]: + """Parse a download URL into method, url, reference""" + scheme, netloc, path, query, fragment = urlsplit( + download_url, allow_fragments=False + ) + match scheme: + case "https": + return RetrieveMethod.tarball, download_url, None + case "git+https": + method = RetrieveMethod.git_https + case "git+ssh": + method = RetrieveMethod.git_ssh + case _: + raise ValueError(f"unsupported download URL {download_url!r}") + # remove git+ + scheme = scheme[4:] + # split off @ revision + if "@" not in path: + raise ValueError(f"git download url {download_url!r} is missing '@ref'") + path, ref = path.rsplit("@", 1) + return method, urlunsplit((scheme, netloc, path, query, fragment)), ref + + def get_project_from_pypi( project: str, extras: typing.Iterable[str], sdist_server_url: str, ignore_platform: bool = False, + *, + override_download_url: str | None = None, ) -> Candidates: """Return candidates created from the project name and extras.""" found_candidates: set[str] = set() @@ -345,6 +377,11 @@ def get_project_from_pypi( ignored_candidates.add(dp.filename) continue + if override_download_url is None: + url = dp.url + else: + url = override_download_url.format(version=version) + upload_time = dp.upload_time if upload_time is not None: upload_time = upload_time.astimezone(datetime.UTC) @@ -352,7 +389,7 @@ def get_project_from_pypi( c = Candidate( name=name, version=version, - url=dp.url, + url=url, extras=tuple(sorted(extras)), is_sdist=is_sdist, build_tag=build_tag, @@ -603,6 +640,7 @@ def __init__( ignore_platform: bool = False, *, use_resolver_cache: bool = True, + override_download_url: str | None = None, ): super().__init__( constraints=constraints, @@ -613,6 +651,7 @@ def __init__( self.include_wheels = include_wheels self.sdist_server_url = sdist_server_url self.ignore_platform = ignore_platform + self.override_download_url = override_download_url @property def cache_key(self) -> str: @@ -625,9 +664,10 @@ def cache_key(self) -> str: def find_candidates(self, identifier: str) -> Candidates: return get_project_from_pypi( identifier, - set(), - self.sdist_server_url, - self.ignore_platform, + extras=set(), + sdist_server_url=self.sdist_server_url, + ignore_platform=self.ignore_platform, + override_download_url=self.override_download_url, ) def validate_candidate( @@ -803,6 +843,7 @@ def __init__( *, req_type: RequirementType | None = None, use_resolver_cache: bool = True, + retrieve_method: RetrieveMethod = RetrieveMethod.tarball, ): super().__init__( constraints=constraints, @@ -813,6 +854,7 @@ def __init__( ) self.organization = organization self.repo = repo + self.retrieve_method = retrieve_method @property def cache_key(self) -> str: @@ -847,7 +889,14 @@ def _find_tags( logger.debug(f"{identifier}: match function ignores {tagname}") continue assert isinstance(version, Version) - url = entry["tarball_url"] + + match self.retrieve_method: + case RetrieveMethod.tarball: + url = entry["tarball_url"] + case RetrieveMethod.git_https: + url = f"git+https://{self.host}/{self.organization}/{self.repo}.git@{tagname}" + case RetrieveMethod.git_ssh: + url = f"git+ssh://git@{self.host}/{self.organization}/{self.repo}.git@{tagname}" # Github tag API endpoint does not include commit date information. # It would be too expensive to query every commit API endpoint. @@ -880,6 +929,7 @@ def __init__( *, req_type: RequirementType | None = None, use_resolver_cache: bool = True, + retrieve_method: RetrieveMethod = RetrieveMethod.tarball, ) -> None: super().__init__( constraints=constraints, @@ -889,6 +939,9 @@ def __init__( matcher=matcher, ) self.server_url = server_url.rstrip("/") + self.server_hostname = urlparse(server_url).hostname + if not self.server_hostname: + raise ValueError(f"invalid {server_url=}") self.project_path = project_path.lstrip("/") # URL-encode the project path as required by GitLab API. # The safe="" parameter tells quote() to encode ALL characters, @@ -899,6 +952,7 @@ def __init__( self.api_url = ( f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags" ) + self.retrieve_method = retrieve_method @property def cache_key(self) -> str: @@ -927,8 +981,14 @@ def _find_tags( continue assert isinstance(version, Version) - archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz" - url = urljoin(self.server_url, archive_path) + match self.retrieve_method: + case RetrieveMethod.tarball: + archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz" + url = urljoin(self.server_url, archive_path) + case RetrieveMethod.git_https: + url = f"git+https://{self.server_hostname}/{self.project_path}.git@{tagname}" + case RetrieveMethod.git_ssh: + url = f"git+ssh://git@{self.server_hostname}/{self.project_path}.git@{tagname}" # get tag creation time, fall back to commit creation time created_at_str: str | None = entry.get("created_at") diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 4f4150d4..5fb2ba78 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -370,6 +370,26 @@ def test_provider_constraint_match() -> None: assert str(candidate.version) == "1.2.2" +def test_provider_override_download_url() -> None: + with requests_mock.Mocker() as r: + r.get( + "https://pypi.org/simple/hydra-core/", + text=_hydra_core_simple_response, + ) + + provider = resolver.PyPIProvider( + override_download_url="https://server.test/hydr_core-{version}.tar.gz" + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + rslvr = resolvelib.Resolver(provider, reporter) + + result = rslvr.resolve([Requirement("hydra-core")]) + assert "hydra-core" in result.mapping + + candidate = result.mapping["hydra-core"] + assert candidate.url == "https://server.test/hydr_core-1.3.2.tar.gz" + + _ignore_platform_simple_response = """ @@ -715,6 +735,51 @@ def test_resolve_github() -> None: ) +@pytest.mark.parametrize( + ["retrieve_method", "expected_url"], + [ + ( + resolver.RetrieveMethod.tarball, + "https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0", + ), + ( + resolver.RetrieveMethod.git_https, + "git+https://github.com:443/python-wheel-build/fromager.git@0.9.0", + ), + ( + resolver.RetrieveMethod.git_ssh, + "git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0", + ), + ], +) +def test_resolve_github_retrieve_method( + retrieve_method: resolver.RetrieveMethod, expected_url: str +) -> None: + with requests_mock.Mocker() as r: + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager", + text=_github_fromager_repo_response, + ) + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager/tags", + text=_github_fromager_tag_response, + ) + + provider = resolver.GitHubTagProvider( + organization="python-wheel-build", + repo="fromager", + retrieve_method=retrieve_method, + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + rslvr = resolvelib.Resolver(provider, reporter) + + result = rslvr.resolve([Requirement("fromager")]) + assert "fromager" in result.mapping + + candidate = result.mapping["fromager"] + assert candidate.url == expected_url + + def test_github_constraint_mismatch() -> None: constraint = constraints.Constraints() constraint.add_constraint("fromager>=1.0") @@ -922,6 +987,49 @@ def test_resolve_gitlab() -> None: ) +@pytest.mark.parametrize( + ["retrieve_method", "expected_url"], + [ + ( + resolver.RetrieveMethod.tarball, + "https://gitlab.com/mirrors/github/decile-team/submodlib/-/archive/v0.0.3/submodlib-v0.0.3.tar.gz", + ), + ( + resolver.RetrieveMethod.git_https, + "git+https://gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3", + ), + ( + resolver.RetrieveMethod.git_ssh, + "git+ssh://git@gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3", + ), + ], +) +def test_resolve_gitlab_retrieve_method( + retrieve_method: resolver.RetrieveMethod, expected_url: str +) -> None: + with requests_mock.Mocker() as r: + r.get( + "https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags", + text=_gitlab_submodlib_repo_response, + ) + + provider = resolver.GitLabTagProvider( + project_path="mirrors/github/decile-team/submodlib", + server_url="https://gitlab.com", + retrieve_method=retrieve_method, + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + rslvr = resolvelib.Resolver(provider, reporter) + + result = rslvr.resolve([Requirement("submodlib")]) + assert "submodlib" in result.mapping + + candidate = result.mapping["submodlib"] + assert str(candidate.version) == "0.0.3" + + assert candidate.url == expected_url + + def test_gitlab_constraint_mismatch() -> None: constraint = constraints.Constraints() constraint.add_constraint("submodlib>=1.0") @@ -1107,3 +1215,53 @@ def custom_resolver_provider(*args, **kwargs): assert "pypi.org" not in error_message.lower(), ( f"Error message incorrectly mentions PyPI when using GitHub resolver: {error_message}" ) + + +@pytest.mark.parametrize( + ["download_url", "expected_method", "expected_url", "expected_ref"], + [ + ( + "https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0", + resolver.RetrieveMethod.tarball, + "https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0", + None, + ), + ( + "git+https://github.com:443/python-wheel-build/fromager.git@0.9.0", + resolver.RetrieveMethod.git_https, + "https://github.com:443/python-wheel-build/fromager.git", + "0.9.0", + ), + ( + "git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0", + resolver.RetrieveMethod.git_ssh, + "ssh://git@github.com:443/python-wheel-build/fromager.git", + "0.9.0", + ), + ], +) +def test_retrieve_method_from_url( + download_url: str, + expected_method: resolver.RetrieveMethod, + expected_url: str, + expected_ref: str | None, +) -> None: + assert resolver.RetrieveMethod.from_url(download_url) == ( + expected_method, + expected_url, + expected_ref, + ) + + +@pytest.mark.parametrize( + ["download_url"], + [ + ["http://insecure.test"], + ["hg+ssh://mercurial.test"], + ], +) +def test_retrieve_method_from_url_error( + download_url: str, +) -> None: + with pytest.raises(ValueError): + resolver.RetrieveMethod.from_url(download_url)