diff --git a/docs/config-reference.rst b/docs/config-reference.rst index d8090d6f..5c56d3d2 100644 --- a/docs/config-reference.rst +++ b/docs/config-reference.rst @@ -20,6 +20,8 @@ For example `flash_attn.yaml`. .. autopydantic_model:: fromager.packagesettings.ProjectOverride +.. autopydantic_model:: fromager.packagesettings.CreateFile + Global Settings --------------- diff --git a/docs/customization.md b/docs/customization.md index d5bead58..0dc9f63a 100644 --- a/docs/customization.md +++ b/docs/customization.md @@ -97,10 +97,48 @@ support templating. The only supported template variable are: ### Resolver dist The source distribution index server used by the package resolver can -be overriden for a particular package. The resolver can also be told +be overridden for a particular package. The resolver can also be told to whether include wheels or sdist sources while trying to resolve the package. Templating is not supported here. +#### Alternative resolver providers + +By default, fromager resolves package versions from PyPI. The `resolver_dist` +section also supports resolving versions from GitHub releases or GitLab +tags using the `provider` field. + +**GitHub provider:** + +```yaml +resolver_dist: + provider: github + organization: openssl + repo: openssl + tag_matcher: "openssl-(.*)" +``` + +The `organization` and `repo` fields are required for the GitHub provider. + +**GitLab provider:** + +```yaml +resolver_dist: + provider: gitlab + project_path: group/subgroup/project + server_url: https://gitlab.example.com + tag_matcher: "v(.*)" +``` + +For GitLab, you can use either `project_path` (which takes precedence) or +`organization` and `repo`. The `server_url` defaults to `https://gitlab.com`. + +**`tag_matcher`:** + +The `tag_matcher` field is a regular expression pattern used to extract +version numbers from git tags. It must contain exactly one capturing group. +For example, `"v(.*)"` matches tags like `v1.2.3` and extracts `1.2.3` as +the version. This field works with all providers. + ### Git submodules When building packages from git repositories using `git+https://` URLs in your @@ -229,6 +267,80 @@ variants: PATH: "/cpu/bin:$PATH" ``` +#### Version template variables + +When a resolved version is available, `env` values can reference version +template variables: + +- `${version}` - the full version string (e.g., `1.2.3.post1`) +- `${version_base_version}` - the base version without pre/post/dev + suffixes (e.g., `1.2.3`) +- `${version_post}` - the post release number as a string, or empty + string if there is no post release + +These are useful for packages that need version information passed at build +time. + +```yaml +env: + BUILD_VERSION: "${version}" + PYTORCH_BUILD_VERSION: "${version_base_version}" + POST_RELEASE: "${version_post}" +``` + +Note that actual environment variables with the same name take precedence +over these template variables. + +### Automatic PKG-INFO creation + +When preparing new source trees, fromager automatically creates a +`PKG-INFO` file if one is missing. Every sdist must have a `PKG-INFO` +file in the root directory. This is done automatically and does not +require any YAML configuration. If the package has a non-standard +`build_dir`, the `PKG-INFO` file is also created in that directory. + +This behavior replaces the need to call `ensure_pkg_info()` manually +in override plugins for most use cases. + +### Creating files in the source tree + +The `create_files` setting allows you to create files in the source tree +before building. This is useful for adding missing files that some sdists +are lacking, such as `__init__.py`, `version.py`, or empty requirements +files. + +Each entry requires a `path` (relative to the source root) and an optional +`content`. The `content` field supports template substitution with the +same version variables available in `env` settings: `${version}`, +`${version_base_version}`, and `${version_post}`. + +```yaml +create_files: + - path: requirements-dev.txt + content: "" + - path: src/mypackage/_version.py + content: | + __version__ = "${version}" +``` + +Paths must be relative and must not contain `..` components. + +### Rust vendor ordering + +By default, fromager vendors Rust crate dependencies after applying +patches. If your patches modify vendored `Cargo.lock` or `Cargo.toml` +files, you may need to vendor Rust crates first and then apply patches +on top of the vendored sources. + +Set `vendor_rust_before_patch` to `true` to change the ordering: + +```yaml +vendor_rust_before_patch: true +``` + +When enabled, `cargo vendor` runs before patches are applied instead of +after. The default is `false`. + ## Patching source The `--patches-dir` command line argument specifies a directory containing @@ -291,6 +403,8 @@ The `project_override` configures the `pyproject.toml` auto-fixer. It can automatically create a missing `pyproject.toml` or modify an existing file. Packages are matched by canonical name. +### Build requirements + - `remove_build_requires` is a list of package names. Any build requirement in the list is removed - `update_build_requires` a list of requirement specifiers. Existing specs @@ -321,6 +435,29 @@ Output: requires = ["setuptools>=68.0.0", "torch", "triton"] ``` +### Install dependencies + +The `project_override` section also supports modifying the `[project] +dependencies` (install requirements) in `pyproject.toml`: + +- `remove_install_requires` is a list of package names. Any install + dependency matching the name is removed. +- `update_install_requires` is a list of requirement specifiers. Existing + specs are replaced and missing specs are added. + +```yaml +project_override: + remove_install_requires: + - easyocr + - rapidocr-onnxruntime + update_install_requires: + - "torch>=2.3.0" +``` + +This is useful for removing optional or platform-specific dependencies that +are not needed in your build environment, or for pinning specific versions +of install dependencies. + ## Override plugins Override plugins are documented in [the reference guide](hooks.rst). diff --git a/docs/how-tos/pyproject-overrides.rst b/docs/how-tos/pyproject-overrides.rst index 683f7776..7941200d 100644 --- a/docs/how-tos/pyproject-overrides.rst +++ b/docs/how-tos/pyproject-overrides.rst @@ -73,3 +73,22 @@ All existing numpy requirements are replaced by all the new numpy requirements: "packaging", "setuptools" ] + +Modifying Install Dependencies +------------------------------ + +The same matching behavior applies to install dependencies (``[project] +dependencies``) using ``update_install_requires`` and +``remove_install_requires``. + +.. code-block:: yaml + + project_override: + remove_install_requires: + - easyocr + - rapidocr-onnxruntime + update_install_requires: + - "torch>=2.3.0" + +This removes the ``easyocr`` and ``rapidocr-onnxruntime`` packages from +install dependencies and adds or updates the ``torch`` requirement. diff --git a/src/fromager/packagesettings.py b/src/fromager/packagesettings.py index 7ae96c22..69587da4 100644 --- a/src/fromager/packagesettings.py +++ b/src/fromager/packagesettings.py @@ -201,12 +201,79 @@ class ResolverDist(pydantic.BaseModel): .. versionadded:: 0.70 """ + provider: str | None = None + """Resolver provider type: 'pypi' (default), 'github', 'gitlab' + + When not set, defaults to PyPI provider behavior. + + .. versionadded:: 0.XX + """ + + organization: str | None = None + """GitHub organization or GitLab group (required for github/gitlab providers)""" + + repo: str | None = None + """GitHub/GitLab repository name (required for github/gitlab providers)""" + + project_path: str | None = None + """GitLab project path (e.g., 'group/subgroup/project') + + For GitLab, this takes precedence over organization/repo. + + .. versionadded:: 0.XX + """ + + server_url: str | None = None + """GitLab server URL (default: https://gitlab.com) + + Only used with gitlab provider. + + .. versionadded:: 0.XX + """ + + tag_matcher: str | None = None + """Regex pattern for matching version tags + + Applied to git tags to extract version numbers. Must contain + exactly one capturing group that captures the version string. + Example: ``v(\\d+\\.\\d+\\.\\d+)`` + + .. versionadded:: 0.XX + """ + @pydantic.model_validator(mode="after") - def validate_ignore_platform(self) -> typing.Self: + def validate_resolver_dist(self) -> typing.Self: + """Validate resolver_dist configuration.""" if self.ignore_platform and not self.include_wheels: raise ValueError( "'ignore_platforms' has no effect without 'include_wheels'" ) + if self.provider == "github": + if not self.organization or not self.repo: + raise ValueError("GitHub provider requires 'organization' and 'repo'") + elif self.provider == "gitlab": + if not self.project_path and not (self.organization and self.repo): + raise ValueError( + "GitLab provider requires 'project_path' or " + "'organization' and 'repo'" + ) + elif self.provider is not None and self.provider != "pypi": + raise ValueError( + f"Unknown provider: {self.provider!r}. " + f"Supported: 'pypi', 'github', 'gitlab'" + ) + if self.tag_matcher is not None: + try: + pattern = re.compile(self.tag_matcher) + except re.error as err: + raise ValueError( + f"Invalid tag_matcher regex: {self.tag_matcher!r}: {err}" + ) from err + if pattern.groups != 1: + raise ValueError( + f"tag_matcher must have exactly 1 capturing group, " + f"got {pattern.groups}" + ) return self @@ -316,9 +383,24 @@ class ProjectOverride(pydantic.BaseModel): ``tomlkit.loads(dist(pkgname).read_text("fromager-build-settings"))``. """ + update_install_requires: list[str] = Field(default_factory=list) + """Add / update requirements in pyproject.toml ``[project] dependencies``""" + + remove_install_requires: list[Package] = Field(default_factory=list) + """Remove requirements from pyproject.toml ``[project] dependencies``""" + @pydantic.field_validator("update_build_requires") @classmethod def validate_update_build_requires(cls, v: list[str]) -> list[str]: + """Validate that each entry is a valid requirement string.""" + for reqstr in v: + Requirement(reqstr) + return v + + @pydantic.field_validator("update_install_requires") + @classmethod + def validate_update_install_requires(cls, v: list[str]) -> list[str]: + """Validate that each entry is a valid requirement string.""" for reqstr in v: Requirement(reqstr) return v @@ -388,6 +470,36 @@ class GitOptions(pydantic.BaseModel): _DictStrAny = dict[str, typing.Any] +class CreateFile(pydantic.BaseModel): + """A file to create in the source tree + + :: + + path: src/mypackage/version.py + content: | + __version__ = "${version}" + """ + + model_config = MODEL_CONFIG + + path: str + """Relative path within the source tree""" + + content: str = "" + """File content (supports template substitution with ${version}, etc.)""" + + @pydantic.field_validator("path") + @classmethod + def validate_path(cls, v: str) -> str: + """Reject absolute paths and path traversal components.""" + p = pathlib.Path(v) + if p.is_absolute(): + raise ValueError(f"{v!r} is not a relative path") + if ".." in p.parts: + raise ValueError(f"{v!r} must not contain '..' components") + return v + + class PackageSettings(pydantic.BaseModel): """Package settings @@ -450,6 +562,15 @@ class PackageSettings(pydantic.BaseModel): - "-Dsystem-qhull=true" """ + create_files: list[CreateFile] = Field(default_factory=list) + """Files to create in the source tree before building + + Useful for adding missing __init__.py, version.py, or other files + that some sdists are missing. + + .. versionadded:: 0.XX + """ + env: EnvVars = Field(default_factory=dict) """Common env var for all variants""" @@ -468,6 +589,16 @@ class PackageSettings(pydantic.BaseModel): project_override: ProjectOverride = Field(default_factory=ProjectOverride) """Patch project settings""" + vendor_rust_before_patch: bool = False + """Vendor Rust crates before applying patches instead of after + + When True, ``cargo vendor`` runs before patches are applied. + This is useful when patches modify vendored Cargo.lock or + Cargo.toml files. + + .. versionadded:: 0.XX + """ + variants: Mapping[Variant, VariantInfo] = Field(default_factory=dict) """Variant configuration""" @@ -825,6 +956,36 @@ def resolver_ignore_platform(self) -> bool: """Ignore the platform when resolving with wheels?""" return self._ps.resolver_dist.ignore_platform + @property + def resolver_provider(self) -> str | None: + """Resolver provider type.""" + return self._ps.resolver_dist.provider + + @property + def resolver_organization(self) -> str | None: + """GitHub/GitLab organization.""" + return self._ps.resolver_dist.organization + + @property + def resolver_repo(self) -> str | None: + """GitHub/GitLab repository name.""" + return self._ps.resolver_dist.repo + + @property + def resolver_project_path(self) -> str | None: + """GitLab project path.""" + return self._ps.resolver_dist.project_path + + @property + def resolver_server_url(self) -> str | None: + """GitLab server URL.""" + return self._ps.resolver_dist.server_url + + @property + def resolver_tag_matcher(self) -> str | None: + """Tag matcher regex pattern.""" + return self._ps.resolver_dist.tag_matcher + @property def use_pypi_org_metadata(self) -> bool: """Can use metadata from pypi.org JSON / Simple API? @@ -883,13 +1044,16 @@ def get_extra_environ( *, template_env: dict[str, str] | None = None, build_env: build_environment.BuildEnvironment | None = None, + version: Version | None = None, ) -> dict[str, str]: """Get extra environment variables for a variant 1. parallel jobs: ``MAKEFLAGS``, ``MAX_JOBS``, ``CMAKE_BUILD_PARALLEL_LEVEL`` 2. PATH and VIRTUAL_ENV from ``build_env`` (if given) - 3. package's env settings - 4. package variant's env settings + 3. version template variables (``${version}``, ``${version_base_version}``, + ``${version_post}``) when *version* is given + 4. package's env settings + 5. package variant's env settings `template_env` defaults to `os.environ`. """ @@ -916,6 +1080,14 @@ def get_extra_environ( template_env.update(venv_environ) extra_environ.update(venv_environ) + # add version template variables if version is provided; + # use setdefault so actual environment variables take precedence + if version is not None: + template_env.setdefault("version", str(version)) + template_env.setdefault("version_base_version", version.base_version) + post_str = str(version.post) if version.post is not None else "" + template_env.setdefault("version_post", post_str) + # chain entries so variant entries can reference general entries entries = list(self._ps.env.items()) vi = self._ps.variants.get(self.variant) @@ -973,6 +1145,16 @@ def git_options(self) -> GitOptions: """Git repository cloning options""" return self._ps.git_options + @property + def create_files(self) -> list[CreateFile]: + """Files to create in the source tree.""" + return list(self._ps.create_files) + + @property + def vendor_rust_before_patch(self) -> bool: + """Should Rust crates be vendored before patching?""" + return self._ps.vendor_rust_before_patch + @property def project_override(self) -> ProjectOverride: return self._ps.project_override @@ -1226,7 +1408,7 @@ def get_extra_environ( ) -> dict[str, str]: """Get extra environment variables from settings and update hook""" pbi = ctx.package_build_info(req) - extra_environ = pbi.get_extra_environ(build_env=build_env) + extra_environ = pbi.get_extra_environ(build_env=build_env, version=version) overrides.find_and_invoke( req.name, "update_extra_environ", diff --git a/src/fromager/pyproject.py b/src/fromager/pyproject.py index 2277bc29..3d7aa90b 100644 --- a/src/fromager/pyproject.py +++ b/src/fromager/pyproject.py @@ -22,20 +22,30 @@ BUILD_SYSTEM = "build-system" BUILD_BACKEND = "build-backend" BUILD_REQUIRES = "requires" +PROJECT = "project" +DEPENDENCIES = "dependencies" class PyprojectFix: """Auto-fixer for pyproject.toml settings - add missing pyproject.toml - - add or update `[build-system] requires` + - add or update ``[build-system] requires`` + - add, update, or remove ``[project] dependencies`` Requirements in `update_build_requires` are added to - `[build-system] requires`. If a requirement name matches an existing + ``[build-system] requires``. If a requirement name matches an existing name, then the requirement is replaced. Requirements in `remove_build_requires` are removed from - `[build-system] requires`. + ``[build-system] requires``. + + Requirements in `update_install_requires` are added to + ``[project] dependencies``. If a requirement name matches an existing + name, then the requirement is replaced. + + Requirements in `remove_install_requires` are removed from + ``[project] dependencies``. """ def __init__( @@ -45,18 +55,24 @@ def __init__( build_dir: pathlib.Path, update_build_requires: list[str], remove_build_requires: list[NormalizedName], + update_install_requires: list[str] | None = None, + remove_install_requires: list[NormalizedName] | None = None, ) -> None: self.req = req self.build_dir = build_dir self.update_requirements = update_build_requires self.remove_requirements = remove_build_requires + self.update_install_requirements = update_install_requires or [] + self.remove_install_requirements = remove_install_requires or [] self.pyproject_toml = self.build_dir / "pyproject.toml" self.setup_py = self.build_dir / "setup.py" def run(self) -> None: + """Load, fix, and save pyproject.toml.""" doc = self._load() build_system = self._default_build_system(doc) self._update_build_requires(build_system) + self._update_install_requires(doc) logger.debug( "pyproject.toml %s: %s=%r, %s=%r", BUILD_SYSTEM, @@ -126,18 +142,70 @@ def _update_build_requires(self, build_system: TomlDict) -> None: new_requires, ) + def _update_install_requires(self, doc: tomlkit.TOMLDocument) -> None: + """Update ``[project] dependencies``.""" + if ( + not self.update_install_requirements + and not self.remove_install_requirements + ): + return + + project: TomlDict | None = doc.get(PROJECT) + if project is None: + logger.debug("no [project] section, skipping install_requires changes") + return + + old_deps: list[str] = list(project.get(DEPENDENCIES, [])) + # Build a map of canonicalized name -> list of Requirement strings + dep_map: dict[NormalizedName, list[Requirement]] = {} + for depstr in old_deps: + dep = Requirement(depstr) + dep_map.setdefault(canonicalize_name(dep.name), []).append(dep) + + # Remove unwanted + for name in self.remove_install_requirements: + dep_map.pop(canonicalize_name(name), None) + + # Add / update + update_map: dict[NormalizedName, list[Requirement]] = {} + for depstr in self.update_install_requirements: + dep = Requirement(depstr) + update_map.setdefault(canonicalize_name(dep.name), []).append(dep) + dep_map.update(update_map) + + new_deps = sorted( + itertools.chain.from_iterable( + [str(dep) for dep in deps] for deps in dep_map.values() + ) + ) + if set(new_deps) != set(old_deps): + project[DEPENDENCIES] = new_deps + logger.info( + "changed project dependencies from %r to %r", + old_deps, + new_deps, + ) + def apply_project_override( ctx: context.WorkContext, req: Requirement, sdist_root_dir: pathlib.Path ) -> None: - """Apply project_overrides""" + """Apply project_overrides.""" pbi = ctx.package_build_info(req) update_build_requires = pbi.project_override.update_build_requires remove_build_requires = pbi.project_override.remove_build_requires - if update_build_requires or remove_build_requires: + update_install_requires = pbi.project_override.update_install_requires + remove_install_requires = pbi.project_override.remove_install_requires + if ( + update_build_requires + or remove_build_requires + or update_install_requires + or remove_install_requires + ): logger.debug( f"applying project_override: " - f"{update_build_requires=}, {remove_build_requires=}" + f"{update_build_requires=}, {remove_build_requires=}, " + f"{update_install_requires=}, {remove_install_requires=}" ) build_dir = pbi.build_dir(sdist_root_dir) PyprojectFix( @@ -145,6 +213,8 @@ def apply_project_override( build_dir=build_dir, update_build_requires=update_build_requires, remove_build_requires=remove_build_requires, + update_install_requires=update_install_requires, + remove_install_requires=remove_install_requires, ).run() else: logger.debug("no project_override") diff --git a/src/fromager/sources.py b/src/fromager/sources.py index 3fc93e93..594ee297 100644 --- a/src/fromager/sources.py +++ b/src/fromager/sources.py @@ -4,6 +4,7 @@ import json import logging import pathlib +import re import shutil import tarfile import typing @@ -43,6 +44,7 @@ def get_source_type(ctx: context.WorkContext, req: Requirement) -> SourceType: + """Determine how a requirement's source will be obtained.""" source_type = SourceType.SDIST if req.url: return SourceType.GIT @@ -52,6 +54,7 @@ def get_source_type(ctx: context.WorkContext, req: Requirement) -> SourceType: or overrides.find_override_method(req.name, "resolve_source") or overrides.find_override_method(req.name, "get_resolver_provider") or pbi.download_source_url(resolve_template=False) + or pbi.resolver_provider ): source_type = SourceType.OVERRIDE return source_type @@ -176,11 +179,25 @@ def default_resolve_source( sdist_server_url: str, req_type: RequirementType | None = None, ) -> tuple[str, Version]: - "Return URL to source and its version." + """Return URL to source and its version. + Checks for a YAML-configured resolver provider first. If one is + configured (and is not 'pypi'), the appropriate provider is created + and used. Otherwise falls back to the standard PyPI resolver. + """ pbi = ctx.package_build_info(req) override_sdist_server_url = pbi.resolver_sdist_server_url(sdist_server_url) + # Check if a specific provider is configured in YAML + provider_type = pbi.resolver_provider + if provider_type and provider_type != "pypi": + url, version = _resolve_with_configured_provider( + ctx=ctx, + req=req, + pbi=pbi, + ) + return url, version + url, version = resolver.resolve( ctx=ctx, req=req, @@ -193,6 +210,47 @@ def default_resolve_source( return url, version +def _resolve_with_configured_provider( + *, + ctx: context.WorkContext, + req: Requirement, + pbi: packagesettings.PackageBuildInfo, +) -> tuple[str, Version]: + """Create a resolver provider from YAML configuration and resolve.""" + provider_type = pbi.resolver_provider + tag_matcher_pattern = pbi.resolver_tag_matcher + matcher: re.Pattern[str] | None = None + if tag_matcher_pattern: + matcher = re.compile(tag_matcher_pattern) + + provider: resolver.BaseProvider + if provider_type == "github": + assert pbi.resolver_organization is not None + assert pbi.resolver_repo is not None + provider = resolver.GitHubTagProvider( + organization=pbi.resolver_organization, + repo=pbi.resolver_repo, + constraints=ctx.constraints, + matcher=matcher, + ) + elif provider_type == "gitlab": + project_path = pbi.resolver_project_path + if not project_path: + assert pbi.resolver_organization is not None + assert pbi.resolver_repo is not None + project_path = f"{pbi.resolver_organization}/{pbi.resolver_repo}" + provider = resolver.GitLabTagProvider( + project_path=project_path, + server_url=pbi.resolver_server_url or "https://gitlab.com", + constraints=ctx.constraints, + matcher=matcher, + ) + else: + raise ValueError(f"Unknown provider type: {provider_type!r}") + + return resolver.resolve_from_provider(provider, req) + + def default_download_source( ctx: context.WorkContext, req: Requirement, @@ -571,19 +629,64 @@ def prepare_new_source( ) -> None: """Default steps for new sources + - ensure PKG-INFO exists + - optionally vendor Rust crates before patching - patch sources + - create files from settings - apply project overrides from settings - - vendor Rust dependencies + - vendor Rust dependencies (if not already done before patching) :func:`~default_prepare_source` runs this function when the sources are new. """ + pbi = ctx.package_build_info(req) + build_dir = pbi.build_dir(source_root_dir) + ensure_pkg_info( + ctx=ctx, + req=req, + version=version, + sdist_root_dir=source_root_dir, + build_dir=build_dir, + ) + if pbi.vendor_rust_before_patch: + vendor_rust.vendor_rust(req, source_root_dir) patch_source(ctx, source_root_dir, req, version) + create_source_files(ctx, req, source_root_dir, version) pyproject.apply_project_override( ctx=ctx, req=req, sdist_root_dir=source_root_dir, ) - vendor_rust.vendor_rust(req, source_root_dir) + if not pbi.vendor_rust_before_patch: + vendor_rust.vendor_rust(req, source_root_dir) + + +def create_source_files( + ctx: context.WorkContext, + req: Requirement, + source_root_dir: pathlib.Path, + version: Version, +) -> None: + """Create files defined in package settings create_files. + + Each file spec includes a relative path and optional content that + supports template substitution with ``${version}``, etc. + """ + pbi = ctx.package_build_info(req) + files = pbi.create_files + if not files: + return + + for file_spec in files: + file_path = source_root_dir / file_spec.path + template_env: dict[str, str] = { + "version": str(version), + "version_base_version": version.base_version, + "canonicalized_name": str(pbi.package), + } + content = packagesettings.substitute_template(file_spec.content, template_env) + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.write_text(content) + logger.info("created file %s", file_path) @metrics.timeit(description="build sdist") diff --git a/tests/test_packagesettings.py b/tests/test_packagesettings.py index 07bee01a..ceca15ba 100644 --- a/tests/test_packagesettings.py +++ b/tests/test_packagesettings.py @@ -12,6 +12,7 @@ from fromager.packagesettings import ( Annotations, BuildDirectory, + CreateFile, EnvVars, GitOptions, Package, @@ -51,6 +52,7 @@ ], "cmake.define.BLA_VENDOR": "OpenBLAS", }, + "create_files": [], "download_source": { "destination_filename": "${canonicalized_name}-${version}.tar.gz", "url": "https://egg.test/${canonicalized_name}/v${version}.tar.gz", @@ -73,6 +75,8 @@ "remove_build_requires": ["cmake"], "update_build_requires": ["setuptools>=68.0.0", "torch"], "requires_external": ["openssl-libs"], + "update_install_requires": [], + "remove_install_requires": [], }, "resolver_dist": { "include_sdists": True, @@ -80,6 +84,12 @@ "sdist_server_url": "https://sdist.test/egg", "ignore_platform": True, "use_pypi_org_metadata": True, + "provider": None, + "organization": None, + "repo": None, + "project_path": None, + "server_url": None, + "tag_matcher": None, }, "variants": { "cpu": { @@ -105,6 +115,7 @@ "pre_built": False, }, }, + "vendor_rust_before_patch": False, } EMPTY_EXPECTED: dict[str, typing.Any] = { @@ -119,6 +130,7 @@ }, "changelog": {}, "config_settings": {}, + "create_files": [], "env": {}, "download_source": { "url": None, @@ -133,6 +145,8 @@ "remove_build_requires": [], "update_build_requires": [], "requires_external": [], + "update_install_requires": [], + "remove_install_requires": [], }, "resolver_dist": { "sdist_server_url": None, @@ -140,8 +154,15 @@ "include_wheels": False, "ignore_platform": False, "use_pypi_org_metadata": None, + "provider": None, + "organization": None, + "repo": None, + "project_path": None, + "server_url": None, + "tag_matcher": None, }, "variants": {}, + "vendor_rust_before_patch": False, } PREBUILT_PKG_EXPECTED: dict[str, typing.Any] = { @@ -158,6 +179,7 @@ Version("1.0.1"): ["onboard"], }, "config_settings": {}, + "create_files": [], "env": {}, "download_source": { "url": None, @@ -172,6 +194,8 @@ "remove_build_requires": [], "update_build_requires": [], "requires_external": [], + "update_install_requires": [], + "remove_install_requires": [], }, "resolver_dist": { "sdist_server_url": None, @@ -179,6 +203,12 @@ "include_wheels": False, "ignore_platform": False, "use_pypi_org_metadata": None, + "provider": None, + "organization": None, + "repo": None, + "project_path": None, + "server_url": None, + "tag_matcher": None, }, "variants": { "cpu": { @@ -188,6 +218,7 @@ "wheel_server_url": None, }, }, + "vendor_rust_before_patch": False, } @@ -805,3 +836,346 @@ def test_use_pypi_org_metadata(testdata_context: context.WorkContext) -> None: "somepackage_without_customization" ) assert pbi.use_pypi_org_metadata + + +@patch("fromager.packagesettings.get_cpu_count", return_value=1) +@patch("fromager.packagesettings.get_available_memory_gib", return_value=8.0) +def test_get_extra_environ_version_substitution( + _get_mem: Mock, + _get_cpu: Mock, +) -> None: + """Verify ${version} template vars are substituted in env settings.""" + settings_yaml = """ +env: + MY_VERSION: "${version}" + MY_BASE: "${version_base_version}" + MY_POST: "${version_post}" +""" + from fromager.packagesettings import Settings, SettingsFile + + ps = PackageSettings.from_string("version-pkg", settings_yaml) + s = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = s.package_build_info("version-pkg") + result = pbi.get_extra_environ(template_env={}, version=Version("1.2.3")) + assert result["MY_VERSION"] == "1.2.3" + assert result["MY_BASE"] == "1.2.3" + assert result["MY_POST"] == "" + + +@patch("fromager.packagesettings.get_cpu_count", return_value=1) +@patch("fromager.packagesettings.get_available_memory_gib", return_value=8.0) +def test_get_extra_environ_version_post_release( + _get_mem: Mock, + _get_cpu: Mock, +) -> None: + """Verify ${version_base_version} and ${version_post} with post-release.""" + settings_yaml = """ +env: + MY_VERSION: "${version}" + MY_BASE: "${version_base_version}" + MY_POST: "${version_post}" +""" + from fromager.packagesettings import Settings, SettingsFile + + ps = PackageSettings.from_string("version-pkg", settings_yaml) + s = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = s.package_build_info("version-pkg") + result = pbi.get_extra_environ(template_env={}, version=Version("1.2.3.post1")) + assert result["MY_VERSION"] == "1.2.3.post1" + assert result["MY_BASE"] == "1.2.3" + assert result["MY_POST"] == "1" + + +@patch("fromager.packagesettings.get_cpu_count", return_value=1) +@patch("fromager.packagesettings.get_available_memory_gib", return_value=8.0) +def test_get_extra_environ_version_none_backward_compat( + _get_mem: Mock, + _get_cpu: Mock, + testdata_context: context.WorkContext, +) -> None: + """Verify backward compatibility when version is None.""" + testdata_context.settings.max_jobs = 1 + pbi = testdata_context.settings.package_build_info(TEST_EMPTY_PKG) + result = pbi.get_extra_environ(template_env={}, version=None) + assert "version" not in result + assert "version_base_version" not in result + assert "version_post" not in result + + +@patch("fromager.packagesettings.get_cpu_count", return_value=1) +@patch("fromager.packagesettings.get_available_memory_gib", return_value=8.0) +def test_get_extra_environ_version_env_override( + _get_mem: Mock, + _get_cpu: Mock, +) -> None: + """Verify that actual env variables named 'version' take precedence.""" + settings_yaml = """ +env: + MY_VERSION: "${version}" +""" + from fromager.packagesettings import Settings, SettingsFile + + ps = PackageSettings.from_string("version-pkg", settings_yaml) + s = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = s.package_build_info("version-pkg") + result = pbi.get_extra_environ( + template_env={"version": "from-env"}, + version=Version("1.2.3"), + ) + assert result["MY_VERSION"] == "from-env" + + +def test_create_file_relative_path() -> None: + """Verify CreateFile accepts relative paths.""" + cf = CreateFile(path="src/mypackage/__init__.py", content="") + assert cf.path == "src/mypackage/__init__.py" + assert cf.content == "" + + +def test_create_file_rejects_absolute_path() -> None: + """Verify CreateFile rejects absolute paths.""" + with pytest.raises(pydantic.ValidationError, match="is not a relative path"): + CreateFile(path="/etc/passwd", content="bad") + + +def test_create_file_rejects_path_traversal() -> None: + """Verify CreateFile rejects paths with '..' components.""" + with pytest.raises(pydantic.ValidationError, match="must not contain"): + CreateFile(path="../../../etc/passwd", content="bad") + + with pytest.raises(pydantic.ValidationError, match="must not contain"): + CreateFile(path="src/../../etc/passwd", content="bad") + + +def test_create_file_with_content() -> None: + """Verify CreateFile stores content.""" + cf = CreateFile(path="version.py", content='__version__ = "${version}"') + assert cf.content == '__version__ = "${version}"' + + +def test_vendor_rust_before_patch_default() -> None: + """Verify vendor_rust_before_patch defaults to False.""" + settings = PackageSettings.from_default("test-pkg") + assert settings.vendor_rust_before_patch is False + + +def test_vendor_rust_before_patch_from_yaml() -> None: + """Verify vendor_rust_before_patch can be set via YAML.""" + data = "vendor_rust_before_patch: true\n" + settings = PackageSettings.from_string("test-pkg", data) + assert settings.vendor_rust_before_patch is True + + +def test_create_files_from_yaml() -> None: + """Verify create_files can be parsed from YAML.""" + data = """\ +create_files: + - path: src/mypackage/__init__.py + content: "" + - path: src/mypackage/version.py + content: | + __version__ = "${version}" +""" + settings = PackageSettings.from_string("test-pkg", data) + assert len(settings.create_files) == 2 + assert settings.create_files[0].path == "src/mypackage/__init__.py" + assert settings.create_files[0].content == "" + assert settings.create_files[1].path == "src/mypackage/version.py" + assert '__version__ = "${version}"' in settings.create_files[1].content + + +def test_pbi_vendor_rust_before_patch() -> None: + """Verify PackageBuildInfo exposes vendor_rust_before_patch.""" + from fromager.packagesettings import Settings, SettingsFile + + data = "vendor_rust_before_patch: true\n" + ps = PackageSettings.from_string("test-pkg", data) + settings = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = settings.package_build_info("test-pkg") + assert pbi.vendor_rust_before_patch is True + + +def test_pbi_create_files() -> None: + """Verify PackageBuildInfo exposes create_files.""" + from fromager.packagesettings import Settings, SettingsFile + + data = """\ +create_files: + - path: src/__init__.py + content: "" +""" + ps = PackageSettings.from_string("test-pkg", data) + settings = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = settings.package_build_info("test-pkg") + assert len(pbi.create_files) == 1 + assert pbi.create_files[0].path == "src/__init__.py" + + +def test_resolver_dist_github_provider() -> None: + """Verify ResolverDist accepts valid github provider config.""" + rd = ResolverDist(provider="github", organization="myorg", repo="myrepo") + assert rd.provider == "github" + assert rd.organization == "myorg" + assert rd.repo == "myrepo" + + +def test_resolver_dist_github_provider_missing_fields() -> None: + """Verify github provider requires organization and repo.""" + with pytest.raises(pydantic.ValidationError, match=r"organization.*repo"): + ResolverDist(provider="github", organization="myorg") + with pytest.raises(pydantic.ValidationError, match=r"organization.*repo"): + ResolverDist(provider="github", repo="myrepo") + with pytest.raises(pydantic.ValidationError, match=r"organization.*repo"): + ResolverDist(provider="github") + + +def test_resolver_dist_gitlab_provider_with_project_path() -> None: + """Verify GitLab provider with project_path.""" + rd = ResolverDist(provider="gitlab", project_path="group/subgroup/project") + assert rd.provider == "gitlab" + assert rd.project_path == "group/subgroup/project" + + +def test_resolver_dist_gitlab_provider_with_org_repo() -> None: + """Verify GitLab provider with organization and repo.""" + rd = ResolverDist(provider="gitlab", organization="myorg", repo="myrepo") + assert rd.provider == "gitlab" + assert rd.organization == "myorg" + assert rd.repo == "myrepo" + + +def test_resolver_dist_gitlab_provider_missing_fields() -> None: + """Verify gitlab provider requires project_path or organization+repo.""" + with pytest.raises(pydantic.ValidationError, match="project_path"): + ResolverDist(provider="gitlab") + with pytest.raises(pydantic.ValidationError, match="project_path"): + ResolverDist(provider="gitlab", organization="myorg") + + +def test_resolver_dist_unknown_provider() -> None: + """Verify unknown provider names are rejected.""" + with pytest.raises(pydantic.ValidationError, match="Unknown provider"): + ResolverDist(provider="unknown") + + +def test_resolver_dist_pypi_provider() -> None: + """Verify pypi provider is accepted (explicit or default).""" + rd = ResolverDist(provider="pypi") + assert rd.provider == "pypi" + rd_default = ResolverDist() + assert rd_default.provider is None + + +def test_resolver_dist_tag_matcher_valid() -> None: + """Verify valid tag_matcher regex with one capturing group.""" + rd = ResolverDist( + provider="github", + organization="org", + repo="repo", + tag_matcher=r"v(\d+\.\d+\.\d+)", + ) + assert rd.tag_matcher == r"v(\d+\.\d+\.\d+)" + + +def test_resolver_dist_tag_matcher_invalid_regex() -> None: + """Verify invalid regex in tag_matcher is rejected.""" + with pytest.raises(pydantic.ValidationError, match="Invalid tag_matcher regex"): + ResolverDist( + provider="github", + organization="org", + repo="repo", + tag_matcher=r"v(\d+", + ) + + +def test_resolver_dist_tag_matcher_wrong_groups() -> None: + """Verify tag_matcher with zero or multiple groups is rejected.""" + with pytest.raises(pydantic.ValidationError, match="exactly 1 capturing group"): + ResolverDist( + provider="github", + organization="org", + repo="repo", + tag_matcher=r"v\d+\.\d+\.\d+", + ) + with pytest.raises(pydantic.ValidationError, match="exactly 1 capturing group"): + ResolverDist( + provider="github", + organization="org", + repo="repo", + tag_matcher=r"v(\d+)\.(\d+)", + ) + + +def test_resolver_dist_from_yaml() -> None: + """Verify ResolverDist can be parsed from YAML via PackageSettings.""" + yaml_data = """ +resolver_dist: + provider: github + organization: openssl + repo: openssl + tag_matcher: "openssl-(\\\\d+\\\\.\\\\d+\\\\.\\\\d+)" +""" + ps = PackageSettings.from_string("test-resolver-pkg", yaml_data) + assert ps.resolver_dist.provider == "github" + assert ps.resolver_dist.organization == "openssl" + assert ps.resolver_dist.repo == "openssl" + + +def test_pbi_resolver_properties() -> None: + """Verify PackageBuildInfo exposes resolver properties.""" + from fromager.packagesettings import Settings, SettingsFile + + ps = PackageSettings.from_string( + "resolver-test", + """ +resolver_dist: + provider: github + organization: myorg + repo: myrepo + tag_matcher: "v(.*)" +""", + ) + settings = Settings( + settings=SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=pathlib.Path("/tmp"), + max_jobs=1, + ) + pbi = settings.package_build_info("resolver-test") + assert pbi.resolver_provider == "github" + assert pbi.resolver_organization == "myorg" + assert pbi.resolver_repo == "myrepo" + assert pbi.resolver_project_path is None + assert pbi.resolver_server_url is None + assert pbi.resolver_tag_matcher == "v(.*)" diff --git a/tests/test_pyproject.py b/tests/test_pyproject.py index a242da1e..1557a517 100644 --- a/tests/test_pyproject.py +++ b/tests/test_pyproject.py @@ -177,3 +177,147 @@ def test_pyproject_override_multiple_requires(tmp_path: pathlib.Path) -> None: "setuptools", ] ] + + +# --- Tests for install_requires (project dependencies) --- + +PYPROJECT_WITH_DEPS = """ +[build-system] +requires = ["setuptools"] + +[project] +name = "testproject" +dependencies = [ + "requests>=2.28", + "numpy>=1.24", + "nvidia-cublas-cu12", + "nvidia-cuda-runtime-cu12", + "torch>=2.3.0", +] +""" + + +def test_pyproject_remove_install_requires(tmp_path: pathlib.Path) -> None: + """Verify dependencies can be removed from [project] dependencies.""" + tmp_path.joinpath("pyproject.toml").write_text(PYPROJECT_WITH_DEPS) + req = Requirement("testproject==1.0.0") + fixer = pyproject.PyprojectFix( + req, + build_dir=tmp_path, + update_build_requires=[], + remove_build_requires=[], + remove_install_requires=[ + canonicalize_name("nvidia-cublas-cu12"), + canonicalize_name("nvidia-cuda-runtime-cu12"), + ], + ) + fixer.run() + doc = tomlkit.loads(tmp_path.joinpath("pyproject.toml").read_text()) + project = dict(doc["project"].items()) # type: ignore[union-attr] + deps: list[str] = list(project["dependencies"]) + assert "nvidia-cublas-cu12" not in deps + assert "nvidia-cuda-runtime-cu12" not in deps + assert str(Requirement("numpy>=1.24")) in deps + assert str(Requirement("requests>=2.28")) in deps + assert str(Requirement("torch>=2.3.0")) in deps + + +def test_pyproject_update_install_requires(tmp_path: pathlib.Path) -> None: + """Verify dependencies can be added/updated in [project] dependencies.""" + tmp_path.joinpath("pyproject.toml").write_text(PYPROJECT_WITH_DEPS) + req = Requirement("testproject==1.0.0") + fixer = pyproject.PyprojectFix( + req, + build_dir=tmp_path, + update_build_requires=[], + remove_build_requires=[], + update_install_requires=["torch>=2.4.0", "click>=8.0"], + ) + fixer.run() + doc = tomlkit.loads(tmp_path.joinpath("pyproject.toml").read_text()) + project = dict(doc["project"].items()) # type: ignore[union-attr] + deps: list[str] = list(project["dependencies"]) + # torch should be replaced, click should be added + assert str(Requirement("torch>=2.4.0")) in deps + assert str(Requirement("click>=8.0")) in deps + # original torch version should be gone + assert str(Requirement("torch>=2.3.0")) not in deps + + +def test_pyproject_install_requires_no_project_section( + tmp_path: pathlib.Path, +) -> None: + """Verify install_requires changes are skipped when no [project] section.""" + tmp_path.joinpath("pyproject.toml").write_text( + textwrap.dedent(""" + [build-system] + requires = ["setuptools"] + """) + ) + req = Requirement("testproject==1.0.0") + fixer = pyproject.PyprojectFix( + req, + build_dir=tmp_path, + update_build_requires=[], + remove_build_requires=[], + remove_install_requires=[canonicalize_name("nvidia-cublas-cu12")], + ) + fixer.run() + doc = tomlkit.loads(tmp_path.joinpath("pyproject.toml").read_text()) + assert "project" not in doc + + +def test_pyproject_install_requires_no_dependencies_key( + tmp_path: pathlib.Path, +) -> None: + """Verify install_requires changes are skipped when no dependencies key.""" + tmp_path.joinpath("pyproject.toml").write_text( + textwrap.dedent(""" + [build-system] + requires = ["setuptools"] + + [project] + name = "testproject" + """) + ) + req = Requirement("testproject==1.0.0") + fixer = pyproject.PyprojectFix( + req, + build_dir=tmp_path, + update_build_requires=[], + remove_build_requires=[], + remove_install_requires=[canonicalize_name("nvidia-cublas-cu12")], + ) + fixer.run() + doc = tomlkit.loads(tmp_path.joinpath("pyproject.toml").read_text()) + project = dict(doc["project"].items()) # type: ignore[union-attr] + assert "dependencies" not in project + + +def test_pyproject_install_requires_remove_and_update( + tmp_path: pathlib.Path, +) -> None: + """Verify combination of remove + update on [project] dependencies.""" + tmp_path.joinpath("pyproject.toml").write_text(PYPROJECT_WITH_DEPS) + req = Requirement("testproject==1.0.0") + fixer = pyproject.PyprojectFix( + req, + build_dir=tmp_path, + update_build_requires=[], + remove_build_requires=[], + remove_install_requires=[ + canonicalize_name("nvidia-cublas-cu12"), + canonicalize_name("nvidia-cuda-runtime-cu12"), + ], + update_install_requires=["torch>=2.4.0"], + ) + fixer.run() + doc = tomlkit.loads(tmp_path.joinpath("pyproject.toml").read_text()) + project = dict(doc["project"].items()) # type: ignore[union-attr] + deps: list[str] = list(project["dependencies"]) + assert "nvidia-cublas-cu12" not in deps + assert "nvidia-cuda-runtime-cu12" not in deps + assert str(Requirement("torch>=2.4.0")) in deps + assert str(Requirement("torch>=2.3.0")) not in deps + assert str(Requirement("numpy>=1.24")) in deps + assert str(Requirement("requests>=2.28")) in deps diff --git a/tests/test_sources.py b/tests/test_sources.py index 9e216a25..275b0129 100644 --- a/tests/test_sources.py +++ b/tests/test_sources.py @@ -7,6 +7,8 @@ from packaging.version import Version from fromager import context, packagesettings, resolver, sources +from fromager.packagesettings import CreateFile +from fromager.requirements_file import SourceType @patch("fromager.sources.download_url") @@ -218,3 +220,407 @@ def test_validate_sdist_file( else: with pytest.raises(ValueError): sources.validate_sdist_filename(req, version, sdist_file) + + +@patch("fromager.resolver.resolve_from_provider") +@patch("fromager.resolver.GitHubTagProvider") +def test_resolve_with_configured_github_provider( + mock_github_cls: Mock, + mock_resolve_from_provider: Mock, + tmp_context: context.WorkContext, +) -> None: + """Verify _resolve_with_configured_provider creates GitHubTagProvider.""" + mock_provider = Mock() + mock_github_cls.return_value = mock_provider + mock_resolve_from_provider.return_value = ( + "https://github.com/org/repo/archive/v1.0.tar.gz", + Version("1.0"), + ) + + ps = packagesettings.PackageSettings.from_string( + "github-pkg", + """ +resolver_dist: + provider: github + organization: myorg + repo: myrepo + tag_matcher: "v(.*)" +""", + ) + settings = packagesettings.Settings( + settings=packagesettings.SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=tmp_context.settings.patches_dir, + max_jobs=1, + ) + tmp_context.settings = settings + + req = Requirement("github-pkg==1.0") + pbi = tmp_context.package_build_info(req) + url, version = sources._resolve_with_configured_provider( + ctx=tmp_context, + req=req, + pbi=pbi, + ) + assert url == "https://github.com/org/repo/archive/v1.0.tar.gz" + assert version == Version("1.0") + mock_github_cls.assert_called_once() + mock_resolve_from_provider.assert_called_once_with(mock_provider, req) + + +@patch("fromager.resolver.resolve_from_provider") +@patch("fromager.resolver.GitLabTagProvider") +def test_resolve_with_configured_gitlab_provider( + mock_gitlab_cls: Mock, + mock_resolve_from_provider: Mock, + tmp_context: context.WorkContext, +) -> None: + """Verify _resolve_with_configured_provider creates GitLabTagProvider.""" + mock_provider = Mock() + mock_gitlab_cls.return_value = mock_provider + mock_resolve_from_provider.return_value = ( + "https://gitlab.com/group/project/-/archive/v1.0/project-v1.0.tar.gz", + Version("1.0"), + ) + + ps = packagesettings.PackageSettings.from_string( + "gitlab-pkg", + """ +resolver_dist: + provider: gitlab + project_path: group/project + server_url: https://gitlab.example.com +""", + ) + settings = packagesettings.Settings( + settings=packagesettings.SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=tmp_context.settings.patches_dir, + max_jobs=1, + ) + tmp_context.settings = settings + + req = Requirement("gitlab-pkg==1.0") + pbi = tmp_context.package_build_info(req) + _url, _version = sources._resolve_with_configured_provider( + ctx=tmp_context, + req=req, + pbi=pbi, + ) + mock_gitlab_cls.assert_called_once_with( + project_path="group/project", + server_url="https://gitlab.example.com", + constraints=tmp_context.constraints, + matcher=None, + ) + mock_resolve_from_provider.assert_called_once_with(mock_provider, req) + + +@patch("fromager.resolver.resolve_from_provider") +@patch("fromager.resolver.GitLabTagProvider") +def test_resolve_gitlab_with_org_repo_fallback( + mock_gitlab_cls: Mock, + mock_resolve_from_provider: Mock, + tmp_context: context.WorkContext, +) -> None: + """Verify GitLab provider uses org/repo when project_path is not set.""" + mock_provider = Mock() + mock_gitlab_cls.return_value = mock_provider + mock_resolve_from_provider.return_value = ( + "https://gitlab.com/myorg/myrepo/-/archive/v1.0.tar.gz", + Version("1.0"), + ) + + ps = packagesettings.PackageSettings.from_string( + "gitlab-org-pkg", + """ +resolver_dist: + provider: gitlab + organization: myorg + repo: myrepo +""", + ) + settings = packagesettings.Settings( + settings=packagesettings.SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=tmp_context.settings.patches_dir, + max_jobs=1, + ) + tmp_context.settings = settings + + req = Requirement("gitlab-org-pkg==1.0") + pbi = tmp_context.package_build_info(req) + sources._resolve_with_configured_provider( + ctx=tmp_context, + req=req, + pbi=pbi, + ) + mock_gitlab_cls.assert_called_once_with( + project_path="myorg/myrepo", + server_url="https://gitlab.com", + constraints=tmp_context.constraints, + matcher=None, + ) + + +@patch("fromager.resolver.resolve") +def test_default_resolve_source_with_yaml_provider( + mock_resolve: Mock, + tmp_context: context.WorkContext, +) -> None: + """Verify default_resolve_source skips PyPI when provider is configured.""" + ps = packagesettings.PackageSettings.from_string( + "provider-pkg", + """ +resolver_dist: + provider: github + organization: myorg + repo: myrepo +""", + ) + settings = packagesettings.Settings( + settings=packagesettings.SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=tmp_context.settings.patches_dir, + max_jobs=1, + ) + tmp_context.settings = settings + + req = Requirement("provider-pkg==1.0") + + with patch.object( + sources, + "_resolve_with_configured_provider", + return_value=("https://example.com/archive.tar.gz", Version("1.0")), + ) as mock_configured: + url, version = sources.default_resolve_source( + tmp_context, req, resolver.PYPI_SERVER_URL + ) + + mock_configured.assert_called_once() + mock_resolve.assert_not_called() + assert url == "https://example.com/archive.tar.gz" + assert version == Version("1.0") + + +def test_get_source_type_with_yaml_provider( + tmp_context: context.WorkContext, +) -> None: + """Verify get_source_type detects YAML-configured provider.""" + ps = packagesettings.PackageSettings.from_string( + "source-type-pkg", + """ +resolver_dist: + provider: github + organization: myorg + repo: myrepo +""", + ) + settings = packagesettings.Settings( + settings=packagesettings.SettingsFile(), + package_settings=[ps], + variant="cpu", + patches_dir=tmp_context.settings.patches_dir, + max_jobs=1, + ) + tmp_context.settings = settings + + req = Requirement("source-type-pkg==1.0") + source_type = sources.get_source_type(tmp_context, req) + assert source_type == SourceType.OVERRIDE + + +@patch("fromager.vendor_rust.vendor_rust") +@patch("fromager.pyproject.apply_project_override") +@patch("fromager.sources.patch_source") +@patch("fromager.sources.ensure_pkg_info") +def test_prepare_new_source_calls_ensure_pkg_info( + ensure_pkg_info: Mock, + patch_source: Mock, + apply_project_override: Mock, + vendor_rust: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify prepare_new_source calls ensure_pkg_info before patching.""" + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + version = Version("1.0") + + sources.prepare_new_source(tmp_context, req, source_root_dir, version) + + ensure_pkg_info.assert_called_once() + patch_source.assert_called_once() + apply_project_override.assert_called_once() + vendor_rust.assert_called_once() + + +@patch("fromager.vendor_rust.vendor_rust") +@patch("fromager.pyproject.apply_project_override") +@patch("fromager.sources.patch_source") +@patch("fromager.sources.ensure_pkg_info") +def test_prepare_new_source_vendor_rust_default_after_patch( + ensure_pkg_info: Mock, + patch_source: Mock, + apply_project_override: Mock, + vendor_rust: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify vendor_rust runs after patch_source by default.""" + call_order: list[str] = [] + patch_source.side_effect = lambda *a, **kw: call_order.append("patch") + vendor_rust.side_effect = lambda *a, **kw: call_order.append("vendor_rust") + + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + + sources.prepare_new_source(tmp_context, req, source_root_dir, Version("1.0")) + + assert call_order == ["patch", "vendor_rust"] + + +@patch("fromager.vendor_rust.vendor_rust") +@patch("fromager.pyproject.apply_project_override") +@patch("fromager.sources.patch_source") +@patch("fromager.sources.ensure_pkg_info") +@patch.object( + packagesettings.PackageBuildInfo, + "vendor_rust_before_patch", + new_callable=lambda: property(lambda self: True), +) +def test_prepare_new_source_vendor_rust_before_patch( + _vendor_rust_prop: Mock, + ensure_pkg_info: Mock, + patch_source: Mock, + apply_project_override: Mock, + vendor_rust: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify vendor_rust runs before patch_source when setting is True.""" + call_order: list[str] = [] + patch_source.side_effect = lambda *a, **kw: call_order.append("patch") + vendor_rust.side_effect = lambda *a, **kw: call_order.append("vendor_rust") + + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + + sources.prepare_new_source(tmp_context, req, source_root_dir, Version("1.0")) + + assert call_order == ["vendor_rust", "patch"] + + +@patch("fromager.vendor_rust.vendor_rust") +@patch("fromager.pyproject.apply_project_override") +@patch("fromager.sources.patch_source") +@patch("fromager.sources.ensure_pkg_info") +@patch.object( + packagesettings.PackageBuildInfo, + "create_files", + new_callable=lambda: property( + lambda self: [ + CreateFile(path="src/pkg/__init__.py", content=""), + CreateFile( + path="src/pkg/version.py", + content='__version__ = "${version}"', + ), + ] + ), +) +def test_prepare_new_source_create_files( + _create_files_prop: Mock, + ensure_pkg_info: Mock, + patch_source: Mock, + apply_project_override: Mock, + vendor_rust: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify create_source_files creates files with template substitution.""" + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + + sources.prepare_new_source(tmp_context, req, source_root_dir, Version("1.0")) + + init_file = source_root_dir / "src" / "pkg" / "__init__.py" + assert init_file.exists() + assert init_file.read_text() == "" + + version_file = source_root_dir / "src" / "pkg" / "version.py" + assert version_file.exists() + assert version_file.read_text() == '__version__ = "1.0"' + + +def test_create_source_files_no_files( + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify create_source_files is a no-op when no files are configured.""" + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + + sources.create_source_files(tmp_context, req, source_root_dir, Version("1.0")) + + assert list(source_root_dir.iterdir()) == [] + + +@patch.object( + packagesettings.PackageBuildInfo, + "create_files", + new_callable=lambda: property( + lambda self: [ + CreateFile(path="nested/dir/file.txt", content="hello"), + ] + ), +) +def test_create_source_files_creates_parent_dirs( + _create_files_prop: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify create_source_files creates parent directories.""" + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + + sources.create_source_files(tmp_context, req, source_root_dir, Version("1.0")) + + created_file = source_root_dir / "nested" / "dir" / "file.txt" + assert created_file.exists() + assert created_file.read_text() == "hello" + + +@patch.object( + packagesettings.PackageBuildInfo, + "create_files", + new_callable=lambda: property( + lambda self: [ + CreateFile(path="existing.txt", content="new content"), + ] + ), +) +def test_create_source_files_overwrites_existing( + _create_files_prop: Mock, + tmp_path: pathlib.Path, + tmp_context: context.WorkContext, +) -> None: + """Verify create_source_files overwrites existing files.""" + req = Requirement("foo==1.0") + source_root_dir = tmp_path / "foo-1.0" + source_root_dir.mkdir() + existing = source_root_dir / "existing.txt" + existing.write_text("old content") + + sources.create_source_files(tmp_context, req, source_root_dir, Version("1.0")) + + assert existing.read_text() == "new content"