Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions src/fromager/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from __future__ import annotations

import datetime
import enum
import functools
import logging
import os
Expand All @@ -14,7 +15,7 @@
from collections.abc import Iterable
from operator import attrgetter
from platform import python_version
from urllib.parse import quote, unquote, urljoin, urlparse
from urllib.parse import quote, unquote, urljoin, urlparse, urlsplit, urlunsplit

import pypi_simple
import resolvelib
Expand Down Expand Up @@ -180,11 +181,42 @@ def resolve_from_provider(
raise ValueError(f"Unable to resolve {req}")


class RetrieveMethod(enum.StrEnum):
tarball = "tarball"
git_https = "git+https"
git_ssh = "git+ssh"

@classmethod
def from_url(cls, download_url) -> tuple[RetrieveMethod, str, str | None]:
"""Parse a download URL into method, url, reference"""
scheme, netloc, path, query, fragment = urlsplit(
download_url, allow_fragments=False
)
match scheme:
case "https":
return RetrieveMethod.tarball, download_url, None
case "git+https":
method = RetrieveMethod.git_https
case "git+ssh":
method = RetrieveMethod.git_ssh
case _:
raise ValueError(f"unsupported download URL {download_url!r}")
# remove git+
scheme = scheme[4:]
# split off @ revision
if "@" not in path:
raise ValueError(f"git download url {download_url!r} is missing '@ref'")
path, ref = path.rsplit("@", 1)
return method, urlunsplit((scheme, netloc, path, query, fragment)), ref


def get_project_from_pypi(
project: str,
extras: typing.Iterable[str],
sdist_server_url: str,
ignore_platform: bool = False,
*,
override_download_url: str | None = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear how a caller would know what an alternative download site might be, and why that would be anything other than another package index server that could have been passed as the sdist_server_url parameter. What's the use case for using on package index to resolve a version and another server to act as the source of the download? Shouldn't the resolver look at the place where it's going to download the package to see what versions are available there?

) -> Candidates:
"""Return candidates created from the project name and extras."""
found_candidates: set[str] = set()
Expand Down Expand Up @@ -345,14 +377,19 @@ def get_project_from_pypi(
ignored_candidates.add(dp.filename)
continue

if override_download_url is None:
url = dp.url
else:
url = override_download_url.format(version=version)

upload_time = dp.upload_time
if upload_time is not None:
upload_time = upload_time.astimezone(datetime.UTC)

c = Candidate(
name=name,
version=version,
url=dp.url,
url=url,
extras=tuple(sorted(extras)),
is_sdist=is_sdist,
build_tag=build_tag,
Expand Down Expand Up @@ -603,6 +640,7 @@ def __init__(
ignore_platform: bool = False,
*,
use_resolver_cache: bool = True,
override_download_url: str | None = None,
):
super().__init__(
constraints=constraints,
Expand All @@ -613,6 +651,7 @@ def __init__(
self.include_wheels = include_wheels
self.sdist_server_url = sdist_server_url
self.ignore_platform = ignore_platform
self.override_download_url = override_download_url

@property
def cache_key(self) -> str:
Expand All @@ -625,9 +664,10 @@ def cache_key(self) -> str:
def find_candidates(self, identifier: str) -> Candidates:
return get_project_from_pypi(
identifier,
set(),
self.sdist_server_url,
self.ignore_platform,
extras=set(),
sdist_server_url=self.sdist_server_url,
ignore_platform=self.ignore_platform,
override_download_url=self.override_download_url,
)

def validate_candidate(
Expand Down Expand Up @@ -803,6 +843,7 @@ def __init__(
*,
req_type: RequirementType | None = None,
use_resolver_cache: bool = True,
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
):
super().__init__(
constraints=constraints,
Expand All @@ -813,6 +854,7 @@ def __init__(
)
self.organization = organization
self.repo = repo
self.retrieve_method = retrieve_method

@property
def cache_key(self) -> str:
Expand Down Expand Up @@ -847,7 +889,14 @@ def _find_tags(
logger.debug(f"{identifier}: match function ignores {tagname}")
continue
assert isinstance(version, Version)
url = entry["tarball_url"]

match self.retrieve_method:
case RetrieveMethod.tarball:
url = entry["tarball_url"]
case RetrieveMethod.git_https:
url = f"git+https://{self.host}/{self.organization}/{self.repo}.git@{tagname}"
case RetrieveMethod.git_ssh:
url = f"git+ssh://git@{self.host}/{self.organization}/{self.repo}.git@{tagname}"

# Github tag API endpoint does not include commit date information.
# It would be too expensive to query every commit API endpoint.
Expand Down Expand Up @@ -880,6 +929,7 @@ def __init__(
*,
req_type: RequirementType | None = None,
use_resolver_cache: bool = True,
retrieve_method: RetrieveMethod = RetrieveMethod.tarball,
) -> None:
super().__init__(
constraints=constraints,
Expand All @@ -889,6 +939,9 @@ def __init__(
matcher=matcher,
)
self.server_url = server_url.rstrip("/")
self.server_hostname = urlparse(server_url).hostname
if not self.server_hostname:
raise ValueError(f"invalid {server_url=}")
self.project_path = project_path.lstrip("/")
# URL-encode the project path as required by GitLab API.
# The safe="" parameter tells quote() to encode ALL characters,
Expand All @@ -899,6 +952,7 @@ def __init__(
self.api_url = (
f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags"
)
self.retrieve_method = retrieve_method

@property
def cache_key(self) -> str:
Expand Down Expand Up @@ -927,8 +981,14 @@ def _find_tags(
continue
assert isinstance(version, Version)

archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
url = urljoin(self.server_url, archive_path)
match self.retrieve_method:
case RetrieveMethod.tarball:
archive_path: str = f"{self.project_path}/-/archive/{tagname}/{self.project_path.split('/')[-1]}-{tagname}.tar.gz"
url = urljoin(self.server_url, archive_path)
case RetrieveMethod.git_https:
url = f"git+https://{self.server_hostname}/{self.project_path}.git@{tagname}"
case RetrieveMethod.git_ssh:
url = f"git+ssh://git@{self.server_hostname}/{self.project_path}.git@{tagname}"

# get tag creation time, fall back to commit creation time
created_at_str: str | None = entry.get("created_at")
Expand Down
158 changes: 158 additions & 0 deletions tests/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,26 @@ def test_provider_constraint_match() -> None:
assert str(candidate.version) == "1.2.2"


def test_provider_override_download_url() -> None:
with requests_mock.Mocker() as r:
r.get(
"https://pypi.org/simple/hydra-core/",
text=_hydra_core_simple_response,
)

provider = resolver.PyPIProvider(
override_download_url="https://server.test/hydr_core-{version}.tar.gz"
)
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
rslvr = resolvelib.Resolver(provider, reporter)

result = rslvr.resolve([Requirement("hydra-core")])
assert "hydra-core" in result.mapping

candidate = result.mapping["hydra-core"]
assert candidate.url == "https://server.test/hydr_core-1.3.2.tar.gz"


_ignore_platform_simple_response = """
<!DOCTYPE html>
<html>
Expand Down Expand Up @@ -715,6 +735,51 @@ def test_resolve_github() -> None:
)


@pytest.mark.parametrize(
["retrieve_method", "expected_url"],
[
(
resolver.RetrieveMethod.tarball,
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
),
(
resolver.RetrieveMethod.git_https,
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
),
(
resolver.RetrieveMethod.git_ssh,
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
),
],
)
def test_resolve_github_retrieve_method(
retrieve_method: resolver.RetrieveMethod, expected_url: str
) -> None:
with requests_mock.Mocker() as r:
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager",
text=_github_fromager_repo_response,
)
r.get(
"https://api.github.com:443/repos/python-wheel-build/fromager/tags",
text=_github_fromager_tag_response,
)

provider = resolver.GitHubTagProvider(
organization="python-wheel-build",
repo="fromager",
retrieve_method=retrieve_method,
)
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
rslvr = resolvelib.Resolver(provider, reporter)

result = rslvr.resolve([Requirement("fromager")])
assert "fromager" in result.mapping

candidate = result.mapping["fromager"]
assert candidate.url == expected_url


def test_github_constraint_mismatch() -> None:
constraint = constraints.Constraints()
constraint.add_constraint("fromager>=1.0")
Expand Down Expand Up @@ -922,6 +987,49 @@ def test_resolve_gitlab() -> None:
)


@pytest.mark.parametrize(
["retrieve_method", "expected_url"],
[
(
resolver.RetrieveMethod.tarball,
"https://gitlab.com/mirrors/github/decile-team/submodlib/-/archive/v0.0.3/submodlib-v0.0.3.tar.gz",
),
(
resolver.RetrieveMethod.git_https,
"git+https://gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
),
(
resolver.RetrieveMethod.git_ssh,
"git+ssh://git@gitlab.com/mirrors/github/decile-team/submodlib.git@v0.0.3",
),
],
)
def test_resolve_gitlab_retrieve_method(
retrieve_method: resolver.RetrieveMethod, expected_url: str
) -> None:
with requests_mock.Mocker() as r:
r.get(
"https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags",
text=_gitlab_submodlib_repo_response,
)

provider = resolver.GitLabTagProvider(
project_path="mirrors/github/decile-team/submodlib",
server_url="https://gitlab.com",
retrieve_method=retrieve_method,
)
reporter: resolvelib.BaseReporter = resolvelib.BaseReporter()
rslvr = resolvelib.Resolver(provider, reporter)

result = rslvr.resolve([Requirement("submodlib")])
assert "submodlib" in result.mapping

candidate = result.mapping["submodlib"]
assert str(candidate.version) == "0.0.3"

assert candidate.url == expected_url


def test_gitlab_constraint_mismatch() -> None:
constraint = constraints.Constraints()
constraint.add_constraint("submodlib>=1.0")
Expand Down Expand Up @@ -1107,3 +1215,53 @@ def custom_resolver_provider(*args, **kwargs):
assert "pypi.org" not in error_message.lower(), (
f"Error message incorrectly mentions PyPI when using GitHub resolver: {error_message}"
)


@pytest.mark.parametrize(
["download_url", "expected_method", "expected_url", "expected_ref"],
[
(
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
resolver.RetrieveMethod.tarball,
"https://api.github.com/repos/python-wheel-build/fromager/tarball/refs/tags/0.9.0",
None,
),
(
"git+https://github.com:443/python-wheel-build/fromager.git@0.9.0",
resolver.RetrieveMethod.git_https,
"https://github.com:443/python-wheel-build/fromager.git",
"0.9.0",
),
(
"git+ssh://git@github.com:443/python-wheel-build/fromager.git@0.9.0",
resolver.RetrieveMethod.git_ssh,
"ssh://git@github.com:443/python-wheel-build/fromager.git",
"0.9.0",
),
],
)
def test_retrieve_method_from_url(
download_url: str,
expected_method: resolver.RetrieveMethod,
expected_url: str,
expected_ref: str | None,
) -> None:
assert resolver.RetrieveMethod.from_url(download_url) == (
expected_method,
expected_url,
expected_ref,
)


@pytest.mark.parametrize(
["download_url"],
[
["http://insecure.test"],
["hg+ssh://mercurial.test"],
],
)
def test_retrieve_method_from_url_error(
download_url: str,
) -> None:
with pytest.raises(ValueError):
resolver.RetrieveMethod.from_url(download_url)
Loading